#!/usr/bin/env python3 """ BASE TIER SOUP ANALYSIS ======================== Load the trained 800K param soup and examine: - Anchor geometry on the 128-d hypersphere - Projector alignment (do the 3 experts converge?) - Triangulation patterns (which anchors are used?) - Patchwork compartment activation profiles - Per-expert projected distributions - CV and volume geometry of the learned space - Per-class anchor affinity (which anchors serve which COCO classes?) """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import os import gc DEVICE = "cuda" if torch.cuda.is_available() else "cpu" D_EXPERT = 768 D_ANCHOR = 128 N_ANCHORS = 256 N_CLASSES = 80 N_COMP = 8 D_COMP = 64 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] print("=" * 65) print("BASE TIER SOUP ANALYSIS") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # LOAD MODEL + DATA # ══════════════════════════════════════════════════════════════════ # Rebuild model class (minimal, for loading) class ExpertProjector(nn.Module): def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR): super().__init__() self.proj = nn.Sequential(nn.Linear(d_in, d_out), nn.LayerNorm(d_out)) def forward(self, x): return F.normalize(self.proj(x), dim=-1) class Constellation(nn.Module): def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR): super().__init__() self.n_anchors = n_anchors self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): super().__init__() self.n_comp = n_comp asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for k in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) class BaseTierSoup(nn.Module): def __init__(self): super().__init__() self.n_experts = 3 self.projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) self.constellation = Constellation() self.patchwork = Patchwork() pw_dim = N_COMP * D_COMP self.classifier = nn.Sequential( nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.1), nn.Linear(pw_dim, N_CLASSES)) def forward(self, expert_embeddings, apply_autograd=False): projected = [self.projectors[i](expert_embeddings[i]) for i in range(3)] fused = F.normalize(sum(projected) / 3, dim=-1) tri, nearest = self.constellation.triangulate(fused) pw = self.patchwork(tri) logits = self.classifier(torch.cat([pw, fused], -1)) return logits, fused, tri, nearest, projected print(f"\n Loading checkpoint...") ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False) model = BaseTierSoup() model.load_state_dict(ckpt["state_dict"]) model = model.eval().to(DEVICE) print(f" Loaded: mAP={ckpt['mAP']:.3f} cv={ckpt['cv']:.4f} epoch={ckpt['epoch']}") # Load val data from datasets import load_dataset ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") val_ids = ref["image_id"]; N_val = len(val_ids) id_map = {iid: i for i, iid in enumerate(val_ids)} val_labels = torch.zeros(N_val, N_CLASSES) for i, labs in enumerate(ref["labels"]): for l in labs: if l < N_CLASSES: val_labels[i, l] = 1.0 val_feats = [] for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") feats = torch.zeros(N_val, D_EXPERT) for row in ds: if row["image_id"] in id_map: feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) val_feats.append(feats.to(DEVICE)) print(f" {name} loaded") del ds; gc.collect() # Run full val through model print(f"\n Running inference on {N_val} val images...") all_logits, all_fused, all_tri, all_nearest, all_proj = [], [], [], [], [[], [], []] BATCH = 256 with torch.no_grad(): for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) batch = [val_feats[e][j:end] for e in range(3)] lo, fu, tr, ne, pr = model(batch) all_logits.append(lo.cpu()) all_fused.append(fu.cpu()) all_tri.append(tr.cpu()) all_nearest.append(ne.cpu()) for e in range(3): all_proj[e].append(pr[e].cpu()) logits = torch.cat(all_logits) fused = torch.cat(all_fused) tri = torch.cat(all_tri) nearest = torch.cat(all_nearest) proj = [torch.cat(all_proj[e]) for e in range(3)] print(f" Done: fused={fused.shape} tri={tri.shape}") # ══════════════════════════════════════════════════════════════════ # SCAN 1: ANCHOR GEOMETRY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 1: ANCHOR GEOMETRY") print(f"{'='*65}") anchors = F.normalize(model.constellation.anchors.detach().cpu(), dim=-1) # Pairwise cosine anchor_sim = anchors @ anchors.T anchor_sim.fill_diagonal_(0) print(f" Anchor pairwise cosine:") print(f" mean={anchor_sim.mean():.4f} std={anchor_sim.std():.4f}") print(f" max={anchor_sim.max():.4f} min={anchor_sim.min():.4f}") # Distribution of max-neighbor cosine max_neighbor = anchor_sim.max(dim=1).values print(f" Max neighbor cosine per anchor:") print(f" mean={max_neighbor.mean():.4f} std={max_neighbor.std():.4f}") print(f" max={max_neighbor.max():.4f} min={max_neighbor.min():.4f}") # Anchor norms (should be ~1.0 after normalize) anchor_norms = anchors.norm(dim=-1) print(f" Anchor norms: mean={anchor_norms.mean():.6f} std={anchor_norms.std():.6f}") # SVD of anchor matrix sv = torch.linalg.svdvals(anchors) eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" Anchor spectral: eff_rank={eff_rank:.1f}/{min(anchors.shape)}") print(f" sv_max={sv[0]:.4f} sv_10={sv[9]:.4f} sv_50={sv[49]:.4f} sv_min={sv[-1]:.6f}") # Volume CV of anchors def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) vols = [] for _ in range(500): idx = torch.randperm(N_ANCHORS)[:5] v2 = cayley_menger_vol2(anchors[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) anchor_cv = np.std(vols) / (np.mean(vols) + 1e-8) print(f" Anchor pentachoron CV: {anchor_cv:.4f}") print(f" mean_vol={np.mean(vols):.6f} std_vol={np.std(vols):.6f}") # ══════════════════════════════════════════════════════════════════ # SCAN 2: ANCHOR UTILIZATION # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 2: ANCHOR UTILIZATION") print(f"{'='*65}") # How many images use each anchor as nearest anchor_counts = torch.bincount(nearest, minlength=N_ANCHORS).float() active = (anchor_counts > 0).sum().item() print(f" Active anchors: {active}/{N_ANCHORS} ({active/N_ANCHORS*100:.1f}%)") print(f" Visit counts: mean={anchor_counts.mean():.1f} std={anchor_counts.std():.1f}") print(f" max={anchor_counts.max():.0f} min={anchor_counts.min():.0f}") print(f" top 10: {anchor_counts.topk(10).values.long().tolist()}") print(f" bottom 10: {anchor_counts.sort().values[:10].long().tolist()}") # Entropy of anchor distribution probs = anchor_counts / anchor_counts.sum() entropy = -(probs[probs > 0] * probs[probs > 0].log()).sum().item() max_entropy = math.log(N_ANCHORS) print(f" Anchor entropy: {entropy:.4f} / {max_entropy:.4f} ({entropy/max_entropy*100:.1f}%)") # Per-anchor mean cosine to fused embeddings print(f"\n Per-anchor embedding density:") anchor_mean_cos = [] for a_idx in range(N_ANCHORS): mask = nearest == a_idx if mask.sum() < 2: anchor_mean_cos.append(0.0) continue cluster_embs = fused[mask] mean_cos = F.cosine_similarity( cluster_embs.unsqueeze(0), cluster_embs.unsqueeze(1), dim=-1) mean_cos.fill_diagonal_(0) n = cluster_embs.shape[0] avg = mean_cos.sum().item() / max(n * (n-1), 1) anchor_mean_cos.append(avg) amc = np.array(anchor_mean_cos) print(f" Intra-cluster cosine: mean={amc[amc>0].mean():.4f} std={amc[amc>0].std():.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 3: PROJECTOR ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 3: PROJECTOR ANALYSIS") print(f"{'='*65}") expert_names = ["clip_l14", "dinov2_b14", "siglip_b16"] # Per-expert projection stats for e, name in enumerate(expert_names): p = proj[e] print(f"\n {name}:") print(f" norm: mean={p.norm(dim=-1).mean():.6f} (should be 1.0)") print(f" self-sim off-diag: {(F.normalize(p,dim=-1) @ F.normalize(p,dim=-1).T).fill_diagonal_(0).mean():.4f}") # SVD of projected embeddings pc = p.float() - p.float().mean(0, keepdim=True) sv = torch.linalg.svdvals(pc) eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" eff_dim: {eff_dim:.1f}/{D_ANCHOR}") # Pairwise agreement print(f"\n Expert agreement (cosine in 128-d):") for i in range(3): for j in range(i+1, 3): cos = F.cosine_similarity(proj[i], proj[j], dim=-1) print(f" {expert_names[i]:<15} × {expert_names[j]:<15}: " f"mean={cos.mean():.4f} std={cos.std():.4f} min={cos.min():.4f}") # How different are the nearest anchors per expert? print(f"\n Per-expert nearest anchor agreement:") expert_nearest = [] for e in range(3): a = F.normalize(anchors, dim=-1) cos = proj[e] @ a.T en = cos.argmax(dim=-1) expert_nearest.append(en) for i in range(3): for j in range(i+1, 3): agree = (expert_nearest[i] == expert_nearest[j]).float().mean().item() print(f" {expert_names[i]:<15} × {expert_names[j]:<15}: " f"same_anchor={agree:.4f} ({agree*100:.1f}%)") # Projector weight analysis print(f"\n Projector weight comparison:") proj_weights = [] for e in range(3): w = model.projectors[e].proj[0].weight.detach().float() # (128, 768) proj_weights.append(w) sv = torch.linalg.svdvals(w) eff_r = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" {expert_names[e]:<15}: norm={w.norm():.4f} eff_rank={eff_r:.1f}/{min(w.shape)}") # Cross-projector cosine for i in range(3): for j in range(i+1, 3): cos = F.cosine_similarity( proj_weights[i].reshape(-1).unsqueeze(0), proj_weights[j].reshape(-1).unsqueeze(0)).item() print(f" {expert_names[i]:<15} × {expert_names[j]:<15} weight_cos={cos:.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 4: PATCHWORK COMPARTMENT ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 4: PATCHWORK COMPARTMENTS") print(f"{'='*65}") # Which anchors are in which compartment asgn = model.patchwork.asgn.cpu() for k in range(N_COMP): anchor_ids = (asgn == k).nonzero(as_tuple=True)[0] print(f" Comp {k}: {len(anchor_ids)} anchors") # Patchwork output analysis with torch.no_grad(): pw_all = [] for j in range(0, N_val, BATCH): end = min(j + BATCH, N_val) pw = model.patchwork(tri[j:end].to(DEVICE)) pw_all.append(pw.cpu()) pw_cat = torch.cat(pw_all) print(f"\n Patchwork output: {pw_cat.shape}") print(f" norm: mean={pw_cat.norm(dim=-1).mean():.4f} std={pw_cat.norm(dim=-1).std():.4f}") # Per-compartment output magnitude for k in range(N_COMP): comp_out = pw_cat[:, k*D_COMP:(k+1)*D_COMP] print(f" comp {k}: norm={comp_out.norm(dim=-1).mean():.4f} " f"std_across_dims={comp_out.std(dim=0).mean():.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 5: TRIANGULATION PATTERN ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 5: TRIANGULATION PATTERNS") print(f"{'='*65}") # Triangulation distance stats print(f" Triangulation distances (1-cosine):") print(f" mean={tri.mean():.4f} std={tri.std():.4f}") print(f" min={tri.min():.4f} max={tri.max():.4f}") # Nearest anchor distance nearest_dist = tri.gather(1, nearest.unsqueeze(1)).squeeze(1) print(f" Nearest anchor distance:") print(f" mean={nearest_dist.mean():.4f} std={nearest_dist.std():.4f}") print(f" max={nearest_dist.max():.4f} min={nearest_dist.min():.4f}") # How many anchors are "close" (cosine > 0.5, i.e. dist < 0.5) close_count = (tri < 0.5).float().sum(dim=1) print(f" Anchors within cos>0.5 per image:") print(f" mean={close_count.mean():.1f} std={close_count.std():.1f}") # Top-k nearest anchors — how spread are they? topk_dists = tri.topk(10, dim=1, largest=False) print(f" Top-10 nearest anchor distances:") for k_idx in range(10): d = topk_dists.values[:, k_idx] print(f" k={k_idx}: mean={d.mean():.4f} std={d.std():.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 6: PER-CLASS ANCHOR AFFINITY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 6: PER-CLASS ANCHOR AFFINITY") print(f"{'='*65}") # COCO class names (subset) coco_names = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow"] # For each class, which anchors are most associated? print(f"\n Top-3 anchors per class (first 20 classes):") for c in range(min(20, N_CLASSES)): mask = val_labels[:, c] > 0 if mask.sum() < 5: continue class_nearest = nearest[mask] counts = torch.bincount(class_nearest, minlength=N_ANCHORS) top3 = counts.topk(3) name = coco_names[c] if c < len(coco_names) else f"class_{c}" total = mask.sum().item() pcts = [f"{top3.indices[k]}({top3.values[k].item()}/{total})" for k in range(3)] print(f" {name:<15} (n={total:4d}): {' '.join(pcts)}") # Anchor specialization: how many classes does each anchor serve? anchor_class_count = torch.zeros(N_ANCHORS) for a in range(N_ANCHORS): mask = nearest == a if mask.sum() < 1: continue class_present = val_labels[mask].sum(0) > 0 anchor_class_count[a] = class_present.sum().item() print(f"\n Anchor specialization:") print(f" classes per anchor: mean={anchor_class_count[anchor_class_count>0].mean():.1f} " f"std={anchor_class_count[anchor_class_count>0].std():.1f}") print(f" max={anchor_class_count.max():.0f} min={anchor_class_count[anchor_class_count>0].min():.0f}") # ══════════════════════════════════════════════════════════════════ # SCAN 7: FUSED EMBEDDING GEOMETRY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 7: FUSED EMBEDDING GEOMETRY") print(f"{'='*65}") # Norms (should be 1.0) fused_norms = fused.norm(dim=-1) print(f" Norms: mean={fused_norms.mean():.6f} std={fused_norms.std():.6f}") # Self-similarity fused_n = F.normalize(fused, dim=-1) self_sim = fused_n @ fused_n.T self_sim_off = (self_sim.sum() - self_sim.diag().sum()) / (N_val**2 - N_val) print(f" Self-sim (off-diag): {self_sim_off:.4f}") # SVD fc = fused.float() - fused.float().mean(0, keepdim=True) sv = torch.linalg.svdvals(fc) eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" Effective dim: {eff_dim:.1f}/{D_ANCHOR}") cumvar = sv.pow(2).cumsum(0) / sv.pow(2).sum() for k in [5, 10, 20, 50, 100]: if k-1 < len(cumvar): print(f" top-{k} SVs explain {cumvar[k-1]*100:.1f}%") # CV vols = [] for _ in range(500): idx = torch.randperm(N_val)[:5] v2 = cayley_menger_vol2(fused_n[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) fused_cv = np.std(vols) / (np.mean(vols) + 1e-8) print(f" Pentachoron CV: {fused_cv:.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 8: EXPERT CONTRIBUTION ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 8: EXPERT CONTRIBUTION") print(f"{'='*65}") # How much does each expert contribute to the fused embedding? # cos(expert_proj, fused) tells us alignment for e, name in enumerate(expert_names): cos = F.cosine_similarity(proj[e], fused, dim=-1) print(f" {name:<15}: cos_to_fused mean={cos.mean():.4f} std={cos.std():.4f}") # Residual after removing each expert for e, name in enumerate(expert_names): others = [proj[i] for i in range(3) if i != e] fused_without = F.normalize(sum(others) / 2, dim=-1) delta = F.cosine_similarity(fused, fused_without, dim=-1) print(f" Without {name:<15}: cos_to_full={delta.mean():.4f} (uniqueness={1-delta.mean():.4f})") # Per-image expert disagreement print(f"\n Per-image expert disagreement:") all_cos = [] for i in range(3): for j in range(i+1, 3): cos = F.cosine_similarity(proj[i], proj[j], dim=-1) all_cos.append(cos) stacked = torch.stack(all_cos, dim=1) # (N, 3) per_image_agree = stacked.mean(dim=1) per_image_disagree = stacked.std(dim=1) print(f" Agreement: mean={per_image_agree.mean():.4f} std={per_image_agree.std():.4f}") print(f" Disagreement: mean={per_image_disagree.mean():.4f} std={per_image_disagree.std():.4f}") # Most agreed and disagreed images most_agree_idx = per_image_agree.argmax().item() most_disagree_idx = per_image_agree.argmin().item() print(f"\n Most agreed image ({most_agree_idx}): agreement={per_image_agree[most_agree_idx]:.4f}") print(f" labels: {val_labels[most_agree_idx].nonzero(as_tuple=True)[0].tolist()}") print(f" Most disagreed image ({most_disagree_idx}): agreement={per_image_agree[most_disagree_idx]:.4f}") print(f" labels: {val_labels[most_disagree_idx].nonzero(as_tuple=True)[0].tolist()}") print(f"\n{'='*65}") print("ANALYSIS COMPLETE") print(f"{'='*65}")