#!/usr/bin/env python3 """ GEOLIP VISION ENCODER — FROM SCRATCH ====================================== CaptionBERT paradigm applied to vision: CaptionBERT: from-scratch 6L/384d transformer → 768-d → InfoNCE+MSE+CV against 5-BERT consensus This: from-scratch ViT (6L/384d, patch16) → 128-d → full GeoLIP losses against 3-expert consensus Phase 0: Pre-compute targets from frozen soup (3 expert features → 128-d fused) Phase 1: Build from-scratch ViT, Xavier init, Procrustes-init output_proj Phase 2: Train on raw COCO images, targets = frozen soup consensus embeddings Losses: - InfoNCE (embedding vs consensus target) - MSE (embedding vs consensus target) - CV loss (calibrated to consensus CV) - BCE (through frozen soup pipeline for task signal) - Whitened Procrustes alignment loss - Geometric autograd (tangential + separation) """ import torch import torch.nn as nn import torch.nn.functional as F import os import gc import time import math import numpy as np DEVICE = "cuda" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Architecture D_MODEL = 384 # internal transformer dim (same as CaptionBERT) N_HEADS = 6 N_LAYERS = 6 D_FF = 1536 PATCH_SIZE = 16 IMAGE_SIZE = 224 D_ANCHOR = 128 # output dim (matches soup) N_ANCHORS = 256 N_CLASSES = 80 N_COMP = 8 D_COMP = 64 DROPOUT = 0.1 # Training BATCH = 48 EPOCHS = 20 LR = 3e-4 WARMUP_STEPS = 500 GRAD_CLIP = 1.0 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 # 14×14 = 196 print("=" * 65) print("GEOLIP VISION ENCODER — FROM SCRATCH") print(f" ViT: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") print(f" {N_PATCHES} patches + CLS → {D_ANCHOR}-d output") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T def whitened_procrustes_loss(emb, targets): """Differentiable whitened Procrustes alignment loss.""" B = emb.shape[0] if B < 10: return torch.tensor(0.0, device=emb.device) emb_f = emb.float(); tgt_f = targets.float() em = emb_f.mean(0, keepdim=True); tm = tgt_f.mean(0, keepdim=True) ec = emb_f - em; tc = tgt_f - tm # Cosine alignment per sample (differentiable proxy for Procrustes) cos = F.cosine_similarity(ec, tc, dim=-1) return 1.0 - cos.mean() # ══════════════════════════════════════════════════════════════════ # FROZEN SOUP # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): def __init__(self): super().__init__() self.anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, D_ANCHOR), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self): super().__init__() self.n_comp = N_COMP asgn = torch.arange(N_ANCHORS) % N_COMP self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), D_COMP * 2), nn.GELU(), nn.Linear(D_COMP * 2, D_COMP), nn.LayerNorm(D_COMP)) for k in range(N_COMP)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) class FrozenSoup(nn.Module): def __init__(self): super().__init__() self.constellation = Constellation() self.patchwork = Patchwork() pw_dim = N_COMP * D_COMP self.classifier = nn.Sequential( nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0), nn.Linear(pw_dim, N_CLASSES)) def forward(self, emb_128): tri, nearest = self.constellation.triangulate(emb_128) pw = self.patchwork(tri) logits = self.classifier(torch.cat([pw, emb_128], -1)) return logits, tri, nearest # ══════════════════════════════════════════════════════════════════ # FROM-SCRATCH ViT ENCODER # ══════════════════════════════════════════════════════════════════ class GeoLIPViTEncoder(nn.Module): """ From-scratch ViT. Same pattern as CaptionBERT's CaptionEncoder. patch_embed (Xavier) → pos_embed (learned) → CLS token → transformer encoder (Xavier, norm_first) → mean pool → output_proj (→ 128-d) → L2-norm No pretrained weights anywhere. """ def __init__(self, image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, d_ff=D_FF, output_dim=D_ANCHOR, dropout=DROPOUT): super().__init__() self.patch_size = patch_size self.d_model = d_model n_patches = (image_size // patch_size) ** 2 # Patch embedding: (B, 3, H, W) → (B, n_patches, d_model) self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model)) self.embed_norm = nn.LayerNorm(d_model) self.embed_drop = nn.Dropout(dropout) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=n_layers, enable_nested_tensor=False) # Output projection → anchor space self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, output_dim), ) self._init_weights() def _init_weights(self): # Xavier for all linear layers for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) # Position and CLS token init nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, pixel_values): B = pixel_values.shape[0] # Patch embed x = self.patch_embed(pixel_values) # (B, d_model, H/P, W/P) x = x.flatten(2).transpose(1, 2) # (B, n_patches, d_model) # Prepend CLS token cls = self.cls_token.expand(B, -1, -1) x = torch.cat([cls, x], dim=1) # (B, n_patches+1, d_model) x = x + self.pos_embed x = self.embed_drop(self.embed_norm(x)) # Transformer x = self.encoder(x) # Pool: mean over patch tokens (exclude CLS for richer pooling) patch_tokens = x[:, 1:, :] pooled = patch_tokens.mean(dim=1) # Project to anchor space + L2 normalize return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # LOAD SOUP + PRE-COMPUTE TARGETS # ══════════════════════════════════════════════════════════════════ print(f"\n Loading soup...") ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False) soup = FrozenSoup() soup_sd = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("constellation.") or k.startswith("patchwork.") or k.startswith("classifier.")} soup.load_state_dict(soup_sd, strict=True) soup = soup.eval().to(DEVICE) for p in soup.parameters(): p.requires_grad = False consensus_cv = ckpt.get("consensus_cv_128", 0.27) print(f" Soup: mAP={ckpt['mAP']:.3f} CV_target={consensus_cv:.4f}") # Load projectors from soup to generate targets class ExpertProjector(nn.Module): def __init__(self): super().__init__() self.proj = nn.Sequential(nn.Linear(768, D_ANCHOR), nn.LayerNorm(D_ANCHOR)) def forward(self, x): return F.normalize(self.proj(x), dim=-1) print(f"\n Pre-computing consensus targets from 3 experts...") from datasets import load_dataset # Rebuild projectors projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) proj_sd = {} for k, v in ckpt["state_dict"].items(): if k.startswith("projectors."): proj_sd[k.replace("projectors.", "")] = v projectors.load_state_dict(proj_sd) projectors = projectors.eval().to(DEVICE) # Load expert features + compute targets for split_name, split_key in [("train", "train"), ("val", "val")]: ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split=split_key) ids = ref["image_id"]; N = len(ids) id_map = {iid: i for i, iid in enumerate(ids)} labels = torch.zeros(N, N_CLASSES) for i, labs in enumerate(ref["labels"]): for l in labs: if l < N_CLASSES: labels[i, l] = 1.0 expert_feats = [] for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split=split_key) feats = torch.zeros(N, 768) for row in ds: if row["image_id"] in id_map: feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) expert_feats.append(feats) del ds # Compute fused embeddings through projectors targets = torch.zeros(N, D_ANCHOR) with torch.no_grad(): for j in range(0, N, 512): end = min(j + 512, N) batch = [expert_feats[e][j:end].to(DEVICE) for e in range(3)] projected = [projectors[e](batch[e]) for e in range(3)] fused = F.normalize(sum(projected) / 3, dim=-1) targets[j:end] = fused.cpu() if split_name == "train": train_targets = targets; train_labels = labels train_ids = ids; train_id_map = id_map; N_train = N else: val_targets = targets; val_labels = labels val_ids = ids; val_id_map = id_map; N_val = N print(f" {split_name}: {N:,} targets computed") del expert_feats; gc.collect() del projectors, proj_sd; gc.collect() print(f" CV of train targets: ", end="") t_cv = cv_loss(train_targets[:5000].to(DEVICE), target=0.0).item() print(f"{t_cv:.4f} (raw CV)") # Move to GPU train_targets_gpu = train_targets.to(DEVICE) train_labels_gpu = train_labels.to(DEVICE) val_targets_gpu = val_targets.to(DEVICE) # Anchors for geometric autograd anchors_frozen = soup.constellation.anchors.detach() # ══════════════════════════════════════════════════════════════════ # BUILD ENCODER # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("BUILD ENCODER") print(f"{'='*65}") encoder = GeoLIPViTEncoder().to(DEVICE) n_params = sum(p.numel() for p in encoder.parameters()) print(f" Architecture: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") print(f" Input: {IMAGE_SIZE}×{IMAGE_SIZE} → {N_PATCHES} patches") print(f" Output: {D_ANCHOR}-d (on hypersphere)") print(f" Parameters: {n_params:,}") # Image preprocessing (simple, no pretrained processor dependency) from torchvision import transforms img_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("TRAINING") print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") print(f" Losses: InfoNCE + MSE + CV + BCE + Procrustes alignment") print(f" CV target: {consensus_cv:.4f}") print(f"{'='*65}") optimizer = torch.optim.Adam(encoder.parameters(), lr=LR) total_steps = (N_train // BATCH) * EPOCHS scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=WARMUP_STEPS), torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(total_steps - WARMUP_STEPS, 1), eta_min=1e-6)], milestones=[WARMUP_STEPS]) scaler = torch.amp.GradScaler("cuda") os.makedirs("checkpoints", exist_ok=True) best_mAP = 0.0 gs = 0 for epoch in range(EPOCHS): encoder.train() t0 = time.time() coco_train = load_dataset("rafaelpadilla/coco2017", split="train", revision="refs/convert/parquet", streaming=True) tl_total, tl_nce, tl_mse, tl_bce, tn_acc = 0, 0, 0, 0, 0 nb, n_images = 0, 0 batch_imgs, batch_idx = [], [] for row in coco_train: iid = row.get("image_id") if iid not in train_id_map: continue try: img = row["image"].convert("RGB") pixel_values = img_transform(img) batch_imgs.append(pixel_values) batch_idx.append(train_id_map[iid]) except: continue if len(batch_imgs) < BATCH: continue # ── Process batch ── pixels = torch.stack(batch_imgs).to(DEVICE) indices = torch.tensor(batch_idx, device=DEVICE) targets = train_targets_gpu[indices] labels = train_labels_gpu[indices] with torch.amp.autocast("cuda", dtype=torch.bfloat16): emb = encoder(pixels) # Geometric autograd emb = EmbeddingAutograd.apply(emb, emb, anchors_frozen, 0.01, 1.0) # Student losses (CaptionBERT style) l_nce, nce_acc = infonce(emb, targets) l_mse = F.mse_loss(emb, targets) l_cv = cv_loss(emb, target=consensus_cv) l_align = whitened_procrustes_loss(emb, targets) # Task loss through frozen soup logits, _, _ = soup(emb) l_bce = F.binary_cross_entropy_with_logits(logits, labels) loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce + 0.5 * l_align + 0.001 * l_cv) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(encoder.parameters(), GRAD_CLIP) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() tl_total += loss.item() tl_nce += l_nce.item() tl_bce += l_bce.item() tn_acc += nce_acc nb += 1; n_images += len(batch_imgs); gs += 1 batch_imgs, batch_idx = [], [] if n_images % (BATCH * 50) == 0: elapsed = time.time() - t0 print(f" {n_images:>6}/{N_train} ({n_images/elapsed:.0f} img/s) " f"loss={tl_total/nb:.4f} nce_acc={tn_acc/nb:.3f}", flush=True) elapsed = time.time() - t0 d = max(nb, 1) print(f" E{epoch+1} train: {n_images} imgs, {elapsed:.0f}s, " f"loss={tl_total/d:.4f} nce={tl_nce/d:.4f} bce={tl_bce/d:.4f} " f"nce_acc={tn_acc/d:.3f}") # ── Validation ── encoder.eval() coco_val = load_dataset("rafaelpadilla/coco2017", split="validation", revision="refs/convert/parquet", streaming=True) all_logits = torch.zeros(N_val, N_CLASSES) all_embs = torch.zeros(N_val, D_ANCHOR) n_val_seen = 0 vbatch_imgs, vbatch_idx = [], [] with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): for row in coco_val: iid = row.get("image_id") if iid not in val_id_map: continue try: img = row["image"].convert("RGB") vbatch_imgs.append(img_transform(img)) vbatch_idx.append(val_id_map[iid]) except: continue if len(vbatch_imgs) >= BATCH: pixels = torch.stack(vbatch_imgs).to(DEVICE) emb = encoder(pixels) logits, _, nearest = soup(emb) for j, idx in enumerate(vbatch_idx): all_logits[idx] = logits[j].cpu().float() all_embs[idx] = emb[j].cpu().float() n_val_seen += len(vbatch_imgs) vbatch_imgs, vbatch_idx = [], [] if vbatch_imgs: pixels = torch.stack(vbatch_imgs).to(DEVICE) emb = encoder(pixels) logits, _, nearest = soup(emb) for j, idx in enumerate(vbatch_idx): all_logits[idx] = logits[j].cpu().float() all_embs[idx] = emb[j].cpu().float() n_val_seen += len(vbatch_imgs) # mAP v_lab = val_labels ap_sum, nv = 0, 0 for c in range(N_CLASSES): if v_lab[:, c].sum() > 0: si = all_logits[:, c].argsort(descending=True) st = v_lab[:, c][si] pak = st.cumsum(0) / torch.arange(1, len(st)+1).float() ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 mAP = ap_sum / max(nv, 1) vp = (all_logits.sigmoid() > 0.5).float() tp = (vp * v_lab).sum(0); fp = (vp * (1-v_lab)).sum(0) fn = ((1-vp) * v_lab).sum(0) pr = tp/(tp+fp+1e-8); rc = tp/(tp+fn+1e-8) f1 = 2*pr*rc/(pr+rc+1e-8) macro_f1 = f1[f1 > 0].mean().item() v_cos = F.cosine_similarity( all_embs, val_targets, dim=-1).mean().item() _, v_nearest = soup.constellation.triangulate(all_embs.to(DEVICE)) n_active = v_nearest.cpu().unique().numel() # R@1 sim = all_embs @ val_targets.T r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item() mk = "" if mAP > best_mAP: best_mAP = mAP torch.save({ "encoder_state_dict": encoder.state_dict(), "config": {"d_model": D_MODEL, "n_heads": N_HEADS, "n_layers": N_LAYERS, "d_ff": D_FF, "patch_size": PATCH_SIZE, "image_size": IMAGE_SIZE, "output_dim": D_ANCHOR}, "mAP": mAP, "f1": macro_f1, "r1": r1, "cos": v_cos, "epoch": epoch+1, "n_active": n_active, }, "checkpoints/geolip_vit_encoder_best.pt") mk = " ★" print(f" E{epoch+1} val: mAP={mAP:.3f} F1={macro_f1:.3f} R@1={r1:.3f} " f"cos={v_cos:.3f} anchors={n_active}/256 seen={n_val_seen}/{N_val}{mk}") print(f"\n Best mAP: {best_mAP:.3f}") print(f" Encoder: {n_params:,} params") print(f"\n{'='*65}\nDONE\n{'='*65}")