File size: 4,917 Bytes
67196d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

import gc
import os
from typing import Optional

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torchao.quantization import Int4WeightOnlyConfig, quantize_

from generator import DEFAULT_MISO_TTS_REPO_ID, Generator
from models import MISO_TTS_8B_CONFIG, Model, ModelArgs


def _get_submodule(root: nn.Module, path: str) -> nn.Module:
    mod = root
    if not path:
        return mod
    for part in path.split("."):
        if part.isdigit():
            mod = mod[int(part)]  # type: ignore[index]
        else:
            mod = getattr(mod, part)
    return mod


def _set_parameter(root: nn.Module, name: str, value: torch.Tensor) -> None:
    if "." in name:
        module_path, param_name = name.rsplit(".", 1)
        mod = _get_submodule(root, module_path)
    else:
        mod = root
        param_name = name
    setattr(mod, param_name, nn.Parameter(value, requires_grad=False))


def _is_linear_weight(root: nn.Module, name: str) -> bool:
    if not name.endswith(".weight"):
        return False
    module_path = name[: -len(".weight")]
    try:
        mod = _get_submodule(root, module_path)
    except Exception:
        return False
    return isinstance(mod, nn.Linear)


def _assert_no_meta_tensors(model: nn.Module) -> None:
    leftovers = []
    for name, tensor in list(model.named_parameters()) + list(model.named_buffers()):
        if getattr(tensor, "is_meta", False):
            leftovers.append(name)
    if leftovers:
        raise RuntimeError(f"Model still has meta tensors: {leftovers[:20]}{'...' if len(leftovers) > 20 else ''}")


def load_miso_8b_int4_weight_only(
    device: str = "cuda",
    model_path_or_repo_id: Optional[str] = None,
    dtype: torch.dtype = torch.bfloat16,
    group_size: int = 128,
    quantize_output_heads: bool = True,
) -> Generator:
    """Load MisoTTS with torchao int4 weight-only quantization.

    This avoids ever materializing the full 8B model on GPU or holding both a full
    model and full safetensors state dict in CPU RAM. Linear weights are streamed
    one at a time from safetensors, moved to CUDA as BF16, quantized in-place, and
    kept on GPU. Non-linear parameters (embeddings, norms, audio_head) stay BF16.
    """
    source = model_path_or_repo_id or os.environ.get("MISO_TTS_8B_MODEL", DEFAULT_MISO_TTS_REPO_ID)
    if os.path.isfile(source):
        model_file = source
    elif os.path.isdir(source):
        model_file = os.path.join(source, "model.safetensors")
    else:
        model_file = hf_hub_download(repo_id=source, filename="model.safetensors")

    with torch.device("meta"):
        model = Model(MISO_TTS_8B_CONFIG)

    qconfig = Int4WeightOnlyConfig(group_size=group_size)
    loaded = set()

    with safe_open(model_file, framework="pt", device="cpu") as f:
        keys = list(f.keys())
        # Load small/BF16 parameters first so next(model.parameters()) sees CUDA.
        # Linear weights are quantized one at a time afterwards.
        for key in keys:
            if _is_linear_weight(model, key):
                continue
            tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
            _set_parameter(model, key, tensor)
            loaded.add(key)
            del tensor

        for idx, key in enumerate(keys, start=1):
            if not _is_linear_weight(model, key):
                continue
            module_path = key[: -len(".weight")]
            if (not quantize_output_heads) and (module_path.endswith("codebook0_head") or module_path.endswith("projection")):
                tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
                _set_parameter(model, key, tensor)
                loaded.add(key)
                del tensor
                continue

            mod = _get_submodule(model, module_path)
            tensor = f.get_tensor(key).to(device=device, dtype=dtype, non_blocking=True)
            mod.weight = nn.Parameter(tensor, requires_grad=False)  # type: ignore[assignment]
            del tensor
            quantize_(mod, qconfig)
            loaded.add(key)
            if idx % 25 == 0:
                torch.cuda.empty_cache()
                gc.collect()

    missing = set(model.state_dict().keys()) - loaded
    # Quantized Linear state_dict entries remain as "weight" tensor subclasses; allow
    # no missing checkpoint tensors, but ignore runtime buffers/caches added later.
    if missing:
        raise RuntimeError(f"Missing checkpoint tensors: {sorted(missing)[:20]}")

    model.eval()
    _assert_no_meta_tensors(model)
    with torch.device(device):
        for mod in model.modules():
            if hasattr(mod, "rope_init") and not getattr(mod, "is_cache_built", True):
                mod.rope_init()
    torch.cuda.empty_cache()
    return Generator(model)