MisoTTS-8B-torchao-int4 / quant_loader_streaming.py
droyster's picture
Add files using upload-large-folder tool
67196d0 verified
Raw
History Blame Contribute Delete
4.92 kB
from __future__ import annotations
import gc
import os
from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from generator import DEFAULT_MISO_TTS_REPO_ID, Generator
from models import MISO_TTS_8B_CONFIG, Model, ModelArgs
def _get_submodule(root: nn.Module, path: str) -> nn.Module:
mod = root
if not path:
return mod
for part in path.split("."):
if part.isdigit():
mod = mod[int(part)] # type: ignore[index]
else:
mod = getattr(mod, part)
return mod
def _set_parameter(root: nn.Module, name: str, value: torch.Tensor) -> None:
if "." in name:
module_path, param_name = name.rsplit(".", 1)
mod = _get_submodule(root, module_path)
else:
mod = root
param_name = name
setattr(mod, param_name, nn.Parameter(value, requires_grad=False))
def _is_linear_weight(root: nn.Module, name: str) -> bool:
if not name.endswith(".weight"):
return False
module_path = name[: -len(".weight")]
try:
mod = _get_submodule(root, module_path)
except Exception:
return False
return isinstance(mod, nn.Linear)
def _assert_no_meta_tensors(model: nn.Module) -> None:
leftovers = []
for name, tensor in list(model.named_parameters()) + list(model.named_buffers()):
if getattr(tensor, "is_meta", False):
leftovers.append(name)
if leftovers:
raise RuntimeError(f"Model still has meta tensors: {leftovers[:20]}{'...' if len(leftovers) > 20 else ''}")
def load_miso_8b_int4_weight_only(
device: str = "cuda",
model_path_or_repo_id: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
group_size: int = 128,
quantize_output_heads: bool = True,
) -> Generator:
"""Load MisoTTS with torchao int4 weight-only quantization.
This avoids ever materializing the full 8B model on GPU or holding both a full
model and full safetensors state dict in CPU RAM. Linear weights are streamed
one at a time from safetensors, moved to CUDA as BF16, quantized in-place, and
kept on GPU. Non-linear parameters (embeddings, norms, audio_head) stay BF16.
"""
source = model_path_or_repo_id or os.environ.get("MISO_TTS_8B_MODEL", DEFAULT_MISO_TTS_REPO_ID)
if os.path.isfile(source):
model_file = source
elif os.path.isdir(source):
model_file = os.path.join(source, "model.safetensors")
else:
model_file = hf_hub_download(repo_id=source, filename="model.safetensors")
with torch.device("meta"):
model = Model(MISO_TTS_8B_CONFIG)
qconfig = Int4WeightOnlyConfig(group_size=group_size)
loaded = set()
with safe_open(model_file, framework="pt", device="cpu") as f:
keys = list(f.keys())
# Load small/BF16 parameters first so next(model.parameters()) sees CUDA.
# Linear weights are quantized one at a time afterwards.
for key in keys:
if _is_linear_weight(model, key):
continue
tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
_set_parameter(model, key, tensor)
loaded.add(key)
del tensor
for idx, key in enumerate(keys, start=1):
if not _is_linear_weight(model, key):
continue
module_path = key[: -len(".weight")]
if (not quantize_output_heads) and (module_path.endswith("codebook0_head") or module_path.endswith("projection")):
tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
_set_parameter(model, key, tensor)
loaded.add(key)
del tensor
continue
mod = _get_submodule(model, module_path)
tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
mod.weight = nn.Parameter(tensor, requires_grad=False) # type: ignore[assignment]
del tensor
quantize_(mod, qconfig)
loaded.add(key)
if idx % 25 == 0:
torch.cuda.empty_cache()
gc.collect()
missing = set(model.state_dict().keys()) - loaded
# Quantized Linear state_dict entries remain as "weight" tensor subclasses; allow
# no missing checkpoint tensors, but ignore runtime buffers/caches added later.
if missing:
raise RuntimeError(f"Missing checkpoint tensors: {sorted(missing)[:20]}")
model.eval()
_assert_no_meta_tensors(model)
with torch.device(device):
for mod in model.modules():
if hasattr(mod, "rope_init") and not getattr(mod, "is_cache_built", True):
mod.rope_init()
torch.cuda.empty_cache()
return Generator(model)