"""OgmaModel — top-level model wrapping any architecture variant.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from .config import OgmaConfig, TaskToken, VariantType from .embeddings import TokenEmbedding from .pooling import create_pooling from .variants.conv import ConvVariant from .variants.deep_narrow import DeepNarrowVariant from .variants.linear_attention import LinearAttentionVariant from .variants.mlp_mixer import MLPMixerVariant from .variants.transformer import TransformerVariant from .variants.transformer_resa import TransformerReSAVariant from .variants.gla import GLAVariant __all__ = ["OgmaModel"] MAX_PARAMS = 10_000_000 def _build_variant(config: OgmaConfig) -> nn.Module: """Instantiate the appropriate architecture variant.""" if config.variant == VariantType.TRANSFORMER: return TransformerVariant(config) elif config.variant == VariantType.DEEP_NARROW: return DeepNarrowVariant(config) elif config.variant == VariantType.CONV: return ConvVariant(config) elif config.variant == VariantType.LINEAR_ATTENTION: return LinearAttentionVariant(config) elif config.variant == VariantType.MLP_MIXER: return MLPMixerVariant(config) elif config.variant == VariantType.TRANSFORMER_RESA: return TransformerReSAVariant(config) elif config.variant == VariantType.GLA: return GLAVariant(config) raise ValueError(f"Unknown variant: {config.variant}") class OgmaModel(nn.Module): """Ogma embedding model. Wraps any architecture variant with shared embedding, pooling, and normalization. Produces L2-normalized embeddings at d_output dimensions, Matryoshka-compatible at configured sub-dimensions. """ def __init__(self, config: OgmaConfig) -> None: super().__init__() self.config = config self.embedding = TokenEmbedding(config) self.variant = _build_variant(config) self.pooling = create_pooling(config) # Output projection if variant output != d_output needs_proj = ( config.variant == VariantType.DEEP_NARROW and config.d_model != config.d_output ) # DeepNarrowVariant already has output_proj, so no extra needed here if not needs_proj and config.d_model != config.d_output: self.output_proj: nn.Module = nn.Linear( config.d_model, config.d_output ) else: self.output_proj = nn.Identity() def forward( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, task_token_ids: torch.Tensor, ) -> torch.Tensor: """Forward pass producing L2-normalized embeddings. Args: token_ids: (B, S) token IDs. attention_mask: (B, S) attention mask (1=valid, 0=pad). task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM). Returns: (B, d_output) L2-normalized embeddings. """ # Embed tokens with task token prepended -> (B, S+1, d_model) x = self.embedding(token_ids, task_token_ids) # Extend attention mask for prepended task token task_mask = torch.ones( attention_mask.shape[0], 1, device=attention_mask.device, dtype=attention_mask.dtype, ) extended_mask = torch.cat([task_mask, attention_mask], dim=1) # Run through variant x = self.variant(x, extended_mask) # Pool x = self.pooling(x, extended_mask) # Project if needed x = self.output_proj(x) # L2 normalize return F.normalize(x, p=2, dim=-1) def encode( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, task: TaskToken = TaskToken.SYM, ) -> torch.Tensor: """Encode tokens with a specified task mode. Args: token_ids: (B, S) token IDs. attention_mask: (B, S) attention mask. task: Task token to use. Returns: (B, d_output) L2-normalized embeddings. """ task_ids = torch.full( (token_ids.shape[0],), self.config.task_token_id(task), device=token_ids.device, dtype=torch.long, ) return self.forward(token_ids, attention_mask, task_ids) def param_count(self) -> int: """Count total trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def assert_param_budget(self) -> None: """Assert model is under the 10M parameter budget.""" count = self.param_count() assert count < MAX_PARAMS, ( f"Model has {count:,} params, exceeds {MAX_PARAMS:,} budget" ) @classmethod def from_config(cls, config: OgmaConfig) -> OgmaModel: """Factory method to build a model from config.""" model = cls(config) model.assert_param_budget() return model @classmethod def from_checkpoint( cls, path: str, device: str = "cpu", ) -> OgmaModel: """Load model from a checkpoint directory. Args: path: Path to checkpoint directory containing config.yaml and model.pt. device: Device to load model to. Returns: Loaded OgmaModel. """ from pathlib import Path import yaml ckpt_path = Path(path) with open(ckpt_path / "config.yaml") as f: config_dict = yaml.safe_load(f) config = OgmaConfig.from_dict(config_dict) model = cls(config) state_dict = torch.load( ckpt_path / "model.pt", map_location=device, weights_only=True, ) model.load_state_dict(state_dict) model.to(device) model.eval() return model def save_checkpoint(self, path: str) -> None: """Save model checkpoint. Args: path: Directory to save config.yaml and model.pt. """ from pathlib import Path import yaml ckpt_path = Path(path) ckpt_path.mkdir(parents=True, exist_ok=True) with open(ckpt_path / "config.yaml", "w") as f: yaml.dump(self.config.to_dict(), f, default_flow_style=False) torch.save(self.state_dict(), ckpt_path / "model.pt")