#!/usr/bin/env python3 """ GEOLIP MASSIVE SOUP — ORTHO SPECTRUM HYPERSPHERE ================================================== 2048 anchors × 256-d × 3 expert perspectives. Orthogonal initialization: 8 rotated orthogonal bases of 256 vectors = 2048. Each base tiles a different region of S^255. Together they form a structured mesh with known geometric relationships. Multi-depth patchwork: Level 0 (coarse): 16 compartments × 128 anchors × 3 experts = 384 inputs each → 128-d Level 1 (fine): 64 compartments × 32 anchors × 3 experts = 96 inputs each → 64-d Level 2 (micro): 128 compartments × 16 anchors × 3 experts = 48 inputs each → 32-d Total patchwork output: 16×128 + 64×64 + 128×32 = 2048 + 4096 + 4096 = 10240-d → project down to 1024 before classifier The depth levels read the sphere at different resolutions. Coarse catches global position, fine catches local neighborhood, micro catches sub-anchor structure. Each level has its own expert-aware triangulation view. GPA → PCA 256-d → Procrustes calibration → train with full loss stack. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import gc from tqdm import tqdm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Geometry D_EXPERT = 768 D_ANCHOR = 256 N_ANCHORS = 2048 N_ORTHO_BASES = 8 # 8 × 256 = 2048 N_EXPERTS_COUNT = 3 N_CLASSES = 80 ANCHOR_DROP = 0.30 # Multi-depth patchwork COARSE_COMP = 16 # 2048/16 = 128 anchors per comp FINE_COMP = 64 # 2048/64 = 32 anchors per comp MICRO_COMP = 128 # 2048/128 = 16 anchors per comp D_COARSE = 128 D_FINE = 64 D_MICRO = 32 D_PW_PROJ = 1024 # project combined patchwork to this # Training BATCH = 128 EPOCHS = 30 LR = 1e-3 QUEUE_SIZE = 4096 GRAD_CLIP = 1.0 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] TRI_DIM = N_ANCHORS * N_EXPERTS_COUNT print("=" * 65) print("GEOLIP MASSIVE SOUP — ORTHO SPECTRUM") print(f" {N_ANCHORS} anchors × {D_ANCHOR}-d × {N_EXPERTS_COUNT} perspectives") print(f" Ortho bases: {N_ORTHO_BASES} × {D_ANCHOR} = {N_ANCHORS}") print(f" Patchwork: coarse({COARSE_COMP}) + fine({FINE_COMP}) + micro({MICRO_COMP})") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() @torch.no_grad() def cv_metric(emb, n_samples=500): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) def infonce_queued(emb, targets, queue_emb, queue_tgt, temperature=0.07): B = emb.shape[0] e = F.normalize(emb, dim=-1); t = F.normalize(targets, dim=-1) if queue_tgt is not None and queue_tgt.shape[0] > 0: at = torch.cat([t, queue_tgt], 0); ae = torch.cat([e, queue_emb], 0) else: at = t; ae = e l_e2t = (e @ at.T) / temperature; l_t2e = (t @ ae.T) / temperature labels = torch.arange(B, device=emb.device) loss = (F.cross_entropy(l_e2t, labels) + F.cross_entropy(l_t2e, labels)) / 2 with torch.no_grad(): acc = (l_e2t.argmax(-1) == labels).float().mean().item() return loss, acc def whitened_procrustes_loss(emb, targets): B = emb.shape[0] if B < 10: return torch.tensor(0.0, device=emb.device) em = emb.float().mean(0, keepdim=True); tm = targets.float().mean(0, keepdim=True) return 1.0 - F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1).mean() class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T def procrustes_align(source, target, n_align=10000): N = min(n_align, source.shape[0], target.shape[0]) S = source[:N].float(); T = target[:N].float() sm = S.mean(0, keepdim=True); tm = T.mean(0, keepdim=True) Sc = S - sm; Tc = T - tm; Ns = Sc.shape[0] sw = symmetric_inv_sqrt((Sc.T @ Sc) / max(Ns-1, 1)) tw = symmetric_inv_sqrt((Tc.T @ Tc) / max(Ns-1, 1)) Sc_w = F.normalize(Sc @ sw, dim=-1); Tc_w = F.normalize(Tc @ tw, dim=-1) U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) return {"rotation": U @ Vt, "source_mean": sm.squeeze(0), "source_whitener": sw, "target_unwhitener": torch.linalg.pinv(tw)} def apply_align(emb, a): x = emb.float() - a["source_mean"] return x @ a["source_whitener"] @ a["rotation"].T @ a["target_unwhitener"] # ══════════════════════════════════════════════════════════════════ # ORTHO ANCHOR INITIALIZATION # ══════════════════════════════════════════════════════════════════ def init_ortho_anchors(d, n_bases): """ Generate n_bases × d anchors from rotated orthonormal bases. Each base is a full d×d orthogonal matrix (d vectors). We take d vectors from each, rotated to tile different regions. Total: n_bases × d anchors. """ all_anchors = [] # First base: identity-like (from QR of random) base = torch.randn(d, d) Q, _ = torch.linalg.qr(base) all_anchors.append(Q) # d × d, each row is unit vector, all orthogonal for i in range(1, n_bases): # Generate random rotation R_rand = torch.randn(d, d) R_q, _ = torch.linalg.qr(R_rand) # Rotate the base rotated = Q @ R_q.T all_anchors.append(rotated) anchors = torch.cat(all_anchors, dim=0) # (n_bases*d, d) return F.normalize(anchors, dim=-1) # ══════════════════════════════════════════════════════════════════ # FUSED CONSTELLATION (2048 anchors) # ══════════════════════════════════════════════════════════════════ class FusedConstellation(nn.Module): def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR, n_experts=N_EXPERTS_COUNT, drop_rate=ANCHOR_DROP): super().__init__() self.n_anchors = n_anchors self.n_experts = n_experts self.drop_rate = drop_rate self.d = d self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) self.expert_rotations = nn.ParameterList([ nn.Parameter(torch.eye(d)) for _ in range(n_experts)]) self.expert_whiteners = nn.ParameterList([ nn.Parameter(torch.eye(d)) for _ in range(n_experts)]) self.expert_means = nn.ParameterList([ nn.Parameter(torch.zeros(d)) for _ in range(n_experts)]) def triangulate(self, emb, training=False): B = emb.shape[0] anchors_n = F.normalize(self.anchors, dim=-1) expert_embs = [] for i in range(self.n_experts): centered = emb.float() - self.expert_means[i] whitened = centered @ self.expert_whiteners[i] rotated = F.normalize(whitened @ self.expert_rotations[i].T, dim=-1) expert_embs.append(rotated) if training and self.drop_rate > 0: n_keep = max(int(self.n_anchors * (1 - self.drop_rate)), 128) keep_idx = torch.randperm(self.n_anchors, device=emb.device)[:n_keep] a_masked = anchors_n[keep_idx] expert_tris, expert_cos_list = [], [] for rotated in expert_embs: cos = rotated @ a_masked.T full_cos = torch.full((B, self.n_anchors), -1.0, device=emb.device, dtype=cos.dtype) full_cos[:, keep_idx] = cos expert_tris.append(1.0 - full_cos) expert_cos_list.append(full_cos) else: expert_tris, expert_cos_list = [], [] for rotated in expert_embs: cos = rotated @ anchors_n.T expert_tris.append(1.0 - cos) expert_cos_list.append(cos) tri_stacked = torch.stack(expert_tris, dim=-1) # (B, N_ANCHORS, 3) tri_fused = tri_stacked.reshape(B, -1) mean_cos = torch.stack(expert_cos_list, dim=-1).mean(dim=-1) nearest = mean_cos.argmax(dim=-1) return tri_fused, nearest, tri_stacked def anchor_spread_loss(self): # Sample-based for 2048 anchors (full 2048×2048 is too big) a = F.normalize(self.anchors, dim=-1) idx = torch.randperm(self.n_anchors, device=a.device)[:512] a_sub = a[idx] sim = a_sub @ a_sub.T; sim = sim - torch.diag(torch.diag(sim)) return sim.pow(2).mean() def expert_agreement_loss(self, emb): anchors_n = F.normalize(self.anchors[:512], dim=-1) # subsample for speed expert_cos = [] for i in range(self.n_experts): centered = emb.float() - self.expert_means[i] rotated = F.normalize(centered @ self.expert_whiteners[i] @ self.expert_rotations[i].T, dim=-1) expert_cos.append(rotated @ anchors_n.T) stacked = torch.stack(expert_cos, dim=-1) disagree = stacked.std(dim=-1) return (disagree.mean() - 0.05).abs() # ══════════════════════════════════════════════════════════════════ # MULTI-DEPTH PATCHWORK # ══════════════════════════════════════════════════════════════════ class DepthLevel(nn.Module): """Single depth level of the patchwork — reads a specific granularity.""" def __init__(self, n_anchors, n_comp, n_experts, d_comp): super().__init__() self.n_comp = n_comp self.n_experts = n_experts asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) inputs_per_comp = (n_anchors // n_comp) * n_experts self.comps = nn.ModuleList([nn.Sequential( nn.Linear(inputs_per_comp, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for _ in range(n_comp)]) def forward(self, tri_3d): """tri_3d: (B, n_anchors, n_experts)""" B = tri_3d.shape[0] results = [] for k in range(self.n_comp): mask = self.asgn == k comp_input = tri_3d[:, mask, :].reshape(B, -1) results.append(self.comps[k](comp_input)) return torch.cat(results, dim=-1) class MultiDepthPatchwork(nn.Module): """ Reads the sphere at 3 resolutions: Coarse: 16 compartments, 128 anchors each — global position Fine: 64 compartments, 32 anchors each — local neighborhood Micro: 128 compartments, 16 anchors each — sub-anchor structure Combined output projected to D_PW_PROJ. """ def __init__(self): super().__init__() self.coarse = DepthLevel(N_ANCHORS, COARSE_COMP, N_EXPERTS_COUNT, D_COARSE) self.fine = DepthLevel(N_ANCHORS, FINE_COMP, N_EXPERTS_COUNT, D_FINE) self.micro = DepthLevel(N_ANCHORS, MICRO_COMP, N_EXPERTS_COUNT, D_MICRO) # Coarse: 16 × 128 = 2048 # Fine: 64 × 64 = 4096 # Micro: 128 × 32 = 4096 total_dim = COARSE_COMP * D_COARSE + FINE_COMP * D_FINE + MICRO_COMP * D_MICRO self.proj = nn.Sequential( nn.Linear(total_dim, D_PW_PROJ), nn.GELU(), nn.LayerNorm(D_PW_PROJ)) def forward(self, tri_3d): """tri_3d: (B, N_ANCHORS, N_EXPERTS_COUNT)""" c = self.coarse(tri_3d) f = self.fine(tri_3d) m = self.micro(tri_3d) combined = torch.cat([c, f, m], dim=-1) return self.proj(combined) # ══════════════════════════════════════════════════════════════════ # EXPERT PROJECTOR + MODEL # ══════════════════════════════════════════════════════════════════ class ExpertProjector(nn.Module): def __init__(self): super().__init__() self.proj = nn.Sequential(nn.Linear(D_EXPERT, D_ANCHOR), nn.LayerNorm(D_ANCHOR)) def forward(self, x): return F.normalize(self.proj(x), dim=-1) class MassiveSoup(nn.Module): def __init__(self): super().__init__() self.projectors = nn.ModuleList([ExpertProjector() for _ in range(N_EXPERTS_COUNT)]) self.constellation = FusedConstellation() self.patchwork = MultiDepthPatchwork() self.classifier = nn.Sequential( nn.Linear(D_PW_PROJ + D_ANCHOR, D_PW_PROJ), nn.GELU(), nn.LayerNorm(D_PW_PROJ), nn.Dropout(0.1), nn.Linear(D_PW_PROJ, N_CLASSES)) def forward(self, expert_features, apply_autograd=True): projected = [self.projectors[i](expert_features[i]) for i in range(N_EXPERTS_COUNT)] fused = F.normalize(sum(projected) / N_EXPERTS_COUNT, dim=-1) if apply_autograd and self.training: fused = EmbeddingAutograd.apply( fused, fused, self.constellation.anchors, 0.01, 1.0) tri_fused, nearest, tri_3d = self.constellation.triangulate( fused, training=self.training) pw = self.patchwork(tri_3d) logits = self.classifier(torch.cat([pw, fused], dim=-1)) return logits, fused, tri_fused, nearest, projected # ══════════════════════════════════════════════════════════════════ # LOAD DATA + GPA + CALIBRATE # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 0: LOAD DATA") print(f"{'='*65}") from datasets import load_dataset ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train") train_ids = ref["image_id"]; N_train = len(train_ids) train_id_map = {iid: i for i, iid in enumerate(train_ids)} train_labels = torch.zeros(N_train, N_CLASSES) for i, labs in enumerate(ref["labels"]): for l in labs: if l < N_CLASSES: train_labels[i, l] = 1.0 ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") val_ids = ref_val["image_id"]; N_val = len(val_ids) val_id_map = {iid: i for i, iid in enumerate(val_ids)} val_labels = torch.zeros(N_val, N_CLASSES) for i, labs in enumerate(ref_val["labels"]): for l in labs: if l < N_CLASSES: val_labels[i, l] = 1.0 print(f" Train: {N_train:,} Val: {N_val:,}") train_raw, val_raw = {}, {} for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") feats = torch.zeros(N_train, D_EXPERT) for row in ds: if row["image_id"] in train_id_map: feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) train_raw[name] = feats ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") feats_v = torch.zeros(N_val, D_EXPERT) for row in ds_v: if row["image_id"] in val_id_map: feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) val_raw[name] = feats_v print(f" {name:<30} loaded") del ds, ds_v; gc.collect() # GPA print(f"\n{'='*65}") print("PHASE 1: GPA + PCA + PROCRUSTES") print(f"{'='*65}") current = {name: train_raw[name].float() for name in EXPERTS} for gpa_iter in range(20): mean_shape = sum(current[n] for n in EXPERTS) / len(EXPERTS) delta = 0.0 for name in EXPERTS: info = procrustes_align(current[name], mean_shape) current[name] = apply_align(current[name], info) delta += (current[name] - apply_align(train_raw[name].float(), info)).pow(2).mean().item() # Recompute properly new_current = {} delta = 0.0 for name in EXPERTS: info = procrustes_align(current[name], mean_shape) new_current[name] = apply_align(current[name], info) delta += (new_current[name] - current[name]).pow(2).mean().item() current = new_current if gpa_iter == 0 or (gpa_iter+1) % 5 == 0: print(f" GPA iter {gpa_iter+1}: delta={delta:.8f}") if delta < 1e-8: break consensus_768 = F.normalize(sum(current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) for name in EXPERTS: c = F.cosine_similarity(consensus_768[:5000], current[name][:5000], dim=-1).mean().item() print(f" cos(consensus, {name}): {c:.4f}") # PCA → 256-d cc = consensus_768 - consensus_768.mean(0, keepdim=True) U, S, Vt = torch.linalg.svd(cc[:10000], full_matrices=False) pca_proj = Vt[:D_ANCHOR] consensus_d = F.normalize(consensus_768 @ pca_proj.T, dim=-1) var_ret = S[:D_ANCHOR].pow(2).sum() / S.pow(2).sum() print(f" PCA 768→{D_ANCHOR}: var_retained={var_ret.item():.4f}") consensus_cv = cv_metric(consensus_d[:5000].to(DEVICE)) print(f" Consensus CV at {D_ANCHOR}-d: {consensus_cv:.4f}") # Val consensus val_current = {name: val_raw[name].float() for name in EXPERTS} for _ in range(20): vm = sum(val_current[n] for n in EXPERTS) / len(EXPERTS) d = 0.0 for name in EXPERTS: info = procrustes_align(val_current[name], vm) new = apply_align(val_current[name], info) d += (new - val_current[name]).pow(2).mean().item() val_current[name] = new if d < 1e-8: break val_consensus_768 = F.normalize(sum(val_current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) val_consensus_d = F.normalize(val_consensus_768 @ pca_proj.T, dim=-1) # Per-expert Procrustes expert_calibrations = {} for name in EXPERTS: raw = train_raw[name][:10000].float() tgt = consensus_d[:10000].float() sm = raw.mean(0, keepdim=True); tm = tgt.mean(0, keepdim=True) sc = raw - sm; tc = tgt - tm sw = symmetric_inv_sqrt((sc.T @ sc) / 9999) tw = symmetric_inv_sqrt((tc.T @ tc) / 9999) src_w = F.normalize(sc @ sw, dim=-1); tgt_w = F.normalize(tc @ tw, dim=-1) M = tgt_w.T @ src_w U_r, S_r, Vt_r = torch.linalg.svd(M, full_matrices=False) R = U_r @ Vt_r proj_W = (sw @ R.T).T; proj_b = -(sm.squeeze(0) @ sw @ R.T).squeeze(0) test = F.normalize(raw[:1000] @ proj_W.T + proj_b, dim=-1) cos = F.cosine_similarity(test, tgt[:1000], dim=-1).mean().item() expert_calibrations[name] = {"W": proj_W, "b": proj_b, "cos": cos, "R": R[:D_ANCHOR, :D_ANCHOR], "whiten": tw, "mean": tm.squeeze(0)} print(f" {name:<30} cos={cos:.4f}") # ══════════════════════════════════════════════════════════════════ # BUILD + INITIALIZE # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 2: BUILD MODEL") print(f"{'='*65}") model = MassiveSoup().to(DEVICE) with torch.no_grad(): # Projectors for i, name in enumerate(EXPERTS): cal = expert_calibrations[name] model.projectors[i].proj[0].weight.copy_(cal["W"].to(DEVICE)) model.projectors[i].proj[0].bias.copy_(cal["b"].to(DEVICE)) print(f" ✓ Projectors from Procrustes") # Ortho anchors ortho_anchors = init_ortho_anchors(D_ANCHOR, N_ORTHO_BASES) model.constellation.anchors.copy_(ortho_anchors.to(DEVICE)) print(f" ✓ {N_ANCHORS} ortho-spectrum anchors ({N_ORTHO_BASES} bases × {D_ANCHOR})") # Expert perspectives for i, name in enumerate(EXPERTS): cal = expert_calibrations[name] model.constellation.expert_rotations[i].copy_(cal["R"].to(DEVICE)) model.constellation.expert_whiteners[i].copy_(cal["whiten"].to(DEVICE)) model.constellation.expert_means[i].copy_(cal["mean"].to(DEVICE)) print(f" ✓ Expert perspectives calibrated") # Verify with torch.no_grad(): test_in = [train_raw[EXPERTS[e]][:200].to(DEVICE) for e in range(3)] _, test_fused, _, test_nearest, _ = model(test_in, apply_autograd=False) test_tgt = consensus_d[:200].to(DEVICE) init_cos = F.cosine_similarity(test_fused, test_tgt, dim=-1).mean().item() n_active = test_nearest.unique().numel() print(f" Init: cos={init_cos:.4f} active_anchors={n_active}/{N_ANCHORS}") # Count params def count_params(module): return sum(p.numel() for p in module.parameters()) n_total = count_params(model) print(f"\n Parameters:") print(f" projectors: {sum(count_params(p) for p in model.projectors):>12,}") print(f" constellation: {count_params(model.constellation):>12,}") print(f" patchwork: {count_params(model.patchwork):>12,}") print(f" classifier: {count_params(model.classifier):>12,}") print(f" total: {n_total:>12,}") # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("PHASE 3: TRAINING") print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") print(f" Queue: {QUEUE_SIZE} | Anchor dropout: {ANCHOR_DROP}") print(f" CV target: {consensus_cv:.4f}") print(f"{'='*65}") train_targets = consensus_d.to(DEVICE) val_targets = val_consensus_d.to(DEVICE) train_labels_gpu = train_labels.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LR) os.makedirs("checkpoints", exist_ok=True) from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter("runs/massive_soup") best_mAP = 0.0; gs = 0 queue_e = torch.zeros(0, D_ANCHOR, device=DEVICE) queue_t = torch.zeros(0, D_ANCHOR, device=DEVICE) for epoch in range(EPOCHS): model.train() perm = torch.randperm(N_train) acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0, "spread": 0, "agree": 0, "align": 0, "nce_acc": 0, "n": 0} pbar = tqdm(range(0, N_train, BATCH), desc=f"E{epoch+1:2d}/{EPOCHS}", unit="batch") for i in pbar: idx = perm[i:i+BATCH] if len(idx) < 4: continue batch = [train_raw[EXPERTS[e]][idx].to(DEVICE) for e in range(3)] labels = train_labels_gpu[idx] targets = train_targets[idx] logits, fused, tri, nearest, projected = model(batch) l_nce, nce_acc = infonce_queued(fused, targets, queue_e, queue_t) with torch.no_grad(): queue_e = torch.cat([queue_e, fused.detach()], 0)[-QUEUE_SIZE:] queue_t = torch.cat([queue_t, targets.detach()], 0)[-QUEUE_SIZE:] l_mse = F.mse_loss(fused, targets) l_bce = F.binary_cross_entropy_with_logits(logits, labels) l_align = whitened_procrustes_loss(fused, targets) l_cv = cv_loss(fused, target=consensus_cv) l_spread = model.constellation.anchor_spread_loss() l_agree = model.constellation.expert_agreement_loss(fused) loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce + 0.5 * l_align + 0.001 * l_cv + 1e-3 * l_spread + 0.1 * l_agree) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step(); optimizer.zero_grad(set_to_none=True) acc["loss"] += loss.item(); acc["nce"] += l_nce.item() acc["mse"] += l_mse.item(); acc["bce"] += l_bce.item() acc["cv"] += l_cv.item(); acc["spread"] += l_spread.item() acc["agree"] += l_agree.item() if torch.is_tensor(l_agree) else l_agree acc["align"] += l_align.item(); acc["nce_acc"] += nce_acc acc["n"] += 1; gs += 1 if gs % 50 == 0: for k in ["loss", "nce", "bce", "nce_acc"]: writer.add_scalar(f"step/{k}", acc[k]/max(acc["n"],1), gs) if acc["n"] % 20 == 0: d = acc["n"] pbar.set_postfix(loss=f"{acc['loss']/d:.4f}", nce_acc=f"{acc['nce_acc']/d:.3f}", cos=f"{1-acc['align']/d:.3f}", ordered=True) d = max(acc["n"], 1) print(f" E{epoch+1} train: loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} " f"bce={acc['bce']/d:.4f} agree={acc['agree']/d:.4f} " f"nce_acc={acc['nce_acc']/d:.3f}") # Validation model.eval() with torch.no_grad(): all_lo, all_em = [], [] for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) batch_v = [val_raw[EXPERTS[e]][j:end].to(DEVICE) for e in range(3)] lo, em, _, _, _ = model(batch_v, apply_autograd=False) all_lo.append(lo.cpu()); all_em.append(em.cpu()) v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) v_lab = val_labels ap_sum, nv = 0, 0 for c in range(N_CLASSES): if v_lab[:, c].sum() > 0: si = v_lo[:, c].argsort(descending=True); st = v_lab[:, c][si] pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 mAP = ap_sum / max(nv, 1) vp = (v_lo.sigmoid() > 0.5).float() tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0); fn = ((1-vp) * v_lab).sum(0) pr_ = tp/(tp+fp+1e-8); rc_ = tp/(tp+fn+1e-8); f1_ = 2*pr_*rc_/(pr_+rc_+1e-8) v_cos = F.cosine_similarity(v_em, val_targets.cpu(), dim=-1).mean().item() sim = v_em @ val_targets.cpu().T r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item() _, v_nearest, _ = model.constellation.triangulate(v_em.to(DEVICE), training=False) n_active = v_nearest.cpu().unique().numel() v_cv = cv_metric(v_em[:2000].to(DEVICE)) for k in acc: if k != "n": writer.add_scalar(f"epoch/{k}", acc[k]/d, epoch+1) writer.add_scalar("val/mAP", mAP, epoch+1) writer.add_scalar("val/cos", v_cos, epoch+1) writer.add_scalar("val/R@1", r1, epoch+1) writer.add_scalar("val/anchors", n_active, epoch+1) writer.add_scalar("val/cv", v_cv, epoch+1) mk = "" if mAP > best_mAP: best_mAP = mAP torch.save({"state_dict": model.state_dict(), "config": {"d_anchor": D_ANCHOR, "n_anchors": N_ANCHORS, "n_ortho_bases": N_ORTHO_BASES, "n_experts": N_EXPERTS_COUNT, "coarse_comp": COARSE_COMP, "fine_comp": FINE_COMP, "micro_comp": MICRO_COMP, "d_coarse": D_COARSE, "d_fine": D_FINE, "d_micro": D_MICRO, "d_pw_proj": D_PW_PROJ, "anchor_drop": ANCHOR_DROP, "experts": EXPERTS, "cv_target": consensus_cv}, "pca_proj": pca_proj, "consensus_cv": consensus_cv, "mAP": mAP, "r1": r1, "cos": v_cos, "cv": v_cv, "epoch": epoch+1, "n_active": n_active}, "checkpoints/massive_soup_best.pt") mk = " ★" torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, "mAP": mAP, "optimizer": optimizer.state_dict(), "gs": gs}, f"checkpoints/massive_soup_e{epoch+1:02d}.pt") print(f" E{epoch+1} val: mAP={mAP:.3f} F1={f1_[f1_>0].mean():.3f} " f"R@1={r1:.3f} cos={v_cos:.3f} cv={v_cv:.4f} " f"anchors={n_active}/{N_ANCHORS}{mk}") writer.close() print(f"\n Best mAP: {best_mAP:.3f}") print(f" Total: {n_total:,} params") print(f"\n{'='*65}\nDONE\n{'='*65}")