geolip-vit-base-x3 / run_2_soup_trainer.py
AbstractPhil's picture
Create run_2_soup_trainer.py
38ec99d verified
Raw
History Blame
23.2 kB
#!/usr/bin/env python3
"""
BASE TIER PATCHWORK SOUP β€” PROPERLY CALIBRATED
================================================
3 experts, all 768-d:
clip_l14_openai, dinov2_b14, siglip_b16_384
Pipeline (from CaptionBERT research):
1. GPA alignment at 768-d β†’ consensus
2. Measure consensus CV β†’ CV loss target
3. Per-expert whitened Procrustes calibration
4. Initialize projectors from Procrustes rotations
5. Train: projectors + constellation + patchwork + classifier
against consensus targets with calibrated CV
Architecture:
Per-expert: 768 β†’ 128 (Procrustes-initialized projection)
Constellation: 256 anchors Γ— 128-d (geometric autograd)
Patchwork: 8 compartments
Classifier: patchwork + fused β†’ 80 classes
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
D_EXPERT = 768
D_ANCHOR = 128
N_ANCHORS = 256
N_CLASSES = 80
N_COMP = 8
D_COMP = 64
BATCH = 128
EPOCHS = 20
LR = 1e-3
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
print("=" * 65)
print("BASE TIER PATCHWORK SOUP β€” CALIBRATED")
print(f" 3 experts Γ— {D_EXPERT}-d β†’ {N_ANCHORS} anchors Γ— {D_ANCHOR}-d")
print(f" Device: {DEVICE}")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
@torch.no_grad()
def cv_metric(emb, n_samples=500):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def anchor_spread_loss(anchors):
a = F.normalize(anchors, dim=-1)
sim = a @ a.T
sim = sim - torch.diag(torch.diag(sim))
return sim.pow(2).mean()
def anchor_entropy_loss(emb, anchors, sharpness=10.0):
a = F.normalize(anchors, dim=-1)
probs = F.softmax(emb @ a.T * sharpness, dim=-1)
return -(probs * (probs + 1e-12).log()).sum(-1).mean()
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# PROCRUSTES UTILITIES (from cotrain_bank.py)
# ══════════════════════════════════════════════════════════════════
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=10000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float(); T = target[:N].float()
s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True)
Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0]
s_cov = (Sc.T @ Sc) / max(N_s-1, 1)
t_cov = (Tc.T @ Tc) / max(N_s-1, 1)
s_whiten = symmetric_inv_sqrt(s_cov)
t_whiten = symmetric_inv_sqrt(t_cov)
Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
R = U @ Vt
cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
return {"rotation": R, "source_mean": s_mean.squeeze(0),
"source_whitener": s_whiten,
"target_unwhitener": torch.linalg.pinv(t_whiten),
"cos_after": cos_after}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
x = x @ a["source_whitener"]
x = x @ a["rotation"].T
x = x @ a["target_unwhitener"]
return x
# ══════════════════════════════════════════════════════════════════
# MODEL
# ══════════════════════════════════════════════════════════════════
class ExpertProjector(nn.Module):
def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR):
super().__init__()
self.proj = nn.Sequential(nn.Linear(d_in, d_out), nn.LayerNorm(d_out))
def forward(self, x):
return F.normalize(self.proj(x), dim=-1)
class Constellation(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR):
super().__init__()
self.n_anchors = n_anchors
self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
cos = emb @ a.T
return 1.0 - cos, cos.argmax(dim=-1)
class Patchwork(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP):
super().__init__()
self.n_comp = n_comp
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for k in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
class BaseTierSoup(nn.Module):
def __init__(self):
super().__init__()
self.n_experts = 3
self.projectors = nn.ModuleList([ExpertProjector() for _ in range(3)])
self.constellation = Constellation()
self.patchwork = Patchwork()
pw_dim = N_COMP * D_COMP
self.classifier = nn.Sequential(
nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(0.1),
nn.Linear(pw_dim, N_CLASSES))
def forward(self, expert_embeddings, apply_autograd=True):
projected = [self.projectors[i](expert_embeddings[i]) for i in range(3)]
fused = F.normalize(sum(projected) / 3, dim=-1)
if apply_autograd and self.training:
fused = EmbeddingAutograd.apply(
fused, fused, self.constellation.anchors, 0.01, 1.0)
tri, nearest = self.constellation.triangulate(fused)
pw = self.patchwork(tri)
logits = self.classifier(torch.cat([pw, fused], -1))
return logits, fused, tri, nearest, projected
# ══════════════════════════════════════════════════════════════════
# PHASE 0: LOAD DATA
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 0: LOAD DATA")
print(f"{'='*65}")
from datasets import load_dataset
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train")
train_ids = ref["image_id"]; N_train = len(train_ids)
train_id_map = {iid: i for i, iid in enumerate(train_ids)}
train_label_matrix = torch.zeros(N_train, N_CLASSES)
for i, labs in enumerate(ref["labels"]):
for l in labs:
if l < N_CLASSES: train_label_matrix[i, l] = 1.0
ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
val_ids = ref_val["image_id"]; N_val = len(val_ids)
val_id_map = {iid: i for i, iid in enumerate(val_ids)}
val_label_matrix = torch.zeros(N_val, N_CLASSES)
for i, labs in enumerate(ref_val["labels"]):
for l in labs:
if l < N_CLASSES: val_label_matrix[i, l] = 1.0
print(f" Train: {N_train:,} Val: {N_val:,}")
train_raw = {}
val_raw = {}
for name in EXPERTS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train")
feats = torch.zeros(N_train, D_EXPERT)
for row in ds:
if row["image_id"] in train_id_map:
feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
train_raw[name] = feats
ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
feats_v = torch.zeros(N_val, D_EXPERT)
for row in ds_v:
if row["image_id"] in val_id_map:
feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
val_raw[name] = feats_v
print(f" {name:<30} loaded", flush=True)
del ds, ds_v; gc.collect()
# ══════════════════════════════════════════════════════════════════
# PHASE 1: GPA ALIGNMENT AT 768-d
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 1: GPA ALIGNMENT AT 768-d")
print(f"{'='*65}")
current = {name: train_raw[name][:N_train].float() for name in EXPERTS}
for gpa_iter in range(20):
mean_shape = sum(current[n] for n in EXPERTS) / len(EXPERTS)
total_delta = 0.0
new_current = {}
for name in EXPERTS:
info = procrustes_align(current[name], mean_shape)
new_current[name] = apply_align(current[name], info)
total_delta += (new_current[name] - current[name]).pow(2).mean().item()
current = new_current
if gpa_iter == 0 or (gpa_iter+1) % 5 == 0:
print(f" GPA iter {gpa_iter+1}: delta={total_delta:.8f}")
if total_delta < 1e-8:
print(f" Converged at iteration {gpa_iter+1}"); break
consensus_768 = F.normalize(
sum(current[n] for n in EXPERTS) / len(EXPERTS), dim=-1)
for name in EXPERTS:
c = F.cosine_similarity(consensus_768[:5000], current[name][:5000], dim=-1).mean().item()
print(f" cos(consensus, {name}): {c:.4f}")
consensus_cv_768 = cv_metric(consensus_768[:5000].to(DEVICE))
print(f" Consensus CV at 768-d: {consensus_cv_768:.4f}")
# ══════════════════════════════════════════════════════════════════
# PHASE 2: PROJECT CONSENSUS TO 128-d + CALIBRATE CV
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 2: PROJECT TO 128-d + CALIBRATE")
print(f"{'='*65}")
cons_centered = consensus_768 - consensus_768.mean(0, keepdim=True)
U, S, Vt = torch.linalg.svd(cons_centered[:10000], full_matrices=False)
pca_proj = Vt[:D_ANCHOR]
consensus_128 = F.normalize(consensus_768 @ pca_proj.T, dim=-1)
var_retained = S[:D_ANCHOR].pow(2).sum() / S.pow(2).sum()
print(f" PCA 768β†’128: variance retained = {var_retained.item():.4f}")
consensus_cv_128 = cv_metric(consensus_128[:5000].to(DEVICE))
print(f" Consensus CV at 128-d: {consensus_cv_128:.4f}")
# Val consensus
val_current = {name: val_raw[name].float() for name in EXPERTS}
for gpa_iter in range(20):
val_mean = sum(val_current[n] for n in EXPERTS) / len(EXPERTS)
delta = 0.0
for name in EXPERTS:
info = procrustes_align(val_current[name], val_mean)
new = apply_align(val_current[name], info)
delta += (new - val_current[name]).pow(2).mean().item()
val_current[name] = new
if delta < 1e-8: break
val_consensus_768 = F.normalize(
sum(val_current[n] for n in EXPERTS) / len(EXPERTS), dim=-1)
val_consensus_128 = F.normalize(val_consensus_768 @ pca_proj.T, dim=-1)
print(f" Val consensus: {val_consensus_128.shape}")
# ══════════════════════════════════════════════════════════════════
# PHASE 3: PER-EXPERT PROCRUSTES TO 128-d CONSENSUS
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 3: PER-EXPERT PROCRUSTES CALIBRATION")
print(f"{'='*65}")
expert_calibrations = {}
for name in EXPERTS:
raw = train_raw[name][:10000].float()
tgt = consensus_128[:10000].float()
src_mean = raw.mean(0, keepdim=True)
tgt_mean = tgt.mean(0, keepdim=True)
src_c = raw[:10000] - src_mean
tgt_c = tgt[:10000] - tgt_mean
src_cov = (src_c.T @ src_c) / 9999
src_whiten = symmetric_inv_sqrt(src_cov)
tgt_cov = (tgt_c.T @ tgt_c) / 9999
tgt_whiten = symmetric_inv_sqrt(tgt_cov)
src_w = F.normalize(src_c @ src_whiten, dim=-1)
tgt_w = F.normalize(tgt_c @ tgt_whiten, dim=-1)
M = tgt_w.T @ src_w
U_r, S_r, Vt_r = torch.linalg.svd(M, full_matrices=False)
R = U_r @ Vt_r
proj_W = (src_whiten @ R.T).T
proj_b = -(src_mean.squeeze(0) @ src_whiten @ R.T).squeeze(0)
test_proj = raw[:1000] @ proj_W.T + proj_b
test_proj_n = F.normalize(test_proj, dim=-1)
cos = F.cosine_similarity(test_proj_n, tgt[:1000], dim=-1).mean().item()
expert_calibrations[name] = {"weight": proj_W, "bias": proj_b, "cos": cos, "svd_S": S_r}
print(f" {name:<30} cos={cos:.4f} svd: min={S_r.min():.4f} max={S_r.max():.4f}")
# ══════════════════════════════════════════════════════════════════
# PHASE 4: BUILD + INITIALIZE
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 4: BUILD + INITIALIZE")
print(f"{'='*65}")
model = BaseTierSoup().to(DEVICE)
with torch.no_grad():
for i, name in enumerate(EXPERTS):
cal = expert_calibrations[name]
model.projectors[i].proj[0].weight.copy_(cal["weight"].to(DEVICE))
model.projectors[i].proj[0].bias.copy_(cal["bias"].to(DEVICE))
print(f" βœ“ {name} projector initialized (cos={cal['cos']:.4f})")
sample_idx = torch.randperm(min(10000, N_train))[:N_ANCHORS]
anchor_seeds = consensus_128[sample_idx].to(DEVICE)
model.constellation.anchors.copy_(F.normalize(anchor_seeds, dim=-1))
print(f" βœ“ Constellation seeded from consensus")
# Verify
with torch.no_grad():
test_in = [train_raw[EXPERTS[e]][:200].to(DEVICE) for e in range(3)]
_, test_fused, _, test_nearest, test_proj = model(test_in, apply_autograd=False)
test_tgt = consensus_128[:200].to(DEVICE)
init_cos = F.cosine_similarity(test_fused, test_tgt, dim=-1).mean().item()
init_cv = cv_metric(test_fused)
n_active = test_nearest.unique().numel()
for e, name in enumerate(["clip", "dino", "siglip"]):
c = F.cosine_similarity(test_proj[e], test_tgt, dim=-1).mean().item()
print(f" {name} proj→consensus cos: {c:.4f}")
print(f" Init: cos={init_cos:.4f} cv={init_cv:.4f} active_anchors={n_active}/256")
params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {params:,}")
print(f" CV target: {consensus_cv_128:.4f}")
# ══════════════════════════════════════════════════════════════════
# PHASE 5: TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 5: TRAINING")
print(f" {EPOCHS} epochs, lr={LR}, CV target={consensus_cv_128:.4f}")
print(f"{'='*65}")
train_targets = consensus_128.to(DEVICE)
val_targets = val_consensus_128.to(DEVICE)
train_labels_gpu = train_label_matrix.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
os.makedirs("checkpoints", exist_ok=True)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/base_tier_calibrated")
best_mAP = 0.0; gs = 0
for epoch in range(EPOCHS):
model.train()
perm = torch.randperm(N_train)
tl, tn, nb = 0, 0, 0
for i in range(0, N_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 4: continue
batch = [train_raw[EXPERTS[e]][idx].to(DEVICE) for e in range(3)]
labels = train_labels_gpu[idx]
targets = train_targets[idx]
logits, fused, tri, nearest, projected = model(batch)
anchors = model.constellation.anchors
l_nce, nce_acc = infonce(fused, targets)
l_mse = F.mse_loss(fused, targets)
l_cls = F.binary_cross_entropy_with_logits(logits, labels)
l_cv = cv_loss(fused, target=consensus_cv_128)
l_spread = anchor_spread_loss(anchors)
l_ent = anchor_entropy_loss(fused, anchors)
loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_cls
+ 0.001 * l_cv + 1e-3 * l_spread + 1e-4 * l_ent)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
tl += loss.item(); tn += nce_acc; nb += 1; gs += 1
if gs % 100 == 0:
writer.add_scalar("train/loss", loss.item(), gs)
writer.add_scalar("train/nce", l_nce.item(), gs)
writer.add_scalar("train/cls", l_cls.item(), gs)
writer.add_scalar("train/cv", l_cv.item(), gs)
writer.add_scalar("train/nce_acc", nce_acc, gs)
# Validation
model.eval()
with torch.no_grad():
all_lo, all_em = [], []
for j in range(0, N_val, BATCH):
end = min(j + BATCH, N_val)
batch_v = [val_raw[EXPERTS[e]][j:end].to(DEVICE) for e in range(3)]
lo, em, _, _, _ = model(batch_v, apply_autograd=False)
all_lo.append(lo.cpu()); all_em.append(em.cpu())
v_lo = torch.cat(all_lo); v_em = torch.cat(all_em)
v_lab = val_label_matrix
ap_sum, nv = 0, 0
for c in range(N_CLASSES):
if v_lab[:, c].sum() > 0:
si = v_lo[:, c].argsort(descending=True)
st = v_lab[:, c][si]
pak = st.cumsum(0) / torch.arange(1, len(st)+1).float()
ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1
mAP = ap_sum / max(nv, 1)
vp = (v_lo.sigmoid() > 0.5).float()
tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0)
fn = ((1-vp) * v_lab).sum(0)
pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8)
f1 = 2*pr*rc/(pr+rc+1e-8)
macro_f1 = f1[f1 > 0].mean().item()
v_cos = F.cosine_similarity(v_em, val_targets.cpu(), dim=-1).mean().item()
v_cv = cv_metric(v_em.to(DEVICE))
sim = v_em @ val_targets.cpu().T
r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item()
_, v_nearest = model.constellation.triangulate(v_em.to(DEVICE))
n_active = v_nearest.cpu().unique().numel()
writer.add_scalar("val/mAP", mAP, epoch+1)
writer.add_scalar("val/cos", v_cos, epoch+1)
writer.add_scalar("val/cv", v_cv, epoch+1)
writer.add_scalar("val/R@1", r1, epoch+1)
writer.add_scalar("val/active_anchors", n_active, epoch+1)
mk = ""
if mAP > best_mAP:
best_mAP = mAP
torch.save({
"state_dict": model.state_dict(),
"config": {"d_expert": D_EXPERT, "d_anchor": D_ANCHOR,
"n_anchors": N_ANCHORS, "n_comp": N_COMP,
"d_comp": D_COMP, "n_classes": N_CLASSES,
"experts": EXPERTS, "cv_target": consensus_cv_128},
"pca_proj": pca_proj,
"consensus_cv_768": consensus_cv_768,
"consensus_cv_128": consensus_cv_128,
"epoch": epoch+1, "mAP": mAP, "cv": v_cv, "r1": r1,
}, "checkpoints/base_tier_best.pt")
mk = " β˜…"
print(f" E{epoch+1:2d}: mAP={mAP:.3f} F1={macro_f1:.3f} R@1={r1:.3f} "
f"cos={v_cos:.3f} cv={v_cv:.4f} anchors={n_active}/256 "
f"nce={tn/nb:.3f} loss={tl/nb:.4f}{mk}")
writer.close()
print(f"\n Best mAP: {best_mAP:.3f}")
print(f" CV target: {consensus_cv_128:.4f}")
print(f"\n{'='*65}\nDONE\n{'='*65}")