import colorsys import gc import tempfile from collections import defaultdict from collections.abc import Iterator, Mapping, Sequence from typing import Any import cv2 import gradio as gr import numpy as np import spaces import torch from gradio.themes import Soft import google.generativeai as genai import os import json import datetime from pathlib import Path from PIL import Image, ImageDraw, ImageFont from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor import dotenv # Import Supabase utilities import sys sys.path.insert(0, str(Path(__file__).parent / 'utils')) from storage import get_storage, generate_timestamped_filename from database import get_database, generate_job_id dotenv.load_dotenv(".env") # Initialize Supabase clients (will raise error if env vars not set) try: storage_client = get_storage() db_client = get_database() print("Supabase clients initialized successfully!") except Exception as e: print(f"Warning: Could not initialize Supabase clients: {str(e)}") print("File storage will fall back to local filesystem if Supabase is not configured.") storage_client = None db_client = None MODEL_ID = "facebook/sam3" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE).eval() TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval() TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(MODEL_ID) print("Models loaded successfully!") MAX_SECONDS = 10.0 TEXT_PROMPT = "rat" FIXED_WIDTH = 720 # Fixed width in pixels FIXED_HEIGHT = 480 # Fixed height in pixels def to_device_recursive(obj: Any, device: str | torch.device) -> Any: # noqa: ANN401 """Return a new object where all torch.Tensors reachable from `obj` are moved to the given device. - Does NOT mutate the original object. - Handles: * torch.Tensor * Mapping (e.g. dict, defaultdict, OrderedDict, etc.) * Sequence (e.g. list, tuple) except str/bytes * Custom classes with attributes (__dict__) - Tries to preserve container types where reasonable. """ device = torch.device(device) memo = {} def _convert(x: Any) -> Any: # noqa: ANN401, C901 obj_id = id(x) if obj_id in memo: return memo[obj_id] # 1. Tensor if isinstance(x, torch.Tensor): y = x.to(device) memo[obj_id] = y return y # 2. Mapping (dict, defaultdict, etc.) if isinstance(x, Mapping): # Special case: defaultdict if isinstance(x, defaultdict): y = defaultdict(x.default_factory) memo[obj_id] = y for k, v in x.items(): y[k] = _convert(v) return y # Try to rebuild the same type using (key, value) pairs try: y = type(x)((k, _convert(v)) for k, v in x.items()) memo[obj_id] = y return y except TypeError: # Fallback: plain dict y = {k: _convert(v) for k, v in x.items()} memo[obj_id] = y return y # 3. Sequence (list/tuple/etc.) but not str/bytes if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)): if isinstance(x, list): y = [_convert(v) for v in x] elif isinstance(x, tuple): y = type(x)(_convert(v) for v in x) else: try: y = type(x)(_convert(v) for v in x) except TypeError: y = [_convert(v) for v in x] memo[obj_id] = y return y # 4. Custom object with attributes (__dict__) if hasattr(x, "__dict__") and not isinstance(x, type): new_obj = x.__class__.__new__(x.__class__) memo[obj_id] = new_obj for name, value in vars(x).items(): setattr(new_obj, name, _convert(value)) return new_obj # 5. Everything else → keep as-is memo[obj_id] = x return x return _convert(obj) def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]: cap = cv2.VideoCapture(video_path_or_url) frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break # Resize frame to fixed resolution (720x480) for consistency frame_resized = cv2.resize(frame, (FIXED_WIDTH, FIXED_HEIGHT), interpolation=cv2.INTER_AREA) # Convert to RGB for PIL Image frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) fps_val = cap.get(cv2.CAP_PROP_FPS) cap.release() info = { "num_frames": len(frames), "fps": float(fps_val) if fps_val and fps_val > 0 else None, } return frames, info def overlay_masks_on_frame( frame: Image.Image, masks_per_object: dict[int, np.ndarray], color_by_obj: dict[int, tuple[int, int, int]], alpha: float = 0.5, ) -> Image.Image: base = np.array(frame).astype(np.float32) / 255.0 height, width = base.shape[:2] overlay = base.copy() for obj_id, mask in masks_per_object.items(): if mask is None: continue if mask.dtype != np.float32: mask = mask.astype(np.float32) if mask.ndim == 3: mask = mask.squeeze() mask = np.clip(mask, 0.0, 1.0) color = np.array(color_by_obj.get(obj_id, (255, 0, 0)), dtype=np.float32) / 255.0 a = alpha m = mask[..., None] overlay = (1.0 - a * m) * overlay + (a * m) * color out = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) return Image.fromarray(out) def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]: golden_ratio_conjugate = 0.61 hue = (obj_id * golden_ratio_conjugate) % 1.0 saturation = 0.45 value = 1.0 r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) return int(r_f * 255), int(g_f * 255), int(b_f * 255) def pastel_color_for_prompt(prompt_text: str) -> tuple[int, int, int]: """Generate a consistent color for a prompt text using a deterministic hash.""" # Use a deterministic hash by summing character codes # This ensures the same prompt always gets the same color char_sum = sum(ord(c) for c in prompt_text) # Use the sum to generate a hue that's well-distributed across the color spectrum # Multiply by a large prime to spread values out hue = ((char_sum * 2654435761) % 360) / 360.0 # Use pastel colors (lower saturation, high value) saturation = 0.5 value = 0.95 r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value) return int(r_f * 255), int(g_f * 255), int(b_f * 255) def get_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None: """Calculate the centroid (center of mass) of a binary mask.""" if mask is None: return None if mask.ndim == 3: mask = mask.squeeze() mask_binary = (mask > 0).astype(np.uint8) if not np.any(mask_binary): return None # Calculate moments to find centroid moments = cv2.moments(mask_binary) if moments["m00"] == 0: return None cx = int(moments["m10"] / moments["m00"]) cy = int(moments["m01"] / moments["m00"]) return (cx, cy) def generate_activity_heatmap( centroids_by_obj: dict[int, dict[int, tuple[int, int]]], masks_by_frame: dict[int, dict[int, np.ndarray]], frame_idx: int, width: int, height: int, use_masks: bool = True, blur_radius: int = 25, ) -> np.ndarray | None: """Generate an activity heatmap showing where objects have been. Args: centroids_by_obj: Dictionary mapping obj_id -> {frame_idx -> (x, y)} masks_by_frame: Dictionary mapping frame_idx -> {obj_id -> mask} frame_idx: Current frame index (accumulate activity up to this frame) width: Frame width height: Frame height use_masks: If True, use mask accumulation; if False, use centroid points blur_radius: Gaussian blur radius for smoothing Returns: Heatmap as RGB numpy array, or None if no activity data """ # Create accumulator for activity activity_map = np.zeros((height, width), dtype=np.float32) if use_masks and masks_by_frame: # Accumulate mask pixels across frames for f_idx in range(frame_idx + 1): if f_idx in masks_by_frame: for obj_id, mask in masks_by_frame[f_idx].items(): if mask is not None: mask_2d = mask.squeeze() if mask.ndim == 3 else mask if mask_2d.shape[0] == height and mask_2d.shape[1] == width: activity_map += (mask_2d > 0).astype(np.float32) elif centroids_by_obj: # Use centroid positions with gaussian points point_radius = 15 for obj_id, frame_centroids in centroids_by_obj.items(): for f_idx, (cx, cy) in frame_centroids.items(): if f_idx <= frame_idx: # Draw a filled circle at the centroid y_min = max(0, cy - point_radius) y_max = min(height, cy + point_radius) x_min = max(0, cx - point_radius) x_max = min(width, cx + point_radius) for y in range(y_min, y_max): for x in range(x_min, x_max): dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) if dist <= point_radius: activity_map[y, x] += 1.0 - (dist / point_radius) if not np.any(activity_map): return None # Apply gaussian blur for smooth heatmap if blur_radius > 0: activity_map = cv2.GaussianBlur(activity_map, (blur_radius * 2 + 1, blur_radius * 2 + 1), 0) # Normalize to 0-1 max_val = activity_map.max() if max_val > 0: activity_map = activity_map / max_val # Apply colormap (use a professional heat/fire colormap) heatmap_uint8 = (activity_map * 255).astype(np.uint8) # Using COLORMAP_HOT for a professional "fire" look heatmap_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_HOT) heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 return heatmap_rgb, activity_map class AppState: def __init__(self) -> None: self.reset() def reset(self) -> None: self.video_frames: list[Image.Image] = [] self.inference_session = None self.video_fps: float | None = None self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {} self.color_by_obj: dict[int, tuple[int, int, int]] = {} self.color_by_prompt: dict[str, tuple[int, int, int]] = {} self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {} self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {} self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {} self.composited_frames: dict[int, Image.Image] = {} self.centroids_by_obj: dict[int, dict[int, tuple[int, int]]] = {} # obj_id -> {frame_idx -> (x, y)} self.show_heatmap: bool = True self.show_trajectory: bool = True self.current_frame_idx: int = 0 self.current_obj_id: int = 1 self.current_label: str = "positive" self.current_clear_old: bool = True self.current_prompt_type: str = "Points" self.pending_box_start: tuple[int, int] | None = None self.pending_box_start_frame_idx: int | None = None self.pending_box_start_obj_id: int | None = None self.active_tab: str = "text" def __repr__(self) -> str: return f"AppState(video_frames={len(self.video_frames)}, video_fps={self.video_fps}, masks_by_frame={len(self.masks_by_frame)}, color_by_obj={len(self.color_by_obj)})" @property def num_frames(self) -> int: return len(self.video_frames) def init_video_session( state: AppState, video: str | dict, active_tab: str = "point_box" ) -> tuple[AppState, int, int, Image.Image, str]: state.video_frames = [] state.masks_by_frame = {} state.color_by_obj = {} state.color_by_prompt = {} state.text_prompts_by_frame_obj = {} state.clicks_by_frame_obj = {} state.boxes_by_frame_obj = {} state.composited_frames = {} state.inference_session = None state.active_tab = active_tab video_path: str | None = None if isinstance(video, dict): video_path = video.get("name") or video.get("path") or video.get("data") elif isinstance(video, str): video_path = video else: video_path = None if not video_path: raise gr.Error("Invalid video input.") frames, info = try_load_video_frames(video_path) if len(frames) == 0: raise gr.Error("No frames could be loaded from the video.") trimmed_note = "" fps_in = info.get("fps") max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames) if len(frames) > max_frames_allowed: frames = frames[:max_frames_allowed] trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)" if isinstance(info, dict): info["num_frames"] = len(frames) state.video_frames = frames state.video_fps = float(fps_in) if fps_in else None raw_video = [np.array(frame) for frame in frames] processor = TEXT_VIDEO_PROCESSOR state.inference_session = processor.init_video_session( video=frames, inference_device=DEVICE, inference_state_device=DEVICE, processing_device=DEVICE, video_storage_device=DEVICE, dtype=DTYPE, ) state.inference_session.inference_device = DEVICE state.inference_session.processing_device = DEVICE state.inference_session.cache.inference_device = DEVICE first_frame = frames[0] max_idx = len(frames) - 1 if active_tab == "text": status = ( f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. " ) else: status = ( f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. " ) return state, 0, max_idx, first_frame, status def compose_frame(state: AppState, frame_idx: int) -> Image.Image: if state is None or state.video_frames is None or len(state.video_frames) == 0: return None frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) frame = state.video_frames[frame_idx] masks = state.masks_by_frame.get(frame_idx, {}) out_img = frame # Overlay activity heatmap if enabled if getattr(state, "show_heatmap", False) and (state.centroids_by_obj or state.masks_by_frame): img_width, img_height = out_img.size heatmap_result = generate_activity_heatmap( state.centroids_by_obj, state.masks_by_frame, state.num_frames - 1, # Use full duration for summary heatmap img_width, img_height, use_masks=True, blur_radius=30, ) if heatmap_result is not None: heatmap_rgb, activity_map = heatmap_result # Convert frame to numpy for blending frame_np = np.array(out_img).astype(np.float32) / 255.0 # Blend heatmap with frame (only where there's activity) heatmap_alpha = 0.6 activity_mask = activity_map[..., None] # Add channel dim blended = frame_np * (1 - activity_mask * heatmap_alpha) + heatmap_rgb * activity_mask * heatmap_alpha blended = np.clip(blended * 255, 0, 255).astype(np.uint8) out_img = Image.fromarray(blended) if len(masks) != 0: out_img = overlay_masks_on_frame(out_img, masks, state.color_by_obj, alpha=0.65) clicks_map = state.clicks_by_frame_obj.get(frame_idx) if clicks_map: draw = ImageDraw.Draw(out_img) cross_half = 6 for obj_id, pts in clicks_map.items(): for x, y, lbl in pts: color = (0, 255, 0) if int(lbl) == 1 else (255, 0, 0) draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) if ( state.pending_box_start is not None and state.pending_box_start_frame_idx == frame_idx and state.pending_box_start_obj_id is not None ): draw = ImageDraw.Draw(out_img) x, y = state.pending_box_start cross_half = 6 color = state.color_by_obj.get(state.pending_box_start_obj_id, (255, 255, 255)) draw.line([(x - cross_half, y), (x + cross_half, y)], fill=color, width=2) draw.line([(x, y - cross_half), (x, y + cross_half)], fill=color, width=2) box_map = state.boxes_by_frame_obj.get(frame_idx) if box_map: draw = ImageDraw.Draw(out_img) for obj_id, boxes in box_map.items(): color = state.color_by_obj.get(obj_id, (255, 255, 255)) for x1, y1, x2, y2 in boxes: draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2) text_prompts_by_obj = {} for frame_texts in state.text_prompts_by_frame_obj.values(): for obj_id, text_prompt in frame_texts.items(): if obj_id not in text_prompts_by_obj: text_prompts_by_obj[obj_id] = text_prompt if text_prompts_by_obj and len(masks) > 0: draw = ImageDraw.Draw(out_img) # Calculate scale factor based on image size (reference: 720p height = 720) img_width, img_height = out_img.size reference_height = 720.0 scale_factor = img_height / reference_height # Scale font size (base size ~13 pixels for default font, scale proportionally) base_font_size = 13 font_size = max(10, int(base_font_size * scale_factor)) # Try to load a scalable font, fall back to default if not available try: # Try common system fonts font_paths = [ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", "/System/Library/Fonts/Helvetica.ttc", "arial.ttf", ] font = None for font_path in font_paths: try: font = ImageFont.truetype(font_path, font_size) break except OSError: continue if font is None: # Fallback to default font font = ImageFont.load_default() except Exception: font = ImageFont.load_default() for obj_id, text_prompt in text_prompts_by_obj.items(): obj_mask = masks.get(obj_id) if obj_mask is not None: mask_array = np.array(obj_mask) if mask_array.size > 0 and np.any(mask_array): rows = np.any(mask_array, axis=1) cols = np.any(mask_array, axis=0) if np.any(rows) and np.any(cols): y_min, y_max = np.where(rows)[0][[0, -1]] x_min, x_max = np.where(cols)[0][[0, -1]] label_x = int(x_min) # Scale vertical offset and padding vertical_offset = int(20 * scale_factor) padding = max(2, int(4 * scale_factor)) label_y = int(y_min) - vertical_offset label_y = max(int(5 * scale_factor), label_y) obj_color = state.color_by_obj.get(obj_id, (255, 255, 255)) # Include object ID in the label label_text = f"{text_prompt} - ID {obj_id}" bbox = draw.textbbox((label_x, label_y), label_text, font=font) draw.rectangle( [(bbox[0] - padding, bbox[1] - padding), (bbox[2] + padding, bbox[3] + padding)], fill=obj_color, outline=None, width=0, ) draw.text((label_x, label_y), label_text, fill=(255, 255, 255), font=font) # Draw trajectory lines for tracked objects if getattr(state, "show_trajectory", True) and hasattr(state, "centroids_by_obj") and state.centroids_by_obj: draw = ImageDraw.Draw(out_img) trajectory_tail_length = 999999 # Show full trajectory across whole video for obj_id, frame_centroids in state.centroids_by_obj.items(): if len(frame_centroids) < 2: continue # Get sorted frame indices up to and including current frame (with tail limit) sorted_frames = sorted( [f for f in frame_centroids.keys() if f <= frame_idx and f >= frame_idx - trajectory_tail_length] ) if len(sorted_frames) < 2: continue # Get the color for this object color = state.color_by_obj.get(obj_id, (255, 255, 255)) # Draw lines connecting consecutive centroids for i in range(1, len(sorted_frames)): prev_frame = sorted_frames[i - 1] curr_frame = sorted_frames[i] prev_centroid = frame_centroids[prev_frame] curr_centroid = frame_centroids[curr_frame] # Calculate line width based on recency (thicker for more recent) frames_ago = frame_idx - curr_frame line_width = max(1, 4 - frames_ago // 15) draw.line([prev_centroid, curr_centroid], fill=color, width=line_width) # Draw a small circle at the current position if frame_idx in frame_centroids: cx, cy = frame_centroids[frame_idx] radius = 6 draw.ellipse( [(cx - radius, cy - radius), (cx + radius, cy + radius)], fill=color, outline=(255, 255, 255), width=2, ) state.composited_frames[frame_idx] = out_img return out_img def update_frame_display(state: AppState, frame_idx: int) -> Image.Image: if state is None or state.video_frames is None or len(state.video_frames) == 0: return None frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1)) cached = state.composited_frames.get(frame_idx) if cached is not None: return cached return compose_frame(state, frame_idx) def generate_report(state: AppState, frame_idx: int) -> tuple: """Generate a markdown report using Gemini API and save to Supabase Storage.""" # Prioritize environment variable, then fallback to UI input effective_api_key = os.environ.get("GOOGLE_API_KEY") if not effective_api_key or not effective_api_key.strip(): return "Please enter a valid Gemini API Key or set the GOOGLE_API_KEY environment variable.", None if not state or not state.video_frames: return "No video loaded.", None try: genai.configure(api_key=effective_api_key) model = genai.GenerativeModel('gemini-2.5-flash-image') # Use Flash for speed and analysis # 1. Prepare image (current frame with overlays) current_frame = update_frame_display(state, frame_idx) if current_frame is None: return "Failed to capture current frame.", None # 2. Extract metadata num_rats = len(state.color_by_obj) total_frames = state.num_frames fps = state.video_fps or 12.0 duration_sec = total_frames / fps if fps > 0 else 0 # Determine tracking info tracking_summary = [] for obj_id, centroids in state.centroids_by_obj.items(): prompt = _get_prompt_for_obj(state, obj_id) or "Unknown object" num_points = len(centroids) tracking_summary.append(f"Object {obj_id} ({prompt}): tracked across {num_points} frames.") # 3. Create prompt for Gemini metadata_str = "\n".join(tracking_summary) prompt = ( f"You are an expert rodent behavior analyst. Analyze this image from a 'PetalGard' rat detection system.\n" f"The image shows a frame from a video with overlays (masks, trajectories, heatmaps).\n\n" f"Metadata:\n" f"- Total detected objects: {num_rats}\n" f"- Video duration: {duration_sec:.2f} seconds\n" f"- Tracking Data Summary:\n{metadata_str}\n\n" f"Please provide a short report (markdown format) including:\n" f"1. A brief interpretation of the rat(s) movement and behavior based on the trajectory and heatmap.\n" f"2. Observations about the environment (e.g., indoor, outdoor, obstacles).\n" f"3. Any other interesting insights from the image (e.g., group behavior, specific interest points).\n" f"Keep it professional and concise." ) # 4. Generate content (Text Report) response = model.generate_content([prompt, current_frame]) # Extract text manually from parts to handle responses that include # inline_data (images) alongside text - response.text fails on these. report_text_parts = [] for part in response.candidates[0].content.parts: if hasattr(part, 'text') and part.text: report_text_parts.append(part.text) report_text = "\n".join(report_text_parts) if not report_text: report_text = "No text analysis was returned by the model." # 5. Generate filenames and timestamps timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") report_filename = f"report_{timestamp}.md" # 6. Generate Annotated Image (Gemini 3 Pro Image Preview) # Use the last frame which includes all overlays (heatmap, trajectories, etc.) last_frame_idx = state.num_frames - 1 last_frame = update_frame_display(state, last_frame_idx) annotated_image_path = None annotated_image_filename = None image_explanation = "" image_path_for_gradio = None try: # Prompt to annotate the image annotation_prompt = ( "Annotate this image and draw 1-2 circles around the places where the trap will be most effective:\n" "1. The place where the rat spent most time (highlighted in the heatmap with red/yellow zones).\n" "2. The place that the rat moved to (if different from the first location, following the trajectory lines).\n" "Focus on accuracy based on the provided visual data.\n" "Write a short explanation for your choices on the image." ) image_model = genai.GenerativeModel('gemini-2.5-flash-image') img_response = image_model.generate_content([annotation_prompt, last_frame]) # Handle potential image and text output in the response for part in img_response.candidates[0].content.parts: if hasattr(part, 'inline_data') and part.inline_data: extension = ".jpg" if hasattr(part.inline_data, "mime_type"): if "png" in part.inline_data.mime_type: extension = ".png" elif "webp" in part.inline_data.mime_type: extension = ".webp" annotated_image_filename = f"annotated_{timestamp}{extension}" image_bytes = part.inline_data.data # Save to temp file for Gradio display tmp = tempfile.NamedTemporaryFile(suffix=extension, delete=False) tmp.write(image_bytes) tmp.close() image_path_for_gradio = tmp.name # Try to upload to Supabase, fall back to local storage if storage_client: try: upload_result = storage_client.upload_annotated_image( image_bytes=image_bytes, filename=annotated_image_filename, image_format=extension[1:] # Remove the dot ) annotated_image_path = upload_result['public_url'] print(f"Uploaded annotated image to Supabase: {annotated_image_path}") except Exception as upload_err: print(f"Failed to upload to Supabase, using local storage: {upload_err}") # Fall back to local storage reports_dir = Path("../platform/public/reports") reports_dir.mkdir(parents=True, exist_ok=True) local_path = reports_dir / annotated_image_filename with open(local_path, "wb") as f: f.write(image_bytes) annotated_image_path = f"/reports/{annotated_image_filename}" else: # Use local storage reports_dir = Path("../platform/public/reports") reports_dir.mkdir(parents=True, exist_ok=True) local_path = reports_dir / annotated_image_filename with open(local_path, "wb") as f: f.write(image_bytes) annotated_image_path = f"/reports/{annotated_image_filename}" elif hasattr(part, 'text') and part.text: image_explanation += part.text + "\n" except Exception as img_err: print(f"Annotated image generation failed: {img_err}") # Create full markdown components header = f"# PetalGard Analysis Report\n\n**Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" header += f"**Video Duration:** {duration_sec:.2f}s | **Detected Rats:** {num_rats}\n\n" # Section for saved report (uses Supabase URL or local path) visual_analysis_saved = "" if annotated_image_path: visual_analysis_saved += f"## Visual Analysis & Trap Placement\n" visual_analysis_saved += f"![AI Annotated Trap Locations]({annotated_image_path})\n\n" if image_explanation: visual_analysis_saved += f"**AI Explanation:** {image_explanation}\n\n" # Section for Gradio UI (text only, image returned separately) visual_analysis_gradio = "" if annotated_image_filename: visual_analysis_gradio += f"## Visual Analysis & Trap Placement\n" visual_analysis_gradio += f"*See annotated image below*\n\n" if image_explanation: visual_analysis_gradio += f"**AI Explanation:** {image_explanation}\n\n" behavioral_insights = "## Behavioral Insights\n" + report_text # Save the report to database (content stored directly in table) full_report = header + visual_analysis_saved + behavioral_insights # Store in database table instead of storage bucket if db_client: try: db_client.create_report( title=f"Report {timestamp}", content=full_report, storage_path=None, # Not using storage for markdown annotated_image_path=annotated_image_path, video_duration_seconds=duration_sec, num_detected_objects=num_rats ) print("Saved report content and metadata to database") except Exception as db_err: print(f"Failed to save to database: {db_err}") # Fall back to local file if database fails reports_dir = Path("../platform/public/reports") reports_dir.mkdir(parents=True, exist_ok=True) report_path = reports_dir / report_filename with open(report_path, "w") as f: f.write(full_report) # Return the report for Gradio preview (text and image separately) gradio_full_report = header + visual_analysis_gradio + behavioral_insights return f"Report generated successfully: {report_filename}\n\n{gradio_full_report}", image_path_for_gradio except Exception as e: return f"Error generating report: {str(e)}", None def _get_prompt_for_obj(state: AppState, obj_id: int) -> str | None: """Get the prompt text associated with an object ID.""" # Priority 1: Check text_prompts_by_frame_obj (most reliable) for frame_texts in state.text_prompts_by_frame_obj.values(): if obj_id in frame_texts: return frame_texts[obj_id].strip() # Priority 2: Check inference session mapping if state.inference_session is not None and ( hasattr(state.inference_session, "obj_id_to_prompt_id") and obj_id in state.inference_session.obj_id_to_prompt_id ): prompt_id = state.inference_session.obj_id_to_prompt_id[obj_id] if hasattr(state.inference_session, "prompts") and prompt_id in state.inference_session.prompts: return state.inference_session.prompts[prompt_id].strip() return None def _ensure_color_for_obj(state: AppState, obj_id: int) -> None: """Assign color to object based on its prompt if available, otherwise use object ID.""" prompt_text = _get_prompt_for_obj(state, obj_id) if prompt_text is not None: # Ensure prompt has a color assigned if prompt_text not in state.color_by_prompt: state.color_by_prompt[prompt_text] = pastel_color_for_prompt(prompt_text) # Always update to prompt-based color state.color_by_obj[obj_id] = state.color_by_prompt[prompt_text] elif obj_id not in state.color_by_obj: # Fallback to object ID-based color (for point/box prompting mode) state.color_by_obj[obj_id] = pastel_color_for_object(obj_id) # Removed on_image_click @spaces.GPU def on_text_prompt( state: AppState, frame_idx: int, progress: gr.Progress = gr.Progress(), ) -> Iterator[tuple[Image.Image, str, str, AppState, dict]]: """Automatically find object and propagate.""" if state is None or state.inference_session is None: yield None, "Upload a video first.", "**Active prompts:** None", state, gr.update() return model = TEXT_VIDEO_MODEL processor = TEXT_VIDEO_PROCESSOR text_prompt = TEXT_PROMPT # Parse comma-separated prompts or single prompt prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()] state.inference_session = to_device_recursive(state.inference_session, DEVICE) # Add text prompt(s) state.inference_session = processor.add_text_prompt( inference_session=state.inference_session, text=prompt_texts, ) # Store frames where rats are detected frames_with_rats = [] min_frames_needed = 10 num_frames = len(state.video_frames) # Search from current frame to the end for search_idx in range(int(frame_idx), num_frames): yield update_frame_display(state, search_idx), f"Searching for object in frame {search_idx}... (Found {len(frames_with_rats)}/{min_frames_needed} frames)", _get_active_prompts_display(state), state, gr.update(value=search_idx) masks_for_frame = state.masks_by_frame.setdefault(search_idx, {}) frame_texts = state.text_prompts_by_frame_obj.setdefault(int(search_idx), {}) detected_obj_ids = [] num_objects = 0 with torch.no_grad(): for model_outputs in model.propagate_in_video_iterator( inference_session=state.inference_session, start_frame_idx=search_idx, max_frame_num_to_track=1, ): processed_outputs = processor.postprocess_outputs( state.inference_session, model_outputs, ) if model_outputs.frame_idx == search_idx: object_ids = processed_outputs["object_ids"] masks = processed_outputs["masks"] scores = processed_outputs["scores"] prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {}) num_objects = len(object_ids) if num_objects > 0: if len(scores) > 0: sorted_indices = torch.argsort(scores, descending=True).cpu().tolist() else: sorted_indices = list(range(num_objects)) for mask_idx in sorted_indices: current_obj_id = int(object_ids[mask_idx].item()) detected_obj_ids.append(current_obj_id) mask_2d = masks[mask_idx].float().cpu().numpy() if mask_2d.ndim == 3: mask_2d = mask_2d.squeeze() mask_2d = (mask_2d > 0.0).astype(np.float32) masks_for_frame[current_obj_id] = mask_2d centroid = get_mask_centroid(mask_2d) if centroid: if current_obj_id not in state.centroids_by_obj: state.centroids_by_obj[current_obj_id] = {} state.centroids_by_obj[current_obj_id][search_idx] = centroid found_prompt = None for p, ids in prompt_to_obj_ids.items(): if current_obj_id in ids: found_prompt = p break if found_prompt: frame_texts[current_obj_id] = found_prompt.strip() _ensure_color_for_obj(state, current_obj_id) # Track frames where rats are detected if num_objects > 0: frames_with_rats.append(search_idx) state.composited_frames.pop(search_idx, None) # Check if we've found enough frames if len(frames_with_rats) >= min_frames_needed: break # Check if we found any frames with rats if len(frames_with_rats) == 0: yield update_frame_display(state, int(frame_idx)), "Object not found in the rest of the video.", _get_active_prompts_display(state), state, gr.update() return # Use the last frame with detected rats found_frame_idx = frames_with_rats[-1] # Automatically start propagation from the found frame (last frame with rats) state.current_frame_idx = found_frame_idx yield update_frame_display(state, found_frame_idx), f"Object found at frame {found_frame_idx} (found {len(frames_with_rats)} frames with rats)! Starting propagation...", _get_active_prompts_display(state), state, gr.update(value=found_frame_idx) # Run propagation for s, status, slider_update in _propagate_masks_core(state, progress=progress): # Adapt propagate_masks yield to match on_text_prompt yield yield update_frame_display(s, s.current_frame_idx), status, _get_active_prompts_display(s), s, slider_update def _get_active_prompts_display(state: AppState) -> str: """Get a formatted string showing all active prompts in the inference session.""" if state is None or state.inference_session is None: return "**Active prompts:** None" if hasattr(state.inference_session, "prompts") and state.inference_session.prompts: prompts_list = sorted(set(state.inference_session.prompts.values())) if prompts_list: prompts_str = ", ".join([f"'{p}'" for p in prompts_list]) return f"**Active prompts:** {prompts_str}" return "**Active prompts:** None" @spaces.GPU def propagate_masks(state: AppState, progress: gr.Progress = gr.Progress()) -> Iterator[tuple[AppState, str, dict]]: yield from _propagate_masks_core(state, progress=progress) def _propagate_masks_core(state: AppState, progress: gr.Progress = gr.Progress()) -> Iterator[tuple[AppState, str, dict]]: if state is None: yield state, "Load a video first.", gr.update() return if state.active_tab != "text" and state.inference_session is None: yield state, "Load a video first.", gr.update() return total = max(1, state.num_frames) processed = 0 yield state, f"Progress {int(processed / total * 100)}%", gr.update() last_frame_idx = 0 with torch.no_grad(): if state.active_tab == "text": if state.inference_session is None: yield state, "Text video model not loaded.", gr.update() return model = TEXT_VIDEO_MODEL processor = TEXT_VIDEO_PROCESSOR state.inference_session = to_device_recursive(state.inference_session, DEVICE) # Collect all unique prompts from existing frame annotations text_prompt_to_obj_ids = {} for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items(): for obj_id, text_prompt in frame_texts.items(): if text_prompt not in text_prompt_to_obj_ids: text_prompt_to_obj_ids[text_prompt] = [] if obj_id not in text_prompt_to_obj_ids[text_prompt]: text_prompt_to_obj_ids[text_prompt].append(obj_id) # Also check if there are prompts already in the inference session if hasattr(state.inference_session, "prompts") and state.inference_session.prompts: for prompt_text in state.inference_session.prompts.values(): if prompt_text not in text_prompt_to_obj_ids: text_prompt_to_obj_ids[prompt_text] = [] for text_prompt in text_prompt_to_obj_ids: text_prompt_to_obj_ids[text_prompt].sort() if not text_prompt_to_obj_ids: yield state, "No text prompts found. Please add a text prompt first.", gr.update() return # Add all prompts to the inference session (processor handles deduplication) for text_prompt in text_prompt_to_obj_ids: state.inference_session = processor.add_text_prompt( inference_session=state.inference_session, text=text_prompt, ) earliest_frame = min(state.text_prompts_by_frame_obj.keys()) if state.text_prompts_by_frame_obj else 0 frames_to_track = state.num_frames - earliest_frame outputs_per_frame = {} for model_outputs in model.propagate_in_video_iterator( inference_session=state.inference_session, start_frame_idx=earliest_frame, max_frame_num_to_track=frames_to_track, ): frame_idx = model_outputs.frame_idx # Handle potential KeyError in postprocess_outputs when obj_id_to_prompt_id is incomplete try: processed_outputs = processor.postprocess_outputs( state.inference_session, model_outputs, ) except KeyError as e: print(f"Warning: Skipping frame {frame_idx} due to missing prompt mapping: {e}") processed += 1 continue outputs_per_frame[frame_idx] = processed_outputs object_ids = processed_outputs["object_ids"] masks = processed_outputs["masks"] scores = processed_outputs["scores"] prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {}) masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {}) frame_texts = state.text_prompts_by_frame_obj.setdefault(frame_idx, {}) num_objects = len(object_ids) if num_objects > 0: if len(scores) > 0: sorted_indices = torch.argsort(scores, descending=True).cpu().tolist() else: sorted_indices = list(range(num_objects)) for mask_idx in sorted_indices: current_obj_id = int(object_ids[mask_idx].item()) mask_2d = masks[mask_idx].float().cpu().numpy() if mask_2d.ndim == 3: mask_2d = mask_2d.squeeze() mask_2d = (mask_2d > 0.0).astype(np.float32) masks_for_frame[current_obj_id] = mask_2d # Store centroid for trajectory tracking centroid = get_mask_centroid(mask_2d) if centroid: if current_obj_id not in state.centroids_by_obj: state.centroids_by_obj[current_obj_id] = {} state.centroids_by_obj[current_obj_id][frame_idx] = centroid # Find which prompt detected this object found_prompt = None for prompt, obj_ids in prompt_to_obj_ids.items(): if current_obj_id in obj_ids: found_prompt = prompt break # Store prompt and assign color if found_prompt: frame_texts[current_obj_id] = found_prompt.strip() _ensure_color_for_obj(state, current_obj_id) state.composited_frames.pop(frame_idx, None) last_frame_idx = frame_idx processed += 1 if processed % 30 == 0 or processed == total: state.current_frame_idx = int(frame_idx) progress(processed / total, desc=f"Progress {int(processed / total * 100)}%") yield state, f"Progress {int(processed / total * 100)}%", gr.update(value=frame_idx) text = f"Propagated masks across {processed} frames." state.current_frame_idx = int(last_frame_idx) yield state, text, gr.update(value=last_frame_idx) def reset_session(state: AppState) -> tuple[AppState, Image.Image, int, int, str, str]: if not state.video_frames: return state, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None" if state.video_frames: processor = TEXT_VIDEO_PROCESSOR state.inference_session = processor.init_video_session( video=state.video_frames, inference_device=DEVICE, processing_device=DEVICE, video_storage_device=DEVICE, dtype=DTYPE, ) state.masks_by_frame.clear() state.clicks_by_frame_obj.clear() state.boxes_by_frame_obj.clear() state.text_prompts_by_frame_obj.clear() state.composited_frames.clear() state.color_by_obj.clear() state.color_by_prompt.clear() state.centroids_by_obj.clear() state.pending_box_start = None state.pending_box_start_frame_idx = None state.pending_box_start_obj_id = None gc.collect() current_idx = int(getattr(state, "current_frame_idx", 0)) current_idx = max(0, min(current_idx, state.num_frames - 1)) preview_img = update_frame_display(state, current_idx) slider_minmax = gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True) slider_value = gr.update(value=current_idx) status = "Session reset. Prompts cleared; video preserved." active_prompts = _get_active_prompts_display(state) return state, preview_img, slider_minmax, slider_value, status, active_prompts # Removed _on_video_change_pointbox def _on_video_change_text(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]: if video is None: return state, None, None, None, None state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "text") active_prompts = _get_active_prompts_display(state) return ( state, gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True), first_frame, status, active_prompts, ) with gr.Blocks(title="PetalGard - Rat Detection System") as demo: app_state = gr.State(AppState()) gr.Markdown( """ Detect and track rats across video footage using AI-powered segmentation. Features include: - **Automatic rat detection** using text prompts - **Movement trajectory tracking** to visualize rat paths - **Activity heatmap** showing high-activity zones - **🎯 Optimal trap location finder** based on movement patterns """ ) with gr.Row(): with gr.Column(scale=1): video_in_text = gr.Video(label="Upload video", sources=["upload", "webcam"]) load_status_text = gr.Markdown(visible=True) reset_btn_text = gr.Button("Reset Session", variant="secondary") with gr.Column(scale=2): preview_text = gr.Image(label="Preview") with gr.Row(): frame_slider_text = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0) with gr.Column(scale=0): analyze_btn = gr.Button("Analyze Video", variant="primary") propagate_status_text = gr.Markdown(visible=True) active_prompts_display = gr.Markdown("**Active prompts:** None", visible=False) text_status = gr.Markdown(visible=True) with gr.Row(): gr.Markdown("**Visualization Options:**") with gr.Row(): show_trajectory_chk_text = gr.Checkbox(value=True, label="Show Trajectory Lines") show_heatmap_chk_text = gr.Checkbox(value=True, label="Show Activity Heatmap") with gr.Row(): render_btn_text = gr.Button("Render WebP for smooth playback", variant="primary") playback_video_text = gr.Image(label="Rendered Playback", interactive=False) with gr.Row(): with gr.Column(): generate_report_btn = gr.Button("Generate Behavior Report", variant="primary") report_output = gr.Markdown(label="Report Output") annotated_image_output = gr.Image(label="AI Annotated Trap Locations", visible=True) examples_list_text = [ [None, "./videos/VID0010.mp4"], [None, "./videos/VID0015.mp4"], [None, "./videos/VID0140.mp4"], [None, "./videos/VID0142.mp4"], [None, "./videos/VID0143.mp4"], [None, "./videos/VID0144.mp4"], ] with gr.Row(): gr.Examples( label="Examples", examples=examples_list_text, inputs=[app_state, video_in_text], examples_per_page=5, ) video_in_text.change( fn=_on_video_change_text, inputs=[app_state, video_in_text], outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display], show_progress=True, ) def _sync_frame_idx_text(state_in: AppState, idx: int) -> Image.Image: if state_in is not None: state_in.current_frame_idx = int(idx) return update_frame_display(state_in, int(idx)) frame_slider_text.change( fn=_sync_frame_idx_text, inputs=[app_state, frame_slider_text], outputs=preview_text, ) def _sync_visualization_options( state_in: AppState, show_trajectory: bool, show_heatmap: bool, frame_idx: int ) -> tuple[AppState, Image.Image]: if state_in is not None: state_in.show_trajectory = show_trajectory state_in.show_heatmap = show_heatmap # Clear cached composited frames to force redraw state_in.composited_frames.clear() return state_in, update_frame_display(state_in, int(frame_idx)) show_trajectory_chk_text.change( fn=_sync_visualization_options, inputs=[app_state, show_trajectory_chk_text, show_heatmap_chk_text, frame_slider_text], outputs=[app_state, preview_text], ) show_heatmap_chk_text.change( fn=_sync_visualization_options, inputs=[app_state, show_trajectory_chk_text, show_heatmap_chk_text, frame_slider_text], outputs=[app_state, preview_text], ) analyze_btn.click( fn=on_text_prompt, inputs=[app_state, frame_slider_text], outputs=[preview_text, text_status, active_prompts_display, app_state, frame_slider_text], ) def _render_video(s: AppState) -> str: if s is None or s.num_frames == 0: raise gr.Error("Load a video first.") fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12 frames_pil = [] for idx in range(s.num_frames): img = s.composited_frames.get(idx) if img is None: img = compose_frame(s, idx) frames_pil.append(img) if (idx + 1) % 60 == 0: gc.collect() try: # Create a temporary file for the WebP timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"recording_{timestamp}.webp" # Write WebP to temporary file first with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as temp_file: temp_path = temp_file.name # Save as animated WebP dur = int(1000 / fps) if fps > 0 else 100 frames_pil[0].save( temp_path, format="WEBP", save_all=True, append_images=frames_pil[1:], duration=dur, loop=0, quality=80, method=4 # Balance between speed and compression ) # Read the temporary file and try to upload to Supabase video_url = None with open(temp_path, 'rb') as f: video_bytes = f.read() if storage_client: try: # Upload to Supabase Storage upload_result = storage_client.upload_video( video_bytes=video_bytes, filename=filename ) video_url = upload_result['public_url'] print(f"Uploaded video to Supabase: {video_url}") # Save to database if available if db_client: try: # Check if there's an analysis job for this session job_id = generate_job_id() db_client.create_analysis_job(job_id=job_id, video_path=video_url) print(f"Created analysis job: {job_id}") except Exception as db_err: print(f"Failed to create analysis job: {db_err}") except Exception as upload_err: print(f"Failed to upload to Supabase, using local storage: {upload_err}") # Fall back to local storage out_dir = Path("../platform/public/recordings") out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / filename with open(out_path, 'wb') as f: f.write(video_bytes) video_url = f"/recordings/{filename}" else: # Use local storage out_dir = Path("../platform/public/recordings") out_dir.mkdir(parents=True, exist_ok=True) out_path = out_dir / filename with open(out_path, 'wb') as f: f.write(video_bytes) video_url = f"/recordings/{filename}" # Clean up temporary file try: os.unlink(temp_path) except: pass return video_url except Exception as e: print(f"Failed to render video with cv2: {e}") raise gr.Error(f"Failed to render video: {e}") render_btn_text.click( fn=_render_video, inputs=app_state, outputs=playback_video_text, ) reset_btn_text.click( fn=reset_session, inputs=app_state, outputs=[ app_state, preview_text, frame_slider_text, frame_slider_text, load_status_text, active_prompts_display, ], ) generate_report_btn.click( fn=generate_report, inputs=[app_state, frame_slider_text], outputs=[report_output, annotated_image_output], ) if __name__ == "__main__": demo.queue(api_open=False).launch( theme=Soft(primary_hue="red", secondary_hue="orange", neutral_hue="slate") )