# ============================================================================ # GeoLIP ViT: HuggingFace AutoModel # # Usage: # from transformers import AutoModel # model = AutoModel.from_pretrained("AbstractPhil/geolip-vit-base-x3", # trust_remote_code=True) # # from torchvision import transforms # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # ]) # pixel_values = transform(image).unsqueeze(0) # outputs = model(pixel_values) # # # 128-d embedding on hypersphere (L2-normalized) # embedding = outputs.embedding # (B, 128) # # # Multi-label classification logits (80 COCO classes) # logits = outputs.logits # (B, 80) — if soup_enabled # # # Triangulation distances to 256 constellation anchors # triangulation = outputs.triangulation # (B, 256) # # # Nearest anchor index per sample # nearest = outputs.nearest # (B,) # # # Geometric diagnostics # diagnostics = outputs.diagnostics # dict # ============================================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from dataclasses import dataclass, field from typing import Optional, Dict, Any # ══════════════════════════════════════════════════════════════════ # CONFIG # ══════════════════════════════════════════════════════════════════ class GeoLIPViTConfig(PretrainedConfig): model_type = "geolip_vit" def __init__( self, image_size=224, patch_size=16, hidden_size=384, num_attention_heads=6, num_hidden_layers=6, intermediate_size=1536, output_dim=128, n_anchors=256, n_comp=8, d_comp=64, n_classes=80, hidden_dropout_prob=0.1, soup_enabled=True, consensus_cv=0.2731, experts=None, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.intermediate_size = intermediate_size self.output_dim = output_dim self.n_anchors = n_anchors self.n_comp = n_comp self.d_comp = d_comp self.n_classes = n_classes self.hidden_dropout_prob = hidden_dropout_prob self.soup_enabled = soup_enabled self.consensus_cv = consensus_cv self.experts = experts or ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] # ══════════════════════════════════════════════════════════════════ # OUTPUT # ══════════════════════════════════════════════════════════════════ @dataclass class GeoLIPViTOutput: """ Output fields: embedding: (B, output_dim) L2-normalized on hypersphere logits: (B, n_classes) multi-label classification (if soup_enabled) triangulation: (B, n_anchors) distances to constellation anchors nearest: (B,) nearest anchor index patch_tokens: (B, n_patches, hidden_size) pre-pooling patch representations diagnostics: dict geometric metrics """ embedding: torch.Tensor = None logits: Optional[torch.Tensor] = None triangulation: Optional[torch.Tensor] = None nearest: Optional[torch.Tensor] = None patch_tokens: Optional[torch.Tensor] = None diagnostics: Optional[Dict[str, Any]] = None # ══════════════════════════════════════════════════════════════════ # GEOMETRIC COMPONENTS # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): def __init__(self, n_anchors, d): super().__init__() self.n_anchors = n_anchors self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) class Patchwork(nn.Module): def __init__(self, n_anchors, n_comp, d_comp): super().__init__() self.n_comp = n_comp self.n_anchors = n_anchors asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) # Compute input sizes from ints, not tensors (meta-tensor safe) anchors_per_comp = n_anchors // n_comp remainder = n_anchors % n_comp self.comps = nn.ModuleList([nn.Sequential( nn.Linear(anchors_per_comp + (1 if k < remainder else 0), d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for k in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) # ══════════════════════════════════════════════════════════════════ # MODEL # ══════════════════════════════════════════════════════════════════ class GeoLIPViTModel(PreTrainedModel): """ From-scratch Vision Transformer producing L2-normalized embeddings on a 128-d hypersphere, geometrically anchored by a constellation of 256 reference points trained via 3-expert consensus distillation. The encoder is trained from Xavier initialization against consensus targets from CLIP ViT-L/14, DINOv2 ViT-B/14, and SigLIP ViT-B/16. Optional soup pipeline (constellation + patchwork + classifier) provides multi-label COCO classification from the embedding. Output fields: embedding: (B, 128) L2-normalized, consensus-aligned logits: (B, 80) multi-label COCO logits (if soup_enabled) triangulation: (B, 256) distances to constellation anchors nearest: (B,) nearest anchor index patch_tokens: (B, 196, 384) pre-pooling patch representations diagnostics: dict geometric metrics """ config_class = GeoLIPViTConfig def __init__(self, config): super().__init__(config) self.config = config n_patches = (config.image_size // config.patch_size) ** 2 # ── Encoder ── self.patch_embed = nn.Conv2d( 3, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.pos_embed = nn.Parameter( torch.zeros(1, n_patches + 1, config.hidden_size)) self.embed_norm = nn.LayerNorm(config.hidden_size) self.embed_drop = nn.Dropout(config.hidden_dropout_prob) # Individual layers for geometric injection between each self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=config.hidden_dropout_prob, activation="gelu", batch_first=True, norm_first=True) for _ in range(config.num_hidden_layers)]) # Geometric injection: pool → anchor_dim → triangulate → hidden_size self.geo_pool_proj = nn.Linear(config.hidden_size, config.output_dim) self.geo_tri_proj = nn.Sequential( nn.Linear(config.n_anchors, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size)) self.output_proj = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, config.output_dim), ) # ── Soup Pipeline (optional) ── if getattr(config, "soup_enabled", False): self.constellation = Constellation(config.n_anchors, config.output_dim) self.patchwork = Patchwork( config.n_anchors, config.n_comp, config.d_comp) pw_dim = config.n_comp * config.d_comp self.classifier = nn.Sequential( nn.Linear(pw_dim + config.output_dim, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.0), nn.Linear(pw_dim, config.n_classes)) else: self.constellation = None self.patchwork = None self.classifier = None self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, pixel_values, output_patch_tokens=False, **kwargs): B = pixel_values.shape[0] # ── Encode ── x = self.patch_embed(pixel_values) x = x.flatten(2).transpose(1, 2) cls = self.cls_token.expand(B, -1, -1) x = torch.cat([cls, x], dim=1) x = x + self.pos_embed x = self.embed_drop(self.embed_norm(x)) # ── Transformer with geometric injection ── # Get anchors for triangulation (from constellation if available) if self.constellation is not None: anchors_n = F.normalize(self.constellation.anchors.detach(), dim=-1) else: anchors_n = None for layer in self.layers: if anchors_n is not None: # Pool → project → triangulate → geo token pooled = x[:, 1:, :].mean(dim=1) geo_128 = F.normalize(self.geo_pool_proj(pooled), dim=-1) tri_dists = 1.0 - geo_128 @ anchors_n.T geo_token = self.geo_tri_proj(tri_dists).unsqueeze(1) x_with_geo = torch.cat([geo_token, x], dim=1) x_with_geo = layer(x_with_geo) x = x_with_geo[:, 1:, :] else: x = layer(x) # ── Pool + Project ── patch_tokens = x[:, 1:, :] pooled = patch_tokens.mean(dim=1) embedding = F.normalize(self.output_proj(pooled), dim=-1) # ── Soup Pipeline ── logits = None triangulation = None nearest = None diagnostics = {} if self.constellation is not None: tri, near = self.constellation.triangulate(embedding) triangulation = tri nearest = near if self.patchwork is not None and self.classifier is not None: pw = self.patchwork(tri) logits = self.classifier(torch.cat([pw, embedding], -1)) # Geometric diagnostics with torch.no_grad(): anchors_n = F.normalize(self.constellation.anchors, dim=-1) cos_to_anchors = embedding @ anchors_n.T diagnostics = { "nearest_cos": cos_to_anchors.max(dim=-1).values.mean().item(), "mean_anchor_cos": cos_to_anchors.mean().item(), "n_active_anchors": near.unique().numel(), "embedding_norm": embedding.norm(dim=-1).mean().item(), } return GeoLIPViTOutput( embedding=embedding, logits=logits, triangulation=triangulation, nearest=nearest, patch_tokens=patch_tokens if output_patch_tokens else None, diagnostics=diagnostics, )