Mirko Trasciatti commited on
Commit
11e7a5f
·
1 Parent(s): 1205e2f

Visualize SAM2 ball trajectory using mask centroids

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py CHANGED
@@ -212,6 +212,7 @@ class AppState:
212
  self.pending_box_start_frame_idx: int | None = None
213
  self.pending_box_start_obj_id: int | None = None
214
  self.is_switching_model: bool = False
 
215
  # Model selection
216
  self.model_repo_key: str = "tiny"
217
  self.model_repo_id: str | None = None
@@ -286,6 +287,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
286
  GLOBAL_STATE.inference_session = None
287
  GLOBAL_STATE.masks_by_frame = {}
288
  GLOBAL_STATE.color_by_obj = {}
 
289
 
290
  load_model_if_needed(GLOBAL_STATE)
291
 
@@ -397,6 +399,15 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
397
  color = state.color_by_obj.get(obj_id, (255, 255, 255))
398
  for x1, y1, x2, y2 in boxes:
399
  draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
 
 
 
 
 
 
 
 
 
400
  # Save to cache and return
401
  state.composited_frames[frame_idx] = out_img
402
  return out_img
@@ -418,6 +429,43 @@ def _ensure_color_for_obj(state: AppState, obj_id: int):
418
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
419
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  def on_image_click(
422
  img: Image.Image | np.ndarray,
423
  state: AppState,
@@ -545,6 +593,7 @@ def on_image_click(
545
  masks_for_frame[int(oid)] = mask_2d
546
 
547
  state.masks_by_frame[int(frame_idx)] = masks_for_frame
 
548
  # Invalidate cache for this frame to force recomposition
549
  state.composited_frames.pop(int(frame_idx), None)
550
 
@@ -590,6 +639,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
590
  mask_2d = video_res_masks[i].cpu().numpy().squeeze()
591
  masks_for_frame[int(oid)] = mask_2d
592
  GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
 
593
  # Invalidate cache for that frame to force recomposition
594
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
595
 
@@ -618,6 +668,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
618
  GLOBAL_STATE.pending_box_start = None
619
  GLOBAL_STATE.pending_box_start_frame_idx = None
620
  GLOBAL_STATE.pending_box_start_obj_id = None
 
621
 
622
  # Dispose and re-init inference session for current model with existing frames
623
  try:
 
212
  self.pending_box_start_frame_idx: int | None = None
213
  self.pending_box_start_obj_id: int | None = None
214
  self.is_switching_model: bool = False
215
+ self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
216
  # Model selection
217
  self.model_repo_key: str = "tiny"
218
  self.model_repo_id: str | None = None
 
287
  GLOBAL_STATE.inference_session = None
288
  GLOBAL_STATE.masks_by_frame = {}
289
  GLOBAL_STATE.color_by_obj = {}
290
+ GLOBAL_STATE.ball_centers = {}
291
 
292
  load_model_if_needed(GLOBAL_STATE)
293
 
 
399
  color = state.color_by_obj.get(obj_id, (255, 255, 255))
400
  for x1, y1, x2, y2 in boxes:
401
  draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
402
+ # Draw trajectory centers (all frames)
403
+ if state.ball_centers:
404
+ draw = ImageDraw.Draw(out_img)
405
+ cross_half = 4
406
+ for obj_id, centers in state.ball_centers.items():
407
+ color = state.color_by_obj.get(obj_id, (255, 255, 0))
408
+ for cx, cy in centers.values():
409
+ draw.line([(cx - cross_half, cy), (cx + cross_half, cy)], fill=color, width=2)
410
+ draw.line([(cx, cy - cross_half), (cx, cy + cross_half)], fill=color, width=2)
411
  # Save to cache and return
412
  state.composited_frames[frame_idx] = out_img
413
  return out_img
 
429
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
430
 
431
 
432
+ def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None:
433
+ if mask is None:
434
+ return None
435
+ mask_np = np.array(mask)
436
+ if mask_np.ndim == 3:
437
+ mask_np = mask_np.squeeze()
438
+ if mask_np.size == 0:
439
+ return None
440
+ mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32)
441
+ moments = cv2.moments(mask_float)
442
+ if moments["m00"] == 0:
443
+ return None
444
+ cx = int(moments["m10"] / moments["m00"])
445
+ cy = int(moments["m01"] / moments["m00"])
446
+ return cx, cy
447
+
448
+
449
+ def _update_centroids_for_frame(state: AppState, frame_idx: int):
450
+ if state is None:
451
+ return
452
+ masks = state.masks_by_frame.get(int(frame_idx), {})
453
+ seen_obj_ids: set[int] = set()
454
+ for obj_id, mask in masks.items():
455
+ centroid = _compute_mask_centroid(mask)
456
+ centers = state.ball_centers.setdefault(int(obj_id), {})
457
+ if centroid is not None:
458
+ centers[int(frame_idx)] = centroid
459
+ else:
460
+ centers.pop(int(frame_idx), None)
461
+ seen_obj_ids.add(int(obj_id))
462
+ _ensure_color_for_obj(state, int(obj_id))
463
+ # Remove frames for objects without masks at this frame
464
+ for obj_id, centers in state.ball_centers.items():
465
+ if obj_id not in seen_obj_ids:
466
+ centers.pop(int(frame_idx), None)
467
+
468
+
469
  def on_image_click(
470
  img: Image.Image | np.ndarray,
471
  state: AppState,
 
593
  masks_for_frame[int(oid)] = mask_2d
594
 
595
  state.masks_by_frame[int(frame_idx)] = masks_for_frame
596
+ _update_centroids_for_frame(state, int(frame_idx))
597
  # Invalidate cache for this frame to force recomposition
598
  state.composited_frames.pop(int(frame_idx), None)
599
 
 
639
  mask_2d = video_res_masks[i].cpu().numpy().squeeze()
640
  masks_for_frame[int(oid)] = mask_2d
641
  GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
642
+ _update_centroids_for_frame(GLOBAL_STATE, frame_idx)
643
  # Invalidate cache for that frame to force recomposition
644
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
645
 
 
668
  GLOBAL_STATE.pending_box_start = None
669
  GLOBAL_STATE.pending_box_start_frame_idx = None
670
  GLOBAL_STATE.pending_box_start_obj_id = None
671
+ GLOBAL_STATE.ball_centers.clear()
672
 
673
  # Dispose and re-init inference session for current model with existing frames
674
  try: