Qwen3-VL-Embedding-2B / processing_qwen3_vl_embedding.py
whybe-choi's picture
feat: add trust_remote_code support via AutoModel integration
e4fcc36
Raw
History Blame Contribute Delete
6.77 kB
import logging
import unicodedata
from typing import Any, Dict, List, Optional, Union
import numpy as np
from PIL import Image
from qwen_vl_utils.vision_process import process_vision_info
from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor
logger = logging.getLogger(__name__)
def sample_frames(
frames: List[Union[str, Image.Image]],
num_segments: int,
max_segments: int,
) -> List[Union[str, Image.Image]]:
duration = len(frames)
frame_id_array = np.linspace(0, duration - 1, num_segments, dtype=int)
frame_id_list = frame_id_array.tolist()
last_frame_id = frame_id_list[-1]
sampled_frames: List[Union[str, Image.Image]] = []
for frame_idx in frame_id_list:
try:
sampled_frames.append(frames[frame_idx])
except Exception:
break
while len(sampled_frames) < num_segments:
sampled_frames.append(frames[last_frame_id])
return sampled_frames[:max_segments]
class Qwen3VLEmbeddingProcessor(Qwen3VLProcessor):
default_instruction = "Represent the user's input."
def format_model_input(
self,
text: Optional[str] = None,
image: Optional[Union[str, Image.Image]] = None,
video: Optional[Union[str, List[Union[str, Image.Image]]]] = None,
instruction: Optional[str] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
total_pixels: Optional[int] = None,
fps: Optional[float] = None,
num_frames: Optional[int] = None,
max_frames: Optional[int] = None,
) -> List[Dict[str, Any]]:
if instruction:
instruction = instruction.strip()
if instruction and not unicodedata.category(instruction[-1]).startswith("P"):
instruction = f"{instruction}."
content: List[Dict[str, Any]] = []
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": instruction or self.default_instruction}],
},
{"role": "user", "content": content},
]
if not text and not image and not video:
content.append({"type": "text", "text": "NULL"})
return conversation
if video:
video_content = None
video_kwargs: Dict[str, Any] = {"total_pixels": total_pixels}
if isinstance(video, list):
video_content = video
if num_frames is not None or max_frames is not None:
video_content = sample_frames(video_content, num_frames or len(video_content), max_frames or len(video_content))
video_content = [
("file://" + frame if isinstance(frame, str) and not frame.startswith(("http://", "https://", "file://")) else frame)
for frame in video_content
]
elif isinstance(video, str):
video_content = video if video.startswith(("http://", "https://")) else f"file://{video}"
video_kwargs = {"fps": fps, "max_frames": max_frames}
else:
raise TypeError(f"Unrecognized video type: {type(video)}")
if video_content:
content.append({"type": "video", "video": video_content, **{k: v for k, v in video_kwargs.items() if v is not None}})
if image:
if isinstance(image, Image.Image):
image_content: Union[str, Image.Image] = image
elif isinstance(image, str):
image_content = image if image.startswith(("http", "oss", "file://")) else f"file://{image}"
else:
raise TypeError(f"Unrecognized image type: {type(image)}")
content.append(
{
"type": "image",
"image": image_content,
**({} if min_pixels is None else {"min_pixels": min_pixels}),
**({} if max_pixels is None else {"max_pixels": max_pixels}),
}
)
if text:
content.append({"type": "text", "text": text})
return conversation
def prepare_for_embedding(
self,
inputs: List[Dict[str, Any]],
*,
max_length: int,
min_pixels: int,
max_pixels: int,
total_pixels: int,
fps: float,
num_frames: int,
max_frames: int,
default_instruction: Optional[str] = None,
return_tensors: str = "pt",
):
original_default_instruction = self.default_instruction
if default_instruction is not None:
self.default_instruction = default_instruction
conversations = [
self.format_model_input(
text=item.get("text"),
image=item.get("image"),
video=item.get("video"),
instruction=item.get("instruction"),
min_pixels=min_pixels,
max_pixels=max_pixels,
total_pixels=total_pixels,
fps=item.get("fps", fps),
num_frames=item.get("num_frames", num_frames),
max_frames=item.get("max_frames", max_frames),
)
for item in inputs
]
text = self.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)
try:
images, video_inputs, video_kwargs = process_vision_info(
conversations,
image_patch_size=16,
return_video_metadata=True,
return_video_kwargs=True,
)
except Exception as exc:
logger.error("Error in processing vision info: %s", exc)
text = self.apply_chat_template(
[[{"role": "user", "content": [{"type": "text", "text": "NULL"}]}]],
add_generation_prompt=True,
tokenize=False,
)
images = None
video_inputs = None
video_kwargs = {"do_sample_frames": False}
if video_inputs is not None:
videos, video_metadata = (list(x) for x in zip(*video_inputs))
else:
videos, video_metadata = None, None
try:
batch = self(
text=text,
images=images,
videos=videos,
video_metadata=video_metadata,
truncation=True,
max_length=max_length,
padding=True,
do_resize=False,
return_tensors=return_tensors,
**video_kwargs,
)
finally:
self.default_instruction = original_default_instruction
return batch