#!/usr/bin/env python3 """ GEOLIP VISION ENCODER — FROM SCRATCH ====================================== From-scratch ViT trained against frozen soup consensus targets. Phase 0: Pre-compute consensus targets from frozen soup Phase 1: Pre-cache all COCO images as tensors (once, then reuse) Phase 2: Train from-scratch ViT with full GeoLIP loss stack """ import torch import torch.nn as nn import torch.nn.functional as F import os import gc import time import math import numpy as np from tqdm import tqdm DEVICE = "cuda" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Architecture D_MODEL = 384 N_HEADS = 6 N_LAYERS = 6 D_FF = 1536 PATCH_SIZE = 16 IMAGE_SIZE = 224 D_ANCHOR = 128 N_ANCHORS = 256 N_CLASSES = 80 N_COMP = 8 D_COMP = 64 DROPOUT = 0.1 # Training BATCH = 48 EPOCHS = 20 LR = 3e-4 WARMUP_STEPS = 500 GRAD_CLIP = 1.0 EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 print("=" * 65) print("GEOLIP VISION ENCODER — FROM SCRATCH") print(f" ViT: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") print(f" {N_PATCHES} patches + CLS → {D_ANCHOR}-d output") print(f" Device: {DEVICE}") print("=" * 65) # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() @torch.no_grad() def cv_metric(emb, n_samples=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc def whitened_procrustes_loss(emb, targets): B = emb.shape[0] if B < 10: return torch.tensor(0.0, device=emb.device) em = emb.float().mean(0, keepdim=True) tm = targets.float().mean(0, keepdim=True) cos = F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1) return 1.0 - cos.mean() class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(-1, keepdim=True) corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest return corrected.to(grad_output.dtype), None, None, None, None # ══════════════════════════════════════════════════════════════════ # FROZEN SOUP # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): def __init__(self): super().__init__() self.anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, D_ANCHOR), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self): super().__init__() self.n_comp = N_COMP asgn = torch.arange(N_ANCHORS) % N_COMP self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), D_COMP * 2), nn.GELU(), nn.Linear(D_COMP * 2, D_COMP), nn.LayerNorm(D_COMP)) for k in range(N_COMP)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) class FrozenSoup(nn.Module): def __init__(self): super().__init__() self.constellation = Constellation() self.patchwork = Patchwork() pw_dim = N_COMP * D_COMP self.classifier = nn.Sequential( nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0), nn.Linear(pw_dim, N_CLASSES)) def forward(self, emb_128): tri, nearest = self.constellation.triangulate(emb_128) pw = self.patchwork(tri) logits = self.classifier(torch.cat([pw, emb_128], -1)) return logits, tri, nearest # ══════════════════════════════════════════════════════════════════ # FROM-SCRATCH ViT ENCODER # ══════════════════════════════════════════════════════════════════ class GeoLIPViTEncoder(nn.Module): def __init__(self): super().__init__() self.patch_embed = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, stride=PATCH_SIZE) self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) self.pos_embed = nn.Parameter(torch.zeros(1, N_PATCHES + 1, D_MODEL)) self.embed_norm = nn.LayerNorm(D_MODEL) self.embed_drop = nn.Dropout(DROPOUT) encoder_layer = nn.TransformerEncoderLayer( d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=D_FF, dropout=DROPOUT, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=N_LAYERS, enable_nested_tensor=False) self.output_proj = nn.Sequential( nn.Linear(D_MODEL, D_MODEL), nn.GELU(), nn.LayerNorm(D_MODEL), nn.Linear(D_MODEL, D_ANCHOR)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight); nn.init.zeros_(m.bias) nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, pixel_values): B = pixel_values.shape[0] x = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) cls = self.cls_token.expand(B, -1, -1) x = torch.cat([cls, x], dim=1) + self.pos_embed x = self.embed_drop(self.embed_norm(x)) x = self.encoder(x) pooled = x[:, 1:, :].mean(dim=1) return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # LOAD SOUP + PRE-COMPUTE TARGETS # ══════════════════════════════════════════════════════════════════ print(f"\n Loading soup...") ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False) soup = FrozenSoup() soup_sd = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("constellation.") or k.startswith("patchwork.") or k.startswith("classifier.")} soup.load_state_dict(soup_sd, strict=True) soup = soup.eval().to(DEVICE) for p in soup.parameters(): p.requires_grad = False consensus_cv = ckpt.get("consensus_cv_128", 0.27) print(f" Soup: mAP={ckpt['mAP']:.3f} CV_target={consensus_cv:.4f}") # Rebuild projectors for target generation class ExpertProjector(nn.Module): def __init__(self): super().__init__() self.proj = nn.Sequential(nn.Linear(768, D_ANCHOR), nn.LayerNorm(D_ANCHOR)) def forward(self, x): return F.normalize(self.proj(x), dim=-1) from datasets import load_dataset projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) proj_sd = {k.replace("projectors.", ""): v for k, v in ckpt["state_dict"].items() if k.startswith("projectors.")} projectors.load_state_dict(proj_sd) projectors = projectors.eval().to(DEVICE) for split_name, split_key in [("train", "train"), ("val", "val")]: cache_path = f"cached_{split_name}_targets.pt" if os.path.exists(cache_path): cached = torch.load(cache_path, weights_only=False) if split_name == "train": train_targets = cached["targets"]; train_labels = cached["labels"] train_ids = cached["image_ids"]; train_id_map = {iid: i for i, iid in enumerate(train_ids)} N_train = len(train_ids) else: val_targets = cached["targets"]; val_labels = cached["labels"] val_ids = cached["image_ids"]; val_id_map = {iid: i for i, iid in enumerate(val_ids)} N_val = len(val_ids) print(f" {split_name}: loaded cached targets ({len(cached['targets']):,})") continue print(f" Computing {split_name} targets...") ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split=split_key) ids = ref["image_id"]; N = len(ids) id_map = {iid: i for i, iid in enumerate(ids)} labels = torch.zeros(N, N_CLASSES) for i, labs in enumerate(ref["labels"]): for l in labs: if l < N_CLASSES: labels[i, l] = 1.0 expert_feats = [] for name in tqdm(EXPERTS, desc=f" Loading {split_name} experts"): ds = load_dataset("AbstractPhil/bulk-coco-features", name, split=split_key) feats = torch.zeros(N, 768) for row in ds: if row["image_id"] in id_map: feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) expert_feats.append(feats) del ds targets = torch.zeros(N, D_ANCHOR) with torch.no_grad(): for j in tqdm(range(0, N, 512), desc=f" Fusing {split_name}"): end = min(j + 512, N) batch = [expert_feats[e][j:end].to(DEVICE) for e in range(3)] projected = [projectors[e](batch[e]) for e in range(3)] fused = F.normalize(sum(projected) / 3, dim=-1) targets[j:end] = fused.cpu() torch.save({"targets": targets, "labels": labels, "image_ids": ids}, cache_path) print(f" {split_name}: {N:,} targets computed and cached") if split_name == "train": train_targets = targets; train_labels = labels train_ids = ids; train_id_map = id_map; N_train = N else: val_targets = targets; val_labels = labels val_ids = ids; val_id_map = id_map; N_val = N del expert_feats; gc.collect() del projectors, proj_sd; gc.collect() train_targets_gpu = train_targets.to(DEVICE) train_labels_gpu = train_labels.to(DEVICE) val_targets_gpu = val_targets.to(DEVICE) anchors_frozen = soup.constellation.anchors.detach() # Image preprocessing from torchvision import transforms img_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ══════════════════════════════════════════════════════════════════ # PRE-CACHE IMAGES AS TENSORS # ══════════════════════════════════════════════════════════════════ def cache_images(split_name, split_key, id_map, N): cache_path = f"cached_{split_name}_images.pt" if os.path.exists(cache_path): print(f" Loading cached {split_name} images...") data = torch.load(cache_path, weights_only=True) print(f" {split_name}: {data.shape} ({data.shape[0] * data.element_size() * data.nelement() / data.shape[0] / 1e6:.1f} MB/img)") return data print(f" Caching {split_name} images ({N:,})...") images = torch.zeros(N, 3, IMAGE_SIZE, IMAGE_SIZE, dtype=torch.float16) stream = load_dataset("rafaelpadilla/coco2017", split=split_key, revision="refs/convert/parquet", streaming=True) cached = 0 for row in tqdm(stream, desc=f" Caching {split_name}", total=N): iid = row.get("image_id") if iid not in id_map: continue try: img = row["image"].convert("RGB") tensor = img_transform(img).half() images[id_map[iid]] = tensor cached += 1 except: continue print(f" Cached {cached}/{N} images") torch.save(images, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Saved: {cache_path} ({size_mb:.0f} MB)") return images train_images = cache_images("train", "train", train_id_map, N_train) val_images = cache_images("val", "validation", val_id_map, N_val) # ══════════════════════════════════════════════════════════════════ # BUILD ENCODER # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("BUILD ENCODER") print(f"{'='*65}") encoder = GeoLIPViTEncoder().to(DEVICE) n_params = sum(p.numel() for p in encoder.parameters()) print(f" Architecture: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") print(f" Input: {IMAGE_SIZE}×{IMAGE_SIZE} → {N_PATCHES} patches") print(f" Output: {D_ANCHOR}-d (on hypersphere)") print(f" Parameters: {n_params:,}") # ══════════════════════════════════════════════════════════════════ # EVALUATION # ══════════════════════════════════════════════════════════════════ @torch.no_grad() def evaluate(encoder, soup, val_images, val_targets, val_labels, desc="Val"): encoder.eval() N = val_images.shape[0] all_logits = torch.zeros(N, N_CLASSES) all_embs = torch.zeros(N, D_ANCHOR) n_seen = 0 for j in tqdm(range(0, N, BATCH), desc=f" {desc}", leave=False): end = min(j + BATCH, N) pixels = val_images[j:end].float().to(DEVICE) # Skip zero images (failed to cache) mask = pixels.abs().sum(dim=(1, 2, 3)) > 0.1 if mask.sum() == 0: continue emb = encoder(pixels[mask]) logits, _, nearest = soup(emb) k = 0 for idx in range(j, end): if idx - j < len(mask) and mask[idx - j]: all_logits[idx] = logits[k].cpu().float() all_embs[idx] = emb[k].cpu().float() k += 1 n_seen += 1 # mAP v_lab = val_labels ap_sum, nv = 0, 0 for c in range(N_CLASSES): if v_lab[:, c].sum() > 0: si = all_logits[:, c].argsort(descending=True) st = v_lab[:, c][si] pak = st.cumsum(0) / torch.arange(1, len(st) + 1).float() ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 mAP = ap_sum / max(nv, 1) # F1 vp = (all_logits.sigmoid() > 0.5).float() tp = (vp * v_lab).sum(0); fp = (vp * (1 - v_lab)).sum(0) fn = ((1 - vp) * v_lab).sum(0) pr = tp / (tp + fp + 1e-8); rc = tp / (tp + fn + 1e-8) f1 = 2 * pr * rc / (pr + rc + 1e-8) macro_f1 = f1[f1 > 0].mean().item() # Cosine to targets valid = all_embs.norm(dim=-1) > 0.1 v_cos = F.cosine_similarity( all_embs[valid], val_targets[valid], dim=-1).mean().item() if valid.sum() > 0 else 0.0 # R@1 if valid.sum() > 100: sim = all_embs[valid] @ val_targets[valid].T r1 = (sim.argmax(-1) == torch.arange(valid.sum())).float().mean().item() else: r1 = 0.0 # Active anchors valid_embs = all_embs[valid].to(DEVICE) if valid_embs.shape[0] > 0: _, v_nearest = soup.constellation.triangulate(valid_embs) n_active = v_nearest.cpu().unique().numel() else: n_active = 0 # CV v_cv = cv_metric(valid_embs[:2000]) if valid_embs.shape[0] > 100 else 0.0 return { "mAP": mAP, "f1": macro_f1, "r1": r1, "cos": v_cos, "cv": v_cv, "n_active": n_active, "n_seen": n_seen, } # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("TRAINING") print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") print(f" Losses: InfoNCE + MSE + CV + BCE + Procrustes alignment") print(f" CV target: {consensus_cv:.4f}") print(f" Images: train={N_train:,} val={N_val:,} (cached as tensors)") print(f"{'='*65}") optimizer = torch.optim.Adam(encoder.parameters(), lr=LR) n_batches = N_train // BATCH total_steps = n_batches * EPOCHS scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=WARMUP_STEPS), torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(total_steps - WARMUP_STEPS, 1), eta_min=1e-6)], milestones=[WARMUP_STEPS]) scaler = torch.amp.GradScaler("cuda") os.makedirs("checkpoints", exist_ok=True) from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter("runs/geolip_vit_encoder") best_mAP = 0.0 gs = 0 for epoch in range(EPOCHS): encoder.train() t0 = time.time() perm = torch.randperm(N_train) # Accumulators acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0, "align": 0, "nce_acc": 0, "n": 0} pbar = tqdm(range(0, N_train, BATCH), desc=f"E{epoch+1:2d}/{EPOCHS} train", unit="batch") for i in pbar: idx = perm[i:i+BATCH] if len(idx) < 4: continue pixels = train_images[idx].float().to(DEVICE) targets = train_targets_gpu[idx] labels = train_labels_gpu[idx] # Skip batches with too many zero images valid = pixels.abs().sum(dim=(1, 2, 3)) > 0.1 if valid.sum() < 4: continue pixels = pixels[valid] targets = targets[valid] labels = labels[valid] with torch.amp.autocast("cuda", dtype=torch.bfloat16): emb = encoder(pixels) emb = EmbeddingAutograd.apply(emb, emb, anchors_frozen, 0.01, 1.0) l_nce, nce_acc = infonce(emb, targets) l_mse = F.mse_loss(emb, targets) l_cv = cv_loss(emb, target=consensus_cv) l_align = whitened_procrustes_loss(emb, targets) logits, _, _ = soup(emb) l_bce = F.binary_cross_entropy_with_logits(logits, labels) loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce + 0.5 * l_align + 0.001 * l_cv) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(encoder.parameters(), GRAD_CLIP) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() acc["loss"] += loss.item() acc["nce"] += l_nce.item() acc["mse"] += l_mse.item() acc["bce"] += l_bce.item() acc["cv"] += l_cv.item() acc["align"] += l_align.item() acc["nce_acc"] += nce_acc acc["n"] += 1 gs += 1 # Tensorboard step logging if gs % 50 == 0: writer.add_scalar("step/loss", loss.item(), gs) writer.add_scalar("step/nce", l_nce.item(), gs) writer.add_scalar("step/mse", l_mse.item(), gs) writer.add_scalar("step/bce", l_bce.item(), gs) writer.add_scalar("step/cv", l_cv.item(), gs) writer.add_scalar("step/align", l_align.item(), gs) writer.add_scalar("step/nce_acc", nce_acc, gs) writer.add_scalar("step/lr", scheduler.get_last_lr()[0], gs) # Update tqdm if acc["n"] % 20 == 0: d = acc["n"] pbar.set_postfix( loss=f"{acc['loss']/d:.4f}", nce_acc=f"{acc['nce_acc']/d:.3f}", cos=f"{1-acc['align']/d:.3f}", ordered=True) elapsed = time.time() - t0 d = max(acc["n"], 1) print(f" E{epoch+1} train: {elapsed:.0f}s " f"loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} " f"mse={acc['mse']/d:.4f} bce={acc['bce']/d:.4f} " f"nce_acc={acc['nce_acc']/d:.3f}") # Epoch tensorboard writer.add_scalar("epoch/train_loss", acc["loss"] / d, epoch + 1) writer.add_scalar("epoch/train_nce", acc["nce"] / d, epoch + 1) writer.add_scalar("epoch/train_mse", acc["mse"] / d, epoch + 1) writer.add_scalar("epoch/train_bce", acc["bce"] / d, epoch + 1) writer.add_scalar("epoch/train_cv", acc["cv"] / d, epoch + 1) writer.add_scalar("epoch/train_align", acc["align"] / d, epoch + 1) writer.add_scalar("epoch/train_nce_acc", acc["nce_acc"] / d, epoch + 1) # ── Validation ── m = evaluate(encoder, soup, val_images, val_targets, val_labels) writer.add_scalar("epoch/val_mAP", m["mAP"], epoch + 1) writer.add_scalar("epoch/val_F1", m["f1"], epoch + 1) writer.add_scalar("epoch/val_R@1", m["r1"], epoch + 1) writer.add_scalar("epoch/val_cos", m["cos"], epoch + 1) writer.add_scalar("epoch/val_cv", m["cv"], epoch + 1) writer.add_scalar("epoch/val_anchors", m["n_active"], epoch + 1) mk = "" if m["mAP"] > best_mAP: best_mAP = m["mAP"] torch.save({ "encoder_state_dict": encoder.state_dict(), "config": {"d_model": D_MODEL, "n_heads": N_HEADS, "n_layers": N_LAYERS, "d_ff": D_FF, "patch_size": PATCH_SIZE, "image_size": IMAGE_SIZE, "output_dim": D_ANCHOR}, "mAP": m["mAP"], "f1": m["f1"], "r1": m["r1"], "cos": m["cos"], "cv": m["cv"], "epoch": epoch + 1, "n_active": m["n_active"], "consensus_cv": consensus_cv, }, "checkpoints/geolip_vit_encoder_best.pt") mk = " ★" # Save every epoch checkpoint torch.save({ "encoder_state_dict": encoder.state_dict(), "epoch": epoch + 1, "mAP": m["mAP"], "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "scaler": scaler.state_dict(), "gs": gs, }, f"checkpoints/geolip_vit_e{epoch+1:02d}.pt") print(f" E{epoch+1} val: mAP={m['mAP']:.3f} F1={m['f1']:.3f} " f"R@1={m['r1']:.3f} cos={m['cos']:.3f} cv={m['cv']:.4f} " f"anchors={m['n_active']}/256 seen={m['n_seen']}/{N_val}{mk}") writer.close() print(f"\n Best mAP: {best_mAP:.3f}") print(f" Encoder: {n_params:,} params (from scratch)") print(f" Checkpoints saved every epoch in checkpoints/") print(f" Tensorboard: runs/geolip_vit_encoder") print(f"\n{'='*65}\nDONE\n{'='*65}")