import math from typing import List, Optional import json import torch import torchvision from threading import Thread from copy import deepcopy from PIL import Image from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer, \ Owlv2ForObjectDetection from .constants import ADE20K_847, IMAGENET_CLASSES, HOUSE_OBJECTS from .configuration_minicpm import MiniCPMVConfig from .modeling_navit_siglip import SiglipVisionTransformer from .resampler import Resampler from ultralytics import YOLO class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel): config_class = MiniCPMVConfig class MiniCPMV(MiniCPMVPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = Qwen2ForCausalLM(config) self.vpm = self.init_vision_module() self.od_model = None self.vision_dim = self.vpm.embed_dim self.embed_dim = self.llm.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.processor = None self.terminators = ['<|im_end|>', '<|endoftext|>'] self._generate = self.generate def init_vision_module(self): # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes if self.config._attn_implementation == 'flash_attention_2': self.config.vision_config._attn_implementation = 'flash_attention_2' else: # not suport sdpa self.config.vision_config._attn_implementation = 'eager' model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, 'embed_dim', model.embeddings.embed_dim) setattr(model, 'patch_size', model.embeddings.patch_size) return model def init_od_model(self): # google/owlv2-base-patch16-ensemble if self.od_model is None: self.od_model = YOLO(self.config.od_model_name) return self.od_model def init_resampler(self, embed_dim, vision_dim): return Resampler( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True ) def get_input_embeddings(self): return self.llm.get_input_embeddings() def set_input_embeddings(self, value): self.llm.embed_tokens = value def get_output_embeddings(self): return self.llm.lm_head def set_output_embeddings(self, new_embeddings): self.llm.lm_head = new_embeddings def set_decoder(self, decoder): self.llm = decoder def get_decoder(self): return self.llm def get_vllm_embedding(self, data): if 'vision_hidden_states' not in data: dtype = self.llm.model.embed_tokens.weight.dtype device = self.llm.model.embed_tokens.weight.device tgt_sizes = data['tgt_sizes'] pixel_values_list = data['pixel_values'] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] for pixel_values in pixel_values_list: img_cnt.append(len(pixel_values)) all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image if all_pixel_values: tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_batch_size = self.config.vision_batch_size all_pixel_values = all_pixel_values.type(dtype) if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: vision_hidden_states.append(vision_embedding[start: start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) else: # no image if self.training: dummy_image = torch.zeros( (1, 3, 224, 224), device=device, dtype=dtype ) tgt_sizes = torch.Tensor( [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] for _ in range(len(pixel_values_list)): vision_hidden_states.append(dummy_feature) else: vision_hidden_states = data['vision_hidden_states'] if hasattr(self.llm.config, 'scale_emb'): vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb else: vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) new_vllm_embedding = vllm_embedding.clone() vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( i, torch.Tensor) else i for i in vision_hidden_states] bs = len(data['input_ids']) for i in range(bs): cur_vs_hs = vision_hidden_states[i] if len(cur_vs_hs) > 0: cur_vllm_emb = vllm_embedding[i] cur_image_bound = data['image_bound'][i] if len(cur_image_bound) > 0: image_indices = torch.stack( [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) new_vllm_embedding[i] = cur_vllm_emb.scatter(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[ -1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) elif self.training: new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0 return new_vllm_embedding, vision_hidden_states def forward(self, data, **kwargs): vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) position_ids = data["position_ids"] if position_ids.dtype != torch.int64: position_ids = position_ids.long() for key in ['input_ids', 'inputs_embeds', 'position_ids']: if key in kwargs: del kwargs[key] return self.llm( input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs ) def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] output = self.llm.generate( inputs_embeds=inputs_embeds, pad_token_id=0, eos_token_id=terminators, attention_mask=attention_mask, **kwargs ) if decode_text: return self._decode_text(output, tokenizer) return output def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = { 'inputs_embeds': inputs_embeds, 'pad_token_id': 0, 'eos_token_id': terminators, 'streamer': streamer } generation_kwargs.update(kwargs) thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) thread.start() return streamer def _decode_text(self, result_ids, tokenizer): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] result_text = [] for result in result_ids: result = result[result != 0] if result[0] == tokenizer.bos_id: result = result[1:] if result[-1] in terminators: result = result[:-1] result_text.append(tokenizer.decode(result).strip()) return result_text def generate( self, input_ids=None, pixel_values=None, tgt_sizes=None, image_bound=None, attention_mask=None, tokenizer=None, vision_hidden_states=None, return_vision_hidden_states=False, stream=False, decode_text=False, **kwargs ): assert input_ids is not None assert len(input_ids) == len(pixel_values) model_inputs = { "input_ids": input_ids, "image_bound": image_bound, } if vision_hidden_states is None: model_inputs["pixel_values"] = pixel_values model_inputs['tgt_sizes'] = tgt_sizes else: model_inputs["vision_hidden_states"] = vision_hidden_states with torch.inference_mode(): ( model_inputs["inputs_embeds"], vision_hidden_states, ) = self.get_vllm_embedding(model_inputs) if stream: result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) else: result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs) if return_vision_hidden_states: return result, vision_hidden_states return result @staticmethod def box2string(box): """ Convertit une boîte englobante en chaîne de caractères. Args: box (tensor): Coordonnées de la boîte englobante. Returns: str: Coordonnées sous format texte. """ return f"[{', '.join(f'{x:.2f}' for x in box)}]" def make_od_prompt(self, od_results, img_width, img_height): """ Génère une phrase décrivant les objets détectés par YOLO v11. Args: od_results (list): Résultats de la détection d'objets. Returns: tuple: Description textuelle, boîtes englobantes et classes détectées. """ def get_position_description(x1, y1, x2, y2, img_width, img_height): # Calcul du centre center_x = (x1 + x2) / 2 / img_width center_y = (y1 + y2) / 2 / img_height # Description horizontale if center_x < 0.33: horizontal = "on the left" elif center_x < 0.66: horizontal = "in the center" else: horizontal = "on the right" # Description verticale if center_y < 0.33: vertical = "at the top" elif center_y < 0.66: vertical = "in the middle" else: vertical = "at the bottom" return f"{vertical} {horizontal}" def enrich_user_prompt(od_labels, od_boxes, img_width, img_height): context = [] for label, box in zip(od_labels, od_boxes): x1, y1, x2, y2 = box print(f' Box: x1={x1:.2f}, y1={y1:.2f}, x2={x2:.2f}, y2={y2:.2f}') position = get_position_description(x1, y1, x2, y2, img_width, img_height) context.append(f'- Objet : "{label}", Position : {position}') # Joins les objets détectés avec leur position dans un texte lisible context_text = "\n".join(context) # Crée le prompt complet return f"Scene Context:\n{context_text}\n\n" detected_objects = [] od_boxes = [] od_labels = [] for result in od_results: for box in result.boxes: x1, y1, x2, y2 = box.xyxy[0].tolist() # Coordonnées des objets confidence = box.conf[0].item() class_id = int(box.cls[0].item()) # label = HOUSE_OBJECTS[class_id] if class_id < len(HOUSE_OBJECTS) else f"Objet_{class_id}" label = self.init_od_model().names[class_id] if confidence >= 0.8: # Seulement si la confiance est suffisante detected_objects.append(f"{self.box2string([x1, y1, x2, y2])} {label}") od_boxes.append([x1, y1, x2, y2]) od_labels.append(label) if detected_objects: verbalization_od = enrich_user_prompt(od_labels, od_boxes, img_width, img_height) else: verbalization_od = "" return verbalization_od def od_to_prompt(self, image): od_results = self.init_od_model()(image) verbalization_od = self.make_od_prompt(od_results, image.size[0], image.size[1]) return verbalization_od def chat( self, image, msgs, tokenizer, processor=None, vision_hidden_states=None, max_new_tokens=2048, min_new_tokens=0, sampling=True, max_inp_length=8192, system_prompt = "", stream=False, max_slice_nums=None, use_image_id=None, **kwargs ): if isinstance(msgs[0], list): batched = True else: batched = False msgs_list = msgs images_list = image if batched is False: images_list, msgs_list = [images_list], [msgs_list] else: assert images_list is None, "Please integrate image to msgs when using batch inference." images_list = [None] * len(msgs_list) assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." if processor is None: if self.processor is None: self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) processor = self.processor assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." prompts_lists = [] input_images_lists = [] for image, msgs in zip(images_list, msgs_list): if isinstance(msgs, str): msgs = json.loads(msgs) copy_msgs = deepcopy(msgs) assert len(msgs) > 0, "msgs is empty" assert sampling or not stream, "if use stream mode, make sure sampling=True" if image is not None and isinstance(copy_msgs[0]["content"], str): copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] images = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["user", "assistant"] if i == 0: assert role == "user", "The role of first msg should be user" if isinstance(content, str): content = [content] cur_msgs = [] for c in content: if isinstance(c, Image.Image): images.append(c) cur_msgs.append(f"Image {len(images)-1}:\n(./)") cur_msgs.append(self.od_to_prompt(c)) elif isinstance(c, str): cur_msgs.append(c) msg["content"] = "\n".join(cur_msgs) num_last_image = len(images) - 1 system_prompt = f"""You are a vision-language assistant. Your task is to answer questions based on the visual context provided, addressed to a visually impaired person. Each scene is presented through a sequence of images, and each image is labeled with a number (e.g., "Image 0", "Image 1", etc.). It's very important to know that the images are ordered chronologically: "Image 0" is the first image, and the last image (Image {num_last_image}) represents the current position of the person. Each image includes: - The object label (what the object is) - The approximate position in the image (top/middle/bottom and left/center/right) Use both the visual content of the images and the spatial information provided to understand the environment and guide with spatial instructions the person appropriately. Do not use visual information to guide but number of steps etc instead. Here is the scene context and question: """ if system_prompt: sys_msg = {'role': 'system', 'content': system_prompt} copy_msgs = [sys_msg] + copy_msgs prompts_lists.append( processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)) input_images_lists.append(images) print(prompts_lists) inputs = processor( prompts_lists, input_images_lists, max_slice_nums=max_slice_nums, use_image_id=use_image_id, return_tensors="pt", max_length=max_inp_length ).to(self.device) if sampling: generation_config = { "top_p": 0.6, "top_k": 50, "temperature": 0.6, "do_sample": True, "repetition_penalty": 1.05 } else: generation_config = { "num_beams": 3, "repetition_penalty": 1.2, } if min_new_tokens > 0: generation_config['min_new_tokens'] = min_new_tokens generation_config.update( (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() ) inputs.pop("image_sizes") with torch.inference_mode(): res = self.generate( **inputs, tokenizer=tokenizer, max_new_tokens=max_new_tokens, vision_hidden_states=vision_hidden_states, stream=stream, decode_text=True, **generation_config ) if stream: def stream_gen(): for text in res: for term in self.terminators: text = text.replace(term, '') yield text return stream_gen() else: if batched: answer = res else: answer = res[0] return answer