OdinNext-138M-Early-Checkpoint / modeling_odinnext.py
joelhenwang's picture
Initial release: 6.84B token early checkpoint (EMA weights)
cb8708f verified
# coding=utf-8
# Copyright 2026 The OdinNext authors.
# Licensed under the Apache License, Version 2.0.
"""OdinNext: 138M HGRN2+RoPE hybrid causal language model.
This is a self-contained HuggingFace `trust_remote_code=True` port of the
production OdinNext model used to train the 6.84B-token early checkpoint.
The training-time machinery (DiffusionBlocks, TST, gate-absorption,
torch.compile zone helpers) is dropped — only the inference path remains.
Architecture summary:
* 16 layers, d=768, 6 heads, ffn=2048, vocab=32768.
* Even layers (0,2,...,14) get RoPE on q/k.
* Odd layers (1,3,...,15) are position-free recurrent.
* SwiGLU2 FFN: silu(gate)^2 * up.
* ZCRMSNorm normalization, gated residuals (frozen at training time).
* Tied input/output embeddings.
* HGRN2 recurrence: O(T) train, O(1) per-token decode.
Hardware notes:
* Uses `flash-linear-attention` (`fla`) Triton kernels when available.
Falls back to a pure-PyTorch implementation (~10-30x slower) otherwise,
so the model loads on any backend including CPU.
* Trained in fp16 on AMD Strix Halo (gfx1151, RDNA 3.5, ROCm 7.13).
fp16 is the recommended inference dtype. bf16 was never validated on
this checkpoint.
"""
from __future__ import annotations
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_odinnext import OdinNextConfig
# ---------------------------------------------------------------------------
# HGRN2 kernel: prefer flash-linear-attention, fall back to pure PyTorch
# ---------------------------------------------------------------------------
try:
from fla.ops.gla import chunk_gla as _chunk_gla
from fla.ops.gla import fused_recurrent_gla as _fused_recurrent_gla
# `fla.ops.gla.chunk.ChunkGLAFunction` is decorated with
# @torch.compiler.disable. Marking it allow_in_graph lets Dynamo treat
# it as an opaque leaf op, preventing graph breaks if the user does
# `torch.compile(model)`. Best-effort, ignored if internals shift.
try:
from fla.ops.gla.chunk import ChunkGLAFunction
torch.compiler.allow_in_graph(ChunkGLAFunction)
except Exception:
pass
_HAS_FLA = True
except Exception: # ImportError, missing Triton, no CUDA/ROCm, ...
from ._hgrn2_fallback import chunk_gla as _chunk_gla
from ._hgrn2_fallback import fused_recurrent_gla as _fused_recurrent_gla
_HAS_FLA = False
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class ZCRMSNorm(nn.Module):
"""Zero-Centered RMSNorm.
Stored weight is initialized to 1.0; F.rms_norm sees a leaf parameter
directly. Mathematically equivalent to RMSNorm with `gamma = weight - 1`.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self._normalized_shape = (dim,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, self._normalized_shape, self.weight, self.eps)
class SwiGLU2(nn.Module):
"""SwiGLU squared FFN: silu(gate)^2 * up -> down."""
def __init__(self, d_model: int, ffn_inner: int):
super().__init__()
self.w_gate_up = nn.Linear(d_model, 2 * ffn_inner, bias=False)
self.w_down = nn.Linear(ffn_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate, up = self.w_gate_up(x).chunk(2, dim=-1)
return self.w_down(F.silu(gate).square() * up)
def _apply_rope(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""Apply RoPE to x[B,T,H,D] using real arithmetic.
cos/sin: [1, T, 1, D/2] pre-broadcast.
"""
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos
return torch.stack([out_even, out_odd], dim=-1).flatten(-2)
class OdinNextAttention(nn.Module):
"""HGRN2 attention with optional RoPE on q/k."""
def __init__(
self,
d_model: int = 768,
n_heads: int = 6,
expand_ratio: Optional[int] = None,
use_rope: bool = True,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
if expand_ratio is None:
expand_ratio = d_model // n_heads
self.expand_ratio = expand_ratio
self.head_f_dim = expand_ratio
self.head_i_dim = d_model // n_heads
self.forget_dim = n_heads * expand_ratio
self.use_rope = use_rope
self.q_proj = nn.Linear(d_model, self.forget_dim, bias=False)
self.f_proj = nn.Linear(d_model, self.forget_dim, bias=False)
self.i_proj = nn.Linear(d_model, d_model, bias=False)
self.g_norm = ZCRMSNorm(d_model)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
recurrent_state: Optional[torch.Tensor] = None,
output_state: bool = False,
use_recurrent_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x: [B, T, D] hidden states.
cos, sin: RoPE caches if `use_rope`, else ignored.
recurrent_state: optional [B, H, K, V] HGRN2 state to seed the scan.
output_state: if True, return the final HGRN2 state alongside output.
use_recurrent_kernel: if True (single-token decode), call the
fused recurrent kernel; otherwise call chunk_gla.
"""
B, T, D = x.shape
q = F.silu(self.q_proj(x))
forget_logits = self.f_proj(x)
g = F.logsigmoid(forget_logits)
k = torch.sigmoid(-forget_logits)
v = self.i_proj(x)
q = q.view(B, T, self.n_heads, self.head_f_dim)
k = k.view(B, T, self.n_heads, self.head_f_dim)
g = g.view(B, T, self.n_heads, self.head_f_dim)
v = v.view(B, T, self.n_heads, self.head_i_dim)
if self.use_rope and cos is not None:
q = _apply_rope(q, cos, sin)
k = _apply_rope(k, cos, sin)
if use_recurrent_kernel:
o, final_state = _fused_recurrent_gla(
q=q, k=k, v=v, gk=g,
initial_state=recurrent_state,
output_final_state=True,
)
else:
o, final_state = _chunk_gla(
q=q, k=k, v=v, g=g,
initial_state=recurrent_state,
output_final_state=output_state,
)
o = o.reshape(B, T, D)
o = self.g_norm(o)
o = self.o_proj(o)
if output_state:
return o, final_state
return o, None
class OdinNextBlock(nn.Module):
"""Pre-norm block with gated residuals.
Gates were absorbed and frozen at training time: `gate_attn` and
`gate_ffn` are stored as scalars whose `sigmoid()` ≈ 1 by the time of
this checkpoint. They remain in the state_dict for compatibility.
"""
def __init__(
self,
d_model: int,
n_heads: int,
ffn_inner: int,
use_rope: bool = True,
):
super().__init__()
self.pre_norm = ZCRMSNorm(d_model)
self.attn = OdinNextAttention(
d_model=d_model, n_heads=n_heads, use_rope=use_rope
)
self.ffn_norm = ZCRMSNorm(d_model)
self.ffn = SwiGLU2(d_model, ffn_inner)
self.gate_attn = nn.Parameter(torch.zeros(1))
self.gate_ffn = nn.Parameter(torch.zeros(1))
def forward(
self,
x: torch.Tensor,
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
recurrent_state: Optional[torch.Tensor] = None,
output_state: bool = False,
use_recurrent_kernel: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
attn_out, new_state = self.attn(
self.pre_norm(x),
cos=cos, sin=sin,
recurrent_state=recurrent_state,
output_state=output_state,
use_recurrent_kernel=use_recurrent_kernel,
)
x = x + torch.sigmoid(self.gate_attn) * attn_out
x = x + torch.sigmoid(self.gate_ffn) * self.ffn(self.ffn_norm(x))
return x, new_state
# ---------------------------------------------------------------------------
# OdinNext recurrent-state cache
# ---------------------------------------------------------------------------
class OdinNextCache:
"""Container for HGRN2 recurrent states across all layers.
Wraps `List[Optional[Tensor]]` (one per layer, each [B, H, K, V]) with
just enough surface to satisfy HuggingFace `generate()`'s expectations
for `past_key_values`. Importantly: cache size is independent of T —
it is the per-layer hidden-state matrix S, not a growing K/V tape.
Also tracks `seen_tokens`, the number of input positions the cache has
consumed so far, which OdinNext uses to look up the correct RoPE
position offset during decode.
"""
def __init__(self, n_layers: int):
self.n_layers = n_layers
self.states: List[Optional[torch.Tensor]] = [None] * n_layers
self.seen_tokens: int = 0
def __len__(self) -> int:
return self.n_layers
def __getitem__(self, idx: int) -> Optional[torch.Tensor]:
return self.states[idx]
def __setitem__(self, idx: int, value: Optional[torch.Tensor]) -> None:
self.states[idx] = value
def __iter__(self):
return iter(self.states)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
def get_max_length(self) -> Optional[int]:
return None # HGRN2 has no hard cache length cap
def update_seen(self, n_new_tokens: int) -> None:
self.seen_tokens += n_new_tokens
def to(self, device: torch.device) -> "OdinNextCache":
for i, s in enumerate(self.states):
if s is not None:
self.states[i] = s.to(device)
return self
# ---------------------------------------------------------------------------
# OdinNext PreTrainedModel: HF integration
# ---------------------------------------------------------------------------
class OdinNextPreTrainedModel(PreTrainedModel):
"""Base class wiring up HF infrastructure for OdinNext."""
config_class = OdinNextConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["OdinNextBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = False # we use our own OdinNextCache
def _init_weights(self, module: nn.Module) -> None:
"""Conservative init — at inference we only need to define defaults
in case someone constructs an OdinNext from scratch.
"""
std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class OdinNextModel(OdinNextPreTrainedModel):
"""Backbone (no LM head)."""
def __init__(self, config: OdinNextConfig):
super().__init__(config)
self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([
OdinNextBlock(
d_model=config.d_model,
n_heads=config.n_heads,
ffn_inner=config.ffn_inner,
use_rope=(i % 2 == 0),
)
for i in range(config.n_layers)
])
self.final_norm = ZCRMSNorm(config.d_model)
# RoPE caches are lazy-built on first forward. Storing them as
# `register_buffer(..., persistent=False)` is incompatible with
# `from_pretrained(low_cpu_mem_usage=True)`: HF builds the model on
# the meta device and only materializes tensors that appear in the
# checkpoint. Non-persistent buffers are NOT in the checkpoint and
# so end up backed by uninitialized memory after meta -> real
# transfer. We side-step this entirely by computing cos/sin on the
# first forward, cached on the model object as plain attributes.
self._cos_cache: Optional[torch.Tensor] = None
self._sin_cache: Optional[torch.Tensor] = None
# Skip _init_weights here — we expect to load weights from a
# pretrained checkpoint immediately after construction.
def get_input_embeddings(self) -> nn.Embedding:
return self.tok_embeddings
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.tok_embeddings = value
# -----------------------------------------------------------------
# Forward
# -----------------------------------------------------------------
def _ensure_rope_cache(self, target_device: torch.device) -> None:
"""Build the RoPE cos/sin caches on `target_device` if not already.
Cached as plain Python attributes (not buffers) to avoid HF's
`low_cpu_mem_usage=True` meta-device materialization issue with
non-persistent buffers.
"""
need_build = (
self._cos_cache is None
or self._cos_cache.device != target_device
)
if not need_build:
return
head_f_dim = self.config.d_model // self.config.n_heads
half_dim = head_f_dim // 2
freqs = 1.0 / (
self.config.rope_theta
** (
torch.arange(0, half_dim, dtype=torch.float32, device=target_device)
/ half_dim
)
)
t = torch.arange(self.config.max_seq_len, dtype=torch.float32, device=target_device)
angles = torch.outer(t, freqs)
self._cos_cache = angles.cos()
self._sin_cache = angles.sin()
def _rope_slice(
self,
seq_len: int,
offset: int,
target_dtype: torch.dtype,
target_device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
end = offset + seq_len
if end > self.config.max_seq_len:
raise ValueError(
f"Position {end} exceeds max_seq_len={self.config.max_seq_len}. "
"OdinNext was trained with a 2048-token RoPE cache."
)
self._ensure_rope_cache(target_device)
cos = self._cos_cache[offset:end].to(dtype=target_dtype)
sin = self._sin_cache[offset:end].to(dtype=target_dtype)
cos = cos.unsqueeze(0).unsqueeze(2) # [1, T, 1, D/2]
sin = sin.unsqueeze(0).unsqueeze(2)
return cos, sin
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[OdinNextCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**_unused,
) -> Tuple[torch.Tensor, Optional[OdinNextCache]]:
"""Backbone forward.
Returns `(hidden_states, past_key_values)`. The LM-head wrapper
(`OdinNextForCausalLM`) projects to logits.
Note: `attention_mask` is accepted for HF API compatibility but is
NOT used. HGRN2 is causal by construction (the recurrence is strictly
forward-in-time) and cannot honor a left-padded mask. For correct
results with batched generation, callers must right-pad and ensure
all sequences in a batch have valid tokens at every position they
process. Single-sequence generation is unaffected.
"""
if use_cache is None:
use_cache = self.config.use_cache
B, T = input_ids.shape
# Determine if we're in single-token decode mode.
single_step = (T == 1) and (past_key_values is not None)
# RoPE position offset
if past_key_values is not None:
offset = past_key_values.seen_tokens
else:
offset = 0
h = self.tok_embeddings(input_ids)
# Prepare RoPE caches in the embedding's dtype.
cos, sin = self._rope_slice(
seq_len=T, offset=offset,
target_dtype=h.dtype, target_device=h.device,
)
# Coerce past_key_values to our expected type. HF generate may
# try to auto-instantiate a DynamicCache or pass a legacy tuple;
# we want strict OdinNextCache or None.
if past_key_values is not None and not isinstance(past_key_values, OdinNextCache):
past_key_values = None
if past_key_values is None and use_cache:
past_key_values = OdinNextCache(self.config.n_layers)
for i, layer in enumerate(self.layers):
prev_state = past_key_values[i] if past_key_values is not None else None
h, new_state = layer(
h,
cos=cos, sin=sin,
recurrent_state=prev_state,
output_state=use_cache,
use_recurrent_kernel=single_step,
)
if use_cache and past_key_values is not None:
past_key_values[i] = new_state
h = self.final_norm(h)
if past_key_values is not None:
past_key_values.update_seen(T)
return h, past_key_values
class OdinNextForCausalLM(OdinNextPreTrainedModel):
"""Top-level wrapper exposing logits + HF generate()."""
# Map tied output -> source. Newer `transformers` (>=4.45) expects a
# dict; older versions tolerate (and used) a list of keys. Provide the
# dict form which is forward-compatible.
_tied_weights_keys = {"lm_head.weight": "model.tok_embeddings.weight"}
def __init__(self, config: OdinNextConfig):
super().__init__(config)
self.model = OdinNextModel(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.model.tok_embeddings.weight
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.tok_embeddings
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.model.tok_embeddings = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[OdinNextCache] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**_unused,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else True
hidden_states, past_key_values = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)).float(),
shift_labels.view(-1).long(),
ignore_index=-100,
)
if not return_dict:
output = (logits,) + ((past_key_values,) if past_key_values is not None else ())
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
# -----------------------------------------------------------------
# generate() integration
# -----------------------------------------------------------------
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[OdinNextCache] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = True,
**kwargs,
) -> dict:
"""Trim input_ids to only the new positions when a cache exists.
After the first forward, the recurrent state already encodes the
prompt. Subsequent calls only need to pass the most recently
generated token.
"""
if past_key_values is not None and past_key_values.seen_tokens > 0:
# New tokens since last call.
new_count = input_ids.shape[1] - past_key_values.seen_tokens
if new_count <= 0:
# generate() can occasionally call us with the same length
# twice (e.g., assistant-decoding paths). Default to feeding
# the last token only.
input_ids = input_ids[:, -1:]
else:
input_ids = input_ids[:, -new_count:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache": use_cache,
}
def _reorder_cache(
self, past_key_values: OdinNextCache, beam_idx: torch.Tensor
) -> OdinNextCache:
"""Beam-search support: reorder per-layer states along the batch axis."""
for i, state in enumerate(past_key_values.states):
if state is not None:
past_key_values.states[i] = state.index_select(0, beam_idx.to(state.device))
return past_key_values
@staticmethod
def _supports_default_dynamic_cache() -> bool:
return False
# Re-export for convenience
__all__ = [
"OdinNextConfig",
"OdinNextModel",
"OdinNextForCausalLM",
"OdinNextCache",
]