| |
| """ |
| BASE TIER SOUP ANALYSIS |
| ======================== |
| Load the trained 800K param soup and examine: |
| - Anchor geometry on the 128-d hypersphere |
| - Projector alignment (do the 3 experts converge?) |
| - Triangulation patterns (which anchors are used?) |
| - Patchwork compartment activation profiles |
| - Per-expert projected distributions |
| - CV and volume geometry of the learned space |
| - Per-class anchor affinity (which anchors serve which COCO classes?) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import os |
| import gc |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| D_EXPERT = 768 |
| D_ANCHOR = 128 |
| N_ANCHORS = 256 |
| N_CLASSES = 80 |
| N_COMP = 8 |
| D_COMP = 64 |
| EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
|
|
| print("=" * 65) |
| print("BASE TIER SOUP ANALYSIS") |
| print(f" Device: {DEVICE}") |
| print("=" * 65) |
|
|
|
|
| |
| |
| |
|
|
| |
| class ExpertProjector(nn.Module): |
| def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR): |
| super().__init__() |
| self.proj = nn.Sequential(nn.Linear(d_in, d_out), nn.LayerNorm(d_out)) |
| def forward(self, x): |
| return F.normalize(self.proj(x), dim=-1) |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) |
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| cos = emb @ a.T |
| return 1.0 - cos, cos.argmax(dim=-1) |
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): |
| super().__init__() |
| self.n_comp = n_comp |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for k in range(n_comp)]) |
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
| class BaseTierSoup(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.n_experts = 3 |
| self.projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) |
| self.constellation = Constellation() |
| self.patchwork = Patchwork() |
| pw_dim = N_COMP * D_COMP |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(0.1), |
| nn.Linear(pw_dim, N_CLASSES)) |
| def forward(self, expert_embeddings, apply_autograd=False): |
| projected = [self.projectors[i](expert_embeddings[i]) for i in range(3)] |
| fused = F.normalize(sum(projected) / 3, dim=-1) |
| tri, nearest = self.constellation.triangulate(fused) |
| pw = self.patchwork(tri) |
| logits = self.classifier(torch.cat([pw, fused], -1)) |
| return logits, fused, tri, nearest, projected |
|
|
| print(f"\n Loading checkpoint...") |
| ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False) |
| model = BaseTierSoup() |
| model.load_state_dict(ckpt["state_dict"]) |
| model = model.eval().to(DEVICE) |
| print(f" Loaded: mAP={ckpt['mAP']:.3f} cv={ckpt['cv']:.4f} epoch={ckpt['epoch']}") |
|
|
| |
| from datasets import load_dataset |
| ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") |
| val_ids = ref["image_id"]; N_val = len(val_ids) |
| id_map = {iid: i for i, iid in enumerate(val_ids)} |
| val_labels = torch.zeros(N_val, N_CLASSES) |
| for i, labs in enumerate(ref["labels"]): |
| for l in labs: |
| if l < N_CLASSES: val_labels[i, l] = 1.0 |
|
|
| val_feats = [] |
| for name in EXPERTS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| feats = torch.zeros(N_val, D_EXPERT) |
| for row in ds: |
| if row["image_id"] in id_map: |
| feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| val_feats.append(feats.to(DEVICE)) |
| print(f" {name} loaded") |
| del ds; gc.collect() |
|
|
| |
| print(f"\n Running inference on {N_val} val images...") |
| all_logits, all_fused, all_tri, all_nearest, all_proj = [], [], [], [], [[], [], []] |
| BATCH = 256 |
| with torch.no_grad(): |
| for j in range(0, N_val, BATCH): |
| end = min(j + BATCH, N_val) |
| batch = [val_feats[e][j:end] for e in range(3)] |
| lo, fu, tr, ne, pr = model(batch) |
| all_logits.append(lo.cpu()) |
| all_fused.append(fu.cpu()) |
| all_tri.append(tr.cpu()) |
| all_nearest.append(ne.cpu()) |
| for e in range(3): |
| all_proj[e].append(pr[e].cpu()) |
|
|
| logits = torch.cat(all_logits) |
| fused = torch.cat(all_fused) |
| tri = torch.cat(all_tri) |
| nearest = torch.cat(all_nearest) |
| proj = [torch.cat(all_proj[e]) for e in range(3)] |
| print(f" Done: fused={fused.shape} tri={tri.shape}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 1: ANCHOR GEOMETRY") |
| print(f"{'='*65}") |
|
|
| anchors = F.normalize(model.constellation.anchors.detach().cpu(), dim=-1) |
|
|
| |
| anchor_sim = anchors @ anchors.T |
| anchor_sim.fill_diagonal_(0) |
|
|
| print(f" Anchor pairwise cosine:") |
| print(f" mean={anchor_sim.mean():.4f} std={anchor_sim.std():.4f}") |
| print(f" max={anchor_sim.max():.4f} min={anchor_sim.min():.4f}") |
|
|
| |
| max_neighbor = anchor_sim.max(dim=1).values |
| print(f" Max neighbor cosine per anchor:") |
| print(f" mean={max_neighbor.mean():.4f} std={max_neighbor.std():.4f}") |
| print(f" max={max_neighbor.max():.4f} min={max_neighbor.min():.4f}") |
|
|
| |
| anchor_norms = anchors.norm(dim=-1) |
| print(f" Anchor norms: mean={anchor_norms.mean():.6f} std={anchor_norms.std():.6f}") |
|
|
| |
| sv = torch.linalg.svdvals(anchors) |
| eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() |
| print(f" Anchor spectral: eff_rank={eff_rank:.1f}/{min(anchors.shape)}") |
| print(f" sv_max={sv[0]:.4f} sv_10={sv[9]:.4f} sv_50={sv[49]:.4f} sv_min={sv[-1]:.6f}") |
|
|
| |
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| vols = [] |
| for _ in range(500): |
| idx = torch.randperm(N_ANCHORS)[:5] |
| v2 = cayley_menger_vol2(anchors[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| anchor_cv = np.std(vols) / (np.mean(vols) + 1e-8) |
| print(f" Anchor pentachoron CV: {anchor_cv:.4f}") |
| print(f" mean_vol={np.mean(vols):.6f} std_vol={np.std(vols):.6f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 2: ANCHOR UTILIZATION") |
| print(f"{'='*65}") |
|
|
| |
| anchor_counts = torch.bincount(nearest, minlength=N_ANCHORS).float() |
| active = (anchor_counts > 0).sum().item() |
| print(f" Active anchors: {active}/{N_ANCHORS} ({active/N_ANCHORS*100:.1f}%)") |
| print(f" Visit counts: mean={anchor_counts.mean():.1f} std={anchor_counts.std():.1f}") |
| print(f" max={anchor_counts.max():.0f} min={anchor_counts.min():.0f}") |
| print(f" top 10: {anchor_counts.topk(10).values.long().tolist()}") |
| print(f" bottom 10: {anchor_counts.sort().values[:10].long().tolist()}") |
|
|
| |
| probs = anchor_counts / anchor_counts.sum() |
| entropy = -(probs[probs > 0] * probs[probs > 0].log()).sum().item() |
| max_entropy = math.log(N_ANCHORS) |
| print(f" Anchor entropy: {entropy:.4f} / {max_entropy:.4f} ({entropy/max_entropy*100:.1f}%)") |
|
|
| |
| print(f"\n Per-anchor embedding density:") |
| anchor_mean_cos = [] |
| for a_idx in range(N_ANCHORS): |
| mask = nearest == a_idx |
| if mask.sum() < 2: |
| anchor_mean_cos.append(0.0) |
| continue |
| cluster_embs = fused[mask] |
| mean_cos = F.cosine_similarity( |
| cluster_embs.unsqueeze(0), cluster_embs.unsqueeze(1), dim=-1) |
| mean_cos.fill_diagonal_(0) |
| n = cluster_embs.shape[0] |
| avg = mean_cos.sum().item() / max(n * (n-1), 1) |
| anchor_mean_cos.append(avg) |
| amc = np.array(anchor_mean_cos) |
| print(f" Intra-cluster cosine: mean={amc[amc>0].mean():.4f} std={amc[amc>0].std():.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 3: PROJECTOR ANALYSIS") |
| print(f"{'='*65}") |
|
|
| expert_names = ["clip_l14", "dinov2_b14", "siglip_b16"] |
|
|
| |
| for e, name in enumerate(expert_names): |
| p = proj[e] |
| print(f"\n {name}:") |
| print(f" norm: mean={p.norm(dim=-1).mean():.6f} (should be 1.0)") |
| print(f" self-sim off-diag: {(F.normalize(p,dim=-1) @ F.normalize(p,dim=-1).T).fill_diagonal_(0).mean():.4f}") |
|
|
| |
| pc = p.float() - p.float().mean(0, keepdim=True) |
| sv = torch.linalg.svdvals(pc) |
| eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() |
| print(f" eff_dim: {eff_dim:.1f}/{D_ANCHOR}") |
|
|
| |
| print(f"\n Expert agreement (cosine in 128-d):") |
| for i in range(3): |
| for j in range(i+1, 3): |
| cos = F.cosine_similarity(proj[i], proj[j], dim=-1) |
| print(f" {expert_names[i]:<15} Γ {expert_names[j]:<15}: " |
| f"mean={cos.mean():.4f} std={cos.std():.4f} min={cos.min():.4f}") |
|
|
| |
| print(f"\n Per-expert nearest anchor agreement:") |
| expert_nearest = [] |
| for e in range(3): |
| a = F.normalize(anchors, dim=-1) |
| cos = proj[e] @ a.T |
| en = cos.argmax(dim=-1) |
| expert_nearest.append(en) |
| for i in range(3): |
| for j in range(i+1, 3): |
| agree = (expert_nearest[i] == expert_nearest[j]).float().mean().item() |
| print(f" {expert_names[i]:<15} Γ {expert_names[j]:<15}: " |
| f"same_anchor={agree:.4f} ({agree*100:.1f}%)") |
|
|
| |
| print(f"\n Projector weight comparison:") |
| proj_weights = [] |
| for e in range(3): |
| w = model.projectors[e].proj[0].weight.detach().float() |
| proj_weights.append(w) |
| sv = torch.linalg.svdvals(w) |
| eff_r = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() |
| print(f" {expert_names[e]:<15}: norm={w.norm():.4f} eff_rank={eff_r:.1f}/{min(w.shape)}") |
|
|
| |
| for i in range(3): |
| for j in range(i+1, 3): |
| cos = F.cosine_similarity( |
| proj_weights[i].reshape(-1).unsqueeze(0), |
| proj_weights[j].reshape(-1).unsqueeze(0)).item() |
| print(f" {expert_names[i]:<15} Γ {expert_names[j]:<15} weight_cos={cos:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 4: PATCHWORK COMPARTMENTS") |
| print(f"{'='*65}") |
|
|
| |
| asgn = model.patchwork.asgn.cpu() |
| for k in range(N_COMP): |
| anchor_ids = (asgn == k).nonzero(as_tuple=True)[0] |
| print(f" Comp {k}: {len(anchor_ids)} anchors") |
|
|
| |
| with torch.no_grad(): |
| pw_all = [] |
| for j in range(0, N_val, BATCH): |
| end = min(j + BATCH, N_val) |
| pw = model.patchwork(tri[j:end].to(DEVICE)) |
| pw_all.append(pw.cpu()) |
| pw_cat = torch.cat(pw_all) |
|
|
| print(f"\n Patchwork output: {pw_cat.shape}") |
| print(f" norm: mean={pw_cat.norm(dim=-1).mean():.4f} std={pw_cat.norm(dim=-1).std():.4f}") |
|
|
| |
| for k in range(N_COMP): |
| comp_out = pw_cat[:, k*D_COMP:(k+1)*D_COMP] |
| print(f" comp {k}: norm={comp_out.norm(dim=-1).mean():.4f} " |
| f"std_across_dims={comp_out.std(dim=0).mean():.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 5: TRIANGULATION PATTERNS") |
| print(f"{'='*65}") |
|
|
| |
| print(f" Triangulation distances (1-cosine):") |
| print(f" mean={tri.mean():.4f} std={tri.std():.4f}") |
| print(f" min={tri.min():.4f} max={tri.max():.4f}") |
|
|
| |
| nearest_dist = tri.gather(1, nearest.unsqueeze(1)).squeeze(1) |
| print(f" Nearest anchor distance:") |
| print(f" mean={nearest_dist.mean():.4f} std={nearest_dist.std():.4f}") |
| print(f" max={nearest_dist.max():.4f} min={nearest_dist.min():.4f}") |
|
|
| |
| close_count = (tri < 0.5).float().sum(dim=1) |
| print(f" Anchors within cos>0.5 per image:") |
| print(f" mean={close_count.mean():.1f} std={close_count.std():.1f}") |
|
|
| |
| topk_dists = tri.topk(10, dim=1, largest=False) |
| print(f" Top-10 nearest anchor distances:") |
| for k_idx in range(10): |
| d = topk_dists.values[:, k_idx] |
| print(f" k={k_idx}: mean={d.mean():.4f} std={d.std():.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 6: PER-CLASS ANCHOR AFFINITY") |
| print(f"{'='*65}") |
|
|
| |
| coco_names = ["person", "bicycle", "car", "motorcycle", "airplane", |
| "bus", "train", "truck", "boat", "traffic light", |
| "fire hydrant", "stop sign", "parking meter", "bench", "bird", |
| "cat", "dog", "horse", "sheep", "cow"] |
|
|
| |
| print(f"\n Top-3 anchors per class (first 20 classes):") |
| for c in range(min(20, N_CLASSES)): |
| mask = val_labels[:, c] > 0 |
| if mask.sum() < 5: continue |
| class_nearest = nearest[mask] |
| counts = torch.bincount(class_nearest, minlength=N_ANCHORS) |
| top3 = counts.topk(3) |
| name = coco_names[c] if c < len(coco_names) else f"class_{c}" |
| total = mask.sum().item() |
| pcts = [f"{top3.indices[k]}({top3.values[k].item()}/{total})" for k in range(3)] |
| print(f" {name:<15} (n={total:4d}): {' '.join(pcts)}") |
|
|
| |
| anchor_class_count = torch.zeros(N_ANCHORS) |
| for a in range(N_ANCHORS): |
| mask = nearest == a |
| if mask.sum() < 1: continue |
| class_present = val_labels[mask].sum(0) > 0 |
| anchor_class_count[a] = class_present.sum().item() |
| print(f"\n Anchor specialization:") |
| print(f" classes per anchor: mean={anchor_class_count[anchor_class_count>0].mean():.1f} " |
| f"std={anchor_class_count[anchor_class_count>0].std():.1f}") |
| print(f" max={anchor_class_count.max():.0f} min={anchor_class_count[anchor_class_count>0].min():.0f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 7: FUSED EMBEDDING GEOMETRY") |
| print(f"{'='*65}") |
|
|
| |
| fused_norms = fused.norm(dim=-1) |
| print(f" Norms: mean={fused_norms.mean():.6f} std={fused_norms.std():.6f}") |
|
|
| |
| fused_n = F.normalize(fused, dim=-1) |
| self_sim = fused_n @ fused_n.T |
| self_sim_off = (self_sim.sum() - self_sim.diag().sum()) / (N_val**2 - N_val) |
| print(f" Self-sim (off-diag): {self_sim_off:.4f}") |
|
|
| |
| fc = fused.float() - fused.float().mean(0, keepdim=True) |
| sv = torch.linalg.svdvals(fc) |
| eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() |
| print(f" Effective dim: {eff_dim:.1f}/{D_ANCHOR}") |
| cumvar = sv.pow(2).cumsum(0) / sv.pow(2).sum() |
| for k in [5, 10, 20, 50, 100]: |
| if k-1 < len(cumvar): |
| print(f" top-{k} SVs explain {cumvar[k-1]*100:.1f}%") |
|
|
| |
| vols = [] |
| for _ in range(500): |
| idx = torch.randperm(N_val)[:5] |
| v2 = cayley_menger_vol2(fused_n[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| fused_cv = np.std(vols) / (np.mean(vols) + 1e-8) |
| print(f" Pentachoron CV: {fused_cv:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("SCAN 8: EXPERT CONTRIBUTION") |
| print(f"{'='*65}") |
|
|
| |
| |
| for e, name in enumerate(expert_names): |
| cos = F.cosine_similarity(proj[e], fused, dim=-1) |
| print(f" {name:<15}: cos_to_fused mean={cos.mean():.4f} std={cos.std():.4f}") |
|
|
| |
| for e, name in enumerate(expert_names): |
| others = [proj[i] for i in range(3) if i != e] |
| fused_without = F.normalize(sum(others) / 2, dim=-1) |
| delta = F.cosine_similarity(fused, fused_without, dim=-1) |
| print(f" Without {name:<15}: cos_to_full={delta.mean():.4f} (uniqueness={1-delta.mean():.4f})") |
|
|
| |
| print(f"\n Per-image expert disagreement:") |
| all_cos = [] |
| for i in range(3): |
| for j in range(i+1, 3): |
| cos = F.cosine_similarity(proj[i], proj[j], dim=-1) |
| all_cos.append(cos) |
| stacked = torch.stack(all_cos, dim=1) |
| per_image_agree = stacked.mean(dim=1) |
| per_image_disagree = stacked.std(dim=1) |
| print(f" Agreement: mean={per_image_agree.mean():.4f} std={per_image_agree.std():.4f}") |
| print(f" Disagreement: mean={per_image_disagree.mean():.4f} std={per_image_disagree.std():.4f}") |
|
|
| |
| most_agree_idx = per_image_agree.argmax().item() |
| most_disagree_idx = per_image_agree.argmin().item() |
| print(f"\n Most agreed image ({most_agree_idx}): agreement={per_image_agree[most_agree_idx]:.4f}") |
| print(f" labels: {val_labels[most_agree_idx].nonzero(as_tuple=True)[0].tolist()}") |
| print(f" Most disagreed image ({most_disagree_idx}): agreement={per_image_agree[most_disagree_idx]:.4f}") |
| print(f" labels: {val_labels[most_disagree_idx].nonzero(as_tuple=True)[0].tolist()}") |
|
|
|
|
| print(f"\n{'='*65}") |
| print("ANALYSIS COMPLETE") |
| print(f"{'='*65}") |