File size: 10,176 Bytes
7508877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# Copyright © 2025 Apple Inc.
#
# GLM-5 / GLM-5.2  (model_type: "glm_moe_dsa") — a DeepSeek-V3.2-style MoE with MLA +
# DeepSeek Sparse Attention (DSA). Built on top of the `deepseek_v32` model; the only
# architectural difference is the DSA `indexer_types` (full / shared) scheme: GLM places
# the lightning-indexer weights only on `'full'` layers, while `'shared'` layers reuse
# the top-k indices computed by the most recent full layer (in groups of `index_topk_freq`).
# A naive port that builds an indexer on every layer fails to load ("Missing parameters"
# on the shared layers); this handles both layer kinds.

import os
from dataclasses import dataclass
from typing import Dict, List, Optional

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from . import deepseek_v32 as dsv32

# Stream eval per layer so a >RAM model can run on 128GB via mmap paging.
# Set GLM_STREAM_EVAL=0 to disable (faster when the model fits in RAM, e.g. pruned).
GLM_STREAM_EVAL = os.environ.get("GLM_STREAM_EVAL", "1") != "0"


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    vocab_size: int
    hidden_size: int
    index_head_dim: int
    index_n_heads: int
    index_topk: int
    intermediate_size: int
    moe_intermediate_size: int
    num_hidden_layers: int
    num_attention_heads: int
    num_key_value_heads: int
    n_shared_experts: Optional[int]
    n_routed_experts: Optional[int]
    routed_scaling_factor: float
    kv_lora_rank: int
    q_lora_rank: int
    qk_rope_head_dim: int
    v_head_dim: int
    qk_nope_head_dim: int
    topk_method: str
    scoring_func: str
    norm_topk_prob: bool
    n_group: int
    topk_group: int
    num_experts_per_tok: int
    moe_layer_freq: int
    first_k_dense_replace: int
    max_position_embeddings: int
    rms_norm_eps: float
    rope_parameters: Dict
    attention_bias: bool
    # GLM-5.2 DSA sharing: per-layer 'full' | 'shared'; full layers own an indexer.
    indexer_types: Optional[List[str]] = None
    index_topk_freq: int = 4
    rope_scaling: Dict = None
    rope_theta: Optional[float] = None

    def __post_init__(self):
        self.rope_scaling = self.rope_parameters
        self.rope_theta = self.rope_parameters["rope_theta"]


def _is_full(config: ModelArgs, layer_idx: int) -> bool:
    """Does this layer own a DSA indexer? Default to all-full if unspecified."""
    if not config.indexer_types:
        return True
    if layer_idx < len(config.indexer_types):
        return config.indexer_types[layer_idx] == "full"
    return True


class GlmDsaAttention(dsv32.DeepseekV32Attention):
    """DeepSeek-V3.2 attention, but the indexer exists only on 'full' layers.
    'shared' layers receive `shared_topk` (the full layer's topk) and reuse it."""

    def __init__(self, config: ModelArgs, is_full: bool):
        super().__init__(config)
        self.is_full = is_full
        if not is_full:
            # drop the indexer so no indexer weights are expected on this layer
            self.indexer = None

    def __call__(self, x, mask=None, cache=None, shared_topk=None):
        B, L, D = x.shape

        qr = self.q_a_layernorm(self.q_a_proj(x))
        q = self.q_b_proj(qr)
        q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
        q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
        compressed_kv = self.kv_a_proj_with_mqa(x)
        compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
        k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
        kv_latent = self.kv_a_layernorm(compressed_kv)

        offset = cache[0].offset if cache is not None else 0
        q_pe = self.rope(q_pe, offset)
        k_pe = self.rope(k_pe, offset)

        kv_latent = mx.expand_dims(kv_latent, axis=1)
        if cache is not None:
            kv_latent, k_pe = cache[0].update_and_fetch(kv_latent, k_pe)
        else:
            cache = [None] * 2

        # topk: compute on full layers, reuse the shared one otherwise.
        if self.is_full and self.indexer is not None:
            topk_indices = self.indexer(x, qr, mask, cache=cache[1])
        else:
            topk_indices = shared_topk

        if topk_indices is not None:
            if L == 1:
                idx = topk_indices[:, :, 0, :, None]
                kv_latent = mx.take_along_axis(
                    kv_latent,
                    mx.broadcast_to(idx, idx.shape[:-1] + (kv_latent.shape[-1],)),
                    axis=2,
                )
                k_pe = mx.take_along_axis(
                    k_pe,
                    mx.broadcast_to(idx, idx.shape[:-1] + (k_pe.shape[-1],)),
                    axis=2,
                )
                if mask is not None:
                    mask = mx.take_along_axis(mask, topk_indices, axis=-1)
            else:
                shape = list(topk_indices.shape)
                shape[-1] = kv_latent.shape[2]
                sparse_mask = mx.zeros(shape, dtype=mx.bool_)
                sparse_mask = mx.put_along_axis(
                    sparse_mask, topk_indices, mx.array(True), axis=-1
                )
                if mask is not None:
                    sparse_mask = sparse_mask & mask
                mask = sparse_mask

        # keep the indexer cache in the graph only when this layer has one
        if (self.is_full and cache is not None and cache[0] is not None
                and cache[1] is not None):
            cache[0].keys = mx.depends(
                cache[0].keys, (cache[1].keys, cache[1].values))

        pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
        if mask is not None:
            pe_scores = mx.where(
                mask, pe_scores,
                mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype))

        if L == 1:
            q_nope = self.embed_q(q_nope)
            k = v = kv_latent
        else:
            k = self.embed_q(kv_latent, transpose=False)
            v = self.unembed_out(kv_latent)

        output = dsv32.scaled_dot_product_attention(
            q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores)
        if L == 1:
            output = self.unembed_out(output)

        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.o_proj(output), topk_indices


class GlmDsaDecoderLayer(nn.Module):
    def __init__(self, config: ModelArgs, layer_idx: int):
        super().__init__()
        self.is_full = _is_full(config, layer_idx)
        self.self_attn = GlmDsaAttention(config, self.is_full)
        self.mlp = (
            dsv32.DeepseekV32MoE(config)
            if (config.n_routed_experts is not None
                and layer_idx >= config.first_k_dense_replace
                and layer_idx % config.moe_layer_freq == 0)
            else dsv32.DeepseekV32MLP(config))
        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps)

    def __call__(self, x, mask=None, cache=None, shared_topk=None):
        r, topk = self.self_attn(self.input_layernorm(x), mask, cache, shared_topk)
        h = x + r
        r = self.mlp(self.post_attention_layernorm(h))
        return h + r, topk


class GlmDsaModel(dsv32.DeepseekV32Model):
    def __init__(self, config: ModelArgs):
        nn.Module.__init__(self)
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = [
            GlmDsaDecoderLayer(config, idx)
            for idx in range(config.num_hidden_layers)
        ]
        self.start_idx = 0
        self.end_idx = len(self.layers)
        self.num_layers = self.end_idx
        self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pipeline_rank = 0
        self.pipeline_size = 1

    def __call__(self, x, cache=None):
        h = self.embed_tokens(x)
        if cache is None:
            cache = [None] * self.num_layers
        mask = dsv32.create_attention_mask(
            h, cache[0][0] if cache[0] else None, return_array=True)
        shared_topk = None
        for i in range(self.num_layers):
            layer = self.layers[self.start_idx + i]
            h, topk = layer(h, mask, cache[i], shared_topk)
            if layer.is_full:
                shared_topk = topk     # propagate to subsequent shared layers
            # Incremental eval so the lazy graph doesn't hold ALL 78 layers'
            # weights at once — keeps the working set ~1 layer, lets mmap page
            # out used experts. Critical for running a >RAM model on 128GB.
            if GLM_STREAM_EVAL:
                mx.eval(h)
                if shared_topk is not None:
                    mx.eval(shared_topk)
        return self.norm(h)


class Model(dsv32.Model):
    def __init__(self, config: ModelArgs):
        nn.Module.__init__(self)
        self.args = config
        self.model_type = config.model_type
        self.model = GlmDsaModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def make_cache(self):
        # Shared DSA layers never run their indexer (the indexer call + the
        # mx.depends are both guarded by is_full), so the base make_cache's second
        # KVCache stays unpopulated (keys=None) and generate.py's per-prompt
        # `mx.eval([c.state for c in cache])` crashes ('NoneType' has no 'shape').
        # Give shared layers ONLY the kv cache: generation is unchanged (they never
        # touch cache[1]) and every cache now has a valid .state/from_state -> the
        # prompt-cache TTFT speedup works (set PROMPT_CACHE / --prompt-cache-size).
        from mlx_lm.models.cache import CacheList, KVCache
        caches = []
        for layer in self.model.layers:
            full = getattr(layer, "is_full", True) and getattr(
                getattr(layer, "self_attn", None), "indexer", None) is not None
            caches.append(CacheList(KVCache(), KVCache()) if full
                          else CacheList(KVCache()))
        return caches