| |
| """ |
| GEOLIP MASSIVE SOUP — ORTHO SPECTRUM HYPERSPHERE |
| ================================================== |
| 2048 anchors × 256-d × 3 expert perspectives. |
| |
| Orthogonal initialization: 8 rotated orthogonal bases of 256 vectors = 2048. |
| Each base tiles a different region of S^255. Together they form |
| a structured mesh with known geometric relationships. |
| |
| Multi-depth patchwork: |
| Level 0 (coarse): 16 compartments × 128 anchors × 3 experts = 384 inputs each → 128-d |
| Level 1 (fine): 64 compartments × 32 anchors × 3 experts = 96 inputs each → 64-d |
| Level 2 (micro): 128 compartments × 16 anchors × 3 experts = 48 inputs each → 32-d |
| |
| Total patchwork output: 16×128 + 64×64 + 128×32 = 2048 + 4096 + 4096 = 10240-d |
| → project down to 1024 before classifier |
| |
| The depth levels read the sphere at different resolutions. Coarse catches |
| global position, fine catches local neighborhood, micro catches sub-anchor |
| structure. Each level has its own expert-aware triangulation view. |
| |
| GPA → PCA 256-d → Procrustes calibration → train with full loss stack. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| import gc |
| from tqdm import tqdm |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| D_EXPERT = 768 |
| D_ANCHOR = 256 |
| N_ANCHORS = 2048 |
| N_ORTHO_BASES = 8 |
| N_EXPERTS_COUNT = 3 |
| N_CLASSES = 80 |
| ANCHOR_DROP = 0.30 |
|
|
| |
| COARSE_COMP = 16 |
| FINE_COMP = 64 |
| MICRO_COMP = 128 |
| D_COARSE = 128 |
| D_FINE = 64 |
| D_MICRO = 32 |
| D_PW_PROJ = 1024 |
|
|
| |
| BATCH = 128 |
| EPOCHS = 30 |
| LR = 1e-3 |
| QUEUE_SIZE = 4096 |
| GRAD_CLIP = 1.0 |
|
|
| EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
| TRI_DIM = N_ANCHORS * N_EXPERTS_COUNT |
|
|
| print("=" * 65) |
| print("GEOLIP MASSIVE SOUP — ORTHO SPECTRUM") |
| print(f" {N_ANCHORS} anchors × {D_ANCHOR}-d × {N_EXPERTS_COUNT} perspectives") |
| print(f" Ortho bases: {N_ORTHO_BASES} × {D_ANCHOR} = {N_ANCHORS}") |
| print(f" Patchwork: coarse({COARSE_COMP}) + fine({FINE_COMP}) + micro({MICRO_COMP})") |
| print(f" Device: {DEVICE}") |
| print("=" * 65) |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.2, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() |
|
|
| @torch.no_grad() |
| def cv_metric(emb, n_samples=500): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = np.array(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def infonce_queued(emb, targets, queue_emb, queue_tgt, temperature=0.07): |
| B = emb.shape[0] |
| e = F.normalize(emb, dim=-1); t = F.normalize(targets, dim=-1) |
| if queue_tgt is not None and queue_tgt.shape[0] > 0: |
| at = torch.cat([t, queue_tgt], 0); ae = torch.cat([e, queue_emb], 0) |
| else: |
| at = t; ae = e |
| l_e2t = (e @ at.T) / temperature; l_t2e = (t @ ae.T) / temperature |
| labels = torch.arange(B, device=emb.device) |
| loss = (F.cross_entropy(l_e2t, labels) + F.cross_entropy(l_t2e, labels)) / 2 |
| with torch.no_grad(): |
| acc = (l_e2t.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
| def whitened_procrustes_loss(emb, targets): |
| B = emb.shape[0] |
| if B < 10: return torch.tensor(0.0, device=emb.device) |
| em = emb.float().mean(0, keepdim=True); tm = targets.float().mean(0, keepdim=True) |
| return 1.0 - F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1).mean() |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n |
| corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(-1, keepdim=True) |
| corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
| def symmetric_inv_sqrt(cov, eps=1e-6): |
| evals, evecs = torch.linalg.eigh(cov) |
| return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T |
|
|
| def procrustes_align(source, target, n_align=10000): |
| N = min(n_align, source.shape[0], target.shape[0]) |
| S = source[:N].float(); T = target[:N].float() |
| sm = S.mean(0, keepdim=True); tm = T.mean(0, keepdim=True) |
| Sc = S - sm; Tc = T - tm; Ns = Sc.shape[0] |
| sw = symmetric_inv_sqrt((Sc.T @ Sc) / max(Ns-1, 1)) |
| tw = symmetric_inv_sqrt((Tc.T @ Tc) / max(Ns-1, 1)) |
| Sc_w = F.normalize(Sc @ sw, dim=-1); Tc_w = F.normalize(Tc @ tw, dim=-1) |
| U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| return {"rotation": U @ Vt, "source_mean": sm.squeeze(0), |
| "source_whitener": sw, "target_unwhitener": torch.linalg.pinv(tw)} |
|
|
| def apply_align(emb, a): |
| x = emb.float() - a["source_mean"] |
| return x @ a["source_whitener"] @ a["rotation"].T @ a["target_unwhitener"] |
|
|
|
|
| |
| |
| |
|
|
| def init_ortho_anchors(d, n_bases): |
| """ |
| Generate n_bases × d anchors from rotated orthonormal bases. |
| Each base is a full d×d orthogonal matrix (d vectors). |
| We take d vectors from each, rotated to tile different regions. |
| Total: n_bases × d anchors. |
| """ |
| all_anchors = [] |
| |
| base = torch.randn(d, d) |
| Q, _ = torch.linalg.qr(base) |
| all_anchors.append(Q) |
|
|
| for i in range(1, n_bases): |
| |
| R_rand = torch.randn(d, d) |
| R_q, _ = torch.linalg.qr(R_rand) |
| |
| rotated = Q @ R_q.T |
| all_anchors.append(rotated) |
|
|
| anchors = torch.cat(all_anchors, dim=0) |
| return F.normalize(anchors, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class FusedConstellation(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR, n_experts=N_EXPERTS_COUNT, |
| drop_rate=ANCHOR_DROP): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.n_experts = n_experts |
| self.drop_rate = drop_rate |
| self.d = d |
|
|
| self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) |
| self.expert_rotations = nn.ParameterList([ |
| nn.Parameter(torch.eye(d)) for _ in range(n_experts)]) |
| self.expert_whiteners = nn.ParameterList([ |
| nn.Parameter(torch.eye(d)) for _ in range(n_experts)]) |
| self.expert_means = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d)) for _ in range(n_experts)]) |
|
|
| def triangulate(self, emb, training=False): |
| B = emb.shape[0] |
| anchors_n = F.normalize(self.anchors, dim=-1) |
|
|
| expert_embs = [] |
| for i in range(self.n_experts): |
| centered = emb.float() - self.expert_means[i] |
| whitened = centered @ self.expert_whiteners[i] |
| rotated = F.normalize(whitened @ self.expert_rotations[i].T, dim=-1) |
| expert_embs.append(rotated) |
|
|
| if training and self.drop_rate > 0: |
| n_keep = max(int(self.n_anchors * (1 - self.drop_rate)), 128) |
| keep_idx = torch.randperm(self.n_anchors, device=emb.device)[:n_keep] |
| a_masked = anchors_n[keep_idx] |
| expert_tris, expert_cos_list = [], [] |
| for rotated in expert_embs: |
| cos = rotated @ a_masked.T |
| full_cos = torch.full((B, self.n_anchors), -1.0, |
| device=emb.device, dtype=cos.dtype) |
| full_cos[:, keep_idx] = cos |
| expert_tris.append(1.0 - full_cos) |
| expert_cos_list.append(full_cos) |
| else: |
| expert_tris, expert_cos_list = [], [] |
| for rotated in expert_embs: |
| cos = rotated @ anchors_n.T |
| expert_tris.append(1.0 - cos) |
| expert_cos_list.append(cos) |
|
|
| tri_stacked = torch.stack(expert_tris, dim=-1) |
| tri_fused = tri_stacked.reshape(B, -1) |
| mean_cos = torch.stack(expert_cos_list, dim=-1).mean(dim=-1) |
| nearest = mean_cos.argmax(dim=-1) |
|
|
| return tri_fused, nearest, tri_stacked |
|
|
| def anchor_spread_loss(self): |
| |
| a = F.normalize(self.anchors, dim=-1) |
| idx = torch.randperm(self.n_anchors, device=a.device)[:512] |
| a_sub = a[idx] |
| sim = a_sub @ a_sub.T; sim = sim - torch.diag(torch.diag(sim)) |
| return sim.pow(2).mean() |
|
|
| def expert_agreement_loss(self, emb): |
| anchors_n = F.normalize(self.anchors[:512], dim=-1) |
| expert_cos = [] |
| for i in range(self.n_experts): |
| centered = emb.float() - self.expert_means[i] |
| rotated = F.normalize(centered @ self.expert_whiteners[i] @ |
| self.expert_rotations[i].T, dim=-1) |
| expert_cos.append(rotated @ anchors_n.T) |
| stacked = torch.stack(expert_cos, dim=-1) |
| disagree = stacked.std(dim=-1) |
| return (disagree.mean() - 0.05).abs() |
|
|
|
|
| |
| |
| |
|
|
| class DepthLevel(nn.Module): |
| """Single depth level of the patchwork — reads a specific granularity.""" |
| def __init__(self, n_anchors, n_comp, n_experts, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| self.n_experts = n_experts |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| inputs_per_comp = (n_anchors // n_comp) * n_experts |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(inputs_per_comp, d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for _ in range(n_comp)]) |
|
|
| def forward(self, tri_3d): |
| """tri_3d: (B, n_anchors, n_experts)""" |
| B = tri_3d.shape[0] |
| results = [] |
| for k in range(self.n_comp): |
| mask = self.asgn == k |
| comp_input = tri_3d[:, mask, :].reshape(B, -1) |
| results.append(self.comps[k](comp_input)) |
| return torch.cat(results, dim=-1) |
|
|
|
|
| class MultiDepthPatchwork(nn.Module): |
| """ |
| Reads the sphere at 3 resolutions: |
| Coarse: 16 compartments, 128 anchors each — global position |
| Fine: 64 compartments, 32 anchors each — local neighborhood |
| Micro: 128 compartments, 16 anchors each — sub-anchor structure |
| |
| Combined output projected to D_PW_PROJ. |
| """ |
| def __init__(self): |
| super().__init__() |
| self.coarse = DepthLevel(N_ANCHORS, COARSE_COMP, N_EXPERTS_COUNT, D_COARSE) |
| self.fine = DepthLevel(N_ANCHORS, FINE_COMP, N_EXPERTS_COUNT, D_FINE) |
| self.micro = DepthLevel(N_ANCHORS, MICRO_COMP, N_EXPERTS_COUNT, D_MICRO) |
|
|
| |
| |
| |
| total_dim = COARSE_COMP * D_COARSE + FINE_COMP * D_FINE + MICRO_COMP * D_MICRO |
| self.proj = nn.Sequential( |
| nn.Linear(total_dim, D_PW_PROJ), nn.GELU(), |
| nn.LayerNorm(D_PW_PROJ)) |
|
|
| def forward(self, tri_3d): |
| """tri_3d: (B, N_ANCHORS, N_EXPERTS_COUNT)""" |
| c = self.coarse(tri_3d) |
| f = self.fine(tri_3d) |
| m = self.micro(tri_3d) |
| combined = torch.cat([c, f, m], dim=-1) |
| return self.proj(combined) |
|
|
|
|
| |
| |
| |
|
|
| class ExpertProjector(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.proj = nn.Sequential(nn.Linear(D_EXPERT, D_ANCHOR), nn.LayerNorm(D_ANCHOR)) |
| def forward(self, x): |
| return F.normalize(self.proj(x), dim=-1) |
|
|
|
|
| class MassiveSoup(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.projectors = nn.ModuleList([ExpertProjector() for _ in range(N_EXPERTS_COUNT)]) |
| self.constellation = FusedConstellation() |
| self.patchwork = MultiDepthPatchwork() |
|
|
| self.classifier = nn.Sequential( |
| nn.Linear(D_PW_PROJ + D_ANCHOR, D_PW_PROJ), nn.GELU(), |
| nn.LayerNorm(D_PW_PROJ), nn.Dropout(0.1), |
| nn.Linear(D_PW_PROJ, N_CLASSES)) |
|
|
| def forward(self, expert_features, apply_autograd=True): |
| projected = [self.projectors[i](expert_features[i]) for i in range(N_EXPERTS_COUNT)] |
| fused = F.normalize(sum(projected) / N_EXPERTS_COUNT, dim=-1) |
|
|
| if apply_autograd and self.training: |
| fused = EmbeddingAutograd.apply( |
| fused, fused, self.constellation.anchors, 0.01, 1.0) |
|
|
| tri_fused, nearest, tri_3d = self.constellation.triangulate( |
| fused, training=self.training) |
| pw = self.patchwork(tri_3d) |
| logits = self.classifier(torch.cat([pw, fused], dim=-1)) |
|
|
| return logits, fused, tri_fused, nearest, projected |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 0: LOAD DATA") |
| print(f"{'='*65}") |
|
|
| from datasets import load_dataset |
|
|
| ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train") |
| train_ids = ref["image_id"]; N_train = len(train_ids) |
| train_id_map = {iid: i for i, iid in enumerate(train_ids)} |
| train_labels = torch.zeros(N_train, N_CLASSES) |
| for i, labs in enumerate(ref["labels"]): |
| for l in labs: |
| if l < N_CLASSES: train_labels[i, l] = 1.0 |
|
|
| ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") |
| val_ids = ref_val["image_id"]; N_val = len(val_ids) |
| val_id_map = {iid: i for i, iid in enumerate(val_ids)} |
| val_labels = torch.zeros(N_val, N_CLASSES) |
| for i, labs in enumerate(ref_val["labels"]): |
| for l in labs: |
| if l < N_CLASSES: val_labels[i, l] = 1.0 |
|
|
| print(f" Train: {N_train:,} Val: {N_val:,}") |
|
|
| train_raw, val_raw = {}, {} |
| for name in EXPERTS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") |
| feats = torch.zeros(N_train, D_EXPERT) |
| for row in ds: |
| if row["image_id"] in train_id_map: |
| feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| train_raw[name] = feats |
| ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| feats_v = torch.zeros(N_val, D_EXPERT) |
| for row in ds_v: |
| if row["image_id"] in val_id_map: |
| feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| val_raw[name] = feats_v |
| print(f" {name:<30} loaded") |
| del ds, ds_v; gc.collect() |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 1: GPA + PCA + PROCRUSTES") |
| print(f"{'='*65}") |
|
|
| current = {name: train_raw[name].float() for name in EXPERTS} |
| for gpa_iter in range(20): |
| mean_shape = sum(current[n] for n in EXPERTS) / len(EXPERTS) |
| delta = 0.0 |
| for name in EXPERTS: |
| info = procrustes_align(current[name], mean_shape) |
| current[name] = apply_align(current[name], info) |
| delta += (current[name] - apply_align(train_raw[name].float(), info)).pow(2).mean().item() |
| |
| new_current = {} |
| delta = 0.0 |
| for name in EXPERTS: |
| info = procrustes_align(current[name], mean_shape) |
| new_current[name] = apply_align(current[name], info) |
| delta += (new_current[name] - current[name]).pow(2).mean().item() |
| current = new_current |
| if gpa_iter == 0 or (gpa_iter+1) % 5 == 0: |
| print(f" GPA iter {gpa_iter+1}: delta={delta:.8f}") |
| if delta < 1e-8: break |
|
|
| consensus_768 = F.normalize(sum(current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) |
| for name in EXPERTS: |
| c = F.cosine_similarity(consensus_768[:5000], current[name][:5000], dim=-1).mean().item() |
| print(f" cos(consensus, {name}): {c:.4f}") |
|
|
| |
| cc = consensus_768 - consensus_768.mean(0, keepdim=True) |
| U, S, Vt = torch.linalg.svd(cc[:10000], full_matrices=False) |
| pca_proj = Vt[:D_ANCHOR] |
| consensus_d = F.normalize(consensus_768 @ pca_proj.T, dim=-1) |
| var_ret = S[:D_ANCHOR].pow(2).sum() / S.pow(2).sum() |
| print(f" PCA 768→{D_ANCHOR}: var_retained={var_ret.item():.4f}") |
| consensus_cv = cv_metric(consensus_d[:5000].to(DEVICE)) |
| print(f" Consensus CV at {D_ANCHOR}-d: {consensus_cv:.4f}") |
|
|
| |
| val_current = {name: val_raw[name].float() for name in EXPERTS} |
| for _ in range(20): |
| vm = sum(val_current[n] for n in EXPERTS) / len(EXPERTS) |
| d = 0.0 |
| for name in EXPERTS: |
| info = procrustes_align(val_current[name], vm) |
| new = apply_align(val_current[name], info) |
| d += (new - val_current[name]).pow(2).mean().item() |
| val_current[name] = new |
| if d < 1e-8: break |
| val_consensus_768 = F.normalize(sum(val_current[n] for n in EXPERTS) / len(EXPERTS), dim=-1) |
| val_consensus_d = F.normalize(val_consensus_768 @ pca_proj.T, dim=-1) |
|
|
| |
| expert_calibrations = {} |
| for name in EXPERTS: |
| raw = train_raw[name][:10000].float() |
| tgt = consensus_d[:10000].float() |
| sm = raw.mean(0, keepdim=True); tm = tgt.mean(0, keepdim=True) |
| sc = raw - sm; tc = tgt - tm |
| sw = symmetric_inv_sqrt((sc.T @ sc) / 9999) |
| tw = symmetric_inv_sqrt((tc.T @ tc) / 9999) |
| src_w = F.normalize(sc @ sw, dim=-1); tgt_w = F.normalize(tc @ tw, dim=-1) |
| M = tgt_w.T @ src_w |
| U_r, S_r, Vt_r = torch.linalg.svd(M, full_matrices=False) |
| R = U_r @ Vt_r |
| proj_W = (sw @ R.T).T; proj_b = -(sm.squeeze(0) @ sw @ R.T).squeeze(0) |
| test = F.normalize(raw[:1000] @ proj_W.T + proj_b, dim=-1) |
| cos = F.cosine_similarity(test, tgt[:1000], dim=-1).mean().item() |
| expert_calibrations[name] = {"W": proj_W, "b": proj_b, "cos": cos, |
| "R": R[:D_ANCHOR, :D_ANCHOR], |
| "whiten": tw, "mean": tm.squeeze(0)} |
| print(f" {name:<30} cos={cos:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 2: BUILD MODEL") |
| print(f"{'='*65}") |
|
|
| model = MassiveSoup().to(DEVICE) |
|
|
| with torch.no_grad(): |
| |
| for i, name in enumerate(EXPERTS): |
| cal = expert_calibrations[name] |
| model.projectors[i].proj[0].weight.copy_(cal["W"].to(DEVICE)) |
| model.projectors[i].proj[0].bias.copy_(cal["b"].to(DEVICE)) |
| print(f" ✓ Projectors from Procrustes") |
|
|
| |
| ortho_anchors = init_ortho_anchors(D_ANCHOR, N_ORTHO_BASES) |
| model.constellation.anchors.copy_(ortho_anchors.to(DEVICE)) |
| print(f" ✓ {N_ANCHORS} ortho-spectrum anchors ({N_ORTHO_BASES} bases × {D_ANCHOR})") |
|
|
| |
| for i, name in enumerate(EXPERTS): |
| cal = expert_calibrations[name] |
| model.constellation.expert_rotations[i].copy_(cal["R"].to(DEVICE)) |
| model.constellation.expert_whiteners[i].copy_(cal["whiten"].to(DEVICE)) |
| model.constellation.expert_means[i].copy_(cal["mean"].to(DEVICE)) |
| print(f" ✓ Expert perspectives calibrated") |
|
|
| |
| with torch.no_grad(): |
| test_in = [train_raw[EXPERTS[e]][:200].to(DEVICE) for e in range(3)] |
| _, test_fused, _, test_nearest, _ = model(test_in, apply_autograd=False) |
| test_tgt = consensus_d[:200].to(DEVICE) |
| init_cos = F.cosine_similarity(test_fused, test_tgt, dim=-1).mean().item() |
| n_active = test_nearest.unique().numel() |
| print(f" Init: cos={init_cos:.4f} active_anchors={n_active}/{N_ANCHORS}") |
|
|
| |
| def count_params(module): |
| return sum(p.numel() for p in module.parameters()) |
| n_total = count_params(model) |
| print(f"\n Parameters:") |
| print(f" projectors: {sum(count_params(p) for p in model.projectors):>12,}") |
| print(f" constellation: {count_params(model.constellation):>12,}") |
| print(f" patchwork: {count_params(model.patchwork):>12,}") |
| print(f" classifier: {count_params(model.classifier):>12,}") |
| print(f" total: {n_total:>12,}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 3: TRAINING") |
| print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") |
| print(f" Queue: {QUEUE_SIZE} | Anchor dropout: {ANCHOR_DROP}") |
| print(f" CV target: {consensus_cv:.4f}") |
| print(f"{'='*65}") |
|
|
| train_targets = consensus_d.to(DEVICE) |
| val_targets = val_consensus_d.to(DEVICE) |
| train_labels_gpu = train_labels.to(DEVICE) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
| os.makedirs("checkpoints", exist_ok=True) |
| from torch.utils.tensorboard import SummaryWriter |
| writer = SummaryWriter("runs/massive_soup") |
| best_mAP = 0.0; gs = 0 |
| queue_e = torch.zeros(0, D_ANCHOR, device=DEVICE) |
| queue_t = torch.zeros(0, D_ANCHOR, device=DEVICE) |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| perm = torch.randperm(N_train) |
| acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0, |
| "spread": 0, "agree": 0, "align": 0, "nce_acc": 0, "n": 0} |
|
|
| pbar = tqdm(range(0, N_train, BATCH), desc=f"E{epoch+1:2d}/{EPOCHS}", unit="batch") |
| for i in pbar: |
| idx = perm[i:i+BATCH] |
| if len(idx) < 4: continue |
|
|
| batch = [train_raw[EXPERTS[e]][idx].to(DEVICE) for e in range(3)] |
| labels = train_labels_gpu[idx] |
| targets = train_targets[idx] |
|
|
| logits, fused, tri, nearest, projected = model(batch) |
|
|
| l_nce, nce_acc = infonce_queued(fused, targets, queue_e, queue_t) |
| with torch.no_grad(): |
| queue_e = torch.cat([queue_e, fused.detach()], 0)[-QUEUE_SIZE:] |
| queue_t = torch.cat([queue_t, targets.detach()], 0)[-QUEUE_SIZE:] |
|
|
| l_mse = F.mse_loss(fused, targets) |
| l_bce = F.binary_cross_entropy_with_logits(logits, labels) |
| l_align = whitened_procrustes_loss(fused, targets) |
| l_cv = cv_loss(fused, target=consensus_cv) |
| l_spread = model.constellation.anchor_spread_loss() |
| l_agree = model.constellation.expert_agreement_loss(fused) |
|
|
| loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce |
| + 0.5 * l_align + 0.001 * l_cv |
| + 1e-3 * l_spread + 0.1 * l_agree) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) |
|
|
| acc["loss"] += loss.item(); acc["nce"] += l_nce.item() |
| acc["mse"] += l_mse.item(); acc["bce"] += l_bce.item() |
| acc["cv"] += l_cv.item(); acc["spread"] += l_spread.item() |
| acc["agree"] += l_agree.item() if torch.is_tensor(l_agree) else l_agree |
| acc["align"] += l_align.item(); acc["nce_acc"] += nce_acc |
| acc["n"] += 1; gs += 1 |
|
|
| if gs % 50 == 0: |
| for k in ["loss", "nce", "bce", "nce_acc"]: |
| writer.add_scalar(f"step/{k}", acc[k]/max(acc["n"],1), gs) |
|
|
| if acc["n"] % 20 == 0: |
| d = acc["n"] |
| pbar.set_postfix(loss=f"{acc['loss']/d:.4f}", |
| nce_acc=f"{acc['nce_acc']/d:.3f}", |
| cos=f"{1-acc['align']/d:.3f}", ordered=True) |
|
|
| d = max(acc["n"], 1) |
| print(f" E{epoch+1} train: loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} " |
| f"bce={acc['bce']/d:.4f} agree={acc['agree']/d:.4f} " |
| f"nce_acc={acc['nce_acc']/d:.3f}") |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| all_lo, all_em = [], [] |
| for j in range(0, N_val, BATCH): |
| end = min(j + BATCH, N_val) |
| batch_v = [val_raw[EXPERTS[e]][j:end].to(DEVICE) for e in range(3)] |
| lo, em, _, _, _ = model(batch_v, apply_autograd=False) |
| all_lo.append(lo.cpu()); all_em.append(em.cpu()) |
| v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) |
|
|
| v_lab = val_labels |
| ap_sum, nv = 0, 0 |
| for c in range(N_CLASSES): |
| if v_lab[:, c].sum() > 0: |
| si = v_lo[:, c].argsort(descending=True); st = v_lab[:, c][si] |
| pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() |
| ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 |
| mAP = ap_sum / max(nv, 1) |
|
|
| vp = (v_lo.sigmoid() > 0.5).float() |
| tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0); fn = ((1-vp) * v_lab).sum(0) |
| pr_ = tp/(tp+fp+1e-8); rc_ = tp/(tp+fn+1e-8); f1_ = 2*pr_*rc_/(pr_+rc_+1e-8) |
|
|
| v_cos = F.cosine_similarity(v_em, val_targets.cpu(), dim=-1).mean().item() |
| sim = v_em @ val_targets.cpu().T |
| r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item() |
| _, v_nearest, _ = model.constellation.triangulate(v_em.to(DEVICE), training=False) |
| n_active = v_nearest.cpu().unique().numel() |
| v_cv = cv_metric(v_em[:2000].to(DEVICE)) |
|
|
| for k in acc: |
| if k != "n": writer.add_scalar(f"epoch/{k}", acc[k]/d, epoch+1) |
| writer.add_scalar("val/mAP", mAP, epoch+1) |
| writer.add_scalar("val/cos", v_cos, epoch+1) |
| writer.add_scalar("val/R@1", r1, epoch+1) |
| writer.add_scalar("val/anchors", n_active, epoch+1) |
| writer.add_scalar("val/cv", v_cv, epoch+1) |
|
|
| mk = "" |
| if mAP > best_mAP: |
| best_mAP = mAP |
| torch.save({"state_dict": model.state_dict(), |
| "config": {"d_anchor": D_ANCHOR, "n_anchors": N_ANCHORS, |
| "n_ortho_bases": N_ORTHO_BASES, |
| "n_experts": N_EXPERTS_COUNT, |
| "coarse_comp": COARSE_COMP, "fine_comp": FINE_COMP, |
| "micro_comp": MICRO_COMP, |
| "d_coarse": D_COARSE, "d_fine": D_FINE, |
| "d_micro": D_MICRO, "d_pw_proj": D_PW_PROJ, |
| "anchor_drop": ANCHOR_DROP, "experts": EXPERTS, |
| "cv_target": consensus_cv}, |
| "pca_proj": pca_proj, "consensus_cv": consensus_cv, |
| "mAP": mAP, "r1": r1, "cos": v_cos, "cv": v_cv, |
| "epoch": epoch+1, "n_active": n_active}, |
| "checkpoints/massive_soup_best.pt") |
| mk = " ★" |
|
|
| torch.save({"state_dict": model.state_dict(), "epoch": epoch+1, |
| "mAP": mAP, "optimizer": optimizer.state_dict(), "gs": gs}, |
| f"checkpoints/massive_soup_e{epoch+1:02d}.pt") |
|
|
| print(f" E{epoch+1} val: mAP={mAP:.3f} F1={f1_[f1_>0].mean():.3f} " |
| f"R@1={r1:.3f} cos={v_cos:.3f} cv={v_cv:.4f} " |
| f"anchors={n_active}/{N_ANCHORS}{mk}") |
|
|
| writer.close() |
| print(f"\n Best mAP: {best_mAP:.3f}") |
| print(f" Total: {n_total:,} params") |
| print(f"\n{'='*65}\nDONE\n{'='*65}") |