from __future__ import annotations import gc import os from typing import Optional import torch import torch.nn as nn from huggingface_hub import hf_hub_download from safetensors import safe_open from torchao.quantization import Int4WeightOnlyConfig, quantize_ from generator import DEFAULT_MISO_TTS_REPO_ID, Generator from models import MISO_TTS_8B_CONFIG, Model, ModelArgs def _get_submodule(root: nn.Module, path: str) -> nn.Module: mod = root if not path: return mod for part in path.split("."): if part.isdigit(): mod = mod[int(part)] # type: ignore[index] else: mod = getattr(mod, part) return mod def _set_parameter(root: nn.Module, name: str, value: torch.Tensor) -> None: if "." in name: module_path, param_name = name.rsplit(".", 1) mod = _get_submodule(root, module_path) else: mod = root param_name = name setattr(mod, param_name, nn.Parameter(value, requires_grad=False)) def _is_linear_weight(root: nn.Module, name: str) -> bool: if not name.endswith(".weight"): return False module_path = name[: -len(".weight")] try: mod = _get_submodule(root, module_path) except Exception: return False return isinstance(mod, nn.Linear) def _assert_no_meta_tensors(model: nn.Module) -> None: leftovers = [] for name, tensor in list(model.named_parameters()) + list(model.named_buffers()): if getattr(tensor, "is_meta", False): leftovers.append(name) if leftovers: raise RuntimeError(f"Model still has meta tensors: {leftovers[:20]}{'...' if len(leftovers) > 20 else ''}") def load_miso_8b_int4_weight_only( device: str = "cuda", model_path_or_repo_id: Optional[str] = None, dtype: torch.dtype = torch.bfloat16, group_size: int = 128, quantize_output_heads: bool = True, ) -> Generator: """Load MisoTTS with torchao int4 weight-only quantization. This avoids ever materializing the full 8B model on GPU or holding both a full model and full safetensors state dict in CPU RAM. Linear weights are streamed one at a time from safetensors, moved to CUDA as BF16, quantized in-place, and kept on GPU. Non-linear parameters (embeddings, norms, audio_head) stay BF16. """ source = model_path_or_repo_id or os.environ.get("MISO_TTS_8B_MODEL", DEFAULT_MISO_TTS_REPO_ID) if os.path.isfile(source): model_file = source elif os.path.isdir(source): model_file = os.path.join(source, "model.safetensors") else: model_file = hf_hub_download(repo_id=source, filename="model.safetensors") with torch.device("meta"): model = Model(MISO_TTS_8B_CONFIG) qconfig = Int4WeightOnlyConfig(group_size=group_size) loaded = set() with safe_open(model_file, framework="pt", device="cpu") as f: keys = list(f.keys()) # Load small/BF16 parameters first so next(model.parameters()) sees CUDA. # Linear weights are quantized one at a time afterwards. for key in keys: if _is_linear_weight(model, key): continue tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) _set_parameter(model, key, tensor) loaded.add(key) del tensor for idx, key in enumerate(keys, start=1): if not _is_linear_weight(model, key): continue module_path = key[: -len(".weight")] if (not quantize_output_heads) and (module_path.endswith("codebook0_head") or module_path.endswith("projection")): tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) _set_parameter(model, key, tensor) loaded.add(key) del tensor continue mod = _get_submodule(model, module_path) tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) mod.weight = nn.Parameter(tensor, requires_grad=False) # type: ignore[assignment] del tensor quantize_(mod, qconfig) loaded.add(key) if idx % 25 == 0: torch.cuda.empty_cache() gc.collect() missing = set(model.state_dict().keys()) - loaded # Quantized Linear state_dict entries remain as "weight" tensor subclasses; allow # no missing checkpoint tensors, but ignore runtime buffers/caches added later. if missing: raise RuntimeError(f"Missing checkpoint tensors: {sorted(missing)[:20]}") model.eval() _assert_no_meta_tensors(model) with torch.device(device): for mod in model.modules(): if hasattr(mod, "rope_init") and not getattr(mod, "is_cache_built", True): mod.rope_init() torch.cuda.empty_cache() return Generator(model)