Safetensors
English
llava
video-retrieval
text-to-video-search
multimodal-embedding
TARA / modeling_tara.py
bpiyush's picture
Update TARA to latest Tarsier2 checkpoint and runnable demo.
7daf628
import os
from abc import ABCMeta, abstractmethod
from typing import Optional, Union, Dict, List
from termcolor import colored
import torch
from transformers import (
LlavaConfig,
)
import decord
import PIL.Image
from tarsier2.dataset.utils import format_one_sample
from tarsier2.modeling_tarsier2 import Tarsier2ForConditionalGeneration
from tarsier2.modeling_qwen2_vl_fast import Qwen2VLForCausalLM
from tarsier2.dataset.tarsier_datamodule import init_processor
decord.bridge.set_bridge("torch")
EOL_PROMPTS = {
'text': '<sent>\nSummary above sentence in one word:',
'image': '<image>\nSummary above image in one word:',
'video': '<video>\nSummary above video in one word:',
"video_edit": "USER: Source video: <video>\nEdit instruction: <sent>\n"\
"Look at the attached video carefully. The provided text is instruction to edit the video. "\
"Imagine this edit instruction being applied to the provided video frame.\n"\
"Summarize the resulting edited video in one word: ASSISTANT:"
}
base_registry = {}
encoder_registry = {}
class BaseModel(metaclass=ABCMeta):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# register model architecture
if hasattr(cls, 'ARCHITECTURE'):
base_registry[cls.ARCHITECTURE] = cls
@classmethod
def from_pretrained(
cls,
model_name_or_path: str,
load_llm: bool = False,
device_map: Optional[Union[str, Dict[str, int]]] = None,
**kwargs):
colored(f'Loading {cls.__name__} from {model_name_or_path}')
return cls(model_name_or_path, load_llm=load_llm, device_map=device_map, **kwargs)
class EncodeMixin(metaclass=ABCMeta):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# register model architecture
if hasattr(cls, 'ARCHITECTURE'):
encoder_registry[cls.ARCHITECTURE] = cls
@abstractmethod
def encode_vision(self, pixel_values: torch.Tensor | List[torch.Tensor]) -> torch.Tensor:
"""
Encodes vision data (images or videos) into a tensor representation.
Args:
pixel_values (torch.Tensor | List[torch.Tensor]): The input pixel values.
- If a tensor, it should be of shape (B, C, H, W) for images or (B, T, C, H, W) for videos.
- If a list, it will be stacked into a tensor.
Returns:
torch.Tensor: The encoded tensor representation of the input vision data.
Raises:
ValueError: If `pixel_values` is not 4D or 5D.
## Notes:
- This function does not accept unbatched inputs.
- `pixel_values` should be of type uint8.
"""
raise NotImplementedError
@abstractmethod
def encode_text(self, text: str | List[str]) -> torch.Tensor:
"""
Encodes the given text(s) into a tensor representation using the model.
Args:
text (str | List[str]): A single string or a list of strings to be encoded.
Returns:
torch.Tensor: The tensor representation of the encoded text(s).
## Notes:
- The method uses a prompt to encode the text.
- If a single string is provided, it is converted into a list containing that string.
- The method processes the prompts and generates the tensor representation using the model.
- The output tensor contains the hidden states of the last token for each input text.
"""
raise NotImplementedError
class BaseModelForTARA(BaseModel):
ARCHITECTURE = "Tarsier2ForConditionalGeneration"
LLM_CLASS = Qwen2VLForCausalLM
MLLM_CLASS = Tarsier2ForConditionalGeneration
@property
def describe_prompt(self):
return "Describe the video in detail."
@property
def text_eol_prompt(self):
# prompt = f'USER: {EOL_PROMPTS["text"]} ASSISTANT: '
prompt = EOL_PROMPTS["text"]
return prompt
@property
def image_eol_prompt(self):
# prompt = f'USER: {EOL_PROMPTS["image"]} ASSISTANT: '
prompt = EOL_PROMPTS["image"]
return prompt
@property
def video_eol_prompt(self):
# prompt = f'USER: {EOL_PROMPTS["video"]} ASSISTANT: '
prompt = EOL_PROMPTS["video"]
return prompt
@staticmethod
def _resolve_attn_implementation(requested_attn_impl: Optional[str] = None) -> str:
attn_impl = requested_attn_impl or "flash_attention_2"
if attn_impl != "flash_attention_2":
return attn_impl
if not torch.cuda.is_available():
print("CUDA is unavailable; falling back attn_implementation to 'eager'.")
return "eager"
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if major < 8:
print(
f"GPU compute capability {major}.x does not support FlashAttention-2; "
"falling back attn_implementation to 'eager'."
)
return "eager"
return "flash_attention_2"
def __init__(
self,
model_name_or_path: str,
load_llm: Optional[bool] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
**kwargs,
):
MODEL_CLASS = self.LLM_CLASS if load_llm else self.MLLM_CLASS
if load_llm:
self.split_weights(model_name_or_path, model_name_or_path + '-llm')
model_name_or_path += '-llm'
model_config = None
# from tarsier2.tarsier2_processor import TarsierProcessor
# self.processor = TarsierProcessor.from_pretrained(model_name_or_path, use_fast=False)
# self.tokenizer = self.processor.tokenizer
import shared.utils as su
self.base_config = su.io.load_yml(
os.path.join(
su.log.repo_path, 'tarsier2/default_config.yaml'
)
)
self.super_processor = init_processor(model_name_or_path, self.base_config)
self.processor = self.super_processor.processor
self.tokenizer = self.processor.tokenizer
else:
model_config = LlavaConfig.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
# from tarsier2.tarsier2_processor import TarsierProcessor
# self.processor = TarsierProcessor.from_pretrained(
# model_name_or_path,
# padding_side='left',
# trust_remote_code=True,
# )
# Load base config
import shared.utils as su
self.base_config = su.io.load_yml(
os.path.join(
su.log.repo_path, 'tarsier2/default_config.yaml'
)
)
self.super_processor = init_processor(model_name_or_path, self.base_config)
self.processor = self.super_processor.processor
self.tokenizer = self.processor.tokenizer
attn_implementation = self._resolve_attn_implementation(
kwargs.get("attn_implementation", "flash_attention_2")
)
self.model = MODEL_CLASS.from_pretrained(
model_name_or_path,
config=model_config,
attn_implementation=attn_implementation,
# torch_dtype=kwargs.get("torch_dtype", torch.bfloat16),
torch_dtype=torch.bfloat16,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=kwargs.get("low_cpu_mem_usage", True), # Default to True for large models
)
# self.processor.patch_size = self.model.config.vision_config.patch_size
# self.processor.vision_feature_select_strategy = self.model.config.vision_feature_select_strategy
self.model.eval()
def split_weights(self, mllm_path, llm_path):
if os.path.exists(llm_path):
print(f'{llm_path} already exists. Skip splitting weights.')
return
print('Splitting LLM weights from MLLM.')
attn_implementation = self._resolve_attn_implementation("flash_attention_2")
model = self.MLLM_CLASS.from_pretrained(
mllm_path,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
)
llm = model.language_model
llm.save_pretrained(llm_path)
import shared.utils as su
from tarsier2.dataset.tarsier_datamodule import init_processor
base_config = su.io.load_yml(
os.path.join(su.log.repo_path, 'models/tarsier2/default_config.yaml'),
)
super_processor = init_processor(
mllm_path,
base_config,
)
super_processor.processor.save_pretrained(llm_path)
super_processor.processor.tokenizer.save_pretrained(llm_path)
class TARA(BaseModelForTARA, EncodeMixin):
def encode_vision(self, video_path: str, prompt=None) -> torch.Tensor:
ext = video_path.split('.')[-1]
if ext in ['mp4', 'avi', 'mov', 'mkv', 'webm']:
is_video = True
else:
is_video = False
if prompt is None:
if is_video:
prompt = self.video_eol_prompt
else:
prompt = self.image_eol_prompt
else:
assert "<video>" in prompt or "<image>" in prompt
sample = format_one_sample(media_file=video_path, prompt=prompt)
sample = self.super_processor(sample)
model_inputs = {}
for k, v in sample.items():
if not isinstance(v, torch.Tensor):
continue
model_inputs[k] = v.to(self.model.device)
with torch.inference_mode():
output = self.model.generate(
**model_inputs,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
pad_token_id=self.processor.tokenizer.eos_token_id
)
emb = output.hidden_states[0][-1][:, -1, :]
return emb
def encode_vision_with_text(self, video_path: str, text: str) -> torch.Tensor:
ext = video_path.split('.')[-1]
# if ext in ['mp4', 'avi', 'mov', 'mkv', 'webm']:
# is_video = True
# else:
# is_video = False
# assert not is_video
prompt = EOL_PROMPTS["video_edit"].replace('<sent>', text)
sample = format_one_sample(media_file=video_path, prompt=prompt)
sample = self.super_processor(sample)
model_inputs = {}
for k, v in sample.items():
if not isinstance(v, torch.Tensor):
continue
model_inputs[k] = v.to(self.model.device)
with torch.inference_mode():
output = self.model.generate(
**model_inputs,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
pad_token_id=self.processor.tokenizer.eos_token_id
)
emb = output.hidden_states[0][-1][:, -1, :]
return emb
def encode_image(self, image_path: str, prompt=None):
ext = image_path.split('.')[-1]
if ext in ['mp4', 'avi', 'mov', 'mkv', 'webm']:
is_video = True
else:
is_video = False
assert not is_video
if prompt is None:
prompt = self.image_eol_prompt
else:
assert "<image>" in prompt
sample = format_one_sample(media_file=image_path, prompt=prompt)
sample = self.super_processor(sample)
model_inputs = {}
for k, v in sample.items():
if not isinstance(v, torch.Tensor):
continue
model_inputs[k] = v.to(self.model.device)
with torch.inference_mode():
output = self.model.generate(
**model_inputs,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
pad_token_id=self.processor.tokenizer.eos_token_id
)
emb = output.hidden_states[0][-1][:, -1, :]
return emb
def encode_text(self, text: str, prompt=None) -> torch.Tensor:
if prompt is None:
prompt = self.text_eol_prompt
else:
assert "<sent>" in prompt
if isinstance(text, str):
prompt = prompt.replace('<sent>', text)
sample = format_one_sample(media_file=None, prompt=prompt)
sample = self.super_processor(sample)
model_inputs = {}
for k, v in sample.items():
if not isinstance(v, torch.Tensor):
continue
model_inputs[k] = v.to(self.model.device)
with torch.inference_mode():
output = self.model.generate(
**model_inputs,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
pad_token_id=self.processor.tokenizer.eos_token_id
)
emb = output.hidden_states[0][-1][:, -1, :]
return emb
elif isinstance(text, list):
text_embs = []
for t in text:
prompt = self.text_eol_prompt.replace('<sent>', t)
sample = format_one_sample(media_file=None, prompt=prompt)
sample = self.super_processor(sample)
model_inputs = {}
for k, v in sample.items():
if not isinstance(v, torch.Tensor):
continue
model_inputs[k] = v.to(self.model.device)
with torch.inference_mode():
output = self.model.generate(
**model_inputs,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
)
emb = output.hidden_states[0][-1][:, -1, :]
text_embs.append(emb)
return torch.cat(text_embs)
else:
raise ValueError(f"Invalid type for text: {type(text)}")
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except (ValueError, IndexError):
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_decord(
video_path, num_frames, sample='middle', fix_start=None,
max_num_frames=-1, trimmed30=False, height=-1, width=-1
):
decord.bridge.set_bridge('torch')
# num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
num_threads = 1
video_reader = VideoReader(video_path, num_threads=num_threads, height=height, width=width)
try:
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
# only use top 30 seconds
if trimmed30 and duration > 30:
duration = 30
vlen = int(30 * float(fps))
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
if not isinstance(frames, torch.Tensor):
frames = torch.from_numpy(frames.asnumpy())
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames
finally:
# Explicitly release underlying resources to avoid file descriptor leaks
del video_reader
def read_image_decord(image_path):
image = PIL.Image.open(image_path)
image = image.convert('RGB')
image = np.array(image)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image)
image = image.unsqueeze(0)
return image
def read_images_decord(image_paths):
images = []
for image_path in image_paths:
image = read_image_decord(image_path)
images.append(image)
images = torch.cat(images)
return images
if __name__ == "__main__":
from termcolor import colored
import random
import numpy as np
from decord import VideoReader
# Load model
model = TARA.from_pretrained(
"/work/piyush/experiments/CaRe/Tarsier2-7b-0115/covr/chiral10k-covr10k/merged_checkpoint/",
device_map='auto',
dtype=torch.bfloat16,
)
n_params = sum(p.numel() for p in model.model.parameters())
print(f"Number of parameters: {round(n_params/1e9, 3)}B")
# Let's encode a sample video
print(colored("Testing video encoding...", 'cyan'))
video_path = "./assets/folding_paper.mp4"
# video_tensor = read_frames_decord(video_path, num_frames=16)
# video_tensor = video_tensor.unsqueeze(0)
# video_tensor = video_tensor.to(model.model.device)
with torch.no_grad():
video_emb = model.encode_vision(video_path).cpu().squeeze(0).float()
# print("Video shape:", video_tensor.shape) # torch.Size([1, 16, 3, 240, 426])
print("Video embedding shape:", video_emb.shape) # torch.Size([4096])
# Let's encode a sample text
print(colored("Testing text encoding...", 'cyan'))
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper']
# NOTE: It can also take a single string
with torch.no_grad():
text_emb = model.encode_text(text).cpu().float()
print("Text:", text)
print("Text embedding shape:", text_emb.shape) # torch.Size([3, 4096])