Text Generation
Transformers
Safetensors
English
odinnext
hgrn2
linear-attention
recurrent
causal-lm
custom_code
early-checkpoint
fp16
amd
rocm
Instructions to use joelhenwang/OdinNext-138M-Early-Checkpoint with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use joelhenwang/OdinNext-138M-Early-Checkpoint with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="joelhenwang/OdinNext-138M-Early-Checkpoint", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("joelhenwang/OdinNext-138M-Early-Checkpoint", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use joelhenwang/OdinNext-138M-Early-Checkpoint with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "joelhenwang/OdinNext-138M-Early-Checkpoint" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "joelhenwang/OdinNext-138M-Early-Checkpoint", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/joelhenwang/OdinNext-138M-Early-Checkpoint
- SGLang
How to use joelhenwang/OdinNext-138M-Early-Checkpoint with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "joelhenwang/OdinNext-138M-Early-Checkpoint" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "joelhenwang/OdinNext-138M-Early-Checkpoint", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "joelhenwang/OdinNext-138M-Early-Checkpoint" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "joelhenwang/OdinNext-138M-Early-Checkpoint", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use joelhenwang/OdinNext-138M-Early-Checkpoint with Docker Model Runner:
docker model run hf.co/joelhenwang/OdinNext-138M-Early-Checkpoint
| # coding=utf-8 | |
| # Copyright 2026 The OdinNext authors. | |
| # Licensed under the Apache License, Version 2.0. | |
| """OdinNext: 138M HGRN2+RoPE hybrid causal language model. | |
| This is a self-contained HuggingFace `trust_remote_code=True` port of the | |
| production OdinNext model used to train the 6.84B-token early checkpoint. | |
| The training-time machinery (DiffusionBlocks, TST, gate-absorption, | |
| torch.compile zone helpers) is dropped — only the inference path remains. | |
| Architecture summary: | |
| * 16 layers, d=768, 6 heads, ffn=2048, vocab=32768. | |
| * Even layers (0,2,...,14) get RoPE on q/k. | |
| * Odd layers (1,3,...,15) are position-free recurrent. | |
| * SwiGLU2 FFN: silu(gate)^2 * up. | |
| * ZCRMSNorm normalization, gated residuals (frozen at training time). | |
| * Tied input/output embeddings. | |
| * HGRN2 recurrence: O(T) train, O(1) per-token decode. | |
| Hardware notes: | |
| * Uses `flash-linear-attention` (`fla`) Triton kernels when available. | |
| Falls back to a pure-PyTorch implementation (~10-30x slower) otherwise, | |
| so the model loads on any backend including CPU. | |
| * Trained in fp16 on AMD Strix Halo (gfx1151, RDNA 3.5, ROCm 7.13). | |
| fp16 is the recommended inference dtype. bf16 was never validated on | |
| this checkpoint. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from .configuration_odinnext import OdinNextConfig | |
| # --------------------------------------------------------------------------- | |
| # HGRN2 kernel: prefer flash-linear-attention, fall back to pure PyTorch | |
| # --------------------------------------------------------------------------- | |
| try: | |
| from fla.ops.gla import chunk_gla as _chunk_gla | |
| from fla.ops.gla import fused_recurrent_gla as _fused_recurrent_gla | |
| # `fla.ops.gla.chunk.ChunkGLAFunction` is decorated with | |
| # @torch.compiler.disable. Marking it allow_in_graph lets Dynamo treat | |
| # it as an opaque leaf op, preventing graph breaks if the user does | |
| # `torch.compile(model)`. Best-effort, ignored if internals shift. | |
| try: | |
| from fla.ops.gla.chunk import ChunkGLAFunction | |
| torch.compiler.allow_in_graph(ChunkGLAFunction) | |
| except Exception: | |
| pass | |
| _HAS_FLA = True | |
| except Exception: # ImportError, missing Triton, no CUDA/ROCm, ... | |
| from ._hgrn2_fallback import chunk_gla as _chunk_gla | |
| from ._hgrn2_fallback import fused_recurrent_gla as _fused_recurrent_gla | |
| _HAS_FLA = False | |
| # --------------------------------------------------------------------------- | |
| # Building blocks | |
| # --------------------------------------------------------------------------- | |
| class ZCRMSNorm(nn.Module): | |
| """Zero-Centered RMSNorm. | |
| Stored weight is initialized to 1.0; F.rms_norm sees a leaf parameter | |
| directly. Mathematically equivalent to RMSNorm with `gamma = weight - 1`. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self._normalized_shape = (dim,) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.rms_norm(x, self._normalized_shape, self.weight, self.eps) | |
| class SwiGLU2(nn.Module): | |
| """SwiGLU squared FFN: silu(gate)^2 * up -> down.""" | |
| def __init__(self, d_model: int, ffn_inner: int): | |
| super().__init__() | |
| self.w_gate_up = nn.Linear(d_model, 2 * ffn_inner, bias=False) | |
| self.w_down = nn.Linear(ffn_inner, d_model, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| gate, up = self.w_gate_up(x).chunk(2, dim=-1) | |
| return self.w_down(F.silu(gate).square() * up) | |
| def _apply_rope( | |
| x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Apply RoPE to x[B,T,H,D] using real arithmetic. | |
| cos/sin: [1, T, 1, D/2] pre-broadcast. | |
| """ | |
| x_even = x[..., 0::2] | |
| x_odd = x[..., 1::2] | |
| out_even = x_even * cos - x_odd * sin | |
| out_odd = x_even * sin + x_odd * cos | |
| return torch.stack([out_even, out_odd], dim=-1).flatten(-2) | |
| class OdinNextAttention(nn.Module): | |
| """HGRN2 attention with optional RoPE on q/k.""" | |
| def __init__( | |
| self, | |
| d_model: int = 768, | |
| n_heads: int = 6, | |
| expand_ratio: Optional[int] = None, | |
| use_rope: bool = True, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| if expand_ratio is None: | |
| expand_ratio = d_model // n_heads | |
| self.expand_ratio = expand_ratio | |
| self.head_f_dim = expand_ratio | |
| self.head_i_dim = d_model // n_heads | |
| self.forget_dim = n_heads * expand_ratio | |
| self.use_rope = use_rope | |
| self.q_proj = nn.Linear(d_model, self.forget_dim, bias=False) | |
| self.f_proj = nn.Linear(d_model, self.forget_dim, bias=False) | |
| self.i_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.g_norm = ZCRMSNorm(d_model) | |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos: Optional[torch.Tensor] = None, | |
| sin: Optional[torch.Tensor] = None, | |
| recurrent_state: Optional[torch.Tensor] = None, | |
| output_state: bool = False, | |
| use_recurrent_kernel: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Args: | |
| x: [B, T, D] hidden states. | |
| cos, sin: RoPE caches if `use_rope`, else ignored. | |
| recurrent_state: optional [B, H, K, V] HGRN2 state to seed the scan. | |
| output_state: if True, return the final HGRN2 state alongside output. | |
| use_recurrent_kernel: if True (single-token decode), call the | |
| fused recurrent kernel; otherwise call chunk_gla. | |
| """ | |
| B, T, D = x.shape | |
| q = F.silu(self.q_proj(x)) | |
| forget_logits = self.f_proj(x) | |
| g = F.logsigmoid(forget_logits) | |
| k = torch.sigmoid(-forget_logits) | |
| v = self.i_proj(x) | |
| q = q.view(B, T, self.n_heads, self.head_f_dim) | |
| k = k.view(B, T, self.n_heads, self.head_f_dim) | |
| g = g.view(B, T, self.n_heads, self.head_f_dim) | |
| v = v.view(B, T, self.n_heads, self.head_i_dim) | |
| if self.use_rope and cos is not None: | |
| q = _apply_rope(q, cos, sin) | |
| k = _apply_rope(k, cos, sin) | |
| if use_recurrent_kernel: | |
| o, final_state = _fused_recurrent_gla( | |
| q=q, k=k, v=v, gk=g, | |
| initial_state=recurrent_state, | |
| output_final_state=True, | |
| ) | |
| else: | |
| o, final_state = _chunk_gla( | |
| q=q, k=k, v=v, g=g, | |
| initial_state=recurrent_state, | |
| output_final_state=output_state, | |
| ) | |
| o = o.reshape(B, T, D) | |
| o = self.g_norm(o) | |
| o = self.o_proj(o) | |
| if output_state: | |
| return o, final_state | |
| return o, None | |
| class OdinNextBlock(nn.Module): | |
| """Pre-norm block with gated residuals. | |
| Gates were absorbed and frozen at training time: `gate_attn` and | |
| `gate_ffn` are stored as scalars whose `sigmoid()` ≈ 1 by the time of | |
| this checkpoint. They remain in the state_dict for compatibility. | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_heads: int, | |
| ffn_inner: int, | |
| use_rope: bool = True, | |
| ): | |
| super().__init__() | |
| self.pre_norm = ZCRMSNorm(d_model) | |
| self.attn = OdinNextAttention( | |
| d_model=d_model, n_heads=n_heads, use_rope=use_rope | |
| ) | |
| self.ffn_norm = ZCRMSNorm(d_model) | |
| self.ffn = SwiGLU2(d_model, ffn_inner) | |
| self.gate_attn = nn.Parameter(torch.zeros(1)) | |
| self.gate_ffn = nn.Parameter(torch.zeros(1)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos: Optional[torch.Tensor] = None, | |
| sin: Optional[torch.Tensor] = None, | |
| recurrent_state: Optional[torch.Tensor] = None, | |
| output_state: bool = False, | |
| use_recurrent_kernel: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| attn_out, new_state = self.attn( | |
| self.pre_norm(x), | |
| cos=cos, sin=sin, | |
| recurrent_state=recurrent_state, | |
| output_state=output_state, | |
| use_recurrent_kernel=use_recurrent_kernel, | |
| ) | |
| x = x + torch.sigmoid(self.gate_attn) * attn_out | |
| x = x + torch.sigmoid(self.gate_ffn) * self.ffn(self.ffn_norm(x)) | |
| return x, new_state | |
| # --------------------------------------------------------------------------- | |
| # OdinNext recurrent-state cache | |
| # --------------------------------------------------------------------------- | |
| class OdinNextCache: | |
| """Container for HGRN2 recurrent states across all layers. | |
| Wraps `List[Optional[Tensor]]` (one per layer, each [B, H, K, V]) with | |
| just enough surface to satisfy HuggingFace `generate()`'s expectations | |
| for `past_key_values`. Importantly: cache size is independent of T — | |
| it is the per-layer hidden-state matrix S, not a growing K/V tape. | |
| Also tracks `seen_tokens`, the number of input positions the cache has | |
| consumed so far, which OdinNext uses to look up the correct RoPE | |
| position offset during decode. | |
| """ | |
| def __init__(self, n_layers: int): | |
| self.n_layers = n_layers | |
| self.states: List[Optional[torch.Tensor]] = [None] * n_layers | |
| self.seen_tokens: int = 0 | |
| def __len__(self) -> int: | |
| return self.n_layers | |
| def __getitem__(self, idx: int) -> Optional[torch.Tensor]: | |
| return self.states[idx] | |
| def __setitem__(self, idx: int, value: Optional[torch.Tensor]) -> None: | |
| self.states[idx] = value | |
| def __iter__(self): | |
| return iter(self.states) | |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
| return self.seen_tokens | |
| def get_max_length(self) -> Optional[int]: | |
| return None # HGRN2 has no hard cache length cap | |
| def update_seen(self, n_new_tokens: int) -> None: | |
| self.seen_tokens += n_new_tokens | |
| def to(self, device: torch.device) -> "OdinNextCache": | |
| for i, s in enumerate(self.states): | |
| if s is not None: | |
| self.states[i] = s.to(device) | |
| return self | |
| # --------------------------------------------------------------------------- | |
| # OdinNext PreTrainedModel: HF integration | |
| # --------------------------------------------------------------------------- | |
| class OdinNextPreTrainedModel(PreTrainedModel): | |
| """Base class wiring up HF infrastructure for OdinNext.""" | |
| config_class = OdinNextConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = False | |
| _no_split_modules = ["OdinNextBlock"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_cache_class = False # we use our own OdinNextCache | |
| def _init_weights(self, module: nn.Module) -> None: | |
| """Conservative init — at inference we only need to define defaults | |
| in case someone constructs an OdinNext from scratch. | |
| """ | |
| std = getattr(self.config, "initializer_range", 0.02) | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| class OdinNextModel(OdinNextPreTrainedModel): | |
| """Backbone (no LM head).""" | |
| def __init__(self, config: OdinNextConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.d_model) | |
| self.layers = nn.ModuleList([ | |
| OdinNextBlock( | |
| d_model=config.d_model, | |
| n_heads=config.n_heads, | |
| ffn_inner=config.ffn_inner, | |
| use_rope=(i % 2 == 0), | |
| ) | |
| for i in range(config.n_layers) | |
| ]) | |
| self.final_norm = ZCRMSNorm(config.d_model) | |
| # RoPE caches are lazy-built on first forward. Storing them as | |
| # `register_buffer(..., persistent=False)` is incompatible with | |
| # `from_pretrained(low_cpu_mem_usage=True)`: HF builds the model on | |
| # the meta device and only materializes tensors that appear in the | |
| # checkpoint. Non-persistent buffers are NOT in the checkpoint and | |
| # so end up backed by uninitialized memory after meta -> real | |
| # transfer. We side-step this entirely by computing cos/sin on the | |
| # first forward, cached on the model object as plain attributes. | |
| self._cos_cache: Optional[torch.Tensor] = None | |
| self._sin_cache: Optional[torch.Tensor] = None | |
| # Skip _init_weights here — we expect to load weights from a | |
| # pretrained checkpoint immediately after construction. | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.tok_embeddings | |
| def set_input_embeddings(self, value: nn.Embedding) -> None: | |
| self.tok_embeddings = value | |
| # ----------------------------------------------------------------- | |
| # Forward | |
| # ----------------------------------------------------------------- | |
| def _ensure_rope_cache(self, target_device: torch.device) -> None: | |
| """Build the RoPE cos/sin caches on `target_device` if not already. | |
| Cached as plain Python attributes (not buffers) to avoid HF's | |
| `low_cpu_mem_usage=True` meta-device materialization issue with | |
| non-persistent buffers. | |
| """ | |
| need_build = ( | |
| self._cos_cache is None | |
| or self._cos_cache.device != target_device | |
| ) | |
| if not need_build: | |
| return | |
| head_f_dim = self.config.d_model // self.config.n_heads | |
| half_dim = head_f_dim // 2 | |
| freqs = 1.0 / ( | |
| self.config.rope_theta | |
| ** ( | |
| torch.arange(0, half_dim, dtype=torch.float32, device=target_device) | |
| / half_dim | |
| ) | |
| ) | |
| t = torch.arange(self.config.max_seq_len, dtype=torch.float32, device=target_device) | |
| angles = torch.outer(t, freqs) | |
| self._cos_cache = angles.cos() | |
| self._sin_cache = angles.sin() | |
| def _rope_slice( | |
| self, | |
| seq_len: int, | |
| offset: int, | |
| target_dtype: torch.dtype, | |
| target_device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| end = offset + seq_len | |
| if end > self.config.max_seq_len: | |
| raise ValueError( | |
| f"Position {end} exceeds max_seq_len={self.config.max_seq_len}. " | |
| "OdinNext was trained with a 2048-token RoPE cache." | |
| ) | |
| self._ensure_rope_cache(target_device) | |
| cos = self._cos_cache[offset:end].to(dtype=target_dtype) | |
| sin = self._sin_cache[offset:end].to(dtype=target_dtype) | |
| cos = cos.unsqueeze(0).unsqueeze(2) # [1, T, 1, D/2] | |
| sin = sin.unsqueeze(0).unsqueeze(2) | |
| return cos, sin | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[OdinNextCache] = None, | |
| use_cache: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **_unused, | |
| ) -> Tuple[torch.Tensor, Optional[OdinNextCache]]: | |
| """Backbone forward. | |
| Returns `(hidden_states, past_key_values)`. The LM-head wrapper | |
| (`OdinNextForCausalLM`) projects to logits. | |
| Note: `attention_mask` is accepted for HF API compatibility but is | |
| NOT used. HGRN2 is causal by construction (the recurrence is strictly | |
| forward-in-time) and cannot honor a left-padded mask. For correct | |
| results with batched generation, callers must right-pad and ensure | |
| all sequences in a batch have valid tokens at every position they | |
| process. Single-sequence generation is unaffected. | |
| """ | |
| if use_cache is None: | |
| use_cache = self.config.use_cache | |
| B, T = input_ids.shape | |
| # Determine if we're in single-token decode mode. | |
| single_step = (T == 1) and (past_key_values is not None) | |
| # RoPE position offset | |
| if past_key_values is not None: | |
| offset = past_key_values.seen_tokens | |
| else: | |
| offset = 0 | |
| h = self.tok_embeddings(input_ids) | |
| # Prepare RoPE caches in the embedding's dtype. | |
| cos, sin = self._rope_slice( | |
| seq_len=T, offset=offset, | |
| target_dtype=h.dtype, target_device=h.device, | |
| ) | |
| # Coerce past_key_values to our expected type. HF generate may | |
| # try to auto-instantiate a DynamicCache or pass a legacy tuple; | |
| # we want strict OdinNextCache or None. | |
| if past_key_values is not None and not isinstance(past_key_values, OdinNextCache): | |
| past_key_values = None | |
| if past_key_values is None and use_cache: | |
| past_key_values = OdinNextCache(self.config.n_layers) | |
| for i, layer in enumerate(self.layers): | |
| prev_state = past_key_values[i] if past_key_values is not None else None | |
| h, new_state = layer( | |
| h, | |
| cos=cos, sin=sin, | |
| recurrent_state=prev_state, | |
| output_state=use_cache, | |
| use_recurrent_kernel=single_step, | |
| ) | |
| if use_cache and past_key_values is not None: | |
| past_key_values[i] = new_state | |
| h = self.final_norm(h) | |
| if past_key_values is not None: | |
| past_key_values.update_seen(T) | |
| return h, past_key_values | |
| class OdinNextForCausalLM(OdinNextPreTrainedModel): | |
| """Top-level wrapper exposing logits + HF generate().""" | |
| # Map tied output -> source. Newer `transformers` (>=4.45) expects a | |
| # dict; older versions tolerate (and used) a list of keys. Provide the | |
| # dict form which is forward-compatible. | |
| _tied_weights_keys = {"lm_head.weight": "model.tok_embeddings.weight"} | |
| def __init__(self, config: OdinNextConfig): | |
| super().__init__(config) | |
| self.model = OdinNextModel(config) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| if config.tie_embeddings: | |
| self.lm_head.weight = self.model.tok_embeddings.weight | |
| self.post_init() | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.model.tok_embeddings | |
| def set_input_embeddings(self, value: nn.Embedding) -> None: | |
| self.model.tok_embeddings = value | |
| def get_output_embeddings(self) -> nn.Linear: | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: | |
| self.lm_head = new_embeddings | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[OdinNextCache] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **_unused, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| return_dict = return_dict if return_dict is not None else True | |
| hidden_states, past_key_values = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| ) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)).float(), | |
| shift_labels.view(-1).long(), | |
| ignore_index=-100, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + ((past_key_values,) if past_key_values is not None else ()) | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=past_key_values, | |
| hidden_states=None, | |
| attentions=None, | |
| ) | |
| # ----------------------------------------------------------------- | |
| # generate() integration | |
| # ----------------------------------------------------------------- | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.Tensor, | |
| past_key_values: Optional[OdinNextCache] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| use_cache: Optional[bool] = True, | |
| **kwargs, | |
| ) -> dict: | |
| """Trim input_ids to only the new positions when a cache exists. | |
| After the first forward, the recurrent state already encodes the | |
| prompt. Subsequent calls only need to pass the most recently | |
| generated token. | |
| """ | |
| if past_key_values is not None and past_key_values.seen_tokens > 0: | |
| # New tokens since last call. | |
| new_count = input_ids.shape[1] - past_key_values.seen_tokens | |
| if new_count <= 0: | |
| # generate() can occasionally call us with the same length | |
| # twice (e.g., assistant-decoding paths). Default to feeding | |
| # the last token only. | |
| input_ids = input_ids[:, -1:] | |
| else: | |
| input_ids = input_ids[:, -new_count:] | |
| return { | |
| "input_ids": input_ids, | |
| "past_key_values": past_key_values, | |
| "attention_mask": attention_mask, | |
| "use_cache": use_cache, | |
| } | |
| def _reorder_cache( | |
| self, past_key_values: OdinNextCache, beam_idx: torch.Tensor | |
| ) -> OdinNextCache: | |
| """Beam-search support: reorder per-layer states along the batch axis.""" | |
| for i, state in enumerate(past_key_values.states): | |
| if state is not None: | |
| past_key_values.states[i] = state.index_select(0, beam_idx.to(state.device)) | |
| return past_key_values | |
| def _supports_default_dynamic_cache() -> bool: | |
| return False | |
| # Re-export for convenience | |
| __all__ = [ | |
| "OdinNextConfig", | |
| "OdinNextModel", | |
| "OdinNextForCausalLM", | |
| "OdinNextCache", | |
| ] | |