#!/usr/bin/env python3 """ BASE TIER PATCHWORK SOUP — PROPERLY CALIBRATED ================================================ 3 experts, all 768-d: clip_l14_openai, dinov2_b14, siglip_b16_384 Pipeline (from CaptionBERT research): 1. GPA alignment at 768-d → consensus 2. Measure consensus CV → CV loss target 3. Per-expert whitened Procrustes calibration 4. Initialize projectors from Procrustes rotations 5. Train: projectors + constellation + patchwork + classifier against consensus targets with calibrated CV Architecture: Per-expert: 768 → 128 (Procrustes-initialized projection) Constellation: 256 anchors × 128-d (geometric autograd) Patchwork: 8 compartments Classifier: patchwork + fused → 80 classes """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import gc DEVICE = "cuda" if torch.cuda.is_available() else "cpu" D_EXPERT = 768 D_ANCHOR = 128 N_ANCHORS = 256 N_CLASSES = 80 N_COMP = 8 D_COMP = 64 BATCH = 128 EPOCHS = 20 LR = 1e-3 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] print("=" * 65) print("BASE TIER PATCHWORK SOUP — CALIBRATED") print(f" 3 experts × {D_EXPERT}-d → {N_ANCHORS} anchors × {D_ANCHOR}-d") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() @torch.no_grad() def cv_metric(emb, n_samples=500): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) def anchor_spread_loss(anchors): a = F.normalize(anchors, dim=-1) sim = a @ a.T sim = sim - torch.diag(torch.diag(sim)) return sim.pow(2).mean() def anchor_entropy_loss(emb, anchors, sharpness=10.0): a = F.normalize(anchors, dim=-1) probs = F.softmax(emb @ a.T * sharpness, dim=-1) return -(probs * (probs + 1e-12).log()).sum(-1).mean() def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None # ══════════════════════════════════════════════════════════════════ # PROCRUSTES UTILITIES (from cotrain_bank.py) # ══════════════════════════════════════════════════════════════════ def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T def procrustes_align(source, target, n_align=10000): N = min(n_align, source.shape[0], target.shape[0]) S = source[:N].float(); T = target[:N].float() s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True) Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0] s_cov = (Sc.T @ Sc) / max(N_s-1, 1) t_cov = (Tc.T @ Tc) / max(N_s-1, 1) s_whiten = symmetric_inv_sqrt(s_cov) t_whiten = symmetric_inv_sqrt(t_cov) Sc_w = F.normalize(Sc @ s_whiten, dim=-1) Tc_w = F.normalize(Tc @ t_whiten, dim=-1) U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) R = U @ Vt cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() return {"rotation": R, "source_mean": s_mean.squeeze(0), "source_whitener": s_whiten, "target_unwhitener": torch.linalg.pinv(t_whiten), "cos_after": cos_after} def apply_align(emb, a): x = emb.float() - a["source_mean"] x = x @ a["source_whitener"] x = x @ a["rotation"].T x = x @ a["target_unwhitener"] return x # ══════════════════════════════════════════════════════════════════ # MODEL # ══════════════════════════════════════════════════════════════════ class ExpertProjector(nn.Module): def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR): super().__init__() self.proj = nn.Sequential(nn.Linear(d_in, d_out), nn.LayerNorm(d_out)) def forward(self, x): return F.normalize(self.proj(x), dim=-1) class Constellation(nn.Module): def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR): super().__init__() self.n_anchors = n_anchors self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): super().__init__() self.n_comp = n_comp asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for k in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) class BaseTierSoup(nn.Module): def __init__(self): super().__init__() self.n_experts = 3 self.projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) self.constellation = Constellation() self.patchwork = Patchwork() pw_dim = N_COMP * D_COMP self.classifier = nn.Sequential( nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.1), nn.Linear(pw_dim, N_CLASSES)) def forward(self, expert_embeddings, apply_autograd=True): projected = [self.projectors[i](expert_embeddings[i]) for i in range(3)] fused = F.normalize(sum(projected) / 3, dim=-1) if apply_autograd and self.training: fused = EmbeddingAutograd.apply( fused, fused, self.constellation.anchors, 0.01, 1.0) tri, nearest = self.constellation.triangulate(fused) pw = self.patchwork(tri) logits = self.classifier(torch.cat([pw, fused], -1)) return logits, fused, tri, nearest, projected # ══════════════════════════════════════════════════════════════════ # PHASE 0: LOAD DATA # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 0: LOAD DATA") print(f"{'='*65}") from datasets import load_dataset ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train") train_ids = ref["image_id"]; N_train = len(train_ids) train_id_map = {iid: i for i, iid in enumerate(train_ids)} train_label_matrix = torch.zeros(N_train, N_CLASSES) for i, labs in enumerate(ref["labels"]): for l in labs: if l < N_CLASSES: train_label_matrix[i, l] = 1.0 ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") val_ids = ref_val["image_id"]; N_val = len(val_ids) val_id_map = {iid: i for i, iid in enumerate(val_ids)} val_label_matrix = torch.zeros(N_val, N_CLASSES) for i, labs in enumerate(ref_val["labels"]): for l in labs: if l < N_CLASSES: val_label_matrix[i, l] = 1.0 print(f" Train: {N_train:,} Val: {N_val:,}") train_raw = {} val_raw = {} for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") feats = torch.zeros(N_train, D_EXPERT) for row in ds: if row["image_id"] in train_id_map: feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) train_raw[name] = feats ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") feats_v = torch.zeros(N_val, D_EXPERT) for row in ds_v: if row["image_id"] in val_id_map: feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) val_raw[name] = feats_v print(f" {name:<30} loaded", flush=True) del ds, ds_v; gc.collect() # ══════════════════════════════════════════════════════════════════ # PHASE 1: GPA ALIGNMENT AT 768-d # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 1: GPA ALIGNMENT AT 768-d") print(f"{'='*65}") current = {name: train_raw[name][:N_train].float() for name in EXPERTS} for gpa_iter in range(20): mean_shape = sum(current[n] for n in EXPERTS) / len(EXPERTS) total_delta = 0.0 new_current = {} for name in EXPERTS: info = procrustes_align(current[name], mean_shape) new_current[name] = apply_align(current[name], info) total_delta += (new_current[name] - current[name]).pow(2).mean().item() current = new_current if gpa_iter == 0 or (gpa_iter+1) % 5 == 0: print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}") if total_delta < 1e-8: print(f" Converged at iteration {gpa_iter+1}"); break consensus_768 = F.normalize( sum(current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) for name in EXPERTS: c = F.cosine_similarity(consensus_768[:5000], current[name][:5000], dim=-1).mean().item() print(f" cos(consensus, {name}): {c:.4f}") consensus_cv_768 = cv_metric(consensus_768[:5000].to(DEVICE)) print(f" Consensus CV at 768-d: {consensus_cv_768:.4f}") # ══════════════════════════════════════════════════════════════════ # PHASE 2: PROJECT CONSENSUS TO 128-d + CALIBRATE CV # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 2: PROJECT TO 128-d + CALIBRATE") print(f"{'='*65}") cons_centered = consensus_768 - consensus_768.mean(0, keepdim=True) U, S, Vt = torch.linalg.svd(cons_centered[:10000], full_matrices=False) pca_proj = Vt[:D_ANCHOR] consensus_128 = F.normalize(consensus_768 @ pca_proj.T, dim=-1) var_retained = S[:D_ANCHOR].pow(2).sum() / S.pow(2).sum() print(f" PCA 768→128: variance retained = {var_retained.item():.4f}") consensus_cv_128 = cv_metric(consensus_128[:5000].to(DEVICE)) print(f" Consensus CV at 128-d: {consensus_cv_128:.4f}") # Val consensus val_current = {name: val_raw[name].float() for name in EXPERTS} for gpa_iter in range(20): val_mean = sum(val_current[n] for n in EXPERTS) / len(EXPERTS) delta = 0.0 for name in EXPERTS: info = procrustes_align(val_current[name], val_mean) new = apply_align(val_current[name], info) delta += (new - val_current[name]).pow(2).mean().item() val_current[name] = new if delta < 1e-8: break val_consensus_768 = F.normalize( sum(val_current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) val_consensus_128 = F.normalize(val_consensus_768 @ pca_proj.T, dim=-1) print(f" Val consensus: {val_consensus_128.shape}") # ══════════════════════════════════════════════════════════════════ # PHASE 3: PER-EXPERT PROCRUSTES TO 128-d CONSENSUS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 3: PER-EXPERT PROCRUSTES CALIBRATION") print(f"{'='*65}") expert_calibrations = {} for name in EXPERTS: raw = train_raw[name][:10000].float() tgt = consensus_128[:10000].float() src_mean = raw.mean(0, keepdim=True) tgt_mean = tgt.mean(0, keepdim=True) src_c = raw[:10000] - src_mean tgt_c = tgt[:10000] - tgt_mean src_cov = (src_c.T @ src_c) / 9999 src_whiten = symmetric_inv_sqrt(src_cov) tgt_cov = (tgt_c.T @ tgt_c) / 9999 tgt_whiten = symmetric_inv_sqrt(tgt_cov) src_w = F.normalize(src_c @ src_whiten, dim=-1) tgt_w = F.normalize(tgt_c @ tgt_whiten, dim=-1) M = tgt_w.T @ src_w U_r, S_r, Vt_r = torch.linalg.svd(M, full_matrices=False) R = U_r @ Vt_r proj_W = (src_whiten @ R.T).T proj_b = -(src_mean.squeeze(0) @ src_whiten @ R.T).squeeze(0) test_proj = raw[:1000] @ proj_W.T + proj_b test_proj_n = F.normalize(test_proj, dim=-1) cos = F.cosine_similarity(test_proj_n, tgt[:1000], dim=-1).mean().item() expert_calibrations[name] = {"weight": proj_W, "bias": proj_b, "cos": cos, "svd_S": S_r} print(f" {name:<30} cos={cos:.4f} svd: min={S_r.min():.4f} max={S_r.max():.4f}") # ══════════════════════════════════════════════════════════════════ # PHASE 4: BUILD + INITIALIZE # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 4: BUILD + INITIALIZE") print(f"{'='*65}") model = BaseTierSoup().to(DEVICE) with torch.no_grad(): for i, name in enumerate(EXPERTS): cal = expert_calibrations[name] model.projectors[i].proj[0].weight.copy_(cal["weight"].to(DEVICE)) model.projectors[i].proj[0].bias.copy_(cal["bias"].to(DEVICE)) print(f" ✓ {name} projector initialized (cos={cal['cos']:.4f})") sample_idx = torch.randperm(min(10000, N_train))[:N_ANCHORS] anchor_seeds = consensus_128[sample_idx].to(DEVICE) model.constellation.anchors.copy_(F.normalize(anchor_seeds, dim=-1)) print(f" ✓ Constellation seeded from consensus") # Verify with torch.no_grad(): test_in = [train_raw[EXPERTS[e]][:200].to(DEVICE) for e in range(3)] _, test_fused, _, test_nearest, test_proj = model(test_in, apply_autograd=False) test_tgt = consensus_128[:200].to(DEVICE) init_cos = F.cosine_similarity(test_fused, test_tgt, dim=-1).mean().item() init_cv = cv_metric(test_fused) n_active = test_nearest.unique().numel() for e, name in enumerate(["clip", "dino", "siglip"]): c = F.cosine_similarity(test_proj[e], test_tgt, dim=-1).mean().item() print(f" {name} proj→consensus cos: {c:.4f}") print(f" Init: cos={init_cos:.4f} cv={init_cv:.4f} active_anchors={n_active}/256") params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {params:,}") print(f" CV target: {consensus_cv_128:.4f}") # ══════════════════════════════════════════════════════════════════ # PHASE 5: TRAINING # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 5: TRAINING") print(f" {EPOCHS} epochs, lr={LR}, CV target={consensus_cv_128:.4f}") print(f"{'='*65}") train_targets = consensus_128.to(DEVICE) val_targets = val_consensus_128.to(DEVICE) train_labels_gpu = train_label_matrix.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LR) os.makedirs("checkpoints", exist_ok=True) from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter("runs/base_tier_calibrated") best_mAP = 0.0; gs = 0 for epoch in range(EPOCHS): model.train() perm = torch.randperm(N_train) tl, tn, nb = 0, 0, 0 for i in range(0, N_train, BATCH): idx = perm[i:i+BATCH] if len(idx) < 4: continue batch = [train_raw[EXPERTS[e]][idx].to(DEVICE) for e in range(3)] labels = train_labels_gpu[idx] targets = train_targets[idx] logits, fused, tri, nearest, projected = model(batch) anchors = model.constellation.anchors l_nce, nce_acc = infonce(fused, targets) l_mse = F.mse_loss(fused, targets) l_cls = F.binary_cross_entropy_with_logits(logits, labels) l_cv = cv_loss(fused, target=consensus_cv_128) l_spread = anchor_spread_loss(anchors) l_ent = anchor_entropy_loss(fused, anchors) loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_cls + 0.001 * l_cv + 1e-3 * l_spread + 1e-4 * l_ent) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step(); optimizer.zero_grad(set_to_none=True) tl += loss.item(); tn += nce_acc; nb += 1; gs += 1 if gs % 100 == 0: writer.add_scalar("train/loss", loss.item(), gs) writer.add_scalar("train/nce", l_nce.item(), gs) writer.add_scalar("train/cls", l_cls.item(), gs) writer.add_scalar("train/cv", l_cv.item(), gs) writer.add_scalar("train/nce_acc", nce_acc, gs) # Validation model.eval() with torch.no_grad(): all_lo, all_em = [], [] for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) batch_v = [val_raw[EXPERTS[e]][j:end].to(DEVICE) for e in range(3)] lo, em, _, _, _ = model(batch_v, apply_autograd=False) all_lo.append(lo.cpu()); all_em.append(em.cpu()) v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) v_lab = val_label_matrix ap_sum, nv = 0, 0 for c in range(N_CLASSES): if v_lab[:, c].sum() > 0: si = v_lo[:, c].argsort(descending=True) st = v_lab[:, c][si] pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 mAP = ap_sum / max(nv, 1) vp = (v_lo.sigmoid() > 0.5).float() tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0) fn = ((1-vp) * v_lab).sum(0) pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8) f1 = 2*pr*rc/(pr+rc+1e-8) macro_f1 = f1[f1 > 0].mean().item() v_cos = F.cosine_similarity(v_em, val_targets.cpu(), dim=-1).mean().item() v_cv = cv_metric(v_em.to(DEVICE)) sim = v_em @ val_targets.cpu().T r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item() _, v_nearest = model.constellation.triangulate(v_em.to(DEVICE)) n_active = v_nearest.cpu().unique().numel() writer.add_scalar("val/mAP", mAP, epoch+1) writer.add_scalar("val/cos", v_cos, epoch+1) writer.add_scalar("val/cv", v_cv, epoch+1) writer.add_scalar("val/R@1", r1, epoch+1) writer.add_scalar("val/active_anchors", n_active, epoch+1) mk = "" if mAP > best_mAP: best_mAP = mAP torch.save({ "state_dict": model.state_dict(), "config": {"d_expert": D_EXPERT, "d_anchor": D_ANCHOR, "n_anchors": N_ANCHORS, "n_comp": N_COMP, "d_comp": D_COMP, "n_classes": N_CLASSES, "experts": EXPERTS, "cv_target": consensus_cv_128}, "pca_proj": pca_proj, "consensus_cv_768": consensus_cv_768, "consensus_cv_128": consensus_cv_128, "epoch": epoch+1, "mAP": mAP, "cv": v_cv, "r1": r1, }, "checkpoints/base_tier_best.pt") mk = " ★" print(f" E{epoch+1:2d}: mAP={mAP:.3f} F1={macro_f1:.3f} R@1={r1:.3f} " f"cos={v_cos:.3f} cv={v_cv:.4f} anchors={n_active}/256 " f"nce={tn/nb:.3f} loss={tl/nb:.4f}{mk}") writer.close() print(f"\n Best mAP: {best_mAP:.3f}") print(f" CV target: {consensus_cv_128:.4f}") print(f"\n{'='*65}\nDONE\n{'='*65}")