from typing import Optional, Tuple from enum import Enum from dataclasses import dataclass, field from types import SimpleNamespace import torch import copy from transformers import Qwen3Config from transformers import GradientCheckpointingLayer, Cache from transformers.masking_utils import ( create_causal_mask, create_sliding_window_causal_mask, ) from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention, Qwen3MLP, Qwen3RMSNorm from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention, Qwen3VLTextMLP, Qwen3VLTextRMSNorm from fla.layers.delta_net import DeltaNet from fla.models.delta_net.configuration_delta_net import DeltaNetConfig from fla.layers.gated_deltanet import GatedDeltaNet from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig from fla.layers.kda import KimiDeltaAttention from fla.models.kda.configuration_kda import KDAConfig from fla.models.kda.modeling_kda import KDAPreTrainedModel from fla.layers.mamba2 import Mamba2 from fla.models.mamba2.configuration_mamba2 import Mamba2Config from fla.models.mamba2.modeling_mamba2 import Mamba2Block from fla.layers.gla import GatedLinearAttention from fla.models.gla.configuration_gla import GLAConfig from fla.layers.nsa import NativeSparseAttention from fla.models.nsa.configuration_nsa import NSAConfig from fla.layers.mla import MultiheadLatentAttention from fla.models.mla.configuration_mla import MLAConfig import copy class FLACacheAdapter: def __init__(self, cache): self.cache = cache if not hasattr(self.cache, 'fla_states'): self.cache.fla_states = {} def get_seq_length(self, layer_idx=None): if layer_idx is not None and layer_idx in self.cache.fla_states: state = self.cache.fla_states[layer_idx] if 'attn_state' in state: attn_state = state['attn_state'] if (isinstance(attn_state, tuple) and len(attn_state) == 2 and isinstance(attn_state[0], torch.Tensor)): return attn_state[0].shape[1] return 0 def update(self, attn_state=None, layer_idx=None, offset=None, cache_kwargs=None, **kwargs): if layer_idx is None: layer_idx = kwargs.pop('layer_idx', None) if layer_idx is None: return {} if layer_idx not in self.cache.fla_states: self.cache.fla_states[layer_idx] = {} state = self.cache.fla_states[layer_idx] if attn_state is not None: if (isinstance(attn_state, tuple) and len(attn_state) == 2 and isinstance(attn_state[0], torch.Tensor) and isinstance(attn_state[1], torch.Tensor)): new_k, new_v = attn_state if 'attn_state' in state: old_k, old_v = state['attn_state'] new_k = torch.cat([old_k, new_k], dim=1) new_v = torch.cat([old_v, new_v], dim=1) state['attn_state'] = (new_k, new_v) else: state['attn_state'] = attn_state for key, value in kwargs.items(): if key != 'layer_idx': state[key] = value return state def __getitem__(self, layer_idx): return self.cache.fla_states.get(layer_idx, None) def __setitem__(self, layer_idx, value): self.cache.fla_states[layer_idx] = value def __contains__(self, layer_idx): return layer_idx in self.cache.fla_states def __len__(self): if not self.cache.fla_states: return 0 return max(self.cache.fla_states.keys()) + 1 class AttentionType(str, Enum): FULL = "full_attention" SWA = "swa" MAMBA2 = "mamba2" GLA = "gla" GDN = "gdn" DN = "dn" KDA = "kda" NSA = "nsa" MLA = "mla" NOOP = "no-op" LINEAR = "linear" class FFNType(str, Enum): FFN = "ffn" MOE = "moe" NOOP = "no-op" LINEAR = "linear" NFFN = "nffn" class MetricType(str, Enum): mse = "mse" cosine = "cosine" kl = "kl" @dataclass class ChildLayerVLConfig: attention_type: Optional[AttentionType] = field(default=None) ffn_type: Optional[FFNType] = field(default=None) block_metric: Optional[MetricType] = field(default=None) child_hidden_size: Optional[int] = field(default=None) child_intermediate_size: Optional[int] = field(default=None) gqa_num_kv_heads: Optional[int] = field(default=None) child_num_attention_heads: Optional[int] = field(default=None) inherit: str = field(default="false") sliding_window: Optional[int] = field(default=1024) def __post_init__(self): if self.inherit is not None: cleaned = str(self.inherit).strip().lower() self.inherit = cleaned in ("true", "yes", "1") else: self.inherit = False class NonGatedFFN(torch.nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = torch.nn.ReLU() def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) class NasVLDecoderLayer(GradientCheckpointingLayer): def __init__(self, layer_idx: int, nas_config, parent_config, parent_model=None): super().__init__() self.parent_config = parent_config self.parent_text_config = parent_config.text_config self.layer_idx = layer_idx if isinstance(nas_config, dict): nas_config = ChildLayerVLConfig(**nas_config) elif not isinstance(nas_config, ChildLayerVLConfig): nas_config = ChildLayerVLConfig(**vars(nas_config)) self.nas_config = nas_config self.attention_type = nas_config.attention_type self.inherit = nas_config.inherit self.child_attn_heads = int( getattr(nas_config, "child_num_attention_heads", 0) or self.parent_text_config.num_attention_heads ) self.child_kv_heads = int( getattr(nas_config, "gqa_num_kv_heads", 0) or self.parent_text_config.num_key_value_heads ) self.child_inter_size = int( getattr(nas_config, "child_intermediate_size", 0) or self.parent_text_config.intermediate_size ) self.hidden_size = self.parent_text_config.hidden_size if nas_config.attention_type == AttentionType.FULL: attn_config = copy.deepcopy(self.parent_text_config) attn_config.num_attention_heads = self.child_attn_heads attn_config.num_key_value_heads = self.child_kv_heads attn_config._attn_implementation = "sdpa" self.self_attn = Qwen3VLTextAttention(config=attn_config, layer_idx=layer_idx) if parent_model is not None and self.inherit: teacher_attn = parent_model.model.language_model.layers[layer_idx].self_attn if (self.child_attn_heads == self.parent_text_config.num_attention_heads and self.child_kv_heads == self.parent_text_config.num_key_value_heads): self.self_attn.load_state_dict(teacher_attn.state_dict(), strict=True) else: prune_qwen_attention_head( student_attn=self.self_attn, teacher_attn=teacher_attn, teacher_config=self.parent_text_config, target_q_heads=self.child_attn_heads, target_kv_heads=self.child_kv_heads, ) elif nas_config.attention_type == AttentionType.SWA: self.sliding_window = int( getattr(nas_config, "sliding_window", 1024) or 1024 ) self._swa_mask_config = copy.deepcopy(parent_config) self._swa_mask_config.sliding_window = self.sliding_window if hasattr(self._swa_mask_config, "text_config"): self._swa_mask_config.text_config.sliding_window = self.sliding_window self._swa_mask_config._attn_implementation = "sdpa" if hasattr(self._swa_mask_config, "text_config"): self._swa_mask_config.text_config._attn_implementation = "sdpa" attn_config = copy.deepcopy(self.parent_text_config) attn_config.num_attention_heads = self.child_attn_heads attn_config.num_key_value_heads = self.child_kv_heads attn_config._attn_implementation = "sdpa" self.self_attn = Qwen3VLTextAttention(config=attn_config, layer_idx=layer_idx) if parent_model is not None and self.inherit: teacher_attn = parent_model.model.language_model.layers[layer_idx].self_attn if (self.child_attn_heads == self.parent_text_config.num_attention_heads and self.child_kv_heads == self.parent_text_config.num_key_value_heads): self.self_attn.load_state_dict(teacher_attn.state_dict(), strict=True) else: prune_qwen_attention_head( student_attn=self.self_attn, teacher_attn=teacher_attn, teacher_config=self.parent_text_config, target_q_heads=self.child_attn_heads, target_kv_heads=self.child_kv_heads, ) elif nas_config.attention_type == AttentionType.LINEAR: self.self_attn = torch.nn.Linear(self.hidden_size, self.hidden_size, bias=False) if parent_model is not None and self.inherit: prune_qwen_attention_head_linear( student_attn=self.self_attn, teacher_attn=parent_model.model.language_model.layers[layer_idx].self_attn, teacher_config=parent_config.text_config, ) elif nas_config.attention_type == AttentionType.KDA: config = KDAConfig(hidden_size=self.hidden_size) config.expand_v = 1 self.self_attn = KimiDeltaAttention( mode=config.attn_mode, hidden_size=config.hidden_size, expand_v=config.expand_v, head_dim=config.head_dim, num_heads=config.num_heads, num_v_heads=config.num_v_heads, use_short_conv=config.use_short_conv, allow_neg_eigval=config.allow_neg_eigval, conv_size=config.conv_size, norm_eps=config.norm_eps, layer_idx=layer_idx, ) if parent_model is not None and self.inherit: prune_qwen_attention_head_kda( student_attn=self.self_attn, teacher_attn=parent_model.model.language_model.layers[layer_idx].self_attn, teacher_config=parent_config.text_config, ) elif nas_config.attention_type == AttentionType.GDN: config = GatedDeltaNetConfig(hidden_size=self.hidden_size) self.self_attn = GatedDeltaNet( mode=config.attn_mode, hidden_size=config.hidden_size, expand_v=config.expand_v, head_dim=config.head_dim, num_heads=config.num_heads, num_v_heads=config.num_v_heads, use_gate=config.use_gate, use_short_conv=config.use_short_conv, allow_neg_eigval=config.allow_neg_eigval, conv_size=config.conv_size, norm_eps=config.norm_eps, layer_idx=layer_idx, ) if parent_model is not None and self.inherit: prune_qwen_attention_head_gdn( student_attn=self.self_attn, teacher_attn=parent_model.model.language_model.layers[layer_idx].self_attn, teacher_config=parent_config.text_config, ) elif nas_config.attention_type == AttentionType.NSA: config = NSAConfig(hidden_size=self.hidden_size) self.self_attn = NativeSparseAttention( hidden_size=config.hidden_size, num_heads=config.num_heads, num_kv_heads=config.num_kv_heads, head_dim=config.head_dim, qkv_bias=config.qkv_bias, block_size=config.block_size, block_counts=config.block_counts, window_size=config.window_size, rope_theta=config.rope_theta, max_position_embeddings=config.max_position_embeddings, layer_idx=layer_idx, ) if parent_model is not None and self.inherit: prune_qwen_attention_head_nsa( student_attn=self.self_attn, teacher_attn=parent_model.model.language_model.layers[layer_idx].self_attn, teacher_config=parent_config.text_config, ) elif nas_config.attention_type == AttentionType.MLA: config = MLAConfig(hidden_size=self.hidden_size) self.self_attn = MultiheadLatentAttention( hidden_size=config.hidden_size, num_heads=config.num_heads, q_lora_rank=config.q_lora_rank, qk_rope_head_dim=config.qk_rope_head_dim, kv_lora_rank=config.kv_lora_rank, v_head_dim=config.v_head_dim, qk_nope_head_dim=config.qk_nope_head_dim, qk_head_dim=config.qk_head_dim, window_size=config.window_size, rope_theta=config.rope_theta, max_position_embeddings=config.max_position_embeddings, rope_scaling=config.rope_scaling, layer_idx=layer_idx, ) if parent_model is not None and self.inherit: prune_qwen_attention_head_mla( student_attn=self.self_attn, teacher_attn=parent_model.model.language_model.layers[layer_idx].self_attn, teacher_config=parent_config.text_config, ) elif nas_config.attention_type == AttentionType.NOOP: self.self_attn = None else: raise Exception(f"Attention Type Not Define: {nas_config.attention_type}") if nas_config.ffn_type == FFNType.FFN: mlp_config = copy.deepcopy(self.parent_text_config) mlp_config.intermediate_size = self.child_inter_size self.mlp = Qwen3VLTextMLP(mlp_config) if parent_model is not None and self.inherit: teacher_mlp = parent_model.model.language_model.layers[layer_idx].mlp teacher_inter_size = teacher_mlp.up_proj.weight.shape[0] if self.child_inter_size < teacher_inter_size: init_student_ffn(self.mlp, teacher_mlp, self.child_inter_size) elif self.child_inter_size == teacher_inter_size: self.mlp.load_state_dict(teacher_mlp.state_dict(), strict=True) else: raise ValueError( f"Layer {layer_idx}: Student intermediate size ({self.child_inter_size}) " f"is larger than Teacher ({teacher_inter_size})." ) elif nas_config.ffn_type == FFNType.LINEAR: self.mlp = torch.nn.Linear(self.hidden_size, self.hidden_size, bias=False) if parent_model is not None and self.inherit: init_student_ffn_linear( self.mlp, parent_model.model.language_model.layers[layer_idx].mlp ) elif nas_config.ffn_type == FFNType.NFFN: nffn_config = copy.deepcopy(self.parent_text_config) nffn_config.intermediate_size = self.child_inter_size self.mlp = NonGatedFFN(nffn_config) elif nas_config.ffn_type == FFNType.NOOP: self.mlp = None else: raise Exception(f"FFN Type Not Define: {nas_config.ffn_type}") norm_eps = self.parent_text_config.rms_norm_eps if self.self_attn is not None: self.input_layernorm = Qwen3VLTextRMSNorm(self.hidden_size, eps=norm_eps) if parent_model is not None: self.input_layernorm.load_state_dict( parent_model.model.language_model.layers[layer_idx].input_layernorm.state_dict() ) else: self.input_layernorm = None if self.mlp is not None: self.post_attention_layernorm = Qwen3VLTextRMSNorm(self.hidden_size, eps=norm_eps) if parent_model is not None: self.post_attention_layernorm.load_state_dict( parent_model.model.language_model.layers[layer_idx].post_attention_layernorm.state_dict() ) else: self.post_attention_layernorm = None def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[Cache]]: residual = hidden_states present_key_values = past_key_values mask_2d = None mask_4d = None if attention_mask is not None: if attention_mask.ndim == 4: mask_2d = attention_mask[:, 0, -1, :] else: mask_2d = attention_mask if self.nas_config.attention_type == AttentionType.FULL: if attention_mask.ndim == 4: mask_4d = attention_mask else: if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device, ) mask_4d = create_causal_mask( input_embeds=hidden_states, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, config=self.parent_config, ) elif self.nas_config.attention_type == AttentionType.SWA: if attention_mask.ndim == 4: mask_4d = attention_mask else: if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device, ) mask_4d = create_sliding_window_causal_mask( config=self._swa_mask_config, input_embeds=hidden_states, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, ) if self.nas_config.attention_type == AttentionType.SWA and mask_4d is None: if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device, ) mask_4d = create_sliding_window_causal_mask( config=self._swa_mask_config, input_embeds=hidden_states, attention_mask=None, cache_position=cache_position, past_key_values=past_key_values, ) if self.nas_config.attention_type == AttentionType.FULL: hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=mask_4d, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states elif self.nas_config.attention_type == AttentionType.SWA: hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=mask_4d, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states elif self.nas_config.attention_type == AttentionType.LINEAR: hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn(hidden_states) hidden_states = residual + hidden_states elif self.nas_config.attention_type == AttentionType.NOOP: hidden_states = residual elif self.nas_config.attention_type in [ AttentionType.KDA, AttentionType.GDN ]: fla_cache_proxy = None if use_cache and past_key_values is not None: fla_cache_proxy = FLACacheAdapter(past_key_values) if self.training: mode = "chunk" else: mode = "fused_recurrent" if use_cache else "chunk" batch_size, q_len, _ = hidden_states.shape if q_len > 64 or use_cache: hidden_states = self.input_layernorm(hidden_states) outputs = self.self_attn( hidden_states=hidden_states, attention_mask=mask_2d, past_key_values=fla_cache_proxy, use_cache=use_cache, mode=mode, **kwargs, ) if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs hidden_states = residual + hidden_states else: hidden_states = residual elif self.nas_config.attention_type == AttentionType.NSA: hidden_states = self.input_layernorm(hidden_states) if self.training: nsa_kwargs = {k: v for k, v in kwargs.items() if k in ("cu_seqlens",)} outputs = self.self_attn( hidden_states=hidden_states, attention_mask=mask_2d, past_key_values=None, use_cache=False, **nsa_kwargs, ) if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs else: if past_key_values is not None and use_cache: if not hasattr(past_key_values, "fla_states"): past_key_values.fla_states = {} nsa_state = past_key_values.fla_states.get( f"nsa_hidden_{self.layer_idx}", None ) if nsa_state is not None: full_hidden = torch.cat([nsa_state, hidden_states], dim=1) else: full_hidden = hidden_states past_key_values.fla_states[f"nsa_hidden_{self.layer_idx}"] = ( full_hidden.detach() ) full_mask = None if mask_2d is not None: cached_len = full_hidden.shape[1] - hidden_states.shape[1] if cached_len > 0: prefix_mask = torch.ones( mask_2d.shape[0], cached_len, dtype=mask_2d.dtype, device=mask_2d.device, ) full_mask = torch.cat([prefix_mask, mask_2d], dim=1) else: full_mask = mask_2d outputs = self.self_attn( hidden_states=full_hidden, attention_mask=full_mask, past_key_values=None, use_cache=False, **{k: v for k, v in kwargs.items() if k in ("cu_seqlens",)}, ) if isinstance(outputs, tuple): full_output = outputs[0] else: full_output = outputs hidden_states = full_output[:, -hidden_states.shape[1] :, :] else: outputs = self.self_attn( hidden_states=hidden_states, attention_mask=mask_2d, past_key_values=None, use_cache=False, ) if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] hidden_states = residual + hidden_states elif self.nas_config.attention_type == AttentionType.MLA: hidden_states = self.input_layernorm(hidden_states) fla_cache_proxy = None if past_key_values is not None: fla_cache_proxy = FLACacheAdapter(past_key_values) outputs = self.self_attn( hidden_states=hidden_states, attention_mask=mask_2d, past_key_values=fla_cache_proxy, use_cache=use_cache, **kwargs, ) if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs hidden_states = residual + hidden_states else: raise Exception(f"Attention Type Not Define: {self.self_attn}") if self.nas_config.ffn_type in [FFNType.FFN, FFNType.NFFN, FFNType.LINEAR]: residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states elif self.nas_config.ffn_type == FFNType.NOOP: pass else: raise Exception(f"FFN Type Not Define: {self.nas_config.ffn_type}") return hidden_states, present_key_values