| from __future__ import annotations |
|
|
| import gc |
| import os |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
| from safetensors import safe_open |
| from torchao.quantization import Int4WeightOnlyConfig, quantize_ |
|
|
| from generator import DEFAULT_MISO_TTS_REPO_ID, Generator |
| from models import MISO_TTS_8B_CONFIG, Model, ModelArgs |
|
|
|
|
| def _get_submodule(root: nn.Module, path: str) -> nn.Module: |
| mod = root |
| if not path: |
| return mod |
| for part in path.split("."): |
| if part.isdigit(): |
| mod = mod[int(part)] |
| else: |
| mod = getattr(mod, part) |
| return mod |
|
|
|
|
| def _set_parameter(root: nn.Module, name: str, value: torch.Tensor) -> None: |
| if "." in name: |
| module_path, param_name = name.rsplit(".", 1) |
| mod = _get_submodule(root, module_path) |
| else: |
| mod = root |
| param_name = name |
| setattr(mod, param_name, nn.Parameter(value, requires_grad=False)) |
|
|
|
|
| def _is_linear_weight(root: nn.Module, name: str) -> bool: |
| if not name.endswith(".weight"): |
| return False |
| module_path = name[: -len(".weight")] |
| try: |
| mod = _get_submodule(root, module_path) |
| except Exception: |
| return False |
| return isinstance(mod, nn.Linear) |
|
|
|
|
| def _assert_no_meta_tensors(model: nn.Module) -> None: |
| leftovers = [] |
| for name, tensor in list(model.named_parameters()) + list(model.named_buffers()): |
| if getattr(tensor, "is_meta", False): |
| leftovers.append(name) |
| if leftovers: |
| raise RuntimeError(f"Model still has meta tensors: {leftovers[:20]}{'...' if len(leftovers) > 20 else ''}") |
|
|
|
|
| def load_miso_8b_int4_weight_only( |
| device: str = "cuda", |
| model_path_or_repo_id: Optional[str] = None, |
| dtype: torch.dtype = torch.bfloat16, |
| group_size: int = 128, |
| quantize_output_heads: bool = True, |
| ) -> Generator: |
| """Load MisoTTS with torchao int4 weight-only quantization. |
| |
| This avoids ever materializing the full 8B model on GPU or holding both a full |
| model and full safetensors state dict in CPU RAM. Linear weights are streamed |
| one at a time from safetensors, moved to CUDA as BF16, quantized in-place, and |
| kept on GPU. Non-linear parameters (embeddings, norms, audio_head) stay BF16. |
| """ |
| source = model_path_or_repo_id or os.environ.get("MISO_TTS_8B_MODEL", DEFAULT_MISO_TTS_REPO_ID) |
| if os.path.isfile(source): |
| model_file = source |
| elif os.path.isdir(source): |
| model_file = os.path.join(source, "model.safetensors") |
| else: |
| model_file = hf_hub_download(repo_id=source, filename="model.safetensors") |
|
|
| with torch.device("meta"): |
| model = Model(MISO_TTS_8B_CONFIG) |
|
|
| qconfig = Int4WeightOnlyConfig(group_size=group_size) |
| loaded = set() |
|
|
| with safe_open(model_file, framework="pt", device="cpu") as f: |
| keys = list(f.keys()) |
| |
| |
| for key in keys: |
| if _is_linear_weight(model, key): |
| continue |
| tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) |
| _set_parameter(model, key, tensor) |
| loaded.add(key) |
| del tensor |
|
|
| for idx, key in enumerate(keys, start=1): |
| if not _is_linear_weight(model, key): |
| continue |
| module_path = key[: -len(".weight")] |
| if (not quantize_output_heads) and (module_path.endswith("codebook0_head") or module_path.endswith("projection")): |
| tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) |
| _set_parameter(model, key, tensor) |
| loaded.add(key) |
| del tensor |
| continue |
|
|
| mod = _get_submodule(model, module_path) |
| tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True) |
| mod.weight = nn.Parameter(tensor, requires_grad=False) |
| del tensor |
| quantize_(mod, qconfig) |
| loaded.add(key) |
| if idx % 25 == 0: |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| missing = set(model.state_dict().keys()) - loaded |
| |
| |
| if missing: |
| raise RuntimeError(f"Missing checkpoint tensors: {sorted(missing)[:20]}") |
|
|
| model.eval() |
| _assert_no_meta_tensors(model) |
| with torch.device(device): |
| for mod in model.modules(): |
| if hasattr(mod, "rope_init") and not getattr(mod, "is_cache_built", True): |
| mod.rope_init() |
| torch.cuda.empty_cache() |
| return Generator(model) |
|
|