| |
| """ |
| BASE TIER PATCHWORK SOUP |
| |
| First we make the soup, then we refine it into stew, and then crystalize it into matter. |
| |
| Parmas in current configuration: |
| |
| Parameters: |
| projectors : 296,064 |
| constellation : 32,768 |
| patchwork : 100,864 |
| classifier : 370,256 |
| total : 799,952 |
| |
| |
| ========================= |
| 3 experts, all 768-d, no dimensional mismatch: |
| clip_l14_openai — semantic (text-supervised) |
| dinov2_b14 — structural (self-supervised) |
| siglip_b16_384 — semantic (sigmoid contrastive) |
| |
| Architecture: |
| Per-expert: projection 768 → 128 (learned, on hypersphere) |
| Constellation: 256 anchors at 128-d (dynamic, geometric autograd) |
| Patchwork: 8 compartments reading triangulation distances |
| Classifier: patchwork output → 80-class multi-label |
| |
| Phase 1: Soup the three experts into the constellation |
| Phase 2: Alignment bank reads the crystallized geometry |
| |
| 2,201 patches per image across 3 experts, compressed into |
| 256 anchors on the 128-d hypersphere. The sphere has more |
| than enough capacity — this is a fraction of what 128-d holds. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| import gc |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| D_EXPERT = 768 |
| D_ANCHOR = 128 |
| N_ANCHORS = 256 |
| N_CLASSES = 80 |
| N_COMP = 8 |
| D_COMP = 64 |
|
|
| |
| BATCH = 128 |
| EPOCHS = 20 |
| LR = 1e-3 |
|
|
| EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
|
|
| print("=" * 65) |
| print("BASE TIER PATCHWORK SOUP") |
| print(f" 3 experts × {D_EXPERT}-d → {N_ANCHORS} anchors × {D_ANCHOR}-d") |
| print(f" Device: {DEVICE}") |
| print("=" * 65) |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.2, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() |
|
|
| @torch.no_grad() |
| def cv_metric(emb, n_samples=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = torch.tensor(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def anchor_spread_loss(anchors): |
| a = F.normalize(anchors, dim=-1) |
| sim = a @ a.T |
| sim = sim - torch.diag(torch.diag(sim)) |
| return sim.pow(2).mean() |
|
|
| def anchor_entropy_loss(emb, anchors, sharpness=10.0): |
| a = F.normalize(anchors, dim=-1) |
| probs = F.softmax(emb @ a.T * sharpness, dim=-1) |
| return -(probs * (probs + 1e-12).log()).sum(-1).mean() |
|
|
| def infonce(a, b, temperature=0.07): |
| a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n |
| corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(-1, keepdim=True) |
| corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
|
|
| |
| |
| |
|
|
| class ExpertProjector(nn.Module): |
| """768-d → 128-d, L2-normalized onto hypersphere.""" |
| def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(d_in, d_out), |
| nn.LayerNorm(d_out), |
| ) |
| def forward(self, x): |
| return F.normalize(self.proj(x), dim=-1) |
|
|
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.anchors = nn.Parameter(F.normalize( |
| torch.randn(n_anchors, d), dim=-1)) |
|
|
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| cos = emb @ a.T |
| return 1.0 - cos, cos.argmax(dim=-1) |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): |
| super().__init__() |
| self.n_comp = n_comp |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for k in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| class BaseTierSoup(nn.Module): |
| """ |
| 3-expert soup on 128-d hypersphere. |
| |
| Each expert: 768-d → projector → 128-d (on sphere) |
| Per-image: 3 projected embeddings → mean → on sphere |
| Constellation: 256 anchors at 128-d |
| Patchwork: 8 compartments → classifier |
| |
| The projectors learn to place each expert's perspective |
| into the shared 128-d anchor space. The constellation |
| crystallizes through geometric autograd. |
| """ |
| def __init__(self, n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR, |
| n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, |
| n_classes=N_CLASSES): |
| super().__init__() |
| self.n_experts = n_experts |
| self.d_anchor = d_anchor |
|
|
| |
| self.projectors = nn.ModuleList([ |
| ExpertProjector(d_expert, d_anchor) for _ in range(n_experts)]) |
|
|
| |
| self.constellation = Constellation(n_anchors, d_anchor) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
|
|
| |
| pw_dim = n_comp * d_comp |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + d_anchor, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), |
| nn.Dropout(0.1), |
| nn.Linear(pw_dim, n_classes)) |
|
|
| def forward(self, expert_embeddings, apply_autograd=True): |
| """ |
| expert_embeddings: list of (B, 768) tensors, one per expert |
| """ |
| |
| projected = [self.projectors[i](expert_embeddings[i]) |
| for i in range(self.n_experts)] |
|
|
| |
| fused = F.normalize(sum(projected) / self.n_experts, dim=-1) |
|
|
| |
| if apply_autograd and self.training: |
| fused = EmbeddingAutograd.apply( |
| fused, fused, self.constellation.anchors, 0.01, 1.0) |
|
|
| |
| tri, nearest = self.constellation.triangulate(fused) |
| pw = self.patchwork(tri) |
|
|
| |
| logits = self.classifier(torch.cat([pw, fused], -1)) |
|
|
| return logits, fused, tri, nearest, projected |
|
|
| def count_params(self): |
| proj = sum(sum(p.numel() for p in pr.parameters()) for pr in self.projectors) |
| const = sum(p.numel() for p in self.constellation.parameters()) |
| pw = sum(p.numel() for p in self.patchwork.parameters()) |
| cls = sum(p.numel() for p in self.classifier.parameters()) |
| return {"projectors": proj, "constellation": const, |
| "patchwork": pw, "classifier": cls, |
| "total": proj + const + pw + cls} |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("LOADING DATA") |
| print(f"{'='*65}") |
|
|
| from datasets import load_dataset |
|
|
| |
| ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train") |
| train_ids = ref["image_id"]; N_train = len(train_ids) |
| train_id_map = {iid: i for i, iid in enumerate(train_ids)} |
| train_labels_raw = ref["labels"] |
| train_label_matrix = torch.zeros(N_train, N_CLASSES) |
| for i, labs in enumerate(train_labels_raw): |
| for l in labs: |
| if l < N_CLASSES: train_label_matrix[i, l] = 1.0 |
|
|
| ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") |
| val_ids = ref_val["image_id"]; N_val = len(val_ids) |
| val_id_map = {iid: i for i, iid in enumerate(val_ids)} |
| val_labels_raw = ref_val["labels"] |
| val_label_matrix = torch.zeros(N_val, N_CLASSES) |
| for i, labs in enumerate(val_labels_raw): |
| for l in labs: |
| if l < N_CLASSES: val_label_matrix[i, l] = 1.0 |
|
|
| print(f" Train: {N_train:,} Val: {N_val:,}") |
|
|
| |
| train_feats = [] |
| val_feats = [] |
| for name in EXPERTS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") |
| feats = torch.zeros(N_train, D_EXPERT) |
| for row in ds: |
| if row["image_id"] in train_id_map: |
| feats[train_id_map[row["image_id"]]] = torch.tensor( |
| row["features"], dtype=torch.float32) |
| train_feats.append(feats) |
|
|
| ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| feats_v = torch.zeros(N_val, D_EXPERT) |
| for row in ds_v: |
| if row["image_id"] in val_id_map: |
| feats_v[val_id_map[row["image_id"]]] = torch.tensor( |
| row["features"], dtype=torch.float32) |
| val_feats.append(feats_v) |
| print(f" {name:<30} loaded", flush=True) |
| del ds, ds_v; gc.collect() |
|
|
| |
| val_feats_gpu = [f.to(DEVICE) for f in val_feats] |
| val_labels_gpu = val_label_matrix.to(DEVICE) |
| train_labels_gpu = train_label_matrix.to(DEVICE) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("BUILDING MODEL") |
| print(f"{'='*65}") |
|
|
| model = BaseTierSoup( |
| n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR, |
| n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, |
| n_classes=N_CLASSES).to(DEVICE) |
|
|
| params = model.count_params() |
| print(f" Parameters:") |
| for k, v in params.items(): |
| print(f" {k:<15}: {v:>10,}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("TRAINING") |
| print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") |
| print(f" Adam, no weight decay (geometry IS the regularization)") |
| print(f"{'='*65}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
| best_mAP = 0.0 |
|
|
| from torch.utils.tensorboard import SummaryWriter |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/base_tier_soup") |
| gs = 0 |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| perm = torch.randperm(N_train) |
| tl, nb = 0, 0 |
|
|
| for i in range(0, N_train, BATCH): |
| idx = perm[i:i+BATCH] |
| if len(idx) < 4: continue |
|
|
| |
| batch_experts = [train_feats[e][idx].to(DEVICE) for e in range(3)] |
| labels = train_labels_gpu[idx] |
|
|
| logits, fused, tri, nearest, projected = model(batch_experts) |
| anchors = model.constellation.anchors |
|
|
| |
| l_cls = F.binary_cross_entropy_with_logits(logits, labels) |
|
|
| |
| l_cv = cv_loss(fused, target=0.2) |
| l_spread = anchor_spread_loss(anchors) |
| l_ent = anchor_entropy_loss(fused, anchors) |
|
|
| |
| l_agree = 0.0 |
| for pi in range(3): |
| for pj in range(pi+1, 3): |
| l_agree += (1.0 - F.cosine_similarity( |
| projected[pi], projected[pj], dim=-1)).mean() |
| l_agree = l_agree / 3.0 |
|
|
| loss = (l_cls |
| + 0.001 * l_cv |
| + 1e-3 * l_spread |
| + 1e-4 * l_ent |
| + 0.1 * l_agree) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) |
|
|
| tl += loss.item(); nb += 1; gs += 1 |
|
|
| if gs % 100 == 0: |
| writer.add_scalar("train/loss", loss.item(), gs) |
| writer.add_scalar("train/cls", l_cls.item(), gs) |
| writer.add_scalar("train/cv", l_cv.item(), gs) |
| writer.add_scalar("train/agree", l_agree, gs) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| all_lo, all_em = [], [] |
| for j in range(0, N_val, BATCH): |
| end = min(j + BATCH, N_val) |
| batch_v = [val_feats_gpu[e][j:end] for e in range(3)] |
| lo, em, _, _, _ = model(batch_v, apply_autograd=False) |
| all_lo.append(lo.cpu()); all_em.append(em.cpu()) |
| v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) |
|
|
| |
| v_lab = val_label_matrix |
| ap_sum, nv = 0, 0 |
| for c in range(N_CLASSES): |
| if v_lab[:, c].sum() > 0: |
| si = v_lo[:, c].argsort(descending=True) |
| st = v_lab[:, c][si] |
| pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() |
| ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 |
| mAP = ap_sum / max(nv, 1) |
|
|
| |
| vp = (v_lo.sigmoid() > 0.5).float() |
| tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0) |
| fn = ((1-vp) * v_lab).sum(0) |
| pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8) |
| f1 = 2*pr*rc/(pr+rc+1e-8) |
| macro_f1 = f1[f1 > 0].mean().item() |
|
|
| v_cv = cv_metric(v_em) |
|
|
| |
| all_proj = [] |
| for j in range(0, N_val, BATCH): |
| end = min(j + BATCH, N_val) |
| batch_v = [val_feats_gpu[e][j:end] for e in range(3)] |
| _, _, _, _, proj = model(batch_v, apply_autograd=False) |
| all_proj.append([p.cpu() for p in proj]) |
| proj_stacked = [torch.cat([ap[e] for ap in all_proj]) for e in range(3)] |
| agree_01 = F.cosine_similarity(proj_stacked[0], proj_stacked[1], dim=-1).mean().item() |
| agree_02 = F.cosine_similarity(proj_stacked[0], proj_stacked[2], dim=-1).mean().item() |
| agree_12 = F.cosine_similarity(proj_stacked[1], proj_stacked[2], dim=-1).mean().item() |
|
|
| writer.add_scalar("val/mAP", mAP, epoch+1) |
| writer.add_scalar("val/F1", macro_f1, epoch+1) |
| writer.add_scalar("val/cv", v_cv, epoch+1) |
| writer.add_scalar("val/agree_clip_dino", agree_01, epoch+1) |
| writer.add_scalar("val/agree_clip_siglip", agree_02, epoch+1) |
| writer.add_scalar("val/agree_dino_siglip", agree_12, epoch+1) |
|
|
| mk = "" |
| if mAP > best_mAP: |
| best_mAP = mAP |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": {"d_expert": D_EXPERT, "d_anchor": D_ANCHOR, |
| "n_anchors": N_ANCHORS, "n_comp": N_COMP, |
| "d_comp": D_COMP, "n_classes": N_CLASSES, |
| "experts": EXPERTS}, |
| "epoch": epoch+1, "mAP": mAP, "cv": v_cv, |
| }, "checkpoints/base_tier_best.pt") |
| mk = " ★" |
|
|
| print(f" E{epoch+1:2d}: mAP={mAP:.3f} F1={macro_f1:.3f} cv={v_cv:.4f} " |
| f"agree=[{agree_01:.3f},{agree_02:.3f},{agree_12:.3f}] " |
| f"loss={tl/nb:.4f}{mk}") |
|
|
| writer.close() |
|
|
| print(f"\n Best mAP: {best_mAP:.3f}") |
| print(f" Model: {params['total']:,} params") |
| print(f" Anchors: {N_ANCHORS} × {D_ANCHOR}-d") |
| print(f"\n{'='*65}") |
| print("DONE") |
| print(f"{'='*65}") |