Video-Text-to-Text
Transformers
Safetensors
English
soccer_qa_4b
soccer
video-qa
question-answering
vision-language
multimodal
sports-analysis
Instructions to use sportsvision/soccer-qa-4b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use sportsvision/soccer-qa-4b with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("sportsvision/soccer-qa-4b", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import bisect | |
| import csv | |
| import io | |
| import time | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import _utils | |
| from torch.utils.data.dataloader import ExceptionWrapper, _DatasetKind, _MultiProcessingDataLoaderIter | |
| from src.utils.monitoring import ResourceMonitoringThread | |
| class ConcatIndices: | |
| """Helper to map indices of concatenated/mixed datasets to the sample index for the corresponding dataset.""" | |
| cumulative_sizes: np.ndarray | |
| def __init__(self, sizes): | |
| self.cumulative_sizes = np.cumsum(sizes) | |
| def __len__(self): | |
| return self.cumulative_sizes[-1] | |
| def __getitem__(self, idx): | |
| # Returns a pair (dataset_idx, sample_idx) | |
| if idx < 0 or idx >= len(self): | |
| raise ValueError(f"index must be between 0 and the total size ({len(self)})") | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| return dataset_idx, idx | |
| return dataset_idx, idx - self.cumulative_sizes[dataset_idx - 1] | |
| class CSVLogger(object): | |
| """An append-to CSV abstraction. File I/O requires a flush.""" | |
| def __init__(self, fname, header): | |
| """Write header to internal buffers.""" | |
| self.fname = fname | |
| self.buffer = io.StringIO() | |
| self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC) | |
| self.writer.writerow(header) | |
| self.initialized = False | |
| def writerow(self, row) -> None: | |
| """Write row to internal buffers.""" | |
| self.writer.writerow(row) | |
| def flush(self) -> None: | |
| """Flush buffer to file.""" | |
| # Overwrite old file | |
| mode = "a+" if self.initialized else "w" | |
| with open(self.fname, mode, newline="") as f: | |
| f.write(self.buffer.getvalue()) | |
| self.buffer = io.StringIO() | |
| self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC) | |
| self.initialized = True | |
| class MonitoredDataset(torch.utils.data.Dataset): | |
| """Implement resource monitoring on a per-worker basis. | |
| The sampling occurs every monitor_interval seconds and writes the log | |
| every log_interval seconds to a file specified by log_filename, which | |
| maps a worker id to a file using the '%w' placeholder. | |
| Warning: Do not call this dataset before it is consumed in the DataLoader. | |
| """ | |
| def __init__( | |
| self, dataset: torch.utils.data.Dataset, log_filename: str, log_interval: float, monitor_interval: float | |
| ): | |
| self.dataset = dataset | |
| self.log_filename = str(log_filename) | |
| self.log_interval = log_interval | |
| self.monitor_interval = monitor_interval | |
| self._csv_log = None | |
| self._monitoring_thread = None | |
| self._last_log_time = None | |
| # Patch getitems dynamically | |
| if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: | |
| def __getitems__(self, index): | |
| self.maybe_start_resource_monitoring() | |
| return self.dataset.__getitems__(index) | |
| self.__getitems__ = __getitems__ | |
| def __del__(self): | |
| self.stop_resource_monitoring() | |
| def __getitem__(self, index): | |
| self.maybe_start_resource_monitoring() | |
| return self.dataset.__getitem__(index) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def _elapsed_log_time(self): | |
| if self._last_log_time is None: | |
| return float("inf") | |
| else: | |
| return time.perf_counter() - self._last_log_time | |
| def _update_log_time(self): | |
| self._last_log_time = time.perf_counter() | |
| def maybe_start_resource_monitoring(self): | |
| if self._monitoring_thread is None: | |
| def callback_fn(resource_sample): | |
| worker_info = torch.utils.data.get_worker_info() | |
| worker_id = worker_info.id | |
| if self._csv_log is None: | |
| header = [f.name for f in resource_sample.fields()] | |
| log_filename = self.log_filename.replace("%w", str(worker_id)) | |
| self._csv_log = CSVLogger(log_filename, header) | |
| row_values = resource_sample.as_tuple() | |
| self._csv_log.writerow(row_values) | |
| if self._elapsed_log_time() > self.log_interval: | |
| self._csv_log.flush() | |
| self._update_log_time() | |
| self._monitoring_thread = ResourceMonitoringThread( | |
| None, self.monitor_interval, stats_callback_fn=callback_fn | |
| ) | |
| self._monitoring_thread.start() | |
| def stop_resource_monitoring(self): | |
| if self._monitoring_thread: | |
| self._monitoring_thread.stop() | |
| class NondeterministicDataLoader(torch.utils.data.DataLoader): | |
| """Override torch dataloader to return out of order.""" | |
| def __init__(self, *args, **kwargs): | |
| """Pass through constructor.""" | |
| super().__init__(*args, **kwargs) | |
| def _get_iterator(self): | |
| if self.num_workers: | |
| self.check_worker_number_rationality() | |
| return _SloppyMultiProcessingDataLoaderIter(self) | |
| else: | |
| return super()._get_iterator() | |
| class _SloppyMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): | |
| def __init__(self, *args, **kwargs): | |
| """Pass through constructor.""" | |
| super().__init__(*args, **kwargs) | |
| def _next_data(self): | |
| """Adds out of order returns.""" | |
| while True: | |
| # If the worker responsible for `self._rcvd_idx` has already ended | |
| # and was unable to fulfill this task (due to exhausting an `IterableDataset`), | |
| # we try to advance `self._rcvd_idx` to find the next valid index. | |
| # | |
| # This part needs to run in the loop because both the `self._get_data()` | |
| # call and `_IterableDatasetStopIteration` check below can mark | |
| # extra worker(s) as dead. | |
| while self._rcvd_idx < self._send_idx: | |
| info = self._task_info[self._rcvd_idx] | |
| if info is None: | |
| # Found a reordered tombstone | |
| del self._task_info[self._rcvd_idx] | |
| self._rcvd_idx += 1 | |
| self._try_put_index() | |
| else: | |
| worker_id = info[0] | |
| # has data or is still active | |
| if len(info) == 2 or self._workers_status[worker_id]: | |
| break | |
| del self._task_info[self._rcvd_idx] | |
| self._rcvd_idx += 1 | |
| else: | |
| # no valid `self._rcvd_idx` is found (i.e., didn't break) | |
| if not self._persistent_workers: | |
| self._shutdown_workers() | |
| raise StopIteration | |
| # Now `self._rcvd_idx` is the batch index we want to fetch | |
| # Check if the next sample has already been generated | |
| if len(self._task_info[self._rcvd_idx]) == 2: | |
| data = self._task_info.pop(self._rcvd_idx)[1] | |
| return self._process_data(data) | |
| assert not self._shutdown and self._tasks_outstanding > 0 | |
| idx, data = self._get_data() | |
| self._tasks_outstanding -= 1 | |
| if self._dataset_kind == _DatasetKind.Iterable: | |
| # Check for _IterableDatasetStopIteration | |
| if isinstance(data, _utils.worker._IterableDatasetStopIteration): | |
| if self._persistent_workers: | |
| self._workers_status[data.worker_id] = False | |
| else: | |
| self._mark_worker_as_unavailable(data.worker_id) | |
| self._try_put_index() | |
| continue | |
| if idx != self._rcvd_idx: | |
| # Tombstone to recieve later | |
| self._task_info[idx] = None | |
| if isinstance(data, ExceptionWrapper): | |
| data.reraise() | |
| return data | |
| else: | |
| del self._task_info[idx] | |
| return self._process_data(data) | |
| def get_worker_info(): | |
| worker_info = torch.utils.data.get_worker_info() | |
| if worker_info is None: | |
| num_workers = 1 | |
| worker_id = 0 | |
| else: | |
| num_workers = worker_info.num_workers | |
| worker_id = worker_info.id | |
| return num_workers, worker_id | |