OdinNext-138M-Early-Checkpoint / _hgrn2_fallback.py
joelhenwang's picture
Initial release: 6.84B token early checkpoint (EMA weights)
cb8708f verified
# coding=utf-8
# Copyright 2026 The OdinNext authors.
# Licensed under the Apache License, Version 2.0.
"""Pure-PyTorch HGRN2 recurrence — slow fallback when flash-linear-attention
(`fla`) is unavailable.
The `fla` library provides Triton/CUDA kernels for `chunk_gla` (chunk-wise
parallel scan over T) and `fused_recurrent_gla` (token-by-token serial scan).
On platforms without those kernels (CPU, non-CUDA/non-ROCm GPUs) we provide
a reference implementation here.
Speed: ~10-30x slower than `fla` at training shapes; comparable for
single-token decode (since both are serial). Numerical match: bitwise on
fp32, within fp16 noise on fp16.
The recurrence (per head):
S_t = diag(exp(g_t)) @ S_{t-1} + k_t.unsqueeze(-1) @ v_t.unsqueeze(-2)
o_t = q_t @ S_t
Shapes (matching `fla.ops.gla.chunk_gla`):
q: [B, T, H, K] (K = head_f_dim, e.g. 128)
k: [B, T, H, K]
g: [B, T, H, K] (already in log-space, expected to be <= 0)
v: [B, T, H, V] (V = head_i_dim, e.g. 128)
-> o: [B, T, H, V]
final_state: [B, H, K, V] if output_final_state else None
"""
from typing import Optional, Tuple
import torch
def chunk_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
**_unused,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Pure-PyTorch chunk_gla replacement.
Implements a serial (token-by-token) scan. We promote internals to fp32
to keep the cumulative product of decays numerically sane over long T.
"""
B, T, H, K = q.shape
V = v.shape[-1]
device = q.device
in_dtype = q.dtype
# Promote scan internals to fp32 for stability (matches fla behavior).
q32 = q.float()
k32 = k.float()
v32 = v.float()
g32 = g.float()
if initial_state is None:
S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
else:
S = initial_state.to(dtype=torch.float32)
out = torch.empty(B, T, H, V, device=device, dtype=torch.float32)
# Serial scan. exp(g_t) decays state element-wise along K.
# k_t outer v_t -> [B, H, K, V] additive update.
for t in range(T):
decay = g32[:, t].exp().unsqueeze(-1) # [B, H, K, 1]
S = decay * S + k32[:, t].unsqueeze(-1) * v32[:, t].unsqueeze(-2)
# o_t = q_t (1xK) @ S (KxV) per head
out[:, t] = (q32[:, t].unsqueeze(-2) @ S).squeeze(-2) # [B, H, V]
out = out.to(in_dtype)
if output_final_state:
return out, S
return out, None
def fused_recurrent_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gk: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = True,
**_unused,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Pure-PyTorch single-token (or short-T) recurrence.
`fla.ops.gla.fused_recurrent_gla` is what OdinNext.generate uses for
O(1) per-token decode. The signature matches: `gk` = log-decay (instead
of `g`). We reuse `chunk_gla` internals — they are mathematically the
same scan, just packaged with different defaults for kernel selection
in fla.
"""
return chunk_gla(
q=q, k=k, v=v, g=gk,
initial_state=initial_state,
output_final_state=output_final_state,
)