#!/usr/bin/env python3 """ BASE TIER PATCHWORK SOUP First we make the soup, then we refine it into stew, and then crystalize it into matter. Parmas in current configuration: Parameters: projectors : 296,064 constellation : 32,768 patchwork : 100,864 classifier : 370,256 total : 799,952 ========================= 3 experts, all 768-d, no dimensional mismatch: clip_l14_openai — semantic (text-supervised) dinov2_b14 — structural (self-supervised) siglip_b16_384 — semantic (sigmoid contrastive) Architecture: Per-expert: projection 768 → 128 (learned, on hypersphere) Constellation: 256 anchors at 128-d (dynamic, geometric autograd) Patchwork: 8 compartments reading triangulation distances Classifier: patchwork output → 80-class multi-label Phase 1: Soup the three experts into the constellation Phase 2: Alignment bank reads the crystallized geometry 2,201 patches per image across 3 experts, compressed into 256 anchors on the 128-d hypersphere. The sphere has more than enough capacity — this is a fraction of what 128-d holds. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import gc DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Architecture D_EXPERT = 768 D_ANCHOR = 128 # anchor/constellation dimension N_ANCHORS = 256 # 256 anchors on the 128-d hypersphere N_CLASSES = 80 N_COMP = 8 # patchwork compartments D_COMP = 64 # per-compartment output # Training BATCH = 128 EPOCHS = 20 LR = 1e-3 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] print("=" * 65) print("BASE TIER PATCHWORK SOUP") print(f" 3 experts × {D_EXPERT}-d → {N_ANCHORS} anchors × {D_ANCHOR}-d") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() @torch.no_grad() def cv_metric(emb, n_samples=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n_samples): idx = torch.randperm(B)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = torch.tensor(vols) return float(a.std() / (a.mean() + 1e-8)) def anchor_spread_loss(anchors): a = F.normalize(anchors, dim=-1) sim = a @ a.T sim = sim - torch.diag(torch.diag(sim)) return sim.pow(2).mean() def anchor_entropy_loss(emb, anchors, sharpness=10.0): a = F.normalize(anchors, dim=-1) probs = F.softmax(emb @ a.T * sharpness, dim=-1) return -(probs * (probs + 1e-12).log()).sum(-1).mean() def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None # ══════════════════════════════════════════════════════════════════ # MODEL # ══════════════════════════════════════════════════════════════════ class ExpertProjector(nn.Module): """768-d → 128-d, L2-normalized onto hypersphere.""" def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR): super().__init__() self.proj = nn.Sequential( nn.Linear(d_in, d_out), nn.LayerNorm(d_out), ) def forward(self, x): return F.normalize(self.proj(x), dim=-1) class Constellation(nn.Module): def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR): super().__init__() self.n_anchors = n_anchors self.anchors = nn.Parameter(F.normalize( torch.randn(n_anchors, d), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): super().__init__() self.n_comp = n_comp asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for k in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) class BaseTierSoup(nn.Module): """ 3-expert soup on 128-d hypersphere. Each expert: 768-d → projector → 128-d (on sphere) Per-image: 3 projected embeddings → mean → on sphere Constellation: 256 anchors at 128-d Patchwork: 8 compartments → classifier The projectors learn to place each expert's perspective into the shared 128-d anchor space. The constellation crystallizes through geometric autograd. """ def __init__(self, n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES): super().__init__() self.n_experts = n_experts self.d_anchor = d_anchor # Per-expert projection to anchor space self.projectors = nn.ModuleList([ ExpertProjector(d_expert, d_anchor) for _ in range(n_experts)]) # Geometric pipeline self.constellation = Constellation(n_anchors, d_anchor) self.patchwork = Patchwork(n_anchors, n_comp, d_comp) # Classifier pw_dim = n_comp * d_comp self.classifier = nn.Sequential( nn.Linear(pw_dim + d_anchor, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.1), nn.Linear(pw_dim, n_classes)) def forward(self, expert_embeddings, apply_autograd=True): """ expert_embeddings: list of (B, 768) tensors, one per expert """ # Project each expert to 128-d hypersphere projected = [self.projectors[i](expert_embeddings[i]) for i in range(self.n_experts)] # Fuse: mean on hypersphere (normalize after averaging) fused = F.normalize(sum(projected) / self.n_experts, dim=-1) # Geometric autograd if apply_autograd and self.training: fused = EmbeddingAutograd.apply( fused, fused, self.constellation.anchors, 0.01, 1.0) # Triangulate + patchwork tri, nearest = self.constellation.triangulate(fused) pw = self.patchwork(tri) # Classify logits = self.classifier(torch.cat([pw, fused], -1)) return logits, fused, tri, nearest, projected def count_params(self): proj = sum(sum(p.numel() for p in pr.parameters()) for pr in self.projectors) const = sum(p.numel() for p in self.constellation.parameters()) pw = sum(p.numel() for p in self.patchwork.parameters()) cls = sum(p.numel() for p in self.classifier.parameters()) return {"projectors": proj, "constellation": const, "patchwork": pw, "classifier": cls, "total": proj + const + pw + cls} # ══════════════════════════════════════════════════════════════════ # LOAD DATA # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("LOADING DATA") print(f"{'='*65}") from datasets import load_dataset # Reference for image_ids and labels ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train") train_ids = ref["image_id"]; N_train = len(train_ids) train_id_map = {iid: i for i, iid in enumerate(train_ids)} train_labels_raw = ref["labels"] train_label_matrix = torch.zeros(N_train, N_CLASSES) for i, labs in enumerate(train_labels_raw): for l in labs: if l < N_CLASSES: train_label_matrix[i, l] = 1.0 ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") val_ids = ref_val["image_id"]; N_val = len(val_ids) val_id_map = {iid: i for i, iid in enumerate(val_ids)} val_labels_raw = ref_val["labels"] val_label_matrix = torch.zeros(N_val, N_CLASSES) for i, labs in enumerate(val_labels_raw): for l in labs: if l < N_CLASSES: val_label_matrix[i, l] = 1.0 print(f" Train: {N_train:,} Val: {N_val:,}") # Load 3 experts train_feats = [] val_feats = [] for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") feats = torch.zeros(N_train, D_EXPERT) for row in ds: if row["image_id"] in train_id_map: feats[train_id_map[row["image_id"]]] = torch.tensor( row["features"], dtype=torch.float32) train_feats.append(feats) ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") feats_v = torch.zeros(N_val, D_EXPERT) for row in ds_v: if row["image_id"] in val_id_map: feats_v[val_id_map[row["image_id"]]] = torch.tensor( row["features"], dtype=torch.float32) val_feats.append(feats_v) print(f" {name:<30} loaded", flush=True) del ds, ds_v; gc.collect() # Move val to GPU val_feats_gpu = [f.to(DEVICE) for f in val_feats] val_labels_gpu = val_label_matrix.to(DEVICE) train_labels_gpu = train_label_matrix.to(DEVICE) # ══════════════════════════════════════════════════════════════════ # BUILD MODEL # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("BUILDING MODEL") print(f"{'='*65}") model = BaseTierSoup( n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES).to(DEVICE) params = model.count_params() print(f" Parameters:") for k, v in params.items(): print(f" {k:<15}: {v:>10,}") # ══════════════════════════════════════════════════════════════════ # TRAIN # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("TRAINING") print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") print(f" Adam, no weight decay (geometry IS the regularization)") print(f"{'='*65}") optimizer = torch.optim.Adam(model.parameters(), lr=LR) best_mAP = 0.0 from torch.utils.tensorboard import SummaryWriter os.makedirs("checkpoints", exist_ok=True) writer = SummaryWriter("runs/base_tier_soup") gs = 0 for epoch in range(EPOCHS): model.train() perm = torch.randperm(N_train) tl, nb = 0, 0 for i in range(0, N_train, BATCH): idx = perm[i:i+BATCH] if len(idx) < 4: continue # Move batch to GPU batch_experts = [train_feats[e][idx].to(DEVICE) for e in range(3)] labels = train_labels_gpu[idx] logits, fused, tri, nearest, projected = model(batch_experts) anchors = model.constellation.anchors # Classification l_cls = F.binary_cross_entropy_with_logits(logits, labels) # Geometric losses l_cv = cv_loss(fused, target=0.2) l_spread = anchor_spread_loss(anchors) l_ent = anchor_entropy_loss(fused, anchors) # Per-expert agreement: all projections should be close l_agree = 0.0 for pi in range(3): for pj in range(pi+1, 3): l_agree += (1.0 - F.cosine_similarity( projected[pi], projected[pj], dim=-1)).mean() l_agree = l_agree / 3.0 # 3 pairs loss = (l_cls + 0.001 * l_cv + 1e-3 * l_spread + 1e-4 * l_ent + 0.1 * l_agree) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step(); optimizer.zero_grad(set_to_none=True) tl += loss.item(); nb += 1; gs += 1 if gs % 100 == 0: writer.add_scalar("train/loss", loss.item(), gs) writer.add_scalar("train/cls", l_cls.item(), gs) writer.add_scalar("train/cv", l_cv.item(), gs) writer.add_scalar("train/agree", l_agree, gs) # Validation model.eval() with torch.no_grad(): all_lo, all_em = [], [] for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) batch_v = [val_feats_gpu[e][j:end] for e in range(3)] lo, em, _, _, _ = model(batch_v, apply_autograd=False) all_lo.append(lo.cpu()); all_em.append(em.cpu()) v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) # mAP v_lab = val_label_matrix ap_sum, nv = 0, 0 for c in range(N_CLASSES): if v_lab[:, c].sum() > 0: si = v_lo[:, c].argsort(descending=True) st = v_lab[:, c][si] pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 mAP = ap_sum / max(nv, 1) # F1 vp = (v_lo.sigmoid() > 0.5).float() tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0) fn = ((1-vp) * v_lab).sum(0) pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8) f1 = 2*pr*rc/(pr+rc+1e-8) macro_f1 = f1[f1 > 0].mean().item() v_cv = cv_metric(v_em) # Expert agreement all_proj = [] for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) batch_v = [val_feats_gpu[e][j:end] for e in range(3)] _, _, _, _, proj = model(batch_v, apply_autograd=False) all_proj.append([p.cpu() for p in proj]) proj_stacked = [torch.cat([ap[e] for ap in all_proj]) for e in range(3)] agree_01 = F.cosine_similarity(proj_stacked[0], proj_stacked[1], dim=-1).mean().item() agree_02 = F.cosine_similarity(proj_stacked[0], proj_stacked[2], dim=-1).mean().item() agree_12 = F.cosine_similarity(proj_stacked[1], proj_stacked[2], dim=-1).mean().item() writer.add_scalar("val/mAP", mAP, epoch+1) writer.add_scalar("val/F1", macro_f1, epoch+1) writer.add_scalar("val/cv", v_cv, epoch+1) writer.add_scalar("val/agree_clip_dino", agree_01, epoch+1) writer.add_scalar("val/agree_clip_siglip", agree_02, epoch+1) writer.add_scalar("val/agree_dino_siglip", agree_12, epoch+1) mk = "" if mAP > best_mAP: best_mAP = mAP torch.save({ "state_dict": model.state_dict(), "config": {"d_expert": D_EXPERT, "d_anchor": D_ANCHOR, "n_anchors": N_ANCHORS, "n_comp": N_COMP, "d_comp": D_COMP, "n_classes": N_CLASSES, "experts": EXPERTS}, "epoch": epoch+1, "mAP": mAP, "cv": v_cv, }, "checkpoints/base_tier_best.pt") mk = " ★" print(f" E{epoch+1:2d}: mAP={mAP:.3f} F1={macro_f1:.3f} cv={v_cv:.4f} " f"agree=[{agree_01:.3f},{agree_02:.3f},{agree_12:.3f}] " f"loss={tl/nb:.4f}{mk}") writer.close() print(f"\n Best mAP: {best_mAP:.3f}") print(f" Model: {params['total']:,} params") print(f" Anchors: {N_ANCHORS} × {D_ANCHOR}-d") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}")