"""Model configuration for Ogma.""" from __future__ import annotations from dataclasses import dataclass, field from enum import StrEnum from typing import Any __all__ = ["OgmaConfig", "VariantType", "PoolingType", "TaskToken"] class VariantType(StrEnum): """Architecture variant identifiers.""" TRANSFORMER = "transformer" DEEP_NARROW = "deep_narrow" CONV = "conv" LINEAR_ATTENTION = "linear_attention" MLP_MIXER = "mlp_mixer" TRANSFORMER_RESA = "transformer_resa" GLA = "gla" class PoolingType(StrEnum): """Pooling strategy identifiers.""" TASK_TOKEN = "task_token" LATENT_ATTENTION = "latent_attention" MEAN = "mean" class TaskToken(StrEnum): """Task token identifiers for asymmetric encoding.""" QRY = "QRY" DOC = "DOC" SYM = "SYM" @dataclass class OgmaConfig: """Configuration for an Ogma model instance. Args: variant: Architecture variant to use. d_embed: Token embedding dimension (from teacher PCA). d_model: Internal model dimension after projection. n_layers: Number of fusion layers/blocks. n_heads: Number of attention heads (attention variants only). vocab_size: Vocabulary size for embedding table. max_seq_len: Maximum sequence length. matryoshka_dims: Nested output dimensions for Matryoshka. pooling: Pooling strategy. d_output: Final output dimension. ffn_mult: SwiGLU FFN hidden dimension multiplier. conv_kernel_size: Kernel size for conv variant. spatial_rank: Rank of spatial mixing in MLP mixer. n_random_features: Random features for linear attention. dropout: Dropout rate (0 for inference). """ variant: VariantType = VariantType.TRANSFORMER d_embed: int = 128 d_model: int = 256 n_layers: int = 1 n_heads: int = 4 vocab_size: int = 30_000 max_seq_len: int = 512 matryoshka_dims: list[int] = field( default_factory=lambda: [32, 64, 128, 256] ) pooling: PoolingType = PoolingType.TASK_TOKEN d_output: int = 256 ffn_mult: float = 8 / 3 # SwiGLU: 8/3 * d_model ≈ 683 for d=256 conv_kernel_size: int = 7 spatial_rank: int = 32 n_random_features: int = 128 dropout: float = 0.0 # ReSA scorer settings scorer_type: str = "dot" scorer_alpha_init: float = 0.1 scorer_hidden: int = 0 # 0 defaults to d_head # GLA (Gated Linear Attention) settings gla_expand_k: float = 0.5 # key dim expansion (key_dim = d_model * expand_k) gla_expand_v: float = 1.0 # value dim expansion (value_dim = d_model * expand_v) gla_gate_low_rank_dim: int = 16 # low-rank dim for gating projection gla_gate_logit_normalizer: int = 16 # normalizer for gate logits gla_use_short_conv: bool = True # whether to use short conv on Q,K,V gla_conv_size: int = 4 # short conv kernel size # Special token IDs pad_id: int = 0 unk_id: int = 1 bos_id: int = 2 eos_id: int = 3 qry_id: int = 4 doc_id: int = 5 sym_id: int = 6 n_special_tokens: int = 7 @property def d_head(self) -> int: """Per-head dimension.""" return self.d_model // self.n_heads @property def ffn_hidden(self) -> int: """SwiGLU FFN hidden dimension.""" return int(self.d_model * self.ffn_mult) def task_token_id(self, task: TaskToken) -> int: """Return token ID for a task token.""" mapping = { TaskToken.QRY: self.qry_id, TaskToken.DOC: self.doc_id, TaskToken.SYM: self.sym_id, } return mapping[task] def to_dict(self) -> dict[str, Any]: """Serialize config to dictionary.""" return { "variant": self.variant.value, "d_embed": self.d_embed, "d_model": self.d_model, "n_layers": self.n_layers, "n_heads": self.n_heads, "vocab_size": self.vocab_size, "max_seq_len": self.max_seq_len, "matryoshka_dims": self.matryoshka_dims, "pooling": self.pooling.value, "d_output": self.d_output, "ffn_mult": self.ffn_mult, "conv_kernel_size": self.conv_kernel_size, "spatial_rank": self.spatial_rank, "n_random_features": self.n_random_features, "dropout": self.dropout, "scorer_type": self.scorer_type, "scorer_alpha_init": self.scorer_alpha_init, "scorer_hidden": self.scorer_hidden, "gla_expand_k": self.gla_expand_k, "gla_expand_v": self.gla_expand_v, "gla_gate_low_rank_dim": self.gla_gate_low_rank_dim, "gla_gate_logit_normalizer": self.gla_gate_logit_normalizer, "gla_use_short_conv": self.gla_use_short_conv, "gla_conv_size": self.gla_conv_size, } @classmethod def from_dict(cls, data: dict[str, Any]) -> OgmaConfig: """Deserialize config from dictionary.""" data = dict(data) if "variant" in data: data["variant"] = VariantType(data["variant"]) if "pooling" in data: data["pooling"] = PoolingType(data["pooling"]) known = {f.name for f in cls.__dataclass_fields__.values()} filtered = {k: v for k, v in data.items() if k in known} return cls(**filtered)