Spaces:
Runtime error
Runtime error
Mirko Trasciatti commited on
Commit ·
11e7a5f
1
Parent(s): 1205e2f
Visualize SAM2 ball trajectory using mask centroids
Browse files
app.py
CHANGED
|
@@ -212,6 +212,7 @@ class AppState:
|
|
| 212 |
self.pending_box_start_frame_idx: int | None = None
|
| 213 |
self.pending_box_start_obj_id: int | None = None
|
| 214 |
self.is_switching_model: bool = False
|
|
|
|
| 215 |
# Model selection
|
| 216 |
self.model_repo_key: str = "tiny"
|
| 217 |
self.model_repo_id: str | None = None
|
|
@@ -286,6 +287,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
| 286 |
GLOBAL_STATE.inference_session = None
|
| 287 |
GLOBAL_STATE.masks_by_frame = {}
|
| 288 |
GLOBAL_STATE.color_by_obj = {}
|
|
|
|
| 289 |
|
| 290 |
load_model_if_needed(GLOBAL_STATE)
|
| 291 |
|
|
@@ -397,6 +399,15 @@ def compose_frame(state: AppState, frame_idx: int, remove_bg: bool = False) -> I
|
|
| 397 |
color = state.color_by_obj.get(obj_id, (255, 255, 255))
|
| 398 |
for x1, y1, x2, y2 in boxes:
|
| 399 |
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
# Save to cache and return
|
| 401 |
state.composited_frames[frame_idx] = out_img
|
| 402 |
return out_img
|
|
@@ -418,6 +429,43 @@ def _ensure_color_for_obj(state: AppState, obj_id: int):
|
|
| 418 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 419 |
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
def on_image_click(
|
| 422 |
img: Image.Image | np.ndarray,
|
| 423 |
state: AppState,
|
|
@@ -545,6 +593,7 @@ def on_image_click(
|
|
| 545 |
masks_for_frame[int(oid)] = mask_2d
|
| 546 |
|
| 547 |
state.masks_by_frame[int(frame_idx)] = masks_for_frame
|
|
|
|
| 548 |
# Invalidate cache for this frame to force recomposition
|
| 549 |
state.composited_frames.pop(int(frame_idx), None)
|
| 550 |
|
|
@@ -590,6 +639,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 590 |
mask_2d = video_res_masks[i].cpu().numpy().squeeze()
|
| 591 |
masks_for_frame[int(oid)] = mask_2d
|
| 592 |
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
|
|
|
|
| 593 |
# Invalidate cache for that frame to force recomposition
|
| 594 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 595 |
|
|
@@ -618,6 +668,7 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 618 |
GLOBAL_STATE.pending_box_start = None
|
| 619 |
GLOBAL_STATE.pending_box_start_frame_idx = None
|
| 620 |
GLOBAL_STATE.pending_box_start_obj_id = None
|
|
|
|
| 621 |
|
| 622 |
# Dispose and re-init inference session for current model with existing frames
|
| 623 |
try:
|
|
|
|
| 212 |
self.pending_box_start_frame_idx: int | None = None
|
| 213 |
self.pending_box_start_obj_id: int | None = None
|
| 214 |
self.is_switching_model: bool = False
|
| 215 |
+
self.ball_centers: dict[int, dict[int, tuple[int, int]]] = {}
|
| 216 |
# Model selection
|
| 217 |
self.model_repo_key: str = "tiny"
|
| 218 |
self.model_repo_id: str | None = None
|
|
|
|
| 287 |
GLOBAL_STATE.inference_session = None
|
| 288 |
GLOBAL_STATE.masks_by_frame = {}
|
| 289 |
GLOBAL_STATE.color_by_obj = {}
|
| 290 |
+
GLOBAL_STATE.ball_centers = {}
|
| 291 |
|
| 292 |
load_model_if_needed(GLOBAL_STATE)
|
| 293 |
|
|
|
|
| 399 |
color = state.color_by_obj.get(obj_id, (255, 255, 255))
|
| 400 |
for x1, y1, x2, y2 in boxes:
|
| 401 |
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=2)
|
| 402 |
+
# Draw trajectory centers (all frames)
|
| 403 |
+
if state.ball_centers:
|
| 404 |
+
draw = ImageDraw.Draw(out_img)
|
| 405 |
+
cross_half = 4
|
| 406 |
+
for obj_id, centers in state.ball_centers.items():
|
| 407 |
+
color = state.color_by_obj.get(obj_id, (255, 255, 0))
|
| 408 |
+
for cx, cy in centers.values():
|
| 409 |
+
draw.line([(cx - cross_half, cy), (cx + cross_half, cy)], fill=color, width=2)
|
| 410 |
+
draw.line([(cx, cy - cross_half), (cx, cy + cross_half)], fill=color, width=2)
|
| 411 |
# Save to cache and return
|
| 412 |
state.composited_frames[frame_idx] = out_img
|
| 413 |
return out_img
|
|
|
|
| 429 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 430 |
|
| 431 |
|
| 432 |
+
def _compute_mask_centroid(mask: np.ndarray) -> tuple[int, int] | None:
|
| 433 |
+
if mask is None:
|
| 434 |
+
return None
|
| 435 |
+
mask_np = np.array(mask)
|
| 436 |
+
if mask_np.ndim == 3:
|
| 437 |
+
mask_np = mask_np.squeeze()
|
| 438 |
+
if mask_np.size == 0:
|
| 439 |
+
return None
|
| 440 |
+
mask_float = np.clip(mask_np, 0.0, 1.0).astype(np.float32)
|
| 441 |
+
moments = cv2.moments(mask_float)
|
| 442 |
+
if moments["m00"] == 0:
|
| 443 |
+
return None
|
| 444 |
+
cx = int(moments["m10"] / moments["m00"])
|
| 445 |
+
cy = int(moments["m01"] / moments["m00"])
|
| 446 |
+
return cx, cy
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _update_centroids_for_frame(state: AppState, frame_idx: int):
|
| 450 |
+
if state is None:
|
| 451 |
+
return
|
| 452 |
+
masks = state.masks_by_frame.get(int(frame_idx), {})
|
| 453 |
+
seen_obj_ids: set[int] = set()
|
| 454 |
+
for obj_id, mask in masks.items():
|
| 455 |
+
centroid = _compute_mask_centroid(mask)
|
| 456 |
+
centers = state.ball_centers.setdefault(int(obj_id), {})
|
| 457 |
+
if centroid is not None:
|
| 458 |
+
centers[int(frame_idx)] = centroid
|
| 459 |
+
else:
|
| 460 |
+
centers.pop(int(frame_idx), None)
|
| 461 |
+
seen_obj_ids.add(int(obj_id))
|
| 462 |
+
_ensure_color_for_obj(state, int(obj_id))
|
| 463 |
+
# Remove frames for objects without masks at this frame
|
| 464 |
+
for obj_id, centers in state.ball_centers.items():
|
| 465 |
+
if obj_id not in seen_obj_ids:
|
| 466 |
+
centers.pop(int(frame_idx), None)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
def on_image_click(
|
| 470 |
img: Image.Image | np.ndarray,
|
| 471 |
state: AppState,
|
|
|
|
| 593 |
masks_for_frame[int(oid)] = mask_2d
|
| 594 |
|
| 595 |
state.masks_by_frame[int(frame_idx)] = masks_for_frame
|
| 596 |
+
_update_centroids_for_frame(state, int(frame_idx))
|
| 597 |
# Invalidate cache for this frame to force recomposition
|
| 598 |
state.composited_frames.pop(int(frame_idx), None)
|
| 599 |
|
|
|
|
| 639 |
mask_2d = video_res_masks[i].cpu().numpy().squeeze()
|
| 640 |
masks_for_frame[int(oid)] = mask_2d
|
| 641 |
GLOBAL_STATE.masks_by_frame[frame_idx] = masks_for_frame
|
| 642 |
+
_update_centroids_for_frame(GLOBAL_STATE, frame_idx)
|
| 643 |
# Invalidate cache for that frame to force recomposition
|
| 644 |
GLOBAL_STATE.composited_frames.pop(frame_idx, None)
|
| 645 |
|
|
|
|
| 668 |
GLOBAL_STATE.pending_box_start = None
|
| 669 |
GLOBAL_STATE.pending_box_start_frame_idx = None
|
| 670 |
GLOBAL_STATE.pending_box_start_obj_id = None
|
| 671 |
+
GLOBAL_STATE.ball_centers.clear()
|
| 672 |
|
| 673 |
# Dispose and re-init inference session for current model with existing frames
|
| 674 |
try:
|