rahular commited on
Commit
2f915a7
·
verified ·
1 Parent(s): 8521c2a

Create sarvam.py

Browse files
Files changed (1) hide show
  1. sarvam.py +788 -0
sarvam.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ #
4
+ # Copyright 2026 Sarvam AI team. All rights reserved.
5
+ #
6
+ # This code is based on Llama, Deepseek, and Bailing MoE implementations
7
+ # in this library. It has been modified from its original forms to
8
+ # accommodate Sarvam's MoE architectures.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from __future__ import annotations
23
+
24
+ import math
25
+ from collections.abc import Iterable, Iterator
26
+ from itertools import islice
27
+
28
+ import torch
29
+ from torch import nn
30
+
31
+ from vllm.config import CacheConfig, ParallelConfig, VllmConfig
32
+ from vllm.distributed import (
33
+ get_pp_group,
34
+ get_tensor_model_parallel_rank,
35
+ get_tensor_model_parallel_world_size,
36
+ )
37
+ from vllm.model_executor.layers.activation import SiluAndMul
38
+ from vllm.model_executor.layers.fused_moe import SharedFusedMoE
39
+ from vllm.model_executor.layers.layernorm import RMSNorm
40
+ from vllm.model_executor.layers.linear import (
41
+ ColumnParallelLinear,
42
+ MergedColumnParallelLinear,
43
+ ReplicatedLinear,
44
+ RowParallelLinear,
45
+ )
46
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
+ from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
48
+ from vllm.model_executor.layers.quantization import QuantizationConfig
49
+ from vllm.model_executor.layers.rotary_embedding import get_rope
50
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
51
+ ParallelLMHead,
52
+ VocabParallelEmbedding,
53
+ )
54
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
55
+ from vllm.sequence import IntermediateTensors
56
+
57
+ from .bailing_moe import BailingMoeForCausalLM
58
+ from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
59
+ from .utils import (
60
+ AutoWeightsLoader,
61
+ PPMissingLayer,
62
+ is_pp_missing_parameter,
63
+ make_empty_intermediate_tensors_factory,
64
+ make_layers,
65
+ maybe_prefix,
66
+ )
67
+
68
+
69
+ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
70
+ if scale <= 1:
71
+ return 1.0
72
+ return 0.1 * mscale * math.log(scale) + 1.0
73
+
74
+
75
+ def _is_gate_expert_bias_name(name: str) -> bool:
76
+ return name.endswith(".mlp.gate.e_score_correction_bias") or name.endswith(
77
+ ".gate.e_score_correction_bias"
78
+ )
79
+
80
+
81
+ def _zero_mean_tensor(t: torch.Tensor) -> torch.Tensor:
82
+ if t.numel() == 0:
83
+ return t
84
+ return t - t.mean()
85
+
86
+
87
+ def _normalized_weights(
88
+ weights: Iterable[tuple[str, torch.Tensor]],
89
+ ) -> Iterator[tuple[str, torch.Tensor]]:
90
+ for name, w in weights:
91
+ if _is_gate_expert_bias_name(name):
92
+ yield name, _zero_mean_tensor(w)
93
+ else:
94
+ yield name, w
95
+
96
+
97
+ class SarvamMLAAttention(nn.Module):
98
+ def __init__(
99
+ self,
100
+ vllm_config: VllmConfig,
101
+ config,
102
+ cache_config: CacheConfig | None = None,
103
+ quant_config: QuantizationConfig | None = None,
104
+ prefix: str = "",
105
+ ) -> None:
106
+ super().__init__()
107
+
108
+ self.config = config
109
+ self.hidden_size = config.hidden_size
110
+ self.qk_nope_head_dim = config.qk_nope_head_dim
111
+ self.qk_rope_head_dim = config.qk_rope_head_dim
112
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
113
+ self.v_head_dim = config.v_head_dim
114
+
115
+ self.q_lora_rank = getattr(config, "q_lora_rank", None)
116
+ self.kv_lora_rank = config.kv_lora_rank
117
+
118
+ self.total_num_heads = config.num_attention_heads
119
+ tp_size = get_tensor_model_parallel_world_size()
120
+ assert self.total_num_heads % tp_size == 0
121
+ self.num_local_heads = self.total_num_heads // tp_size
122
+
123
+ self.scaling = self.qk_head_dim**-0.5
124
+ self.max_position_embeddings = config.max_position_embeddings
125
+
126
+ if self.q_lora_rank is not None:
127
+ self.q_a_proj = ReplicatedLinear(
128
+ self.hidden_size,
129
+ self.q_lora_rank,
130
+ bias=False,
131
+ quant_config=quant_config,
132
+ prefix=f"{prefix}.q_a_proj",
133
+ )
134
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
135
+ self.q_b_proj = ColumnParallelLinear(
136
+ self.q_lora_rank,
137
+ self.total_num_heads * self.qk_head_dim,
138
+ bias=False,
139
+ quant_config=quant_config,
140
+ prefix=f"{prefix}.q_b_proj",
141
+ )
142
+ self.q_proj = None # type: ignore
143
+ else:
144
+ self.q_proj = ColumnParallelLinear(
145
+ self.hidden_size,
146
+ self.total_num_heads * self.qk_head_dim,
147
+ bias=False,
148
+ quant_config=quant_config,
149
+ prefix=f"{prefix}.q_proj",
150
+ )
151
+ self.q_a_proj = None # type: ignore
152
+ self.q_a_layernorm = None # type: ignore
153
+ self.q_b_proj = None # type: ignore
154
+
155
+ # KV latent (MQA-style) A-proj
156
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
157
+ self.hidden_size,
158
+ self.kv_lora_rank + self.qk_rope_head_dim,
159
+ bias=False,
160
+ quant_config=quant_config,
161
+ prefix=f"{prefix}.kv_a_proj_with_mqa",
162
+ )
163
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
164
+
165
+ # KV B-proj produces per-head K_nope and V
166
+ self.kv_b_proj = ColumnParallelLinear(
167
+ self.kv_lora_rank,
168
+ self.total_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
169
+ bias=False,
170
+ quant_config=quant_config,
171
+ prefix=f"{prefix}.kv_b_proj",
172
+ )
173
+
174
+ self.o_proj = RowParallelLinear(
175
+ self.total_num_heads * self.v_head_dim,
176
+ self.hidden_size,
177
+ bias=False,
178
+ quant_config=quant_config,
179
+ prefix=f"{prefix}.o_proj",
180
+ )
181
+
182
+ self.rotary_emb = get_rope(
183
+ self.qk_rope_head_dim,
184
+ # rotary_dim=self.qk_rope_head_dim,
185
+ max_position=config.max_position_embeddings,
186
+ rope_parameters=config.rope_parameters,
187
+ is_neox_style=False,
188
+ )
189
+
190
+ if config.rope_parameters.get("rope_type", None) == "deepseek_yarn":
191
+ mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
192
+ scaling_factor = config.rope_parameters["factor"]
193
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
194
+ self.scaling = self.scaling * mscale * mscale
195
+
196
+ mla_modules = MLAModules(
197
+ kv_a_layernorm=self.kv_a_layernorm,
198
+ kv_b_proj=self.kv_b_proj,
199
+ rotary_emb=self.rotary_emb,
200
+ o_proj=self.o_proj,
201
+ fused_qkv_a_proj=None,
202
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
203
+ q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
204
+ q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
205
+ q_proj=self.q_proj if self.q_lora_rank is None else None,
206
+ indexer=None,
207
+ indexer_rotary_emb=None,
208
+ is_sparse=False,
209
+ topk_indices_buffer=None,
210
+ )
211
+
212
+ self.mla_attn = MultiHeadLatentAttentionWrapper(
213
+ self.hidden_size,
214
+ self.num_local_heads,
215
+ self.scaling,
216
+ self.qk_nope_head_dim,
217
+ self.qk_rope_head_dim,
218
+ self.v_head_dim,
219
+ self.q_lora_rank,
220
+ self.kv_lora_rank,
221
+ mla_modules,
222
+ cache_config=cache_config,
223
+ quant_config=quant_config,
224
+ prefix=prefix,
225
+ )
226
+
227
+ def forward(
228
+ self,
229
+ positions: torch.Tensor,
230
+ hidden_states: torch.Tensor,
231
+ ) -> torch.Tensor:
232
+ return self.mla_attn(positions, hidden_states, llama_4_scaling=None)
233
+
234
+
235
+ class SarvamMLAMLP(nn.Module):
236
+ def __init__(
237
+ self,
238
+ intermediate_size: int,
239
+ config,
240
+ quant_config: QuantizationConfig | None = None,
241
+ reduce_results: bool = True,
242
+ prefix: str = "",
243
+ ) -> None:
244
+ super().__init__()
245
+
246
+ self.gate_up_proj = MergedColumnParallelLinear(
247
+ config.hidden_size,
248
+ [intermediate_size] * 2,
249
+ bias=False,
250
+ quant_config=quant_config,
251
+ prefix=f"{prefix}.gate_up_proj",
252
+ )
253
+ self.down_proj = RowParallelLinear(
254
+ intermediate_size,
255
+ config.hidden_size,
256
+ bias=False,
257
+ quant_config=quant_config,
258
+ reduce_results=reduce_results,
259
+ prefix=f"{prefix}.down_proj",
260
+ )
261
+ self.act_fn = SiluAndMul()
262
+
263
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
264
+ gate_up, _ = self.gate_up_proj(x)
265
+ x = self.act_fn(gate_up)
266
+ x, _ = self.down_proj(x)
267
+ return x
268
+
269
+
270
+ class SarvamMLAMoE(nn.Module):
271
+ def __init__(
272
+ self,
273
+ config,
274
+ parallel_config: ParallelConfig,
275
+ quant_config: QuantizationConfig | None = None,
276
+ prefix: str = "",
277
+ ) -> None:
278
+ super().__init__()
279
+
280
+ self.config = config
281
+ self.tp_size = get_tensor_model_parallel_world_size()
282
+ self.tp_rank = get_tensor_model_parallel_rank()
283
+ self.hidden_size = config.hidden_size
284
+
285
+ self.num_experts = config.num_experts
286
+ self.top_k = config.num_experts_per_tok
287
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 2.5)
288
+
289
+ self.n_group = getattr(config, "n_group", None)
290
+ self.topk_group = getattr(config, "topk_group", None)
291
+ self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
292
+
293
+ self.norm_expert_prob = getattr(config, "norm_topk_prob", True)
294
+
295
+ router_dtype_cfg = getattr(config, "router_dtype", "fp32")
296
+ if router_dtype_cfg is None:
297
+ self.router_dtype = None
298
+ elif router_dtype_cfg == "fp32":
299
+ self.router_dtype = torch.float32
300
+ else:
301
+ self.router_dtype = torch.bfloat16
302
+
303
+ self.gate = nn.Linear(
304
+ self.hidden_size,
305
+ self.num_experts,
306
+ bias=False,
307
+ dtype=self.router_dtype,
308
+ )
309
+
310
+ if getattr(config, "moe_router_enable_expert_bias", True):
311
+ self.gate.e_score_correction_bias = nn.Parameter(
312
+ torch.empty(
313
+ (self.num_experts,),
314
+ dtype=torch.float32,
315
+ )
316
+ )
317
+ else:
318
+ self.gate.e_score_correction_bias = None
319
+
320
+ self.score_function = getattr(config, "score_function", "sigmoid")
321
+ self.num_shared_experts = getattr(config, "num_shared_experts", 1)
322
+ if self.num_shared_experts > 0:
323
+ if hasattr(config, "moe_shared_expert_intermediate_size"):
324
+ shared_int = config.moe_shared_expert_intermediate_size
325
+ else:
326
+ shared_int = config.moe_intermediate_size
327
+ shared_int *= self.num_shared_experts
328
+ self.shared_experts = SarvamMLAMLP(
329
+ intermediate_size=shared_int,
330
+ config=config,
331
+ quant_config=quant_config,
332
+ reduce_results=False,
333
+ prefix=f"{prefix}.shared_experts",
334
+ )
335
+ else:
336
+ self.shared_experts = None
337
+
338
+ self.experts = SharedFusedMoE(
339
+ shared_experts=self.shared_experts,
340
+ num_experts=self.num_experts,
341
+ top_k=self.top_k,
342
+ hidden_size=self.hidden_size,
343
+ intermediate_size=config.moe_intermediate_size,
344
+ reduce_results=False,
345
+ renormalize=self.norm_expert_prob,
346
+ quant_config=quant_config,
347
+ prefix=f"{prefix}.experts",
348
+ scoring_func=self.score_function,
349
+ e_score_correction_bias=self.gate.e_score_correction_bias,
350
+ num_expert_group=self.n_group,
351
+ topk_group=self.topk_group,
352
+ use_grouped_topk=self.use_grouped_topk,
353
+ routed_scaling_factor=self.routed_scaling_factor,
354
+ )
355
+
356
+ def maybe_get_fused_moe(self) -> SharedFusedMoE:
357
+ return self.experts
358
+
359
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
360
+ num_tokens, hidden_dim = hidden_states.shape
361
+ hidden_states = hidden_states.view(-1, hidden_dim)
362
+ router_logits = self.gate(
363
+ hidden_states.to(self.router_dtype)
364
+ if self.router_dtype is not None
365
+ else hidden_states
366
+ )
367
+ router_logits = router_logits.to(hidden_states.dtype)
368
+ final_hidden = self.experts(
369
+ hidden_states=hidden_states,
370
+ router_logits=router_logits,
371
+ )
372
+
373
+ if self.shared_experts is not None:
374
+ shared_output, expert_output = final_hidden
375
+ else:
376
+ shared_output, expert_output = None, final_hidden
377
+
378
+ # expert_output *= self.routed_scaling_factor
379
+
380
+ if shared_output is not None:
381
+ expert_output = expert_output + shared_output
382
+
383
+ if self.tp_size > 1:
384
+ expert_output = self.experts.maybe_all_reduce_tensor_model_parallel(
385
+ expert_output
386
+ )
387
+
388
+ return expert_output.view(num_tokens, hidden_dim)
389
+
390
+
391
+ class SarvamMLABlock(nn.Module):
392
+ def __init__(
393
+ self,
394
+ vllm_config: VllmConfig,
395
+ prefix: str = "",
396
+ ) -> None:
397
+ super().__init__()
398
+ config = vllm_config.model_config.hf_config
399
+ cache_config = vllm_config.cache_config
400
+ quant_config = vllm_config.quant_config
401
+ parallel_config = vllm_config.parallel_config
402
+ layer_idx = int(prefix.split(".")[-1])
403
+ hidden_size = config.hidden_size
404
+ dense_intermediate = getattr(config, "intermediate_size", 16384)
405
+
406
+ self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
407
+ self.self_attn = SarvamMLAAttention(
408
+ vllm_config=vllm_config,
409
+ config=config,
410
+ cache_config=cache_config,
411
+ quant_config=quant_config,
412
+ prefix=f"{prefix}.self_attn",
413
+ )
414
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
415
+ use_moe = hasattr(config, "num_experts") and config.num_experts is not None
416
+ first_k_dense = getattr(config, "first_k_dense_replace", 1)
417
+ moe_layer_freq = getattr(config, "moe_layer_freq", 1)
418
+ if use_moe:
419
+ is_moe_layer = layer_idx >= first_k_dense and (
420
+ (layer_idx - first_k_dense) % moe_layer_freq == 0
421
+ )
422
+ else:
423
+ is_moe_layer = False
424
+
425
+ if is_moe_layer:
426
+ self.mlp = SarvamMLAMoE(
427
+ config=config,
428
+ parallel_config=parallel_config,
429
+ quant_config=quant_config,
430
+ prefix=f"{prefix}.mlp",
431
+ )
432
+ else:
433
+ self.mlp = SarvamMLAMLP(
434
+ intermediate_size=dense_intermediate,
435
+ config=config,
436
+ quant_config=quant_config,
437
+ reduce_results=True,
438
+ prefix=f"{prefix}.mlp",
439
+ )
440
+
441
+ def forward(
442
+ self,
443
+ hidden_states: torch.Tensor,
444
+ positions: torch.Tensor,
445
+ residual: torch.Tensor | None,
446
+ ) -> tuple[torch.Tensor, torch.Tensor]:
447
+ if residual is None:
448
+ residual = hidden_states
449
+ hidden_states = self.input_layernorm(hidden_states)
450
+ else:
451
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
452
+
453
+ hidden_states = self.self_attn(
454
+ positions=positions,
455
+ hidden_states=hidden_states,
456
+ )
457
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
458
+ hidden_states = self.mlp(hidden_states)
459
+ return hidden_states, residual
460
+
461
+
462
+ class SarvamMLAModel(nn.Module):
463
+ def __init__(
464
+ self,
465
+ *,
466
+ vllm_config: VllmConfig,
467
+ prefix: str = "",
468
+ ) -> None:
469
+ super().__init__()
470
+
471
+ config = vllm_config.model_config.hf_config
472
+ quant_config = vllm_config.quant_config
473
+
474
+ self.config = config
475
+ self.vocab_size = config.vocab_size
476
+ self.embed_dim = config.hidden_size
477
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
478
+ if get_pp_group().is_first_rank or (
479
+ self.tie_word_embeddings and get_pp_group().is_last_rank
480
+ ):
481
+ self.embed_tokens = VocabParallelEmbedding(
482
+ self.vocab_size,
483
+ self.embed_dim,
484
+ quant_config=quant_config,
485
+ prefix=f"{prefix}.embed_tokens",
486
+ )
487
+ else:
488
+ self.embed_tokens = PPMissingLayer()
489
+
490
+ self.embedding_dropout = torch.nn.Dropout(
491
+ getattr(config, "embedding_dropout", 0.0)
492
+ )
493
+ self.start_layer, self.end_layer, self.layers = make_layers(
494
+ config.num_hidden_layers,
495
+ lambda prefix: SarvamMLABlock(
496
+ vllm_config=vllm_config,
497
+ prefix=prefix,
498
+ ),
499
+ prefix=f"{prefix}.layers",
500
+ )
501
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
502
+ ["hidden_states", "residual"], config.hidden_size
503
+ )
504
+ if get_pp_group().is_last_rank:
505
+ self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
506
+ else:
507
+ self.norm = PPMissingLayer()
508
+
509
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
510
+ return self.embed_tokens(input_ids)
511
+
512
+ def forward(
513
+ self,
514
+ input_ids: torch.Tensor,
515
+ positions: torch.Tensor,
516
+ intermediate_tensors: IntermediateTensors | None,
517
+ inputs_embeds: torch.Tensor | None = None,
518
+ ) -> torch.Tensor | IntermediateTensors:
519
+ if get_pp_group().is_first_rank:
520
+ if inputs_embeds is not None:
521
+ hidden_states = inputs_embeds
522
+ else:
523
+ hidden_states = self.embed_input_ids(input_ids)
524
+ hidden_states = self.embedding_dropout(hidden_states)
525
+ residual = None
526
+ else:
527
+ assert intermediate_tensors is not None
528
+ hidden_states = intermediate_tensors["hidden_states"]
529
+ residual = intermediate_tensors["residual"]
530
+
531
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
532
+ hidden_states, residual = layer(
533
+ hidden_states,
534
+ positions,
535
+ residual,
536
+ )
537
+ if not get_pp_group().is_last_rank:
538
+ return IntermediateTensors(
539
+ {"hidden_states": hidden_states, "residual": residual}
540
+ )
541
+ if residual is None:
542
+ hidden_states = self.norm(hidden_states)
543
+ else:
544
+ hidden_states, _ = self.norm(hidden_states, residual)
545
+ return hidden_states
546
+
547
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
548
+ return SharedFusedMoE.make_expert_params_mapping(
549
+ self,
550
+ ckpt_gate_proj_name="gate_proj",
551
+ ckpt_down_proj_name="down_proj",
552
+ ckpt_up_proj_name="up_proj",
553
+ num_experts=self.config.num_experts,
554
+ )
555
+
556
+ def load_weights(
557
+ self,
558
+ weights: Iterable[tuple[str, torch.Tensor]],
559
+ ) -> set[str]:
560
+ """Load weights with stacked gate+up and MoE expert remapping."""
561
+ weights = _normalized_weights(weights)
562
+ stacked_params_mapping = [
563
+ ("gate_up_proj", "gate_proj", 0),
564
+ ("gate_up_proj", "up_proj", 1),
565
+ ]
566
+
567
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
568
+ loaded_params: set[str] = set()
569
+ expert_params_mapping = self.get_expert_mapping()
570
+
571
+ for name, loaded_weight in weights:
572
+ for param_name, weight_name, shard_id in stacked_params_mapping:
573
+ if weight_name not in name:
574
+ continue
575
+ if "mlp.experts" in name:
576
+ continue
577
+ new_name = name.replace(weight_name, param_name)
578
+ if new_name.endswith(".bias") and new_name not in params_dict:
579
+ continue
580
+ if new_name not in params_dict:
581
+ continue
582
+ if is_pp_missing_parameter(new_name, self):
583
+ continue
584
+
585
+ param = params_dict[new_name]
586
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
587
+ weight_loader(param, loaded_weight, shard_id)
588
+ loaded_params.add(new_name)
589
+ break
590
+ else:
591
+ mapped = False
592
+ for (
593
+ param_name,
594
+ weight_name,
595
+ expert_id,
596
+ shard_id,
597
+ ) in expert_params_mapping:
598
+ if weight_name not in name:
599
+ continue
600
+
601
+ new_name = name.replace(weight_name, param_name)
602
+ if is_pp_missing_parameter(new_name, self):
603
+ continue
604
+ if new_name not in params_dict:
605
+ continue
606
+
607
+ param = params_dict[new_name]
608
+ weight_loader = getattr(
609
+ param, "weight_loader", default_weight_loader
610
+ )
611
+ weight_loader(
612
+ param,
613
+ loaded_weight,
614
+ name,
615
+ shard_id=shard_id,
616
+ expert_id=expert_id,
617
+ )
618
+ loaded_params.add(new_name)
619
+ mapped = True
620
+ break
621
+
622
+ if mapped:
623
+ continue
624
+
625
+ if name.endswith(".bias") and name not in params_dict:
626
+ continue
627
+ if name not in params_dict:
628
+ continue
629
+ if is_pp_missing_parameter(name, self):
630
+ continue
631
+
632
+ param = params_dict[name]
633
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
634
+ weight_loader(param, loaded_weight)
635
+ loaded_params.add(name)
636
+
637
+ return loaded_params
638
+
639
+
640
+ class SarvamMixtureOfExperts(MixtureOfExperts):
641
+ def extract_moe_parameters(self, example_moe: SarvamMLAMoE | None) -> None:
642
+ if example_moe is None:
643
+ raise RuntimeError("No SarvamMLAMoE layer found in model.layers.")
644
+
645
+ self.num_logical_experts = example_moe.num_experts
646
+ self.num_routed_experts = example_moe.num_experts # routed pool size
647
+ self.num_shared_experts = getattr(example_moe.config, "num_shared_experts", 1)
648
+
649
+ self.num_physical_experts = self.num_logical_experts
650
+ self.num_local_physical_experts = self.num_logical_experts
651
+ self.num_redundant_experts = 0
652
+
653
+ def update_physical_experts_metadata(
654
+ self,
655
+ num_physical_experts: int,
656
+ num_local_physical_experts: int,
657
+ ) -> None:
658
+ self.num_physical_experts = num_physical_experts
659
+ self.num_local_physical_experts = num_local_physical_experts
660
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
661
+
662
+ for moe in self.moe_mlp_layers:
663
+ moe.n_physical_experts = num_physical_experts
664
+ moe.n_local_physical_experts = num_local_physical_experts
665
+ moe.n_redundant_experts = self.num_redundant_experts
666
+
667
+ fused = moe.experts
668
+ if hasattr(fused, "n_local_physical_experts"):
669
+ fused.n_local_physical_experts = num_local_physical_experts
670
+ if hasattr(fused, "n_physical_experts"):
671
+ fused.n_physical_experts = num_physical_experts
672
+ if hasattr(fused, "n_redundant_experts"):
673
+ fused.n_redundant_experts = self.num_redundant_experts
674
+ if hasattr(fused, "update_expert_map"):
675
+ fused.update_expert_map()
676
+
677
+ def set_eplb_state(self, eplb_state) -> None:
678
+ self.eplb_state = eplb_state
679
+ for moe in self.moe_layers:
680
+ if hasattr(moe, "set_eplb_state"):
681
+ moe.set_eplb_state(eplb_state)
682
+
683
+
684
+ class SarvamMLAForCausalLM(nn.Module, SupportsPP, SupportsLoRA, SarvamMixtureOfExperts):
685
+ packed_modules_mapping = {
686
+ "q_proj": ["q_proj"],
687
+ "q_a_proj": ["q_a_proj"],
688
+ "q_b_proj": ["q_b_proj"],
689
+ "kv_a_proj_with_mqa": ["kv_a_proj_with_mqa"],
690
+ "kv_b_proj": ["kv_b_proj"],
691
+ "gate_up_proj": ["gate_proj", "up_proj"],
692
+ }
693
+
694
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
695
+ super().__init__()
696
+ config = vllm_config.model_config.hf_config
697
+ quant_config = vllm_config.quant_config
698
+ self.config = config
699
+ self.quant_config = quant_config
700
+
701
+ self.model = SarvamMLAModel(
702
+ vllm_config=vllm_config,
703
+ prefix=maybe_prefix(prefix, "model"),
704
+ )
705
+
706
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
707
+ if get_pp_group().is_last_rank:
708
+ if self.tie_word_embeddings:
709
+ self.lm_head = self.model.embed_tokens
710
+ else:
711
+ self.lm_head = ParallelLMHead(
712
+ config.vocab_size,
713
+ config.hidden_size,
714
+ quant_config=quant_config,
715
+ prefix=maybe_prefix(prefix, "lm_head"),
716
+ )
717
+ self.logits_processor = LogitsProcessor(config.vocab_size)
718
+ else:
719
+ self.lm_head = PPMissingLayer()
720
+ self.logits_processor = None # type: ignore
721
+
722
+ self.make_empty_intermediate_tensors = (
723
+ self.model.make_empty_intermediate_tensors
724
+ )
725
+
726
+ self.expert_weights = []
727
+ self.num_moe_layers = 0
728
+
729
+ self.moe_layers = []
730
+ self.moe_mlp_layers = []
731
+
732
+ example_moe = None
733
+ for layer in self.model.layers:
734
+ if isinstance(layer, PPMissingLayer):
735
+ continue
736
+ if isinstance(layer.mlp, SarvamMLAMoE):
737
+ example_moe = layer.mlp
738
+ self.moe_mlp_layers.append(layer.mlp)
739
+ self.moe_layers.append(layer.mlp.experts)
740
+ self.num_moe_layers += 1
741
+
742
+ self.extract_moe_parameters(example_moe)
743
+
744
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
745
+ return self.model.embed_input_ids(input_ids)
746
+
747
+ def forward(
748
+ self,
749
+ input_ids: torch.Tensor,
750
+ positions: torch.Tensor,
751
+ intermediate_tensors: IntermediateTensors | None = None,
752
+ inputs_embeds: torch.Tensor | None = None,
753
+ ) -> torch.Tensor | IntermediateTensors:
754
+ return self.model(
755
+ input_ids=input_ids,
756
+ positions=positions,
757
+ intermediate_tensors=intermediate_tensors,
758
+ inputs_embeds=inputs_embeds,
759
+ )
760
+
761
+ def compute_logits(
762
+ self,
763
+ hidden_states: torch.Tensor,
764
+ ) -> torch.Tensor | None:
765
+ if not get_pp_group().is_last_rank:
766
+ return None
767
+ logits = self.logits_processor(self.lm_head, hidden_states)
768
+ return logits
769
+
770
+ def load_weights(
771
+ self,
772
+ weights: Iterable[tuple[str, torch.Tensor]],
773
+ ) -> set[str]:
774
+ loader = AutoWeightsLoader(
775
+ self,
776
+ skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
777
+ )
778
+ return loader.load_weights(weights)
779
+
780
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
781
+ return self.model.get_expert_mapping()
782
+
783
+
784
+ class SarvamMoEForCausalLM(BailingMoeForCausalLM):
785
+ """Same as BailingMoeForCausalLM, but normalizes gate expert_bias pre-load."""
786
+
787
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
788
+ return super().load_weights(_normalized_weights(weights))