| import os |
| from abc import ABCMeta, abstractmethod |
| from typing import Optional, Union, Dict, List |
| from termcolor import colored |
|
|
| import torch |
| from transformers import ( |
| LlavaConfig, |
| ) |
| import decord |
| import PIL.Image |
| from tarsier2.dataset.utils import format_one_sample |
| from tarsier2.modeling_tarsier2 import Tarsier2ForConditionalGeneration |
| from tarsier2.modeling_qwen2_vl_fast import Qwen2VLForCausalLM |
| from tarsier2.dataset.tarsier_datamodule import init_processor |
|
|
| decord.bridge.set_bridge("torch") |
|
|
|
|
| EOL_PROMPTS = { |
| 'text': '<sent>\nSummary above sentence in one word:', |
| 'image': '<image>\nSummary above image in one word:', |
| 'video': '<video>\nSummary above video in one word:', |
| "video_edit": "USER: Source video: <video>\nEdit instruction: <sent>\n"\ |
| "Look at the attached video carefully. The provided text is instruction to edit the video. "\ |
| "Imagine this edit instruction being applied to the provided video frame.\n"\ |
| "Summarize the resulting edited video in one word: ASSISTANT:" |
| } |
| base_registry = {} |
| encoder_registry = {} |
|
|
|
|
| class BaseModel(metaclass=ABCMeta): |
| def __init_subclass__(cls, **kwargs): |
| super().__init_subclass__(**kwargs) |
| |
| if hasattr(cls, 'ARCHITECTURE'): |
| base_registry[cls.ARCHITECTURE] = cls |
| |
| @classmethod |
| def from_pretrained( |
| cls, |
| model_name_or_path: str, |
| load_llm: bool = False, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| **kwargs): |
| colored(f'Loading {cls.__name__} from {model_name_or_path}') |
|
|
| return cls(model_name_or_path, load_llm=load_llm, device_map=device_map, **kwargs) |
|
|
|
|
| class EncodeMixin(metaclass=ABCMeta): |
| def __init_subclass__(cls, **kwargs): |
| super().__init_subclass__(**kwargs) |
| |
| if hasattr(cls, 'ARCHITECTURE'): |
| encoder_registry[cls.ARCHITECTURE] = cls |
|
|
| @abstractmethod |
| def encode_vision(self, pixel_values: torch.Tensor | List[torch.Tensor]) -> torch.Tensor: |
| """ |
| Encodes vision data (images or videos) into a tensor representation. |
| |
| Args: |
| pixel_values (torch.Tensor | List[torch.Tensor]): The input pixel values. |
| - If a tensor, it should be of shape (B, C, H, W) for images or (B, T, C, H, W) for videos. |
| - If a list, it will be stacked into a tensor. |
| |
| Returns: |
| torch.Tensor: The encoded tensor representation of the input vision data. |
| |
| Raises: |
| ValueError: If `pixel_values` is not 4D or 5D. |
| |
| ## Notes: |
| - This function does not accept unbatched inputs. |
| - `pixel_values` should be of type uint8. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def encode_text(self, text: str | List[str]) -> torch.Tensor: |
| """ |
| Encodes the given text(s) into a tensor representation using the model. |
| |
| Args: |
| text (str | List[str]): A single string or a list of strings to be encoded. |
| |
| Returns: |
| torch.Tensor: The tensor representation of the encoded text(s). |
| |
| ## Notes: |
| - The method uses a prompt to encode the text. |
| - If a single string is provided, it is converted into a list containing that string. |
| - The method processes the prompts and generates the tensor representation using the model. |
| - The output tensor contains the hidden states of the last token for each input text. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class BaseModelForTARA(BaseModel): |
| |
| ARCHITECTURE = "Tarsier2ForConditionalGeneration" |
| LLM_CLASS = Qwen2VLForCausalLM |
| MLLM_CLASS = Tarsier2ForConditionalGeneration |
|
|
| @property |
| def describe_prompt(self): |
| return "Describe the video in detail." |
|
|
| @property |
| def text_eol_prompt(self): |
| |
| prompt = EOL_PROMPTS["text"] |
| return prompt |
| |
| @property |
| def image_eol_prompt(self): |
| |
| prompt = EOL_PROMPTS["image"] |
| return prompt |
| |
| @property |
| def video_eol_prompt(self): |
| |
| prompt = EOL_PROMPTS["video"] |
| return prompt |
|
|
| @staticmethod |
| def _resolve_attn_implementation(requested_attn_impl: Optional[str] = None) -> str: |
| attn_impl = requested_attn_impl or "flash_attention_2" |
| if attn_impl != "flash_attention_2": |
| return attn_impl |
| if not torch.cuda.is_available(): |
| print("CUDA is unavailable; falling back attn_implementation to 'eager'.") |
| return "eager" |
| major, _ = torch.cuda.get_device_capability(torch.cuda.current_device()) |
| if major < 8: |
| print( |
| f"GPU compute capability {major}.x does not support FlashAttention-2; " |
| "falling back attn_implementation to 'eager'." |
| ) |
| return "eager" |
| return "flash_attention_2" |
|
|
| def __init__( |
| self, |
| model_name_or_path: str, |
| load_llm: Optional[bool] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| **kwargs, |
| ): |
|
|
| MODEL_CLASS = self.LLM_CLASS if load_llm else self.MLLM_CLASS |
|
|
| if load_llm: |
| self.split_weights(model_name_or_path, model_name_or_path + '-llm') |
| model_name_or_path += '-llm' |
| model_config = None |
| |
| |
| |
| |
| |
| import shared.utils as su |
| self.base_config = su.io.load_yml( |
| os.path.join( |
| su.log.repo_path, 'tarsier2/default_config.yaml' |
| ) |
| ) |
| self.super_processor = init_processor(model_name_or_path, self.base_config) |
| self.processor = self.super_processor.processor |
| self.tokenizer = self.processor.tokenizer |
|
|
| else: |
| model_config = LlavaConfig.from_pretrained( |
| model_name_or_path, |
| trust_remote_code=True, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| import shared.utils as su |
| self.base_config = su.io.load_yml( |
| os.path.join( |
| su.log.repo_path, 'tarsier2/default_config.yaml' |
| ) |
| ) |
| self.super_processor = init_processor(model_name_or_path, self.base_config) |
| self.processor = self.super_processor.processor |
| self.tokenizer = self.processor.tokenizer |
|
|
| attn_implementation = self._resolve_attn_implementation( |
| kwargs.get("attn_implementation", "flash_attention_2") |
| ) |
|
|
| self.model = MODEL_CLASS.from_pretrained( |
| model_name_or_path, |
| config=model_config, |
| attn_implementation=attn_implementation, |
| |
| torch_dtype=torch.bfloat16, |
| device_map=device_map, |
| trust_remote_code=True, |
| low_cpu_mem_usage=kwargs.get("low_cpu_mem_usage", True), |
| ) |
| |
| |
| |
| |
| self.model.eval() |
|
|
| def split_weights(self, mllm_path, llm_path): |
| if os.path.exists(llm_path): |
| print(f'{llm_path} already exists. Skip splitting weights.') |
| return |
| print('Splitting LLM weights from MLLM.') |
| attn_implementation = self._resolve_attn_implementation("flash_attention_2") |
| model = self.MLLM_CLASS.from_pretrained( |
| mllm_path, |
| attn_implementation=attn_implementation, |
| torch_dtype=torch.bfloat16, |
| ) |
| llm = model.language_model |
| llm.save_pretrained(llm_path) |
| |
| import shared.utils as su |
| from tarsier2.dataset.tarsier_datamodule import init_processor |
| base_config = su.io.load_yml( |
| os.path.join(su.log.repo_path, 'models/tarsier2/default_config.yaml'), |
| ) |
| super_processor = init_processor( |
| mllm_path, |
| base_config, |
| ) |
| super_processor.processor.save_pretrained(llm_path) |
| super_processor.processor.tokenizer.save_pretrained(llm_path) |
|
|
|
|
| class TARA(BaseModelForTARA, EncodeMixin): |
| |
| def encode_vision(self, video_path: str, prompt=None) -> torch.Tensor: |
| |
| ext = video_path.split('.')[-1] |
| if ext in ['mp4', 'avi', 'mov', 'mkv', 'webm']: |
| is_video = True |
| else: |
| is_video = False |
| |
| if prompt is None: |
| if is_video: |
| prompt = self.video_eol_prompt |
| else: |
| prompt = self.image_eol_prompt |
| else: |
| assert "<video>" in prompt or "<image>" in prompt |
| sample = format_one_sample(media_file=video_path, prompt=prompt) |
| sample = self.super_processor(sample) |
| model_inputs = {} |
| for k, v in sample.items(): |
| if not isinstance(v, torch.Tensor): |
| continue |
| model_inputs[k] = v.to(self.model.device) |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **model_inputs, |
| max_new_tokens=1, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| pad_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| emb = output.hidden_states[0][-1][:, -1, :] |
| return emb |
| |
| def encode_vision_with_text(self, video_path: str, text: str) -> torch.Tensor: |
| ext = video_path.split('.')[-1] |
| |
| |
| |
| |
| |
| prompt = EOL_PROMPTS["video_edit"].replace('<sent>', text) |
| sample = format_one_sample(media_file=video_path, prompt=prompt) |
| sample = self.super_processor(sample) |
| model_inputs = {} |
| for k, v in sample.items(): |
| if not isinstance(v, torch.Tensor): |
| continue |
| model_inputs[k] = v.to(self.model.device) |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **model_inputs, |
| max_new_tokens=1, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| pad_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| emb = output.hidden_states[0][-1][:, -1, :] |
| return emb |
| |
| def encode_image(self, image_path: str, prompt=None): |
| ext = image_path.split('.')[-1] |
| if ext in ['mp4', 'avi', 'mov', 'mkv', 'webm']: |
| is_video = True |
| else: |
| is_video = False |
| assert not is_video |
| if prompt is None: |
| prompt = self.image_eol_prompt |
| else: |
| assert "<image>" in prompt |
| sample = format_one_sample(media_file=image_path, prompt=prompt) |
| sample = self.super_processor(sample) |
| model_inputs = {} |
| for k, v in sample.items(): |
| if not isinstance(v, torch.Tensor): |
| continue |
| model_inputs[k] = v.to(self.model.device) |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **model_inputs, |
| max_new_tokens=1, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| pad_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| emb = output.hidden_states[0][-1][:, -1, :] |
| return emb |
| |
| def encode_text(self, text: str, prompt=None) -> torch.Tensor: |
| |
| if prompt is None: |
| prompt = self.text_eol_prompt |
| else: |
| assert "<sent>" in prompt |
|
|
| if isinstance(text, str): |
| prompt = prompt.replace('<sent>', text) |
| sample = format_one_sample(media_file=None, prompt=prompt) |
| sample = self.super_processor(sample) |
| model_inputs = {} |
| for k, v in sample.items(): |
| if not isinstance(v, torch.Tensor): |
| continue |
| model_inputs[k] = v.to(self.model.device) |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **model_inputs, |
| max_new_tokens=1, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| pad_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| emb = output.hidden_states[0][-1][:, -1, :] |
| return emb |
| elif isinstance(text, list): |
| text_embs = [] |
| for t in text: |
| prompt = self.text_eol_prompt.replace('<sent>', t) |
| sample = format_one_sample(media_file=None, prompt=prompt) |
| sample = self.super_processor(sample) |
| model_inputs = {} |
| for k, v in sample.items(): |
| if not isinstance(v, torch.Tensor): |
| continue |
| model_inputs[k] = v.to(self.model.device) |
| with torch.inference_mode(): |
| output = self.model.generate( |
| **model_inputs, |
| max_new_tokens=1, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| ) |
| emb = output.hidden_states[0][-1][:, -1, :] |
| text_embs.append(emb) |
| return torch.cat(text_embs) |
| else: |
| raise ValueError(f"Invalid type for text: {type(text)}") |
|
|
|
|
| def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): |
| if sample in ["rand", "middle"]: |
| acc_samples = min(num_frames, vlen) |
| |
| intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
| ranges = [] |
| for idx, interv in enumerate(intervals[:-1]): |
| ranges.append((interv, intervals[idx + 1] - 1)) |
| if sample == 'rand': |
| try: |
| frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
| except (ValueError, IndexError): |
| frame_indices = np.random.permutation(vlen)[:acc_samples] |
| frame_indices.sort() |
| frame_indices = list(frame_indices) |
| elif fix_start is not None: |
| frame_indices = [x[0] + fix_start for x in ranges] |
| elif sample == 'middle': |
| frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
| else: |
| raise NotImplementedError |
|
|
| if len(frame_indices) < num_frames: |
| padded_frame_indices = [frame_indices[-1]] * num_frames |
| padded_frame_indices[:len(frame_indices)] = frame_indices |
| frame_indices = padded_frame_indices |
| elif "fps" in sample: |
| output_fps = float(sample[3:]) |
| duration = float(vlen) / input_fps |
| delta = 1 / output_fps |
| frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
| frame_indices = np.around(frame_seconds * input_fps).astype(int) |
| frame_indices = [e for e in frame_indices if e < vlen] |
| if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
| frame_indices = frame_indices[:max_num_frames] |
| |
| else: |
| raise ValueError |
| return frame_indices |
|
|
|
|
| def read_frames_decord( |
| video_path, num_frames, sample='middle', fix_start=None, |
| max_num_frames=-1, trimmed30=False, height=-1, width=-1 |
| ): |
| decord.bridge.set_bridge('torch') |
|
|
| |
| num_threads = 1 |
| video_reader = VideoReader(video_path, num_threads=num_threads, height=height, width=width) |
| try: |
| vlen = len(video_reader) |
|
|
| fps = video_reader.get_avg_fps() |
| duration = vlen / float(fps) |
|
|
| |
| if trimmed30 and duration > 30: |
| duration = 30 |
| vlen = int(30 * float(fps)) |
|
|
| frame_indices = get_frame_indices( |
| num_frames, vlen, sample=sample, fix_start=fix_start, |
| input_fps=fps, max_num_frames=max_num_frames |
| ) |
|
|
| frames = video_reader.get_batch(frame_indices) |
| if not isinstance(frames, torch.Tensor): |
| frames = torch.from_numpy(frames.asnumpy()) |
| frames = frames.permute(0, 3, 1, 2) |
| return frames |
| finally: |
| |
| del video_reader |
|
|
|
|
| def read_image_decord(image_path): |
| image = PIL.Image.open(image_path) |
| image = image.convert('RGB') |
| image = np.array(image) |
| image = image.transpose(2, 0, 1) |
| image = torch.from_numpy(image) |
| image = image.unsqueeze(0) |
| return image |
|
|
|
|
| def read_images_decord(image_paths): |
| images = [] |
| for image_path in image_paths: |
| image = read_image_decord(image_path) |
| images.append(image) |
| images = torch.cat(images) |
| return images |
|
|
|
|
| if __name__ == "__main__": |
| from termcolor import colored |
| import random |
| import numpy as np |
| from decord import VideoReader |
|
|
| |
| model = TARA.from_pretrained( |
| "/work/piyush/experiments/CaRe/Tarsier2-7b-0115/covr/chiral10k-covr10k/merged_checkpoint/", |
| device_map='auto', |
| dtype=torch.bfloat16, |
| ) |
| n_params = sum(p.numel() for p in model.model.parameters()) |
| print(f"Number of parameters: {round(n_params/1e9, 3)}B") |
| |
| |
| print(colored("Testing video encoding...", 'cyan')) |
| video_path = "./assets/folding_paper.mp4" |
| |
| |
| |
| with torch.no_grad(): |
| video_emb = model.encode_vision(video_path).cpu().squeeze(0).float() |
| |
| print("Video embedding shape:", video_emb.shape) |
| |
| |
| print(colored("Testing text encoding...", 'cyan')) |
| text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper'] |
| |
| with torch.no_grad(): |
| text_emb = model.encode_text(text).cpu().float() |
| print("Text:", text) |
| print("Text embedding shape:", text_emb.shape) |
|
|
|
|