# coding=utf-8 # Copyright 2026 The OdinNext authors. # Licensed under the Apache License, Version 2.0. """Pure-PyTorch HGRN2 recurrence — slow fallback when flash-linear-attention (`fla`) is unavailable. The `fla` library provides Triton/CUDA kernels for `chunk_gla` (chunk-wise parallel scan over T) and `fused_recurrent_gla` (token-by-token serial scan). On platforms without those kernels (CPU, non-CUDA/non-ROCm GPUs) we provide a reference implementation here. Speed: ~10-30x slower than `fla` at training shapes; comparable for single-token decode (since both are serial). Numerical match: bitwise on fp32, within fp16 noise on fp16. The recurrence (per head): S_t = diag(exp(g_t)) @ S_{t-1} + k_t.unsqueeze(-1) @ v_t.unsqueeze(-2) o_t = q_t @ S_t Shapes (matching `fla.ops.gla.chunk_gla`): q: [B, T, H, K] (K = head_f_dim, e.g. 128) k: [B, T, H, K] g: [B, T, H, K] (already in log-space, expected to be <= 0) v: [B, T, H, V] (V = head_i_dim, e.g. 128) -> o: [B, T, H, V] final_state: [B, H, K, V] if output_final_state else None """ from typing import Optional, Tuple import torch def chunk_gla( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, **_unused, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Pure-PyTorch chunk_gla replacement. Implements a serial (token-by-token) scan. We promote internals to fp32 to keep the cumulative product of decays numerically sane over long T. """ B, T, H, K = q.shape V = v.shape[-1] device = q.device in_dtype = q.dtype # Promote scan internals to fp32 for stability (matches fla behavior). q32 = q.float() k32 = k.float() v32 = v.float() g32 = g.float() if initial_state is None: S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) else: S = initial_state.to(dtype=torch.float32) out = torch.empty(B, T, H, V, device=device, dtype=torch.float32) # Serial scan. exp(g_t) decays state element-wise along K. # k_t outer v_t -> [B, H, K, V] additive update. for t in range(T): decay = g32[:, t].exp().unsqueeze(-1) # [B, H, K, 1] S = decay * S + k32[:, t].unsqueeze(-1) * v32[:, t].unsqueeze(-2) # o_t = q_t (1xK) @ S (KxV) per head out[:, t] = (q32[:, t].unsqueeze(-2) @ S).squeeze(-2) # [B, H, V] out = out.to(in_dtype) if output_final_state: return out, S return out, None def fused_recurrent_gla( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gk: torch.Tensor, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = True, **_unused, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Pure-PyTorch single-token (or short-T) recurrence. `fla.ops.gla.fused_recurrent_gla` is what OdinNext.generate uses for O(1) per-token decode. The signature matches: `gk` = log-decay (instead of `g`). We reuse `chunk_gla` internals — they are mathematically the same scan, just packaged with different defaults for kernel selection in fla. """ return chunk_gla( q=q, k=k, v=v, g=gk, initial_state=initial_state, output_final_state=output_final_state, )