File size: 21,147 Bytes
797187d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
#!/usr/bin/env python3
"""
BASE TIER SOUP ANALYSIS
========================
Load the trained 800K param soup and examine:
  - Anchor geometry on the 128-d hypersphere
  - Projector alignment (do the 3 experts converge?)
  - Triangulation patterns (which anchors are used?)
  - Patchwork compartment activation profiles
  - Per-expert projected distributions
  - CV and volume geometry of the learned space
  - Per-class anchor affinity (which anchors serve which COCO classes?)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

D_EXPERT = 768
D_ANCHOR = 128
N_ANCHORS = 256
N_CLASSES = 80
N_COMP = 8
D_COMP = 64
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]

print("=" * 65)
print("BASE TIER SOUP ANALYSIS")
print(f"  Device: {DEVICE}")
print("=" * 65)


# ══════════════════════════════════════════════════════════════════
# LOAD MODEL + DATA
# ══════════════════════════════════════════════════════════════════

# Rebuild model class (minimal, for loading)
class ExpertProjector(nn.Module):
    def __init__(self, d_in=D_EXPERT, d_out=D_ANCHOR):
        super().__init__()
        self.proj = nn.Sequential(nn.Linear(d_in, d_out), nn.LayerNorm(d_out))
    def forward(self, x):
        return F.normalize(self.proj(x), dim=-1)

class Constellation(nn.Module):
    def __init__(self, n_anchors=N_ANCHORS, d=D_ANCHOR):
        super().__init__()
        self.n_anchors = n_anchors
        self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
    def triangulate(self, emb):
        a = F.normalize(self.anchors, dim=-1)
        cos = emb @ a.T
        return 1.0 - cos, cos.argmax(dim=-1)

class Patchwork(nn.Module):
    def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP):
        super().__init__()
        self.n_comp = n_comp
        asgn = torch.arange(n_anchors) % n_comp
        self.register_buffer("asgn", asgn)
        self.comps = nn.ModuleList([nn.Sequential(
            nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(),
            nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
            for k in range(n_comp)])
    def forward(self, tri):
        return torch.cat([self.comps[k](tri[:, self.asgn == k])
                         for k in range(self.n_comp)], -1)

class BaseTierSoup(nn.Module):
    def __init__(self):
        super().__init__()
        self.n_experts = 3
        self.projectors = nn.ModuleList([ExpertProjector() for _ in range(3)])
        self.constellation = Constellation()
        self.patchwork = Patchwork()
        pw_dim = N_COMP * D_COMP
        self.classifier = nn.Sequential(
            nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(),
            nn.LayerNorm(pw_dim), nn.Dropout(0.1),
            nn.Linear(pw_dim, N_CLASSES))
    def forward(self, expert_embeddings, apply_autograd=False):
        projected = [self.projectors[i](expert_embeddings[i]) for i in range(3)]
        fused = F.normalize(sum(projected) / 3, dim=-1)
        tri, nearest = self.constellation.triangulate(fused)
        pw = self.patchwork(tri)
        logits = self.classifier(torch.cat([pw, fused], -1))
        return logits, fused, tri, nearest, projected

print(f"\n  Loading checkpoint...")
ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False)
model = BaseTierSoup()
model.load_state_dict(ckpt["state_dict"])
model = model.eval().to(DEVICE)
print(f"  Loaded: mAP={ckpt['mAP']:.3f} cv={ckpt['cv']:.4f} epoch={ckpt['epoch']}")

# Load val data
from datasets import load_dataset
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
val_ids = ref["image_id"]; N_val = len(val_ids)
id_map = {iid: i for i, iid in enumerate(val_ids)}
val_labels = torch.zeros(N_val, N_CLASSES)
for i, labs in enumerate(ref["labels"]):
    for l in labs:
        if l < N_CLASSES: val_labels[i, l] = 1.0

val_feats = []
for name in EXPERTS:
    ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
    feats = torch.zeros(N_val, D_EXPERT)
    for row in ds:
        if row["image_id"] in id_map:
            feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
    val_feats.append(feats.to(DEVICE))
    print(f"  {name} loaded")
    del ds; gc.collect()

# Run full val through model
print(f"\n  Running inference on {N_val} val images...")
all_logits, all_fused, all_tri, all_nearest, all_proj = [], [], [], [], [[], [], []]
BATCH = 256
with torch.no_grad():
    for j in range(0, N_val, BATCH):
        end = min(j + BATCH, N_val)
        batch = [val_feats[e][j:end] for e in range(3)]
        lo, fu, tr, ne, pr = model(batch)
        all_logits.append(lo.cpu())
        all_fused.append(fu.cpu())
        all_tri.append(tr.cpu())
        all_nearest.append(ne.cpu())
        for e in range(3):
            all_proj[e].append(pr[e].cpu())

logits = torch.cat(all_logits)
fused = torch.cat(all_fused)
tri = torch.cat(all_tri)
nearest = torch.cat(all_nearest)
proj = [torch.cat(all_proj[e]) for e in range(3)]
print(f"  Done: fused={fused.shape} tri={tri.shape}")


# ══════════════════════════════════════════════════════════════════
# SCAN 1: ANCHOR GEOMETRY
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 1: ANCHOR GEOMETRY")
print(f"{'='*65}")

anchors = F.normalize(model.constellation.anchors.detach().cpu(), dim=-1)

# Pairwise cosine
anchor_sim = anchors @ anchors.T
anchor_sim.fill_diagonal_(0)

print(f"  Anchor pairwise cosine:")
print(f"    mean={anchor_sim.mean():.4f} std={anchor_sim.std():.4f}")
print(f"    max={anchor_sim.max():.4f} min={anchor_sim.min():.4f}")

# Distribution of max-neighbor cosine
max_neighbor = anchor_sim.max(dim=1).values
print(f"  Max neighbor cosine per anchor:")
print(f"    mean={max_neighbor.mean():.4f} std={max_neighbor.std():.4f}")
print(f"    max={max_neighbor.max():.4f} min={max_neighbor.min():.4f}")

# Anchor norms (should be ~1.0 after normalize)
anchor_norms = anchors.norm(dim=-1)
print(f"  Anchor norms: mean={anchor_norms.mean():.6f} std={anchor_norms.std():.6f}")

# SVD of anchor matrix
sv = torch.linalg.svdvals(anchors)
eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
print(f"  Anchor spectral: eff_rank={eff_rank:.1f}/{min(anchors.shape)}")
print(f"    sv_max={sv[0]:.4f} sv_10={sv[9]:.4f} sv_50={sv[49]:.4f} sv_min={sv[-1]:.6f}")

# Volume CV of anchors
def cayley_menger_vol2(pts):
    pts = pts.float()
    diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
    d2 = (diff * diff).sum(-1)
    B, V, _ = d2.shape
    cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
    cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
    s = (-1.0)**V; f = math.factorial(V-1)
    return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)

vols = []
for _ in range(500):
    idx = torch.randperm(N_ANCHORS)[:5]
    v2 = cayley_menger_vol2(anchors[idx].unsqueeze(0))
    v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
    if v > 0: vols.append(v)
anchor_cv = np.std(vols) / (np.mean(vols) + 1e-8)
print(f"  Anchor pentachoron CV: {anchor_cv:.4f}")
print(f"    mean_vol={np.mean(vols):.6f} std_vol={np.std(vols):.6f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 2: ANCHOR UTILIZATION
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 2: ANCHOR UTILIZATION")
print(f"{'='*65}")

# How many images use each anchor as nearest
anchor_counts = torch.bincount(nearest, minlength=N_ANCHORS).float()
active = (anchor_counts > 0).sum().item()
print(f"  Active anchors: {active}/{N_ANCHORS} ({active/N_ANCHORS*100:.1f}%)")
print(f"  Visit counts: mean={anchor_counts.mean():.1f} std={anchor_counts.std():.1f}")
print(f"    max={anchor_counts.max():.0f} min={anchor_counts.min():.0f}")
print(f"    top 10: {anchor_counts.topk(10).values.long().tolist()}")
print(f"    bottom 10: {anchor_counts.sort().values[:10].long().tolist()}")

# Entropy of anchor distribution
probs = anchor_counts / anchor_counts.sum()
entropy = -(probs[probs > 0] * probs[probs > 0].log()).sum().item()
max_entropy = math.log(N_ANCHORS)
print(f"  Anchor entropy: {entropy:.4f} / {max_entropy:.4f} ({entropy/max_entropy*100:.1f}%)")

# Per-anchor mean cosine to fused embeddings
print(f"\n  Per-anchor embedding density:")
anchor_mean_cos = []
for a_idx in range(N_ANCHORS):
    mask = nearest == a_idx
    if mask.sum() < 2:
        anchor_mean_cos.append(0.0)
        continue
    cluster_embs = fused[mask]
    mean_cos = F.cosine_similarity(
        cluster_embs.unsqueeze(0), cluster_embs.unsqueeze(1), dim=-1)
    mean_cos.fill_diagonal_(0)
    n = cluster_embs.shape[0]
    avg = mean_cos.sum().item() / max(n * (n-1), 1)
    anchor_mean_cos.append(avg)
amc = np.array(anchor_mean_cos)
print(f"    Intra-cluster cosine: mean={amc[amc>0].mean():.4f} std={amc[amc>0].std():.4f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 3: PROJECTOR ANALYSIS
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 3: PROJECTOR ANALYSIS")
print(f"{'='*65}")

expert_names = ["clip_l14", "dinov2_b14", "siglip_b16"]

# Per-expert projection stats
for e, name in enumerate(expert_names):
    p = proj[e]
    print(f"\n  {name}:")
    print(f"    norm: mean={p.norm(dim=-1).mean():.6f} (should be 1.0)")
    print(f"    self-sim off-diag: {(F.normalize(p,dim=-1) @ F.normalize(p,dim=-1).T).fill_diagonal_(0).mean():.4f}")

    # SVD of projected embeddings
    pc = p.float() - p.float().mean(0, keepdim=True)
    sv = torch.linalg.svdvals(pc)
    eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
    print(f"    eff_dim: {eff_dim:.1f}/{D_ANCHOR}")

# Pairwise agreement
print(f"\n  Expert agreement (cosine in 128-d):")
for i in range(3):
    for j in range(i+1, 3):
        cos = F.cosine_similarity(proj[i], proj[j], dim=-1)
        print(f"    {expert_names[i]:<15} Γ— {expert_names[j]:<15}: "
              f"mean={cos.mean():.4f} std={cos.std():.4f} min={cos.min():.4f}")

# How different are the nearest anchors per expert?
print(f"\n  Per-expert nearest anchor agreement:")
expert_nearest = []
for e in range(3):
    a = F.normalize(anchors, dim=-1)
    cos = proj[e] @ a.T
    en = cos.argmax(dim=-1)
    expert_nearest.append(en)
for i in range(3):
    for j in range(i+1, 3):
        agree = (expert_nearest[i] == expert_nearest[j]).float().mean().item()
        print(f"    {expert_names[i]:<15} Γ— {expert_names[j]:<15}: "
              f"same_anchor={agree:.4f} ({agree*100:.1f}%)")

# Projector weight analysis
print(f"\n  Projector weight comparison:")
proj_weights = []
for e in range(3):
    w = model.projectors[e].proj[0].weight.detach().float()  # (128, 768)
    proj_weights.append(w)
    sv = torch.linalg.svdvals(w)
    eff_r = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
    print(f"    {expert_names[e]:<15}: norm={w.norm():.4f} eff_rank={eff_r:.1f}/{min(w.shape)}")

# Cross-projector cosine
for i in range(3):
    for j in range(i+1, 3):
        cos = F.cosine_similarity(
            proj_weights[i].reshape(-1).unsqueeze(0),
            proj_weights[j].reshape(-1).unsqueeze(0)).item()
        print(f"    {expert_names[i]:<15} Γ— {expert_names[j]:<15} weight_cos={cos:.4f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 4: PATCHWORK COMPARTMENT ANALYSIS
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 4: PATCHWORK COMPARTMENTS")
print(f"{'='*65}")

# Which anchors are in which compartment
asgn = model.patchwork.asgn.cpu()
for k in range(N_COMP):
    anchor_ids = (asgn == k).nonzero(as_tuple=True)[0]
    print(f"  Comp {k}: {len(anchor_ids)} anchors")

# Patchwork output analysis
with torch.no_grad():
    pw_all = []
    for j in range(0, N_val, BATCH):
        end = min(j + BATCH, N_val)
        pw = model.patchwork(tri[j:end].to(DEVICE))
        pw_all.append(pw.cpu())
    pw_cat = torch.cat(pw_all)

print(f"\n  Patchwork output: {pw_cat.shape}")
print(f"    norm: mean={pw_cat.norm(dim=-1).mean():.4f} std={pw_cat.norm(dim=-1).std():.4f}")

# Per-compartment output magnitude
for k in range(N_COMP):
    comp_out = pw_cat[:, k*D_COMP:(k+1)*D_COMP]
    print(f"    comp {k}: norm={comp_out.norm(dim=-1).mean():.4f} "
          f"std_across_dims={comp_out.std(dim=0).mean():.4f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 5: TRIANGULATION PATTERN ANALYSIS
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 5: TRIANGULATION PATTERNS")
print(f"{'='*65}")

# Triangulation distance stats
print(f"  Triangulation distances (1-cosine):")
print(f"    mean={tri.mean():.4f} std={tri.std():.4f}")
print(f"    min={tri.min():.4f} max={tri.max():.4f}")

# Nearest anchor distance
nearest_dist = tri.gather(1, nearest.unsqueeze(1)).squeeze(1)
print(f"  Nearest anchor distance:")
print(f"    mean={nearest_dist.mean():.4f} std={nearest_dist.std():.4f}")
print(f"    max={nearest_dist.max():.4f} min={nearest_dist.min():.4f}")

# How many anchors are "close" (cosine > 0.5, i.e. dist < 0.5)
close_count = (tri < 0.5).float().sum(dim=1)
print(f"  Anchors within cos>0.5 per image:")
print(f"    mean={close_count.mean():.1f} std={close_count.std():.1f}")

# Top-k nearest anchors β€” how spread are they?
topk_dists = tri.topk(10, dim=1, largest=False)
print(f"  Top-10 nearest anchor distances:")
for k_idx in range(10):
    d = topk_dists.values[:, k_idx]
    print(f"    k={k_idx}: mean={d.mean():.4f} std={d.std():.4f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 6: PER-CLASS ANCHOR AFFINITY
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 6: PER-CLASS ANCHOR AFFINITY")
print(f"{'='*65}")

# COCO class names (subset)
coco_names = ["person", "bicycle", "car", "motorcycle", "airplane",
              "bus", "train", "truck", "boat", "traffic light",
              "fire hydrant", "stop sign", "parking meter", "bench", "bird",
              "cat", "dog", "horse", "sheep", "cow"]

# For each class, which anchors are most associated?
print(f"\n  Top-3 anchors per class (first 20 classes):")
for c in range(min(20, N_CLASSES)):
    mask = val_labels[:, c] > 0
    if mask.sum() < 5: continue
    class_nearest = nearest[mask]
    counts = torch.bincount(class_nearest, minlength=N_ANCHORS)
    top3 = counts.topk(3)
    name = coco_names[c] if c < len(coco_names) else f"class_{c}"
    total = mask.sum().item()
    pcts = [f"{top3.indices[k]}({top3.values[k].item()}/{total})" for k in range(3)]
    print(f"    {name:<15} (n={total:4d}): {' '.join(pcts)}")

# Anchor specialization: how many classes does each anchor serve?
anchor_class_count = torch.zeros(N_ANCHORS)
for a in range(N_ANCHORS):
    mask = nearest == a
    if mask.sum() < 1: continue
    class_present = val_labels[mask].sum(0) > 0
    anchor_class_count[a] = class_present.sum().item()
print(f"\n  Anchor specialization:")
print(f"    classes per anchor: mean={anchor_class_count[anchor_class_count>0].mean():.1f} "
      f"std={anchor_class_count[anchor_class_count>0].std():.1f}")
print(f"    max={anchor_class_count.max():.0f} min={anchor_class_count[anchor_class_count>0].min():.0f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 7: FUSED EMBEDDING GEOMETRY
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 7: FUSED EMBEDDING GEOMETRY")
print(f"{'='*65}")

# Norms (should be 1.0)
fused_norms = fused.norm(dim=-1)
print(f"  Norms: mean={fused_norms.mean():.6f} std={fused_norms.std():.6f}")

# Self-similarity
fused_n = F.normalize(fused, dim=-1)
self_sim = fused_n @ fused_n.T
self_sim_off = (self_sim.sum() - self_sim.diag().sum()) / (N_val**2 - N_val)
print(f"  Self-sim (off-diag): {self_sim_off:.4f}")

# SVD
fc = fused.float() - fused.float().mean(0, keepdim=True)
sv = torch.linalg.svdvals(fc)
eff_dim = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
print(f"  Effective dim: {eff_dim:.1f}/{D_ANCHOR}")
cumvar = sv.pow(2).cumsum(0) / sv.pow(2).sum()
for k in [5, 10, 20, 50, 100]:
    if k-1 < len(cumvar):
        print(f"    top-{k} SVs explain {cumvar[k-1]*100:.1f}%")

# CV
vols = []
for _ in range(500):
    idx = torch.randperm(N_val)[:5]
    v2 = cayley_menger_vol2(fused_n[idx].unsqueeze(0))
    v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
    if v > 0: vols.append(v)
fused_cv = np.std(vols) / (np.mean(vols) + 1e-8)
print(f"  Pentachoron CV: {fused_cv:.4f}")


# ══════════════════════════════════════════════════════════════════
# SCAN 8: EXPERT CONTRIBUTION ANALYSIS
# ══════════════════════════════════════════════════════════════════

print(f"\n{'='*65}")
print("SCAN 8: EXPERT CONTRIBUTION")
print(f"{'='*65}")

# How much does each expert contribute to the fused embedding?
# cos(expert_proj, fused) tells us alignment
for e, name in enumerate(expert_names):
    cos = F.cosine_similarity(proj[e], fused, dim=-1)
    print(f"  {name:<15}: cos_to_fused mean={cos.mean():.4f} std={cos.std():.4f}")

# Residual after removing each expert
for e, name in enumerate(expert_names):
    others = [proj[i] for i in range(3) if i != e]
    fused_without = F.normalize(sum(others) / 2, dim=-1)
    delta = F.cosine_similarity(fused, fused_without, dim=-1)
    print(f"  Without {name:<15}: cos_to_full={delta.mean():.4f} (uniqueness={1-delta.mean():.4f})")

# Per-image expert disagreement
print(f"\n  Per-image expert disagreement:")
all_cos = []
for i in range(3):
    for j in range(i+1, 3):
        cos = F.cosine_similarity(proj[i], proj[j], dim=-1)
        all_cos.append(cos)
stacked = torch.stack(all_cos, dim=1)  # (N, 3)
per_image_agree = stacked.mean(dim=1)
per_image_disagree = stacked.std(dim=1)
print(f"  Agreement: mean={per_image_agree.mean():.4f} std={per_image_agree.std():.4f}")
print(f"  Disagreement: mean={per_image_disagree.mean():.4f} std={per_image_disagree.std():.4f}")

# Most agreed and disagreed images
most_agree_idx = per_image_agree.argmax().item()
most_disagree_idx = per_image_agree.argmin().item()
print(f"\n  Most agreed image ({most_agree_idx}): agreement={per_image_agree[most_agree_idx]:.4f}")
print(f"    labels: {val_labels[most_agree_idx].nonzero(as_tuple=True)[0].tolist()}")
print(f"  Most disagreed image ({most_disagree_idx}): agreement={per_image_agree[most_disagree_idx]:.4f}")
print(f"    labels: {val_labels[most_disagree_idx].nonzero(as_tuple=True)[0].tolist()}")


print(f"\n{'='*65}")
print("ANALYSIS COMPLETE")
print(f"{'='*65}")