nilmeruo commited on
Commit
01b6330
·
verified ·
1 Parent(s): 4444c8c

Upload 3 files

Browse files
build_ascii_vocab_bundle_v9.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ASCII-limited English-first vocab bundle builder for a tiny LLM.
4
+
5
+ Design goals
6
+ ------------
7
+ - English-only as much as reasonably possible
8
+ - Keep text intact instead of creating holes
9
+ - Fold uppercase -> lowercase
10
+ - Fold accented Latin letters -> plain ASCII where reasonable
11
+ - Drop emoji and non-Latin scripts
12
+ - Keep only a small practical punctuation set
13
+ - Learn multi-character tokens from LETTERS ONLY
14
+ - Keep digits and punctuation atomic as single-character tokens
15
+ - Stream from Hugging Face without local dataset files
16
+
17
+ Default source
18
+ --------------
19
+ Streams:
20
+ HuggingFaceFW/fineweb-edu
21
+ config=sample-10BT
22
+ split=train
23
+
24
+ Outputs
25
+ -------
26
+ Creates a bundle directory containing:
27
+ manifest.json
28
+ vocab.json
29
+ token_stats.npz
30
+ pair_stats.npz
31
+
32
+ What gets kept
33
+ --------------
34
+ - letters: a-z
35
+ - digits: 0-9
36
+ - whitespace: space + newline
37
+ - limited punctuation:
38
+ . , ! ? ' " - ( ) : ; @ # + % = / \ *
39
+
40
+ Tokenization policy
41
+ -------------------
42
+ - learned multi-character tokens: letters only
43
+ - digits remain atomic single-character tokens
44
+ - punctuation remains atomic single-character tokens
45
+
46
+ Examples
47
+ --------
48
+ PowerShell smoke test:
49
+ python F:\\TokenizerUltra\\build_ascii_vocab_bundle_v9.py --output "F:\\TokenizerUltra\\vocab_bundle_test" --max-examples 5000 --bpe-train-chars 2000000 --final-token-budget 2000000
50
+
51
+ PowerShell full build:
52
+ python F:\\TokenizerUltra\\build_ascii_vocab_bundle_v9.py --output "F:\\TokenizerUltra\\vocab_bundle" --bpe-train-chars 100000000 --final-token-budget 100000000
53
+
54
+ Dependencies
55
+ ------------
56
+ python -m pip install numpy datasets
57
+ """
58
+ from __future__ import annotations
59
+
60
+ import argparse
61
+ import json
62
+ import re
63
+ import unicodedata
64
+ from collections import Counter
65
+ from dataclasses import dataclass
66
+ from pathlib import Path
67
+ from typing import Dict, Iterator, List, Optional, Sequence, Tuple
68
+
69
+ import numpy as np
70
+
71
+
72
+ DEFAULT_DATASET = "HuggingFaceFW/fineweb-edu"
73
+ DEFAULT_CONFIG = "sample-10BT"
74
+ DEFAULT_SPLIT = "train"
75
+
76
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
77
+
78
+ ASCII_LETTERS = "abcdefghijklmnopqrstuvwxyz"
79
+ ASCII_DIGITS = "0123456789"
80
+ ALLOWED_PUNCT = ".,!?\'\"-():;@#+%=/\\*"
81
+
82
+ SPACE_TOKEN = " "
83
+ NEWLINE_TOKEN = "\n"
84
+
85
+ ALLOWED_CHARS = set(ASCII_LETTERS + ASCII_DIGITS + ALLOWED_PUNCT + SPACE_TOKEN + NEWLINE_TOKEN)
86
+ TEXT_FIELDS = ("text", "content", "body", "document", "raw_content", "message")
87
+
88
+ ESCAPED_PUNCT = re.escape(ALLOWED_PUNCT)
89
+ TOKEN_RE = re.compile(rf"\n| +|[a-z]+|[0-9]|[{ESCAPED_PUNCT}]")
90
+ MULTISPACE_RE = re.compile(r"[ \t\f\v]+")
91
+ MULTINEWLINE_RE = re.compile(r"\n{3,}")
92
+
93
+ SEQUENCE_REPLACEMENTS = {
94
+ "\u2018": "'",
95
+ "\u2019": "'",
96
+ "\u201c": '"',
97
+ "\u201d": '"',
98
+ "\u2013": "-",
99
+ "\u2014": "-",
100
+ "\u2015": "-",
101
+ "\u2212": "-",
102
+ "\u2026": "...",
103
+ "\u2022": " ",
104
+ "\u00b7": " ",
105
+ "\u00a0": " ",
106
+ "\u200b": "",
107
+ "\u200c": "",
108
+ "\u200d": "",
109
+ "\ufeff": "",
110
+ "\u00ad": "",
111
+ "\t": " ",
112
+ "\r": "\n",
113
+ "[": "(",
114
+ "]": ")",
115
+ "{": "(",
116
+ "}": ")",
117
+ "<": "(",
118
+ ">": ")",
119
+ "(": "(",
120
+ ")": ")",
121
+ "[": "(",
122
+ "]": ")",
123
+ "{": "(",
124
+ "}": ")",
125
+ "【": "(",
126
+ "】": ")",
127
+ "〈": "(",
128
+ "〉": ")",
129
+ "《": "(",
130
+ "》": ")",
131
+ "「": "(",
132
+ "」": ")",
133
+ "『": "(",
134
+ "』": ")",
135
+ "〔": "(",
136
+ "〕": ")",
137
+ "〖": "(",
138
+ "〗": ")",
139
+ }
140
+
141
+ LATIN_FOLD_REPLACEMENTS = {
142
+ "ß": "ss",
143
+ "ẞ": "ss",
144
+ "æ": "ae",
145
+ "ǽ": "ae",
146
+ "œ": "oe",
147
+ "ø": "o",
148
+ "ð": "d",
149
+ "þ": "th",
150
+ "ł": "l",
151
+ "đ": "d",
152
+ "ħ": "h",
153
+ "ı": "i",
154
+ }
155
+
156
+
157
+ @dataclass
158
+ class BundleConfig:
159
+ output: Path
160
+ dataset: str = DEFAULT_DATASET
161
+ config: str = DEFAULT_CONFIG
162
+ split: str = DEFAULT_SPLIT
163
+ vocab_size: int = 2000
164
+ bpe_train_chars: int = 100_000_000
165
+ final_token_budget: int = 100_000_000
166
+ max_examples: Optional[int] = None
167
+ min_pair_count: int = 5
168
+ token_prior_clip: float = 3.0
169
+ pair_prior_clip: float = 3.0
170
+ word_cache_size: int = 200000
171
+
172
+
173
+ def _import_load_dataset():
174
+ try:
175
+ from datasets import load_dataset
176
+ except Exception as exc:
177
+ raise SystemExit(
178
+ "Missing dependency: datasets. Install with:\n"
179
+ " python -m pip install datasets numpy"
180
+ ) from exc
181
+ return load_dataset
182
+
183
+
184
+ def normalize_text(text: str) -> str:
185
+ if not text:
186
+ return ""
187
+
188
+ for src, dst in SEQUENCE_REPLACEMENTS.items():
189
+ text = text.replace(src, dst)
190
+ for src, dst in LATIN_FOLD_REPLACEMENTS.items():
191
+ text = text.replace(src, dst)
192
+
193
+ text = text.casefold()
194
+ text = unicodedata.normalize("NFKD", text)
195
+
196
+ out_chars: List[str] = []
197
+ last_was_space = False
198
+
199
+ for ch in text:
200
+ cat = unicodedata.category(ch)
201
+
202
+ if cat.startswith("M"):
203
+ continue
204
+
205
+ if ch in ALLOWED_CHARS:
206
+ out_chars.append(ch)
207
+ last_was_space = (ch == " ")
208
+ continue
209
+
210
+ if ch == "\n":
211
+ out_chars.append("\n")
212
+ last_was_space = False
213
+ continue
214
+
215
+ if ch.isspace():
216
+ if not last_was_space:
217
+ out_chars.append(" ")
218
+ last_was_space = True
219
+ continue
220
+
221
+ if ord(ch) < 128:
222
+ if cat[:1] in {"P", "S"} or ch in "[]{}<>_|~^$&`":
223
+ if not last_was_space:
224
+ out_chars.append(" ")
225
+ last_was_space = True
226
+ continue
227
+
228
+ if cat[:1] in {"L", "N", "P", "S"}:
229
+ if not last_was_space:
230
+ out_chars.append(" ")
231
+ last_was_space = True
232
+ continue
233
+
234
+ normalized = "".join(out_chars)
235
+ normalized = MULTISPACE_RE.sub(" ", normalized)
236
+ normalized = re.sub(r" *\n *", "\n", normalized)
237
+ normalized = MULTINEWLINE_RE.sub("\n\n", normalized)
238
+ normalized = normalized.strip(" ")
239
+ return normalized
240
+
241
+
242
+ def iter_stream_examples(
243
+ dataset_name: str,
244
+ config_name: str,
245
+ split: str,
246
+ max_examples: Optional[int],
247
+ ) -> Iterator[str]:
248
+ load_dataset = _import_load_dataset()
249
+ ds = load_dataset(dataset_name, config_name, split=split, streaming=True)
250
+
251
+ seen = 0
252
+ for row in ds:
253
+ text = None
254
+
255
+ if isinstance(row, dict):
256
+ for field in TEXT_FIELDS:
257
+ if field in row and isinstance(row[field], str):
258
+ text = row[field]
259
+ break
260
+
261
+ if text is None and "messages" in row and isinstance(row["messages"], list):
262
+ chunks: List[str] = []
263
+ for msg in row["messages"]:
264
+ if isinstance(msg, dict):
265
+ content = msg.get("content")
266
+ if isinstance(content, str):
267
+ chunks.append(content)
268
+ if chunks:
269
+ text = "\n".join(chunks)
270
+
271
+ elif isinstance(row, str):
272
+ text = row
273
+
274
+ if text:
275
+ yield text
276
+ seen += 1
277
+ if max_examples is not None and seen >= max_examples:
278
+ break
279
+
280
+
281
+ def iter_normalized_text(cfg: BundleConfig) -> Iterator[str]:
282
+ for raw in iter_stream_examples(cfg.dataset, cfg.config, cfg.split, cfg.max_examples):
283
+ text = normalize_text(raw)
284
+ if text:
285
+ yield text
286
+
287
+
288
+ def iter_pre_tokens(text: str) -> Iterator[str]:
289
+ for piece in TOKEN_RE.findall(text):
290
+ yield piece
291
+
292
+
293
+ def count_words_for_bpe(cfg: BundleConfig) -> Counter[str]:
294
+ word_freq: Counter[str] = Counter()
295
+ char_budget = 0
296
+
297
+ for text in iter_normalized_text(cfg):
298
+ char_budget += len(text)
299
+ for piece in iter_pre_tokens(text):
300
+ if piece.isalpha():
301
+ word_freq[piece] += 1
302
+
303
+ if char_budget >= cfg.bpe_train_chars:
304
+ break
305
+
306
+ return word_freq
307
+
308
+
309
+ def word_to_symbols(word: str) -> Tuple[str, ...]:
310
+ return tuple(word)
311
+
312
+
313
+ def compute_pair_counts_from_vocab(
314
+ vocab_words: Dict[Tuple[str, ...], int]
315
+ ) -> Counter[Tuple[str, str]]:
316
+ pair_counts: Counter[Tuple[str, str]] = Counter()
317
+ for symbols, freq in vocab_words.items():
318
+ if len(symbols) < 2:
319
+ continue
320
+ for i in range(len(symbols) - 1):
321
+ left = symbols[i]
322
+ right = symbols[i + 1]
323
+ if left.isalpha() and right.isalpha():
324
+ pair_counts[(left, right)] += freq
325
+ return pair_counts
326
+
327
+
328
+ def merge_word_symbols(
329
+ symbols: Tuple[str, ...],
330
+ pair: Tuple[str, str],
331
+ ) -> Tuple[str, ...]:
332
+ merged: List[str] = []
333
+ i = 0
334
+ while i < len(symbols):
335
+ if i < len(symbols) - 1 and symbols[i] == pair[0] and symbols[i + 1] == pair[1]:
336
+ merged.append(symbols[i] + symbols[i + 1])
337
+ i += 2
338
+ else:
339
+ merged.append(symbols[i])
340
+ i += 1
341
+ return tuple(merged)
342
+
343
+
344
+ def train_bpe_from_words(
345
+ word_freq: Counter[str],
346
+ vocab_size: int,
347
+ ) -> Tuple[List[str], List[Tuple[str, str]]]:
348
+ fixed_non_alpha_count = len(SPECIAL_TOKENS) + 2 + len(ASCII_DIGITS) + len(ALLOWED_PUNCT)
349
+ target_alpha_piece_count = max(vocab_size - fixed_non_alpha_count, len(ASCII_LETTERS))
350
+
351
+ vocab_words: Dict[Tuple[str, ...], int] = {
352
+ word_to_symbols(word): freq for word, freq in word_freq.items()
353
+ }
354
+
355
+ current_symbols = set(ASCII_LETTERS)
356
+ merges: List[Tuple[str, str]] = []
357
+
358
+ while len(current_symbols) < target_alpha_piece_count:
359
+ pair_counts = compute_pair_counts_from_vocab(vocab_words)
360
+ if not pair_counts:
361
+ break
362
+
363
+ best_pair, best_count = pair_counts.most_common(1)[0]
364
+ if best_count < 2:
365
+ break
366
+
367
+ merges.append(best_pair)
368
+ new_vocab_words: Dict[Tuple[str, ...], int] = {}
369
+ for symbols, freq in vocab_words.items():
370
+ merged_symbols = merge_word_symbols(symbols, best_pair)
371
+ new_vocab_words[merged_symbols] = new_vocab_words.get(merged_symbols, 0) + freq
372
+ vocab_words = new_vocab_words
373
+ current_symbols.add(best_pair[0] + best_pair[1])
374
+
375
+ if len(current_symbols) % 100 == 0:
376
+ print(f"[bpe] learned alpha pieces: {len(current_symbols)}", flush=True)
377
+
378
+ learned_alpha_pieces = sorted(current_symbols)
379
+ final_vocab = (
380
+ SPECIAL_TOKENS
381
+ + [SPACE_TOKEN, NEWLINE_TOKEN]
382
+ + list(ASCII_DIGITS)
383
+ + list(ALLOWED_PUNCT)
384
+ + learned_alpha_pieces
385
+ )
386
+ final_vocab = final_vocab[:vocab_size]
387
+ return final_vocab, merges
388
+
389
+
390
+ class GreedyTokenizer:
391
+ def __init__(
392
+ self,
393
+ vocab: Sequence[str],
394
+ merges: Sequence[Tuple[str, str]],
395
+ word_cache_size: int = 200000,
396
+ ) -> None:
397
+ self.vocab = list(vocab)
398
+ self.merges = list(merges)
399
+ self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
400
+ self.unk_id = self.token_to_id["<unk>"]
401
+ self.alpha_token_ids = {
402
+ tid for tok, tid in self.token_to_id.items() if tok.isalpha()
403
+ }
404
+ self.merge_ranks: Dict[Tuple[str, str], int] = {
405
+ pair: rank for rank, pair in enumerate(self.merges)
406
+ }
407
+ self.word_cache_size = max(int(word_cache_size), 0)
408
+ self._word_cache: Dict[str, Tuple[int, ...]] = {}
409
+
410
+ def _get_pairs(self, symbols: Tuple[str, ...]) -> set[Tuple[str, str]]:
411
+ return set(zip(symbols[:-1], symbols[1:]))
412
+
413
+ def _merge_once(self, symbols: Tuple[str, ...], pair: Tuple[str, str]) -> Tuple[str, ...]:
414
+ first, second = pair
415
+ merged: List[str] = []
416
+ i = 0
417
+ while i < len(symbols):
418
+ if i < len(symbols) - 1 and symbols[i] == first and symbols[i + 1] == second:
419
+ merged.append(first + second)
420
+ i += 2
421
+ else:
422
+ merged.append(symbols[i])
423
+ i += 1
424
+ return tuple(merged)
425
+
426
+ def tokenize_alpha_run(self, span: str) -> List[int]:
427
+ if not span:
428
+ return []
429
+
430
+ cached = self._word_cache.get(span)
431
+ if cached is not None:
432
+ return list(cached)
433
+
434
+ symbols: Tuple[str, ...] = tuple(span)
435
+
436
+ while True:
437
+ pairs = self._get_pairs(symbols)
438
+ if not pairs:
439
+ break
440
+
441
+ ranked_pairs = [
442
+ (self.merge_ranks[pair], pair)
443
+ for pair in pairs
444
+ if pair in self.merge_ranks
445
+ ]
446
+ if not ranked_pairs:
447
+ break
448
+
449
+ _, best_pair = min(ranked_pairs)
450
+ symbols = self._merge_once(symbols, best_pair)
451
+ if len(symbols) == 1:
452
+ break
453
+
454
+ token_ids = tuple(self.token_to_id.get(piece, self.unk_id) for piece in symbols)
455
+
456
+ if self.word_cache_size > 0:
457
+ if len(self._word_cache) >= self.word_cache_size:
458
+ self._word_cache.clear()
459
+ self._word_cache[span] = token_ids
460
+
461
+ return list(token_ids)
462
+
463
+ def is_alpha_id(self, token_id: int) -> bool:
464
+ return token_id in self.alpha_token_ids
465
+
466
+ def encode(self, text: str) -> List[int]:
467
+ ids: List[int] = []
468
+ for piece in iter_pre_tokens(text):
469
+ if piece == "\n":
470
+ ids.append(self.token_to_id[NEWLINE_TOKEN])
471
+ elif piece.isspace():
472
+ ids.append(self.token_to_id[SPACE_TOKEN])
473
+ elif piece.isalpha():
474
+ ids.extend(self.tokenize_alpha_run(piece))
475
+ else:
476
+ ids.append(self.token_to_id.get(piece, self.unk_id))
477
+ return ids
478
+
479
+
480
+ def _safe_zscore(values: np.ndarray) -> np.ndarray:
481
+ values = values.astype(np.float32, copy=False)
482
+ mean = float(values.mean())
483
+ std = float(values.std())
484
+ if std < 1e-8:
485
+ return np.zeros_like(values, dtype=np.float32)
486
+ return (values - mean) / std
487
+
488
+
489
+ def build_priors_from_counts(
490
+ counts: np.ndarray,
491
+ clip_value: float,
492
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
493
+ counts = counts.astype(np.float64, copy=False)
494
+ total = counts.sum()
495
+ if total <= 0:
496
+ raise ValueError("Counts are empty; cannot build priors.")
497
+
498
+ probs = (counts + 1.0) / (total + counts.size)
499
+ surprisal = -np.log(probs)
500
+ z = _safe_zscore(surprisal.astype(np.float32))
501
+ z = np.clip(z, -clip_value, clip_value)
502
+ prior = (z + clip_value) / (2.0 * clip_value)
503
+ return probs.astype(np.float32), surprisal.astype(np.float32), prior.astype(np.float32)
504
+
505
+
506
+ def build_pair_priors(
507
+ pair_counts: np.ndarray,
508
+ min_pair_count: int,
509
+ clip_value: float,
510
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
511
+ row_sums = pair_counts.sum(axis=1, keepdims=True).astype(np.float64)
512
+ vocab_size = pair_counts.shape[0]
513
+
514
+ probs = (pair_counts.astype(np.float64) + 1.0) / (row_sums + vocab_size)
515
+ surprisal = -np.log(probs)
516
+ valid_mask = (pair_counts >= min_pair_count).astype(np.uint8)
517
+
518
+ flat_surprisal = surprisal.astype(np.float32).reshape(-1)
519
+ z = _safe_zscore(flat_surprisal).reshape(pair_counts.shape)
520
+ z = np.clip(z, -clip_value, clip_value)
521
+
522
+ prior = (z + clip_value) / (2.0 * clip_value)
523
+ prior = np.where(valid_mask == 1, prior, 0.5)
524
+
525
+ return (
526
+ probs.astype(np.float32),
527
+ surprisal.astype(np.float32),
528
+ prior.astype(np.float32),
529
+ valid_mask,
530
+ )
531
+
532
+
533
+ def second_pass_stats(
534
+ cfg: BundleConfig,
535
+ tokenizer: GreedyTokenizer,
536
+ ) -> Tuple[np.ndarray, np.ndarray]:
537
+ vocab_size = len(tokenizer.vocab)
538
+ token_counts = np.zeros(vocab_size, dtype=np.int64)
539
+ pair_counts = np.zeros((vocab_size, vocab_size), dtype=np.int32)
540
+
541
+ token_budget = 0
542
+ for text in iter_normalized_text(cfg):
543
+ ids = tokenizer.encode(text)
544
+ if not ids:
545
+ continue
546
+
547
+ for tid in ids:
548
+ token_counts[tid] += 1
549
+
550
+ prev = ids[0]
551
+ for cur in ids[1:]:
552
+ if tokenizer.is_alpha_id(prev) and tokenizer.is_alpha_id(cur):
553
+ pair_counts[prev, cur] += 1
554
+ prev = cur
555
+
556
+ token_budget += len(ids)
557
+ if token_budget >= cfg.final_token_budget:
558
+ break
559
+
560
+ return token_counts, pair_counts
561
+
562
+
563
+ def save_bundle(
564
+ cfg: BundleConfig,
565
+ vocab: Sequence[str],
566
+ merges: Sequence[Tuple[str, str]],
567
+ token_counts: np.ndarray,
568
+ token_probs: np.ndarray,
569
+ token_surprisal: np.ndarray,
570
+ token_prior: np.ndarray,
571
+ pair_counts: np.ndarray,
572
+ pair_probs: np.ndarray,
573
+ pair_surprisal: np.ndarray,
574
+ pair_prior: np.ndarray,
575
+ pair_valid_mask: np.ndarray,
576
+ ) -> None:
577
+ cfg.output.mkdir(parents=True, exist_ok=True)
578
+
579
+ vocab_json = {
580
+ "token_to_id": {tok: i for i, tok in enumerate(vocab)},
581
+ "id_to_token": {str(i): tok for i, tok in enumerate(vocab)},
582
+ "special_tokens": SPECIAL_TOKENS,
583
+ "space_token": SPACE_TOKEN,
584
+ "newline_token": NEWLINE_TOKEN,
585
+ "merges": [[a, b] for a, b in merges],
586
+ }
587
+ (cfg.output / "vocab.json").write_text(
588
+ json.dumps(vocab_json, indent=2, ensure_ascii=True),
589
+ encoding="utf-8",
590
+ )
591
+
592
+ manifest = {
593
+ "bundle_version": 9,
594
+ "description": "english-first ascii-limited vocab bundle with letter-only learned tokens, atomic digits and punctuation, latin accent folding, bracket folding, faster ranked-bpe runtime tokenization, and alpha-only pair priors",
595
+ "dataset": cfg.dataset,
596
+ "config": cfg.config,
597
+ "split": cfg.split,
598
+ "vocab_size": len(vocab),
599
+ "requested_vocab_size": cfg.vocab_size,
600
+ "special_tokens": SPECIAL_TOKENS,
601
+ "allowed_ascii_letters": ASCII_LETTERS,
602
+ "allowed_ascii_digits": ASCII_DIGITS,
603
+ "allowed_ascii_punctuation": ALLOWED_PUNCT,
604
+ "normalization": {
605
+ "casefold_uppercase_to_lowercase": True,
606
+ "latin_accent_folding": True,
607
+ "bracket_like_marks_folded_to_parentheses": True,
608
+ "non_latin_scripts_to_space": True,
609
+ "emoji_removed": True,
610
+ "unsupported_symbols_to_space": True,
611
+ "collapse_spaces": True,
612
+ "trim_long_newlines": True,
613
+ "runtime_tokenization": "ranked_bpe_letters_only_with_word_cache",
614
+ },
615
+ "token_shape_policy": {
616
+ "learned_multi_character_tokens": "letters_only",
617
+ "digits": "atomic_single_character",
618
+ "punctuation": "atomic_single_character",
619
+ "spaces": "atomic_single_character",
620
+ "newlines": "atomic_single_character",
621
+ },
622
+ "pair_prior_scope": {
623
+ "counted_pairs": "alpha_to_alpha_only",
624
+ "non_alpha_pairs": "neutral_default_prior",
625
+ },
626
+ "bpe_train_chars": cfg.bpe_train_chars,
627
+ "final_token_budget": cfg.final_token_budget,
628
+ "min_pair_count": cfg.min_pair_count,
629
+ "token_prior_clip": cfg.token_prior_clip,
630
+ "pair_prior_clip": cfg.pair_prior_clip,
631
+ "word_cache_size": cfg.word_cache_size,
632
+ }
633
+ (cfg.output / "manifest.json").write_text(
634
+ json.dumps(manifest, indent=2, ensure_ascii=True),
635
+ encoding="utf-8",
636
+ )
637
+
638
+ np.savez_compressed(
639
+ cfg.output / "token_stats.npz",
640
+ count=token_counts,
641
+ prob=token_probs,
642
+ surprisal=token_surprisal,
643
+ importance_prior=token_prior,
644
+ )
645
+
646
+ np.savez_compressed(
647
+ cfg.output / "pair_stats.npz",
648
+ pair_count=pair_counts,
649
+ pair_prob=pair_probs,
650
+ pair_surprisal=pair_surprisal,
651
+ pair_importance_prior=pair_prior,
652
+ pair_valid_mask=pair_valid_mask,
653
+ )
654
+
655
+
656
+ def build_bundle(cfg: BundleConfig) -> None:
657
+ print("[1/4] Counting normalized words for letter-only BPE training...", flush=True)
658
+ word_freq = count_words_for_bpe(cfg)
659
+ if not word_freq:
660
+ raise SystemExit("No usable normalized text found in the stream.")
661
+
662
+ print(f"[1/4] Unique normalized letter-words: {len(word_freq):,}", flush=True)
663
+
664
+ print("[2/4] Training letter-only BPE-style subword vocab...", flush=True)
665
+ vocab, merges = train_bpe_from_words(word_freq, cfg.vocab_size)
666
+ print(f"[2/4] Final vocab size: {len(vocab)}", flush=True)
667
+
668
+ print("[3/4] Streaming second pass for token and pair stats...", flush=True)
669
+ tokenizer = GreedyTokenizer(vocab, merges, word_cache_size=cfg.word_cache_size)
670
+ token_counts, pair_counts = second_pass_stats(cfg, tokenizer)
671
+
672
+ if token_counts.sum() <= 0:
673
+ raise SystemExit("Second pass produced no tokens. Check dataset fields or normalization rules.")
674
+
675
+ print(f"[3/4] Final token count: {int(token_counts.sum()):,}", flush=True)
676
+
677
+ print("[4/4] Building priors and saving bundle...", flush=True)
678
+ token_probs, token_surprisal, token_prior = build_priors_from_counts(
679
+ token_counts,
680
+ cfg.token_prior_clip,
681
+ )
682
+ pair_probs, pair_surprisal, pair_prior, pair_valid_mask = build_pair_priors(
683
+ pair_counts,
684
+ cfg.min_pair_count,
685
+ cfg.pair_prior_clip,
686
+ )
687
+
688
+ save_bundle(
689
+ cfg=cfg,
690
+ vocab=vocab,
691
+ merges=merges,
692
+ token_counts=token_counts,
693
+ token_probs=token_probs,
694
+ token_surprisal=token_surprisal,
695
+ token_prior=token_prior,
696
+ pair_counts=pair_counts,
697
+ pair_probs=pair_probs,
698
+ pair_surprisal=pair_surprisal,
699
+ pair_prior=pair_prior,
700
+ pair_valid_mask=pair_valid_mask,
701
+ )
702
+ print(f"Done. Bundle written to: {cfg.output}", flush=True)
703
+
704
+
705
+ def parse_args(argv: Optional[Sequence[str]] = None) -> BundleConfig:
706
+ parser = argparse.ArgumentParser(description="Build an ASCII-limited English vocab + prior bundle from a streamed dataset.")
707
+ parser.add_argument("--output", required=True, help="Output directory for the bundle.")
708
+ parser.add_argument("--dataset", default=DEFAULT_DATASET, help=f"Hugging Face dataset name. Default: {DEFAULT_DATASET}")
709
+ parser.add_argument("--config", default=DEFAULT_CONFIG, help=f"Hugging Face dataset config. Default: {DEFAULT_CONFIG}")
710
+ parser.add_argument("--split", default=DEFAULT_SPLIT, help=f"Dataset split. Default: {DEFAULT_SPLIT}")
711
+ parser.add_argument("--vocab-size", type=int, default=2000, help="Final vocab size including special tokens.")
712
+ parser.add_argument("--bpe-train-chars", type=int, default=100_000_000, help="Normalized character budget for vocab learning.")
713
+ parser.add_argument("--final-token-budget", type=int, default=100_000_000, help="Final tokenizer token budget for priors.")
714
+ parser.add_argument("--max-examples", type=int, default=None, help="Optional cap on streamed examples for testing.")
715
+ parser.add_argument("--min-pair-count", type=int, default=5, help="Minimum pair count to trust a pair prior.")
716
+ parser.add_argument("--token-prior-clip", type=float, default=3.0, help="Clip for token prior z-scores.")
717
+ parser.add_argument("--pair-prior-clip", type=float, default=3.0, help="Clip for pair prior z-scores.")
718
+ parser.add_argument("--word-cache-size", type=int, default=200000, help="Max cached normalized words for faster runtime tokenization.")
719
+ args = parser.parse_args(argv)
720
+
721
+ return BundleConfig(
722
+ output=Path(args.output),
723
+ dataset=args.dataset,
724
+ config=args.config,
725
+ split=args.split,
726
+ vocab_size=args.vocab_size,
727
+ bpe_train_chars=args.bpe_train_chars,
728
+ final_token_budget=args.final_token_budget,
729
+ max_examples=args.max_examples,
730
+ min_pair_count=args.min_pair_count,
731
+ token_prior_clip=args.token_prior_clip,
732
+ pair_prior_clip=args.pair_prior_clip,
733
+ word_cache_size=args.word_cache_size,
734
+ )
735
+
736
+
737
+ def main(argv: Optional[Sequence[str]] = None) -> int:
738
+ cfg = parse_args(argv)
739
+ build_bundle(cfg)
740
+ return 0
741
+
742
+
743
+ if __name__ == "__main__":
744
+ raise SystemExit(main())
final_infer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbed84c723a50e97b426806ac5d070e7820d46c4634c14e41d20a8d2bada02ce
3
+ size 15957156
pgsm_sparse_rope_lm.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ pgsm_sparse_rope_lm.py
4
+
5
+ Reusable model module for the custom LLM architecture developed from the
6
+ long-memory experiments:
7
+
8
+ Parallel Geometric State Model (PGSM)
9
+ + optional query-only sparse RoPE retrieval head
10
+
11
+ Core design:
12
+ - Fast attention-free local backbone.
13
+ - Depthwise causal convolution for local state propagation.
14
+ - Gated state mixing.
15
+ - Gated MLP blocks.
16
+ - Optional sparse retrieval only at selected query positions.
17
+ - Retrieval dimension is configurable; experiments showed retrieval_dim=512
18
+ was the first strong setting at block_size=1024 / distance=768.
19
+
20
+ This file is intentionally model-only. It does not include training loops,
21
+ datasets, benchmark code, or CLI handling. Import it from your training module.
22
+
23
+ Example:
24
+
25
+ from pgsm_sparse_rope_lm import PGSMConfig, PGSMSparseRoPELM
26
+
27
+ cfg = PGSMConfig.small(vocab_size=256, block_size=1024)
28
+ model = PGSMSparseRoPELM(cfg)
29
+
30
+ logits, loss = model(input_ids, labels)
31
+
32
+ For retrieval tasks where only specific answer/query positions should do sparse
33
+ long-range retrieval:
34
+
35
+ logits, loss = model(input_ids, labels, retrieval_positions=answer_pos)
36
+
37
+ For normal causal LM pretraining, you can disable sparse retrieval or use
38
+ automatic query-token detection if your data marks query positions.
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ import math
44
+ from dataclasses import asdict, dataclass, replace
45
+ from typing import Any, Dict, Iterable, Optional, Tuple
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+
51
+
52
+ # -----------------------------
53
+ # Configuration
54
+ # -----------------------------
55
+
56
+ @dataclass(frozen=True)
57
+ class PGSMConfig:
58
+ # Vocabulary / sequence
59
+ vocab_size: int = 256
60
+ block_size: int = 1024
61
+
62
+ # Backbone
63
+ dim: int = 192
64
+ layers: int = 3
65
+ hidden: int = 384
66
+ kernel_size: int = 17
67
+ dropout: float = 0.0
68
+
69
+ # Sparse retrieval
70
+ use_sparse_retrieval: bool = True
71
+ retrieval_dim: int = 512
72
+ retrieval_heads: int = 4
73
+ retrieval_dropout: float = 0.0
74
+
75
+ # Retrieval positioning
76
+ # If retrieval_positions is passed to forward(), that wins.
77
+ # Otherwise, if query_token_id is set, positions matching it can be used.
78
+ # Otherwise, retrieval can be skipped or applied to the final token.
79
+ query_token_id: Optional[int] = None
80
+ auto_retrieve_on_query_token: bool = False
81
+ retrieve_at_last_token_if_unspecified: bool = False
82
+
83
+ # Output / loss behavior
84
+ tie_weights: bool = True
85
+ use_post_retrieval_block: bool = True
86
+ ignore_index: int = -100
87
+
88
+ # Init
89
+ init_std: float = 0.02
90
+
91
+ def to_dict(self) -> Dict[str, Any]:
92
+ return asdict(self)
93
+
94
+ @classmethod
95
+ def tiny(
96
+ cls,
97
+ vocab_size: int = 256,
98
+ block_size: int = 512,
99
+ **overrides: Any,
100
+ ) -> "PGSMConfig":
101
+ cfg = cls(
102
+ vocab_size=vocab_size,
103
+ block_size=block_size,
104
+ dim=128,
105
+ layers=3,
106
+ hidden=256,
107
+ kernel_size=17,
108
+ retrieval_dim=256,
109
+ retrieval_heads=4,
110
+ )
111
+ return replace(cfg, **overrides)
112
+
113
+ @classmethod
114
+ def small(
115
+ cls,
116
+ vocab_size: int = 256,
117
+ block_size: int = 1024,
118
+ **overrides: Any,
119
+ ) -> "PGSMConfig":
120
+ # Closest to the successful experiment, with retrieval_dim=512.
121
+ cfg = cls(
122
+ vocab_size=vocab_size,
123
+ block_size=block_size,
124
+ dim=192,
125
+ layers=3,
126
+ hidden=384,
127
+ kernel_size=17,
128
+ retrieval_dim=512,
129
+ retrieval_heads=4,
130
+ )
131
+ return replace(cfg, **overrides)
132
+
133
+ @classmethod
134
+ def medium(
135
+ cls,
136
+ vocab_size: int,
137
+ block_size: int = 2048,
138
+ **overrides: Any,
139
+ ) -> "PGSMConfig":
140
+ cfg = cls(
141
+ vocab_size=vocab_size,
142
+ block_size=block_size,
143
+ dim=384,
144
+ layers=6,
145
+ hidden=1024,
146
+ kernel_size=21,
147
+ retrieval_dim=768,
148
+ retrieval_heads=8,
149
+ dropout=0.0,
150
+ retrieval_dropout=0.0,
151
+ )
152
+ return replace(cfg, **overrides)
153
+
154
+ @classmethod
155
+ def large(
156
+ cls,
157
+ vocab_size: int,
158
+ block_size: int = 4096,
159
+ **overrides: Any,
160
+ ) -> "PGSMConfig":
161
+ cfg = cls(
162
+ vocab_size=vocab_size,
163
+ block_size=block_size,
164
+ dim=768,
165
+ layers=12,
166
+ hidden=2048,
167
+ kernel_size=25,
168
+ retrieval_dim=1024,
169
+ retrieval_heads=8,
170
+ dropout=0.0,
171
+ retrieval_dropout=0.0,
172
+ )
173
+ return replace(cfg, **overrides)
174
+
175
+
176
+ # -----------------------------
177
+ # Utility functions
178
+ # -----------------------------
179
+
180
+ def count_parameters(module: nn.Module, trainable_only: bool = True) -> int:
181
+ if trainable_only:
182
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
183
+ return sum(p.numel() for p in module.parameters())
184
+
185
+
186
+ def init_pgsm_weights(module: nn.Module, std: float = 0.02) -> None:
187
+ if isinstance(module, (nn.Linear, nn.Embedding)):
188
+ nn.init.normal_(module.weight, mean=0.0, std=std)
189
+ if isinstance(module, nn.Linear) and module.bias is not None:
190
+ nn.init.zeros_(module.bias)
191
+
192
+
193
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
194
+ x_even = x[..., 0::2]
195
+ x_odd = x[..., 1::2]
196
+ return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
197
+
198
+
199
+ def _positions_from_query_tokens(input_ids: torch.Tensor, query_token_id: int) -> torch.Tensor:
200
+ """
201
+ Return one retrieval position per batch row.
202
+
203
+ If multiple query tokens exist, the last one is used.
204
+ If none exist in a row, the final token is used.
205
+ """
206
+ batch, steps = input_ids.shape
207
+ device = input_ids.device
208
+ matches = input_ids.eq(int(query_token_id))
209
+ positions = torch.full((batch,), steps - 1, dtype=torch.long, device=device)
210
+
211
+ for b in range(batch):
212
+ found = torch.nonzero(matches[b], as_tuple=False).flatten()
213
+ if found.numel() > 0:
214
+ positions[b] = found[-1]
215
+ return positions
216
+
217
+
218
+ # -----------------------------
219
+ # Backbone blocks
220
+ # -----------------------------
221
+
222
+ class CausalDepthwiseConv(nn.Module):
223
+ """
224
+ Depthwise causal convolution.
225
+
226
+ This is the main local state propagation primitive. It is parallel over time
227
+ during training and does not construct an attention matrix.
228
+ """
229
+
230
+ def __init__(self, dim: int, kernel_size: int):
231
+ super().__init__()
232
+ self.dim = int(dim)
233
+ self.kernel_size = int(kernel_size)
234
+ self.conv = nn.Conv1d(
235
+ dim,
236
+ dim,
237
+ kernel_size,
238
+ groups=dim,
239
+ padding=kernel_size - 1,
240
+ )
241
+
242
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
243
+ # x: [B,T,D]
244
+ y = self.conv(x.transpose(1, 2))
245
+ y = y[:, :, : x.size(1)]
246
+ return y.transpose(1, 2)
247
+
248
+
249
+ class ParallelGeometricBlock(nn.Module):
250
+ """
251
+ Attention-free parallel geometric/state-mixing block.
252
+
253
+ Structure:
254
+ norm -> causal depthwise local state -> gated state residual
255
+ norm -> gated MLP -> residual
256
+ """
257
+
258
+ def __init__(self, dim: int, hidden: int, kernel_size: int, dropout: float = 0.0):
259
+ super().__init__()
260
+ self.norm_state = nn.LayerNorm(dim)
261
+ self.local_state = CausalDepthwiseConv(dim, kernel_size)
262
+ self.state_mix = nn.Linear(dim, dim)
263
+ self.state_gate = nn.Linear(dim, dim)
264
+ self.drop_state = nn.Dropout(dropout)
265
+
266
+ self.norm_ff = nn.LayerNorm(dim)
267
+ self.ff_in = nn.Linear(dim, hidden * 2)
268
+ self.ff_out = nn.Linear(hidden, dim)
269
+ self.drop_ff = nn.Dropout(dropout)
270
+
271
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
272
+ h = self.norm_state(x)
273
+ local = self.local_state(h)
274
+ gated_state = torch.tanh(self.state_mix(local)) * torch.sigmoid(self.state_gate(h))
275
+ x = x + self.drop_state(gated_state)
276
+
277
+ h = self.norm_ff(x)
278
+ value, gate = self.ff_in(h).chunk(2, dim=-1)
279
+ ff = self.ff_out(F.silu(gate) * value)
280
+ x = x + self.drop_ff(ff)
281
+ return x
282
+
283
+
284
+ # -----------------------------
285
+ # Sparse RoPE retrieval
286
+ # -----------------------------
287
+
288
+ class RotaryCache(nn.Module):
289
+ """
290
+ RoPE cache for tensors shaped [B,H,T,D] and query tensors [B,H,1,D].
291
+ """
292
+
293
+ def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0):
294
+ super().__init__()
295
+ if head_dim % 2 != 0:
296
+ raise ValueError("head_dim must be even for RoPE")
297
+ self.head_dim = int(head_dim)
298
+ self.max_seq_len = int(max_seq_len)
299
+
300
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
301
+ t = torch.arange(max_seq_len).float()
302
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
303
+
304
+ # Duplicate so cos/sin match [D] after rotate_half.
305
+ emb = torch.cat((freqs, freqs), dim=-1)
306
+ self.register_buffer("cos", emb.cos()[None, None, :, :], persistent=False)
307
+ self.register_buffer("sin", emb.sin()[None, None, :, :], persistent=False)
308
+
309
+ def apply_sequence(self, x: torch.Tensor) -> torch.Tensor:
310
+ # x: [B,H,T,D]
311
+ steps = x.size(-2)
312
+ if steps > self.max_seq_len:
313
+ raise ValueError(
314
+ f"Sequence length {steps} exceeds RoPE cache length {self.max_seq_len}. "
315
+ "Increase config.block_size."
316
+ )
317
+ cos = self.cos[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
318
+ sin = self.sin[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
319
+ return (x * cos) + (rotate_half(x) * sin)
320
+
321
+ def apply_query_positions(self, q: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
322
+ # q: [B,H,1,D], positions: [B]
323
+ cos = self.cos[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
324
+ sin = self.sin[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
325
+ return (q * cos) + (rotate_half(q) * sin)
326
+
327
+
328
+ class QueryOnlyRoPERetriever(nn.Module):
329
+ """
330
+ Sparse retrieval applied only to selected positions.
331
+
332
+ For each batch row, one retrieval position attends backward over prior token
333
+ states using RoPE Q/K. This is O(T) per retrieved position, not O(T^2).
334
+
335
+ This module is the key successful retrieval primitive from the experiments.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ dim: int,
341
+ retrieval_dim: int,
342
+ retrieval_heads: int,
343
+ block_size: int,
344
+ dropout: float = 0.0,
345
+ ):
346
+ super().__init__()
347
+ if retrieval_dim % retrieval_heads != 0:
348
+ raise ValueError("retrieval_dim must be divisible by retrieval_heads")
349
+ self.dim = int(dim)
350
+ self.retrieval_dim = int(retrieval_dim)
351
+ self.retrieval_heads = int(retrieval_heads)
352
+ self.head_dim = retrieval_dim // retrieval_heads
353
+ if self.head_dim % 2 != 0:
354
+ raise ValueError("retrieval_dim / retrieval_heads must be even for RoPE")
355
+
356
+ self.norm = nn.LayerNorm(dim)
357
+ self.q = nn.Linear(dim, retrieval_dim)
358
+ self.k = nn.Linear(dim, retrieval_dim)
359
+ self.v = nn.Linear(dim, retrieval_dim)
360
+ self.out = nn.Linear(retrieval_dim, dim)
361
+ self.gate = nn.Linear(dim * 2, dim)
362
+ self.dropout = nn.Dropout(dropout)
363
+ self.rope = RotaryCache(self.head_dim, max_seq_len=block_size + 8)
364
+
365
+ def forward(self, x: torch.Tensor, retrieval_positions: torch.Tensor) -> torch.Tensor:
366
+ # x: [B,T,D], retrieval_positions: [B]
367
+ batch, steps, _ = x.shape
368
+ device = x.device
369
+ bidx = torch.arange(batch, device=device)
370
+
371
+ h = self.norm(x)
372
+
373
+ k = self.k(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
374
+ v = self.v(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
375
+ k = self.rope.apply_sequence(k)
376
+
377
+ qh = h[bidx, retrieval_positions]
378
+ q = self.q(qh).view(batch, self.retrieval_heads, 1, self.head_dim)
379
+ q = self.rope.apply_query_positions(q, retrieval_positions)
380
+
381
+ scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
382
+
383
+ # Strictly backward. The retrieval position cannot read itself.
384
+ pos = torch.arange(steps, device=device)[None, None, None, :]
385
+ causal_mask = pos < retrieval_positions[:, None, None, None]
386
+ scores = scores.masked_fill(~causal_mask, float("-inf"))
387
+
388
+ att = F.softmax(scores, dim=-1)
389
+ att = self.dropout(att)
390
+
391
+ read = (att @ v).transpose(1, 2).contiguous().view(batch, self.retrieval_dim)
392
+ read = self.out(read)
393
+
394
+ old = x[bidx, retrieval_positions]
395
+ gate = torch.sigmoid(self.gate(torch.cat([qh, read], dim=-1)))
396
+ new = old + gate * read
397
+
398
+ out = x.clone()
399
+ out[bidx, retrieval_positions] = new
400
+ return out
401
+
402
+
403
+ # -----------------------------
404
+ # Main model
405
+ # -----------------------------
406
+
407
+ class PGSMSparseRoPELM(nn.Module):
408
+ """
409
+ Parallel Geometric State Model with optional query-only sparse RoPE retrieval.
410
+
411
+ Forward API:
412
+ logits, loss = model(input_ids, labels=None, retrieval_positions=None)
413
+
414
+ input_ids:
415
+ LongTensor [B,T]
416
+
417
+ labels:
418
+ LongTensor [B,T], optional.
419
+ Standard next-token labels are supported.
420
+ Use config.ignore_index for ignored positions.
421
+
422
+ retrieval_positions:
423
+ Optional LongTensor [B].
424
+ If supplied, sparse retrieval is applied exactly at these positions.
425
+ If omitted, config controls whether to auto-detect query-token positions,
426
+ use final token, or skip retrieval.
427
+ """
428
+
429
+ def __init__(self, config: PGSMConfig):
430
+ super().__init__()
431
+ self.config = config
432
+
433
+ self.token_emb = nn.Embedding(config.vocab_size, config.dim)
434
+ self.blocks = nn.ModuleList(
435
+ [
436
+ ParallelGeometricBlock(
437
+ dim=config.dim,
438
+ hidden=config.hidden,
439
+ kernel_size=config.kernel_size,
440
+ dropout=config.dropout,
441
+ )
442
+ for _ in range(config.layers)
443
+ ]
444
+ )
445
+
446
+ self.retriever: Optional[QueryOnlyRoPERetriever]
447
+ if config.use_sparse_retrieval:
448
+ self.retriever = QueryOnlyRoPERetriever(
449
+ dim=config.dim,
450
+ retrieval_dim=config.retrieval_dim,
451
+ retrieval_heads=config.retrieval_heads,
452
+ block_size=config.block_size,
453
+ dropout=config.retrieval_dropout,
454
+ )
455
+ else:
456
+ self.retriever = None
457
+
458
+ self.post_retrieval_block: Optional[ParallelGeometricBlock]
459
+ if config.use_sparse_retrieval and config.use_post_retrieval_block:
460
+ self.post_retrieval_block = ParallelGeometricBlock(
461
+ dim=config.dim,
462
+ hidden=config.hidden,
463
+ kernel_size=config.kernel_size,
464
+ dropout=config.dropout,
465
+ )
466
+ else:
467
+ self.post_retrieval_block = None
468
+
469
+ self.final_norm = nn.LayerNorm(config.dim)
470
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
471
+
472
+ self.apply(lambda module: init_pgsm_weights(module, std=config.init_std))
473
+
474
+ if config.tie_weights:
475
+ self.lm_head.weight = self.token_emb.weight
476
+
477
+ @property
478
+ def block_size(self) -> int:
479
+ return self.config.block_size
480
+
481
+ @property
482
+ def vocab_size(self) -> int:
483
+ return self.config.vocab_size
484
+
485
+ def num_parameters(self, trainable_only: bool = True) -> int:
486
+ return count_parameters(self, trainable_only=trainable_only)
487
+
488
+ def _resolve_retrieval_positions(
489
+ self,
490
+ input_ids: torch.Tensor,
491
+ retrieval_positions: Optional[torch.Tensor],
492
+ ) -> Optional[torch.Tensor]:
493
+ if not self.config.use_sparse_retrieval:
494
+ return None
495
+
496
+ if retrieval_positions is not None:
497
+ return retrieval_positions.to(device=input_ids.device, dtype=torch.long)
498
+
499
+ if (
500
+ self.config.auto_retrieve_on_query_token
501
+ and self.config.query_token_id is not None
502
+ ):
503
+ return _positions_from_query_tokens(input_ids, self.config.query_token_id)
504
+
505
+ if self.config.retrieve_at_last_token_if_unspecified:
506
+ return torch.full(
507
+ (input_ids.size(0),),
508
+ input_ids.size(1) - 1,
509
+ dtype=torch.long,
510
+ device=input_ids.device,
511
+ )
512
+
513
+ return None
514
+
515
+ def encode(
516
+ self,
517
+ input_ids: torch.Tensor,
518
+ retrieval_positions: Optional[torch.Tensor] = None,
519
+ ) -> torch.Tensor:
520
+ if input_ids.dim() != 2:
521
+ raise ValueError("input_ids must have shape [batch, steps]")
522
+ if input_ids.size(1) > self.config.block_size:
523
+ raise ValueError(
524
+ f"Input length {input_ids.size(1)} exceeds config.block_size={self.config.block_size}"
525
+ )
526
+
527
+ x = self.token_emb(input_ids)
528
+
529
+ for block in self.blocks:
530
+ x = block(x)
531
+
532
+ positions = self._resolve_retrieval_positions(input_ids, retrieval_positions)
533
+ if positions is not None:
534
+ if self.retriever is None:
535
+ raise RuntimeError("retriever is None but retrieval positions were resolved")
536
+ x = self.retriever(x, positions)
537
+ if self.post_retrieval_block is not None:
538
+ x = self.post_retrieval_block(x)
539
+
540
+ return self.final_norm(x)
541
+
542
+ def forward(
543
+ self,
544
+ input_ids: torch.Tensor,
545
+ labels: Optional[torch.Tensor] = None,
546
+ retrieval_positions: Optional[torch.Tensor] = None,
547
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
548
+ x = self.encode(input_ids, retrieval_positions=retrieval_positions)
549
+ logits = self.lm_head(x)
550
+
551
+ loss: Optional[torch.Tensor] = None
552
+ if labels is not None:
553
+ loss = F.cross_entropy(
554
+ logits.reshape(-1, logits.size(-1)),
555
+ labels.reshape(-1),
556
+ ignore_index=self.config.ignore_index,
557
+ )
558
+
559
+ return logits, loss
560
+
561
+ @torch.no_grad()
562
+ def generate(
563
+ self,
564
+ input_ids: torch.Tensor,
565
+ max_new_tokens: int,
566
+ temperature: float = 1.0,
567
+ top_k: Optional[int] = None,
568
+ ) -> torch.Tensor:
569
+ """
570
+ Simple generation helper.
571
+
572
+ For normal generation, sparse retrieval is not automatically applied unless
573
+ config.retrieve_at_last_token_if_unspecified=True or query-token detection
574
+ is enabled. Training modules can provide their own generation loop if they
575
+ need custom retrieval-position behavior.
576
+ """
577
+ self.eval()
578
+ for _ in range(max_new_tokens):
579
+ idx_cond = input_ids[:, -self.config.block_size :]
580
+ logits, _ = self(idx_cond)
581
+ logits = logits[:, -1, :]
582
+
583
+ if temperature <= 0:
584
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
585
+ else:
586
+ logits = logits / temperature
587
+ if top_k is not None:
588
+ values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
589
+ logits = logits.masked_fill(logits < values[:, [-1]], float("-inf"))
590
+ probs = F.softmax(logits, dim=-1)
591
+ next_id = torch.multinomial(probs, num_samples=1)
592
+
593
+ input_ids = torch.cat([input_ids, next_id], dim=1)
594
+
595
+ return input_ids
596
+
597
+
598
+ # -----------------------------
599
+ # Convenience factory
600
+ # -----------------------------
601
+
602
+ def build_pgsm_model(
603
+ size: str = "small",
604
+ vocab_size: int = 256,
605
+ block_size: int = 1024,
606
+ **overrides: Any,
607
+ ) -> PGSMSparseRoPELM:
608
+ size = size.lower().strip()
609
+ if size == "tiny":
610
+ cfg = PGSMConfig.tiny(vocab_size=vocab_size, block_size=block_size, **overrides)
611
+ elif size == "small":
612
+ cfg = PGSMConfig.small(vocab_size=vocab_size, block_size=block_size, **overrides)
613
+ elif size == "medium":
614
+ cfg = PGSMConfig.medium(vocab_size=vocab_size, block_size=block_size, **overrides)
615
+ elif size == "large":
616
+ cfg = PGSMConfig.large(vocab_size=vocab_size, block_size=block_size, **overrides)
617
+ else:
618
+ raise ValueError(f"Unknown model size: {size!r}. Use tiny, small, medium, or large.")
619
+ return PGSMSparseRoPELM(cfg)
620
+
621
+
622
+ __all__ = [
623
+ "PGSMConfig",
624
+ "PGSMSparseRoPELM",
625
+ "build_pgsm_model",
626
+ "count_parameters",
627
+ ]