Lil2J commited on
Commit
e345725
·
verified ·
1 Parent(s): 4db1aba

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.json +32 -0
  2. eagle_data.jsonl +0 -0
  3. pytorch_model.bin +3 -0
  4. qwen3_moe.py +913 -0
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLMEagle3"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 64,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 2048,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 12288,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 36,
16
+ "model_type": "llama",
17
+ "num_attention_heads": 32,
18
+ "num_hidden_layers": 1,
19
+ "num_key_value_heads":4 ,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936,
30
+ "draft_vocab_size": 32000
31
+ }
32
+
eagle_data.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14e0dd4603c30e14a3664bfdc19aecbfd02e7a86a8c9c301687120f8e1bb2e9c
3
+ size 339239102
qwen3_moe.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from qwen2_moe.py
2
+
3
+ # Copyright 2023-2024 SGLang Team
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
+
20
+ import logging
21
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+
26
+ from sglang.srt.distributed import (
27
+ get_pp_group,
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ parallel_state,
31
+ split_tensor_along_last_dim,
32
+ tensor_model_parallel_all_gather,
33
+ tensor_model_parallel_all_reduce,
34
+ )
35
+ from sglang.srt.layers.activation import SiluAndMul
36
+ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
37
+ from sglang.srt.layers.dp_attention import (
38
+ attn_tp_all_gather,
39
+ attn_tp_reduce_scatter,
40
+ dp_gather_partial,
41
+ dp_scatter,
42
+ get_attention_tp_rank,
43
+ get_attention_tp_size,
44
+ get_local_attention_dp_size,
45
+ )
46
+ from sglang.srt.layers.layernorm import RMSNorm
47
+ from sglang.srt.layers.linear import (
48
+ MergedColumnParallelLinear,
49
+ QKVParallelLinear,
50
+ ReplicatedLinear,
51
+ RowParallelLinear,
52
+ )
53
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
54
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
55
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
57
+ from sglang.srt.layers.moe.topk import select_experts
58
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
59
+ from sglang.srt.layers.radix_attention import RadixAttention
60
+ from sglang.srt.layers.rotary_embedding import get_rope
61
+ from sglang.srt.layers.utils import get_layer_id
62
+ from sglang.srt.layers.vocab_parallel_embedding import (
63
+ ParallelLMHead,
64
+ VocabParallelEmbedding,
65
+ )
66
+ from sglang.srt.managers.expert_distribution import (
67
+ get_global_expert_distribution_recorder,
68
+ )
69
+ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
70
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
71
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
72
+ from sglang.srt.model_executor.forward_batch_info import (
73
+ ForwardBatch,
74
+ ForwardMode,
75
+ PPProxyTensors,
76
+ )
77
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
78
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
79
+ from sglang.srt.models.qwen2_moe import Qwen2MoeModel
80
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, ScatterMode
81
+ from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
82
+
83
+ Qwen3MoeConfig = None
84
+
85
+ logger = logging.getLogger(__name__)
86
+
87
+
88
+ class Qwen3MoeSparseMoeBlock(nn.Module):
89
+ def __init__(
90
+ self,
91
+ layer_id: int,
92
+ config: Qwen3MoeConfig,
93
+ quant_config: Optional[QuantizationConfig] = None,
94
+ prefix: str = "",
95
+ ):
96
+ super().__init__()
97
+ self.tp_size = get_tensor_model_parallel_world_size()
98
+ self.layer_id = layer_id
99
+ if self.tp_size > config.num_experts:
100
+ raise ValueError(
101
+ f"Tensor parallel size {self.tp_size} is greater than "
102
+ f"the number of experts {config.num_experts}."
103
+ )
104
+
105
+ self.experts = get_moe_impl_class()(
106
+ num_experts=config.num_experts
107
+ + global_server_args_dict["ep_num_redundant_experts"],
108
+ top_k=config.num_experts_per_tok,
109
+ layer_id=layer_id,
110
+ hidden_size=config.hidden_size,
111
+ intermediate_size=config.moe_intermediate_size,
112
+ renormalize=config.norm_topk_prob,
113
+ quant_config=quant_config,
114
+ prefix=add_prefix("experts", prefix),
115
+ **(
116
+ dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
117
+ if global_server_args_dict["enable_deepep_moe"]
118
+ else {}
119
+ ),
120
+ )
121
+
122
+ self.gate = ReplicatedLinear(
123
+ config.hidden_size,
124
+ config.num_experts,
125
+ bias=False,
126
+ quant_config=None,
127
+ prefix=add_prefix("gate", prefix),
128
+ )
129
+
130
+ if global_server_args_dict["enable_deepep_moe"]:
131
+ # TODO: we will support tp < ep in the future
132
+ self.ep_size = get_tensor_model_parallel_world_size()
133
+ self.num_experts = (
134
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
135
+ )
136
+ self.top_k = config.num_experts_per_tok
137
+ self.renormalize = config.norm_topk_prob
138
+
139
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
140
+ group=parallel_state.get_tp_group().device_group,
141
+ router_topk=self.top_k,
142
+ permute_fusion=True,
143
+ num_experts=self.num_experts,
144
+ num_local_experts=config.num_experts // self.tp_size,
145
+ hidden_size=config.hidden_size,
146
+ params_dtype=config.torch_dtype,
147
+ deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
148
+ async_finish=True, # TODO
149
+ return_recv_hook=True,
150
+ )
151
+
152
+ def forward(
153
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
154
+ ) -> torch.Tensor:
155
+
156
+ if not global_server_args_dict["enable_deepep_moe"]:
157
+ return self.forward_normal(hidden_states)
158
+ else:
159
+ return self.forward_deepep(hidden_states, forward_batch)
160
+
161
+ def get_moe_weights(self):
162
+ return [
163
+ x.data
164
+ for name, x in self.experts.named_parameters()
165
+ if name not in ["correction_bias"]
166
+ ]
167
+
168
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
169
+ num_tokens, hidden_dim = hidden_states.shape
170
+ hidden_states = hidden_states.view(-1, hidden_dim)
171
+
172
+ # router_logits: (num_tokens, n_experts)
173
+ router_logits, _ = self.gate(hidden_states)
174
+ final_hidden_states = self.experts(
175
+ hidden_states=hidden_states, router_logits=router_logits
176
+ )
177
+ if self.tp_size > 1:
178
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
179
+
180
+ return final_hidden_states.view(num_tokens, hidden_dim)
181
+
182
+ def forward_deepep(
183
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
184
+ ) -> torch.Tensor:
185
+ forward_mode = forward_batch.forward_mode
186
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
187
+ # router_logits: (num_tokens, n_experts)
188
+ router_logits, _ = self.gate(hidden_states)
189
+
190
+ topk_weights, topk_idx = select_experts(
191
+ hidden_states=hidden_states,
192
+ router_logits=router_logits,
193
+ top_k=self.top_k,
194
+ use_grouped_topk=False,
195
+ renormalize=self.renormalize,
196
+ num_token_non_padded=forward_batch.num_token_non_padded,
197
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
198
+ layer_id=self.layer_id,
199
+ ),
200
+ )
201
+ else:
202
+ topk_idx = torch.full(
203
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
204
+ )
205
+ topk_weights = torch.empty(
206
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
207
+ )
208
+ if self.ep_size > 1:
209
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
210
+ (
211
+ hidden_states,
212
+ topk_idx,
213
+ topk_weights,
214
+ reorder_topk_ids,
215
+ num_recv_tokens_per_expert,
216
+ seg_indptr,
217
+ masked_m,
218
+ expected_m,
219
+ ) = self.deepep_dispatcher.dispatch(
220
+ hidden_states=hidden_states,
221
+ topk_idx=topk_idx,
222
+ topk_weights=topk_weights,
223
+ forward_mode=forward_mode,
224
+ )
225
+ final_hidden_states = self.experts(
226
+ hidden_states=hidden_states,
227
+ topk_idx=topk_idx,
228
+ topk_weights=topk_weights,
229
+ reorder_topk_ids=reorder_topk_ids,
230
+ seg_indptr=seg_indptr,
231
+ masked_m=masked_m,
232
+ expected_m=expected_m,
233
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
234
+ forward_mode=forward_mode,
235
+ )
236
+ if self.ep_size > 1:
237
+ final_hidden_states = self.deepep_dispatcher.combine(
238
+ hidden_states=final_hidden_states,
239
+ topk_idx=topk_idx,
240
+ topk_weights=topk_weights,
241
+ forward_mode=forward_mode,
242
+ )
243
+ return final_hidden_states
244
+
245
+ def op_gate(self, state):
246
+ if is_non_idle_and_non_empty(
247
+ state.forward_batch.forward_mode, state.hidden_states_mlp_input
248
+ ):
249
+ # router_logits: (num_tokens, n_experts)
250
+ state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
251
+ else:
252
+ state.router_logits = None
253
+
254
+ def op_select_experts(self, state):
255
+ router_logits = state.pop("router_logits")
256
+ hidden_states = state.hidden_states_mlp_input
257
+ if router_logits is not None:
258
+ with get_global_expert_distribution_recorder().with_current_layer(
259
+ self.layer_id
260
+ ):
261
+ state.topk_weights_local, state.topk_idx_local = select_experts(
262
+ hidden_states=hidden_states,
263
+ router_logits=router_logits,
264
+ top_k=self.top_k,
265
+ use_grouped_topk=False,
266
+ renormalize=self.renormalize,
267
+ num_token_non_padded=state.forward_batch.num_token_non_padded,
268
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
269
+ layer_id=self.layer_id,
270
+ ),
271
+ )
272
+ else:
273
+ state.topk_idx_local = torch.full(
274
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
275
+ )
276
+ state.topk_weights_local = torch.empty(
277
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
278
+ )
279
+
280
+ def op_dispatch_a(self, state):
281
+ if self.ep_size > 1:
282
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
283
+ self.deepep_dispatcher.dispatch_a(
284
+ hidden_states=state.pop("hidden_states_mlp_input"),
285
+ topk_idx=state.pop("topk_idx_local"),
286
+ topk_weights=state.pop("topk_weights_local"),
287
+ forward_mode=state.forward_batch.forward_mode,
288
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
289
+ )
290
+
291
+ def op_dispatch_b(self, state):
292
+ if self.ep_size > 1:
293
+ with get_global_expert_distribution_recorder().with_current_layer(
294
+ self.layer_id
295
+ ):
296
+ (
297
+ state.hidden_states_experts_input,
298
+ state.topk_idx_dispatched,
299
+ state.topk_weights_dispatched,
300
+ state.reorder_topk_ids,
301
+ state.num_recv_tokens_per_expert,
302
+ state.seg_indptr,
303
+ state.masked_m,
304
+ state.expected_m,
305
+ ) = self.deepep_dispatcher.dispatch_b(
306
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
307
+ )
308
+
309
+ def op_experts(self, state):
310
+ state.hidden_states_experts_output = self.experts(
311
+ hidden_states=state.pop("hidden_states_experts_input"),
312
+ topk_idx=state.topk_idx_dispatched,
313
+ topk_weights=state.topk_weights_dispatched,
314
+ reorder_topk_ids=state.pop("reorder_topk_ids"),
315
+ seg_indptr=state.pop("seg_indptr"),
316
+ masked_m=state.pop("masked_m"),
317
+ expected_m=state.pop("expected_m"),
318
+ num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
319
+ forward_mode=state.forward_batch.forward_mode,
320
+ )
321
+
322
+ def op_combine_a(self, state):
323
+ if self.ep_size > 1:
324
+ self.deepep_dispatcher.combine_a(
325
+ hidden_states=state.pop("hidden_states_experts_output"),
326
+ topk_idx=state.pop("topk_idx_dispatched"),
327
+ topk_weights=state.pop("topk_weights_dispatched"),
328
+ forward_mode=state.forward_batch.forward_mode,
329
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
330
+ )
331
+
332
+ def op_combine_b(self, state):
333
+ if self.ep_size > 1:
334
+ state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
335
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
336
+ )
337
+
338
+ def op_output(self, state):
339
+ state.hidden_states_mlp_output = state.pop("hidden_states_after_combine")
340
+
341
+
342
+ class Qwen3MoeAttention(nn.Module):
343
+ def __init__(
344
+ self,
345
+ hidden_size: int,
346
+ num_heads: int,
347
+ num_kv_heads: int,
348
+ layer_id: int = 0,
349
+ rope_theta: float = 10000,
350
+ rope_scaling: Optional[Dict[str, Any]] = None,
351
+ max_position_embeddings: int = 8192,
352
+ head_dim: Optional[int] = None,
353
+ rms_norm_eps: float = 1e-06,
354
+ attention_bias: bool = False,
355
+ quant_config: Optional[QuantizationConfig] = None,
356
+ prefix: str = "",
357
+ ) -> None:
358
+ super().__init__()
359
+ self.hidden_size = hidden_size
360
+
361
+ attn_tp_rank = get_attention_tp_rank()
362
+ attn_tp_size = get_attention_tp_size()
363
+
364
+ self.total_num_heads = num_heads
365
+ assert self.total_num_heads % attn_tp_size == 0
366
+ self.num_heads = self.total_num_heads // attn_tp_size
367
+ self.total_num_kv_heads = num_kv_heads
368
+ if self.total_num_kv_heads >= attn_tp_size:
369
+ # Number of KV heads is greater than TP size, so we partition
370
+ # the KV heads across multiple tensor parallel GPUs.
371
+ assert self.total_num_kv_heads % attn_tp_size == 0
372
+ else:
373
+ # Number of KV heads is less than TP size, so we replicate
374
+ # the KV heads across multiple tensor parallel GPUs.
375
+ assert attn_tp_size % self.total_num_kv_heads == 0
376
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
377
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
378
+ self.q_size = self.num_heads * self.head_dim
379
+ self.kv_size = self.num_kv_heads * self.head_dim
380
+ self.scaling = self.head_dim**-0.5
381
+ self.rope_theta = rope_theta
382
+ self.max_position_embeddings = max_position_embeddings
383
+ self.tp_rank = get_tensor_model_parallel_rank()
384
+
385
+ self.qkv_proj = QKVParallelLinear(
386
+ hidden_size,
387
+ self.head_dim,
388
+ self.total_num_heads,
389
+ self.total_num_kv_heads,
390
+ bias=attention_bias,
391
+ quant_config=quant_config,
392
+ tp_rank=attn_tp_rank,
393
+ tp_size=attn_tp_size,
394
+ prefix=add_prefix("qkv_proj", prefix),
395
+ )
396
+
397
+ self.o_proj = RowParallelLinear(
398
+ self.total_num_heads * self.head_dim,
399
+ hidden_size,
400
+ bias=attention_bias,
401
+ quant_config=quant_config,
402
+ tp_rank=attn_tp_rank,
403
+ tp_size=attn_tp_size,
404
+ reduce_results=False,
405
+ prefix=add_prefix("o_proj", prefix),
406
+ )
407
+
408
+ self.rotary_emb = get_rope(
409
+ self.head_dim,
410
+ rotary_dim=self.head_dim,
411
+ max_position=max_position_embeddings,
412
+ base=rope_theta,
413
+ rope_scaling=rope_scaling,
414
+ )
415
+ self.attn = RadixAttention(
416
+ self.num_heads,
417
+ self.head_dim,
418
+ self.scaling,
419
+ num_kv_heads=self.num_kv_heads,
420
+ layer_id=layer_id,
421
+ prefix=add_prefix("attn", prefix),
422
+ )
423
+
424
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
425
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
426
+
427
+ def _apply_qk_norm(
428
+ self, q: torch.Tensor, k: torch.Tensor
429
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
430
+ q_by_head = q.reshape(-1, self.head_dim)
431
+ q_by_head = self.q_norm(q_by_head)
432
+ q = q_by_head.view(q.shape)
433
+ k_by_head = k.reshape(-1, self.head_dim)
434
+ k_by_head = self.k_norm(k_by_head)
435
+ k = k_by_head.view(k.shape)
436
+ return q, k
437
+
438
+ def op_prepare(self, state):
439
+ state.attn_intermediate_state = self.forward_prepare(
440
+ positions=state.positions,
441
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
442
+ forward_batch=state.forward_batch,
443
+ )
444
+
445
+ def op_core(self, state):
446
+ state.hidden_states_after_attn = self.forward_core(
447
+ state.pop("attn_intermediate_state")
448
+ )
449
+
450
+ def forward_prepare(
451
+ self,
452
+ positions: torch.Tensor,
453
+ hidden_states: torch.Tensor,
454
+ forward_batch: ForwardBatch,
455
+ ):
456
+ if hidden_states.shape[0] == 0:
457
+ return hidden_states, forward_batch, None
458
+ qkv, _ = self.qkv_proj(hidden_states)
459
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
460
+ q, k = self._apply_qk_norm(q, k)
461
+ q, k = self.rotary_emb(positions, q, k)
462
+ inner_state = q, k, v, forward_batch
463
+ return None, forward_batch, inner_state
464
+
465
+ def forward_core(self, intermediate_state):
466
+ hidden_states, forward_batch, inner_state = intermediate_state
467
+ if inner_state is None:
468
+ return hidden_states
469
+ attn_output = self.attn(*inner_state)
470
+ output, _ = self.o_proj(attn_output)
471
+ return output
472
+
473
+ def forward(
474
+ self,
475
+ positions: torch.Tensor,
476
+ hidden_states: torch.Tensor,
477
+ forward_batch: ForwardBatch,
478
+ ) -> torch.Tensor:
479
+ s = self.forward_prepare(
480
+ positions=positions,
481
+ hidden_states=hidden_states,
482
+ forward_batch=forward_batch,
483
+ )
484
+ return self.forward_core(s)
485
+
486
+
487
+ class Qwen3MoeDecoderLayer(nn.Module):
488
+ def __init__(
489
+ self,
490
+ config: Qwen3MoeConfig,
491
+ layer_id: int,
492
+ quant_config: Optional[QuantizationConfig] = None,
493
+ prefix: str = "",
494
+ ) -> None:
495
+ super().__init__()
496
+ self.config = config
497
+ self.hidden_size = config.hidden_size
498
+ rope_theta = getattr(config, "rope_theta", 10000)
499
+ rope_scaling = getattr(config, "rope_scaling", None)
500
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
501
+ head_dim = getattr(
502
+ config, "head_dim", config.hidden_size // config.num_attention_heads
503
+ )
504
+ rms_norm_eps = config.rms_norm_eps
505
+ attention_bias = config.attention_bias
506
+ self.self_attn = Qwen3MoeAttention(
507
+ hidden_size=self.hidden_size,
508
+ num_heads=config.num_attention_heads,
509
+ num_kv_heads=config.num_key_value_heads,
510
+ layer_id=layer_id,
511
+ rope_theta=rope_theta,
512
+ rope_scaling=rope_scaling,
513
+ max_position_embeddings=max_position_embeddings,
514
+ head_dim=head_dim,
515
+ rms_norm_eps=rms_norm_eps,
516
+ attention_bias=attention_bias,
517
+ quant_config=quant_config,
518
+ prefix=add_prefix("self_attn", prefix),
519
+ )
520
+
521
+ self.layer_id = layer_id
522
+
523
+ self.attn_tp_size = get_attention_tp_size()
524
+ self.attn_tp_rank = get_attention_tp_rank()
525
+ self.local_dp_size = get_local_attention_dp_size()
526
+
527
+ # Qwen3MoE all layers are sparse and have no nextn now
528
+ self.is_layer_sparse = True
529
+ is_previous_layer_sparse = True
530
+
531
+ self.layer_scatter_modes = LayerScatterModes.init_new(
532
+ layer_id=layer_id,
533
+ num_layers=config.num_hidden_layers,
534
+ is_layer_sparse=self.is_layer_sparse,
535
+ is_previous_layer_sparse=is_previous_layer_sparse,
536
+ )
537
+
538
+ if self.is_layer_sparse:
539
+ self.mlp = Qwen3MoeSparseMoeBlock(
540
+ layer_id=self.layer_id,
541
+ config=config,
542
+ quant_config=quant_config,
543
+ prefix=add_prefix("mlp", prefix),
544
+ )
545
+ else:
546
+ self.mlp = Qwen3MoeMLP(
547
+ hidden_size=config.hidden_size,
548
+ intermediate_size=config.intermediate_size,
549
+ hidden_act=config.hidden_act,
550
+ quant_config=quant_config,
551
+ prefix=add_prefix("mlp", prefix),
552
+ )
553
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
554
+ self.post_attention_layernorm = RMSNorm(
555
+ config.hidden_size, eps=config.rms_norm_eps
556
+ )
557
+
558
+ self.layer_communicator = LayerCommunicator(
559
+ layer_scatter_modes=self.layer_scatter_modes,
560
+ input_layernorm=self.input_layernorm,
561
+ post_attention_layernorm=self.post_attention_layernorm,
562
+ )
563
+
564
+ def forward(
565
+ self,
566
+ positions: torch.Tensor,
567
+ hidden_states: torch.Tensor,
568
+ forward_batch: ForwardBatch,
569
+ residual: Optional[torch.Tensor],
570
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
571
+
572
+ hidden_states, residual = self.layer_communicator.prepare_attn(
573
+ hidden_states, residual, forward_batch
574
+ )
575
+
576
+ if hidden_states.shape[0] != 0:
577
+ hidden_states = self.self_attn(
578
+ positions=positions,
579
+ hidden_states=hidden_states,
580
+ forward_batch=forward_batch,
581
+ )
582
+
583
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
584
+ hidden_states, residual, forward_batch
585
+ )
586
+
587
+ hidden_states = self.mlp(hidden_states, forward_batch)
588
+
589
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
590
+ hidden_states, residual, forward_batch
591
+ )
592
+
593
+ return hidden_states, residual
594
+
595
+ def op_comm_prepare_attn(
596
+ self,
597
+ state,
598
+ positions: torch.Tensor,
599
+ hidden_states: torch.Tensor,
600
+ forward_batch: ForwardBatch,
601
+ residual: Optional[torch.Tensor],
602
+ tbo_subbatch_index: Optional[int] = None,
603
+ ):
604
+ state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
605
+ self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
606
+ )
607
+ state.update(
608
+ dict(
609
+ forward_batch=forward_batch,
610
+ positions=positions,
611
+ tbo_subbatch_index=tbo_subbatch_index,
612
+ )
613
+ )
614
+
615
+ def op_comm_prepare_mlp(self, state):
616
+ state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
617
+ self.layer_communicator.prepare_mlp(
618
+ state.pop("hidden_states_after_attn"),
619
+ state.pop("residual_after_input_ln"),
620
+ state.forward_batch,
621
+ )
622
+ )
623
+
624
+ def op_mlp(self, state):
625
+ hidden_states = state.pop("hidden_states_mlp_input")
626
+ state.hidden_states_mlp_output = self.mlp(
627
+ hidden_states, state.forward_batch.forward_mode
628
+ )
629
+
630
+ def op_comm_postprocess_layer(self, state):
631
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
632
+ state.pop("hidden_states_mlp_output"),
633
+ state.pop("residual_after_comm_pre_mlp"),
634
+ state.forward_batch,
635
+ )
636
+
637
+ output = dict(
638
+ positions=state.positions,
639
+ hidden_states=hidden_states,
640
+ residual=residual,
641
+ forward_batch=state.forward_batch,
642
+ tbo_subbatch_index=state.tbo_subbatch_index,
643
+ )
644
+
645
+ state.clear(
646
+ expect_keys={
647
+ "positions",
648
+ "forward_batch",
649
+ "tbo_subbatch_index",
650
+ }
651
+ )
652
+ return output
653
+
654
+
655
+ class Qwen3MoeModel(Qwen2MoeModel):
656
+ def __init__(
657
+ self,
658
+ config: Qwen3MoeConfig,
659
+ quant_config: Optional[QuantizationConfig] = None,
660
+ prefix: str = "",
661
+ ) -> None:
662
+ super().__init__(
663
+ config=config,
664
+ quant_config=quant_config,
665
+ prefix=prefix,
666
+ decoder_layer_type=Qwen3MoeDecoderLayer,
667
+ )
668
+
669
+ # For EAGLE3 support
670
+ self.layers_to_capture = []
671
+
672
+ def forward(
673
+ self,
674
+ input_ids: torch.Tensor,
675
+ positions: torch.Tensor,
676
+ forward_batch: ForwardBatch,
677
+ input_embeds: torch.Tensor = None,
678
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
679
+ ) -> Union[torch.Tensor, PPProxyTensors]:
680
+ if self.pp_group.is_first_rank:
681
+ if input_embeds is None:
682
+ hidden_states = self.embed_tokens(input_ids)
683
+ else:
684
+ hidden_states = input_embeds
685
+ residual = None
686
+ else:
687
+ assert pp_proxy_tensors is not None
688
+ hidden_states = pp_proxy_tensors["hidden_states"]
689
+ residual = pp_proxy_tensors["residual"]
690
+
691
+ # For EAGLE3 support - collect auxiliary hidden states
692
+ aux_hidden_states = []
693
+
694
+ if forward_batch.can_run_tbo:
695
+ hidden_states, residual = model_forward_maybe_tbo(
696
+ layers=self.layers,
697
+ enable_tbo=True,
698
+ input_data_scatter_mode=ScatterMode.model_input_output(),
699
+ positions=positions,
700
+ forward_batch=forward_batch,
701
+ hidden_states=hidden_states,
702
+ residual=residual,
703
+ )
704
+ else:
705
+ for i in range(self.start_layer, self.end_layer):
706
+ # EAGLE3 support: capture hidden states from specified layers
707
+ if i in self.layers_to_capture:
708
+ aux_hidden_states.append(hidden_states + residual)
709
+
710
+ with get_global_expert_distribution_recorder().with_current_layer(i):
711
+ layer = self.layers[i]
712
+ hidden_states, residual = layer(
713
+ positions, hidden_states, forward_batch, residual
714
+ )
715
+ if not self.pp_group.is_last_rank:
716
+ return PPProxyTensors(
717
+ {
718
+ "hidden_states": hidden_states,
719
+ "residual": residual,
720
+ }
721
+ )
722
+ else:
723
+ if hidden_states.shape[0] != 0:
724
+ if residual is None:
725
+ hidden_states = self.norm(hidden_states)
726
+ else:
727
+ hidden_states, _ = self.norm(hidden_states, residual)
728
+
729
+ # Return aux_hidden_states if available for EAGLE3
730
+ if len(aux_hidden_states) == 0:
731
+ return hidden_states
732
+ return hidden_states, aux_hidden_states
733
+
734
+
735
+ class Qwen3MoeForCausalLM(nn.Module):
736
+ fall_back_to_pt_during_load = False
737
+
738
+ def __init__(
739
+ self,
740
+ config: Qwen3MoeConfig,
741
+ quant_config: Optional[QuantizationConfig] = None,
742
+ prefix: str = "",
743
+ ) -> None:
744
+ super().__init__()
745
+ self.pp_group = get_pp_group()
746
+ self.config = config
747
+ self.quant_config = quant_config
748
+ self.model = Qwen3MoeModel(
749
+ config, quant_config, prefix=add_prefix("model", prefix)
750
+ )
751
+ self.lm_head = ParallelLMHead(
752
+ config.vocab_size,
753
+ config.hidden_size,
754
+ quant_config=quant_config,
755
+ prefix=add_prefix("lm_head", prefix),
756
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
757
+ )
758
+ self.logits_processor = LogitsProcessor(config)
759
+
760
+ # For EAGLE3 support
761
+ self.capture_aux_hidden_states = False
762
+
763
+ @torch.no_grad()
764
+ def forward(
765
+ self,
766
+ input_ids: torch.Tensor,
767
+ positions: torch.Tensor,
768
+ forward_batch: ForwardBatch,
769
+ input_embeds: torch.Tensor = None,
770
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
771
+ ) -> torch.Tensor:
772
+ hidden_states = self.model(
773
+ input_ids,
774
+ positions,
775
+ forward_batch,
776
+ input_embeds,
777
+ pp_proxy_tensors=pp_proxy_tensors,
778
+ )
779
+
780
+ aux_hidden_states = None
781
+ if self.capture_aux_hidden_states:
782
+ hidden_states, aux_hidden_states = hidden_states
783
+
784
+ if self.pp_group.is_last_rank:
785
+ return self.logits_processor(
786
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
787
+ )
788
+ else:
789
+ return hidden_states
790
+
791
+ @property
792
+ def start_layer(self):
793
+ return self.model.start_layer
794
+
795
+ @property
796
+ def end_layer(self):
797
+ return self.model.end_layer
798
+
799
+ def get_embed_and_head(self):
800
+ return self.model.embed_tokens.weight, self.lm_head.weight
801
+
802
+ def set_eagle3_layers_to_capture(self):
803
+ if not self.pp_group.is_last_rank:
804
+ return
805
+
806
+ self.capture_aux_hidden_states = True
807
+ num_layers = self.config.num_hidden_layers
808
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
809
+
810
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
811
+ stacked_params_mapping = [
812
+ # (param_name, shard_name, shard_id)
813
+ ("qkv_proj", "q_proj", "q"),
814
+ ("qkv_proj", "k_proj", "k"),
815
+ ("qkv_proj", "v_proj", "v"),
816
+ ("gate_up_proj", "gate_proj", 0),
817
+ ("gate_up_proj", "up_proj", 1),
818
+ ]
819
+
820
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
821
+ ckpt_gate_proj_name="gate_proj",
822
+ ckpt_down_proj_name="down_proj",
823
+ ckpt_up_proj_name="up_proj",
824
+ num_experts=self.config.num_experts,
825
+ )
826
+
827
+ params_dict = dict(self.named_parameters())
828
+ for name, loaded_weight in weights:
829
+ layer_id = get_layer_id(name)
830
+ if (
831
+ layer_id is not None
832
+ and hasattr(self.model, "start_layer")
833
+ and (
834
+ layer_id < self.model.start_layer
835
+ or layer_id >= self.model.end_layer
836
+ )
837
+ ):
838
+ continue
839
+
840
+ if "rotary_emb.inv_freq" in name:
841
+ continue
842
+ for param_name, weight_name, shard_id in stacked_params_mapping:
843
+ # Skip non-stacked layers and experts (experts handled below).
844
+ if weight_name not in name:
845
+ continue
846
+ # We have mlp.experts[0].gate_proj in the checkpoint.
847
+ # Since we handle the experts below in expert_params_mapping,
848
+ # we need to skip here BEFORE we update the name, otherwise
849
+ # name will be updated to mlp.experts[0].gate_up_proj, which
850
+ # will then be updated below in expert_params_mapping
851
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
852
+ if "mlp.experts" in name:
853
+ continue
854
+ name = name.replace(weight_name, param_name)
855
+ # Skip loading extra bias for GPTQ models.
856
+ if name.endswith(".bias") and name not in params_dict:
857
+ continue
858
+ if name not in params_dict:
859
+ continue
860
+
861
+ param = params_dict[name]
862
+ weight_loader = param.weight_loader
863
+ weight_loader(param, loaded_weight, shard_id)
864
+ break
865
+ else:
866
+ for mapping in expert_params_mapping:
867
+ param_name, weight_name, expert_id, shard_id = mapping
868
+ if weight_name not in name:
869
+ continue
870
+ name = name.replace(weight_name, param_name)
871
+ param = params_dict[name]
872
+ weight_loader = param.weight_loader
873
+ weight_loader(
874
+ param,
875
+ loaded_weight,
876
+ name,
877
+ shard_id=shard_id,
878
+ expert_id=expert_id,
879
+ )
880
+ break
881
+ else:
882
+ # Skip loading extra bias for GPTQ models.
883
+ if name.endswith(".bias") and name not in params_dict:
884
+ continue
885
+ if name not in params_dict:
886
+ continue
887
+
888
+ if name in params_dict.keys():
889
+ param = params_dict[name]
890
+ weight_loader = getattr(
891
+ param, "weight_loader", default_weight_loader
892
+ )
893
+ weight_loader(param, loaded_weight)
894
+ else:
895
+ logger.warning(f"Parameter {name} not found in params_dict")
896
+
897
+ # TODO mimic deepseek
898
+ self.routed_experts_weights_of_layer = {
899
+ layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
900
+ for layer_id in range(self.start_layer, self.end_layer)
901
+ if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
902
+ }
903
+
904
+ @classmethod
905
+ def get_model_config_for_expert_location(cls, config):
906
+ return ModelConfigForExpertLocation(
907
+ num_layers=config.num_hidden_layers,
908
+ num_logical_experts=config.num_experts,
909
+ num_groups=None,
910
+ )
911
+
912
+
913
+ EntryClass = Qwen3MoeForCausalLM