ogma-mini / configuration_ogma.py
Antreas's picture
Enable AutoModel loading
bbae1b8 verified
"""Hugging Face AutoConfig support for Ogma models."""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from transformers import PretrainedConfig
__all__ = ["OgmaConfig", "VariantType", "PoolingType", "TaskToken"]
class VariantType(StrEnum):
"""Architecture variant identifiers."""
TRANSFORMER = "transformer"
DEEP_NARROW = "deep_narrow"
CONV = "conv"
LINEAR_ATTENTION = "linear_attention"
MLP_MIXER = "mlp_mixer"
TRANSFORMER_RESA = "transformer_resa"
GLA = "gla"
class PoolingType(StrEnum):
"""Pooling strategy identifiers."""
TASK_TOKEN = "task_token"
LATENT_ATTENTION = "latent_attention"
MEAN = "mean"
class TaskToken(StrEnum):
"""Task token identifiers for asymmetric encoding."""
QRY = "QRY"
DOC = "DOC"
SYM = "SYM"
class OgmaConfig(PretrainedConfig):
"""Configuration for Ogma embedding models."""
model_type = "ogma"
def __init__(
self,
variant: str | VariantType = VariantType.TRANSFORMER,
d_embed: int = 128,
d_model: int = 256,
n_layers: int = 1,
n_heads: int = 4,
vocab_size: int = 30_000,
max_seq_len: int = 512,
matryoshka_dims: list[int] | None = None,
pooling: str | PoolingType = PoolingType.TASK_TOKEN,
d_output: int = 256,
ffn_mult: float = 8 / 3,
conv_kernel_size: int = 7,
spatial_rank: int = 32,
n_random_features: int = 128,
dropout: float = 0.0,
scorer_type: str = "dot",
scorer_alpha_init: float = 0.1,
scorer_hidden: int = 0,
gla_expand_k: float = 0.5,
gla_expand_v: float = 1.0,
gla_gate_low_rank_dim: int = 16,
gla_gate_logit_normalizer: int = 16,
gla_use_short_conv: bool = True,
gla_conv_size: int = 4,
pad_id: int = 0,
unk_id: int = 1,
bos_id: int = 2,
eos_id: int = 3,
qry_id: int = 4,
doc_id: int = 5,
sym_id: int = 6,
n_special_tokens: int = 7,
**kwargs: Any,
) -> None:
kwargs.setdefault("pad_token_id", pad_id)
kwargs.setdefault("bos_token_id", bos_id)
kwargs.setdefault("eos_token_id", eos_id)
super().__init__(**kwargs)
self.variant = VariantType(variant)
self.d_embed = d_embed
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.matryoshka_dims = matryoshka_dims or [32, 64, 128, 256]
self.pooling = PoolingType(pooling)
self.d_output = d_output
self.ffn_mult = ffn_mult
self.conv_kernel_size = conv_kernel_size
self.spatial_rank = spatial_rank
self.n_random_features = n_random_features
self.dropout = dropout
self.scorer_type = scorer_type
self.scorer_alpha_init = scorer_alpha_init
self.scorer_hidden = scorer_hidden
self.gla_expand_k = gla_expand_k
self.gla_expand_v = gla_expand_v
self.gla_gate_low_rank_dim = gla_gate_low_rank_dim
self.gla_gate_logit_normalizer = gla_gate_logit_normalizer
self.gla_use_short_conv = gla_use_short_conv
self.gla_conv_size = gla_conv_size
self.pad_id = pad_id
self.unk_id = unk_id
self.bos_id = bos_id
self.eos_id = eos_id
self.qry_id = qry_id
self.doc_id = doc_id
self.sym_id = sym_id
self.n_special_tokens = n_special_tokens
@property
def d_head(self) -> int:
"""Per-head dimension."""
return self.d_model // self.n_heads
@property
def ffn_hidden(self) -> int:
"""SwiGLU FFN hidden dimension."""
return int(self.d_model * self.ffn_mult)
def task_token_id(self, task: TaskToken | str) -> int:
"""Return token ID for a task token."""
task = TaskToken(task)
return {
TaskToken.QRY: self.qry_id,
TaskToken.DOC: self.doc_id,
TaskToken.SYM: self.sym_id,
}[task]
def to_dict(self) -> dict[str, Any]:
"""Serialize config to a JSON-compatible dictionary."""
output = super().to_dict()
output["variant"] = self.variant.value
output["pooling"] = self.pooling.value
return output