philipjohnbasile's picture
Upload glm_moe_dsa.py with huggingface_hub
7508877 verified
Raw
History Blame
10.2 kB
# Copyright © 2025 Apple Inc.
#
# GLM-5 / GLM-5.2 (model_type: "glm_moe_dsa") — a DeepSeek-V3.2-style MoE with MLA +
# DeepSeek Sparse Attention (DSA). Built on top of the `deepseek_v32` model; the only
# architectural difference is the DSA `indexer_types` (full / shared) scheme: GLM places
# the lightning-indexer weights only on `'full'` layers, while `'shared'` layers reuse
# the top-k indices computed by the most recent full layer (in groups of `index_topk_freq`).
# A naive port that builds an indexer on every layer fails to load ("Missing parameters"
# on the shared layers); this handles both layer kinds.
import os
from dataclasses import dataclass
from typing import Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from . import deepseek_v32 as dsv32
# Stream eval per layer so a >RAM model can run on 128GB via mmap paging.
# Set GLM_STREAM_EVAL=0 to disable (faster when the model fits in RAM, e.g. pruned).
GLM_STREAM_EVAL = os.environ.get("GLM_STREAM_EVAL", "1") != "0"
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict
attention_bias: bool
# GLM-5.2 DSA sharing: per-layer 'full' | 'shared'; full layers own an indexer.
indexer_types: Optional[List[str]] = None
index_topk_freq: int = 4
rope_scaling: Dict = None
rope_theta: Optional[float] = None
def __post_init__(self):
self.rope_scaling = self.rope_parameters
self.rope_theta = self.rope_parameters["rope_theta"]
def _is_full(config: ModelArgs, layer_idx: int) -> bool:
"""Does this layer own a DSA indexer? Default to all-full if unspecified."""
if not config.indexer_types:
return True
if layer_idx < len(config.indexer_types):
return config.indexer_types[layer_idx] == "full"
return True
class GlmDsaAttention(dsv32.DeepseekV32Attention):
"""DeepSeek-V3.2 attention, but the indexer exists only on 'full' layers.
'shared' layers receive `shared_topk` (the full layer's topk) and reuse it."""
def __init__(self, config: ModelArgs, is_full: bool):
super().__init__(config)
self.is_full = is_full
if not is_full:
# drop the indexer so no indexer weights are expected on this layer
self.indexer = None
def __call__(self, x, mask=None, cache=None, shared_topk=None):
B, L, D = x.shape
qr = self.q_a_layernorm(self.q_a_proj(x))
q = self.q_b_proj(qr)
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv_latent = self.kv_a_layernorm(compressed_kv)
offset = cache[0].offset if cache is not None else 0
q_pe = self.rope(q_pe, offset)
k_pe = self.rope(k_pe, offset)
kv_latent = mx.expand_dims(kv_latent, axis=1)
if cache is not None:
kv_latent, k_pe = cache[0].update_and_fetch(kv_latent, k_pe)
else:
cache = [None] * 2
# topk: compute on full layers, reuse the shared one otherwise.
if self.is_full and self.indexer is not None:
topk_indices = self.indexer(x, qr, mask, cache=cache[1])
else:
topk_indices = shared_topk
if topk_indices is not None:
if L == 1:
idx = topk_indices[:, :, 0, :, None]
kv_latent = mx.take_along_axis(
kv_latent,
mx.broadcast_to(idx, idx.shape[:-1] + (kv_latent.shape[-1],)),
axis=2,
)
k_pe = mx.take_along_axis(
k_pe,
mx.broadcast_to(idx, idx.shape[:-1] + (k_pe.shape[-1],)),
axis=2,
)
if mask is not None:
mask = mx.take_along_axis(mask, topk_indices, axis=-1)
else:
shape = list(topk_indices.shape)
shape[-1] = kv_latent.shape[2]
sparse_mask = mx.zeros(shape, dtype=mx.bool_)
sparse_mask = mx.put_along_axis(
sparse_mask, topk_indices, mx.array(True), axis=-1
)
if mask is not None:
sparse_mask = sparse_mask & mask
mask = sparse_mask
# keep the indexer cache in the graph only when this layer has one
if (self.is_full and cache is not None and cache[0] is not None
and cache[1] is not None):
cache[0].keys = mx.depends(
cache[0].keys, (cache[1].keys, cache[1].values))
pe_scores = (q_pe * self.scale) @ k_pe.swapaxes(-1, -2)
if mask is not None:
pe_scores = mx.where(
mask, pe_scores,
mx.array(mx.finfo(pe_scores.dtype).min, pe_scores.dtype))
if L == 1:
q_nope = self.embed_q(q_nope)
k = v = kv_latent
else:
k = self.embed_q(kv_latent, transpose=False)
v = self.unembed_out(kv_latent)
output = dsv32.scaled_dot_product_attention(
q_nope, k, v, cache=cache, scale=self.scale, mask=pe_scores)
if L == 1:
output = self.unembed_out(output)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), topk_indices
class GlmDsaDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.is_full = _is_full(config, layer_idx)
self.self_attn = GlmDsaAttention(config, self.is_full)
self.mlp = (
dsv32.DeepseekV32MoE(config)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0)
else dsv32.DeepseekV32MLP(config))
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def __call__(self, x, mask=None, cache=None, shared_topk=None):
r, topk = self.self_attn(self.input_layernorm(x), mask, cache, shared_topk)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r, topk
class GlmDsaModel(dsv32.DeepseekV32Model):
def __init__(self, config: ModelArgs):
nn.Module.__init__(self)
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
GlmDsaDecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def __call__(self, x, cache=None):
h = self.embed_tokens(x)
if cache is None:
cache = [None] * self.num_layers
mask = dsv32.create_attention_mask(
h, cache[0][0] if cache[0] else None, return_array=True)
shared_topk = None
for i in range(self.num_layers):
layer = self.layers[self.start_idx + i]
h, topk = layer(h, mask, cache[i], shared_topk)
if layer.is_full:
shared_topk = topk # propagate to subsequent shared layers
# Incremental eval so the lazy graph doesn't hold ALL 78 layers'
# weights at once — keeps the working set ~1 layer, lets mmap page
# out used experts. Critical for running a >RAM model on 128GB.
if GLM_STREAM_EVAL:
mx.eval(h)
if shared_topk is not None:
mx.eval(shared_topk)
return self.norm(h)
class Model(dsv32.Model):
def __init__(self, config: ModelArgs):
nn.Module.__init__(self)
self.args = config
self.model_type = config.model_type
self.model = GlmDsaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def make_cache(self):
# Shared DSA layers never run their indexer (the indexer call + the
# mx.depends are both guarded by is_full), so the base make_cache's second
# KVCache stays unpopulated (keys=None) and generate.py's per-prompt
# `mx.eval([c.state for c in cache])` crashes ('NoneType' has no 'shape').
# Give shared layers ONLY the kv cache: generation is unchanged (they never
# touch cache[1]) and every cache now has a valid .state/from_state -> the
# prompt-cache TTFT speedup works (set PROMPT_CACHE / --prompt-cache-size).
from mlx_lm.models.cache import CacheList, KVCache
caches = []
for layer in self.model.layers:
full = getattr(layer, "is_full", True) and getattr(
getattr(layer, "self_attn", None), "indexer", None) is not None
caches.append(CacheList(KVCache(), KVCache()) if full
else CacheList(KVCache()))
return caches