joelhenwang commited on
Commit
51b0052
·
verified ·
1 Parent(s): 244ade3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +239 -90
README.md CHANGED
@@ -1,142 +1,291 @@
1
  ---
2
  license: apache-2.0
 
 
3
  library_name: transformers
4
  pipeline_tag: text-generation
5
- language:
6
- - en
7
  tags:
8
- - odinnext
9
- - hgrn2
10
- - linear-attention
11
- - recurrent
12
- - custom_code
13
- - early-checkpoint
14
- - causal-lm
15
- - amd
16
- - rocm
 
 
 
 
 
17
  ---
18
 
19
  # OdinNext-138M-Early-Checkpoint
20
 
21
- Early-stage checkpoint of **OdinNext**, a 138M-parameter HGRN2 linear-attention LM trained from scratch on AMD Strix Halo (gfx1151). **6.84B tokens, ~3% of the planned pretraining budget.** Still in active development — output quality is weak, no SFT, no alignment, no context extension.
 
 
 
 
 
 
 
 
22
 
23
- > **Variant**: `main (EMA)` EMA-shadowed weights (decay 0.999), recommended for evaluation. See the [`live` revision](https://huggingface.co/joelhenwang/OdinNext-138M-Early-Checkpoint/tree/live) for the raw training weights.
24
 
25
  ## At a glance
26
 
27
- | | |
28
- |---|---|
29
- | Params | 138.4M (113.3M non-embedding) |
30
- | Architecture | 16 layers, HGRN2 + alternating RoPE, SwiGLU², ZCRMSNorm |
31
- | Hidden / Heads / FFN | 768 / 6 / 2048 |
32
- | Vocab / Context | 32,768 (custom BPE) / 2,048 |
33
- | Inference | **O(1) per token, fixed 3 MB recurrent state** (no growing KV cache) |
34
- | Training | fp16 + GradScaler, NorMuon (2D) + AdamW (1D/embed), WSD, EMA 0.999 |
35
- | Curriculum | TST bag-size-4 active throughout this checkpoint |
36
- | Hardware | 2× AMD Strix Halo (gfx1151), ROCm 7.13, gloo over Thunderbolt 4 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ## Architecture
39
 
40
- OdinNext replaces softmax attention with the **HGRN2 gated linear recurrence** [1]:
41
- `S_t = diag(exp(g_t)) · S_{t-1} + k_t ⊗ v_t`, `o_t = q_t · S_t`. The state is a fixed-size matrix updated in place, so per-token decode is O(1) in compute and memory regardless of context length.
42
 
43
- Sixteen identical pre-norm blocks: `x + σ(gate_attn) · HGRN2(ZCRMSNorm(x))`, `x + σ(gate_ffn) · SwiGLU²(ZCRMSNorm(x))`. Tied embeddings + LM head. No biases on linear layers.
 
 
 
44
 
45
- **Hybrid RoPE**: even layers apply RoPE on q/k (θ=100,000); odd layers are position-free. Half the depth thus generalizes to arbitrary length without ABF, simplifying future context extension.
46
 
47
- ### Decisions and why
 
 
 
48
 
49
- Choices below come from 25+ ablations on a 100M proxy model; only the BPB-winning configuration shipped.
50
 
51
- - **Linear attention (HGRN2) over softmax**: gfx1151 has no MFMA tensor cores. Custom HIP attention can't beat rocBLAS, and softmax attention's O(T²) memory dominates step time on this platform. HGRN2 is dominated by element-wise ops + small GEMMs that fit rocBLAS-friendly shapes.
52
- - **fp16, not bf16**: bf16 GEMMs are 24% slower on gfx1151 and trigger Inductor crashes under `torch.compile`. fp16 + GradScaler + z-loss + activation soft-cap is stable.
53
- - **SwiGLU² over SwiGLU**: −0.009 BPB at iso-parameter count. The squared SiLU gate gives sharper sparsity with smooth gradients.
54
- - **ZCRMSNorm + zero-init gates**: block at init is approximately identity (γ=0, σ(0)=0.5). Loss starts at ≈ln(V), no spike-and-recover phase. Required for future block-wise denoising training [3].
55
- - **NorMuon (2D, fp16 NS) + AdamW (1D, embed @ 0.3× LR)**: each parameter group gets the right update rule for its geometry; Newton-Schulz in fp16 is ~10× faster than fp32 on this platform with no measurable quality loss.
56
- - **TST bag-size-4 curriculum** [2]: every position averages 4 stochastic subword tokenizations of the same text. Forces tokenization-invariant representations early. **Note**: this checkpoint is fully pre-transition (still bagged) → single-stream inference is slightly OOD. Quality is expected to lift after the planned bag-size→1 transition.
57
 
58
- ## Training
59
 
60
- | | |
61
- |---|---|
62
- | Batch | 32 seqs × 4 grad-accum × 2 ranks = 256 effective sequences (524,288 tokens/step) |
63
- | Optimizer steps | 3,259 |
64
- | LR schedule | WSD, peak 8e-4 (NorMuon), warmup 500, MIN_LR 0.1× |
65
- | Stability | z-loss 1e-4, attn-softcap 50, EMA decay 0.999, GradScaler growth 500 |
66
- | Compile | `max-autotune-no-cudagraphs`, per-layer (`compile_zones`) |
67
- | Throughput | ~427K tokens/s aggregate across 2 nodes |
68
- | Run health | 0 NaN events; GradScaler scale 1024→65,536 cleanly |
69
- | Final loss / BPB (step 3,200) | 1.886 / 0.755 |
70
-
71
- ## Memory: HGRN2 state vs Transformer KV cache
72
-
73
- | Context | Transformer KV (typical d=768) | OdinNext HGRN2 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  |---:|---:|---:|
75
- | 1K | ~24 MB | **~3 MB** |
76
- | 4K | ~96 MB | **~3 MB** |
77
- | 16K | ~384 MB | **~3 MB** |
78
- | 64K | ~1.5 GB | **~3 MB** |
 
 
 
 
79
 
80
- State size: `n_layers × n_heads × head_f_dim × head_i_dim × 2 bytes` ≈ 3 MB, **constant**.
81
 
82
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  ```python
85
- from transformers import AutoModelForCausalLM, AutoTokenizer
86
  import torch
 
 
 
 
 
87
 
88
- name = "joelhenwang/OdinNext-138M-Early-Checkpoint"
89
- tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
 
 
90
  model = AutoModelForCausalLM.from_pretrained(
91
- name, trust_remote_code=True, torch_dtype=torch.float16
92
- ).to("cuda" if torch.cuda.is_available() else "cpu").eval()
 
 
 
 
 
 
 
 
 
 
93
 
94
- inputs = tok("The night was quiet and the streets were empty", return_tensors="pt").to(model.device)
95
  with torch.inference_mode():
96
  out = model.generate(
97
- **inputs, max_new_tokens=80, do_sample=True,
98
- temperature=0.8, top_p=0.95, repetition_penalty=1.1,
99
- pad_token_id=tok.pad_token_id, use_cache=True,
 
 
 
 
 
100
  )
 
101
  print(tok.decode(out[0], skip_special_tokens=True))
102
  ```
103
 
104
- - `use_cache=True` is essential — without it, the model re-processes the full prefix each step.
105
- - `past_key_values` is **not** a KV cache; it's a fixed-size HGRN2 state (`OdinNextCache`).
106
- - Hard cap at 2,048 cumulative positions. Recurrence is causal-only — for batched generation, **right-pad**.
107
- - [`flash-linear-attention`](https://github.com/sustcsonglin/flash-linear-attention) is recommended (~10–30× faster Triton kernels). The model auto-falls-back to a pure-PyTorch reference if `fla` is unavailable.
108
 
109
- ## Caveats
 
 
 
 
 
 
 
110
 
111
- - ❌ No SFT, no DPO/RLHF, no chat template, no safety training.
112
- - ❌ No context extension (max 2,048 tokens).
113
- - ❌ English-only mixture; multilingual and code outputs will be poor.
114
- - ❌ TST bagging still active → expect a quality jump at the planned bag→1 transition.
115
- - ❌ bf16 inference untested on this checkpoint.
116
- - ❌ Formal benchmarks (HellaSwag, ARC, etc.) pending.
117
 
118
- ## Revisions
119
 
120
- - **`main`** — EMA-shadowed weights (decay 0.999). Recommended for evaluation.
121
- - **`live`** — raw training weights at step 3,259.
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- ## License
 
124
 
125
- Apache-2.0.
126
 
127
  ## Citation
128
 
129
  ```bibtex
130
- @misc{{odinnext_138m_early_2026,
131
- title = {{OdinNext-138M-Early-Checkpoint}},
132
- author = {{Wang, Joel}},
133
- year = {{2026}},
134
- url = {{https://huggingface.co/joelhenwang/OdinNext-138M-Early-Checkpoint}},
135
- }}
 
136
  ```
137
 
138
  ## References
139
 
140
- [1] Qin, Yang, Sun, et al. **HGRN2: Gated Linear RNNs with State Expansion.** arXiv:2404.07904, 2024.
141
- [2] **Token Superposition Training (TST).** arXiv:2605.06546. (Related: PatchTrain, Shao et al., arXiv:2407.12665, ICLR 2025.)
142
- [3] **DiffusionBlocks** — block-wise training via score-matching denoising (Iizuka et al., 2025). Used in the planned post-this-checkpoint phase, not in this run.
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
  library_name: transformers
6
  pipeline_tag: text-generation
 
 
7
  tags:
8
+ - odinnext
9
+ - hgrn2
10
+ - linear-attention
11
+ - recurrent
12
+ - causal-lm
13
+ - custom_code
14
+ - early-checkpoint
15
+ - fp16
16
+ - amd
17
+ - rocm
18
+ - arxiv:2404.07904
19
+ - arxiv:2605.06546
20
+ - arxiv:2407.12665
21
+ - arxiv:2506.14202
22
  ---
23
 
24
  # OdinNext-138M-Early-Checkpoint
25
 
26
+ Early research checkpoint of **OdinNext**, a 138M-parameter causal language model using an HGRN2-style gated linear recurrence instead of softmax self-attention.
27
+
28
+ This is **not** a chat model and not a production release. It is an early pretraining checkpoint intended for architecture inspection, qualitative sampling, and continued research.
29
+
30
+ - **Repo:** `joelhenwang/OdinNext-138M-Early-Checkpoint`
31
+ - **Recommended revision:** `main` / EMA-shadowed weights
32
+ - **Training status:** early checkpoint at step 3,259
33
+ - **Context window:** 2,048 tokens in the released inference code
34
+ - **License:** Apache-2.0
35
 
36
+ > The model uses custom Transformers code. Loading it with `trust_remote_code=True` executes Python code from this repository. Only do this after reviewing the files or pinning a known commit.
37
 
38
  ## At a glance
39
 
40
+ | Item | Value |
41
+ |---|---:|
42
+ | Unique tied parameters | **138,449,696** |
43
+ | Non-embedding parameters | **113,283,872** |
44
+ | Layers | 16 |
45
+ | Hidden size | 768 |
46
+ | Heads | 6 |
47
+ | Head state dims | 128 × 128 per head |
48
+ | FFN inner size | 2,048 |
49
+ | Vocabulary | 32,768 custom BPE tokens |
50
+ | Max sequence length | 2,048 |
51
+ | Checkpoint dtype | fp16 |
52
+ | Architecture | HGRN2 recurrence + alternating RoPE + SwiGLU² FFN + RMSNorm-style normalization |
53
+ | Cache type | Fixed recurrent state, not a growing Transformer KV cache |
54
+
55
+ ## What this checkpoint is good for
56
+
57
+ Use this checkpoint for:
58
+
59
+ - inspecting a compact recurrent/linear-attention LM implementation;
60
+ - testing HGRN2-style recurrent decoding inside the Hugging Face `generate()` API;
61
+ - studying fixed-state decoding memory behavior;
62
+ - continuing pretraining or running controlled ablations.
63
+
64
+ Do **not** use it for:
65
+
66
+ - chat, instruction following, or agentic tasks;
67
+ - safety-sensitive output generation;
68
+ - benchmark claims without running your own evaluation;
69
+ - multilingual, coding, or long-context claims.
70
 
71
  ## Architecture
72
 
73
+ OdinNext is a decoder-only causal LM. Each block uses a pre-norm residual layout:
 
74
 
75
+ ```text
76
+ x = x + sigmoid(gate_attn) * HGRN2(norm(x))
77
+ x = x + sigmoid(gate_ffn) * SwiGLU²(norm(x))
78
+ ```
79
 
80
+ The HGRN2-style recurrent state is updated per token as:
81
 
82
+ ```text
83
+ S_t = diag(exp(g_t)) S_{t-1} + k_t ⊗ v_t
84
+ o_t = q_t S_t
85
+ ```
86
 
87
+ where each layer keeps a per-batch recurrent state shaped:
88
 
89
+ ```text
90
+ [B, n_heads, head_f_dim, head_i_dim]
91
+ ```
 
 
 
92
 
93
+ For this checkpoint:
94
 
95
+ ```text
96
+ n_heads = 6
97
+ head_f_dim = 128
98
+ head_i_dim = 128
99
+ ```
100
+
101
+ Even-numbered layers apply RoPE to `q` and `k`; odd-numbered layers are position-free. The current inference implementation still enforces a hard 2,048-token cumulative position limit because the RoPE cache is built for `max_seq_len = 2048`.
102
+
103
+ ### Important implementation details
104
+
105
+ - The exported Hugging Face code contains only the inference path. Training-time machinery is not part of this repository.
106
+ - `past_key_values` is an `OdinNextCache`, a list of recurrent states. It is **not** a Transformer KV cache.
107
+ - `attention_mask` is accepted for API compatibility but ignored by the backbone. Left-padding is not supported.
108
+ - Batched generation is safest when all prompts have the same valid length. Padding tokens are still tokens to the recurrence if they are processed.
109
+ - `use_cache=True` is important for generation. Without it, every generation step reprocesses the full prefix.
110
+
111
+ ## Parameter accounting
112
+
113
+ The 138M headline is the **unique tied-parameter runtime count**. The input embedding and LM head are tied and should be counted once for model-capacity comparisons.
114
+
115
+
116
+
117
+ Hugging Face or file-size-derived parameter summaries may round this checkpoint near 0.2B because stored checkpoint tensors and tied runtime parameters are not always counted the same way.
118
+
119
+ ## Memory: recurrent state vs Transformer KV cache
120
+
121
+ For batch size 1 in fp16, OdinNext's recurrent state size is:
122
+
123
+ ```text
124
+ layers × heads × head_f_dim × head_i_dim × bytes
125
+ = 16 × 6 × 128 × 128 × 2
126
+ = 3,145,728 bytes ≈ 3.0 MiB
127
+ ```
128
+
129
+ That state is constant with respect to generated context length. It scales linearly with batch size and with dtype size. In the pure-PyTorch fallback path, the scan state is promoted to fp32, so the returned recurrent state can be about **6.0 MiB per sequence** instead of 3.0 MiB.
130
+
131
+ A same-depth 16-layer, `d_model = 768`, fp16 Transformer with full multi-head K/V cache would use approximately:
132
+
133
+ ```text
134
+ layers × 2(K,V) × hidden_size × context_tokens × bytes
135
+ = 16 × 2 × 768 × T × 2
136
+ ```
137
+
138
+ | Context tokens | Typical Transformer KV cache | OdinNext recurrent state |
139
  |---:|---:|---:|
140
+ | 1,024 | 48 MiB | ~3 MiB fp16 / ~6 MiB fp32 fallback |
141
+ | 4,096 | 192 MiB | ~3 MiB fp16 / ~6 MiB fp32 fallback |
142
+ | 16,384 | 768 MiB | ~3 MiB fp16 / ~6 MiB fp32 fallback |
143
+ | 65,536 | 3,072 MiB | ~3 MiB fp16 / ~6 MiB fp32 fallback |
144
+
145
+ This table is a cache-state comparison only. It is not a claim about total GPU memory, throughput, benchmark quality, or usable context length. The released OdinNext code is still limited to 2,048 cumulative positions.
146
+
147
+ ## Training snapshot
148
 
149
+ Values verified from the public config:
150
 
151
+ | Field | Value |
152
+ |---|---:|
153
+ | `_training_step` | 3,259 |
154
+ | `_total_tokens` | 6,835,666,944 |
155
+ | `_weights_source` | `ema_state_dict` |
156
+ | `torch_dtype` | `float16` |
157
+ | `max_position_embeddings` | 2,048 |
158
+
159
+ Author-reported training notes for this early checkpoint:
160
+
161
+ | Item | Value |
162
+ |---|---|
163
+ | Hardware | 2× AMD Strix Halo / gfx1151, ROCm stack |
164
+ | Training precision | fp16 + GradScaler |
165
+ | Optimizers | NorMuon for 2D tensors; AdamW for 1D/embed tensors |
166
+ | LR schedule | WSD, peak `8e-4`, warmup 500, min LR 0.1× peak |
167
+ | Stabilization | z-loss `1e-4`, attention soft-cap 50, EMA decay 0.999 |
168
+ | Curriculum | TST-style bag-size-4 phase active at this checkpoint |
169
+ | Public benchmarks | not yet provided |
170
+
171
+ ### Token accounting note
172
+
173
+ The public config records `_total_tokens = 6,835,666,944`. Do not reinterpret that as plain next-token positions from:
174
+
175
+ ```text
176
+ 3,259 optimizer steps × 256 effective sequences × 2,048 tokens
177
+ = 1,708,916,224 position tokens
178
+ ```
179
+
180
+ The 6.84B figure appears to be token-superposition/original-token-equivalent accounting rather than simple next-token position accounting. A full reproducibility report should define whether the total counts original text tokens, bagged targets, loss terms, or optimizer-position tokens.
181
+
182
+ ### TST note
183
+
184
+ The cited Token-Superposition Training paper defines TST as a two-phase method: a superposition phase that combines contiguous tokens into bags and uses a multi-hot cross-entropy objective, followed by a recovery phase that returns to ordinary next-token training.
185
+
186
+ This checkpoint is described as still being in a bag-size-4 phase. That means ordinary single-stream autoregressive inference is not necessarily the final intended training distribution. Treat quality as preliminary until a bag-size-1 recovery checkpoint and benchmark results are published.
187
+
188
+ ## Usage with Transformers
189
+
190
+ Install the basics:
191
+
192
+ ```bash
193
+ pip install "transformers>=4.46" torch safetensors
194
+ ```
195
+
196
+ Optional: install `flash-linear-attention` if your platform supports it. Without it, the model falls back to a pure-PyTorch reference implementation that is useful for correctness and portability but slower for long prompts.
197
 
198
  ```python
 
199
  import torch
200
+ from transformers import AutoModelForCausalLM, AutoTokenizer
201
+
202
+ repo = "joelhenwang/OdinNext-138M-Early-Checkpoint"
203
+ # For reproducible experiments, replace "main" with a specific commit hash.
204
+ revision = "main"
205
 
206
+ device = "cuda" if torch.cuda.is_available() else "cpu"
207
+ dtype = torch.float16 if device == "cuda" else torch.float32
208
+
209
+ tok = AutoTokenizer.from_pretrained(repo, revision=revision)
210
  model = AutoModelForCausalLM.from_pretrained(
211
+ repo,
212
+ revision=revision,
213
+ trust_remote_code=True,
214
+ torch_dtype=dtype,
215
+ ).to(device).eval()
216
+
217
+ prompt = "The night was quiet and the streets were empty"
218
+ inputs = tok(prompt, return_tensors="pt").to(device)
219
+
220
+ # The released code is capped at 2,048 cumulative positions.
221
+ remaining = model.config.max_position_embeddings - inputs.input_ids.shape[1]
222
+ max_new_tokens = max(0, min(80, remaining))
223
 
 
224
  with torch.inference_mode():
225
  out = model.generate(
226
+ **inputs,
227
+ max_new_tokens=max_new_tokens,
228
+ do_sample=True,
229
+ temperature=0.8,
230
+ top_p=0.95,
231
+ repetition_penalty=1.1,
232
+ pad_token_id=tok.pad_token_id,
233
+ use_cache=True,
234
  )
235
+
236
  print(tok.decode(out[0], skip_special_tokens=True))
237
  ```
238
 
239
+ ### Batching guidance
 
 
 
240
 
241
+ The model's recurrent scan does not apply an attention mask. For correct batched generation:
242
+
243
+ - avoid left padding;
244
+ - prefer same-length prompts in a batch;
245
+ - avoid processing pad tokens as if they were real prompt tokens;
246
+ - test batched output against single-sample output before relying on batched generation.
247
+
248
+ Single-prompt generation is the safest path for basic use.
249
 
 
 
 
 
 
 
250
 
 
251
 
252
+ ## Known limitations
253
+
254
+ - **No instruction tuning:** no SFT, DPO, RLHF, RLAIF, or chat template.
255
+ - **No safety training:** outputs can be unsafe, biased, false, or incoherent.
256
+ - **Early quality:** this is about 3% of the planned pretraining budget according to the original release notes.
257
+ - **No formal benchmarks yet:** HellaSwag, ARC, MMLU, perplexity suites, and long-context tests are not provided here.
258
+ - **Hard 2,048-token cap:** recurrent cache size is constant, but the released RoPE cache still limits positions.
259
+ - **Masking caveat:** `attention_mask` is ignored in the backbone; padding can affect recurrent state.
260
+ - **English-focused:** multilingual and code generation should be assumed weak unless tested.
261
+ - **bf16 unvalidated:** fp16 is the intended inference dtype for this checkpoint; CPU fallback should use fp32 for portability.
262
+ - **Training data not fully documented in this card:** treat data provenance, memorization risk, and bias profile as uncharacterized unless separately documented.
263
+
264
+ ## Revisions
265
 
266
+ - `main`: EMA-shadowed weights from `_weights_source = ema_state_dict`; recommended for evaluation.
267
+ - `live`: raw training weights at step 3,259, if this branch is retained.
268
 
269
+ For reproducible experiments, pin a commit hash rather than a moving branch name.
270
 
271
  ## Citation
272
 
273
  ```bibtex
274
+ @misc{odinnext_138m_early_2026,
275
+ title = {OdinNext-138M-Early-Checkpoint},
276
+ author = {Wang, Joel},
277
+ year = {2026},
278
+ howpublished = {\url{https://huggingface.co/joelhenwang/OdinNext-138M-Early-Checkpoint}},
279
+ note = {Early HGRN2 recurrent language-model checkpoint}
280
+ }
281
  ```
282
 
283
  ## References
284
 
285
+ - Zhen Qin, Songlin Yang, Weixuan Sun, Xuyang Shen, Dong Li, Weigao Sun, Yiran Zhong. **HGRN2: Gated Linear RNNs with State Expansion.** arXiv:2404.07904. https://arxiv.org/abs/2404.07904
286
+ - Bowen Peng, Théo Gigant, Jeffrey Quesnelle. **Efficient Pre-Training with Token Superposition.** arXiv:2605.06546. https://arxiv.org/abs/2605.06546
287
+ - Chenze Shao, Fandong Meng, Jie Zhou. **Patch-Level Training for Large Language Models.** arXiv:2407.12665. https://arxiv.org/abs/2407.12665
288
+ - Makoto Shing, Masanori Koyama, Takuya Akiba. **DiffusionBlocks: Block-wise Neural Network Training via Diffusion Interpretation.** arXiv:2506.14202. https://arxiv.org/abs/2506.14202
289
+ - Hugging Face Transformers custom-model documentation: https://huggingface.co/docs/transformers/custom_models
290
+ - vLLM custom/Transformers backend documentation: https://docs.vllm.ai/en/latest/models/supported_models/
291
+ - SGLang Transformers backend documentation: https://huggingface.co/docs/transformers/en/community_integrations/sglang