Video-Text-to-Text
Transformers
Safetensors
English
soccer_qa_4b
soccer
video-qa
question-answering
vision-language
multimodal
sports-analysis
Instructions to use sportsvision/soccer-qa-4b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use sportsvision/soccer-qa-4b with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("sportsvision/soccer-qa-4b", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Soccer QA Inference - Single Class, Clean API | |
| Usage in Colab: | |
| from soccer_qa_inference import SoccerQA | |
| model = SoccerQA("soccer-qa-3b-unified") | |
| answer = model.ask("video.mp4", "What happened?", max_tokens=128) | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from safetensors.torch import load_file | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from decord import VideoReader | |
| # Import your existing modules | |
| import src.datasets.utils.video.transforms as video_transforms | |
| import src.datasets.utils.video.volume_transforms as volume_transforms | |
| from src.models.vision_transformer import vit_giant_rope | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| def get_video(fname, num_frames=16): | |
| """Load video and sample frames uniformly""" | |
| vr = VideoReader(fname) | |
| frame_idx = np.linspace(0, len(vr) - 1, num=num_frames).astype(np.int64) | |
| video = vr.get_batch(frame_idx).asnumpy() | |
| return video | |
| def build_video_transform(img_size): | |
| """Build video preprocessing transform""" | |
| short_side_size = int(256.0 / 224 * img_size) | |
| eval_transform = video_transforms.Compose([ | |
| video_transforms.Resize(short_side_size, interpolation="bilinear"), | |
| video_transforms.CenterCrop(size=(img_size, img_size)), | |
| volume_transforms.ClipToTensor(), | |
| video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
| ]) | |
| return eval_transform | |
| class SoccerQA: | |
| """Single class for Soccer QA inference - Clean Colab API""" | |
| def __init__(self, model_dir="/home/varunkodathala/jepa_llm/soccer_pretrain/soccer-qa-3b-unified"): | |
| """Initialize Soccer QA model | |
| Args: | |
| model_dir: Path to merged model directory | |
| """ | |
| self.model_dir = model_dir | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"🚀 Loading Soccer QA from {model_dir}...") | |
| # Load config and tokenizer | |
| self._load_config() | |
| self._load_tokenizer() | |
| # Build models | |
| self._build_vision_model() | |
| self._build_text_model() | |
| self._build_projection() | |
| # Load all weights | |
| self._load_weights() | |
| # Build video transforms | |
| self.video_transform = build_video_transform(self.img_size) | |
| print("✅ Soccer QA ready!") | |
| def _load_config(self): | |
| """Load model configuration""" | |
| config_path = os.path.join(self.model_dir, "config.json") | |
| with open(config_path, 'r') as f: | |
| self.config = json.load(f) | |
| self.vision_dim = self.config["vision_dim"] # 1408 | |
| self.projection_dim = self.config["projection_dim"] # 2048 | |
| self.text_dim = self.config["text_dim"] # 3072 | |
| self.img_size = self.config["img_size"] # 256 | |
| self.num_frames = self.config["num_frames"] # 16 | |
| def _load_tokenizer(self): | |
| """Load tokenizer with <video> token""" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def _build_vision_model(self): | |
| """Build vision transformer using your src modules""" | |
| self.vision_model = vit_giant_rope( | |
| img_size=(self.img_size, self.img_size), | |
| num_frames=self.num_frames | |
| ) | |
| self.vision_model.to(self.device).eval() | |
| # Freeze vision model | |
| for param in self.vision_model.parameters(): | |
| param.requires_grad = False | |
| def _build_text_model(self): | |
| """Build text model - we'll load merged weights later""" | |
| self.text_model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Llama-3.2-3B", | |
| torch_dtype=torch.float32, | |
| device_map=self.device, | |
| trust_remote_code=True | |
| ) | |
| # Resize for <video> token to match saved model | |
| self.text_model.resize_token_embeddings(len(self.tokenizer)) | |
| self.text_model.eval() | |
| def _build_projection(self): | |
| """Build vision projection layer""" | |
| self.vision_projection = nn.Sequential( | |
| nn.Linear(self.vision_dim, self.projection_dim), # 1408 -> 2048 | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(self.projection_dim, self.text_dim), # 2048 -> 3072 | |
| nn.LayerNorm(self.text_dim) | |
| ).to(self.device) | |
| def _load_weights(self): | |
| """Load all weights from safetensors - optimized approach""" | |
| model_path = os.path.join(self.model_dir, "model.safetensors") | |
| print(f"Loading weights from: {model_path}") | |
| state_dict = load_file(model_path, device=str(self.device)) | |
| # Load vision encoder weights | |
| vision_state = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("vision_encoder."): | |
| new_key = key.replace("vision_encoder.", "") | |
| vision_state[new_key] = value | |
| msg = self.vision_model.load_state_dict(vision_state, strict=False) | |
| print(f"Vision model loaded: {msg}") | |
| # Load projection weights | |
| projection_state = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("vision_projection."): | |
| new_key = key.replace("vision_projection.", "") | |
| projection_state[new_key] = value | |
| self.vision_projection.load_state_dict(projection_state) | |
| print("Projection layer loaded") | |
| # Load text model weights directly from merged state_dict | |
| text_state = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("text_model."): | |
| new_key = key.replace("text_model.", "") | |
| text_state[new_key] = value | |
| # Apply merged weights directly to text model | |
| missing_keys, unexpected_keys = self.text_model.load_state_dict(text_state, strict=False) | |
| if missing_keys: | |
| print(f"Missing keys in text model: {len(missing_keys)} (this is normal)") | |
| if unexpected_keys: | |
| print(f"Unexpected keys in text model: {len(unexpected_keys)}") | |
| print("✅ Text model loaded with merged weights") | |
| # Clear state_dict from memory | |
| del state_dict | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| def _get_video_embeddings(self, video_path): | |
| """Extract video embeddings from video file""" | |
| with torch.inference_mode(): | |
| # Load video | |
| video = get_video(video_path, self.num_frames) | |
| video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W | |
| # Preprocess | |
| x = self.video_transform(video).to(self.device).unsqueeze(0) # [1, 16, 3, 256, 256] | |
| # Extract features | |
| features = self.vision_model(x) # [1, 2048, 1408] | |
| # Handle reshaping | |
| squeezed = features.squeeze(0) # [2048, 1408] | |
| if squeezed.shape[0] % 2048 == 0: | |
| num_clips = squeezed.shape[0] // 2048 | |
| reshaped = squeezed.view(num_clips, 2048, 1408) | |
| else: | |
| reshaped = squeezed.unsqueeze(0) # [1, 2048, 1408] | |
| return reshaped | |
| def _project_vision_features(self, vision_features): | |
| """Project vision features to text embedding space""" | |
| # vision_features: [num_clips, 2048, 1408] | |
| num_clips, patches_per_clip, feature_dim = vision_features.shape | |
| # Flatten: [num_clips * 2048, 1408] | |
| flattened = vision_features.view(-1, feature_dim) | |
| # Project: [num_clips * 2048, 3072] | |
| projected = self.vision_projection(flattened) | |
| # Return flattened for sequence: [total_patches, 3072] | |
| return projected | |
| def ask(self, video_path, question, max_tokens=128, temperature=0.7, top_p=0.9, | |
| repetition_penalty=1.2, no_repeat_ngram_size=3): | |
| """Ask a question about a video | |
| Args: | |
| video_path: Path to video file | |
| question: Question about the video | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| repetition_penalty: Penalty for repetition | |
| no_repeat_ngram_size: N-gram size for repetition blocking | |
| Returns: | |
| Generated answer as string | |
| """ | |
| with torch.no_grad(): | |
| # Get video embeddings | |
| video_features = self._get_video_embeddings(video_path) # [num_clips, 2048, 1408] | |
| vision_embeds = self._project_vision_features(video_features) # [total_patches, 3072] | |
| vision_embeds = vision_embeds.unsqueeze(0) # [1, total_patches, 3072] | |
| # Process question (remove <video> token if present) | |
| question_clean = question.replace("<video>", "").strip() | |
| # Tokenize question | |
| question_tokens = self.tokenizer( | |
| question_clean, | |
| return_tensors="pt", | |
| add_special_tokens=True | |
| ).to(self.device) | |
| # Get text embeddings | |
| text_embeds = self.text_model.get_input_embeddings()(question_tokens.input_ids) | |
| # Combine vision and text embeddings | |
| combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1) | |
| # Create attention mask | |
| vision_attention = torch.ones( | |
| 1, vision_embeds.shape[1], | |
| dtype=question_tokens.attention_mask.dtype, | |
| device=self.device | |
| ) | |
| combined_attention_mask = torch.cat([vision_attention, question_tokens.attention_mask], dim=1) | |
| # Generate response | |
| generated_ids = self.text_model.generate( | |
| inputs_embeds=combined_embeds, | |
| attention_mask=combined_attention_mask, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| use_cache=True, | |
| return_dict_in_generate=False | |
| ) | |
| # Handle different return formats from generate() | |
| if generated_ids.shape[1] > combined_embeds.shape[1]: | |
| # Full sequence returned - slice from combined length | |
| new_tokens = generated_ids[:, combined_embeds.shape[1]:] | |
| else: | |
| # Only new tokens returned - use all | |
| new_tokens = generated_ids | |
| generated_text = self.tokenizer.batch_decode( | |
| new_tokens, | |
| skip_special_tokens=True | |
| )[0] | |
| return generated_text.strip() | |
| def batch_ask(self, video_path, questions, **kwargs): | |
| """Ask multiple questions about the same video | |
| Args: | |
| video_path: Path to video file | |
| questions: List of questions | |
| **kwargs: Generation parameters | |
| Returns: | |
| List of {"question": str, "answer": str} dicts | |
| """ | |
| results = [] | |
| for question in questions: | |
| answer = self.ask(video_path, question, **kwargs) | |
| results.append({"question": question, "answer": answer}) | |
| return results |