Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import colorsys | |
| import gc | |
| from copy import deepcopy | |
| import base64 | |
| import math | |
| import statistics | |
| from pathlib import Path | |
| import plotly.graph_objects as go | |
| BASE64_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.b64") | |
| EXAMPLE_VIDEO_PATH = Path("Kickit-Video-2025-07-09-13-47-18-389.mp4") | |
| def ensure_example_video() -> str: | |
| """ | |
| Ensure the Kickit example video exists locally by decoding the base64 text file. | |
| Returns the path to the decoded MP4. | |
| """ | |
| if EXAMPLE_VIDEO_PATH.exists(): | |
| return str(EXAMPLE_VIDEO_PATH) | |
| if not BASE64_VIDEO_PATH.exists(): | |
| raise FileNotFoundError("Base64 video asset not found.") | |
| data = BASE64_VIDEO_PATH.read_text() | |
| EXAMPLE_VIDEO_PATH.write_bytes(base64.b64decode(data)) | |
| return str(EXAMPLE_VIDEO_PATH) | |
| from types import SimpleNamespace | |
| from typing import Optional, Any | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from gradio.themes import Soft | |
| from PIL import Image, ImageDraw | |
| from transformers import AutoModel, Sam2VideoProcessor | |
| from ultralytics import YOLO | |
| from huggingface_hub import hf_hub_download | |
| YOLO_MODEL_CACHE: dict[str, YOLO] = {} | |
| YOLO_DEFAULT_MODEL = "yolov13n.pt" | |
| YOLO_REPO_ID = "atalaydenknalbant/Yolov13" | |
| YOLO_TARGET_NAME = "sports ball" | |
| YOLO_CONF_THRESHOLD = 0.0 | |
| YOLO_IOU_THRESHOLD = 0.02 | |
| PLAYER_TARGET_NAME = "person" | |
| PLAYER_OBJECT_ID = 2 | |
| BALL_OBJECT_ID = 1 | |
| def get_yolo_model(model_filename: str = YOLO_DEFAULT_MODEL) -> YOLO: | |
| """ | |
| Lazily download and load a YOLOv13 model, caching it for reuse. | |
| """ | |
| if model_filename in YOLO_MODEL_CACHE: | |
| return YOLO_MODEL_CACHE[model_filename] | |
| model_path = hf_hub_download(repo_id=YOLO_REPO_ID, filename=model_filename) | |
| model = YOLO(model_path) | |
| YOLO_MODEL_CACHE[model_filename] = model | |
| return model | |
| def detect_ball_center( | |
| frame: Image.Image, | |
| model_filename: str = YOLO_DEFAULT_MODEL, | |
| conf_threshold: float = YOLO_CONF_THRESHOLD, | |
| iou_threshold: float = YOLO_IOU_THRESHOLD, | |
| ) -> Optional[tuple[int, int, int, int, float]]: | |
| """ | |
| Run YOLO on a single frame and return (x_center, y_center, width, height, confidence) | |
| for the highest-confidence sports ball detection. | |
| """ | |
| model = get_yolo_model(model_filename) | |
| class_ids = [ | |
| idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME | |
| ] | |
| if not class_ids: | |
| return None | |
| results = model.predict( | |
| source=frame, | |
| conf=conf_threshold, | |
| iou=iou_threshold, | |
| max_det=1, | |
| classes=class_ids, | |
| imgsz=640, | |
| device="cpu", | |
| verbose=False, | |
| ) | |
| if not results: | |
| return None | |
| boxes = results[0].boxes | |
| if boxes is None or len(boxes) == 0: | |
| return None | |
| box = boxes[0] | |
| # xywh format: x_center, y_center, width, height | |
| xywh = box.xywh[0].cpu().tolist() | |
| conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 | |
| x_center, y_center, width, height = xywh | |
| return ( | |
| int(round(x_center)), | |
| int(round(y_center)), | |
| int(round(width)), | |
| int(round(height)), | |
| conf, | |
| ) | |
| def detect_person_box( | |
| frame: Image.Image, | |
| model_filename: str = YOLO_DEFAULT_MODEL, | |
| conf_threshold: float = YOLO_CONF_THRESHOLD, | |
| iou_threshold: float = YOLO_IOU_THRESHOLD, | |
| ) -> Optional[tuple[int, int, int, int, float]]: | |
| """ | |
| Run YOLO on a single frame and return (x_min, y_min, x_max, y_max, confidence) | |
| for the highest-confidence person detection. | |
| """ | |
| model = get_yolo_model(model_filename) | |
| class_ids = [ | |
| idx for idx, name in model.names.items() if name.lower() == PLAYER_TARGET_NAME | |
| ] | |
| if not class_ids: | |
| return None | |
| results = model.predict( | |
| source=frame, | |
| conf=conf_threshold, | |
| iou=iou_threshold, | |
| max_det=5, | |
| classes=class_ids, | |
| imgsz=640, | |
| device="cpu", | |
| verbose=False, | |
| ) | |
| if not results: | |
| return None | |
| boxes = results[0].boxes | |
| if boxes is None or len(boxes) == 0: | |
| return None | |
| box = boxes[0] | |
| xyxy = box.xyxy[0].cpu().tolist() | |
| conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 | |
| x_min, y_min, x_max, y_max = xyxy | |
| frame_width, frame_height = frame.size | |
| x_min = max(0, min(frame_width - 1, int(round(x_min)))) | |
| y_min = max(0, min(frame_height - 1, int(round(y_min)))) | |
| x_max = max(0, min(frame_width - 1, int(round(x_max)))) | |
| y_max = max(0, min(frame_height - 1, int(round(y_max)))) | |
| if x_max <= x_min or y_max <= y_min: | |
| return None | |
| return x_min, y_min, x_max, y_max, conf | |
| def _compute_sam_window_from_kick(state: AppState, kick_frame: int | None) -> tuple[int, int]: | |
| total_frames = state.num_frames | |
| if total_frames == 0: | |
| return 0, 0 | |
| fps = state.video_fps if state.video_fps and state.video_fps > 0 else 25.0 | |
| target_window_frames = max(1, int(round(fps * 4.0))) | |
| half_window = target_window_frames // 2 | |
| if kick_frame is None: | |
| start_idx = 0 | |
| else: | |
| start_idx = max(0, int(kick_frame) - half_window) | |
| end_idx = min(total_frames, start_idx + target_window_frames) | |
| if end_idx <= start_idx: | |
| end_idx = min(total_frames, start_idx + 1) | |
| state.sam_window = (start_idx, end_idx) | |
| return start_idx, end_idx | |
| def _perform_yolo_ball_tracking(state: AppState, progress: gr.Progress | None = None) -> None: | |
| if state is None or state.num_frames == 0: | |
| raise gr.Error("Load a video first, then track with YOLO.") | |
| model = get_yolo_model() | |
| class_ids = [ | |
| idx for idx, name in model.names.items() if name.lower() == YOLO_TARGET_NAME | |
| ] | |
| if not class_ids: | |
| raise gr.Error("YOLO model does not contain the sports ball class.") | |
| frames = state.video_frames | |
| total = len(frames) | |
| centers: dict[int, tuple[float, float]] = {} | |
| boxes: dict[int, tuple[int, int, int, int]] = {} | |
| confs: dict[int, float] = {} | |
| areas: dict[int, float] = {} | |
| first_detection_frame: int | None = None | |
| for idx, frame in enumerate(frames): | |
| if progress is not None: | |
| progress((idx + 1) / total) | |
| results = model.predict( | |
| source=frame, | |
| conf=YOLO_CONF_THRESHOLD, | |
| iou=YOLO_IOU_THRESHOLD, | |
| max_det=1, | |
| classes=class_ids, | |
| imgsz=640, | |
| device="cpu", | |
| verbose=False, | |
| ) | |
| if not results: | |
| continue | |
| boxes_result = results[0].boxes | |
| if boxes_result is None or len(boxes_result) == 0: | |
| continue | |
| box = boxes_result[0] | |
| xywh = box.xywh[0].cpu().tolist() | |
| conf = float(box.conf[0].cpu().item()) if box.conf is not None else 0.0 | |
| x_center, y_center, width, height = xywh | |
| x_center = float(x_center) | |
| y_center = float(y_center) | |
| width = max(1.0, float(width)) | |
| height = max(1.0, float(height)) | |
| frame_width, frame_height = frame.size | |
| x_min = int(round(max(0.0, x_center - width / 2.0))) | |
| y_min = int(round(max(0.0, y_center - height / 2.0))) | |
| x_max = int(round(min(frame_width - 1.0, x_center + width / 2.0))) | |
| y_max = int(round(min(frame_height - 1.0, y_center + height / 2.0))) | |
| if x_max <= x_min or y_max <= y_min: | |
| continue | |
| centers[idx] = (x_center, y_center) | |
| boxes[idx] = (x_min, y_min, x_max, y_max) | |
| confs[idx] = conf | |
| areas[idx] = float((x_max - x_min) * (y_max - y_min)) | |
| if first_detection_frame is None: | |
| first_detection_frame = idx | |
| state.yolo_ball_centers = centers | |
| state.yolo_ball_boxes = boxes | |
| state.yolo_ball_conf = confs | |
| state.yolo_mask_area_proxy = [areas.get(k, 0.0) for k in sorted(centers.keys())] | |
| state.yolo_initial_frame = first_detection_frame | |
| if len(centers) < 3: | |
| state.yolo_smoothed_centers = {} | |
| state.yolo_speeds = {} | |
| state.yolo_distance_from_start = {} | |
| state.yolo_threshold = None | |
| state.yolo_baseline_speed = None | |
| state.yolo_speed_std = None | |
| state.yolo_kick_frame = None | |
| state.yolo_status = "❌ YOLO13: insufficient detections to estimate kick. Please retry or annotate manually." | |
| state.sam_window = None | |
| return | |
| items = sorted(centers.items()) | |
| dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 | |
| alpha = 0.35 | |
| smoothed: dict[int, tuple[float, float]] = {} | |
| speeds: dict[int, float] = {} | |
| prev_frame = None | |
| prev_smooth = None | |
| for frame_idx, (cx, cy) in items: | |
| if prev_smooth is None: | |
| smooth_x, smooth_y = float(cx), float(cy) | |
| else: | |
| smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) | |
| smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) | |
| smoothed[frame_idx] = (smooth_x, smooth_y) | |
| if prev_smooth is None or prev_frame is None: | |
| speeds[frame_idx] = 0.0 | |
| else: | |
| frame_delta = max(1, frame_idx - prev_frame) | |
| time_delta = frame_delta * dt | |
| dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) | |
| speed = dist / time_delta if time_delta > 0 else dist | |
| speeds[frame_idx] = speed | |
| prev_smooth = (smooth_x, smooth_y) | |
| prev_frame = frame_idx | |
| frames_ordered = [frame_idx for frame_idx, _ in items] | |
| speed_series = [speeds.get(f, 0.0) for f in frames_ordered] | |
| baseline_window = min(10, len(frames_ordered) // 3 or 1) | |
| baseline_speeds = speed_series[:baseline_window] | |
| baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 | |
| speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 | |
| base_threshold = baseline_speed + 4.0 * speed_std | |
| if base_threshold < baseline_speed * 3.0: | |
| base_threshold = baseline_speed * 3.0 | |
| speed_threshold = max(base_threshold, 15.0) | |
| distance_dict: dict[int, float] = {} | |
| if smoothed: | |
| first_frame = frames_ordered[0] | |
| origin = smoothed[first_frame] | |
| for frame_idx, (sx, sy) in smoothed.items(): | |
| distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) | |
| areas_dict = {idx: areas.get(idx, 0.0) for idx in frames_ordered} | |
| initial_area = areas_dict.get(frames_ordered[0], 1.0) or 1.0 | |
| radius_estimate = math.sqrt(initial_area / math.pi) | |
| adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) | |
| sustain_frames = 3 | |
| holdout_frames = 8 | |
| area_window = 4 | |
| area_drop_ratio = 0.75 | |
| kalman_pos, kalman_speed, _ = _run_kalman_filter(items, dt) | |
| kalman_speed_series = [kalman_speed.get(f, 0.0) for f in frames_ordered] | |
| kick_frame: int | None = None | |
| for idx, frame in enumerate(frames_ordered[baseline_window:], start=baseline_window): | |
| speed = speed_series[idx] | |
| if speed < speed_threshold: | |
| continue | |
| sustain_ok = True | |
| for j in range(1, sustain_frames + 1): | |
| if idx + j >= len(frames_ordered): | |
| break | |
| if speed_series[idx + j] < speed_threshold * 0.7: | |
| sustain_ok = False | |
| break | |
| if not sustain_ok: | |
| continue | |
| area_pass = True | |
| current_area = areas_dict.get(frame) | |
| if current_area: | |
| prev_areas = [ | |
| areas_dict.get(f) | |
| for f in frames_ordered[max(0, idx - area_window):idx] | |
| if areas_dict.get(f) is not None | |
| ] | |
| if prev_areas: | |
| median_prev = statistics.median(prev_areas) | |
| if median_prev > 0: | |
| ratio = current_area / median_prev | |
| if ratio > area_drop_ratio: | |
| area_pass = False | |
| if not area_pass and speed < speed_threshold * 1.2: | |
| continue | |
| future_slice = frames_ordered[idx: min(len(frames_ordered), idx + holdout_frames)] | |
| max_future_dist = 0.0 | |
| for future_frame in future_slice: | |
| dist = distance_dict.get(future_frame, 0.0) | |
| if dist > max_future_dist: | |
| max_future_dist = dist | |
| if max_future_dist < adaptive_return_distance: | |
| continue | |
| kick_frame = frame | |
| break | |
| state.yolo_smoothed_centers = smoothed | |
| state.yolo_speeds = speeds | |
| state.yolo_distance_from_start = distance_dict | |
| state.yolo_threshold = speed_threshold | |
| state.yolo_baseline_speed = baseline_speed | |
| state.yolo_speed_std = speed_std | |
| state.yolo_kick_frames = frames_ordered | |
| state.yolo_kick_speeds = speed_series | |
| state.yolo_kick_distance = [distance_dict.get(f, 0.0) for f in frames_ordered] | |
| state.yolo_mask_area_proxy = [areas_dict.get(f, 0.0) for f in frames_ordered] | |
| state.yolo_kick_frame = kick_frame | |
| coverage = len(centers) / total if total else 0.0 | |
| if kick_frame is not None: | |
| state.yolo_status = f"✅ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%})." | |
| else: | |
| state.yolo_status = ( | |
| f"⚠️ YOLO13 tracked {len(centers)}/{total} frames ({coverage:.0%}) but did not find a definitive kick." | |
| ) | |
| state.kalman_centers[BALL_OBJECT_ID] = kalman_pos | |
| state.kalman_speeds[BALL_OBJECT_ID] = kalman_speed | |
| if kick_frame is not None: | |
| state.kick_frame = kick_frame | |
| _compute_sam_window_from_kick(state, kick_frame) | |
| else: | |
| state.sam_window = None | |
| def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]: | |
| """Generate a deterministic pastel RGB color for a given object id. | |
| Uses golden ratio to distribute hues; low-medium saturation, high value. | |
| """ | |
| golden_ratio_conjugate = 0.61803398875 | |
| # Map obj_id (1-based) to hue in [0,1) | |
| hue = (obj_id * golden_ratio_conjugate) % 1.0 | |
| saturation = 0.45 | |
| value = 1.0 | |
| r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) | |
| return int(r_f * 255), int(g_f * 255), int(b_f * 255) | |
| def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: | |
| """Load video frames as PIL Images using transformers.video_utils if available, | |
| otherwise fall back to OpenCV. Returns (frames, info). | |
| """ | |
| cap = cv2.VideoCapture(video_path_or_url) | |
| frames = [] | |
| print("loading video frames") | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| # Gather fps if available | |
| fps_val = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| print("loaded video frames") | |
| info = { | |
| "num_frames": len(frames), | |
| "fps": float(fps_val) if fps_val and fps_val > 0 else None, | |
| } | |
| return frames, info | |
| def overlay_masks_on_frame( | |
| frame: Image.Image, | |
| masks_per_object: dict[int, np.ndarray], | |
| color_by_obj: dict[int, tuple[int, int, int]], | |
| alpha: float = 0.5, | |
| ) -> Image.Image: | |
| """Overlay per-object soft masks onto the RGB frame. | |
| masks_per_object: mapping of obj_id -> (H, W) float mask in [0,1] | |
| color_by_obj: mapping of obj_id -> (R, G, B) | |
| """ | |
| base = np.array(frame).astype(np.float32) / 255.0 # H, W, 3 in [0,1] | |
| height, width = base.shape[:2] | |
| overlay = base.copy() | |
| for obj_id, mask in masks_per_object.items(): | |
| if mask is None: | |
| continue | |
| if mask.dtype != np.float32: | |
| mask = mask.astype(np.float32) | |
| # Ensure shape is H x W | |
| if mask.ndim == 3: | |
| mask = mask.squeeze() | |
| mask = np.clip(mask, 0.0, 1.0) | |
| color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0 | |
| # Blend: overlay = (1 - a*m)*overlay + (a*m)*color | |
| a = alpha | |
| m = mask[..., None] | |
| overlay = (1.0 - a * m) * overlay + (a * m) * color | |
| out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) | |
| return Image.fromarray(out) | |
| def get_device_and_dtype() -> tuple[str, torch.dtype]: | |
| device = "cpu" | |
| dtype = torch.bfloat16 | |
| return device, dtype | |
| class AppState: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.video_frames: list[Image.Image] = [] | |
| self.inference_session = None | |
| self.model: Optional[AutoModel] = None | |
| self.processor: Optional[Sam2VideoProcessor] = None | |
| self.device: str = "cpu" | |
| self.dtype: torch.dtype = torch.bfloat16 | |
| self.video_fps: float | None = None | |
| self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} | |
| self.color_by_obj: dict[int, tuple[int, int, int]] = {} | |
| self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {} | |
| self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {} | |
| # Cache of composited frames (original + masks + clicks) | |
| self.composited_frames: dict[int, Image.Image] = {} | |
| # UI state for click handler | |
| self.current_frame_idx: int = 0 | |
| self.current_obj_id: int = 1 | |
| self.current_label: str = "positive" | |
| self.current_clear_old: bool = True | |
| self.current_prompt_type: str = "Points" # or "Boxes" | |
| self.pending_box_start: tuple[int, int] | None = None | |
| self.pending_box_start_frame_idx: int | None = None | |
| self.pending_box_start_obj_id: int | None = None | |
| self.is_switching_model: bool = False | |
| self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {} | |
| self.mask_areas: dict[int, dict[int, float]] = {} | |
| self.smoothed_centers: dict[int, dict[int, tuple[float, float]]] = {} | |
| self.ball_speeds: dict[int, dict[int, float]] = {} | |
| self.distance_from_start: dict[int, dict[int, float]] = {} | |
| self.direction_change: dict[int, dict[int, float]] = {} | |
| self.kick_frame: int | None = None | |
| self.kick_debug_frames: list[int] = [] | |
| self.kick_debug_speeds: list[float] = [] | |
| self.kick_debug_threshold: float | None = None | |
| self.kick_debug_baseline: float | None = None | |
| self.kick_debug_speed_std: float | None = None | |
| self.kick_debug_area: list[float] = [] | |
| self.kick_debug_kick_frame: int | None = None | |
| self.kick_debug_distance: list[float] = [] | |
| self.kick_debug_kalman_speeds: list[float] = [] | |
| self.kalman_centers: dict[int, dict[int, tuple[float, float]]] = {} | |
| self.kalman_speeds: dict[int, dict[int, float]] = {} | |
| self.kalman_residuals: dict[int, dict[int, float]] = {} | |
| self.min_impact_speed_kmh: float = 20.0 | |
| self.goal_distance_m: float = 18.0 | |
| self.impact_frame: int | None = None | |
| self.impact_debug_frames: list[int] = [] | |
| self.impact_debug_innovation: list[float] = [] | |
| self.impact_debug_innovation_threshold: float | None = None | |
| self.impact_debug_direction: list[float] = [] | |
| self.impact_debug_direction_threshold: float | None = None | |
| self.impact_debug_speed_kmh: list[float] = [] | |
| self.impact_debug_speed_threshold_px: float | None = None | |
| self.impact_meters_per_px: float | None = None | |
| # Model selection | |
| self.model_repo_key: str = "tiny" | |
| self.model_repo_id: str | None = None | |
| self.session_repo_id: str | None = None | |
| self.player_obj_id: int | None = None | |
| self.player_detection_frame: int | None = None | |
| self.player_detection_conf: float | None = None | |
| # YOLO tracking caches | |
| self.yolo_ball_centers: dict[int, tuple[float, float]] = {} | |
| self.yolo_ball_boxes: dict[int, tuple[int, int, int, int]] = {} | |
| self.yolo_ball_conf: dict[int, float] = {} | |
| self.yolo_smoothed_centers: dict[int, tuple[float, float]] = {} | |
| self.yolo_speeds: dict[int, float] = {} | |
| self.yolo_distance_from_start: dict[int, float] = {} | |
| self.yolo_threshold: float | None = None | |
| self.yolo_baseline_speed: float | None = None | |
| self.yolo_speed_std: float | None = None | |
| self.yolo_kick_frame: int | None = None | |
| self.yolo_status: str = "" | |
| self.yolo_kick_frames: list[int] = [] | |
| self.yolo_kick_speeds: list[float] = [] | |
| self.yolo_kick_distance: list[float] = [] | |
| self.yolo_mask_area_proxy: list[float] = [] | |
| self.yolo_initial_frame: int | None = None | |
| # SAM window (start_idx inclusive, end_idx exclusive) | |
| self.sam_window: tuple[int, int] | None = None | |
| # Cutout / compositing effects | |
| self.fx_soft_matte_enabled: bool = True | |
| self.fx_soft_matte_feather: float = 4.0 | |
| self.fx_soft_matte_erode: float = 0.5 | |
| self.fx_blur_enabled: bool = True | |
| self.fx_blur_sigma: float = 0.0 | |
| self.fx_bg_darkening: float = 1.0 | |
| self.fx_light_wrap_enabled: bool = False | |
| self.fx_light_wrap_strength: float = 0.6 | |
| self.fx_light_wrap_width: float = 15.0 | |
| self.fx_glow_enabled: bool = False | |
| self.fx_glow_strength: float = 0.4 | |
| self.fx_glow_radius: float = 10.0 | |
| self.fx_ghost_trail_enabled: bool = True | |
| self.show_click_marks: bool = False | |
| self.fx_ball_ring_enabled: bool = False | |
| def __repr__(self): | |
| return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})" | |
| def num_frames(self) -> int: | |
| return len(self.video_frames) | |
| def _model_repo_from_key(key: str) -> str: | |
| mapping = { | |
| "tiny": "facebook/sam2.1-hiera-tiny", | |
| "small": "facebook/sam2.1-hiera-small", | |
| "base_plus": "facebook/sam2.1-hiera-base-plus", | |
| "large": "facebook/sam2.1-hiera-large", | |
| } | |
| return mapping.get(key, mapping["base_plus"]) | |
| def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]: | |
| desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) | |
| if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None: | |
| if GLOBAL_STATE.model_repo_id == desired_repo: | |
| return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype | |
| # Different repo requested: dispose current and reload | |
| GLOBAL_STATE.model = None | |
| GLOBAL_STATE.processor = None | |
| print(f"Loading model from {desired_repo}") | |
| device, dtype = get_device_and_dtype() | |
| # free up the gpu memory | |
| model = AutoModel.from_pretrained(desired_repo) | |
| processor = Sam2VideoProcessor.from_pretrained(desired_repo) | |
| model.to(device, dtype=dtype) | |
| GLOBAL_STATE.model = model | |
| GLOBAL_STATE.processor = processor | |
| GLOBAL_STATE.device = device | |
| GLOBAL_STATE.dtype = dtype | |
| GLOBAL_STATE.model_repo_id = desired_repo | |
| def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None: | |
| """Ensure the model/processor match the selected repo and inference_session exists. | |
| If a video is already loaded, re-initialize the inference session when needed. | |
| """ | |
| load_model_if_needed(GLOBAL_STATE) | |
| desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key) | |
| if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo: | |
| if GLOBAL_STATE.video_frames: | |
| # Clear session-related UI caches when switching model | |
| GLOBAL_STATE.masks_by_frame.clear() | |
| GLOBAL_STATE.clicks_by_frame_obj.clear() | |
| GLOBAL_STATE.boxes_by_frame_obj.clear() | |
| GLOBAL_STATE.composited_frames.clear() | |
| GLOBAL_STATE.inference_session = None | |
| GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session( | |
| inference_device=GLOBAL_STATE.device, | |
| video_storage_device="cpu", | |
| dtype=GLOBAL_STATE.dtype, | |
| ) | |
| GLOBAL_STATE.session_repo_id = desired_repo | |
| def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppState, int, int, Image.Image, str]: | |
| """Gradio handler: load video, init session, return state, slider bounds, and first frame.""" | |
| # Reset ONLY video-related fields, keep model loaded | |
| GLOBAL_STATE.video_frames = [] | |
| GLOBAL_STATE.inference_session = None | |
| GLOBAL_STATE.masks_by_frame = {} | |
| GLOBAL_STATE.color_by_obj = {} | |
| GLOBAL_STATE.ball_centers = {} | |
| GLOBAL_STATE.mask_areas = {} | |
| GLOBAL_STATE.smoothed_centers = {} | |
| GLOBAL_STATE.ball_speeds = {} | |
| GLOBAL_STATE.distance_from_start = {} | |
| GLOBAL_STATE.direction_change = {} | |
| GLOBAL_STATE.kick_frame = None | |
| GLOBAL_STATE.kalman_centers = {} | |
| GLOBAL_STATE.kalman_speeds = {} | |
| GLOBAL_STATE.kalman_residuals = {} | |
| GLOBAL_STATE.kick_debug_kalman_speeds = [] | |
| GLOBAL_STATE.kick_debug_frames = [] | |
| GLOBAL_STATE.kick_debug_speeds = [] | |
| GLOBAL_STATE.kick_debug_threshold = None | |
| GLOBAL_STATE.kick_debug_baseline = None | |
| GLOBAL_STATE.kick_debug_speed_std = None | |
| GLOBAL_STATE.kick_debug_area = [] | |
| GLOBAL_STATE.kick_debug_kick_frame = None | |
| GLOBAL_STATE.kick_debug_distance = [] | |
| GLOBAL_STATE.impact_frame = None | |
| GLOBAL_STATE.impact_debug_frames = [] | |
| GLOBAL_STATE.impact_debug_innovation = [] | |
| GLOBAL_STATE.impact_debug_innovation_threshold = None | |
| GLOBAL_STATE.impact_debug_direction = [] | |
| GLOBAL_STATE.impact_debug_direction_threshold = None | |
| GLOBAL_STATE.impact_debug_speed_kmh = [] | |
| GLOBAL_STATE.impact_debug_speed_threshold_px = None | |
| GLOBAL_STATE.impact_meters_per_px = None | |
| GLOBAL_STATE.yolo_ball_centers = {} | |
| GLOBAL_STATE.yolo_ball_boxes = {} | |
| GLOBAL_STATE.yolo_ball_conf = {} | |
| GLOBAL_STATE.yolo_smoothed_centers = {} | |
| GLOBAL_STATE.yolo_speeds = {} | |
| GLOBAL_STATE.yolo_distance_from_start = {} | |
| GLOBAL_STATE.yolo_threshold = None | |
| GLOBAL_STATE.yolo_baseline_speed = None | |
| GLOBAL_STATE.yolo_speed_std = None | |
| GLOBAL_STATE.yolo_kick_frame = None | |
| GLOBAL_STATE.yolo_status = "" | |
| GLOBAL_STATE.yolo_kick_frames = [] | |
| GLOBAL_STATE.yolo_kick_speeds = [] | |
| GLOBAL_STATE.yolo_kick_distance = [] | |
| GLOBAL_STATE.yolo_mask_area_proxy = [] | |
| GLOBAL_STATE.yolo_initial_frame = None | |
| GLOBAL_STATE.sam_window = None | |
| GLOBAL_STATE.player_obj_id = None | |
| GLOBAL_STATE.player_detection_frame = None | |
| GLOBAL_STATE.player_detection_conf = None | |
| GLOBAL_STATE.yolo_ball_centers = {} | |
| GLOBAL_STATE.yolo_ball_boxes = {} | |
| GLOBAL_STATE.yolo_ball_conf = {} | |
| GLOBAL_STATE.yolo_smoothed_centers = {} | |
| GLOBAL_STATE.yolo_speeds = {} | |
| GLOBAL_STATE.yolo_distance_from_start = {} | |
| GLOBAL_STATE.yolo_threshold = None | |
| GLOBAL_STATE.yolo_baseline_speed = None | |
| GLOBAL_STATE.yolo_speed_std = None | |
| GLOBAL_STATE.yolo_kick_frame = None | |
| GLOBAL_STATE.yolo_status = "" | |
| GLOBAL_STATE.yolo_kick_frames = [] | |
| GLOBAL_STATE.yolo_kick_speeds = [] | |
| GLOBAL_STATE.yolo_kick_distance = [] | |
| GLOBAL_STATE.yolo_mask_area_proxy = [] | |
| GLOBAL_STATE.yolo_initial_frame = None | |
| GLOBAL_STATE.sam_window = None | |
| load_model_if_needed(GLOBAL_STATE) | |
| # Gradio Video may provide a dict with 'name' or a direct file path | |
| video_path: Optional[str] = None | |
| if isinstance(video, dict): | |
| video_path = video.get("name") or video.get("path") or video.get("data") | |
| elif isinstance(video, str): | |
| video_path = video | |
| else: | |
| video_path = None | |
| if not video_path: | |
| raise gr.Error("Invalid video input.") | |
| frames, info = try_load_video_frames(video_path) | |
| if len(frames) == 0: | |
| raise gr.Error("No frames could be loaded from the video.") | |
| # Enforce max duration of 8 seconds (trim if longer) | |
| MAX_SECONDS = 8.0 | |
| trimmed_note = "" | |
| fps_in = info.get("fps") | |
| max_frames_allowed = int(MAX_SECONDS * fps_in) | |
| if len(frames) > max_frames_allowed: | |
| frames = frames[:max_frames_allowed] | |
| trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" | |
| if isinstance(info, dict): | |
| info["num_frames"] = len(frames) | |
| GLOBAL_STATE.video_frames = frames | |
| # Try to capture original FPS if provided by loader | |
| GLOBAL_STATE.video_fps = float(fps_in) | |
| # Initialize session | |
| inference_session = GLOBAL_STATE.processor.init_video_session( | |
| inference_device=GLOBAL_STATE.device, | |
| video_storage_device="cpu", | |
| dtype=GLOBAL_STATE.dtype, | |
| ) | |
| GLOBAL_STATE.inference_session = inference_session | |
| first_frame = frames[0] | |
| max_idx = len(frames) - 1 | |
| status = ( | |
| f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. " | |
| f"Device: {GLOBAL_STATE.device}, dtype: bfloat16" | |
| ) | |
| return GLOBAL_STATE, 0, max_idx, first_frame, status | |
| def _speed_to_color(ratio: float) -> tuple[int, int, int]: | |
| ratio = float(np.clip(ratio, 0.0, 1.0)) | |
| gradient = [ | |
| (255, 0, 0), # red | |
| (255, 165, 0), # orange | |
| (255, 255, 0), # yellow | |
| (0, 255, 0), # green | |
| ] | |
| segment = ratio * (len(gradient) - 1) | |
| idx = int(segment) | |
| frac = segment - idx | |
| if idx >= len(gradient) - 1: | |
| return gradient[-1] | |
| c1 = np.array(gradient[idx], dtype=float) | |
| c2 = np.array(gradient[idx + 1], dtype=float) | |
| blended = (1 - frac) * c1 + frac * c2 | |
| return tuple(int(v) for v in blended) | |
| def _angle_between(v1: tuple[float, float], v2: tuple[float, float]) -> float: | |
| x1, y1 = v1 | |
| x2, y2 = v2 | |
| mag1 = math.hypot(x1, y1) | |
| mag2 = math.hypot(x2, y2) | |
| if mag1 < 1e-6 or mag2 < 1e-6: | |
| return 0.0 | |
| cos_val = (x1 * x2 + y1 * y2) / (mag1 * mag2) | |
| cos_val = max(-1.0, min(1.0, cos_val)) | |
| return math.degrees(math.acos(cos_val)) | |
| DISPLAY_MIN_WIDTH = 640 | |
| DISPLAY_MAX_WIDTH = 1280 | |
| FX_GLOW_COLOR = np.array([1.0, 0.1, 0.6], dtype=np.float32) | |
| FX_EPS = 1e-6 | |
| GHOST_TRAIL_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32) | |
| GHOST_TRAIL_ALPHA = 0.55 | |
| BALL_RING_COLOR = np.array([1.0, 0.0, 1.0], dtype=np.float32) | |
| BALL_RING_THICKNESS_PX = 2 | |
| BALL_RING_COLOR_RGB = tuple(int(max(0, min(255, round(c * 255.0)))) for c in BALL_RING_COLOR.tolist()) | |
| def _maybe_upscale_for_display(image: Image.Image) -> Image.Image: | |
| if image is None: | |
| return image | |
| original_width, original_height = image.size | |
| if original_width <= 0 or original_height <= 0: | |
| return image | |
| target_width = original_width | |
| if original_width < DISPLAY_MIN_WIDTH: | |
| target_width = DISPLAY_MIN_WIDTH | |
| elif original_width > DISPLAY_MAX_WIDTH: | |
| target_width = DISPLAY_MAX_WIDTH | |
| if target_width == original_width: | |
| return image | |
| scale = target_width / float(original_width) | |
| target_height = int(round(original_height * scale)) | |
| return image.resize((target_width, target_height), Image.BILINEAR) | |
| def _annotate_frame_index(image: Image.Image, frame_idx: int) -> Image.Image: | |
| if image is None: | |
| return image | |
| annotated = image.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| text = f"Frame {frame_idx}" | |
| padding = 6 | |
| try: | |
| bbox = draw.textbbox((0, 0), text) | |
| text_w = bbox[2] - bbox[0] | |
| text_h = bbox[3] - bbox[1] | |
| except AttributeError: | |
| text_w, text_h = draw.textsize(text) | |
| x0, y0 = padding, padding | |
| x1, y1 = x0 + text_w + padding, y0 + text_h + padding | |
| draw.rectangle([(x0 - padding // 2, y0 - padding // 2), (x1, y1)], fill=(0, 0, 0)) | |
| draw.text((x0, y0), text, fill=(255, 255, 255)) | |
| return annotated | |
| def _apply_cutout_fx(state: "AppState", frame_np: np.ndarray, combined_mask: np.ndarray) -> np.ndarray: | |
| mask = np.clip(combined_mask.astype(np.float32), 0.0, 1.0) | |
| if mask.max() <= FX_EPS: | |
| # No foreground detected; fall back to darkened background choice | |
| bg = frame_np.copy() | |
| if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS: | |
| bg = cv2.GaussianBlur(bg, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma) | |
| bg = bg * (1.0 - np.clip(state.fx_bg_darkening, 0.0, 1.0)) | |
| return np.clip(bg * 255.0, 0, 255).astype(np.uint8) | |
| mask_soft = mask.copy() | |
| if state.fx_soft_matte_enabled: | |
| erode_px = max(0.0, float(state.fx_soft_matte_erode)) | |
| if erode_px > FX_EPS: | |
| kernel_size = int(round(erode_px * 2 + 1)) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) | |
| mask_soft = cv2.erode(mask_soft, kernel) | |
| feather = max(0.0, float(state.fx_soft_matte_feather)) | |
| if feather > FX_EPS: | |
| mask_soft = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=feather, sigmaY=feather) | |
| mask_soft = np.clip(mask_soft * 1.05, 0.0, 1.0) | |
| bg_source = frame_np.copy() | |
| if state.fx_blur_enabled and state.fx_blur_sigma > FX_EPS: | |
| bg_source = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=state.fx_blur_sigma, sigmaY=state.fx_blur_sigma) | |
| darkening = np.clip(state.fx_bg_darkening, 0.0, 1.0) | |
| bg = bg_source * (1.0 - darkening) | |
| alpha = mask_soft[..., None] | |
| out = frame_np * alpha + bg * (1.0 - alpha) | |
| light_wrap_strength = float(state.fx_light_wrap_strength) | |
| light_wrap_width = max(0.0, float(state.fx_light_wrap_width)) | |
| if state.fx_light_wrap_enabled and light_wrap_strength > FX_EPS and light_wrap_width > FX_EPS: | |
| inner_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=light_wrap_width, sigmaY=light_wrap_width) | |
| inner_edge = np.clip(mask_soft - inner_blur, 0.0, 1.0) | |
| if inner_edge.max() > FX_EPS: | |
| inner_edge /= (inner_edge.max() + FX_EPS) | |
| bg_wrap = cv2.GaussianBlur(bg_source, (0, 0), sigmaX=light_wrap_width * 1.5, sigmaY=light_wrap_width * 1.5) | |
| out = np.clip(out + inner_edge[..., None] * bg_wrap * light_wrap_strength, 0.0, 1.0) | |
| glow_strength = float(state.fx_glow_strength) | |
| glow_radius = max(0.0, float(state.fx_glow_radius)) | |
| if state.fx_glow_enabled and glow_strength > FX_EPS and glow_radius > FX_EPS: | |
| outer_blur = cv2.GaussianBlur(mask_soft, (0, 0), sigmaX=glow_radius, sigmaY=glow_radius) | |
| glow_band = np.clip(outer_blur - mask_soft, 0.0, 1.0) | |
| if glow_band.max() > FX_EPS: | |
| glow_band /= (glow_band.max() + FX_EPS) | |
| glow_color = FX_GLOW_COLOR[None, None, :] | |
| out = np.clip(out + glow_band[..., None] * glow_color * glow_strength, 0.0, 1.0) | |
| return np.clip(out * 255.0, 0, 255).astype(np.uint8) | |
| def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> Image.Image: | |
| if state is None or state.video_frames is None or len(state.video_frames) == 0: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| frame = state.video_frames[frame_idx] | |
| masks = state.masks_by_frame.get(frame_idx, {}) | |
| ball_mask_raw = masks.get(BALL_OBJECT_ID) | |
| out_img: Image.Image | None = state.composited_frames.get(frame_idx) | |
| if out_img is None: | |
| out_img = frame | |
| current_union_mask: np.ndarray | None = None | |
| focus_mask: np.ndarray | None = None | |
| for obj_id, mask in masks.items(): | |
| if mask is None: | |
| continue | |
| mask_np = mask.astype(np.float32) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| mask_np = np.clip(mask_np, 0.0, 1.0) | |
| if current_union_mask is None: | |
| current_union_mask = np.zeros_like(mask_np, dtype=np.float32) | |
| current_union_mask = np.maximum(current_union_mask, mask_np) | |
| if obj_id in (BALL_OBJECT_ID, PLAYER_OBJECT_ID): | |
| if focus_mask is None: | |
| focus_mask = np.zeros_like(mask_np, dtype=np.float32) | |
| focus_mask = np.maximum(focus_mask, mask_np) | |
| ghost_mask: np.ndarray | None = None | |
| circle_trail: list[tuple[float, float, float]] = [] | |
| if state.fx_ball_ring_enabled: | |
| circle_trail = _collect_ball_trail_circles(state, frame_idx) | |
| else: | |
| ghost_mask = _build_ball_trail_mask(state, frame_idx) | |
| if len(masks) != 0: | |
| if remove_bg: | |
| # Remove background - show only tracked objects | |
| frame_np = np.array(frame).astype(np.float32) / 255.0 | |
| combined_mask = current_union_mask | |
| if combined_mask is None: | |
| combined_mask = np.zeros((frame_np.shape[0], frame_np.shape[1]), dtype=np.float32) | |
| # Apply falloff to ball component when rendering foreground | |
| if ball_mask_raw is not None: | |
| falloff_ball = _apply_radial_falloff( | |
| np.clip(ball_mask_raw.astype(np.float32), 0.0, 1.0), | |
| strength=1.0, | |
| solid_ratio=0.8, | |
| ) | |
| if falloff_ball is not None: | |
| combined_mask = np.maximum(combined_mask, falloff_ball) | |
| result_np = _apply_cutout_fx(state, frame_np, combined_mask) | |
| out_img = Image.fromarray(result_np) | |
| else: | |
| overlay_masks = masks | |
| if BALL_OBJECT_ID in masks and (state.fx_ball_ring_enabled or ghost_mask is not None): | |
| overlay_masks = {oid: mask for oid, mask in masks.items() if oid != BALL_OBJECT_ID} | |
| if overlay_masks: | |
| out_img = overlay_masks_on_frame(out_img, overlay_masks, state.color_by_obj, alpha=0.65) | |
| # Overlay feathered ball on top | |
| if not state.fx_ball_ring_enabled and ball_mask_raw is not None: | |
| ball_alpha = _apply_radial_falloff(ball_mask_raw, strength=1.0, solid_ratio=0.8) | |
| if ball_alpha is not None and ball_alpha.max() > FX_EPS: | |
| base_np = np.array(out_img).astype(np.float32) / 255.0 | |
| color = np.array(state.color_by_obj.get(BALL_OBJECT_ID, (255, 255, 0)), dtype=np.float32) / 255.0 | |
| alpha = np.clip(ball_alpha[..., None], 0.0, 1.0) | |
| base_np = (1.0 - alpha) * base_np + alpha * color | |
| out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8)) | |
| if ghost_mask is not None: | |
| ghost_np = np.clip(ghost_mask.astype(np.float32), 0.0, 1.0) | |
| if current_union_mask is not None: | |
| ghost_np = ghost_np * np.clip(1.0 - current_union_mask, 0.0, 1.0) | |
| if ghost_np.max() > FX_EPS: | |
| base_np = np.array(out_img).astype(np.float32) / 255.0 | |
| ghost_alpha = ghost_np[..., None] * GHOST_TRAIL_ALPHA | |
| base_np = (1.0 - ghost_alpha) * base_np + ghost_alpha * GHOST_TRAIL_COLOR | |
| if focus_mask is not None: | |
| focus_alpha = np.clip(focus_mask, 0.0, 1.0)[..., None] | |
| orig_np = np.array(frame).astype(np.float32) / 255.0 | |
| base_np = focus_alpha * orig_np + (1.0 - focus_alpha) * base_np | |
| out_img = Image.fromarray(np.clip(base_np * 255.0, 0, 255).astype(np.uint8)) | |
| if state.fx_ball_ring_enabled and circle_trail: | |
| draw = ImageDraw.Draw(out_img) | |
| ring_width = max(1, int(round(BALL_RING_THICKNESS_PX))) | |
| for cx, cy, radius in circle_trail: | |
| if radius <= 1.0: | |
| continue | |
| left = cx - radius | |
| top = cy - radius | |
| right = cx + radius | |
| bottom = cy + radius | |
| draw.ellipse((left, top, right, bottom), outline=BALL_RING_COLOR_RGB, width=ring_width) | |
| # Draw crosses for conditioning frames only (frames with recorded clicks) | |
| clicks_map = state.clicks_by_frame_obj.get(frame_idx) | |
| if state.show_click_marks and clicks_map: | |
| draw = ImageDraw.Draw(out_img) | |
| cross_half = 6 | |
| for obj_id, pts in clicks_map.items(): | |
| for x, y, lbl in pts: | |
| color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0) | |
| # horizontal | |
| draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) | |
| # vertical | |
| draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) | |
| # Draw temporary cross for first corner in box mode | |
| if ( | |
| state.show_click_marks | |
| and state.pending_box_start is not None | |
| and state.pending_box_start_frame_idx == frame_idx | |
| and state.pending_box_start_obj_id is not None | |
| ): | |
| draw = ImageDraw.Draw(out_img) | |
| x, y = state.pending_box_start | |
| cross_half = 6 | |
| color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255)) | |
| draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) | |
| draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) | |
| # Draw boxes for conditioning frames | |
| box_map = state.boxes_by_frame_obj.get(frame_idx) | |
| if state.show_click_marks and box_map: | |
| draw = ImageDraw.Draw(out_img) | |
| for obj_id, boxes in box_map.items(): | |
| color = state.color_by_obj.get(obj_id, (255, 255, 255)) | |
| for x1, y1, x2, y2 in boxes: | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2) | |
| # Draw trajectory centers (all frames) | |
| if state.show_click_marks and state.ball_centers: | |
| draw = ImageDraw.Draw(out_img) | |
| cross_half = 4 | |
| for obj_id, centers in state.ball_centers.items(): | |
| if not centers: | |
| continue | |
| raw_items = sorted(centers.items()) | |
| for _, (rx, ry) in raw_items: | |
| draw.line([(rx - cross_half, ry), (rx + cross_half, ry)], fill=(160, 160, 160), width=1) | |
| draw.line([(rx, ry - cross_half), (rx, ry + cross_half)], fill=(160, 160, 160), width=1) | |
| smooth_dict = state.smoothed_centers.get(obj_id, {}) | |
| if not smooth_dict: | |
| continue | |
| smooth_items = sorted(smooth_dict.items()) | |
| distances: list[float] = [] | |
| prev_center = None | |
| for _, (sx, sy) in smooth_items: | |
| if prev_center is None: | |
| distances.append(0.0) | |
| else: | |
| dx = sx - prev_center[0] | |
| dy = sy - prev_center[1] | |
| distances.append(float(np.hypot(dx, dy))) | |
| prev_center = (sx, sy) | |
| max_dist = max(distances[1:], default=0.0) | |
| color_by_frame: dict[int, tuple[int, int, int]] = {} | |
| for (f_idx, _), dist in zip(smooth_items, distances): | |
| ratio = dist / max_dist if max_dist > 0 else 0.0 | |
| color_by_frame[f_idx] = _speed_to_color(ratio) | |
| for f_idx, (sx, sy) in reversed(smooth_items): | |
| highlight = (f_idx == frame_idx) | |
| color = (255, 0, 0) if highlight else color_by_frame.get(f_idx, (255, 255, 0)) | |
| line_width = 1 if not highlight else 2 | |
| draw.line([(sx - cross_half, sy), (sx + cross_half, sy)], fill=color, width=line_width) | |
| draw.line([(sx, sy - cross_half), (sx, sy + cross_half)], fill=color, width=line_width) | |
| # Save to cache and return | |
| if not remove_bg: | |
| state.composited_frames[frame_idx] = out_img | |
| return out_img | |
| def update_frame_display(state: AppState, frame_idx: int) -> Image.Image: | |
| if state is None or state.video_frames is None or len(state.video_frames) == 0: | |
| return None | |
| frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) | |
| # Serve from cache when available | |
| cached = state.composited_frames.get(frame_idx) | |
| if cached is not None: | |
| return _maybe_upscale_for_display(cached) | |
| composed = compose_frame(state, frame_idx) | |
| return _maybe_upscale_for_display(composed) | |
| def _update_fx_controls( | |
| state: AppState, | |
| soft_enabled: bool, | |
| soft_feather: float, | |
| soft_erode: float, | |
| blur_enabled: bool, | |
| blur_sigma: float, | |
| bg_darkening: float, | |
| wrap_enabled: bool, | |
| wrap_strength: float, | |
| wrap_width: float, | |
| glow_enabled: bool, | |
| glow_strength: float, | |
| glow_radius: float, | |
| ) -> Image.Image: | |
| if state is None: | |
| return None | |
| state.fx_soft_matte_enabled = bool(soft_enabled) | |
| state.fx_soft_matte_feather = max(0.0, float(soft_feather)) | |
| state.fx_soft_matte_erode = max(0.0, float(soft_erode)) | |
| state.fx_blur_enabled = bool(blur_enabled) | |
| state.fx_blur_sigma = max(0.0, float(blur_sigma)) | |
| state.fx_bg_darkening = float(np.clip(bg_darkening, 0.0, 1.0)) | |
| state.fx_light_wrap_enabled = bool(wrap_enabled) | |
| state.fx_light_wrap_strength = max(0.0, float(wrap_strength)) | |
| state.fx_light_wrap_width = max(0.0, float(wrap_width)) | |
| state.fx_glow_enabled = bool(glow_enabled) | |
| state.fx_glow_strength = max(0.0, float(glow_strength)) | |
| state.fx_glow_radius = max(0.0, float(glow_radius)) | |
| state.composited_frames.clear() | |
| idx = int(getattr(state, "current_frame_idx", 0)) | |
| return update_frame_display(state, idx) | |
| def _toggle_ghost_trail(state: AppState, enabled: bool) -> Image.Image: | |
| if state is None: | |
| return None | |
| state.fx_ghost_trail_enabled = bool(enabled) | |
| state.composited_frames.clear() | |
| idx = int(getattr(state, "current_frame_idx", 0)) | |
| return update_frame_display(state, idx) | |
| def _toggle_click_marks(state: AppState, enabled: bool) -> Image.Image: | |
| if state is None: | |
| return None | |
| state.show_click_marks = bool(enabled) | |
| state.composited_frames.clear() | |
| idx = int(getattr(state, "current_frame_idx", 0)) | |
| return update_frame_display(state, idx) | |
| def _toggle_ball_ring(state: AppState, enabled: bool) -> Image.Image: | |
| if state is None: | |
| return None | |
| state.fx_ball_ring_enabled = bool(enabled) | |
| state.composited_frames.clear() | |
| idx = int(getattr(state, "current_frame_idx", 0)) | |
| return update_frame_display(state, idx) | |
| def _build_ball_trail_mask(state: AppState, frame_idx: int) -> np.ndarray | None: | |
| if ( | |
| state is None | |
| or not state.fx_ghost_trail_enabled | |
| or state.masks_by_frame is None | |
| ): | |
| return None | |
| kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame | |
| if kick_candidate is None: | |
| return None | |
| start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1) | |
| end_idx = state.num_frames | |
| if start_idx >= end_idx: | |
| return None | |
| trail_mask: np.ndarray | None = None | |
| for idx in range(start_idx, end_idx): | |
| frame_masks = state.masks_by_frame.get(idx) | |
| if not frame_masks: | |
| continue | |
| mask = frame_masks.get(BALL_OBJECT_ID) | |
| if mask is None: | |
| continue | |
| mask_np = mask.astype(np.float32) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| mask_np = np.clip(mask_np, 0.0, 1.0) | |
| mask_np = _apply_radial_falloff(mask_np, strength=1.0, solid_ratio=0.8) | |
| if trail_mask is None: | |
| trail_mask = np.zeros_like(mask_np, dtype=np.float32) | |
| if trail_mask.shape != mask_np.shape: | |
| continue | |
| trail_mask = np.maximum(trail_mask, mask_np) | |
| return trail_mask | |
| def _collect_ball_trail_circles(state: AppState, frame_idx: int) -> list[tuple[float, float, float]]: | |
| if ( | |
| state is None | |
| or not state.fx_ghost_trail_enabled | |
| or state.masks_by_frame is None | |
| ): | |
| return [] | |
| kick_candidate = state.kick_frame if state.kick_frame is not None else state.kick_debug_kick_frame | |
| if kick_candidate is None: | |
| return [] | |
| start_idx = max(int(kick_candidate) + 1, int(frame_idx) + 1) | |
| end_idx = state.num_frames | |
| if start_idx >= end_idx: | |
| return [] | |
| circles: list[tuple[float, float, float]] = [] | |
| for idx in range(start_idx, end_idx): | |
| frame_masks = state.masks_by_frame.get(idx) | |
| if not frame_masks: | |
| continue | |
| mask = frame_masks.get(BALL_OBJECT_ID) | |
| if mask is None: | |
| continue | |
| mask_np = np.array(mask, dtype=np.float32) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| if mask_np.size == 0: | |
| continue | |
| mask_np = np.clip(mask_np, 0.0, 1.0) | |
| if mask_np.max() <= FX_EPS: | |
| continue | |
| centroid = _compute_mask_centroid(mask_np) | |
| if centroid is None: | |
| continue | |
| cx, cy = centroid | |
| ys, xs = np.nonzero(mask_np > 0.05) | |
| if xs.size == 0 or ys.size == 0: | |
| continue | |
| min_x, max_x = xs.min(), xs.max() | |
| min_y, max_y = ys.min(), ys.max() | |
| radius_x = (max_x - min_x + 1) / 2.0 | |
| radius_y = (max_y - min_y + 1) / 2.0 | |
| radius = float(max(radius_x, radius_y)) | |
| if radius <= 1.0: | |
| continue | |
| circles.append((float(cx), float(cy), radius)) | |
| return circles | |
| def _ensure_color_for_obj(state: AppState, obj_id: int): | |
| if obj_id not in state.color_by_obj: | |
| state.color_by_obj[obj_id] = pastel_color_for_object(obj_id) | |
| def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None: | |
| if mask is None: | |
| return None | |
| mask_np = np.array(mask) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| if mask_np.size == 0: | |
| return None | |
| mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32) | |
| moments = cv2.moments(mask_float) | |
| if moments["m00"] == 0: | |
| return None | |
| cx = int(moments["m10"] / moments["m00"]) | |
| cy = int(moments["m01"] / moments["m00"]) | |
| return cx, cy | |
| def _apply_radial_falloff(mask: np.ndarray, strength: float = 1.0, solid_ratio: float = 0.8) -> np.ndarray: | |
| if mask is None: | |
| return None | |
| mask_np = np.clip(mask.astype(np.float32), 0.0, 1.0) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| if mask_np.max() <= FX_EPS: | |
| return mask_np | |
| centroid = _compute_mask_centroid(mask_np) | |
| if centroid is None: | |
| return mask_np | |
| cx, cy = centroid | |
| h, w = mask_np.shape | |
| yy, xx = np.ogrid[:h, :w] | |
| dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2) | |
| max_dist = dist[mask_np > FX_EPS].max() if np.any(mask_np > FX_EPS) else 0.0 | |
| if max_dist <= FX_EPS: | |
| return mask_np | |
| if solid_ratio >= 1.0: | |
| return mask_np | |
| clipped_dist = np.clip((dist / max_dist - solid_ratio) / (1.0 - solid_ratio), 0.0, 1.0) | |
| falloff = 1.0 - np.power(clipped_dist, strength) | |
| return np.clip(mask_np * falloff, 0.0, 1.0) | |
| def _update_centroids_for_frame(state: AppState, frame_idx: int): | |
| if state is None: | |
| return | |
| masks = state.masks_by_frame.get(int(frame_idx), {}) | |
| seen_obj_ids: set[int] = set() | |
| for obj_id, mask in masks.items(): | |
| centroid = _compute_mask_centroid(mask) | |
| centers = state.ball_centers.setdefault(int(obj_id), {}) | |
| if centroid is not None: | |
| centers[int(frame_idx)] = centroid | |
| else: | |
| centers.pop(int(frame_idx), None) | |
| seen_obj_ids.add(int(obj_id)) | |
| _ensure_color_for_obj(state, int(obj_id)) | |
| mask_np = np.array(mask) | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np.squeeze() | |
| mask_np = np.clip(mask_np, 0.0, 1.0) | |
| area = float(np.count_nonzero(mask_np > 0.3)) | |
| areas = state.mask_areas.setdefault(int(obj_id), {}) | |
| areas[int(frame_idx)] = area | |
| # Remove frames for objects without masks at this frame | |
| for obj_id, centers in state.ball_centers.items(): | |
| if obj_id not in seen_obj_ids: | |
| centers.pop(int(frame_idx), None) | |
| for obj_id, areas in state.mask_areas.items(): | |
| if obj_id not in seen_obj_ids: | |
| areas.pop(int(frame_idx), None) | |
| _recompute_motion_metrics(state) | |
| def _run_kalman_filter( | |
| ordered_items: list[tuple[int, tuple[float, float]]], | |
| base_dt: float, | |
| ) -> tuple[dict[int, tuple[float, float]], dict[int, float], dict[int, float]]: | |
| if not ordered_items: | |
| return {}, {}, {} | |
| H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=float) | |
| R = np.eye(2, dtype=float) * 25.0 | |
| state_vec = np.array( | |
| [ordered_items[0][1][0], ordered_items[0][1][1], 0.0, 0.0], dtype=float | |
| ) | |
| P = np.eye(4, dtype=float) * 50.0 | |
| positions: dict[int, tuple[float, float]] = {} | |
| speeds: dict[int, float] = {} | |
| residuals: dict[int, float] = {} | |
| prev_frame = ordered_items[0][0] | |
| for frame_idx, (cx, cy) in ordered_items: | |
| frame_delta = max(1, frame_idx - prev_frame) if frame_idx != prev_frame else 1 | |
| dt = frame_delta * base_dt | |
| F = np.array( | |
| [ | |
| [1, 0, dt, 0], | |
| [0, 1, 0, dt], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1], | |
| ], | |
| dtype=float, | |
| ) | |
| q = 0.5 * dt**2 | |
| Q = np.array( | |
| [ | |
| [q, 0, dt, 0], | |
| [0, q, 0, dt], | |
| [dt, 0, 1, 0], | |
| [0, dt, 0, 1], | |
| ], | |
| dtype=float, | |
| ) * 0.05 | |
| state_vec = F @ state_vec | |
| P = F @ P @ F.T + Q | |
| z = np.array([cx, cy], dtype=float) | |
| innovation = z - H @ state_vec | |
| S = H @ P @ H.T + R | |
| K = P @ H.T @ np.linalg.inv(S) | |
| state_vec = state_vec + K @ innovation | |
| P = (np.eye(4) - K @ H) @ P | |
| positions[frame_idx] = (state_vec[0], state_vec[1]) | |
| speeds[frame_idx] = float(math.hypot(state_vec[2], state_vec[3])) | |
| residuals[frame_idx] = float(math.hypot(innovation[0], innovation[1])) | |
| prev_frame = frame_idx | |
| return positions, speeds, residuals | |
| def _build_kick_plot(state: AppState): | |
| fig = go.Figure() | |
| if state is None or not state.kick_debug_frames or not state.kick_debug_speeds: | |
| fig.update_layout( | |
| title="Kick & impact diagnostics", | |
| xaxis_title="Frame", | |
| yaxis_title="Speed (px/s)", | |
| ) | |
| return fig | |
| frames = state.kick_debug_frames | |
| speeds = state.kick_debug_speeds | |
| areas = state.kick_debug_area if state.kick_debug_area else [0.0] * len(frames) | |
| threshold = state.kick_debug_threshold or 0.0 | |
| baseline = state.kick_debug_baseline or 0.0 | |
| kick_frame = state.kick_debug_kick_frame | |
| distance = state.kick_debug_distance if state.kick_debug_distance else [0.0] * len(frames) | |
| impact_frames = state.impact_debug_frames if state.impact_debug_frames else frames | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=speeds, | |
| mode="lines+markers", | |
| name="Speed (px/s)", | |
| line=dict(color="#1f77b4"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[frames[0], frames[-1]], | |
| y=[threshold, threshold], | |
| mode="lines", | |
| name="Adaptive threshold", | |
| line=dict(color="#d62728", dash="dash"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[frames[0], frames[-1]], | |
| y=[baseline, baseline], | |
| mode="lines", | |
| name="Baseline speed", | |
| line=dict(color="#ff7f0e", dash="dot"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=areas, | |
| mode="lines", | |
| name="Mask area", | |
| line=dict(color="#2ca02c"), | |
| yaxis="y2", | |
| ) | |
| ) | |
| max_primary = max( | |
| max(speeds) if speeds else 0.0, | |
| threshold, | |
| baseline, | |
| max(state.kick_debug_kalman_speeds) if state.kick_debug_kalman_speeds else 0.0, | |
| state.impact_debug_innovation_threshold or 0.0, | |
| state.impact_debug_direction_threshold or 0.0, | |
| state.impact_debug_speed_threshold_px or 0.0, | |
| 1.0, | |
| ) | |
| max_distance = max(distance) if distance else 0.0 | |
| if max_distance > 0 and max_primary > 0: | |
| distance_scaled = [d * (max_primary / max_distance) for d in distance] | |
| else: | |
| distance_scaled = distance | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=distance_scaled, | |
| mode="lines", | |
| name="Distance from start (scaled)", | |
| line=dict(color="#9467bd"), | |
| ) | |
| ) | |
| if state.kick_debug_kalman_speeds: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=state.kick_debug_kalman_speeds, | |
| mode="lines", | |
| name="Kalman speed", | |
| line=dict(color="#8c564b"), | |
| ) | |
| ) | |
| if state.impact_debug_innovation: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=impact_frames, | |
| y=state.impact_debug_innovation, | |
| mode="lines", | |
| name="Kalman innovation", | |
| line=dict(color="#17becf"), | |
| ) | |
| ) | |
| max_primary = max(max_primary, max(state.impact_debug_innovation)) | |
| if ( | |
| state.impact_debug_innovation_threshold is not None | |
| and impact_frames | |
| and len(impact_frames) >= 2 | |
| ): | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[impact_frames[0], impact_frames[-1]], | |
| y=[ | |
| state.impact_debug_innovation_threshold, | |
| state.impact_debug_innovation_threshold, | |
| ], | |
| mode="lines", | |
| name="Innovation threshold", | |
| line=dict(color="#17becf", dash="dash"), | |
| ) | |
| ) | |
| max_primary = max(max_primary, state.impact_debug_innovation_threshold or 0.0) | |
| if state.impact_debug_direction: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=impact_frames, | |
| y=state.impact_debug_direction, | |
| mode="lines", | |
| name="Direction change (deg)", | |
| line=dict(color="#bcbd22"), | |
| ) | |
| ) | |
| max_primary = max(max_primary, max(state.impact_debug_direction)) | |
| if ( | |
| state.impact_debug_direction_threshold is not None | |
| and impact_frames | |
| and len(impact_frames) >= 2 | |
| ): | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[impact_frames[0], impact_frames[-1]], | |
| y=[ | |
| state.impact_debug_direction_threshold, | |
| state.impact_debug_direction_threshold, | |
| ], | |
| mode="lines", | |
| name="Direction threshold (deg)", | |
| line=dict(color="#bcbd22", dash="dot"), | |
| ) | |
| ) | |
| max_primary = max(max_primary, state.impact_debug_direction_threshold or 0.0) | |
| if state.impact_debug_speed_threshold_px: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[frames[0], frames[-1]], | |
| y=[state.impact_debug_speed_threshold_px] * 2, | |
| mode="lines", | |
| name="Min impact speed (px/s)", | |
| line=dict(color="#b82e2e", dash="dot"), | |
| ) | |
| ) | |
| max_primary = max(max_primary, state.impact_debug_speed_threshold_px or 0.0) | |
| if kick_frame is not None: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[kick_frame, kick_frame], | |
| y=[0, max_primary * 1.05], | |
| mode="lines", | |
| name="Detected kick", | |
| line=dict(color="#ff00ff", dash="solid", width=3), | |
| ) | |
| ) | |
| impact_frame = state.impact_frame | |
| if impact_frame is not None: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[impact_frame, impact_frame], | |
| y=[0, max_primary * 1.05], | |
| mode="lines", | |
| name="Detected impact", | |
| line=dict(color="#ff1493", width=3), | |
| ) | |
| ) | |
| fig.update_layout( | |
| title="Kick & impact diagnostics", | |
| xaxis_title="Frame", | |
| yaxis_title="Speed (px/s)", | |
| yaxis=dict(side="left"), | |
| yaxis2=dict( | |
| title="Mask area / Distance / Direction", | |
| overlaying="y", | |
| side="right", | |
| showgrid=False, | |
| ), | |
| legend=dict(orientation="h"), | |
| margin=dict(t=40, l=40, r=40, b=40), | |
| ) | |
| return fig | |
| def _ensure_ball_prompt_from_yolo(state: AppState): | |
| if ( | |
| state is None | |
| or state.inference_session is None | |
| or not state.yolo_ball_centers | |
| ): | |
| return | |
| # Check if we already have clicks for the ball | |
| for frame_clicks in state.clicks_by_frame_obj.values(): | |
| if frame_clicks.get(BALL_OBJECT_ID): | |
| return | |
| anchor_frame = state.yolo_initial_frame | |
| if anchor_frame is None and state.yolo_ball_centers: | |
| anchor_frame = min(state.yolo_ball_centers.keys()) | |
| if anchor_frame is None or anchor_frame >= state.num_frames: | |
| return | |
| center = state.yolo_ball_centers.get(anchor_frame) | |
| if center is None: | |
| return | |
| x_center, y_center = center | |
| frame_width, frame_height = state.video_frames[anchor_frame].size | |
| x_center = int(np.clip(round(x_center), 0, frame_width - 1)) | |
| y_center = int(np.clip(round(y_center), 0, frame_height - 1)) | |
| event = SimpleNamespace( | |
| index=(x_center, y_center), | |
| value={"x": x_center, "y": y_center}, | |
| ) | |
| state.current_obj_id = BALL_OBJECT_ID | |
| state.current_label = "positive" | |
| state.current_frame_idx = anchor_frame | |
| on_image_click( | |
| update_frame_display(state, anchor_frame), | |
| state, | |
| anchor_frame, | |
| BALL_OBJECT_ID, | |
| "positive", | |
| False, | |
| event, | |
| ) | |
| def _build_yolo_plot(state: AppState): | |
| fig = go.Figure() | |
| if state is None or not state.yolo_kick_frames or not state.yolo_kick_speeds: | |
| fig.update_layout( | |
| title="YOLO kick diagnostics", | |
| xaxis_title="Frame", | |
| yaxis_title="Speed (px/s)", | |
| ) | |
| return fig | |
| frames = state.yolo_kick_frames | |
| speeds = state.yolo_kick_speeds | |
| distance = state.yolo_kick_distance if state.yolo_kick_distance else [0.0] * len(frames) | |
| areas = state.yolo_mask_area_proxy if state.yolo_mask_area_proxy else [0.0] * len(frames) | |
| threshold = state.yolo_threshold or 0.0 | |
| baseline = state.yolo_baseline_speed or 0.0 | |
| kick_frame = state.yolo_kick_frame | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=speeds, | |
| mode="lines+markers", | |
| name="YOLO speed", | |
| line=dict(color="#4caf50"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=[threshold] * len(frames), | |
| mode="lines", | |
| name="Adaptive threshold", | |
| line=dict(color="#ff9800", dash="dash"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=[baseline] * len(frames), | |
| mode="lines", | |
| name="Baseline speed", | |
| line=dict(color="#9e9e9e", dash="dot"), | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=distance, | |
| mode="lines", | |
| name="Distance from start", | |
| line=dict(color="#03a9f4"), | |
| yaxis="y2", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=frames, | |
| y=areas, | |
| mode="lines", | |
| name="Box area proxy", | |
| line=dict(color="#ab47bc", dash="dot"), | |
| yaxis="y2", | |
| ) | |
| ) | |
| if kick_frame is not None: | |
| fig.add_vline( | |
| x=kick_frame, | |
| line=dict(color="#e91e63", width=2), | |
| annotation_text=f"Kick {kick_frame}", | |
| annotation_position="top right", | |
| ) | |
| fig.update_layout( | |
| title="YOLO kick diagnostics", | |
| xaxis=dict(title="Frame"), | |
| yaxis=dict(title="Speed (px/s)"), | |
| yaxis2=dict( | |
| title="Distance / Area", | |
| overlaying="y", | |
| side="right", | |
| showgrid=False, | |
| ), | |
| legend=dict(orientation="h"), | |
| margin=dict(t=40, l=40, r=40, b=40), | |
| ) | |
| return fig | |
| def _format_impact_status(state: AppState) -> str: | |
| if state is None: | |
| return "Impact frame: not computed" | |
| if not state.impact_debug_frames: | |
| return "Impact frame: not computed" | |
| if state.impact_frame is None: | |
| return "Impact frame: not detected" | |
| frame = state.impact_frame | |
| time_part = "" | |
| if state.video_fps and state.video_fps > 1e-6: | |
| seconds = frame / state.video_fps | |
| time_part = f" (~{seconds:.2f}s)" | |
| speed_text = "" | |
| meters_per_px = state.impact_meters_per_px | |
| target_obj_id = getattr(state, "current_obj_id", 1) or 1 | |
| speed_px = state.kalman_speeds.get(int(target_obj_id), {}).get(frame, 0.0) | |
| if meters_per_px and meters_per_px > 0 and speed_px > 0: | |
| speed_kmh = speed_px * meters_per_px * 3.6 | |
| if speed_kmh > 0.1: | |
| speed_text = f", est. speed ≈ {speed_kmh:.1f} km/h" | |
| return f"Impact frame: {frame}{time_part}{speed_text}" | |
| def _format_kick_status(state: AppState) -> str: | |
| if state is None or not isinstance(state, AppState): | |
| return "Kick frame: not computed" | |
| frame = state.kick_frame | |
| if frame is None: | |
| frame = getattr(state, "kick_debug_kick_frame", None) | |
| if frame is None: | |
| if state.kick_debug_frames: | |
| return "Kick frame: not detected" | |
| return "Kick frame: not computed" | |
| if state.kick_frame is None and frame is not None: | |
| state.kick_frame = frame | |
| time_part = "" | |
| if state.video_fps and state.video_fps > 1e-6: | |
| time_part = f" (~{frame / state.video_fps:.2f}s)" | |
| return f"Kick frame ≈ {frame}{time_part}" | |
| def _ball_has_masks(state: AppState, target_obj_id: int = BALL_OBJECT_ID) -> bool: | |
| if state is None: | |
| return False | |
| for masks in state.masks_by_frame.values(): | |
| if target_obj_id in masks: | |
| return True | |
| return False | |
| def _player_has_masks(state: AppState) -> bool: | |
| if state is None or state.player_obj_id is None: | |
| return False | |
| player_id = state.player_obj_id | |
| for masks in state.masks_by_frame.values(): | |
| if player_id in masks: | |
| return True | |
| return False | |
| def _button_updates(state: AppState) -> tuple[Any, Any, Any]: | |
| yolo_ready = isinstance(state, AppState) and state.yolo_kick_frame is not None | |
| propagate_main_enabled = _ball_has_masks(state) or yolo_ready | |
| detect_player_enabled = yolo_ready | |
| propagate_player_enabled = _player_has_masks(state) | |
| return ( | |
| gr.update(interactive=propagate_main_enabled), | |
| gr.update(interactive=detect_player_enabled), | |
| gr.update(interactive=propagate_player_enabled), | |
| ) | |
| def _recompute_motion_metrics(state: AppState, target_obj_id: int = 1): | |
| centers = state.ball_centers.get(target_obj_id) | |
| if not centers or len(centers) < 3: | |
| state.smoothed_centers[target_obj_id] = {} | |
| state.ball_speeds[target_obj_id] = {} | |
| state.kick_frame = None | |
| state.kick_debug_frames = [] | |
| state.kick_debug_speeds = [] | |
| state.kick_debug_threshold = None | |
| state.kick_debug_baseline = None | |
| state.kick_debug_speed_std = None | |
| state.kick_debug_area = [] | |
| state.kick_debug_kick_frame = None | |
| state.kick_debug_distance = [] | |
| state.kalman_centers[target_obj_id] = {} | |
| state.kalman_speeds[target_obj_id] = {} | |
| state.kalman_residuals[target_obj_id] = {} | |
| state.kick_debug_kalman_speeds = [] | |
| state.distance_from_start[target_obj_id] = {} | |
| state.direction_change[target_obj_id] = {} | |
| state.impact_frame = None | |
| state.impact_debug_frames = [] | |
| state.impact_debug_innovation = [] | |
| state.impact_debug_innovation_threshold = None | |
| state.impact_debug_direction = [] | |
| state.impact_debug_direction_threshold = None | |
| state.impact_debug_speed_kmh = [] | |
| state.impact_debug_speed_threshold_px = None | |
| state.impact_meters_per_px = None | |
| return | |
| items = sorted(centers.items()) | |
| dt = 1.0 / state.video_fps if state.video_fps and state.video_fps > 1e-3 else 1.0 | |
| alpha = 0.35 | |
| smoothed: dict[int, tuple[float, float]] = {} | |
| speeds: dict[int, float] = {} | |
| prev_frame = None | |
| prev_smooth = None | |
| for frame_idx, (cx, cy) in items: | |
| if prev_smooth is None: | |
| smooth_x, smooth_y = float(cx), float(cy) | |
| else: | |
| smooth_x = prev_smooth[0] + alpha * (cx - prev_smooth[0]) | |
| smooth_y = prev_smooth[1] + alpha * (cy - prev_smooth[1]) | |
| smoothed[frame_idx] = (smooth_x, smooth_y) | |
| if prev_smooth is None or prev_frame is None: | |
| speeds[frame_idx] = 0.0 | |
| else: | |
| frame_delta = max(1, frame_idx - prev_frame) | |
| time_delta = frame_delta * dt | |
| dist = math.hypot(smooth_x - prev_smooth[0], smooth_y - prev_smooth[1]) | |
| speed = dist / time_delta if time_delta > 0 else dist | |
| speeds[frame_idx] = speed | |
| prev_smooth = (smooth_x, smooth_y) | |
| prev_frame = frame_idx | |
| state.smoothed_centers[target_obj_id] = smoothed | |
| state.ball_speeds[target_obj_id] = speeds | |
| if smoothed: | |
| first_frame = min(smoothed.keys()) | |
| origin = smoothed[first_frame] | |
| distance_dict: dict[int, float] = {} | |
| for frame_idx, (sx, sy) in smoothed.items(): | |
| distance_dict[frame_idx] = math.hypot(sx - origin[0], sy - origin[1]) | |
| state.distance_from_start[target_obj_id] = distance_dict | |
| state.kick_debug_distance = [distance_dict.get(f, 0.0) for f in sorted(smoothed.keys())] | |
| kalman_pos, kalman_speed, kalman_res = _run_kalman_filter(items, dt) | |
| state.kalman_centers[target_obj_id] = kalman_pos | |
| state.kalman_speeds[target_obj_id] = kalman_speed | |
| state.kalman_residuals[target_obj_id] = kalman_res | |
| state.kick_frame = _detect_kick_frame(state, target_obj_id) | |
| state.impact_frame = _detect_impact_frame(state, target_obj_id) | |
| def _detect_kick_frame(state: AppState, target_obj_id: int) -> int | None: | |
| smoothed = state.smoothed_centers.get(target_obj_id, {}) | |
| speeds = state.ball_speeds.get(target_obj_id, {}) | |
| if len(smoothed) < 5: | |
| return None | |
| frames = sorted(smoothed.keys()) | |
| speed_series = [speeds.get(f, 0.0) for f in frames] | |
| baseline_window = min(10, len(frames) // 3 or 1) | |
| baseline_speeds = speed_series[:baseline_window] | |
| baseline_speed = statistics.median(baseline_speeds) if baseline_speeds else 0.0 | |
| speed_std = statistics.pstdev(baseline_speeds) if len(baseline_speeds) > 1 else 0.0 | |
| base_threshold = baseline_speed + 4.0 * speed_std | |
| if base_threshold < baseline_speed * 3.0: | |
| base_threshold = baseline_speed * 3.0 | |
| speed_threshold = max(base_threshold, 15.0) | |
| sustain_frames = 3 | |
| holdout_frames = 8 | |
| area_window = 4 | |
| area_drop_ratio = 0.75 | |
| areas_dict = state.mask_areas.get(target_obj_id, {}) | |
| initial_center = smoothed[frames[0]] | |
| initial_area = areas_dict.get(frames[0], 1.0) or 1.0 | |
| radius_estimate = math.sqrt(initial_area / math.pi) | |
| adaptive_return_distance = max(8.0, min(radius_estimate * 1.5, 40.0)) | |
| state.kick_debug_frames = frames | |
| state.kick_debug_speeds = speed_series | |
| state.kick_debug_threshold = speed_threshold | |
| state.kick_debug_baseline = baseline_speed | |
| state.kick_debug_speed_std = speed_std | |
| state.kick_debug_area = [areas_dict.get(f, 0.0) for f in frames] | |
| state.kick_debug_distance = [ | |
| math.hypot(smoothed[f][0] - initial_center[0], smoothed[f][1] - initial_center[1]) | |
| for f in frames | |
| ] | |
| kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) | |
| state.kick_debug_kalman_speeds = [kalman_speed_dict.get(f, 0.0) for f in frames] | |
| state.kick_debug_kick_frame = None | |
| for idx in range(baseline_window, len(frames)): | |
| frame = frames[idx] | |
| speed = speed_series[idx] | |
| if speed < speed_threshold: | |
| continue | |
| sustain_ok = True | |
| for j in range(1, sustain_frames + 1): | |
| if idx + j >= len(frames): | |
| break | |
| if speed_series[idx + j] < speed_threshold * 0.7: | |
| sustain_ok = False | |
| break | |
| if not sustain_ok: | |
| continue | |
| current_area = areas_dict.get(frame) | |
| area_pass = True | |
| if current_area: | |
| prev_areas = [ | |
| areas_dict.get(f) | |
| for f in frames[max(0, idx - area_window):idx] | |
| if areas_dict.get(f) is not None | |
| ] | |
| if prev_areas: | |
| median_prev = statistics.median(prev_areas) | |
| if median_prev > 0: | |
| ratio = current_area / median_prev | |
| if ratio > area_drop_ratio: | |
| area_pass = False | |
| if not area_pass and speed < speed_threshold * 1.2: | |
| continue | |
| future_frames = frames[idx:min(len(frames), idx + holdout_frames)] | |
| max_future_dist = 0.0 | |
| for future_frame in future_frames: | |
| cx, cy = smoothed[future_frame] | |
| dist = math.hypot(cx - initial_center[0], cy - initial_center[1]) | |
| if dist > max_future_dist: | |
| max_future_dist = dist | |
| if max_future_dist < adaptive_return_distance: | |
| continue | |
| state.kick_debug_kick_frame = frame | |
| return frame | |
| state.kick_debug_kick_frame = None | |
| return None | |
| def _detect_impact_frame(state: AppState, target_obj_id: int) -> int | None: | |
| residuals = state.kalman_residuals.get(target_obj_id, {}) | |
| frames = sorted(residuals.keys()) | |
| state.impact_debug_frames = frames | |
| state.impact_debug_innovation = [residuals.get(f, 0.0) for f in frames] | |
| state.impact_debug_innovation_threshold = None | |
| state.impact_debug_direction = [] | |
| state.impact_debug_direction_threshold = None | |
| state.impact_debug_speed_kmh = [] | |
| state.impact_debug_speed_threshold_px = None | |
| state.impact_meters_per_px = None | |
| if not frames or state.kick_frame is None: | |
| state.impact_frame = None | |
| return None | |
| kalman_positions = state.kalman_centers.get(target_obj_id, {}) | |
| direction_dict: dict[int, float] = {} | |
| prev_pos: tuple[float, float] | None = None | |
| prev_vec: tuple[float, float] | None = None | |
| for frame in frames: | |
| pos = kalman_positions.get(frame) | |
| if pos is None: | |
| direction_dict[frame] = 0.0 | |
| continue | |
| if prev_pos is None: | |
| direction_dict[frame] = 0.0 | |
| prev_vec = (0.0, 0.0) | |
| else: | |
| vec = (pos[0] - prev_pos[0], pos[1] - prev_pos[1]) | |
| if prev_vec is None: | |
| direction_dict[frame] = 0.0 | |
| else: | |
| direction_dict[frame] = _angle_between(prev_vec, vec) | |
| prev_vec = vec | |
| prev_pos = pos | |
| state.direction_change[target_obj_id] = direction_dict | |
| state.impact_debug_direction = [direction_dict.get(f, 0.0) for f in frames] | |
| distance_dict = state.distance_from_start.get(target_obj_id, {}) | |
| max_distance_px = max(distance_dict.values()) if distance_dict else 0.0 | |
| goal_distance_m = max(state.goal_distance_m, 0.0) | |
| meters_per_px = goal_distance_m / max_distance_px if goal_distance_m > 0 and max_distance_px > 1e-6 else None | |
| state.impact_meters_per_px = meters_per_px | |
| kalman_speed_dict = state.kalman_speeds.get(target_obj_id, {}) | |
| if meters_per_px: | |
| state.impact_debug_speed_kmh = [ | |
| kalman_speed_dict.get(f, 0.0) * meters_per_px * 3.6 for f in frames | |
| ] | |
| if state.min_impact_speed_kmh > 0: | |
| state.impact_debug_speed_threshold_px = (state.min_impact_speed_kmh / 3.6) / meters_per_px | |
| else: | |
| state.impact_debug_speed_kmh = [0.0 for _ in frames] | |
| state.impact_debug_speed_threshold_px = None | |
| baseline_frames = [f for f in frames if f <= state.kick_frame] | |
| if not baseline_frames: | |
| baseline_frames = frames[: max(1, min(len(frames), 10))] | |
| baseline_vals = [residuals.get(f, 0.0) for f in baseline_frames] | |
| baseline_median = statistics.median(baseline_vals) if baseline_vals else 0.0 | |
| baseline_std = statistics.pstdev(baseline_vals) if len(baseline_vals) > 1 else 0.0 | |
| innovation_threshold = baseline_median + 4.0 * baseline_std | |
| innovation_threshold = max(innovation_threshold, baseline_median * 3.0, 5.0) | |
| state.impact_debug_innovation_threshold = innovation_threshold | |
| direction_threshold = 25.0 | |
| state.impact_debug_direction_threshold = direction_threshold | |
| post_kick_buffer = 3 | |
| candidates: list[tuple[float, float, int]] = [] | |
| meters_limit = goal_distance_m * 1.1 if goal_distance_m > 0 else None | |
| frame_list_len = len(frames) | |
| for idx, frame in enumerate(frames): | |
| if frame <= state.kick_frame + post_kick_buffer: | |
| continue | |
| innovation = residuals.get(frame, 0.0) | |
| if innovation < innovation_threshold: | |
| continue | |
| direction_delta = direction_dict.get(frame, 0.0) | |
| if direction_delta < direction_threshold: | |
| continue | |
| speed_px = kalman_speed_dict.get(frame, 0.0) | |
| if state.impact_debug_speed_threshold_px and speed_px < state.impact_debug_speed_threshold_px: | |
| continue | |
| if meters_per_px and meters_limit is not None: | |
| distance_m = distance_dict.get(frame, 0.0) * meters_per_px | |
| if distance_m > meters_limit: | |
| continue | |
| # approximate local peak filter | |
| prev_innovation = residuals.get(frames[idx - 1], innovation) if idx > 0 else innovation | |
| next_innovation = residuals.get(frames[idx + 1], innovation) if idx + 1 < frame_list_len else innovation | |
| if innovation < prev_innovation and innovation < next_innovation: | |
| continue | |
| candidates.append((innovation, -frame, frame)) | |
| if not candidates: | |
| state.impact_frame = None | |
| return None | |
| candidates.sort(reverse=True) | |
| impact_frame = candidates[0][2] | |
| state.impact_frame = impact_frame | |
| return impact_frame | |
| def on_image_click( | |
| img: Image.Image | np.ndarray, | |
| state: AppState, | |
| frame_idx: int, | |
| obj_id: int, | |
| label: str, | |
| clear_old: bool, | |
| evt: gr.SelectData, | |
| ) -> Image.Image: | |
| if state is None or state.inference_session is None: | |
| return img # no-op preview when not ready | |
| if state.is_switching_model: | |
| # Gracefully ignore input during model switch; return current preview unchanged | |
| return update_frame_display(state, int(frame_idx)) | |
| # Parse click coordinates from event | |
| x = y = None | |
| if evt is not None: | |
| # Try different gradio event data shapes for robustness | |
| try: | |
| if hasattr(evt, "index") and isinstance(evt.index, (list, tuple)) and len(evt.index) == 2: | |
| x, y = int(evt.index[0]), int(evt.index[1]) | |
| elif hasattr(evt, "value") and isinstance(evt.value, dict) and "x" in evt.value and "y" in evt.value: | |
| x, y = int(evt.value["x"]), int(evt.value["y"]) | |
| except Exception: | |
| x = y = None | |
| if x is None or y is None: | |
| raise gr.Error("Could not read click coordinates.") | |
| _ensure_color_for_obj(state, int(obj_id)) | |
| processor = state.processor | |
| model = state.model | |
| inference_session = state.inference_session | |
| original_size = None | |
| pixel_values = None | |
| if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: | |
| inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt") | |
| original_size = inputs.original_sizes[0] | |
| pixel_values = inputs.pixel_values[0] | |
| if state.current_prompt_type == "Boxes": | |
| # Two-click box input | |
| if state.pending_box_start is None: | |
| # For boxes, always clear old inputs (points) for this object on this frame | |
| frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) | |
| frame_clicks[int(obj_id)] = [] | |
| state.composited_frames.pop(int(frame_idx), None) | |
| state.pending_box_start = (int(x), int(y)) | |
| state.pending_box_start_frame_idx = int(frame_idx) | |
| state.pending_box_start_obj_id = int(obj_id) | |
| # Invalidate cache so temporary cross is drawn | |
| state.composited_frames.pop(int(frame_idx), None) | |
| return update_frame_display(state, int(frame_idx)) | |
| else: | |
| x1, y1 = state.pending_box_start | |
| x2, y2 = int(x), int(y) | |
| # Clear temporary state and invalidate cache | |
| state.pending_box_start = None | |
| state.pending_box_start_frame_idx = None | |
| state.pending_box_start_obj_id = None | |
| state.composited_frames.pop(int(frame_idx), None) | |
| x_min, y_min = min(x1, x2), min(y1, y2) | |
| x_max, y_max = max(x1, x2), max(y1, y2) | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=int(frame_idx), | |
| obj_ids=int(obj_id), | |
| input_boxes=[[[x_min, y_min, x_max, y_max]]], | |
| clear_old_inputs=True, # For boxes, always clear old inputs | |
| original_size=original_size, | |
| ) | |
| frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) | |
| obj_boxes = frame_boxes.setdefault(int(obj_id), []) | |
| # For boxes, always clear old inputs | |
| obj_boxes.clear() | |
| obj_boxes.append((x_min, y_min, x_max, y_max)) | |
| state.composited_frames.pop(int(frame_idx), None) | |
| else: | |
| # Points mode | |
| label_int = 1 if str(label).lower().startswith("pos") else 0 | |
| # If clear_old is enabled, clear prior boxes for this object on this frame | |
| if bool(clear_old): | |
| frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {}) | |
| frame_boxes[int(obj_id)] = [] | |
| state.composited_frames.pop(int(frame_idx), None) | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=int(frame_idx), | |
| obj_ids=int(obj_id), | |
| input_points=[[[[int(x), int(y)]]]], | |
| input_labels=[[[int(label_int)]]], | |
| original_size=original_size, | |
| clear_old_inputs=bool(clear_old), | |
| ) | |
| frame_clicks = state.clicks_by_frame_obj.setdefault(int(frame_idx), {}) | |
| obj_clicks = frame_clicks.setdefault(int(obj_id), []) | |
| if bool(clear_old): | |
| obj_clicks.clear() | |
| obj_clicks.append((int(x), int(y), int(label_int))) | |
| state.composited_frames.pop(int(frame_idx), None) | |
| # Forward on that frame | |
| with torch.inference_mode(): | |
| outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx)) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| # Detach and move off GPU as early as possible to reduce GPU memory pressure | |
| pred_masks = outputs.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| # Map returned masks to object ids. For single object forward, it's [1, 1, H, W] | |
| # But to be safe, iterate over session.obj_ids order. | |
| masks_for_frame: dict[int, np.ndarray] = {} | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_i = video_res_masks[i] | |
| # mask_i shape could be (1, H, W) or (H, W); squeeze to 2D | |
| mask_2d = mask_i.cpu().numpy().squeeze() | |
| masks_for_frame[int(oid)] = mask_2d | |
| state.masks_by_frame[int(frame_idx)] = masks_for_frame | |
| _update_centroids_for_frame(state, int(frame_idx)) | |
| # Invalidate cache for this frame to force recomposition | |
| state.composited_frames.pop(int(frame_idx), None) | |
| # Return updated preview | |
| return update_frame_display(state, int(frame_idx)) | |
| def _on_image_click_with_updates( | |
| img: Image.Image | np.ndarray, | |
| state: AppState, | |
| frame_idx: int, | |
| obj_id: int, | |
| label: str, | |
| clear_old: bool, | |
| evt: gr.SelectData, | |
| ): | |
| preview_img = on_image_click(img, state, frame_idx, obj_id, label, clear_old, evt) | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state) | |
| return preview_img, propagate_main_update, detect_btn_update, propagate_player_update | |
| def propagate_masks(GLOBAL_STATE: gr.State): | |
| if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: | |
| # yield GLOBAL_STATE, "Load a video first.", gr.update() | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| return ( | |
| GLOBAL_STATE, | |
| "Load a video first.", | |
| gr.update(), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| _ensure_ball_prompt_from_yolo(GLOBAL_STATE) | |
| processor = deepcopy(GLOBAL_STATE.processor) | |
| model = deepcopy(GLOBAL_STATE.model) | |
| inference_session = deepcopy(GLOBAL_STATE.inference_session) | |
| # set inference device to cuda to use zero gpu | |
| inference_session.inference_device = "cuda" | |
| inference_session.cache.inference_device = "cuda" | |
| model.to("cuda") | |
| if not GLOBAL_STATE.sam_window: | |
| _compute_sam_window_from_kick( | |
| GLOBAL_STATE, | |
| GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None), | |
| ) | |
| start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) | |
| start_idx = max(0, int(start_idx)) | |
| end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) | |
| total = max(1, end_idx - start_idx) | |
| processed = 0 | |
| _ensure_ball_prompt_from_yolo(GLOBAL_STATE) | |
| # Initial status; no slider change yet | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| f"Propagating masks: {processed}/{total}", | |
| gr.update(), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| last_frame_idx = start_idx | |
| with torch.inference_mode(): | |
| for frame_idx in range(start_idx, end_idx): | |
| frame = GLOBAL_STATE.video_frames[frame_idx] | |
| pixel_values = None | |
| if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames: | |
| pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] | |
| sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| pred_masks = sam2_video_output.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| last_frame_idx = frame_idx | |
| masks_for_frame: dict[int, np.ndarray] = {} | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_2d = video_res_masks[i].cpu().numpy().squeeze() | |
| masks_for_frame[int(oid)] = mask_2d | |
| GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame | |
| _update_centroids_for_frame(GLOBAL_STATE, frame_idx) | |
| # Invalidate cache for that frame to force recomposition | |
| GLOBAL_STATE.composited_frames.pop(frame_idx, None) | |
| processed += 1 | |
| # Every 15th frame (or last), move slider to current frame to update preview via slider binding | |
| if processed % 30 == 0 or processed == total: | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| f"Propagating masks: {processed}/{total}", | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects." | |
| # Focus UI on kick frame if available; otherwise stick to last processed frame | |
| target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None) | |
| if target_frame is None: | |
| target_frame = last_frame_idx | |
| target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) | |
| GLOBAL_STATE.current_frame_idx = target_frame | |
| # Final status; ensure slider points to the target frame (kick frame when detected) | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| text, | |
| gr.update(value=target_frame), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, any, go.Figure, Any, Any, Any]: | |
| # Reset only session-related state, keep uploaded video and model | |
| if not GLOBAL_STATE.video_frames: | |
| # Nothing loaded; keep behavior | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| return ( | |
| GLOBAL_STATE, | |
| None, | |
| 0, | |
| 0, | |
| "Session reset. Load a new video.", | |
| gr.update(visible=False, value=""), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| # Clear prompts and caches | |
| GLOBAL_STATE.masks_by_frame.clear() | |
| GLOBAL_STATE.clicks_by_frame_obj.clear() | |
| GLOBAL_STATE.boxes_by_frame_obj.clear() | |
| GLOBAL_STATE.composited_frames.clear() | |
| GLOBAL_STATE.pending_box_start = None | |
| GLOBAL_STATE.pending_box_start_frame_idx = None | |
| GLOBAL_STATE.pending_box_start_obj_id = None | |
| GLOBAL_STATE.ball_centers.clear() | |
| GLOBAL_STATE.mask_areas.clear() | |
| GLOBAL_STATE.smoothed_centers.clear() | |
| GLOBAL_STATE.ball_speeds.clear() | |
| GLOBAL_STATE.distance_from_start.clear() | |
| GLOBAL_STATE.direction_change.clear() | |
| GLOBAL_STATE.kick_frame = None | |
| GLOBAL_STATE.ball_centers.clear() | |
| GLOBAL_STATE.kalman_centers.clear() | |
| GLOBAL_STATE.kalman_speeds.clear() | |
| GLOBAL_STATE.kalman_residuals.clear() | |
| GLOBAL_STATE.kick_debug_frames = [] | |
| GLOBAL_STATE.kick_debug_speeds = [] | |
| GLOBAL_STATE.kick_debug_threshold = None | |
| GLOBAL_STATE.kick_debug_baseline = None | |
| GLOBAL_STATE.kick_debug_speed_std = None | |
| GLOBAL_STATE.kick_debug_area = [] | |
| GLOBAL_STATE.kick_debug_kick_frame = None | |
| GLOBAL_STATE.kick_debug_distance = [] | |
| GLOBAL_STATE.kick_debug_kalman_speeds = [] | |
| GLOBAL_STATE.impact_frame = None | |
| GLOBAL_STATE.impact_debug_frames = [] | |
| GLOBAL_STATE.impact_debug_innovation = [] | |
| GLOBAL_STATE.impact_debug_innovation_threshold = None | |
| GLOBAL_STATE.impact_debug_direction = [] | |
| GLOBAL_STATE.impact_debug_direction_threshold = None | |
| GLOBAL_STATE.impact_debug_speed_kmh = [] | |
| GLOBAL_STATE.impact_debug_speed_threshold_px = None | |
| GLOBAL_STATE.impact_meters_per_px = None | |
| # Dispose and re-init inference session for current model with existing frames | |
| try: | |
| if GLOBAL_STATE.inference_session is not None: | |
| GLOBAL_STATE.inference_session.reset_inference_session() | |
| except Exception: | |
| pass | |
| GLOBAL_STATE.inference_session = None | |
| gc.collect() | |
| ensure_session_for_current_model(GLOBAL_STATE) | |
| # Keep current slider index if possible | |
| current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0)) | |
| current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1)) | |
| preview_img = update_frame_display(GLOBAL_STATE, current_idx) | |
| slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True) | |
| slider_value = gr.update(value=current_idx) | |
| status = "Session reset. Prompts cleared; video preserved." | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| # clear and reload model and processor | |
| return ( | |
| GLOBAL_STATE, | |
| preview_img, | |
| slider_minmax, | |
| slider_value, | |
| status, | |
| gr.update(visible=False, value=""), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| def create_annotation_preview(video_file, annotations): | |
| """ | |
| Create a preview image showing annotation points on video frames. | |
| Args: | |
| video_file: Path to video file | |
| annotations: List of annotation dicts | |
| Returns: | |
| PIL Image with annotations visualized | |
| """ | |
| import tempfile | |
| from pathlib import Path | |
| # Get video frames for the annotated frame indices | |
| cap = cv2.VideoCapture(video_file) | |
| if not cap.isOpened(): | |
| return None | |
| # Group annotations by frame | |
| frames_to_show = {} | |
| for ann in annotations: | |
| frame_idx = ann.get("frame", 0) | |
| if frame_idx not in frames_to_show: | |
| frames_to_show[frame_idx] = [] | |
| frames_to_show[frame_idx].append(ann) | |
| # Read and annotate frames | |
| annotated_frames = [] | |
| for frame_idx in sorted(frames_to_show.keys())[:3]: # Show max 3 frames | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if not ret: | |
| continue | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(frame_rgb) | |
| draw = ImageDraw.Draw(pil_img) | |
| # Draw annotations | |
| for ann in frames_to_show[frame_idx]: | |
| x, y = ann.get("x", 0), ann.get("y", 0) | |
| obj_id = ann.get("object_id", 1) | |
| label = ann.get("label", "positive") | |
| # Color based on object ID | |
| color = pastel_color_for_object(obj_id) | |
| # Draw crosshair | |
| size = 20 | |
| draw.line([(x-size, y), (x+size, y)], fill=color, width=3) | |
| draw.line([(x, y-size), (x, y+size)], fill=color, width=3) | |
| draw.ellipse([(x-10, y-10), (x+10, y+10)], outline=color, width=3) | |
| # Draw label | |
| text = f"Obj{obj_id} F{frame_idx}" | |
| draw.text((x+15, y-15), text, fill=color) | |
| # Add frame number label | |
| draw.text((10, 10), f"Frame {frame_idx}", fill=(255, 255, 255)) | |
| annotated_frames.append(pil_img) | |
| cap.release() | |
| # Combine frames horizontally | |
| if not annotated_frames: | |
| return None | |
| total_width = sum(img.width for img in annotated_frames) | |
| max_height = max(img.height for img in annotated_frames) | |
| combined = Image.new('RGB', (total_width, max_height)) | |
| x_offset = 0 | |
| for img in annotated_frames: | |
| combined.paste(img, (x_offset, 0)) | |
| x_offset += img.width | |
| return combined | |
| # Allocate GPU for up to 2 minutes | |
| def process_video_api( | |
| video_file, | |
| annotations_json_str: str, | |
| checkpoint: str = "base_plus", | |
| remove_background: bool = True | |
| ): | |
| """ | |
| Single-endpoint API for programmatic video processing. | |
| Args: | |
| video_file: Uploaded video file | |
| annotations_json_str: JSON string with format: | |
| { | |
| "annotations": [ | |
| {"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}, | |
| {"object_id": 1, "frame": 156, "x": 374, "y": 513, "label": "positive"}, | |
| {"object_id": 2, "frame": 156, "x": 374, "y": 257, "label": "positive"} | |
| ] | |
| } | |
| checkpoint: SAM2 model checkpoint (tiny, small, base_plus, large) | |
| remove_background: Whether to remove background (default: True) | |
| Returns: | |
| Tuple of (preview_image, processed_video_path) | |
| """ | |
| import json | |
| try: | |
| # Parse annotations | |
| annotations_data = json.loads(annotations_json_str) | |
| annotations = annotations_data.get("annotations", []) | |
| client_fps = annotations_data.get("fps", None) # FPS used by iOS app to calculate frame indices | |
| print(f"[API] Processing video with {len(annotations)} annotations") | |
| print(f"[API] Client FPS: {client_fps}") | |
| print(f"[API] Checkpoint: {checkpoint}") | |
| print(f"[API] Remove background: {remove_background}") | |
| # Create preview of annotation points | |
| preview_img = create_annotation_preview(video_file, annotations) | |
| # Create a temporary state for this API call | |
| api_state = AppState() | |
| api_state.model_repo_key = checkpoint | |
| # Step 1: Initialize session with video | |
| api_state, min_idx, max_idx, first_frame, status = init_video_session(api_state, video_file) | |
| space_fps = api_state.video_fps | |
| print(f"[API] Video loaded: {status}") | |
| print(f"[API] ⚠️ FPS mismatch check: Client={client_fps}, Space={space_fps}") | |
| # If FPS mismatch, warn about potential frame offset | |
| if client_fps and space_fps and abs(client_fps - space_fps) > 0.5: | |
| offset_estimate = abs(int((client_fps - space_fps) * (api_state.num_frames / client_fps))) | |
| print(f"[API] ⚠️ FPS mismatch detected! Frame indices may be off by ~{offset_estimate} frames") | |
| print(f"[API] ℹ️ Recommendation: Use timestamps instead of frame indices for accuracy") | |
| # Step 2: Apply each annotation | |
| for i, ann in enumerate(annotations): | |
| object_id = ann.get("object_id", 1) | |
| timestamp_ms = ann.get("timestamp_ms", None) | |
| frame_idx = ann.get("frame", None) | |
| x = ann.get("x", 0) | |
| y = ann.get("y", 0) | |
| label = ann.get("label", "positive") | |
| # Calculate frame from timestamp using Space's FPS (more accurate) | |
| if timestamp_ms is not None and space_fps and space_fps > 0: | |
| calculated_frame = int((timestamp_ms / 1000.0) * space_fps) | |
| if frame_idx is not None and calculated_frame != frame_idx: | |
| print(f"[API] ✅ Using timestamp: {timestamp_ms}ms → Frame {calculated_frame} (client sent frame {frame_idx})") | |
| else: | |
| print(f"[API] ✅ Calculated frame from timestamp: {timestamp_ms}ms → Frame {calculated_frame}") | |
| frame_idx = calculated_frame | |
| elif frame_idx is None: | |
| print(f"[API] ⚠️ Warning: No timestamp or frame provided, using frame 0") | |
| frame_idx = 0 | |
| print(f"[API] Adding annotation {i+1}/{len(annotations)}: " | |
| f"Object {object_id}, Frame {frame_idx}, ({x}, {y}), {label}") | |
| # Sync state | |
| api_state.current_frame_idx = int(frame_idx) | |
| api_state.current_obj_id = int(object_id) | |
| api_state.current_label = str(label) | |
| # Create a mock event with coordinates | |
| class MockEvent: | |
| def __init__(self, x, y): | |
| self.index = (x, y) | |
| mock_evt = MockEvent(x, y) | |
| # Add the point annotation | |
| preview_img = on_image_click( | |
| first_frame, | |
| api_state, | |
| frame_idx, | |
| object_id, | |
| label, | |
| clear_old=False, | |
| evt=mock_evt | |
| ) | |
| # Step 3: Propagate masks across all frames | |
| print("[API] Propagating masks across video...") | |
| # We need to consume the generator | |
| for outputs in propagate_masks(api_state): | |
| if not outputs: | |
| continue | |
| api_state = outputs[0] | |
| status_msg = outputs[1] if len(outputs) > 1 else "" | |
| if status_msg: | |
| print(f"[API] Progress: {status_msg}") | |
| # Step 4: Render the final video | |
| print(f"[API] Rendering video with remove_background={remove_background}...") | |
| result_video_path = _render_video(api_state, remove_background) | |
| print(f"[API] ✅ Processing complete: {result_video_path}") | |
| return preview_img, result_video_path | |
| except Exception as e: | |
| print(f"[API] ❌ Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate") | |
| CUSTOM_CSS = "" | |
| with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", theme=theme, css=CUSTOM_CSS) as demo: | |
| GLOBAL_STATE = gr.State(AppState()) | |
| gr.Markdown( | |
| """ | |
| ### SAM2 Video Tracking · powered by Hugging Face 🤗 Transformers | |
| Segment and track objects across a video with SAM2 (Segment Anything 2). This demo runs the official implementation from the Hugging Face Transformers library for interactive, promptable video segmentation. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Quick start** | |
| - **Load a video**: Upload your own or pick an example below. | |
| - **Checkpoint**: Tiny / Small / Base+ / Large (trade speed vs. accuracy). | |
| - **Points mode**: Select an Object ID and point label (positive/negative), then click the frame to add guidance. You can add **multiple points per object** and define **multiple objects** across frames. | |
| - **Boxes mode**: Click two opposite corners to draw a box. Old inputs for that object are cleared automatically. | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Working with results** | |
| - **Preview**: Use the slider to navigate frames and see the current masks. | |
| - **Track**: Click "Track ball (SAM2)" to track all defined objects across the selected window. The preview follows progress periodically to keep things responsive. | |
| - **Export**: Render an MP4 for smooth playback using the original video FPS. | |
| - **Note**: More info on the Hugging Face 🤗 Transformers implementation of SAM2 can be found [here](https://huggingface.co/docs/transformers/en/main/en/model_doc/sam2_video). | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| video_in = gr.Video( | |
| label="Upload video", | |
| sources=["upload", "webcam"], | |
| interactive=True, | |
| elem_id="video-pane", | |
| ) | |
| ckpt_radio = gr.Radio( | |
| choices=["tiny", "small", "base_plus", "large"], | |
| value="tiny", | |
| label="SAM2.1 checkpoint", | |
| ) | |
| ckpt_progress = gr.Markdown(visible=False) | |
| load_status = gr.Markdown(visible=True) | |
| reset_btn = gr.Button("Reset Session", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Preview**") | |
| preview = gr.Image( | |
| interactive=True, | |
| elem_id="preview-pane", | |
| container=False, | |
| show_label=False, | |
| ) | |
| frame_slider = gr.Slider( | |
| label="Frame", | |
| minimum=0, | |
| maximum=0, | |
| step=1, | |
| value=0, | |
| interactive=True, | |
| elem_id="frame-slider", | |
| ) | |
| with gr.Row(): | |
| min_impact_speed_slider = gr.Slider( | |
| label="Min impact speed (km/h)", | |
| minimum=0, | |
| maximum=120, | |
| step=1, | |
| value=20, | |
| interactive=True, | |
| ) | |
| goal_distance_slider = gr.Slider( | |
| label="Distance to goal (m)", | |
| minimum=1, | |
| maximum=60, | |
| step=0.5, | |
| value=18, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| detect_ball_btn = gr.Button("Detect Ball", variant="secondary") | |
| track_ball_yolo_btn = gr.Button("Track ball (YOLO13)", variant="secondary") | |
| propagate_btn = gr.Button("Track ball (SAM2)", variant="primary", interactive=False) | |
| detect_player_btn = gr.Button("Detect Player", variant="secondary", interactive=False) | |
| propagate_player_btn = gr.Button("Propagate Player", variant="primary", interactive=False) | |
| ball_status = gr.Markdown(visible=False) | |
| propagate_status = gr.Markdown(visible=True) | |
| impact_status = gr.Markdown("Impact frame: not computed") | |
| with gr.Row(): | |
| obj_id_inp = gr.Number(value=1, precision=0, label="Object ID", scale=0) | |
| label_radio = gr.Radio(choices=["positive", "negative"], value="positive", label="Point label") | |
| clear_old_chk = gr.Checkbox(value=False, label="Clear old inputs for this object") | |
| prompt_type = gr.Radio(choices=["Points", "Boxes"], value="Points", label="Prompt type") | |
| kick_plot = gr.Plot(label="Kick & impact diagnostics", show_label=True) | |
| yolo_plot = gr.Plot(label="YOLO kick diagnostics", show_label=True) | |
| # Wire events | |
| def _on_video_change(GLOBAL_STATE: gr.State, video): | |
| GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video) | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| return ( | |
| GLOBAL_STATE, | |
| gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), | |
| first_frame, | |
| status, | |
| gr.update(visible=False, value=""), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| video_in.change( | |
| _on_video_change, | |
| inputs=[GLOBAL_STATE, video_in], | |
| outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| show_progress=True, | |
| ) | |
| example_video_path = ensure_example_video() | |
| examples_list = [ | |
| [None, example_video_path], | |
| ] | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[GLOBAL_STATE, video_in], | |
| fn=_on_video_change, | |
| outputs=[GLOBAL_STATE, frame_slider, preview, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| label="Examples", | |
| cache_examples=False, | |
| examples_per_page=5, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| remove_bg_checkbox = gr.Checkbox( | |
| label="Remove Background", | |
| value=True, | |
| info="If checked, shows only tracked objects on black background. If unchecked, overlays colored masks on original video.", | |
| ) | |
| with gr.Column(scale=1): | |
| ghost_trail_chk = gr.Checkbox( | |
| label="Ghost trail (ball)", | |
| value=True, | |
| info="Overlay post-kick SAM2 ball masks in magenta to visualize trajectory.", | |
| ) | |
| with gr.Column(scale=1): | |
| click_marks_chk = gr.Checkbox( | |
| label="Show annotation '+'", | |
| value=False, | |
| info="If unchecked, hides the '+' markers from clicks in preview and renders.", | |
| ) | |
| with gr.Column(scale=1): | |
| ring_outline_chk = gr.Checkbox( | |
| label="Ball outline (ring)", | |
| value=False, | |
| info="Render the ball and its ghost trail as a thin magenta ring instead of filled masks.", | |
| ) | |
| with gr.Accordion("Cutout FX", open=False): | |
| gr.Markdown("These options apply when rendering with background removal.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| soft_matte_chk = gr.Checkbox(label="Soft matte", value=True) | |
| with gr.Column(scale=2): | |
| soft_matte_feather = gr.Slider( | |
| label="Feather radius (px)", | |
| minimum=0.0, | |
| maximum=12.0, | |
| step=0.5, | |
| value=4.0, | |
| ) | |
| with gr.Column(scale=2): | |
| soft_matte_erode = gr.Slider( | |
| label="Edge shrink (px)", | |
| minimum=0.0, | |
| maximum=5.0, | |
| step=0.5, | |
| value=0.5, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| blur_bg_chk = gr.Checkbox(label="Blur background", value=True) | |
| with gr.Column(scale=2): | |
| blur_radius = gr.Slider( | |
| label="Background blur (px)", | |
| minimum=0.0, | |
| maximum=45.0, | |
| step=1.0, | |
| value=0.0, | |
| ) | |
| with gr.Column(scale=2): | |
| bg_darkening = gr.Slider( | |
| label="Darken background", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=1.0, | |
| info="0 keeps original brightness, 1 turns the background black.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| light_wrap_chk = gr.Checkbox(label="Light wrap", value=False) | |
| with gr.Column(scale=2): | |
| light_wrap_strength = gr.Slider( | |
| label="Wrap strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.6, | |
| ) | |
| with gr.Column(scale=2): | |
| light_wrap_width = gr.Slider( | |
| label="Wrap width (px)", | |
| minimum=0.0, | |
| maximum=25.0, | |
| step=0.5, | |
| value=15.0, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| glow_chk = gr.Checkbox(label="Neon glow", value=False) | |
| with gr.Column(scale=2): | |
| glow_strength = gr.Slider( | |
| label="Glow strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.4, | |
| ) | |
| with gr.Column(scale=2): | |
| glow_radius = gr.Slider( | |
| label="Glow radius (px)", | |
| minimum=0.0, | |
| maximum=35.0, | |
| step=0.5, | |
| value=10.0, | |
| ) | |
| with gr.Row(): | |
| render_btn = gr.Button("Render MP4 for smooth playback", variant="primary") | |
| playback_video = gr.Video(label="Rendered Playback", interactive=False) | |
| fx_inputs = [ | |
| soft_matte_chk, | |
| soft_matte_feather, | |
| soft_matte_erode, | |
| blur_bg_chk, | |
| blur_radius, | |
| bg_darkening, | |
| light_wrap_chk, | |
| light_wrap_strength, | |
| light_wrap_width, | |
| glow_chk, | |
| glow_strength, | |
| glow_radius, | |
| ] | |
| for comp in fx_inputs: | |
| comp.change( | |
| _update_fx_controls, | |
| inputs=[GLOBAL_STATE] + fx_inputs, | |
| outputs=preview, | |
| ) | |
| ghost_trail_chk.change( | |
| _toggle_ghost_trail, | |
| inputs=[GLOBAL_STATE, ghost_trail_chk], | |
| outputs=preview, | |
| ) | |
| click_marks_chk.change( | |
| _toggle_click_marks, | |
| inputs=[GLOBAL_STATE, click_marks_chk], | |
| outputs=preview, | |
| ) | |
| ring_outline_chk.change( | |
| _toggle_ball_ring, | |
| inputs=[GLOBAL_STATE, ring_outline_chk], | |
| outputs=preview, | |
| ) | |
| def _on_ckpt_change(s: AppState, key: str): | |
| if s is not None and key: | |
| key = str(key) | |
| if key != s.model_repo_key: | |
| # Update and drop current model to reload lazily next time | |
| s.is_switching_model = True | |
| s.model_repo_key = key | |
| s.model_repo_id = None | |
| s.model = None | |
| s.processor = None | |
| # Stream progress text while loading (first yield shows text) | |
| yield gr.update(visible=True, value=f"Loading checkpoint: {key}...") | |
| ensure_session_for_current_model(s) | |
| if s is not None: | |
| s.is_switching_model = False | |
| # Final yield hides the text | |
| yield gr.update(visible=False, value="") | |
| ckpt_radio.change(_on_ckpt_change, inputs=[GLOBAL_STATE, ckpt_radio], outputs=[ckpt_progress]) | |
| def _sync_frame_idx(state_in: AppState, idx: int): | |
| if state_in is not None: | |
| state_in.current_frame_idx = int(idx) | |
| return update_frame_display(state_in, int(idx)) | |
| frame_slider.change( | |
| _sync_frame_idx, | |
| inputs=[GLOBAL_STATE, frame_slider], | |
| outputs=preview, | |
| ) | |
| def _sync_obj_id(s: AppState, oid): | |
| if s is not None and oid is not None: | |
| s.current_obj_id = int(oid) | |
| return gr.update() | |
| obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[]) | |
| def _sync_label(s: AppState, lab: str): | |
| if s is not None and lab is not None: | |
| s.current_label = str(lab) | |
| return gr.update() | |
| label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[]) | |
| def _sync_prompt_type(s: AppState, val: str): | |
| if s is not None and val is not None: | |
| s.current_prompt_type = str(val) | |
| s.pending_box_start = None | |
| is_points = str(val).lower() == "points" | |
| # Show labels only for points; hide and disable clear_old when boxes | |
| updates = [ | |
| gr.update(visible=is_points), | |
| gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False), | |
| ] | |
| return updates | |
| prompt_type.change( | |
| _sync_prompt_type, | |
| inputs=[GLOBAL_STATE, prompt_type], | |
| outputs=[label_radio, clear_old_chk], | |
| ) | |
| def _update_min_impact_speed(s: AppState, val: float): | |
| if s is not None and val is not None: | |
| s.min_impact_speed_kmh = float(val) | |
| _recompute_motion_metrics(s) | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) | |
| return ( | |
| _build_kick_plot(s), | |
| _format_impact_status(s), | |
| gr.update(value=_format_kick_status(s), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| def _update_goal_distance(s: AppState, val: float): | |
| if s is not None and val is not None: | |
| s.goal_distance_m = float(val) | |
| _recompute_motion_metrics(s) | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(s) | |
| return ( | |
| _build_kick_plot(s), | |
| _format_impact_status(s), | |
| gr.update(value=_format_kick_status(s), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| min_impact_speed_slider.change( | |
| _update_min_impact_speed, | |
| inputs=[GLOBAL_STATE, min_impact_speed_slider], | |
| outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| goal_distance_slider.change( | |
| _update_goal_distance, | |
| inputs=[GLOBAL_STATE, goal_distance_slider], | |
| outputs=[kick_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| def _auto_detect_ball( | |
| state_in: AppState, | |
| obj_id, | |
| label_value: str, | |
| clear_old_value: bool, | |
| ): | |
| if state_in is None or state_in.num_frames == 0: | |
| raise gr.Error("Load a video first, then try auto-detect.") | |
| frame_idx = 0 | |
| frame = state_in.video_frames[frame_idx] | |
| detection = detect_ball_center(frame) | |
| if detection is None: | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) | |
| return ( | |
| update_frame_display(state_in, frame_idx), | |
| gr.update( | |
| value="❌ Unable to auto-detect the ball. Please add a point manually.", | |
| visible=True, | |
| ), | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(state_in), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| x_center, y_center, _, _, conf = detection | |
| frame_width, frame_height = frame.size | |
| x_center = max(0, min(frame_width - 1, int(x_center))) | |
| y_center = max(0, min(frame_height - 1, int(y_center))) | |
| obj_id_int = int(obj_id) if obj_id is not None else state_in.current_obj_id | |
| label_str = label_value if label_value else state_in.current_label | |
| clear_old_flag = bool(clear_old_value) | |
| # Build a synthetic click event to reuse existing handler | |
| synthetic_evt = SimpleNamespace( | |
| index=(x_center, y_center), | |
| value={"x": x_center, "y": y_center}, | |
| ) | |
| state_in.current_frame_idx = frame_idx | |
| preview_img = on_image_click( | |
| update_frame_display(state_in, frame_idx), | |
| state_in, | |
| frame_idx, | |
| obj_id_int, | |
| label_str, | |
| clear_old_flag, | |
| synthetic_evt, | |
| ) | |
| status_text = f"✅ Auto-detected ball at ({x_center}, {y_center}) (conf={conf:.2f})" | |
| status_text += f" | {_format_kick_status(state_in)}" | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) | |
| return ( | |
| preview_img, | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(state_in), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| detect_ball_btn.click( | |
| _auto_detect_ball, | |
| inputs=[GLOBAL_STATE, obj_id_inp, label_radio, clear_old_chk], | |
| outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| def _track_ball_yolo(state_in: AppState): | |
| if state_in is None or state_in.num_frames == 0: | |
| raise gr.Error("Load a video first, then track the ball with YOLO.") | |
| progress = gr.Progress(track_tqdm=False) | |
| _perform_yolo_ball_tracking(state_in, progress=progress) | |
| target_frame = ( | |
| state_in.yolo_kick_frame | |
| if state_in.yolo_kick_frame is not None | |
| else state_in.yolo_initial_frame | |
| if state_in.yolo_initial_frame is not None | |
| else 0 | |
| ) | |
| if state_in.num_frames: | |
| target_frame = int(np.clip(target_frame, 0, state_in.num_frames - 1)) | |
| state_in.current_frame_idx = target_frame | |
| preview_img = update_frame_display(state_in, target_frame) | |
| base_msg = state_in.yolo_status or "" | |
| kick_msg = _format_kick_status(state_in) | |
| status_text = f"{base_msg} | {kick_msg}" if base_msg else kick_msg | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) | |
| return ( | |
| preview_img, | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=target_frame), | |
| _build_kick_plot(state_in), | |
| _build_yolo_plot(state_in), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| track_ball_yolo_btn.click( | |
| _track_ball_yolo, | |
| inputs=[GLOBAL_STATE], | |
| outputs=[preview, ball_status, frame_slider, kick_plot, yolo_plot, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| def _auto_detect_player(state_in: AppState): | |
| if state_in is None or state_in.num_frames == 0: | |
| raise gr.Error("Load a video first, then try auto-detect.") | |
| if state_in.inference_session is None or state_in.processor is None or state_in.model is None: | |
| raise gr.Error("Model session is not ready. Load a video and propagate masks first.") | |
| kick_frame = state_in.kick_frame or getattr(state_in, "kick_debug_kick_frame", None) | |
| if kick_frame is None: | |
| raise gr.Error("Detect the kick frame first by propagating the ball masks.") | |
| frame_idx = int(np.clip(int(kick_frame), 0, state_in.num_frames - 1)) | |
| frame = state_in.video_frames[frame_idx] | |
| detection = detect_person_box(frame) | |
| if detection is None: | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) | |
| status_text = ( | |
| f"{_format_kick_status(state_in)} | ⚠️ Unable to auto-detect the player on frame {frame_idx}. " | |
| "Please add a box manually." | |
| ) | |
| return ( | |
| update_frame_display(state_in, frame_idx), | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(state_in), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| gr.update(), | |
| ) | |
| x_min, y_min, x_max, y_max, conf = detection | |
| state_in.player_obj_id = PLAYER_OBJECT_ID | |
| state_in.player_detection_frame = frame_idx | |
| state_in.player_detection_conf = conf | |
| state_in.current_obj_id = PLAYER_OBJECT_ID | |
| # Clear previous player-specific prompts/masks | |
| for frame_boxes in state_in.boxes_by_frame_obj.values(): | |
| frame_boxes.pop(PLAYER_OBJECT_ID, None) | |
| for frame_clicks in state_in.clicks_by_frame_obj.values(): | |
| frame_clicks.pop(PLAYER_OBJECT_ID, None) | |
| for frame_masks in state_in.masks_by_frame.values(): | |
| frame_masks.pop(PLAYER_OBJECT_ID, None) | |
| _ensure_color_for_obj(state_in, PLAYER_OBJECT_ID) | |
| processor = state_in.processor | |
| model = state_in.model | |
| inference_session = state_in.inference_session | |
| inputs = processor(images=frame, device=state_in.device, return_tensors="pt") | |
| original_size = inputs.original_sizes[0] | |
| pixel_values = inputs.pixel_values[0] | |
| processor.add_inputs_to_inference_session( | |
| inference_session=inference_session, | |
| frame_idx=frame_idx, | |
| obj_ids=PLAYER_OBJECT_ID, | |
| input_boxes=[[[x_min, y_min, x_max, y_max]]], | |
| clear_old_inputs=True, | |
| original_size=original_size, | |
| ) | |
| frame_boxes = state_in.boxes_by_frame_obj.setdefault(frame_idx, {}) | |
| frame_boxes[PLAYER_OBJECT_ID] = [(x_min, y_min, x_max, y_max)] | |
| state_in.composited_frames.pop(frame_idx, None) | |
| with torch.inference_mode(): | |
| outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| pred_masks = outputs.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| masks_for_frame = state_in.masks_by_frame.get(frame_idx, {}).copy() | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_i = video_res_masks[i].cpu().numpy().squeeze() | |
| masks_for_frame[int(oid)] = mask_i | |
| state_in.masks_by_frame[frame_idx] = masks_for_frame | |
| _update_centroids_for_frame(state_in, frame_idx) | |
| state_in.composited_frames.pop(frame_idx, None) | |
| state_in.current_frame_idx = frame_idx | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(state_in) | |
| status_text = ( | |
| f"{_format_kick_status(state_in)} | ✅ Player auto-detected on frame {frame_idx} (conf={conf:.2f})" | |
| ) | |
| return ( | |
| update_frame_display(state_in, frame_idx), | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(state_in), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| gr.update(value=PLAYER_OBJECT_ID), | |
| ) | |
| detect_player_btn.click( | |
| _auto_detect_player, | |
| inputs=[GLOBAL_STATE], | |
| outputs=[preview, ball_status, frame_slider, kick_plot, propagate_btn, detect_player_btn, propagate_player_btn, obj_id_inp], | |
| ) | |
| def propagate_player_masks(GLOBAL_STATE: gr.State): | |
| if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None: | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| return ( | |
| GLOBAL_STATE, | |
| "Load a video first.", | |
| gr.update(), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| if GLOBAL_STATE.player_obj_id is None or not _player_has_masks(GLOBAL_STATE): | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| return ( | |
| GLOBAL_STATE, | |
| "Detect the player before propagating.", | |
| gr.update(), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| processor = deepcopy(GLOBAL_STATE.processor) | |
| model = deepcopy(GLOBAL_STATE.model) | |
| inference_session = deepcopy(GLOBAL_STATE.inference_session) | |
| inference_session.inference_device = "cuda" | |
| inference_session.cache.inference_device = "cuda" | |
| model.to("cuda") | |
| if not GLOBAL_STATE.sam_window: | |
| _compute_sam_window_from_kick( | |
| GLOBAL_STATE, | |
| GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None), | |
| ) | |
| start_idx, end_idx = GLOBAL_STATE.sam_window or (0, GLOBAL_STATE.num_frames) | |
| start_idx = max(0, int(start_idx)) | |
| end_idx = min(GLOBAL_STATE.num_frames, max(start_idx + 1, int(end_idx))) | |
| total = max(1, end_idx - start_idx) | |
| processed = 0 | |
| last_frame_idx = start_idx | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| f"Propagating player: {processed}/{total}", | |
| gr.update(), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| player_id = GLOBAL_STATE.player_obj_id or PLAYER_OBJECT_ID | |
| with torch.inference_mode(): | |
| for frame_idx in range(start_idx, end_idx): | |
| frame = GLOBAL_STATE.video_frames[frame_idx] | |
| pixel_values = None | |
| if ( | |
| inference_session.processed_frames is None | |
| or frame_idx not in inference_session.processed_frames | |
| ): | |
| pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0] | |
| sam2_video_output = model( | |
| inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx | |
| ) | |
| H = inference_session.video_height | |
| W = inference_session.video_width | |
| pred_masks = sam2_video_output.pred_masks.detach().cpu() | |
| video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0] | |
| masks_for_frame = GLOBAL_STATE.masks_by_frame.get(frame_idx, {}).copy() | |
| obj_ids_order = list(inference_session.obj_ids) | |
| for i, oid in enumerate(obj_ids_order): | |
| mask_2d = video_res_masks[i].cpu().numpy().squeeze() | |
| if int(oid) == int(player_id): | |
| masks_for_frame[int(player_id)] = mask_2d | |
| GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame | |
| _update_centroids_for_frame(GLOBAL_STATE, frame_idx) | |
| GLOBAL_STATE.composited_frames.pop(frame_idx, None) | |
| processed += 1 | |
| last_frame_idx = frame_idx | |
| if processed % 30 == 0 or processed == total: | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| f"Propagating player: {processed}/{total}", | |
| gr.update(value=frame_idx), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| text = f"Propagated player across {processed} frames." | |
| target_frame = GLOBAL_STATE.player_detection_frame | |
| if target_frame is None: | |
| target_frame = GLOBAL_STATE.kick_frame or getattr(GLOBAL_STATE, "kick_debug_kick_frame", None) | |
| if target_frame is None: | |
| target_frame = last_frame_idx | |
| target_frame = int(np.clip(target_frame, 0, max(0, GLOBAL_STATE.num_frames - 1))) | |
| GLOBAL_STATE.current_frame_idx = target_frame | |
| propagate_main_update, detect_btn_update, propagate_player_update = _button_updates(GLOBAL_STATE) | |
| yield ( | |
| GLOBAL_STATE, | |
| text, | |
| gr.update(value=target_frame), | |
| _build_kick_plot(GLOBAL_STATE), | |
| _build_yolo_plot(GLOBAL_STATE), | |
| _format_impact_status(GLOBAL_STATE), | |
| gr.update(value=_format_kick_status(GLOBAL_STATE), visible=True), | |
| propagate_main_update, | |
| detect_btn_update, | |
| propagate_player_update, | |
| ) | |
| propagate_player_btn.click( | |
| propagate_player_masks, | |
| inputs=[GLOBAL_STATE], | |
| outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| # Image click to add a point and run forward on that frame | |
| preview.select( | |
| _on_image_click_with_updates, | |
| [preview, GLOBAL_STATE, frame_slider, obj_id_inp, label_radio, clear_old_chk], | |
| [preview, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| # Playback via MP4 rendering only | |
| # Render a smooth MP4 using imageio/pyav (fallbacks to imageio v2 / OpenCV) | |
| def _render_video(s: AppState, remove_bg: bool = False): | |
| if s is None or s.num_frames == 0: | |
| raise gr.Error("Load a video first.") | |
| fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12 | |
| trim_duration_sec = 4.0 | |
| target_window_frames = max(1, int(round(fps * trim_duration_sec))) | |
| half_window = target_window_frames // 2 | |
| kick_frame = s.kick_frame or getattr(s, "kick_debug_kick_frame", None) | |
| start_idx = 0 | |
| end_idx = min(s.num_frames, target_window_frames) | |
| if kick_frame is not None: | |
| start_idx = max(0, int(kick_frame) - half_window) | |
| end_idx = start_idx + target_window_frames | |
| if end_idx > s.num_frames: | |
| end_idx = s.num_frames | |
| start_idx = max(0, end_idx - target_window_frames) | |
| else: | |
| end_idx = min(s.num_frames, start_idx + target_window_frames) | |
| if end_idx <= start_idx: | |
| end_idx = min(s.num_frames, start_idx + 1) | |
| # Compose all frames in trimmed window | |
| frames_np = [] | |
| first = compose_frame(s, start_idx, remove_bg=remove_bg) | |
| h, w = first.size[1], first.size[0] | |
| for idx in range(start_idx, end_idx): | |
| # Don't use cache when remove_bg changes behavior | |
| if remove_bg: | |
| img = compose_frame(s, idx, remove_bg=True) | |
| else: | |
| img = s.composited_frames.get(idx) | |
| if img is None: | |
| img = compose_frame(s, idx, remove_bg=False) | |
| img_with_idx = _annotate_frame_index(img, idx) | |
| frames_np.append(np.array(img_with_idx)[:, :, ::-1]) # BGR for cv2 | |
| # Periodically release CPU mem to reduce pressure | |
| if (idx + 1) % 60 == 0: | |
| gc.collect() | |
| out_path = "/tmp/sam2_playback.mp4" | |
| # Prefer imageio with PyAV/ffmpeg to respect exact fps | |
| try: | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) | |
| for fr_bgr in frames_np: | |
| writer.write(fr_bgr) | |
| writer.release() | |
| return out_path | |
| except Exception as e: | |
| print(f"Failed to render video with cv2: {e}") | |
| raise gr.Error(f"Failed to render video: {e}") | |
| render_btn.click(_render_video, inputs=[GLOBAL_STATE, remove_bg_checkbox], outputs=[playback_video]) | |
| # While propagating, we stream two outputs: status text and slider value updates | |
| propagate_btn.click( | |
| propagate_masks, | |
| inputs=[GLOBAL_STATE], | |
| outputs=[GLOBAL_STATE, propagate_status, frame_slider, kick_plot, yolo_plot, impact_status, ball_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| reset_btn.click( | |
| reset_session, | |
| inputs=GLOBAL_STATE, | |
| outputs=[GLOBAL_STATE, preview, frame_slider, frame_slider, load_status, ball_status, kick_plot, yolo_plot, impact_status, propagate_btn, detect_player_btn, propagate_player_btn], | |
| ) | |
| # ============================================================================ | |
| # COMBINED INTERFACE WITH EXPLICIT API ENDPOINT | |
| # ============================================================================ | |
| # Create API interface with explicit endpoint | |
| api_interface = gr.Interface( | |
| fn=process_video_api, | |
| inputs=[ | |
| gr.Video(label="Video File"), | |
| gr.Textbox( | |
| label="Annotations JSON", | |
| placeholder='{"annotations": [{"object_id": 1, "frame": 139, "x": 369, "y": 652, "label": "positive"}]}', | |
| lines=5 | |
| ), | |
| gr.Radio( | |
| choices=["tiny", "small", "base_plus", "large"], | |
| value="base_plus", | |
| label="SAM2 Checkpoint" | |
| ), | |
| gr.Checkbox(label="Remove Background", value=True) | |
| ], | |
| outputs=[ | |
| gr.Image(label="Annotation Preview (shows where points are placed)"), | |
| gr.Video(label="Processed Video") | |
| ], | |
| title="SAM2 API", | |
| description=""" | |
| ## Programmatic API for Video Background Removal | |
| **The preview image shows where your annotation points are placed on the video frames.** | |
| **Annotations JSON Format:** | |
| ```json | |
| { | |
| "annotations": [ | |
| {"object_id": 1, "frame": 0, "x": 363, "y": 631, "label": "positive"}, | |
| {"object_id": 1, "frame": 187, "x": 296, "y": 485, "label": "positive"}, | |
| {"object_id": 2, "frame": 187, "x": 296, "y": 412, "label": "positive"} | |
| ] | |
| } | |
| ``` | |
| - **Object 1** (Ball): Frame 0 + Impact frame | |
| - **Object 2** (Player): Impact frame | |
| - Colors represent different objects | |
| """ | |
| ) | |
| # Use gr.Blocks to combine both with proper API exposure | |
| with gr.Blocks(title="SAM2 Video Tracking") as combined_demo: | |
| gr.Markdown("# SAM2 Video Tracking") | |
| with gr.Tabs(): | |
| with gr.TabItem("Interactive UI"): | |
| demo.render() | |
| with gr.TabItem("API"): | |
| api_interface.render() | |
| # Explicitly expose the API function at root level for external API calls | |
| # This creates the /api/predict endpoint | |
| api_video_input_hidden = gr.Video(visible=False) | |
| api_annotations_input_hidden = gr.Textbox(visible=False) | |
| api_checkpoint_input_hidden = gr.Radio(choices=["tiny", "small", "base_plus", "large"], visible=False) | |
| api_remove_bg_input_hidden = gr.Checkbox(visible=False) | |
| api_preview_output_hidden = gr.Image(visible=False) | |
| api_video_output_hidden = gr.Video(visible=False) | |
| # This dummy component creates the external API endpoint | |
| api_dummy_btn = gr.Button("API", visible=False) | |
| api_dummy_btn.click( | |
| fn=process_video_api, | |
| inputs=[api_video_input_hidden, api_annotations_input_hidden, api_checkpoint_input_hidden, api_remove_bg_input_hidden], | |
| outputs=[api_preview_output_hidden, api_video_output_hidden], | |
| api_name="predict" # This creates /api/predict for external calls | |
| ) | |
| # Launch with API enabled | |
| if __name__ == "__main__": | |
| combined_demo.queue(api_open=True).launch() | |