AbstractPhil's picture
Rename soup_maker.py to run_1-collapse/soup_maker.py
62e2bca verified
Raw
History Blame
18.4 kB
#!/usr/bin/env python3
"""
BASE TIER PATCHWORK SOUP
First we make the soup, then we refine it into stew, and then crystalize it into matter.
Parmas in current configuration:
Parameters:
projectors : 296,064
constellation : 32,768
patchwork : 100,864
classifier : 370,256
total : 799,952
=========================
3 experts, all 768-d, no dimensional mismatch:
clip_l14_openai — semantic (text-supervised)
dinov2_b14 — structural (self-supervised)
siglip_b16_384 — semantic (sigmoid contrastive)
Architecture:
Per-expert: projection 768 → 128 (learned, on hypersphere)
Constellation: 256 anchors at 128-d (dynamic, geometric autograd)
Patchwork: 8 compartments reading triangulation distances
Classifier: patchwork output → 80-class multi-label
Phase 1: Soup the three experts into the constellation
Phase 2: Alignment bank reads the crystallized geometry
2,201 patches per image across 3 experts, compressed into
256 anchors on the 128-d hypersphere. The sphere has more
than enough capacity — this is a fraction of what 128-d holds.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Architecture
D_EXPERT = 768
D_ANCHOR = 128 # anchor/constellation dimension
N_ANCHORS = 256 # 256 anchors on the 128-d hypersphere
N_CLASSES = 80
N_COMP = 8 # patchwork compartments
D_COMP = 64 # per-compartment output
# Training
BATCH = 128
EPOCHS = 20
LR = 1e-3
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
print("=" * 65)
print("BASE TIER PATCHWORK SOUP")
print(f" 3 experts × {D_EXPERT}-d → {N_ANCHORS} anchors × {D_ANCHOR}-d")
print(f" Device: {DEVICE}")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
@torch.no_grad()
def cv_metric(emb, n_samples=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n_samples):
idx = torch.randperm(B)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = torch.tensor(vols)
return float(a.std() / (a.mean() + 1e-8))
def anchor_spread_loss(anchors):
a = F.normalize(anchors, dim=-1)
sim = a @ a.T
sim = sim - torch.diag(torch.diag(sim))
return sim.pow(2).mean()
def anchor_entropy_loss(emb, anchors, sharpness=10.0):
a = F.normalize(anchors, dim=-1)
probs = F.softmax(emb @ a.T * sharpness, dim=-1)
return -(probs * (probs + 1e-12).log()).sum(-1).mean()
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# MODEL
# ══════════════════════════════════════════════════════════════════
class ExpertProjector(nn.Module):
"""768-d → 128-d, L2-normalized onto hypersphere."""
def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(d_in, d_out),
nn.LayerNorm(d_out),
)
def forward(self, x):
return F.normalize(self.proj(x), dim=-1)
class Constellation(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR):
super().__init__()
self.n_anchors = n_anchors
self.anchors = nn.Parameter(F.normalize(
torch.randn(n_anchors, d), dim=-1))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
cos = emb @ a.T
return 1.0 - cos, cos.argmax(dim=-1)
class Patchwork(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP):
super().__init__()
self.n_comp = n_comp
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
for k in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
class BaseTierSoup(nn.Module):
"""
3-expert soup on 128-d hypersphere.
Each expert: 768-d → projector → 128-d (on sphere)
Per-image: 3 projected embeddings → mean → on sphere
Constellation: 256 anchors at 128-d
Patchwork: 8 compartments → classifier
The projectors learn to place each expert's perspective
into the shared 128-d anchor space. The constellation
crystallizes through geometric autograd.
"""
def __init__(self, n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR,
n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
n_classes=N_CLASSES):
super().__init__()
self.n_experts = n_experts
self.d_anchor = d_anchor
# Per-expert projection to anchor space
self.projectors = nn.ModuleList([
ExpertProjector(d_expert, d_anchor) for _ in range(n_experts)])
# Geometric pipeline
self.constellation = Constellation(n_anchors, d_anchor)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
# Classifier
pw_dim = n_comp * d_comp
self.classifier = nn.Sequential(
nn.Linear(pw_dim + d_anchor, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim),
nn.Dropout(0.1),
nn.Linear(pw_dim, n_classes))
def forward(self, expert_embeddings, apply_autograd=True):
"""
expert_embeddings: list of (B, 768) tensors, one per expert
"""
# Project each expert to 128-d hypersphere
projected = [self.projectors[i](expert_embeddings[i])
for i in range(self.n_experts)]
# Fuse: mean on hypersphere (normalize after averaging)
fused = F.normalize(sum(projected) / self.n_experts, dim=-1)
# Geometric autograd
if apply_autograd and self.training:
fused = EmbeddingAutograd.apply(
fused, fused, self.constellation.anchors, 0.01, 1.0)
# Triangulate + patchwork
tri, nearest = self.constellation.triangulate(fused)
pw = self.patchwork(tri)
# Classify
logits = self.classifier(torch.cat([pw, fused], -1))
return logits, fused, tri, nearest, projected
def count_params(self):
proj = sum(sum(p.numel() for p in pr.parameters()) for pr in self.projectors)
const = sum(p.numel() for p in self.constellation.parameters())
pw = sum(p.numel() for p in self.patchwork.parameters())
cls = sum(p.numel() for p in self.classifier.parameters())
return {"projectors": proj, "constellation": const,
"patchwork": pw, "classifier": cls,
"total": proj + const + pw + cls}
# ══════════════════════════════════════════════════════════════════
# LOAD DATA
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("LOADING DATA")
print(f"{'='*65}")
from datasets import load_dataset
# Reference for image_ids and labels
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="train")
train_ids = ref["image_id"]; N_train = len(train_ids)
train_id_map = {iid: i for i, iid in enumerate(train_ids)}
train_labels_raw = ref["labels"]
train_label_matrix = torch.zeros(N_train, N_CLASSES)
for i, labs in enumerate(train_labels_raw):
for l in labs:
if l < N_CLASSES: train_label_matrix[i, l] = 1.0
ref_val = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
val_ids = ref_val["image_id"]; N_val = len(val_ids)
val_id_map = {iid: i for i, iid in enumerate(val_ids)}
val_labels_raw = ref_val["labels"]
val_label_matrix = torch.zeros(N_val, N_CLASSES)
for i, labs in enumerate(val_labels_raw):
for l in labs:
if l < N_CLASSES: val_label_matrix[i, l] = 1.0
print(f" Train: {N_train:,} Val: {N_val:,}")
# Load 3 experts
train_feats = []
val_feats = []
for name in EXPERTS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train")
feats = torch.zeros(N_train, D_EXPERT)
for row in ds:
if row["image_id"] in train_id_map:
feats[train_id_map[row["image_id"]]] = torch.tensor(
row["features"], dtype=torch.float32)
train_feats.append(feats)
ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
feats_v = torch.zeros(N_val, D_EXPERT)
for row in ds_v:
if row["image_id"] in val_id_map:
feats_v[val_id_map[row["image_id"]]] = torch.tensor(
row["features"], dtype=torch.float32)
val_feats.append(feats_v)
print(f" {name:<30} loaded", flush=True)
del ds, ds_v; gc.collect()
# Move val to GPU
val_feats_gpu = [f.to(DEVICE) for f in val_feats]
val_labels_gpu = val_label_matrix.to(DEVICE)
train_labels_gpu = train_label_matrix.to(DEVICE)
# ══════════════════════════════════════════════════════════════════
# BUILD MODEL
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("BUILDING MODEL")
print(f"{'='*65}")
model = BaseTierSoup(
n_experts=3, d_expert=D_EXPERT, d_anchor=D_ANCHOR,
n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
n_classes=N_CLASSES).to(DEVICE)
params = model.count_params()
print(f" Parameters:")
for k, v in params.items():
print(f" {k:<15}: {v:>10,}")
# ══════════════════════════════════════════════════════════════════
# TRAIN
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("TRAINING")
print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}")
print(f" Adam, no weight decay (geometry IS the regularization)")
print(f"{'='*65}")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_mAP = 0.0
from torch.utils.tensorboard import SummaryWriter
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/base_tier_soup")
gs = 0
for epoch in range(EPOCHS):
model.train()
perm = torch.randperm(N_train)
tl, nb = 0, 0
for i in range(0, N_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 4: continue
# Move batch to GPU
batch_experts = [train_feats[e][idx].to(DEVICE) for e in range(3)]
labels = train_labels_gpu[idx]
logits, fused, tri, nearest, projected = model(batch_experts)
anchors = model.constellation.anchors
# Classification
l_cls = F.binary_cross_entropy_with_logits(logits, labels)
# Geometric losses
l_cv = cv_loss(fused, target=0.2)
l_spread = anchor_spread_loss(anchors)
l_ent = anchor_entropy_loss(fused, anchors)
# Per-expert agreement: all projections should be close
l_agree = 0.0
for pi in range(3):
for pj in range(pi+1, 3):
l_agree += (1.0 - F.cosine_similarity(
projected[pi], projected[pj], dim=-1)).mean()
l_agree = l_agree / 3.0 # 3 pairs
loss = (l_cls
+ 0.001 * l_cv
+ 1e-3 * l_spread
+ 1e-4 * l_ent
+ 0.1 * l_agree)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
tl += loss.item(); nb += 1; gs += 1
if gs % 100 == 0:
writer.add_scalar("train/loss", loss.item(), gs)
writer.add_scalar("train/cls", l_cls.item(), gs)
writer.add_scalar("train/cv", l_cv.item(), gs)
writer.add_scalar("train/agree", l_agree, gs)
# Validation
model.eval()
with torch.no_grad():
all_lo, all_em = [], []
for j in range(0, N_val, BATCH):
end = min(j + BATCH, N_val)
batch_v = [val_feats_gpu[e][j:end] for e in range(3)]
lo, em, _, _, _ = model(batch_v, apply_autograd=False)
all_lo.append(lo.cpu()); all_em.append(em.cpu())
v_lo = torch.cat(all_lo); v_em = torch.cat(all_em)
# mAP
v_lab = val_label_matrix
ap_sum, nv = 0, 0
for c in range(N_CLASSES):
if v_lab[:, c].sum() > 0:
si = v_lo[:, c].argsort(descending=True)
st = v_lab[:, c][si]
pak = st.cumsum(0) / torch.arange(1, len(st)+1).float()
ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1
mAP = ap_sum / max(nv, 1)
# F1
vp = (v_lo.sigmoid() > 0.5).float()
tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0)
fn = ((1-vp) * v_lab).sum(0)
pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8)
f1 = 2*pr*rc/(pr+rc+1e-8)
macro_f1 = f1[f1 > 0].mean().item()
v_cv = cv_metric(v_em)
# Expert agreement
all_proj = []
for j in range(0, N_val, BATCH):
end = min(j + BATCH, N_val)
batch_v = [val_feats_gpu[e][j:end] for e in range(3)]
_, _, _, _, proj = model(batch_v, apply_autograd=False)
all_proj.append([p.cpu() for p in proj])
proj_stacked = [torch.cat([ap[e] for ap in all_proj]) for e in range(3)]
agree_01 = F.cosine_similarity(proj_stacked[0], proj_stacked[1], dim=-1).mean().item()
agree_02 = F.cosine_similarity(proj_stacked[0], proj_stacked[2], dim=-1).mean().item()
agree_12 = F.cosine_similarity(proj_stacked[1], proj_stacked[2], dim=-1).mean().item()
writer.add_scalar("val/mAP", mAP, epoch+1)
writer.add_scalar("val/F1", macro_f1, epoch+1)
writer.add_scalar("val/cv", v_cv, epoch+1)
writer.add_scalar("val/agree_clip_dino", agree_01, epoch+1)
writer.add_scalar("val/agree_clip_siglip", agree_02, epoch+1)
writer.add_scalar("val/agree_dino_siglip", agree_12, epoch+1)
mk = ""
if mAP > best_mAP:
best_mAP = mAP
torch.save({
"state_dict": model.state_dict(),
"config": {"d_expert": D_EXPERT, "d_anchor": D_ANCHOR,
"n_anchors": N_ANCHORS, "n_comp": N_COMP,
"d_comp": D_COMP, "n_classes": N_CLASSES,
"experts": EXPERTS},
"epoch": epoch+1, "mAP": mAP, "cv": v_cv,
}, "checkpoints/base_tier_best.pt")
mk = " ★"
print(f" E{epoch+1:2d}: mAP={mAP:.3f} F1={macro_f1:.3f} cv={v_cv:.4f} "
f"agree=[{agree_01:.3f},{agree_02:.3f},{agree_12:.3f}] "
f"loss={tl/nb:.4f}{mk}")
writer.close()
print(f"\n Best mAP: {best_mAP:.3f}")
print(f" Model: {params['total']:,} params")
print(f" Anchors: {N_ANCHORS} × {D_ANCHOR}-d")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")