# coding=utf-8 # Copyright 2026 The OdinNext authors. # Licensed under the Apache License, Version 2.0. """OdinNext: 138M HGRN2+RoPE hybrid causal language model. This is a self-contained HuggingFace `trust_remote_code=True` port of the production OdinNext model used to train the 6.84B-token early checkpoint. The training-time machinery (DiffusionBlocks, TST, gate-absorption, torch.compile zone helpers) is dropped — only the inference path remains. Architecture summary: * 16 layers, d=768, 6 heads, ffn=2048, vocab=32768. * Even layers (0,2,...,14) get RoPE on q/k. * Odd layers (1,3,...,15) are position-free recurrent. * SwiGLU2 FFN: silu(gate)^2 * up. * ZCRMSNorm normalization, gated residuals (frozen at training time). * Tied input/output embeddings. * HGRN2 recurrence: O(T) train, O(1) per-token decode. Hardware notes: * Uses `flash-linear-attention` (`fla`) Triton kernels when available. Falls back to a pure-PyTorch implementation (~10-30x slower) otherwise, so the model loads on any backend including CPU. * Trained in fp16 on AMD Strix Halo (gfx1151, RDNA 3.5, ROCm 7.13). fp16 is the recommended inference dtype. bf16 was never validated on this checkpoint. """ from __future__ import annotations import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_odinnext import OdinNextConfig # --------------------------------------------------------------------------- # HGRN2 kernel: prefer flash-linear-attention, fall back to pure PyTorch # --------------------------------------------------------------------------- try: from fla.ops.gla import chunk_gla as _chunk_gla from fla.ops.gla import fused_recurrent_gla as _fused_recurrent_gla # `fla.ops.gla.chunk.ChunkGLAFunction` is decorated with # @torch.compiler.disable. Marking it allow_in_graph lets Dynamo treat # it as an opaque leaf op, preventing graph breaks if the user does # `torch.compile(model)`. Best-effort, ignored if internals shift. try: from fla.ops.gla.chunk import ChunkGLAFunction torch.compiler.allow_in_graph(ChunkGLAFunction) except Exception: pass _HAS_FLA = True except Exception: # ImportError, missing Triton, no CUDA/ROCm, ... from ._hgrn2_fallback import chunk_gla as _chunk_gla from ._hgrn2_fallback import fused_recurrent_gla as _fused_recurrent_gla _HAS_FLA = False # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class ZCRMSNorm(nn.Module): """Zero-Centered RMSNorm. Stored weight is initialized to 1.0; F.rms_norm sees a leaf parameter directly. Mathematically equivalent to RMSNorm with `gamma = weight - 1`. """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self._normalized_shape = (dim,) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.rms_norm(x, self._normalized_shape, self.weight, self.eps) class SwiGLU2(nn.Module): """SwiGLU squared FFN: silu(gate)^2 * up -> down.""" def __init__(self, d_model: int, ffn_inner: int): super().__init__() self.w_gate_up = nn.Linear(d_model, 2 * ffn_inner, bias=False) self.w_down = nn.Linear(ffn_inner, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: gate, up = self.w_gate_up(x).chunk(2, dim=-1) return self.w_down(F.silu(gate).square() * up) def _apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """Apply RoPE to x[B,T,H,D] using real arithmetic. cos/sin: [1, T, 1, D/2] pre-broadcast. """ x_even = x[..., 0::2] x_odd = x[..., 1::2] out_even = x_even * cos - x_odd * sin out_odd = x_even * sin + x_odd * cos return torch.stack([out_even, out_odd], dim=-1).flatten(-2) class OdinNextAttention(nn.Module): """HGRN2 attention with optional RoPE on q/k.""" def __init__( self, d_model: int = 768, n_heads: int = 6, expand_ratio: Optional[int] = None, use_rope: bool = True, ): super().__init__() self.d_model = d_model self.n_heads = n_heads if expand_ratio is None: expand_ratio = d_model // n_heads self.expand_ratio = expand_ratio self.head_f_dim = expand_ratio self.head_i_dim = d_model // n_heads self.forget_dim = n_heads * expand_ratio self.use_rope = use_rope self.q_proj = nn.Linear(d_model, self.forget_dim, bias=False) self.f_proj = nn.Linear(d_model, self.forget_dim, bias=False) self.i_proj = nn.Linear(d_model, d_model, bias=False) self.g_norm = ZCRMSNorm(d_model) self.o_proj = nn.Linear(d_model, d_model, bias=False) def forward( self, x: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, recurrent_state: Optional[torch.Tensor] = None, output_state: bool = False, use_recurrent_kernel: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: x: [B, T, D] hidden states. cos, sin: RoPE caches if `use_rope`, else ignored. recurrent_state: optional [B, H, K, V] HGRN2 state to seed the scan. output_state: if True, return the final HGRN2 state alongside output. use_recurrent_kernel: if True (single-token decode), call the fused recurrent kernel; otherwise call chunk_gla. """ B, T, D = x.shape q = F.silu(self.q_proj(x)) forget_logits = self.f_proj(x) g = F.logsigmoid(forget_logits) k = torch.sigmoid(-forget_logits) v = self.i_proj(x) q = q.view(B, T, self.n_heads, self.head_f_dim) k = k.view(B, T, self.n_heads, self.head_f_dim) g = g.view(B, T, self.n_heads, self.head_f_dim) v = v.view(B, T, self.n_heads, self.head_i_dim) if self.use_rope and cos is not None: q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) if use_recurrent_kernel: o, final_state = _fused_recurrent_gla( q=q, k=k, v=v, gk=g, initial_state=recurrent_state, output_final_state=True, ) else: o, final_state = _chunk_gla( q=q, k=k, v=v, g=g, initial_state=recurrent_state, output_final_state=output_state, ) o = o.reshape(B, T, D) o = self.g_norm(o) o = self.o_proj(o) if output_state: return o, final_state return o, None class OdinNextBlock(nn.Module): """Pre-norm block with gated residuals. Gates were absorbed and frozen at training time: `gate_attn` and `gate_ffn` are stored as scalars whose `sigmoid()` ≈ 1 by the time of this checkpoint. They remain in the state_dict for compatibility. """ def __init__( self, d_model: int, n_heads: int, ffn_inner: int, use_rope: bool = True, ): super().__init__() self.pre_norm = ZCRMSNorm(d_model) self.attn = OdinNextAttention( d_model=d_model, n_heads=n_heads, use_rope=use_rope ) self.ffn_norm = ZCRMSNorm(d_model) self.ffn = SwiGLU2(d_model, ffn_inner) self.gate_attn = nn.Parameter(torch.zeros(1)) self.gate_ffn = nn.Parameter(torch.zeros(1)) def forward( self, x: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, recurrent_state: Optional[torch.Tensor] = None, output_state: bool = False, use_recurrent_kernel: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: attn_out, new_state = self.attn( self.pre_norm(x), cos=cos, sin=sin, recurrent_state=recurrent_state, output_state=output_state, use_recurrent_kernel=use_recurrent_kernel, ) x = x + torch.sigmoid(self.gate_attn) * attn_out x = x + torch.sigmoid(self.gate_ffn) * self.ffn(self.ffn_norm(x)) return x, new_state # --------------------------------------------------------------------------- # OdinNext recurrent-state cache # --------------------------------------------------------------------------- class OdinNextCache: """Container for HGRN2 recurrent states across all layers. Wraps `List[Optional[Tensor]]` (one per layer, each [B, H, K, V]) with just enough surface to satisfy HuggingFace `generate()`'s expectations for `past_key_values`. Importantly: cache size is independent of T — it is the per-layer hidden-state matrix S, not a growing K/V tape. Also tracks `seen_tokens`, the number of input positions the cache has consumed so far, which OdinNext uses to look up the correct RoPE position offset during decode. """ def __init__(self, n_layers: int): self.n_layers = n_layers self.states: List[Optional[torch.Tensor]] = [None] * n_layers self.seen_tokens: int = 0 def __len__(self) -> int: return self.n_layers def __getitem__(self, idx: int) -> Optional[torch.Tensor]: return self.states[idx] def __setitem__(self, idx: int, value: Optional[torch.Tensor]) -> None: self.states[idx] = value def __iter__(self): return iter(self.states) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.seen_tokens def get_max_length(self) -> Optional[int]: return None # HGRN2 has no hard cache length cap def update_seen(self, n_new_tokens: int) -> None: self.seen_tokens += n_new_tokens def to(self, device: torch.device) -> "OdinNextCache": for i, s in enumerate(self.states): if s is not None: self.states[i] = s.to(device) return self # --------------------------------------------------------------------------- # OdinNext PreTrainedModel: HF integration # --------------------------------------------------------------------------- class OdinNextPreTrainedModel(PreTrainedModel): """Base class wiring up HF infrastructure for OdinNext.""" config_class = OdinNextConfig base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["OdinNextBlock"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = False # we use our own OdinNextCache def _init_weights(self, module: nn.Module) -> None: """Conservative init — at inference we only need to define defaults in case someone constructs an OdinNext from scratch. """ std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) class OdinNextModel(OdinNextPreTrainedModel): """Backbone (no LM head).""" def __init__(self, config: OdinNextConfig): super().__init__(config) self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([ OdinNextBlock( d_model=config.d_model, n_heads=config.n_heads, ffn_inner=config.ffn_inner, use_rope=(i % 2 == 0), ) for i in range(config.n_layers) ]) self.final_norm = ZCRMSNorm(config.d_model) # RoPE caches are lazy-built on first forward. Storing them as # `register_buffer(..., persistent=False)` is incompatible with # `from_pretrained(low_cpu_mem_usage=True)`: HF builds the model on # the meta device and only materializes tensors that appear in the # checkpoint. Non-persistent buffers are NOT in the checkpoint and # so end up backed by uninitialized memory after meta -> real # transfer. We side-step this entirely by computing cos/sin on the # first forward, cached on the model object as plain attributes. self._cos_cache: Optional[torch.Tensor] = None self._sin_cache: Optional[torch.Tensor] = None # Skip _init_weights here — we expect to load weights from a # pretrained checkpoint immediately after construction. def get_input_embeddings(self) -> nn.Embedding: return self.tok_embeddings def set_input_embeddings(self, value: nn.Embedding) -> None: self.tok_embeddings = value # ----------------------------------------------------------------- # Forward # ----------------------------------------------------------------- def _ensure_rope_cache(self, target_device: torch.device) -> None: """Build the RoPE cos/sin caches on `target_device` if not already. Cached as plain Python attributes (not buffers) to avoid HF's `low_cpu_mem_usage=True` meta-device materialization issue with non-persistent buffers. """ need_build = ( self._cos_cache is None or self._cos_cache.device != target_device ) if not need_build: return head_f_dim = self.config.d_model // self.config.n_heads half_dim = head_f_dim // 2 freqs = 1.0 / ( self.config.rope_theta ** ( torch.arange(0, half_dim, dtype=torch.float32, device=target_device) / half_dim ) ) t = torch.arange(self.config.max_seq_len, dtype=torch.float32, device=target_device) angles = torch.outer(t, freqs) self._cos_cache = angles.cos() self._sin_cache = angles.sin() def _rope_slice( self, seq_len: int, offset: int, target_dtype: torch.dtype, target_device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: end = offset + seq_len if end > self.config.max_seq_len: raise ValueError( f"Position {end} exceeds max_seq_len={self.config.max_seq_len}. " "OdinNext was trained with a 2048-token RoPE cache." ) self._ensure_rope_cache(target_device) cos = self._cos_cache[offset:end].to(dtype=target_dtype) sin = self._sin_cache[offset:end].to(dtype=target_dtype) cos = cos.unsqueeze(0).unsqueeze(2) # [1, T, 1, D/2] sin = sin.unsqueeze(0).unsqueeze(2) return cos, sin def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[OdinNextCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **_unused, ) -> Tuple[torch.Tensor, Optional[OdinNextCache]]: """Backbone forward. Returns `(hidden_states, past_key_values)`. The LM-head wrapper (`OdinNextForCausalLM`) projects to logits. Note: `attention_mask` is accepted for HF API compatibility but is NOT used. HGRN2 is causal by construction (the recurrence is strictly forward-in-time) and cannot honor a left-padded mask. For correct results with batched generation, callers must right-pad and ensure all sequences in a batch have valid tokens at every position they process. Single-sequence generation is unaffected. """ if use_cache is None: use_cache = self.config.use_cache B, T = input_ids.shape # Determine if we're in single-token decode mode. single_step = (T == 1) and (past_key_values is not None) # RoPE position offset if past_key_values is not None: offset = past_key_values.seen_tokens else: offset = 0 h = self.tok_embeddings(input_ids) # Prepare RoPE caches in the embedding's dtype. cos, sin = self._rope_slice( seq_len=T, offset=offset, target_dtype=h.dtype, target_device=h.device, ) # Coerce past_key_values to our expected type. HF generate may # try to auto-instantiate a DynamicCache or pass a legacy tuple; # we want strict OdinNextCache or None. if past_key_values is not None and not isinstance(past_key_values, OdinNextCache): past_key_values = None if past_key_values is None and use_cache: past_key_values = OdinNextCache(self.config.n_layers) for i, layer in enumerate(self.layers): prev_state = past_key_values[i] if past_key_values is not None else None h, new_state = layer( h, cos=cos, sin=sin, recurrent_state=prev_state, output_state=use_cache, use_recurrent_kernel=single_step, ) if use_cache and past_key_values is not None: past_key_values[i] = new_state h = self.final_norm(h) if past_key_values is not None: past_key_values.update_seen(T) return h, past_key_values class OdinNextForCausalLM(OdinNextPreTrainedModel): """Top-level wrapper exposing logits + HF generate().""" # Map tied output -> source. Newer `transformers` (>=4.45) expects a # dict; older versions tolerate (and used) a list of keys. Provide the # dict form which is forward-compatible. _tied_weights_keys = {"lm_head.weight": "model.tok_embeddings.weight"} def __init__(self, config: OdinNextConfig): super().__init__(config) self.model = OdinNextModel(config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.tie_embeddings: self.lm_head.weight = self.model.tok_embeddings.weight self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.model.tok_embeddings def set_input_embeddings(self, value: nn.Embedding) -> None: self.model.tok_embeddings = value def get_output_embeddings(self) -> nn.Linear: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: self.lm_head = new_embeddings def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[OdinNextCache] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **_unused, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else True hidden_states, past_key_values = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, ) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)).float(), shift_labels.view(-1).long(), ignore_index=-100, ) if not return_dict: output = (logits,) + ((past_key_values,) if past_key_values is not None else ()) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=None, attentions=None, ) # ----------------------------------------------------------------- # generate() integration # ----------------------------------------------------------------- def prepare_inputs_for_generation( self, input_ids: torch.Tensor, past_key_values: Optional[OdinNextCache] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = True, **kwargs, ) -> dict: """Trim input_ids to only the new positions when a cache exists. After the first forward, the recurrent state already encodes the prompt. Subsequent calls only need to pass the most recently generated token. """ if past_key_values is not None and past_key_values.seen_tokens > 0: # New tokens since last call. new_count = input_ids.shape[1] - past_key_values.seen_tokens if new_count <= 0: # generate() can occasionally call us with the same length # twice (e.g., assistant-decoding paths). Default to feeding # the last token only. input_ids = input_ids[:, -1:] else: input_ids = input_ids[:, -new_count:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "use_cache": use_cache, } def _reorder_cache( self, past_key_values: OdinNextCache, beam_idx: torch.Tensor ) -> OdinNextCache: """Beam-search support: reorder per-layer states along the batch axis.""" for i, state in enumerate(past_key_values.states): if state is not None: past_key_values.states[i] = state.index_select(0, beam_idx.to(state.device)) return past_key_values @staticmethod def _supports_default_dynamic_cache() -> bool: return False # Re-export for convenience __all__ = [ "OdinNextConfig", "OdinNextModel", "OdinNextForCausalLM", "OdinNextCache", ]