Upload 3 files
Browse files- build_ascii_vocab_bundle_v9.py +744 -0
- final_infer.pt +3 -0
- pgsm_sparse_rope_lm.py +627 -0
build_ascii_vocab_bundle_v9.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ASCII-limited English-first vocab bundle builder for a tiny LLM.
|
| 4 |
+
|
| 5 |
+
Design goals
|
| 6 |
+
------------
|
| 7 |
+
- English-only as much as reasonably possible
|
| 8 |
+
- Keep text intact instead of creating holes
|
| 9 |
+
- Fold uppercase -> lowercase
|
| 10 |
+
- Fold accented Latin letters -> plain ASCII where reasonable
|
| 11 |
+
- Drop emoji and non-Latin scripts
|
| 12 |
+
- Keep only a small practical punctuation set
|
| 13 |
+
- Learn multi-character tokens from LETTERS ONLY
|
| 14 |
+
- Keep digits and punctuation atomic as single-character tokens
|
| 15 |
+
- Stream from Hugging Face without local dataset files
|
| 16 |
+
|
| 17 |
+
Default source
|
| 18 |
+
--------------
|
| 19 |
+
Streams:
|
| 20 |
+
HuggingFaceFW/fineweb-edu
|
| 21 |
+
config=sample-10BT
|
| 22 |
+
split=train
|
| 23 |
+
|
| 24 |
+
Outputs
|
| 25 |
+
-------
|
| 26 |
+
Creates a bundle directory containing:
|
| 27 |
+
manifest.json
|
| 28 |
+
vocab.json
|
| 29 |
+
token_stats.npz
|
| 30 |
+
pair_stats.npz
|
| 31 |
+
|
| 32 |
+
What gets kept
|
| 33 |
+
--------------
|
| 34 |
+
- letters: a-z
|
| 35 |
+
- digits: 0-9
|
| 36 |
+
- whitespace: space + newline
|
| 37 |
+
- limited punctuation:
|
| 38 |
+
. , ! ? ' " - ( ) : ; @ # + % = / \ *
|
| 39 |
+
|
| 40 |
+
Tokenization policy
|
| 41 |
+
-------------------
|
| 42 |
+
- learned multi-character tokens: letters only
|
| 43 |
+
- digits remain atomic single-character tokens
|
| 44 |
+
- punctuation remains atomic single-character tokens
|
| 45 |
+
|
| 46 |
+
Examples
|
| 47 |
+
--------
|
| 48 |
+
PowerShell smoke test:
|
| 49 |
+
python F:\\TokenizerUltra\\build_ascii_vocab_bundle_v9.py --output "F:\\TokenizerUltra\\vocab_bundle_test" --max-examples 5000 --bpe-train-chars 2000000 --final-token-budget 2000000
|
| 50 |
+
|
| 51 |
+
PowerShell full build:
|
| 52 |
+
python F:\\TokenizerUltra\\build_ascii_vocab_bundle_v9.py --output "F:\\TokenizerUltra\\vocab_bundle" --bpe-train-chars 100000000 --final-token-budget 100000000
|
| 53 |
+
|
| 54 |
+
Dependencies
|
| 55 |
+
------------
|
| 56 |
+
python -m pip install numpy datasets
|
| 57 |
+
"""
|
| 58 |
+
from __future__ import annotations
|
| 59 |
+
|
| 60 |
+
import argparse
|
| 61 |
+
import json
|
| 62 |
+
import re
|
| 63 |
+
import unicodedata
|
| 64 |
+
from collections import Counter
|
| 65 |
+
from dataclasses import dataclass
|
| 66 |
+
from pathlib import Path
|
| 67 |
+
from typing import Dict, Iterator, List, Optional, Sequence, Tuple
|
| 68 |
+
|
| 69 |
+
import numpy as np
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
DEFAULT_DATASET = "HuggingFaceFW/fineweb-edu"
|
| 73 |
+
DEFAULT_CONFIG = "sample-10BT"
|
| 74 |
+
DEFAULT_SPLIT = "train"
|
| 75 |
+
|
| 76 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 77 |
+
|
| 78 |
+
ASCII_LETTERS = "abcdefghijklmnopqrstuvwxyz"
|
| 79 |
+
ASCII_DIGITS = "0123456789"
|
| 80 |
+
ALLOWED_PUNCT = ".,!?\'\"-():;@#+%=/\\*"
|
| 81 |
+
|
| 82 |
+
SPACE_TOKEN = " "
|
| 83 |
+
NEWLINE_TOKEN = "\n"
|
| 84 |
+
|
| 85 |
+
ALLOWED_CHARS = set(ASCII_LETTERS + ASCII_DIGITS + ALLOWED_PUNCT + SPACE_TOKEN + NEWLINE_TOKEN)
|
| 86 |
+
TEXT_FIELDS = ("text", "content", "body", "document", "raw_content", "message")
|
| 87 |
+
|
| 88 |
+
ESCAPED_PUNCT = re.escape(ALLOWED_PUNCT)
|
| 89 |
+
TOKEN_RE = re.compile(rf"\n| +|[a-z]+|[0-9]|[{ESCAPED_PUNCT}]")
|
| 90 |
+
MULTISPACE_RE = re.compile(r"[ \t\f\v]+")
|
| 91 |
+
MULTINEWLINE_RE = re.compile(r"\n{3,}")
|
| 92 |
+
|
| 93 |
+
SEQUENCE_REPLACEMENTS = {
|
| 94 |
+
"\u2018": "'",
|
| 95 |
+
"\u2019": "'",
|
| 96 |
+
"\u201c": '"',
|
| 97 |
+
"\u201d": '"',
|
| 98 |
+
"\u2013": "-",
|
| 99 |
+
"\u2014": "-",
|
| 100 |
+
"\u2015": "-",
|
| 101 |
+
"\u2212": "-",
|
| 102 |
+
"\u2026": "...",
|
| 103 |
+
"\u2022": " ",
|
| 104 |
+
"\u00b7": " ",
|
| 105 |
+
"\u00a0": " ",
|
| 106 |
+
"\u200b": "",
|
| 107 |
+
"\u200c": "",
|
| 108 |
+
"\u200d": "",
|
| 109 |
+
"\ufeff": "",
|
| 110 |
+
"\u00ad": "",
|
| 111 |
+
"\t": " ",
|
| 112 |
+
"\r": "\n",
|
| 113 |
+
"[": "(",
|
| 114 |
+
"]": ")",
|
| 115 |
+
"{": "(",
|
| 116 |
+
"}": ")",
|
| 117 |
+
"<": "(",
|
| 118 |
+
">": ")",
|
| 119 |
+
"(": "(",
|
| 120 |
+
")": ")",
|
| 121 |
+
"[": "(",
|
| 122 |
+
"]": ")",
|
| 123 |
+
"{": "(",
|
| 124 |
+
"}": ")",
|
| 125 |
+
"【": "(",
|
| 126 |
+
"】": ")",
|
| 127 |
+
"〈": "(",
|
| 128 |
+
"〉": ")",
|
| 129 |
+
"《": "(",
|
| 130 |
+
"》": ")",
|
| 131 |
+
"「": "(",
|
| 132 |
+
"」": ")",
|
| 133 |
+
"『": "(",
|
| 134 |
+
"』": ")",
|
| 135 |
+
"〔": "(",
|
| 136 |
+
"〕": ")",
|
| 137 |
+
"〖": "(",
|
| 138 |
+
"〗": ")",
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
LATIN_FOLD_REPLACEMENTS = {
|
| 142 |
+
"ß": "ss",
|
| 143 |
+
"ẞ": "ss",
|
| 144 |
+
"æ": "ae",
|
| 145 |
+
"ǽ": "ae",
|
| 146 |
+
"œ": "oe",
|
| 147 |
+
"ø": "o",
|
| 148 |
+
"ð": "d",
|
| 149 |
+
"þ": "th",
|
| 150 |
+
"ł": "l",
|
| 151 |
+
"đ": "d",
|
| 152 |
+
"ħ": "h",
|
| 153 |
+
"ı": "i",
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class BundleConfig:
|
| 159 |
+
output: Path
|
| 160 |
+
dataset: str = DEFAULT_DATASET
|
| 161 |
+
config: str = DEFAULT_CONFIG
|
| 162 |
+
split: str = DEFAULT_SPLIT
|
| 163 |
+
vocab_size: int = 2000
|
| 164 |
+
bpe_train_chars: int = 100_000_000
|
| 165 |
+
final_token_budget: int = 100_000_000
|
| 166 |
+
max_examples: Optional[int] = None
|
| 167 |
+
min_pair_count: int = 5
|
| 168 |
+
token_prior_clip: float = 3.0
|
| 169 |
+
pair_prior_clip: float = 3.0
|
| 170 |
+
word_cache_size: int = 200000
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _import_load_dataset():
|
| 174 |
+
try:
|
| 175 |
+
from datasets import load_dataset
|
| 176 |
+
except Exception as exc:
|
| 177 |
+
raise SystemExit(
|
| 178 |
+
"Missing dependency: datasets. Install with:\n"
|
| 179 |
+
" python -m pip install datasets numpy"
|
| 180 |
+
) from exc
|
| 181 |
+
return load_dataset
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def normalize_text(text: str) -> str:
|
| 185 |
+
if not text:
|
| 186 |
+
return ""
|
| 187 |
+
|
| 188 |
+
for src, dst in SEQUENCE_REPLACEMENTS.items():
|
| 189 |
+
text = text.replace(src, dst)
|
| 190 |
+
for src, dst in LATIN_FOLD_REPLACEMENTS.items():
|
| 191 |
+
text = text.replace(src, dst)
|
| 192 |
+
|
| 193 |
+
text = text.casefold()
|
| 194 |
+
text = unicodedata.normalize("NFKD", text)
|
| 195 |
+
|
| 196 |
+
out_chars: List[str] = []
|
| 197 |
+
last_was_space = False
|
| 198 |
+
|
| 199 |
+
for ch in text:
|
| 200 |
+
cat = unicodedata.category(ch)
|
| 201 |
+
|
| 202 |
+
if cat.startswith("M"):
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
if ch in ALLOWED_CHARS:
|
| 206 |
+
out_chars.append(ch)
|
| 207 |
+
last_was_space = (ch == " ")
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
if ch == "\n":
|
| 211 |
+
out_chars.append("\n")
|
| 212 |
+
last_was_space = False
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
if ch.isspace():
|
| 216 |
+
if not last_was_space:
|
| 217 |
+
out_chars.append(" ")
|
| 218 |
+
last_was_space = True
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
if ord(ch) < 128:
|
| 222 |
+
if cat[:1] in {"P", "S"} or ch in "[]{}<>_|~^$&`":
|
| 223 |
+
if not last_was_space:
|
| 224 |
+
out_chars.append(" ")
|
| 225 |
+
last_was_space = True
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
if cat[:1] in {"L", "N", "P", "S"}:
|
| 229 |
+
if not last_was_space:
|
| 230 |
+
out_chars.append(" ")
|
| 231 |
+
last_was_space = True
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
normalized = "".join(out_chars)
|
| 235 |
+
normalized = MULTISPACE_RE.sub(" ", normalized)
|
| 236 |
+
normalized = re.sub(r" *\n *", "\n", normalized)
|
| 237 |
+
normalized = MULTINEWLINE_RE.sub("\n\n", normalized)
|
| 238 |
+
normalized = normalized.strip(" ")
|
| 239 |
+
return normalized
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def iter_stream_examples(
|
| 243 |
+
dataset_name: str,
|
| 244 |
+
config_name: str,
|
| 245 |
+
split: str,
|
| 246 |
+
max_examples: Optional[int],
|
| 247 |
+
) -> Iterator[str]:
|
| 248 |
+
load_dataset = _import_load_dataset()
|
| 249 |
+
ds = load_dataset(dataset_name, config_name, split=split, streaming=True)
|
| 250 |
+
|
| 251 |
+
seen = 0
|
| 252 |
+
for row in ds:
|
| 253 |
+
text = None
|
| 254 |
+
|
| 255 |
+
if isinstance(row, dict):
|
| 256 |
+
for field in TEXT_FIELDS:
|
| 257 |
+
if field in row and isinstance(row[field], str):
|
| 258 |
+
text = row[field]
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
if text is None and "messages" in row and isinstance(row["messages"], list):
|
| 262 |
+
chunks: List[str] = []
|
| 263 |
+
for msg in row["messages"]:
|
| 264 |
+
if isinstance(msg, dict):
|
| 265 |
+
content = msg.get("content")
|
| 266 |
+
if isinstance(content, str):
|
| 267 |
+
chunks.append(content)
|
| 268 |
+
if chunks:
|
| 269 |
+
text = "\n".join(chunks)
|
| 270 |
+
|
| 271 |
+
elif isinstance(row, str):
|
| 272 |
+
text = row
|
| 273 |
+
|
| 274 |
+
if text:
|
| 275 |
+
yield text
|
| 276 |
+
seen += 1
|
| 277 |
+
if max_examples is not None and seen >= max_examples:
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def iter_normalized_text(cfg: BundleConfig) -> Iterator[str]:
|
| 282 |
+
for raw in iter_stream_examples(cfg.dataset, cfg.config, cfg.split, cfg.max_examples):
|
| 283 |
+
text = normalize_text(raw)
|
| 284 |
+
if text:
|
| 285 |
+
yield text
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def iter_pre_tokens(text: str) -> Iterator[str]:
|
| 289 |
+
for piece in TOKEN_RE.findall(text):
|
| 290 |
+
yield piece
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def count_words_for_bpe(cfg: BundleConfig) -> Counter[str]:
|
| 294 |
+
word_freq: Counter[str] = Counter()
|
| 295 |
+
char_budget = 0
|
| 296 |
+
|
| 297 |
+
for text in iter_normalized_text(cfg):
|
| 298 |
+
char_budget += len(text)
|
| 299 |
+
for piece in iter_pre_tokens(text):
|
| 300 |
+
if piece.isalpha():
|
| 301 |
+
word_freq[piece] += 1
|
| 302 |
+
|
| 303 |
+
if char_budget >= cfg.bpe_train_chars:
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
return word_freq
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def word_to_symbols(word: str) -> Tuple[str, ...]:
|
| 310 |
+
return tuple(word)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def compute_pair_counts_from_vocab(
|
| 314 |
+
vocab_words: Dict[Tuple[str, ...], int]
|
| 315 |
+
) -> Counter[Tuple[str, str]]:
|
| 316 |
+
pair_counts: Counter[Tuple[str, str]] = Counter()
|
| 317 |
+
for symbols, freq in vocab_words.items():
|
| 318 |
+
if len(symbols) < 2:
|
| 319 |
+
continue
|
| 320 |
+
for i in range(len(symbols) - 1):
|
| 321 |
+
left = symbols[i]
|
| 322 |
+
right = symbols[i + 1]
|
| 323 |
+
if left.isalpha() and right.isalpha():
|
| 324 |
+
pair_counts[(left, right)] += freq
|
| 325 |
+
return pair_counts
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def merge_word_symbols(
|
| 329 |
+
symbols: Tuple[str, ...],
|
| 330 |
+
pair: Tuple[str, str],
|
| 331 |
+
) -> Tuple[str, ...]:
|
| 332 |
+
merged: List[str] = []
|
| 333 |
+
i = 0
|
| 334 |
+
while i < len(symbols):
|
| 335 |
+
if i < len(symbols) - 1 and symbols[i] == pair[0] and symbols[i + 1] == pair[1]:
|
| 336 |
+
merged.append(symbols[i] + symbols[i + 1])
|
| 337 |
+
i += 2
|
| 338 |
+
else:
|
| 339 |
+
merged.append(symbols[i])
|
| 340 |
+
i += 1
|
| 341 |
+
return tuple(merged)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def train_bpe_from_words(
|
| 345 |
+
word_freq: Counter[str],
|
| 346 |
+
vocab_size: int,
|
| 347 |
+
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
| 348 |
+
fixed_non_alpha_count = len(SPECIAL_TOKENS) + 2 + len(ASCII_DIGITS) + len(ALLOWED_PUNCT)
|
| 349 |
+
target_alpha_piece_count = max(vocab_size - fixed_non_alpha_count, len(ASCII_LETTERS))
|
| 350 |
+
|
| 351 |
+
vocab_words: Dict[Tuple[str, ...], int] = {
|
| 352 |
+
word_to_symbols(word): freq for word, freq in word_freq.items()
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
current_symbols = set(ASCII_LETTERS)
|
| 356 |
+
merges: List[Tuple[str, str]] = []
|
| 357 |
+
|
| 358 |
+
while len(current_symbols) < target_alpha_piece_count:
|
| 359 |
+
pair_counts = compute_pair_counts_from_vocab(vocab_words)
|
| 360 |
+
if not pair_counts:
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
best_pair, best_count = pair_counts.most_common(1)[0]
|
| 364 |
+
if best_count < 2:
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
merges.append(best_pair)
|
| 368 |
+
new_vocab_words: Dict[Tuple[str, ...], int] = {}
|
| 369 |
+
for symbols, freq in vocab_words.items():
|
| 370 |
+
merged_symbols = merge_word_symbols(symbols, best_pair)
|
| 371 |
+
new_vocab_words[merged_symbols] = new_vocab_words.get(merged_symbols, 0) + freq
|
| 372 |
+
vocab_words = new_vocab_words
|
| 373 |
+
current_symbols.add(best_pair[0] + best_pair[1])
|
| 374 |
+
|
| 375 |
+
if len(current_symbols) % 100 == 0:
|
| 376 |
+
print(f"[bpe] learned alpha pieces: {len(current_symbols)}", flush=True)
|
| 377 |
+
|
| 378 |
+
learned_alpha_pieces = sorted(current_symbols)
|
| 379 |
+
final_vocab = (
|
| 380 |
+
SPECIAL_TOKENS
|
| 381 |
+
+ [SPACE_TOKEN, NEWLINE_TOKEN]
|
| 382 |
+
+ list(ASCII_DIGITS)
|
| 383 |
+
+ list(ALLOWED_PUNCT)
|
| 384 |
+
+ learned_alpha_pieces
|
| 385 |
+
)
|
| 386 |
+
final_vocab = final_vocab[:vocab_size]
|
| 387 |
+
return final_vocab, merges
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class GreedyTokenizer:
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
vocab: Sequence[str],
|
| 394 |
+
merges: Sequence[Tuple[str, str]],
|
| 395 |
+
word_cache_size: int = 200000,
|
| 396 |
+
) -> None:
|
| 397 |
+
self.vocab = list(vocab)
|
| 398 |
+
self.merges = list(merges)
|
| 399 |
+
self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
|
| 400 |
+
self.unk_id = self.token_to_id["<unk>"]
|
| 401 |
+
self.alpha_token_ids = {
|
| 402 |
+
tid for tok, tid in self.token_to_id.items() if tok.isalpha()
|
| 403 |
+
}
|
| 404 |
+
self.merge_ranks: Dict[Tuple[str, str], int] = {
|
| 405 |
+
pair: rank for rank, pair in enumerate(self.merges)
|
| 406 |
+
}
|
| 407 |
+
self.word_cache_size = max(int(word_cache_size), 0)
|
| 408 |
+
self._word_cache: Dict[str, Tuple[int, ...]] = {}
|
| 409 |
+
|
| 410 |
+
def _get_pairs(self, symbols: Tuple[str, ...]) -> set[Tuple[str, str]]:
|
| 411 |
+
return set(zip(symbols[:-1], symbols[1:]))
|
| 412 |
+
|
| 413 |
+
def _merge_once(self, symbols: Tuple[str, ...], pair: Tuple[str, str]) -> Tuple[str, ...]:
|
| 414 |
+
first, second = pair
|
| 415 |
+
merged: List[str] = []
|
| 416 |
+
i = 0
|
| 417 |
+
while i < len(symbols):
|
| 418 |
+
if i < len(symbols) - 1 and symbols[i] == first and symbols[i + 1] == second:
|
| 419 |
+
merged.append(first + second)
|
| 420 |
+
i += 2
|
| 421 |
+
else:
|
| 422 |
+
merged.append(symbols[i])
|
| 423 |
+
i += 1
|
| 424 |
+
return tuple(merged)
|
| 425 |
+
|
| 426 |
+
def tokenize_alpha_run(self, span: str) -> List[int]:
|
| 427 |
+
if not span:
|
| 428 |
+
return []
|
| 429 |
+
|
| 430 |
+
cached = self._word_cache.get(span)
|
| 431 |
+
if cached is not None:
|
| 432 |
+
return list(cached)
|
| 433 |
+
|
| 434 |
+
symbols: Tuple[str, ...] = tuple(span)
|
| 435 |
+
|
| 436 |
+
while True:
|
| 437 |
+
pairs = self._get_pairs(symbols)
|
| 438 |
+
if not pairs:
|
| 439 |
+
break
|
| 440 |
+
|
| 441 |
+
ranked_pairs = [
|
| 442 |
+
(self.merge_ranks[pair], pair)
|
| 443 |
+
for pair in pairs
|
| 444 |
+
if pair in self.merge_ranks
|
| 445 |
+
]
|
| 446 |
+
if not ranked_pairs:
|
| 447 |
+
break
|
| 448 |
+
|
| 449 |
+
_, best_pair = min(ranked_pairs)
|
| 450 |
+
symbols = self._merge_once(symbols, best_pair)
|
| 451 |
+
if len(symbols) == 1:
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
token_ids = tuple(self.token_to_id.get(piece, self.unk_id) for piece in symbols)
|
| 455 |
+
|
| 456 |
+
if self.word_cache_size > 0:
|
| 457 |
+
if len(self._word_cache) >= self.word_cache_size:
|
| 458 |
+
self._word_cache.clear()
|
| 459 |
+
self._word_cache[span] = token_ids
|
| 460 |
+
|
| 461 |
+
return list(token_ids)
|
| 462 |
+
|
| 463 |
+
def is_alpha_id(self, token_id: int) -> bool:
|
| 464 |
+
return token_id in self.alpha_token_ids
|
| 465 |
+
|
| 466 |
+
def encode(self, text: str) -> List[int]:
|
| 467 |
+
ids: List[int] = []
|
| 468 |
+
for piece in iter_pre_tokens(text):
|
| 469 |
+
if piece == "\n":
|
| 470 |
+
ids.append(self.token_to_id[NEWLINE_TOKEN])
|
| 471 |
+
elif piece.isspace():
|
| 472 |
+
ids.append(self.token_to_id[SPACE_TOKEN])
|
| 473 |
+
elif piece.isalpha():
|
| 474 |
+
ids.extend(self.tokenize_alpha_run(piece))
|
| 475 |
+
else:
|
| 476 |
+
ids.append(self.token_to_id.get(piece, self.unk_id))
|
| 477 |
+
return ids
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _safe_zscore(values: np.ndarray) -> np.ndarray:
|
| 481 |
+
values = values.astype(np.float32, copy=False)
|
| 482 |
+
mean = float(values.mean())
|
| 483 |
+
std = float(values.std())
|
| 484 |
+
if std < 1e-8:
|
| 485 |
+
return np.zeros_like(values, dtype=np.float32)
|
| 486 |
+
return (values - mean) / std
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def build_priors_from_counts(
|
| 490 |
+
counts: np.ndarray,
|
| 491 |
+
clip_value: float,
|
| 492 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 493 |
+
counts = counts.astype(np.float64, copy=False)
|
| 494 |
+
total = counts.sum()
|
| 495 |
+
if total <= 0:
|
| 496 |
+
raise ValueError("Counts are empty; cannot build priors.")
|
| 497 |
+
|
| 498 |
+
probs = (counts + 1.0) / (total + counts.size)
|
| 499 |
+
surprisal = -np.log(probs)
|
| 500 |
+
z = _safe_zscore(surprisal.astype(np.float32))
|
| 501 |
+
z = np.clip(z, -clip_value, clip_value)
|
| 502 |
+
prior = (z + clip_value) / (2.0 * clip_value)
|
| 503 |
+
return probs.astype(np.float32), surprisal.astype(np.float32), prior.astype(np.float32)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def build_pair_priors(
|
| 507 |
+
pair_counts: np.ndarray,
|
| 508 |
+
min_pair_count: int,
|
| 509 |
+
clip_value: float,
|
| 510 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 511 |
+
row_sums = pair_counts.sum(axis=1, keepdims=True).astype(np.float64)
|
| 512 |
+
vocab_size = pair_counts.shape[0]
|
| 513 |
+
|
| 514 |
+
probs = (pair_counts.astype(np.float64) + 1.0) / (row_sums + vocab_size)
|
| 515 |
+
surprisal = -np.log(probs)
|
| 516 |
+
valid_mask = (pair_counts >= min_pair_count).astype(np.uint8)
|
| 517 |
+
|
| 518 |
+
flat_surprisal = surprisal.astype(np.float32).reshape(-1)
|
| 519 |
+
z = _safe_zscore(flat_surprisal).reshape(pair_counts.shape)
|
| 520 |
+
z = np.clip(z, -clip_value, clip_value)
|
| 521 |
+
|
| 522 |
+
prior = (z + clip_value) / (2.0 * clip_value)
|
| 523 |
+
prior = np.where(valid_mask == 1, prior, 0.5)
|
| 524 |
+
|
| 525 |
+
return (
|
| 526 |
+
probs.astype(np.float32),
|
| 527 |
+
surprisal.astype(np.float32),
|
| 528 |
+
prior.astype(np.float32),
|
| 529 |
+
valid_mask,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def second_pass_stats(
|
| 534 |
+
cfg: BundleConfig,
|
| 535 |
+
tokenizer: GreedyTokenizer,
|
| 536 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 537 |
+
vocab_size = len(tokenizer.vocab)
|
| 538 |
+
token_counts = np.zeros(vocab_size, dtype=np.int64)
|
| 539 |
+
pair_counts = np.zeros((vocab_size, vocab_size), dtype=np.int32)
|
| 540 |
+
|
| 541 |
+
token_budget = 0
|
| 542 |
+
for text in iter_normalized_text(cfg):
|
| 543 |
+
ids = tokenizer.encode(text)
|
| 544 |
+
if not ids:
|
| 545 |
+
continue
|
| 546 |
+
|
| 547 |
+
for tid in ids:
|
| 548 |
+
token_counts[tid] += 1
|
| 549 |
+
|
| 550 |
+
prev = ids[0]
|
| 551 |
+
for cur in ids[1:]:
|
| 552 |
+
if tokenizer.is_alpha_id(prev) and tokenizer.is_alpha_id(cur):
|
| 553 |
+
pair_counts[prev, cur] += 1
|
| 554 |
+
prev = cur
|
| 555 |
+
|
| 556 |
+
token_budget += len(ids)
|
| 557 |
+
if token_budget >= cfg.final_token_budget:
|
| 558 |
+
break
|
| 559 |
+
|
| 560 |
+
return token_counts, pair_counts
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def save_bundle(
|
| 564 |
+
cfg: BundleConfig,
|
| 565 |
+
vocab: Sequence[str],
|
| 566 |
+
merges: Sequence[Tuple[str, str]],
|
| 567 |
+
token_counts: np.ndarray,
|
| 568 |
+
token_probs: np.ndarray,
|
| 569 |
+
token_surprisal: np.ndarray,
|
| 570 |
+
token_prior: np.ndarray,
|
| 571 |
+
pair_counts: np.ndarray,
|
| 572 |
+
pair_probs: np.ndarray,
|
| 573 |
+
pair_surprisal: np.ndarray,
|
| 574 |
+
pair_prior: np.ndarray,
|
| 575 |
+
pair_valid_mask: np.ndarray,
|
| 576 |
+
) -> None:
|
| 577 |
+
cfg.output.mkdir(parents=True, exist_ok=True)
|
| 578 |
+
|
| 579 |
+
vocab_json = {
|
| 580 |
+
"token_to_id": {tok: i for i, tok in enumerate(vocab)},
|
| 581 |
+
"id_to_token": {str(i): tok for i, tok in enumerate(vocab)},
|
| 582 |
+
"special_tokens": SPECIAL_TOKENS,
|
| 583 |
+
"space_token": SPACE_TOKEN,
|
| 584 |
+
"newline_token": NEWLINE_TOKEN,
|
| 585 |
+
"merges": [[a, b] for a, b in merges],
|
| 586 |
+
}
|
| 587 |
+
(cfg.output / "vocab.json").write_text(
|
| 588 |
+
json.dumps(vocab_json, indent=2, ensure_ascii=True),
|
| 589 |
+
encoding="utf-8",
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
manifest = {
|
| 593 |
+
"bundle_version": 9,
|
| 594 |
+
"description": "english-first ascii-limited vocab bundle with letter-only learned tokens, atomic digits and punctuation, latin accent folding, bracket folding, faster ranked-bpe runtime tokenization, and alpha-only pair priors",
|
| 595 |
+
"dataset": cfg.dataset,
|
| 596 |
+
"config": cfg.config,
|
| 597 |
+
"split": cfg.split,
|
| 598 |
+
"vocab_size": len(vocab),
|
| 599 |
+
"requested_vocab_size": cfg.vocab_size,
|
| 600 |
+
"special_tokens": SPECIAL_TOKENS,
|
| 601 |
+
"allowed_ascii_letters": ASCII_LETTERS,
|
| 602 |
+
"allowed_ascii_digits": ASCII_DIGITS,
|
| 603 |
+
"allowed_ascii_punctuation": ALLOWED_PUNCT,
|
| 604 |
+
"normalization": {
|
| 605 |
+
"casefold_uppercase_to_lowercase": True,
|
| 606 |
+
"latin_accent_folding": True,
|
| 607 |
+
"bracket_like_marks_folded_to_parentheses": True,
|
| 608 |
+
"non_latin_scripts_to_space": True,
|
| 609 |
+
"emoji_removed": True,
|
| 610 |
+
"unsupported_symbols_to_space": True,
|
| 611 |
+
"collapse_spaces": True,
|
| 612 |
+
"trim_long_newlines": True,
|
| 613 |
+
"runtime_tokenization": "ranked_bpe_letters_only_with_word_cache",
|
| 614 |
+
},
|
| 615 |
+
"token_shape_policy": {
|
| 616 |
+
"learned_multi_character_tokens": "letters_only",
|
| 617 |
+
"digits": "atomic_single_character",
|
| 618 |
+
"punctuation": "atomic_single_character",
|
| 619 |
+
"spaces": "atomic_single_character",
|
| 620 |
+
"newlines": "atomic_single_character",
|
| 621 |
+
},
|
| 622 |
+
"pair_prior_scope": {
|
| 623 |
+
"counted_pairs": "alpha_to_alpha_only",
|
| 624 |
+
"non_alpha_pairs": "neutral_default_prior",
|
| 625 |
+
},
|
| 626 |
+
"bpe_train_chars": cfg.bpe_train_chars,
|
| 627 |
+
"final_token_budget": cfg.final_token_budget,
|
| 628 |
+
"min_pair_count": cfg.min_pair_count,
|
| 629 |
+
"token_prior_clip": cfg.token_prior_clip,
|
| 630 |
+
"pair_prior_clip": cfg.pair_prior_clip,
|
| 631 |
+
"word_cache_size": cfg.word_cache_size,
|
| 632 |
+
}
|
| 633 |
+
(cfg.output / "manifest.json").write_text(
|
| 634 |
+
json.dumps(manifest, indent=2, ensure_ascii=True),
|
| 635 |
+
encoding="utf-8",
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
np.savez_compressed(
|
| 639 |
+
cfg.output / "token_stats.npz",
|
| 640 |
+
count=token_counts,
|
| 641 |
+
prob=token_probs,
|
| 642 |
+
surprisal=token_surprisal,
|
| 643 |
+
importance_prior=token_prior,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
np.savez_compressed(
|
| 647 |
+
cfg.output / "pair_stats.npz",
|
| 648 |
+
pair_count=pair_counts,
|
| 649 |
+
pair_prob=pair_probs,
|
| 650 |
+
pair_surprisal=pair_surprisal,
|
| 651 |
+
pair_importance_prior=pair_prior,
|
| 652 |
+
pair_valid_mask=pair_valid_mask,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def build_bundle(cfg: BundleConfig) -> None:
|
| 657 |
+
print("[1/4] Counting normalized words for letter-only BPE training...", flush=True)
|
| 658 |
+
word_freq = count_words_for_bpe(cfg)
|
| 659 |
+
if not word_freq:
|
| 660 |
+
raise SystemExit("No usable normalized text found in the stream.")
|
| 661 |
+
|
| 662 |
+
print(f"[1/4] Unique normalized letter-words: {len(word_freq):,}", flush=True)
|
| 663 |
+
|
| 664 |
+
print("[2/4] Training letter-only BPE-style subword vocab...", flush=True)
|
| 665 |
+
vocab, merges = train_bpe_from_words(word_freq, cfg.vocab_size)
|
| 666 |
+
print(f"[2/4] Final vocab size: {len(vocab)}", flush=True)
|
| 667 |
+
|
| 668 |
+
print("[3/4] Streaming second pass for token and pair stats...", flush=True)
|
| 669 |
+
tokenizer = GreedyTokenizer(vocab, merges, word_cache_size=cfg.word_cache_size)
|
| 670 |
+
token_counts, pair_counts = second_pass_stats(cfg, tokenizer)
|
| 671 |
+
|
| 672 |
+
if token_counts.sum() <= 0:
|
| 673 |
+
raise SystemExit("Second pass produced no tokens. Check dataset fields or normalization rules.")
|
| 674 |
+
|
| 675 |
+
print(f"[3/4] Final token count: {int(token_counts.sum()):,}", flush=True)
|
| 676 |
+
|
| 677 |
+
print("[4/4] Building priors and saving bundle...", flush=True)
|
| 678 |
+
token_probs, token_surprisal, token_prior = build_priors_from_counts(
|
| 679 |
+
token_counts,
|
| 680 |
+
cfg.token_prior_clip,
|
| 681 |
+
)
|
| 682 |
+
pair_probs, pair_surprisal, pair_prior, pair_valid_mask = build_pair_priors(
|
| 683 |
+
pair_counts,
|
| 684 |
+
cfg.min_pair_count,
|
| 685 |
+
cfg.pair_prior_clip,
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
save_bundle(
|
| 689 |
+
cfg=cfg,
|
| 690 |
+
vocab=vocab,
|
| 691 |
+
merges=merges,
|
| 692 |
+
token_counts=token_counts,
|
| 693 |
+
token_probs=token_probs,
|
| 694 |
+
token_surprisal=token_surprisal,
|
| 695 |
+
token_prior=token_prior,
|
| 696 |
+
pair_counts=pair_counts,
|
| 697 |
+
pair_probs=pair_probs,
|
| 698 |
+
pair_surprisal=pair_surprisal,
|
| 699 |
+
pair_prior=pair_prior,
|
| 700 |
+
pair_valid_mask=pair_valid_mask,
|
| 701 |
+
)
|
| 702 |
+
print(f"Done. Bundle written to: {cfg.output}", flush=True)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def parse_args(argv: Optional[Sequence[str]] = None) -> BundleConfig:
|
| 706 |
+
parser = argparse.ArgumentParser(description="Build an ASCII-limited English vocab + prior bundle from a streamed dataset.")
|
| 707 |
+
parser.add_argument("--output", required=True, help="Output directory for the bundle.")
|
| 708 |
+
parser.add_argument("--dataset", default=DEFAULT_DATASET, help=f"Hugging Face dataset name. Default: {DEFAULT_DATASET}")
|
| 709 |
+
parser.add_argument("--config", default=DEFAULT_CONFIG, help=f"Hugging Face dataset config. Default: {DEFAULT_CONFIG}")
|
| 710 |
+
parser.add_argument("--split", default=DEFAULT_SPLIT, help=f"Dataset split. Default: {DEFAULT_SPLIT}")
|
| 711 |
+
parser.add_argument("--vocab-size", type=int, default=2000, help="Final vocab size including special tokens.")
|
| 712 |
+
parser.add_argument("--bpe-train-chars", type=int, default=100_000_000, help="Normalized character budget for vocab learning.")
|
| 713 |
+
parser.add_argument("--final-token-budget", type=int, default=100_000_000, help="Final tokenizer token budget for priors.")
|
| 714 |
+
parser.add_argument("--max-examples", type=int, default=None, help="Optional cap on streamed examples for testing.")
|
| 715 |
+
parser.add_argument("--min-pair-count", type=int, default=5, help="Minimum pair count to trust a pair prior.")
|
| 716 |
+
parser.add_argument("--token-prior-clip", type=float, default=3.0, help="Clip for token prior z-scores.")
|
| 717 |
+
parser.add_argument("--pair-prior-clip", type=float, default=3.0, help="Clip for pair prior z-scores.")
|
| 718 |
+
parser.add_argument("--word-cache-size", type=int, default=200000, help="Max cached normalized words for faster runtime tokenization.")
|
| 719 |
+
args = parser.parse_args(argv)
|
| 720 |
+
|
| 721 |
+
return BundleConfig(
|
| 722 |
+
output=Path(args.output),
|
| 723 |
+
dataset=args.dataset,
|
| 724 |
+
config=args.config,
|
| 725 |
+
split=args.split,
|
| 726 |
+
vocab_size=args.vocab_size,
|
| 727 |
+
bpe_train_chars=args.bpe_train_chars,
|
| 728 |
+
final_token_budget=args.final_token_budget,
|
| 729 |
+
max_examples=args.max_examples,
|
| 730 |
+
min_pair_count=args.min_pair_count,
|
| 731 |
+
token_prior_clip=args.token_prior_clip,
|
| 732 |
+
pair_prior_clip=args.pair_prior_clip,
|
| 733 |
+
word_cache_size=args.word_cache_size,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def main(argv: Optional[Sequence[str]] = None) -> int:
|
| 738 |
+
cfg = parse_args(argv)
|
| 739 |
+
build_bundle(cfg)
|
| 740 |
+
return 0
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
if __name__ == "__main__":
|
| 744 |
+
raise SystemExit(main())
|
final_infer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbed84c723a50e97b426806ac5d070e7820d46c4634c14e41d20a8d2bada02ce
|
| 3 |
+
size 15957156
|
pgsm_sparse_rope_lm.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
pgsm_sparse_rope_lm.py
|
| 4 |
+
|
| 5 |
+
Reusable model module for the custom LLM architecture developed from the
|
| 6 |
+
long-memory experiments:
|
| 7 |
+
|
| 8 |
+
Parallel Geometric State Model (PGSM)
|
| 9 |
+
+ optional query-only sparse RoPE retrieval head
|
| 10 |
+
|
| 11 |
+
Core design:
|
| 12 |
+
- Fast attention-free local backbone.
|
| 13 |
+
- Depthwise causal convolution for local state propagation.
|
| 14 |
+
- Gated state mixing.
|
| 15 |
+
- Gated MLP blocks.
|
| 16 |
+
- Optional sparse retrieval only at selected query positions.
|
| 17 |
+
- Retrieval dimension is configurable; experiments showed retrieval_dim=512
|
| 18 |
+
was the first strong setting at block_size=1024 / distance=768.
|
| 19 |
+
|
| 20 |
+
This file is intentionally model-only. It does not include training loops,
|
| 21 |
+
datasets, benchmark code, or CLI handling. Import it from your training module.
|
| 22 |
+
|
| 23 |
+
Example:
|
| 24 |
+
|
| 25 |
+
from pgsm_sparse_rope_lm import PGSMConfig, PGSMSparseRoPELM
|
| 26 |
+
|
| 27 |
+
cfg = PGSMConfig.small(vocab_size=256, block_size=1024)
|
| 28 |
+
model = PGSMSparseRoPELM(cfg)
|
| 29 |
+
|
| 30 |
+
logits, loss = model(input_ids, labels)
|
| 31 |
+
|
| 32 |
+
For retrieval tasks where only specific answer/query positions should do sparse
|
| 33 |
+
long-range retrieval:
|
| 34 |
+
|
| 35 |
+
logits, loss = model(input_ids, labels, retrieval_positions=answer_pos)
|
| 36 |
+
|
| 37 |
+
For normal causal LM pretraining, you can disable sparse retrieval or use
|
| 38 |
+
automatic query-token detection if your data marks query positions.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
from __future__ import annotations
|
| 42 |
+
|
| 43 |
+
import math
|
| 44 |
+
from dataclasses import asdict, dataclass, replace
|
| 45 |
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# -----------------------------
|
| 53 |
+
# Configuration
|
| 54 |
+
# -----------------------------
|
| 55 |
+
|
| 56 |
+
@dataclass(frozen=True)
|
| 57 |
+
class PGSMConfig:
|
| 58 |
+
# Vocabulary / sequence
|
| 59 |
+
vocab_size: int = 256
|
| 60 |
+
block_size: int = 1024
|
| 61 |
+
|
| 62 |
+
# Backbone
|
| 63 |
+
dim: int = 192
|
| 64 |
+
layers: int = 3
|
| 65 |
+
hidden: int = 384
|
| 66 |
+
kernel_size: int = 17
|
| 67 |
+
dropout: float = 0.0
|
| 68 |
+
|
| 69 |
+
# Sparse retrieval
|
| 70 |
+
use_sparse_retrieval: bool = True
|
| 71 |
+
retrieval_dim: int = 512
|
| 72 |
+
retrieval_heads: int = 4
|
| 73 |
+
retrieval_dropout: float = 0.0
|
| 74 |
+
|
| 75 |
+
# Retrieval positioning
|
| 76 |
+
# If retrieval_positions is passed to forward(), that wins.
|
| 77 |
+
# Otherwise, if query_token_id is set, positions matching it can be used.
|
| 78 |
+
# Otherwise, retrieval can be skipped or applied to the final token.
|
| 79 |
+
query_token_id: Optional[int] = None
|
| 80 |
+
auto_retrieve_on_query_token: bool = False
|
| 81 |
+
retrieve_at_last_token_if_unspecified: bool = False
|
| 82 |
+
|
| 83 |
+
# Output / loss behavior
|
| 84 |
+
tie_weights: bool = True
|
| 85 |
+
use_post_retrieval_block: bool = True
|
| 86 |
+
ignore_index: int = -100
|
| 87 |
+
|
| 88 |
+
# Init
|
| 89 |
+
init_std: float = 0.02
|
| 90 |
+
|
| 91 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 92 |
+
return asdict(self)
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def tiny(
|
| 96 |
+
cls,
|
| 97 |
+
vocab_size: int = 256,
|
| 98 |
+
block_size: int = 512,
|
| 99 |
+
**overrides: Any,
|
| 100 |
+
) -> "PGSMConfig":
|
| 101 |
+
cfg = cls(
|
| 102 |
+
vocab_size=vocab_size,
|
| 103 |
+
block_size=block_size,
|
| 104 |
+
dim=128,
|
| 105 |
+
layers=3,
|
| 106 |
+
hidden=256,
|
| 107 |
+
kernel_size=17,
|
| 108 |
+
retrieval_dim=256,
|
| 109 |
+
retrieval_heads=4,
|
| 110 |
+
)
|
| 111 |
+
return replace(cfg, **overrides)
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def small(
|
| 115 |
+
cls,
|
| 116 |
+
vocab_size: int = 256,
|
| 117 |
+
block_size: int = 1024,
|
| 118 |
+
**overrides: Any,
|
| 119 |
+
) -> "PGSMConfig":
|
| 120 |
+
# Closest to the successful experiment, with retrieval_dim=512.
|
| 121 |
+
cfg = cls(
|
| 122 |
+
vocab_size=vocab_size,
|
| 123 |
+
block_size=block_size,
|
| 124 |
+
dim=192,
|
| 125 |
+
layers=3,
|
| 126 |
+
hidden=384,
|
| 127 |
+
kernel_size=17,
|
| 128 |
+
retrieval_dim=512,
|
| 129 |
+
retrieval_heads=4,
|
| 130 |
+
)
|
| 131 |
+
return replace(cfg, **overrides)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def medium(
|
| 135 |
+
cls,
|
| 136 |
+
vocab_size: int,
|
| 137 |
+
block_size: int = 2048,
|
| 138 |
+
**overrides: Any,
|
| 139 |
+
) -> "PGSMConfig":
|
| 140 |
+
cfg = cls(
|
| 141 |
+
vocab_size=vocab_size,
|
| 142 |
+
block_size=block_size,
|
| 143 |
+
dim=384,
|
| 144 |
+
layers=6,
|
| 145 |
+
hidden=1024,
|
| 146 |
+
kernel_size=21,
|
| 147 |
+
retrieval_dim=768,
|
| 148 |
+
retrieval_heads=8,
|
| 149 |
+
dropout=0.0,
|
| 150 |
+
retrieval_dropout=0.0,
|
| 151 |
+
)
|
| 152 |
+
return replace(cfg, **overrides)
|
| 153 |
+
|
| 154 |
+
@classmethod
|
| 155 |
+
def large(
|
| 156 |
+
cls,
|
| 157 |
+
vocab_size: int,
|
| 158 |
+
block_size: int = 4096,
|
| 159 |
+
**overrides: Any,
|
| 160 |
+
) -> "PGSMConfig":
|
| 161 |
+
cfg = cls(
|
| 162 |
+
vocab_size=vocab_size,
|
| 163 |
+
block_size=block_size,
|
| 164 |
+
dim=768,
|
| 165 |
+
layers=12,
|
| 166 |
+
hidden=2048,
|
| 167 |
+
kernel_size=25,
|
| 168 |
+
retrieval_dim=1024,
|
| 169 |
+
retrieval_heads=8,
|
| 170 |
+
dropout=0.0,
|
| 171 |
+
retrieval_dropout=0.0,
|
| 172 |
+
)
|
| 173 |
+
return replace(cfg, **overrides)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# -----------------------------
|
| 177 |
+
# Utility functions
|
| 178 |
+
# -----------------------------
|
| 179 |
+
|
| 180 |
+
def count_parameters(module: nn.Module, trainable_only: bool = True) -> int:
|
| 181 |
+
if trainable_only:
|
| 182 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 183 |
+
return sum(p.numel() for p in module.parameters())
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def init_pgsm_weights(module: nn.Module, std: float = 0.02) -> None:
|
| 187 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 188 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 189 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 190 |
+
nn.init.zeros_(module.bias)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
x_even = x[..., 0::2]
|
| 195 |
+
x_odd = x[..., 1::2]
|
| 196 |
+
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _positions_from_query_tokens(input_ids: torch.Tensor, query_token_id: int) -> torch.Tensor:
|
| 200 |
+
"""
|
| 201 |
+
Return one retrieval position per batch row.
|
| 202 |
+
|
| 203 |
+
If multiple query tokens exist, the last one is used.
|
| 204 |
+
If none exist in a row, the final token is used.
|
| 205 |
+
"""
|
| 206 |
+
batch, steps = input_ids.shape
|
| 207 |
+
device = input_ids.device
|
| 208 |
+
matches = input_ids.eq(int(query_token_id))
|
| 209 |
+
positions = torch.full((batch,), steps - 1, dtype=torch.long, device=device)
|
| 210 |
+
|
| 211 |
+
for b in range(batch):
|
| 212 |
+
found = torch.nonzero(matches[b], as_tuple=False).flatten()
|
| 213 |
+
if found.numel() > 0:
|
| 214 |
+
positions[b] = found[-1]
|
| 215 |
+
return positions
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# -----------------------------
|
| 219 |
+
# Backbone blocks
|
| 220 |
+
# -----------------------------
|
| 221 |
+
|
| 222 |
+
class CausalDepthwiseConv(nn.Module):
|
| 223 |
+
"""
|
| 224 |
+
Depthwise causal convolution.
|
| 225 |
+
|
| 226 |
+
This is the main local state propagation primitive. It is parallel over time
|
| 227 |
+
during training and does not construct an attention matrix.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, dim: int, kernel_size: int):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.dim = int(dim)
|
| 233 |
+
self.kernel_size = int(kernel_size)
|
| 234 |
+
self.conv = nn.Conv1d(
|
| 235 |
+
dim,
|
| 236 |
+
dim,
|
| 237 |
+
kernel_size,
|
| 238 |
+
groups=dim,
|
| 239 |
+
padding=kernel_size - 1,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 243 |
+
# x: [B,T,D]
|
| 244 |
+
y = self.conv(x.transpose(1, 2))
|
| 245 |
+
y = y[:, :, : x.size(1)]
|
| 246 |
+
return y.transpose(1, 2)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ParallelGeometricBlock(nn.Module):
|
| 250 |
+
"""
|
| 251 |
+
Attention-free parallel geometric/state-mixing block.
|
| 252 |
+
|
| 253 |
+
Structure:
|
| 254 |
+
norm -> causal depthwise local state -> gated state residual
|
| 255 |
+
norm -> gated MLP -> residual
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, dim: int, hidden: int, kernel_size: int, dropout: float = 0.0):
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.norm_state = nn.LayerNorm(dim)
|
| 261 |
+
self.local_state = CausalDepthwiseConv(dim, kernel_size)
|
| 262 |
+
self.state_mix = nn.Linear(dim, dim)
|
| 263 |
+
self.state_gate = nn.Linear(dim, dim)
|
| 264 |
+
self.drop_state = nn.Dropout(dropout)
|
| 265 |
+
|
| 266 |
+
self.norm_ff = nn.LayerNorm(dim)
|
| 267 |
+
self.ff_in = nn.Linear(dim, hidden * 2)
|
| 268 |
+
self.ff_out = nn.Linear(hidden, dim)
|
| 269 |
+
self.drop_ff = nn.Dropout(dropout)
|
| 270 |
+
|
| 271 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
h = self.norm_state(x)
|
| 273 |
+
local = self.local_state(h)
|
| 274 |
+
gated_state = torch.tanh(self.state_mix(local)) * torch.sigmoid(self.state_gate(h))
|
| 275 |
+
x = x + self.drop_state(gated_state)
|
| 276 |
+
|
| 277 |
+
h = self.norm_ff(x)
|
| 278 |
+
value, gate = self.ff_in(h).chunk(2, dim=-1)
|
| 279 |
+
ff = self.ff_out(F.silu(gate) * value)
|
| 280 |
+
x = x + self.drop_ff(ff)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# -----------------------------
|
| 285 |
+
# Sparse RoPE retrieval
|
| 286 |
+
# -----------------------------
|
| 287 |
+
|
| 288 |
+
class RotaryCache(nn.Module):
|
| 289 |
+
"""
|
| 290 |
+
RoPE cache for tensors shaped [B,H,T,D] and query tensors [B,H,1,D].
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0):
|
| 294 |
+
super().__init__()
|
| 295 |
+
if head_dim % 2 != 0:
|
| 296 |
+
raise ValueError("head_dim must be even for RoPE")
|
| 297 |
+
self.head_dim = int(head_dim)
|
| 298 |
+
self.max_seq_len = int(max_seq_len)
|
| 299 |
+
|
| 300 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 301 |
+
t = torch.arange(max_seq_len).float()
|
| 302 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 303 |
+
|
| 304 |
+
# Duplicate so cos/sin match [D] after rotate_half.
|
| 305 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 306 |
+
self.register_buffer("cos", emb.cos()[None, None, :, :], persistent=False)
|
| 307 |
+
self.register_buffer("sin", emb.sin()[None, None, :, :], persistent=False)
|
| 308 |
+
|
| 309 |
+
def apply_sequence(self, x: torch.Tensor) -> torch.Tensor:
|
| 310 |
+
# x: [B,H,T,D]
|
| 311 |
+
steps = x.size(-2)
|
| 312 |
+
if steps > self.max_seq_len:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"Sequence length {steps} exceeds RoPE cache length {self.max_seq_len}. "
|
| 315 |
+
"Increase config.block_size."
|
| 316 |
+
)
|
| 317 |
+
cos = self.cos[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
|
| 318 |
+
sin = self.sin[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
|
| 319 |
+
return (x * cos) + (rotate_half(x) * sin)
|
| 320 |
+
|
| 321 |
+
def apply_query_positions(self, q: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 322 |
+
# q: [B,H,1,D], positions: [B]
|
| 323 |
+
cos = self.cos[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
|
| 324 |
+
sin = self.sin[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
|
| 325 |
+
return (q * cos) + (rotate_half(q) * sin)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class QueryOnlyRoPERetriever(nn.Module):
|
| 329 |
+
"""
|
| 330 |
+
Sparse retrieval applied only to selected positions.
|
| 331 |
+
|
| 332 |
+
For each batch row, one retrieval position attends backward over prior token
|
| 333 |
+
states using RoPE Q/K. This is O(T) per retrieved position, not O(T^2).
|
| 334 |
+
|
| 335 |
+
This module is the key successful retrieval primitive from the experiments.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
dim: int,
|
| 341 |
+
retrieval_dim: int,
|
| 342 |
+
retrieval_heads: int,
|
| 343 |
+
block_size: int,
|
| 344 |
+
dropout: float = 0.0,
|
| 345 |
+
):
|
| 346 |
+
super().__init__()
|
| 347 |
+
if retrieval_dim % retrieval_heads != 0:
|
| 348 |
+
raise ValueError("retrieval_dim must be divisible by retrieval_heads")
|
| 349 |
+
self.dim = int(dim)
|
| 350 |
+
self.retrieval_dim = int(retrieval_dim)
|
| 351 |
+
self.retrieval_heads = int(retrieval_heads)
|
| 352 |
+
self.head_dim = retrieval_dim // retrieval_heads
|
| 353 |
+
if self.head_dim % 2 != 0:
|
| 354 |
+
raise ValueError("retrieval_dim / retrieval_heads must be even for RoPE")
|
| 355 |
+
|
| 356 |
+
self.norm = nn.LayerNorm(dim)
|
| 357 |
+
self.q = nn.Linear(dim, retrieval_dim)
|
| 358 |
+
self.k = nn.Linear(dim, retrieval_dim)
|
| 359 |
+
self.v = nn.Linear(dim, retrieval_dim)
|
| 360 |
+
self.out = nn.Linear(retrieval_dim, dim)
|
| 361 |
+
self.gate = nn.Linear(dim * 2, dim)
|
| 362 |
+
self.dropout = nn.Dropout(dropout)
|
| 363 |
+
self.rope = RotaryCache(self.head_dim, max_seq_len=block_size + 8)
|
| 364 |
+
|
| 365 |
+
def forward(self, x: torch.Tensor, retrieval_positions: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
# x: [B,T,D], retrieval_positions: [B]
|
| 367 |
+
batch, steps, _ = x.shape
|
| 368 |
+
device = x.device
|
| 369 |
+
bidx = torch.arange(batch, device=device)
|
| 370 |
+
|
| 371 |
+
h = self.norm(x)
|
| 372 |
+
|
| 373 |
+
k = self.k(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
|
| 374 |
+
v = self.v(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
|
| 375 |
+
k = self.rope.apply_sequence(k)
|
| 376 |
+
|
| 377 |
+
qh = h[bidx, retrieval_positions]
|
| 378 |
+
q = self.q(qh).view(batch, self.retrieval_heads, 1, self.head_dim)
|
| 379 |
+
q = self.rope.apply_query_positions(q, retrieval_positions)
|
| 380 |
+
|
| 381 |
+
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 382 |
+
|
| 383 |
+
# Strictly backward. The retrieval position cannot read itself.
|
| 384 |
+
pos = torch.arange(steps, device=device)[None, None, None, :]
|
| 385 |
+
causal_mask = pos < retrieval_positions[:, None, None, None]
|
| 386 |
+
scores = scores.masked_fill(~causal_mask, float("-inf"))
|
| 387 |
+
|
| 388 |
+
att = F.softmax(scores, dim=-1)
|
| 389 |
+
att = self.dropout(att)
|
| 390 |
+
|
| 391 |
+
read = (att @ v).transpose(1, 2).contiguous().view(batch, self.retrieval_dim)
|
| 392 |
+
read = self.out(read)
|
| 393 |
+
|
| 394 |
+
old = x[bidx, retrieval_positions]
|
| 395 |
+
gate = torch.sigmoid(self.gate(torch.cat([qh, read], dim=-1)))
|
| 396 |
+
new = old + gate * read
|
| 397 |
+
|
| 398 |
+
out = x.clone()
|
| 399 |
+
out[bidx, retrieval_positions] = new
|
| 400 |
+
return out
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# -----------------------------
|
| 404 |
+
# Main model
|
| 405 |
+
# -----------------------------
|
| 406 |
+
|
| 407 |
+
class PGSMSparseRoPELM(nn.Module):
|
| 408 |
+
"""
|
| 409 |
+
Parallel Geometric State Model with optional query-only sparse RoPE retrieval.
|
| 410 |
+
|
| 411 |
+
Forward API:
|
| 412 |
+
logits, loss = model(input_ids, labels=None, retrieval_positions=None)
|
| 413 |
+
|
| 414 |
+
input_ids:
|
| 415 |
+
LongTensor [B,T]
|
| 416 |
+
|
| 417 |
+
labels:
|
| 418 |
+
LongTensor [B,T], optional.
|
| 419 |
+
Standard next-token labels are supported.
|
| 420 |
+
Use config.ignore_index for ignored positions.
|
| 421 |
+
|
| 422 |
+
retrieval_positions:
|
| 423 |
+
Optional LongTensor [B].
|
| 424 |
+
If supplied, sparse retrieval is applied exactly at these positions.
|
| 425 |
+
If omitted, config controls whether to auto-detect query-token positions,
|
| 426 |
+
use final token, or skip retrieval.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
def __init__(self, config: PGSMConfig):
|
| 430 |
+
super().__init__()
|
| 431 |
+
self.config = config
|
| 432 |
+
|
| 433 |
+
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
|
| 434 |
+
self.blocks = nn.ModuleList(
|
| 435 |
+
[
|
| 436 |
+
ParallelGeometricBlock(
|
| 437 |
+
dim=config.dim,
|
| 438 |
+
hidden=config.hidden,
|
| 439 |
+
kernel_size=config.kernel_size,
|
| 440 |
+
dropout=config.dropout,
|
| 441 |
+
)
|
| 442 |
+
for _ in range(config.layers)
|
| 443 |
+
]
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
self.retriever: Optional[QueryOnlyRoPERetriever]
|
| 447 |
+
if config.use_sparse_retrieval:
|
| 448 |
+
self.retriever = QueryOnlyRoPERetriever(
|
| 449 |
+
dim=config.dim,
|
| 450 |
+
retrieval_dim=config.retrieval_dim,
|
| 451 |
+
retrieval_heads=config.retrieval_heads,
|
| 452 |
+
block_size=config.block_size,
|
| 453 |
+
dropout=config.retrieval_dropout,
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
self.retriever = None
|
| 457 |
+
|
| 458 |
+
self.post_retrieval_block: Optional[ParallelGeometricBlock]
|
| 459 |
+
if config.use_sparse_retrieval and config.use_post_retrieval_block:
|
| 460 |
+
self.post_retrieval_block = ParallelGeometricBlock(
|
| 461 |
+
dim=config.dim,
|
| 462 |
+
hidden=config.hidden,
|
| 463 |
+
kernel_size=config.kernel_size,
|
| 464 |
+
dropout=config.dropout,
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
self.post_retrieval_block = None
|
| 468 |
+
|
| 469 |
+
self.final_norm = nn.LayerNorm(config.dim)
|
| 470 |
+
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
| 471 |
+
|
| 472 |
+
self.apply(lambda module: init_pgsm_weights(module, std=config.init_std))
|
| 473 |
+
|
| 474 |
+
if config.tie_weights:
|
| 475 |
+
self.lm_head.weight = self.token_emb.weight
|
| 476 |
+
|
| 477 |
+
@property
|
| 478 |
+
def block_size(self) -> int:
|
| 479 |
+
return self.config.block_size
|
| 480 |
+
|
| 481 |
+
@property
|
| 482 |
+
def vocab_size(self) -> int:
|
| 483 |
+
return self.config.vocab_size
|
| 484 |
+
|
| 485 |
+
def num_parameters(self, trainable_only: bool = True) -> int:
|
| 486 |
+
return count_parameters(self, trainable_only=trainable_only)
|
| 487 |
+
|
| 488 |
+
def _resolve_retrieval_positions(
|
| 489 |
+
self,
|
| 490 |
+
input_ids: torch.Tensor,
|
| 491 |
+
retrieval_positions: Optional[torch.Tensor],
|
| 492 |
+
) -> Optional[torch.Tensor]:
|
| 493 |
+
if not self.config.use_sparse_retrieval:
|
| 494 |
+
return None
|
| 495 |
+
|
| 496 |
+
if retrieval_positions is not None:
|
| 497 |
+
return retrieval_positions.to(device=input_ids.device, dtype=torch.long)
|
| 498 |
+
|
| 499 |
+
if (
|
| 500 |
+
self.config.auto_retrieve_on_query_token
|
| 501 |
+
and self.config.query_token_id is not None
|
| 502 |
+
):
|
| 503 |
+
return _positions_from_query_tokens(input_ids, self.config.query_token_id)
|
| 504 |
+
|
| 505 |
+
if self.config.retrieve_at_last_token_if_unspecified:
|
| 506 |
+
return torch.full(
|
| 507 |
+
(input_ids.size(0),),
|
| 508 |
+
input_ids.size(1) - 1,
|
| 509 |
+
dtype=torch.long,
|
| 510 |
+
device=input_ids.device,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
return None
|
| 514 |
+
|
| 515 |
+
def encode(
|
| 516 |
+
self,
|
| 517 |
+
input_ids: torch.Tensor,
|
| 518 |
+
retrieval_positions: Optional[torch.Tensor] = None,
|
| 519 |
+
) -> torch.Tensor:
|
| 520 |
+
if input_ids.dim() != 2:
|
| 521 |
+
raise ValueError("input_ids must have shape [batch, steps]")
|
| 522 |
+
if input_ids.size(1) > self.config.block_size:
|
| 523 |
+
raise ValueError(
|
| 524 |
+
f"Input length {input_ids.size(1)} exceeds config.block_size={self.config.block_size}"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
x = self.token_emb(input_ids)
|
| 528 |
+
|
| 529 |
+
for block in self.blocks:
|
| 530 |
+
x = block(x)
|
| 531 |
+
|
| 532 |
+
positions = self._resolve_retrieval_positions(input_ids, retrieval_positions)
|
| 533 |
+
if positions is not None:
|
| 534 |
+
if self.retriever is None:
|
| 535 |
+
raise RuntimeError("retriever is None but retrieval positions were resolved")
|
| 536 |
+
x = self.retriever(x, positions)
|
| 537 |
+
if self.post_retrieval_block is not None:
|
| 538 |
+
x = self.post_retrieval_block(x)
|
| 539 |
+
|
| 540 |
+
return self.final_norm(x)
|
| 541 |
+
|
| 542 |
+
def forward(
|
| 543 |
+
self,
|
| 544 |
+
input_ids: torch.Tensor,
|
| 545 |
+
labels: Optional[torch.Tensor] = None,
|
| 546 |
+
retrieval_positions: Optional[torch.Tensor] = None,
|
| 547 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 548 |
+
x = self.encode(input_ids, retrieval_positions=retrieval_positions)
|
| 549 |
+
logits = self.lm_head(x)
|
| 550 |
+
|
| 551 |
+
loss: Optional[torch.Tensor] = None
|
| 552 |
+
if labels is not None:
|
| 553 |
+
loss = F.cross_entropy(
|
| 554 |
+
logits.reshape(-1, logits.size(-1)),
|
| 555 |
+
labels.reshape(-1),
|
| 556 |
+
ignore_index=self.config.ignore_index,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
return logits, loss
|
| 560 |
+
|
| 561 |
+
@torch.no_grad()
|
| 562 |
+
def generate(
|
| 563 |
+
self,
|
| 564 |
+
input_ids: torch.Tensor,
|
| 565 |
+
max_new_tokens: int,
|
| 566 |
+
temperature: float = 1.0,
|
| 567 |
+
top_k: Optional[int] = None,
|
| 568 |
+
) -> torch.Tensor:
|
| 569 |
+
"""
|
| 570 |
+
Simple generation helper.
|
| 571 |
+
|
| 572 |
+
For normal generation, sparse retrieval is not automatically applied unless
|
| 573 |
+
config.retrieve_at_last_token_if_unspecified=True or query-token detection
|
| 574 |
+
is enabled. Training modules can provide their own generation loop if they
|
| 575 |
+
need custom retrieval-position behavior.
|
| 576 |
+
"""
|
| 577 |
+
self.eval()
|
| 578 |
+
for _ in range(max_new_tokens):
|
| 579 |
+
idx_cond = input_ids[:, -self.config.block_size :]
|
| 580 |
+
logits, _ = self(idx_cond)
|
| 581 |
+
logits = logits[:, -1, :]
|
| 582 |
+
|
| 583 |
+
if temperature <= 0:
|
| 584 |
+
next_id = torch.argmax(logits, dim=-1, keepdim=True)
|
| 585 |
+
else:
|
| 586 |
+
logits = logits / temperature
|
| 587 |
+
if top_k is not None:
|
| 588 |
+
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 589 |
+
logits = logits.masked_fill(logits < values[:, [-1]], float("-inf"))
|
| 590 |
+
probs = F.softmax(logits, dim=-1)
|
| 591 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
| 592 |
+
|
| 593 |
+
input_ids = torch.cat([input_ids, next_id], dim=1)
|
| 594 |
+
|
| 595 |
+
return input_ids
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
# -----------------------------
|
| 599 |
+
# Convenience factory
|
| 600 |
+
# -----------------------------
|
| 601 |
+
|
| 602 |
+
def build_pgsm_model(
|
| 603 |
+
size: str = "small",
|
| 604 |
+
vocab_size: int = 256,
|
| 605 |
+
block_size: int = 1024,
|
| 606 |
+
**overrides: Any,
|
| 607 |
+
) -> PGSMSparseRoPELM:
|
| 608 |
+
size = size.lower().strip()
|
| 609 |
+
if size == "tiny":
|
| 610 |
+
cfg = PGSMConfig.tiny(vocab_size=vocab_size, block_size=block_size, **overrides)
|
| 611 |
+
elif size == "small":
|
| 612 |
+
cfg = PGSMConfig.small(vocab_size=vocab_size, block_size=block_size, **overrides)
|
| 613 |
+
elif size == "medium":
|
| 614 |
+
cfg = PGSMConfig.medium(vocab_size=vocab_size, block_size=block_size, **overrides)
|
| 615 |
+
elif size == "large":
|
| 616 |
+
cfg = PGSMConfig.large(vocab_size=vocab_size, block_size=block_size, **overrides)
|
| 617 |
+
else:
|
| 618 |
+
raise ValueError(f"Unknown model size: {size!r}. Use tiny, small, medium, or large.")
|
| 619 |
+
return PGSMSparseRoPELM(cfg)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
__all__ = [
|
| 623 |
+
"PGSMConfig",
|
| 624 |
+
"PGSMSparseRoPELM",
|
| 625 |
+
"build_pgsm_model",
|
| 626 |
+
"count_parameters",
|
| 627 |
+
]
|