import torch import torch.nn as nn from typing import List, Optional, Any from dataclasses import dataclass from transformers.models.qwen3_vl import Qwen3VLPreTrainedModel from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLModel, Qwen3VLTextModel, Qwen3VLVisionModel, Qwen3VLTextRMSNorm, Qwen3VLTextRotaryEmbedding, ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import DynamicCache from transformers.utils import is_torchdynamo_compiling from transformers.generation import GenerationMixin from .config_nas_vl import NasChildVLConfig from .nas_vl_layer import NasVLDecoderLayer, ChildLayerVLConfig @dataclass class Qwen3VLCausalLMOutputWithPast(CausalLMOutputWithPast): rope_deltas: Optional[torch.Tensor] = None class NasChildVLModelForCausalLM(Qwen3VLPreTrainedModel, GenerationMixin): config_class = NasChildVLConfig _checkpoint_conversion_mapping = {} _tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"} get_image_features = Qwen3VLModel.get_image_features get_video_features = Qwen3VLModel.get_video_features get_placeholder_mask = Qwen3VLModel.get_placeholder_mask get_rope_index = Qwen3VLModel.get_rope_index _deepstack_process = Qwen3VLTextModel._deepstack_process def __init__(self, config: NasChildVLConfig): super().__init__(config) self.parent_config = config self.is_vl = True self.rope_deltas = None text_config = config.text_config self.parent_hidden_size = text_config.hidden_size self.child_hidden_size = self.parent_hidden_size self.vocab_size = text_config.vocab_size self.visual = Qwen3VLVisionModel._from_config(config.vision_config) self.embed_tokens = nn.Embedding( text_config.vocab_size, text_config.hidden_size ) if self.child_hidden_size != self.parent_hidden_size: self.input_proj = nn.Linear(self.parent_hidden_size, self.child_hidden_size, bias=False) self.output_proj = nn.Linear(self.child_hidden_size, self.parent_hidden_size, bias=False) else: self.input_proj = nn.Identity() self.output_proj = nn.Identity() layer_types = [] global_sliding_window = None for i in range(text_config.num_hidden_layers): cfg = config.nas_layer_config[i] if isinstance(cfg, dict): cfg = ChildLayerVLConfig(**cfg) attn_type = str(cfg.attention_type).split('.')[-1].lower() if attn_type == "swa": layer_types.append("sliding_attention") if global_sliding_window is None: sw_val = getattr(cfg, "sliding_window", 1024) global_sliding_window = int(sw_val) if sw_val else 1024 else: layer_types.append("full_attention") if hasattr(self.config, "text_config"): self.config.text_config.layer_types = layer_types if global_sliding_window is not None: self.config.text_config.sliding_window = global_sliding_window else: self.config.layer_types = layer_types if global_sliding_window is not None: self.config.sliding_window = global_sliding_window self.layers = nn.ModuleList() for i in range(text_config.num_hidden_layers): cfg = config.nas_layer_config[i] self.layers.append( NasVLDecoderLayer( layer_idx=i, nas_config=cfg, parent_config=config, parent_model=None, ) ) self.norm = Qwen3VLTextRMSNorm( self.child_hidden_size, eps=text_config.rms_norm_eps ) self.lm_head = nn.Linear( self.parent_hidden_size, self.vocab_size, bias=False ) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=text_config) self.has_sliding_layers = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, **kwargs, ) model_inputs.update({ "pixel_values": pixel_values, "pixel_values_videos": pixel_values_videos, "image_grid_thw": image_grid_thw, "video_grid_thw": video_grid_thw, }) model_inputs["position_ids"] = None if cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None return model_inputs def _get_image_nums_and_video_nums(self, input_ids, inputs_embeds=None): image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id if inputs_embeds is not None: dev = inputs_embeds.device _embed = lambda tid: self.embed_tokens( torch.tensor(tid, dtype=torch.long, device=dev) ) vision_start_mask = (inputs_embeds == _embed(vision_start_token_id))[..., 0] image_mask = (inputs_embeds == _embed(image_token_id))[..., 0] video_mask = (inputs_embeds == _embed(video_token_id))[..., 0] else: vision_start_mask = input_ids == vision_start_token_id image_mask = input_ids == image_token_id video_mask = input_ids == video_token_id vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) return image_nums, video_nums def _expand_inputs_for_generation( self, expand_size=1, is_encoder_decoder=False, input_ids=None, **model_kwargs, ): if expand_size == 1: return input_ids, model_kwargs visual_keys = [ "pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts", ] def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) return torch.cat([s.repeat(*repeat_args) for s in samples], dim=0) def _expand_visual(d): image_grid_thw = model_kwargs.get("image_grid_thw") video_grid_thw = model_kwargs.get("video_grid_thw") image_nums, video_nums = self._get_image_nums_and_video_nums( input_ids, inputs_embeds=model_kwargs.get("inputs_embeds") ) for key in list(d.keys()): if d[key] is None: continue if key == "pixel_values": lens = [torch.prod(s, dim=1).sum() for s in torch.split(image_grid_thw, list(image_nums))] d[key] = _repeat_interleave_samples(d[key], lens, expand_size) elif key == "image_grid_thw": d[key] = _repeat_interleave_samples(d[key], list(image_nums), expand_size) elif key == "pixel_values_videos": lens = [torch.prod(s, dim=1).sum() for s in torch.split(video_grid_thw, list(video_nums))] d[key] = _repeat_interleave_samples(d[key], lens, expand_size) elif key == "video_grid_thw": d[key] = _repeat_interleave_samples(d[key], list(video_nums), expand_size) elif key == "second_per_grid_ts": d[key] = _repeat_interleave_samples(d[key], list(video_nums), expand_size) return d def _expand_general(d): for key in d: if (key != "cache_position" and d[key] is not None and isinstance(d[key], torch.Tensor) and key not in visual_keys): d[key] = d[key].repeat_interleave(expand_size, dim=0) return d model_kwargs = _expand_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) model_kwargs = _expand_general(model_kwargs) if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: raise ValueError("encoder_outputs required for encoder-decoder") model_kwargs["encoder_outputs"] = _expand_general( model_kwargs["encoder_outputs"] ) return input_ids, model_kwargs def forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.Tensor] = None, video_grid_thw: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) image_mask = video_mask = None deepstack_image_embeds = deepstack_video_embeds = None if pixel_values is not None and self.visual is not None: image_embeds, deepstack_image_embeds = self.get_image_features( pixel_values, image_grid_thw ) image_embeds = torch.cat(image_embeds, dim=0).to( inputs_embeds.device, inputs_embeds.dtype ) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None and self.visual is not None: video_embeds, deepstack_video_embeds = self.get_video_features( pixel_values_videos, video_grid_thw ) video_embeds = torch.cat(video_embeds, dim=0).to( inputs_embeds.device, inputs_embeds.dtype ) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) visual_pos_masks = None deepstack_visual_embeds = None if image_mask is not None and video_mask is not None: image_mask = image_mask[..., 0] video_mask = video_mask[..., 0] visual_pos_masks = image_mask | video_mask deepstack_visual_embeds = [] img_joint = image_mask[visual_pos_masks] vid_joint = video_mask[visual_pos_masks] for img_e, vid_e in zip(deepstack_image_embeds, deepstack_video_embeds): joint = img_e.new_zeros( visual_pos_masks.sum(), img_e.shape[-1] ).to(img_e.device) joint[img_joint, :] = img_e joint[vid_joint, :] = vid_e deepstack_visual_embeds.append(joint) elif image_mask is not None: image_mask = image_mask[..., 0] visual_pos_masks = image_mask deepstack_visual_embeds = deepstack_image_embeds elif video_mask is not None: video_mask = video_mask[..., 0] visual_pos_masks = video_mask deepstack_visual_embeds = deepstack_video_embeds if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen, past_seen + inputs_embeds.shape[1], device=inputs_embeds.device, ) current_seq_len = inputs_embeds.shape[1] if (current_seq_len == 1 and cache_position[0] == 0 and attention_mask is not None): real_past_seen = attention_mask.shape[-1] - 1 if real_past_seen > 0: cache_position = torch.tensor( [real_past_seen], device=inputs_embeds.device ) if position_ids is None: attn_mask_tensor = ( attention_mask if not isinstance(attention_mask, dict) else attention_mask.get("full_attention") ) if attn_mask_tensor is not None and attn_mask_tensor.ndim == 4: attn_mask_tensor = torch.diagonal( attn_mask_tensor[:, 0], dim1=1, dim2=2 ) if attn_mask_tensor.dtype.is_floating_point: attn_mask_tensor = ( attn_mask_tensor / torch.finfo(attn_mask_tensor.dtype).min ) attn_mask_tensor = (1.0 - attn_mask_tensor).int() is_real_prefill = ( (input_ids is not None and input_ids.shape[1] > 1) or (inputs_embeds is not None and inputs_embeds.shape[1] > 1) ) prefill_compiled = is_torchdynamo_compiling() and is_real_prefill prefill_noncompiled = not is_torchdynamo_compiling() and ( (cache_position is not None and cache_position[0] == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0) ) should_calc_rope = ( (prefill_compiled or prefill_noncompiled) or self.rope_deltas is None ) if (should_calc_rope and not is_real_prefill and self.rope_deltas is not None): should_calc_rope = False if should_calc_rope: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask=attn_mask_tensor, ) self.rope_deltas = rope_deltas else: batch_size = inputs_embeds.shape[0] seq_length = inputs_embeds.shape[1] delta = ( (cache_position[0] + self.rope_deltas).to( inputs_embeds.device ) if cache_position is not None else 0 ) position_ids = torch.arange( seq_length, device=inputs_embeds.device ).view(1, -1).expand(batch_size, -1) if cache_position is not None: delta = delta.repeat_interleave( batch_size // delta.shape[0], dim=0 ) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) if position_ids.ndim == 3 and position_ids.shape[0] == 4: text_position_ids = position_ids[0] rope_position_ids = position_ids[1:] elif position_ids.ndim == 3: text_position_ids = position_ids[0] rope_position_ids = position_ids else: text_position_ids = position_ids rope_position_ids = position_ids rotary_emb = self.rotary_emb(inputs_embeds, rope_position_ids) hidden_states = self.input_proj(inputs_embeds) all_hidden_states = () if output_hidden_states else None if output_hidden_states: all_hidden_states += (hidden_states,) for i, layer in enumerate(self.layers): layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=text_position_ids, position_embeddings=rotary_emb, use_cache=use_cache, past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) hidden_states = ( layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs ) if (deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds)): hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[i], ) if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states = self.norm(hidden_states) hidden_states = self.output_proj(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.CrossEntropyLoss()( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1), ) return Qwen3VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=all_hidden_states, rope_deltas=self.rope_deltas, )