geolip-vit-base-x3 / run_3_heavy_soup_trainer.py
AbstractPhil's picture
Update run_3_heavy_soup_trainer.py
bb8153a verified
Raw
History Blame Contribute Delete
30.8 kB
#!/usr/bin/env python3
"""
GEOLIP MASSIVE SOUP — ORTHO SPECTRUM HYPERSPHERE
==================================================
2048 anchors × 256-d × 3 expert perspectives.
Orthogonal initialization: 8 rotated orthogonal bases of 256 vectors = 2048.
Each base tiles a different region of S^255. Together they form
a structured mesh with known geometric relationships.
Multi-depth patchwork:
Level 0 (coarse): 16 compartments × 128 anchors × 3 experts = 384 inputs each → 128-d
Level 1 (fine): 64 compartments × 32 anchors × 3 experts = 96 inputs each → 64-d
Level 2 (micro): 128 compartments × 16 anchors × 3 experts = 48 inputs each → 32-d
Total patchwork output: 16×128 + 64×64 + 128×32 = 2048 + 4096 + 4096 = 10240-d
→ project down to 1024 before classifier
The depth levels read the sphere at different resolutions. Coarse catches
global position, fine catches local neighborhood, micro catches sub-anchor
structure. Each level has its own expert-aware triangulation view.
GPA → PCA 256-d → Procrustes calibration → train with full loss stack.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import gc
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Geometry
D_EXPERT = 768
D_ANCHOR = 256
N_ANCHORS = 2048
N_ORTHO_BASES = 8 # 8 × 256 = 2048
N_EXPERTS_COUNT = 3
N_CLASSES = 80
ANCHOR_DROP = 0.30
# Multi-depth patchwork
COARSE_COMP = 16 # 2048/16 = 128 anchors per comp
FINE_COMP = 64 # 2048/64 = 32 anchors per comp
MICRO_COMP = 128 # 2048/128 = 16 anchors per comp
D_COARSE = 128
D_FINE = 64
D_MICRO = 32
D_PW_PROJ = 1024 # project combined patchwork to this
# Training
BATCH = 128
EPOCHS = 30
LR = 1e-3
QUEUE_SIZE = 4096
GRAD_CLIP = 1.0
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
TRI_DIM = N_ANCHORS * N_EXPERTS_COUNT
print("=" * 65)
print("GEOLIP MASSIVE SOUP — ORTHO SPECTRUM")
print(f" {N_ANCHORS} anchors × {D_ANCHOR}-d × {N_EXPERTS_COUNT} perspectives")
print(f" Ortho bases: {N_ORTHO_BASES} × {D_ANCHOR} = {N_ANCHORS}")
print(f" Patchwork: coarse({COARSE_COMP}) + fine({FINE_COMP}) + micro({MICRO_COMP})")
print(f" Device: {DEVICE}")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
@torch.no_grad()
def cv_metric(emb, n_samples=500):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def infonce_queued(emb, targets, queue_emb, queue_tgt, temperature=0.07):
B = emb.shape[0]
e = F.normalize(emb, dim=-1); t = F.normalize(targets, dim=-1)
if queue_tgt is not None and queue_tgt.shape[0] > 0:
at = torch.cat([t, queue_tgt], 0); ae = torch.cat([e, queue_emb], 0)
else:
at = t; ae = e
l_e2t = (e @ at.T) / temperature; l_t2e = (t @ ae.T) / temperature
labels = torch.arange(B, device=emb.device)
loss = (F.cross_entropy(l_e2t, labels) + F.cross_entropy(l_t2e, labels)) / 2
with torch.no_grad():
acc = (l_e2t.argmax(-1) == labels).float().mean().item()
return loss, acc
def whitened_procrustes_loss(emb, targets):
B = emb.shape[0]
if B < 10: return torch.tensor(0.0, device=emb.device)
em = emb.float().mean(0, keepdim=True); tm = targets.float().mean(0, keepdim=True)
return 1.0 - F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1).mean()
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=10000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float(); T = target[:N].float()
sm = S.mean(0, keepdim=True); tm = T.mean(0, keepdim=True)
Sc = S - sm; Tc = T - tm; Ns = Sc.shape[0]
sw = symmetric_inv_sqrt((Sc.T @ Sc) / max(Ns-1, 1))
tw = symmetric_inv_sqrt((Tc.T @ Tc) / max(Ns-1, 1))
Sc_w = F.normalize(Sc @ sw, dim=-1); Tc_w = F.normalize(Tc @ tw, dim=-1)
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
return {"rotation": U @ Vt, "source_mean": sm.squeeze(0),
"source_whitener": sw, "target_unwhitener": torch.linalg.pinv(tw)}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
return x @ a["source_whitener"] @ a["rotation"].T @ a["target_unwhitener"]
# ══════════════════════════════════════════════════════════════════
# ORTHO ANCHOR INITIALIZATION
# ══════════════════════════════════════════════════════════════════
def init_ortho_anchors(d, n_bases):
"""
Generate n_bases × d anchors from rotated orthonormal bases.
Each base is a full d×d orthogonal matrix (d vectors).
We take d vectors from each, rotated to tile different regions.
Total: n_bases × d anchors.
"""
all_anchors = []
# First base: identity-like (from QR of random)
base = torch.randn(d, d)
Q, _ = torch.linalg.qr(base)
all_anchors.append(Q) # d × d, each row is unit vector, all orthogonal
for i in range(1, n_bases):
# Generate random rotation
R_rand = torch.randn(d, d)
R_q, _ = torch.linalg.qr(R_rand)
# Rotate the base
rotated = Q @ R_q.T
all_anchors.append(rotated)
anchors = torch.cat(all_anchors, dim=0) # (n_bases*d, d)
return F.normalize(anchors, dim=-1)
# ══════════════════════════════════════════════════════════════════
# FUSED CONSTELLATION (2048 anchors)
# ══════════════════════════════════════════════════════════════════
class FusedConstellation(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR, n_experts=N_EXPERTS_COUNT,
drop_rate=ANCHOR_DROP):
super().__init__()
self.n_anchors = n_anchors
self.n_experts = n_experts
self.drop_rate = drop_rate
self.d = d
self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d)) for _ in range(n_experts)])
self.expert_whiteners = nn.ParameterList([
nn.Parameter(torch.eye(d)) for _ in range(n_experts)])
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d)) for _ in range(n_experts)])
def triangulate(self, emb, training=False):
B = emb.shape[0]
anchors_n = F.normalize(self.anchors, dim=-1)
expert_embs = []
for i in range(self.n_experts):
centered = emb.float() - self.expert_means[i]
whitened = centered @ self.expert_whiteners[i]
rotated = F.normalize(whitened @ self.expert_rotations[i].T, dim=-1)
expert_embs.append(rotated)
if training and self.drop_rate > 0:
n_keep = max(int(self.n_anchors * (1 - self.drop_rate)), 128)
keep_idx = torch.randperm(self.n_anchors, device=emb.device)[:n_keep]
a_masked = anchors_n[keep_idx]
expert_tris, expert_cos_list = [], []
for rotated in expert_embs:
cos = rotated @ a_masked.T
full_cos = torch.full((B, self.n_anchors), -1.0,
device=emb.device, dtype=cos.dtype)
full_cos[:, keep_idx] = cos
expert_tris.append(1.0 - full_cos)
expert_cos_list.append(full_cos)
else:
expert_tris, expert_cos_list = [], []
for rotated in expert_embs:
cos = rotated @ anchors_n.T
expert_tris.append(1.0 - cos)
expert_cos_list.append(cos)
tri_stacked = torch.stack(expert_tris, dim=-1) # (B, N_ANCHORS, 3)
tri_fused = tri_stacked.reshape(B, -1)
mean_cos = torch.stack(expert_cos_list, dim=-1).mean(dim=-1)
nearest = mean_cos.argmax(dim=-1)
return tri_fused, nearest, tri_stacked
def anchor_spread_loss(self):
# Sample-based for 2048 anchors (full 2048×2048 is too big)
a = F.normalize(self.anchors, dim=-1)
idx = torch.randperm(self.n_anchors, device=a.device)[:512]
a_sub = a[idx]
sim = a_sub @ a_sub.T; sim = sim - torch.diag(torch.diag(sim))
return sim.pow(2).mean()
def expert_agreement_loss(self, emb):
anchors_n = F.normalize(self.anchors[:512], dim=-1) # subsample for speed
expert_cos = []
for i in range(self.n_experts):
centered = emb.float() - self.expert_means[i]
rotated = F.normalize(centered @ self.expert_whiteners[i] @
self.expert_rotations[i].T, dim=-1)
expert_cos.append(rotated @ anchors_n.T)
stacked = torch.stack(expert_cos, dim=-1)
disagree = stacked.std(dim=-1)
return (disagree.mean() - 0.05).abs()
# ══════════════════════════════════════════════════════════════════
# MULTI-DEPTH PATCHWORK
# ══════════════════════════════════════════════════════════════════
class DepthLevel(nn.Module):
"""Single depth level of the patchwork — reads a specific granularity."""
def __init__(self, n_anchors, n_comp, n_experts, d_comp):
super().__init__()
self.n_comp = n_comp
self.n_experts = n_experts
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
inputs_per_comp = (n_anchors // n_comp) * n_experts
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear(inputs_per_comp, d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for _ in range(n_comp)])
def forward(self, tri_3d):
"""tri_3d: (B, n_anchors, n_experts)"""
B = tri_3d.shape[0]
results = []
for k in range(self.n_comp):
mask = self.asgn == k
comp_input = tri_3d[:, mask, :].reshape(B, -1)
results.append(self.comps[k](comp_input))
return torch.cat(results, dim=-1)
class MultiDepthPatchwork(nn.Module):
"""
Reads the sphere at 3 resolutions:
Coarse: 16 compartments, 128 anchors each — global position
Fine: 64 compartments, 32 anchors each — local neighborhood
Micro: 128 compartments, 16 anchors each — sub-anchor structure
Combined output projected to D_PW_PROJ.
"""
def __init__(self):
super().__init__()
self.coarse = DepthLevel(N_ANCHORS, COARSE_COMP, N_EXPERTS_COUNT, D_COARSE)
self.fine = DepthLevel(N_ANCHORS, FINE_COMP, N_EXPERTS_COUNT, D_FINE)
self.micro = DepthLevel(N_ANCHORS, MICRO_COMP, N_EXPERTS_COUNT, D_MICRO)
# Coarse: 16 × 128 = 2048
# Fine: 64 × 64 = 4096
# Micro: 128 × 32 = 4096
total_dim = COARSE_COMP * D_COARSE + FINE_COMP * D_FINE + MICRO_COMP * D_MICRO
self.proj = nn.Sequential(
nn.Linear(total_dim, D_PW_PROJ), nn.GELU(),
nn.LayerNorm(D_PW_PROJ))
def forward(self, tri_3d):
"""tri_3d: (B, N_ANCHORS, N_EXPERTS_COUNT)"""
c = self.coarse(tri_3d)
f = self.fine(tri_3d)
m = self.micro(tri_3d)
combined = torch.cat([c, f, m], dim=-1)
return self.proj(combined)
# ══════════════════════════════════════════════════════════════════
# EXPERT PROJECTOR + MODEL
# ══════════════════════════════════════════════════════════════════
class ExpertProjector(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Sequential(nn.Linear(D_EXPERT, D_ANCHOR), nn.LayerNorm(D_ANCHOR))
def forward(self, x):
return F.normalize(self.proj(x), dim=-1)
class MassiveSoup(nn.Module):
def __init__(self):
super().__init__()
self.projectors = nn.ModuleList([ExpertProjector() for _ in range(N_EXPERTS_COUNT)])
self.constellation = FusedConstellation()
self.patchwork = MultiDepthPatchwork()
self.classifier = nn.Sequential(
nn.Linear(D_PW_PROJ + D_ANCHOR, D_PW_PROJ), nn.GELU(),
nn.LayerNorm(D_PW_PROJ), nn.Dropout(0.1),
nn.Linear(D_PW_PROJ, N_CLASSES))
def forward(self, expert_features, apply_autograd=True):
projected = [self.projectors[i](expert_features[i]) for i in range(N_EXPERTS_COUNT)]
fused = F.normalize(sum(projected) / N_EXPERTS_COUNT, dim=-1)
if apply_autograd and self.training:
fused = EmbeddingAutograd.apply(
fused, fused, self.constellation.anchors, 0.01, 1.0)
tri_fused, nearest, tri_3d = self.constellation.triangulate(
fused, training=self.training)
pw = self.patchwork(tri_3d)
logits = self.classifier(torch.cat([pw, fused], dim=-1))
return logits, fused, tri_fused, nearest, projected
# ══════════════════════════════════════════════════════════════════
# LOAD DATA + GPA + CALIBRATE
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 0: LOAD DATA")
print(f"{'='*65}")
from datasets import load_dataset
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train")
train_ids = ref["image_id"]; N_train = len(train_ids)
train_id_map = {iid: i for i, iid in enumerate(train_ids)}
train_labels = torch.zeros(N_train, N_CLASSES)
for i, labs in enumerate(ref["labels"]):
for l in labs:
if l < N_CLASSES: train_labels[i, l] = 1.0
ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
val_ids = ref_val["image_id"]; N_val = len(val_ids)
val_id_map = {iid: i for i, iid in enumerate(val_ids)}
val_labels = torch.zeros(N_val, N_CLASSES)
for i, labs in enumerate(ref_val["labels"]):
for l in labs:
if l < N_CLASSES: val_labels[i, l] = 1.0
print(f" Train: {N_train:,} Val: {N_val:,}")
train_raw, val_raw = {}, {}
for name in EXPERTS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train")
feats = torch.zeros(N_train, D_EXPERT)
for row in ds:
if row["image_id"] in train_id_map:
feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
train_raw[name] = feats
ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
feats_v = torch.zeros(N_val, D_EXPERT)
for row in ds_v:
if row["image_id"] in val_id_map:
feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
val_raw[name] = feats_v
print(f" {name:<30} loaded")
del ds, ds_v; gc.collect()
# GPA
print(f"\n{'='*65}")
print("PHASE 1: GPA + PCA + PROCRUSTES")
print(f"{'='*65}")
current = {name: train_raw[name].float() for name in EXPERTS}
for gpa_iter in range(20):
mean_shape = sum(current[n] for n in EXPERTS) / len(EXPERTS)
delta = 0.0
for name in EXPERTS:
info = procrustes_align(current[name], mean_shape)
current[name] = apply_align(current[name], info)
delta += (current[name] - apply_align(train_raw[name].float(), info)).pow(2).mean().item()
# Recompute properly
new_current = {}
delta = 0.0
for name in EXPERTS:
info = procrustes_align(current[name], mean_shape)
new_current[name] = apply_align(current[name], info)
delta += (new_current[name] - current[name]).pow(2).mean().item()
current = new_current
if gpa_iter == 0 or (gpa_iter+1) % 5 == 0:
print(f" GPA iter {gpa_iter+1}: delta={delta:.8f}")
if delta < 1e-8: break
consensus_768 = F.normalize(sum(current[n] for n in EXPERTS) / len(EXPERTS), dim=-1)
for name in EXPERTS:
c = F.cosine_similarity(consensus_768[:5000], current[name][:5000], dim=-1).mean().item()
print(f" cos(consensus, {name}): {c:.4f}")
# PCA → 256-d
cc = consensus_768 - consensus_768.mean(0, keepdim=True)
U, S, Vt = torch.linalg.svd(cc[:10000], full_matrices=False)
pca_proj = Vt[:D_ANCHOR]
consensus_d = F.normalize(consensus_768 @ pca_proj.T, dim=-1)
var_ret = S[:D_ANCHOR].pow(2).sum() / S.pow(2).sum()
print(f" PCA 768→{D_ANCHOR}: var_retained={var_ret.item():.4f}")
consensus_cv = cv_metric(consensus_d[:5000].to(DEVICE))
print(f" Consensus CV at {D_ANCHOR}-d: {consensus_cv:.4f}")
# Val consensus
val_current = {name: val_raw[name].float() for name in EXPERTS}
for _ in range(20):
vm = sum(val_current[n] for n in EXPERTS) / len(EXPERTS)
d = 0.0
for name in EXPERTS:
info = procrustes_align(val_current[name], vm)
new = apply_align(val_current[name], info)
d += (new - val_current[name]).pow(2).mean().item()
val_current[name] = new
if d < 1e-8: break
val_consensus_768 = F.normalize(sum(val_current[n] for n in EXPERTS) / len(EXPERTS), dim=-1)
val_consensus_d = F.normalize(val_consensus_768 @ pca_proj.T, dim=-1)
# Per-expert Procrustes
expert_calibrations = {}
for name in EXPERTS:
raw = train_raw[name][:10000].float()
tgt = consensus_d[:10000].float()
sm = raw.mean(0, keepdim=True); tm = tgt.mean(0, keepdim=True)
sc = raw - sm; tc = tgt - tm
sw = symmetric_inv_sqrt((sc.T @ sc) / 9999)
tw = symmetric_inv_sqrt((tc.T @ tc) / 9999)
src_w = F.normalize(sc @ sw, dim=-1); tgt_w = F.normalize(tc @ tw, dim=-1)
M = tgt_w.T @ src_w
U_r, S_r, Vt_r = torch.linalg.svd(M, full_matrices=False)
R = U_r @ Vt_r
proj_W = (sw @ R.T).T; proj_b = -(sm.squeeze(0) @ sw @ R.T).squeeze(0)
test = F.normalize(raw[:1000] @ proj_W.T + proj_b, dim=-1)
cos = F.cosine_similarity(test, tgt[:1000], dim=-1).mean().item()
expert_calibrations[name] = {"W": proj_W, "b": proj_b, "cos": cos,
"R": R[:D_ANCHOR, :D_ANCHOR],
"whiten": tw, "mean": tm.squeeze(0)}
print(f" {name:<30} cos={cos:.4f}")
# ══════════════════════════════════════════════════════════════════
# BUILD + INITIALIZE
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 2: BUILD MODEL")
print(f"{'='*65}")
model = MassiveSoup().to(DEVICE)
with torch.no_grad():
# Projectors
for i, name in enumerate(EXPERTS):
cal = expert_calibrations[name]
model.projectors[i].proj[0].weight.copy_(cal["W"].to(DEVICE))
model.projectors[i].proj[0].bias.copy_(cal["b"].to(DEVICE))
print(f" ✓ Projectors from Procrustes")
# Ortho anchors
ortho_anchors = init_ortho_anchors(D_ANCHOR, N_ORTHO_BASES)
model.constellation.anchors.copy_(ortho_anchors.to(DEVICE))
print(f" ✓ {N_ANCHORS} ortho-spectrum anchors ({N_ORTHO_BASES} bases × {D_ANCHOR})")
# Expert perspectives
for i, name in enumerate(EXPERTS):
cal = expert_calibrations[name]
model.constellation.expert_rotations[i].copy_(cal["R"].to(DEVICE))
model.constellation.expert_whiteners[i].copy_(cal["whiten"].to(DEVICE))
model.constellation.expert_means[i].copy_(cal["mean"].to(DEVICE))
print(f" ✓ Expert perspectives calibrated")
# Verify
with torch.no_grad():
test_in = [train_raw[EXPERTS[e]][:200].to(DEVICE) for e in range(3)]
_, test_fused, _, test_nearest, _ = model(test_in, apply_autograd=False)
test_tgt = consensus_d[:200].to(DEVICE)
init_cos = F.cosine_similarity(test_fused, test_tgt, dim=-1).mean().item()
n_active = test_nearest.unique().numel()
print(f" Init: cos={init_cos:.4f} active_anchors={n_active}/{N_ANCHORS}")
# Count params
def count_params(module):
return sum(p.numel() for p in module.parameters())
n_total = count_params(model)
print(f"\n Parameters:")
print(f" projectors: {sum(count_params(p) for p in model.projectors):>12,}")
print(f" constellation: {count_params(model.constellation):>12,}")
print(f" patchwork: {count_params(model.patchwork):>12,}")
print(f" classifier: {count_params(model.classifier):>12,}")
print(f" total: {n_total:>12,}")
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 3: TRAINING")
print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}")
print(f" Queue: {QUEUE_SIZE} | Anchor dropout: {ANCHOR_DROP}")
print(f" CV target: {consensus_cv:.4f}")
print(f"{'='*65}")
train_targets = consensus_d.to(DEVICE)
val_targets = val_consensus_d.to(DEVICE)
train_labels_gpu = train_labels.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
os.makedirs("checkpoints", exist_ok=True)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/massive_soup")
best_mAP = 0.0; gs = 0
queue_e = torch.zeros(0, D_ANCHOR, device=DEVICE)
queue_t = torch.zeros(0, D_ANCHOR, device=DEVICE)
for epoch in range(EPOCHS):
model.train()
perm = torch.randperm(N_train)
acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0,
"spread": 0, "agree": 0, "align": 0, "nce_acc": 0, "n": 0}
pbar = tqdm(range(0, N_train, BATCH), desc=f"E{epoch+1:2d}/{EPOCHS}", unit="batch")
for i in pbar:
idx = perm[i:i+BATCH]
if len(idx) < 4: continue
batch = [train_raw[EXPERTS[e]][idx].to(DEVICE) for e in range(3)]
labels = train_labels_gpu[idx]
targets = train_targets[idx]
logits, fused, tri, nearest, projected = model(batch)
l_nce, nce_acc = infonce_queued(fused, targets, queue_e, queue_t)
with torch.no_grad():
queue_e = torch.cat([queue_e, fused.detach()], 0)[-QUEUE_SIZE:]
queue_t = torch.cat([queue_t, targets.detach()], 0)[-QUEUE_SIZE:]
l_mse = F.mse_loss(fused, targets)
l_bce = F.binary_cross_entropy_with_logits(logits, labels)
l_align = whitened_procrustes_loss(fused, targets)
l_cv = cv_loss(fused, target=consensus_cv)
l_spread = model.constellation.anchor_spread_loss()
l_agree = model.constellation.expert_agreement_loss(fused)
loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce
+ 0.5 * l_align + 0.001 * l_cv
+ 1e-3 * l_spread + 0.1 * l_agree)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
acc["loss"] += loss.item(); acc["nce"] += l_nce.item()
acc["mse"] += l_mse.item(); acc["bce"] += l_bce.item()
acc["cv"] += l_cv.item(); acc["spread"] += l_spread.item()
acc["agree"] += l_agree.item() if torch.is_tensor(l_agree) else l_agree
acc["align"] += l_align.item(); acc["nce_acc"] += nce_acc
acc["n"] += 1; gs += 1
if gs % 50 == 0:
for k in ["loss", "nce", "bce", "nce_acc"]:
writer.add_scalar(f"step/{k}", acc[k]/max(acc["n"],1), gs)
if acc["n"] % 20 == 0:
d = acc["n"]
pbar.set_postfix(loss=f"{acc['loss']/d:.4f}",
nce_acc=f"{acc['nce_acc']/d:.3f}",
cos=f"{1-acc['align']/d:.3f}", ordered=True)
d = max(acc["n"], 1)
print(f" E{epoch+1} train: loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} "
f"bce={acc['bce']/d:.4f} agree={acc['agree']/d:.4f} "
f"nce_acc={acc['nce_acc']/d:.3f}")
# Validation
model.eval()
with torch.no_grad():
all_lo, all_em = [], []
for j in range(0, N_val, BATCH):
end = min(j + BATCH, N_val)
batch_v = [val_raw[EXPERTS[e]][j:end].to(DEVICE) for e in range(3)]
lo, em, _, _, _ = model(batch_v, apply_autograd=False)
all_lo.append(lo.cpu()); all_em.append(em.cpu())
v_lo = torch.cat(all_lo); v_em = torch.cat(all_em)
v_lab = val_labels
ap_sum, nv = 0, 0
for c in range(N_CLASSES):
if v_lab[:, c].sum() > 0:
si = v_lo[:, c].argsort(descending=True); st = v_lab[:, c][si]
pak = st.cumsum(0) / torch.arange(1, len(st)+1).float()
ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1
mAP = ap_sum / max(nv, 1)
vp = (v_lo.sigmoid() > 0.5).float()
tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0); fn = ((1-vp) * v_lab).sum(0)
pr_ = tp/(tp+fp+1e-8); rc_ = tp/(tp+fn+1e-8); f1_ = 2*pr_*rc_/(pr_+rc_+1e-8)
v_cos = F.cosine_similarity(v_em, val_targets.cpu(), dim=-1).mean().item()
sim = v_em @ val_targets.cpu().T
r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item()
_, v_nearest, _ = model.constellation.triangulate(v_em.to(DEVICE), training=False)
n_active = v_nearest.cpu().unique().numel()
v_cv = cv_metric(v_em[:2000].to(DEVICE))
for k in acc:
if k != "n": writer.add_scalar(f"epoch/{k}", acc[k]/d, epoch+1)
writer.add_scalar("val/mAP", mAP, epoch+1)
writer.add_scalar("val/cos", v_cos, epoch+1)
writer.add_scalar("val/R@1", r1, epoch+1)
writer.add_scalar("val/anchors", n_active, epoch+1)
writer.add_scalar("val/cv", v_cv, epoch+1)
mk = ""
if mAP > best_mAP:
best_mAP = mAP
torch.save({"state_dict": model.state_dict(),
"config": {"d_anchor": D_ANCHOR, "n_anchors": N_ANCHORS,
"n_ortho_bases": N_ORTHO_BASES,
"n_experts": N_EXPERTS_COUNT,
"coarse_comp": COARSE_COMP, "fine_comp": FINE_COMP,
"micro_comp": MICRO_COMP,
"d_coarse": D_COARSE, "d_fine": D_FINE,
"d_micro": D_MICRO, "d_pw_proj": D_PW_PROJ,
"anchor_drop": ANCHOR_DROP, "experts": EXPERTS,
"cv_target": consensus_cv},
"pca_proj": pca_proj, "consensus_cv": consensus_cv,
"mAP": mAP, "r1": r1, "cos": v_cos, "cv": v_cv,
"epoch": epoch+1, "n_active": n_active},
"checkpoints/massive_soup_best.pt")
mk = " ★"
torch.save({"state_dict": model.state_dict(), "epoch": epoch+1,
"mAP": mAP, "optimizer": optimizer.state_dict(), "gs": gs},
f"checkpoints/massive_soup_e{epoch+1:02d}.pt")
print(f" E{epoch+1} val: mAP={mAP:.3f} F1={f1_[f1_>0].mean():.3f} "
f"R@1={r1:.3f} cos={v_cos:.3f} cv={v_cv:.4f} "
f"anchors={n_active}/{N_ANCHORS}{mk}")
writer.close()
print(f"\n Best mAP: {best_mAP:.3f}")
print(f" Total: {n_total:,} params")
print(f"\n{'='*65}\nDONE\n{'='*65}")