from __future__ import annotations import os from typing import Optional import torch from huggingface_hub import hf_hub_download from generator import Generator from models import MISO_TTS_8B_CONFIG, Model DEFAULT_QUANT_REPO_ID = "droyster/MisoTTS-8B-torchao-int4" DEFAULT_QUANT_FILENAME = "model_int4_torchao.pt" def _assert_no_meta_tensors(model: torch.nn.Module) -> None: leftovers = [] for name, tensor in list(model.named_parameters()) + list(model.named_buffers()): if getattr(tensor, "is_meta", False): leftovers.append(name) if leftovers: raise RuntimeError(f"Model still has meta tensors: {leftovers[:20]}{'...' if len(leftovers) > 20 else ''}") def load_miso_8b_torchao_int4( repo_id: str = DEFAULT_QUANT_REPO_ID, filename: str = DEFAULT_QUANT_FILENAME, device: str = "cuda", map_location: Optional[str] = None, disable_watermark: bool = False, ) -> Generator: """Load the pre-quantized TorchAO int4-weight-only MisoTTS checkpoint. The checkpoint stores TorchAO quantized tensor subclasses, so use weights_only=False and load_state_dict(..., assign=True). """ if disable_watermark: import generator as generator_module def _identity_watermark(_watermarker, audio_array, sample_rate, _watermark_key): return audio_array, sample_rate generator_module.load_watermarker = lambda device="cuda": None generator_module.watermark = _identity_watermark if os.path.isfile(repo_id): ckpt_path = repo_id elif os.path.isdir(repo_id): ckpt_path = os.path.join(repo_id, filename) else: ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) target_device = torch.device(device) if target_device.type == "cuda" and target_device.index is None: target_device = torch.device("cuda:0") # Load directly onto the target device. TorchAO's tensor-core tiled int4 # tensors cannot be loaded on CPU and then moved to CUDA. The export may # serialize tensors as bare "cuda", so remap that to explicit "cuda:0". load_map_location = map_location or ({"cuda": str(target_device)} if target_device.type == "cuda" else target_device) payload = torch.load(ckpt_path, map_location=load_map_location, weights_only=False) state_dict = payload.get("model_state_dict", payload) with torch.device("meta"): model = Model(MISO_TTS_8B_CONFIG) missing, unexpected = model.load_state_dict(state_dict, strict=True, assign=True) if missing or unexpected: raise RuntimeError(f"load_state_dict mismatch: missing={missing[:10]} unexpected={unexpected[:10]}") # Non-quantized BF16 tensors and TorchAO int4 weights are already loaded on # the target device. Avoid model.to(device): TensorCoreTiledAQTTensorImpl # cannot be converted from CPU to CUDA after deserialization. model.eval() _assert_no_meta_tensors(model) with torch.device(device): for mod in model.modules(): if hasattr(mod, "rope_init") and not getattr(mod, "is_cache_built", True): mod.rope_init() return Generator(model) __all__ = ["load_miso_8b_torchao_int4"]