"""Sinkhorn optimal transport loss for segment matching. Note: at eps=0.05, sinkhorn gradients are near-zero (~1e-7 norm) for typical matrix sizes. The loss value is tracked but does not meaningfully train the model. Default sinkhorn_weight=0.0. See worklog.md for details. Future: schedule eps from large (1.0) to small (0.05) during training to get useful gradients early and precise matching late. """ import torch def batched_sinkhorn_loss( pred_segments: torch.Tensor, gt_pad: torch.Tensor, gt_mask: torch.Tensor, eps: float, iters: int, dustbin_cost: float | torch.Tensor, pred_mass: torch.Tensor | None = None, ) -> torch.Tensor: """Batched sinkhorn segment matching loss. Args: pred_segments: [B, S, 2, 3] predicted segments gt_pad: [B, M, 2, 3] padded GT segments gt_mask: [B, M] bool mask (True = valid GT segment) eps: sinkhorn regularization iters: sinkhorn iterations dustbin_cost: cost for unmatched segments (scalar or [B]) pred_mass: [B, S] per-segment mass weights (e.g. sigmoid(conf)). If None, uniform masses are used. Returns: [B] per-sample sinkhorn transport cost """ B, S, _, _ = pred_segments.shape M = gt_pad.shape[1] # Allow per-sample dustbin cost dc = torch.as_tensor(dustbin_cost, device=pred_segments.device, dtype=pred_segments.dtype) if dc.dim() == 0: dc = dc.expand(B) # Compute cost matrices [B, S, M] in midpoint-halfvec space. # Decouples position from direction: mid gradient is pure position, # half gradient is pure direction/length. Sign-invariance on half # handles segment direction ambiguity cleanly. p0 = pred_segments[:, :, 0] # [B, S, 3] p1 = pred_segments[:, :, 1] # [B, S, 3] g0 = gt_pad[:, :, 0] # [B, M, 3] g1 = gt_pad[:, :, 1] # [B, M, 3] mid_pred = 0.5 * (p0 + p1) # [B, S, 3] half_pred = 0.5 * (p1 - p0) # [B, S, 3] mid_gt = 0.5 * (g0 + g1) # [B, M, 3] half_gt = 0.5 * (g1 - g0) # [B, M, 3] # Midpoint distance [B, S, M] d_mid = torch.linalg.norm( mid_pred.unsqueeze(2) - mid_gt.unsqueeze(1), dim=-1) # Decoupled direction + length distance (sign-invariant for direction ambiguity) len_pred = torch.linalg.norm(half_pred, dim=-1, keepdim=True).clamp(min=1e-6) # [B, S, 1] len_gt = torch.linalg.norm(half_gt, dim=-1, keepdim=True).clamp(min=1e-6) # [B, M, 1] dir_pred = half_pred / len_pred # [B, S, 3] dir_gt = half_gt / len_gt # [B, M, 3] # Direction distance: 1 - |cos(angle)|, sign-invariant [B, S, M] cos_angle = (dir_pred.unsqueeze(2) * dir_gt.unsqueeze(1)).sum(dim=-1) # [B, S, M] d_dir = 1.0 - cos_angle.abs() # Length distance [B, S, M] d_len = (len_pred.unsqueeze(2) - len_gt.unsqueeze(1)).squeeze(-1).abs() cost = d_mid + d_dir + d_len # [B, S, M] # Mask invalid GT segments with high cost so they go to dustbin cost = torch.where(gt_mask.unsqueeze(1), cost, dc[:, None, None] * 10.0) # Pad with dustbin row and column: [B, S+1, M+1] cost_pad = dc[:, None, None].expand(B, S + 1, M + 1).clone() cost_pad[:, :S, :M] = cost cost_pad[:, -1, -1] = 0.0 # Masses gt_counts = gt_mask.sum(dim=1).float() # [B] if pred_mass is not None: # Confidence-weighted masses (matches learned_v2 approach). # sigmoid(conf) gives per-segment mass; dustbin masses balance the totals. # No normalization -- sum(a) == sum(b) == max(sum_pred, sum_gt). pm = pred_mass.clamp(min=0.0) # [B, S] sum_pred = pm.sum(dim=1) # [B] sum_gt = gt_counts # [B] pred_dustbin = (sum_gt - sum_pred).clamp(min=0.0) # [B] gt_dustbin = (sum_pred - sum_gt).clamp(min=0.0) # [B] a = torch.cat([pm, pred_dustbin.unsqueeze(1)], dim=1) # [B, S+1] b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype) b_val[:, :M] = gt_mask.float() # 1.0 per valid GT segment b_val[:, -1] = gt_dustbin else: # Uniform masses (normalized) n = float(S) denom = n + gt_counts # [B] a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone() # [B, S+1] a[:, -1] = gt_counts / denom b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone() # [B, M+1] b_val[:, -1] = n / denom # Zero out mass for invalid GT b_val[:, :M] = b_val[:, :M] * gt_mask.float() # Log-domain sinkhorn log_a = torch.log(a + 1e-9) log_b = torch.log(b_val + 1e-9) log_k = -cost_pad / eps log_u = torch.zeros_like(a) log_v = torch.zeros_like(b_val) for _ in range(iters): log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2) log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1) transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k) return (transport * cost_pad).sum(dim=(1, 2)) # [B]