philipjohnbasile commited on
Commit
7508877
·
verified ·
1 Parent(s): 837d19a

Upload glm_moe_dsa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. glm_moe_dsa.py +255 -0
glm_moe_dsa.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2025 Apple Inc.
2
+ #
3
+ # GLM-5 / GLM-5.2 (model_type: "glm_moe_dsa") — a DeepSeek-V3.2-style MoE with MLA +
4
+ # DeepSeek Sparse Attention (DSA). Built on top of the `deepseek_v32` model; the only
5
+ # architectural difference is the DSA `indexer_types` (full / shared) scheme: GLM places
6
+ # the lightning-indexer weights only on `'full'` layers, while `'shared'` layers reuse
7
+ # the top-k indices computed by the most recent full layer (in groups of `index_topk_freq`).
8
+ # A naive port that builds an indexer on every layer fails to load ("Missing parameters"
9
+ # on the shared layers); this handles both layer kinds.
10
+
11
+ import os
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List, Optional
14
+
15
+ import mlx.core as mx
16
+ import mlx.nn as nn
17
+
18
+ from .base import BaseModelArgs
19
+ from . import deepseek_v32 as dsv32
20
+
21
+ # Stream eval per layer so a >RAM model can run on 128GB via mmap paging.
22
+ # Set GLM_STREAM_EVAL=0 to disable (faster when the model fits in RAM, e.g. pruned).
23
+ GLM_STREAM_EVAL = os.environ.get("GLM_STREAM_EVAL", "1") != "0"
24
+
25
+
26
+ @dataclass
27
+ class ModelArgs(BaseModelArgs):
28
+ model_type: str
29
+ vocab_size: int
30
+ hidden_size: int
31
+ index_head_dim: int
32
+ index_n_heads: int
33
+ index_topk: int
34
+ intermediate_size: int
35
+ moe_intermediate_size: int
36
+ num_hidden_layers: int
37
+ num_attention_heads: int
38
+ num_key_value_heads: int
39
+ n_shared_experts: Optional[int]
40
+ n_routed_experts: Optional[int]
41
+ routed_scaling_factor: float
42
+ kv_lora_rank: int
43
+ q_lora_rank: int
44
+ qk_rope_head_dim: int
45
+ v_head_dim: int
46
+ qk_nope_head_dim: int
47
+ topk_method: str
48
+ scoring_func: str
49
+ norm_topk_prob: bool
50
+ n_group: int
51
+ topk_group: int
52
+ num_experts_per_tok: int
53
+ moe_layer_freq: int
54
+ first_k_dense_replace: int
55
+ max_position_embeddings: int
56
+ rms_norm_eps: float
57
+ rope_parameters: Dict
58
+ attention_bias: bool
59
+ # GLM-5.2 DSA sharing: per-layer 'full' | 'shared'; full layers own an indexer.
60
+ indexer_types: Optional[List[str]] = None
61
+ index_topk_freq: int = 4
62
+ rope_scaling: Dict = None
63
+ rope_theta: Optional[float] = None
64
+
65
+ def __post_init__(self):
66
+ self.rope_scaling = self.rope_parameters
67
+ self.rope_theta = self.rope_parameters["rope_theta"]
68
+
69
+
70
+ def _is_full(config: ModelArgs, layer_idx: int) -> bool:
71
+ """Does this layer own a DSA indexer? Default to all-full if unspecified."""
72
+ if not config.indexer_types:
73
+ return True
74
+ if layer_idx < len(config.indexer_types):
75
+ return config.indexer_types[layer_idx] == "full"
76
+ return True
77
+
78
+
79
+ class GlmDsaAttention(dsv32.DeepseekV32Attention):
80
+ """DeepSeek-V3.2 attention, but the indexer exists only on 'full' layers.
81
+ 'shared' layers receive `shared_topk` (the full layer's topk) and reuse it."""
82
+
83
+ def __init__(self, config: ModelArgs, is_full: bool):
84
+ super().__init__(config)
85
+ self.is_full = is_full
86
+ if not is_full:
87
+ # drop the indexer so no indexer weights are expected on this layer
88
+ self.indexer = None
89
+
90
+ def __call__(self, x, mask=None, cache=None, shared_topk=None):
91
+ B, L, D = x.shape
92
+
93
+ qr = self.q_a_layernorm(self.q_a_proj(x))
94
+ q = self.q_b_proj(qr)
95
+ q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
96
+ q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
97
+ compressed_kv = self.kv_a_proj_with_mqa(x)
98
+ compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
99
+ k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
100
+ kv_latent = self.kv_a_layernorm(compressed_kv)
101
+
102
+ offset = cache[0].offset if cache is not None else 0
103
+ q_pe = self.rope(q_pe, offset)
104
+ k_pe = self.rope(k_pe, offset)
105
+
106
+ kv_latent = mx.expand_dims(kv_latent, axis=1)
107
+ if cache is not None:
108
+ kv_latent, k_pe = cache[0].update_and_fetch(kv_latent, k_pe)
109
+ else:
110
+ cache = [None] * 2
111
+
112
+ # topk: compute on full layers, reuse the shared one otherwise.
113
+ if self.is_full and self.indexer is not None:
114
+ topk_indices = self.indexer(x, qr, mask, cache=cache[1])
115
+ else:
116
+ topk_indices = shared_topk
117
+
118
+ if topk_indices is not None:
119
+ if L == 1:
120
+ idx = topk_indices[:, :, 0, :, None]
121
+ kv_latent = mx.take_along_axis(
122
+ kv_latent,
123
+ mx.broadcast_to(idx, idx.shape[:-1] + (kv_latent.shape[-1],)),
124
+ axis=2,
125
+ )
126
+ k_pe = mx.take_along_axis(
127
+ k_pe,
128
+ mx.broadcast_to(idx, idx.shape[:-1] + (k_pe.shape[-1],)),
129
+ axis=2,
130
+ )
131
+ if mask is not None:
132
+ mask = mx.take_along_axis(mask, topk_indices, axis=-1)
133
+ else:
134
+ shape = list(topk_indices.shape)
135
+ shape[-1] = kv_latent.shape[2]
136
+ sparse_mask = mx.zeros(shape, dtype=mx.bool_)
137
+ sparse_mask = mx.put_along_axis(
138
+ sparse_mask, topk_indices, mx.array(True), axis=-1
139
+ )
140
+ if mask is not None:
141
+ sparse_mask = sparse_mask & mask
142
+ mask = sparse_mask
143
+
144
+ # keep the indexer cache in the graph only when this layer has one
145
+ if (self.is_full and cache is not None and cache[0] is not None
146
+ and cache[1] is not None):
147
+ cache[0].keys = mx.depends(
148
+ cache[0].keys, (cache[1].keys, cache[1].values))
149
+
150
+ pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
151
+ if mask is not None:
152
+ pe_scores = mx.where(
153
+ mask, pe_scores,
154
+ mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype))
155
+
156
+ if L == 1:
157
+ q_nope = self.embed_q(q_nope)
158
+ k = v = kv_latent
159
+ else:
160
+ k = self.embed_q(kv_latent, transpose=False)
161
+ v = self.unembed_out(kv_latent)
162
+
163
+ output = dsv32.scaled_dot_product_attention(
164
+ q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores)
165
+ if L == 1:
166
+ output = self.unembed_out(output)
167
+
168
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
169
+ return self.o_proj(output), topk_indices
170
+
171
+
172
+ class GlmDsaDecoderLayer(nn.Module):
173
+ def __init__(self, config: ModelArgs, layer_idx: int):
174
+ super().__init__()
175
+ self.is_full = _is_full(config, layer_idx)
176
+ self.self_attn = GlmDsaAttention(config, self.is_full)
177
+ self.mlp = (
178
+ dsv32.DeepseekV32MoE(config)
179
+ if (config.n_routed_experts is not None
180
+ and layer_idx >= config.first_k_dense_replace
181
+ and layer_idx % config.moe_layer_freq == 0)
182
+ else dsv32.DeepseekV32MLP(config))
183
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
184
+ self.post_attention_layernorm = nn.RMSNorm(
185
+ config.hidden_size, eps=config.rms_norm_eps)
186
+
187
+ def __call__(self, x, mask=None, cache=None, shared_topk=None):
188
+ r, topk = self.self_attn(self.input_layernorm(x), mask, cache, shared_topk)
189
+ h = x + r
190
+ r = self.mlp(self.post_attention_layernorm(h))
191
+ return h + r, topk
192
+
193
+
194
+ class GlmDsaModel(dsv32.DeepseekV32Model):
195
+ def __init__(self, config: ModelArgs):
196
+ nn.Module.__init__(self)
197
+ self.vocab_size = config.vocab_size
198
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
199
+ self.layers = [
200
+ GlmDsaDecoderLayer(config, idx)
201
+ for idx in range(config.num_hidden_layers)
202
+ ]
203
+ self.start_idx = 0
204
+ self.end_idx = len(self.layers)
205
+ self.num_layers = self.end_idx
206
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
207
+ self.pipeline_rank = 0
208
+ self.pipeline_size = 1
209
+
210
+ def __call__(self, x, cache=None):
211
+ h = self.embed_tokens(x)
212
+ if cache is None:
213
+ cache = [None] * self.num_layers
214
+ mask = dsv32.create_attention_mask(
215
+ h, cache[0][0] if cache[0] else None, return_array=True)
216
+ shared_topk = None
217
+ for i in range(self.num_layers):
218
+ layer = self.layers[self.start_idx + i]
219
+ h, topk = layer(h, mask, cache[i], shared_topk)
220
+ if layer.is_full:
221
+ shared_topk = topk # propagate to subsequent shared layers
222
+ # Incremental eval so the lazy graph doesn't hold ALL 78 layers'
223
+ # weights at once — keeps the working set ~1 layer, lets mmap page
224
+ # out used experts. Critical for running a >RAM model on 128GB.
225
+ if GLM_STREAM_EVAL:
226
+ mx.eval(h)
227
+ if shared_topk is not None:
228
+ mx.eval(shared_topk)
229
+ return self.norm(h)
230
+
231
+
232
+ class Model(dsv32.Model):
233
+ def __init__(self, config: ModelArgs):
234
+ nn.Module.__init__(self)
235
+ self.args = config
236
+ self.model_type = config.model_type
237
+ self.model = GlmDsaModel(config)
238
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
239
+
240
+ def make_cache(self):
241
+ # Shared DSA layers never run their indexer (the indexer call + the
242
+ # mx.depends are both guarded by is_full), so the base make_cache's second
243
+ # KVCache stays unpopulated (keys=None) and generate.py's per-prompt
244
+ # `mx.eval([c.state for c in cache])` crashes ('NoneType' has no 'shape').
245
+ # Give shared layers ONLY the kv cache: generation is unchanged (they never
246
+ # touch cache[1]) and every cache now has a valid .state/from_state -> the
247
+ # prompt-cache TTFT speedup works (set PROMPT_CACHE / --prompt-cache-size).
248
+ from mlx_lm.models.cache import CacheList, KVCache
249
+ caches = []
250
+ for layer in self.model.layers:
251
+ full = getattr(layer, "is_full", True) and getattr(
252
+ getattr(layer, "self_attn", None), "indexer", None) is not None
253
+ caches.append(CacheList(KVCache(), KVCache()) if full
254
+ else CacheList(KVCache()))
255
+ return caches