from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers.cache_utils import Cache from transformers.modeling_outputs import ModelOutput from transformers.models.auto import AutoProcessor from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel, Qwen3VLPreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs from .configuration_qwen3_vl_embedding import Qwen3VLEmbeddingConfig @dataclass class Qwen3VLForEmbeddingOutput(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None attention_mask: Optional[torch.Tensor] = None embeddings: Optional[torch.FloatTensor] = None class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel): config_class = Qwen3VLEmbeddingConfig _checkpoint_conversion_mapping = {} accepts_loss_kwargs = False def __init__(self, config: Qwen3VLEmbeddingConfig): super().__init__(config) self.model = Qwen3VLModel(config) self._processor = None self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None, ): return self.model.get_video_features(pixel_values_videos, video_grid_thw) def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None, ): return self.model.get_image_features(pixel_values, image_grid_thw) @property def language_model(self): return self.model.language_model @property def visual(self): return self.model.visual @staticmethod def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: flipped_tensor = attention_mask.flip(dims=[1]) last_one_positions = flipped_tensor.argmax(dim=1) col = attention_mask.shape[1] - last_one_positions - 1 row = torch.arange(hidden_state.shape[0], device=hidden_state.device) return hidden_state[row, col] def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, normalize: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Qwen3VLForEmbeddingOutput]: outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, **kwargs, ) embeddings = None if attention_mask is not None: embeddings = self._pooling_last(outputs.last_hidden_state, attention_mask) if normalize if normalize is not None else self.config.normalize_embeddings: embeddings = F.normalize(embeddings, p=2, dim=-1) return Qwen3VLForEmbeddingOutput( last_hidden_state=outputs.last_hidden_state, attention_mask=attention_mask, embeddings=embeddings, ) def _get_or_load_processor(self, processor=None): if processor is not None: return processor if self._processor is None: name_or_path = getattr(self.config, "_name_or_path", None) if not name_or_path: raise ValueError("A processor must be provided because config._name_or_path is unavailable.") self._processor = AutoProcessor.from_pretrained(name_or_path, trust_remote_code=True) return self._processor @torch.no_grad() def encode( self, inputs: List[Dict[str, Any]], *, processor: Optional[Any] = None, normalize: Optional[bool] = None, return_tensor: bool = True, device: Optional[Union[str, torch.device]] = None, ): processor = self._get_or_load_processor(processor=processor) batch = processor.prepare_for_embedding( inputs, max_length=self.config.max_length, min_pixels=self.config.min_pixels, max_pixels=self.config.max_pixels, total_pixels=self.config.max_total_pixels, fps=self.config.fps, num_frames=self.config.num_frames, max_frames=self.config.max_frames, default_instruction=self.config.default_instruction, ) target_device = torch.device(device) if device is not None else self.device batch = {key: value.to(target_device) for key, value in batch.items()} outputs = self(**batch, normalize=normalize) return outputs.embeddings if return_tensor else outputs