| |
| """ |
| pgsm_sparse_rope_lm.py |
| |
| Reusable model module for the custom LLM architecture developed from the |
| long-memory experiments: |
| |
| Parallel Geometric State Model (PGSM) |
| + optional query-only sparse RoPE retrieval head |
| |
| Core design: |
| - Fast attention-free local backbone. |
| - Depthwise causal convolution for local state propagation. |
| - Gated state mixing. |
| - Gated MLP blocks. |
| - Optional sparse retrieval only at selected query positions. |
| - Retrieval dimension is configurable; experiments showed retrieval_dim=512 |
| was the first strong setting at block_size=1024 / distance=768. |
| |
| This file is intentionally model-only. It does not include training loops, |
| datasets, benchmark code, or CLI handling. Import it from your training module. |
| |
| Example: |
| |
| from pgsm_sparse_rope_lm import PGSMConfig, PGSMSparseRoPELM |
| |
| cfg = PGSMConfig.small(vocab_size=256, block_size=1024) |
| model = PGSMSparseRoPELM(cfg) |
| |
| logits, loss = model(input_ids, labels) |
| |
| For retrieval tasks where only specific answer/query positions should do sparse |
| long-range retrieval: |
| |
| logits, loss = model(input_ids, labels, retrieval_positions=answer_pos) |
| |
| For normal causal LM pretraining, you can disable sparse retrieval or use |
| automatic query-token detection if your data marks query positions. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import asdict, dataclass, replace |
| from typing import Any, Dict, Iterable, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class PGSMConfig: |
| |
| vocab_size: int = 256 |
| block_size: int = 1024 |
|
|
| |
| dim: int = 192 |
| layers: int = 3 |
| hidden: int = 384 |
| kernel_size: int = 17 |
| dropout: float = 0.0 |
|
|
| |
| use_sparse_retrieval: bool = True |
| retrieval_dim: int = 512 |
| retrieval_heads: int = 4 |
| retrieval_dropout: float = 0.0 |
|
|
| |
| |
| |
| |
| query_token_id: Optional[int] = None |
| auto_retrieve_on_query_token: bool = False |
| retrieve_at_last_token_if_unspecified: bool = False |
|
|
| |
| tie_weights: bool = True |
| use_post_retrieval_block: bool = True |
| ignore_index: int = -100 |
|
|
| |
| init_std: float = 0.02 |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return asdict(self) |
|
|
| @classmethod |
| def tiny( |
| cls, |
| vocab_size: int = 256, |
| block_size: int = 512, |
| **overrides: Any, |
| ) -> "PGSMConfig": |
| cfg = cls( |
| vocab_size=vocab_size, |
| block_size=block_size, |
| dim=128, |
| layers=3, |
| hidden=256, |
| kernel_size=17, |
| retrieval_dim=256, |
| retrieval_heads=4, |
| ) |
| return replace(cfg, **overrides) |
|
|
| @classmethod |
| def small( |
| cls, |
| vocab_size: int = 256, |
| block_size: int = 1024, |
| **overrides: Any, |
| ) -> "PGSMConfig": |
| |
| cfg = cls( |
| vocab_size=vocab_size, |
| block_size=block_size, |
| dim=192, |
| layers=3, |
| hidden=384, |
| kernel_size=17, |
| retrieval_dim=512, |
| retrieval_heads=4, |
| ) |
| return replace(cfg, **overrides) |
|
|
| @classmethod |
| def medium( |
| cls, |
| vocab_size: int, |
| block_size: int = 2048, |
| **overrides: Any, |
| ) -> "PGSMConfig": |
| cfg = cls( |
| vocab_size=vocab_size, |
| block_size=block_size, |
| dim=384, |
| layers=6, |
| hidden=1024, |
| kernel_size=21, |
| retrieval_dim=768, |
| retrieval_heads=8, |
| dropout=0.0, |
| retrieval_dropout=0.0, |
| ) |
| return replace(cfg, **overrides) |
|
|
| @classmethod |
| def large( |
| cls, |
| vocab_size: int, |
| block_size: int = 4096, |
| **overrides: Any, |
| ) -> "PGSMConfig": |
| cfg = cls( |
| vocab_size=vocab_size, |
| block_size=block_size, |
| dim=768, |
| layers=12, |
| hidden=2048, |
| kernel_size=25, |
| retrieval_dim=1024, |
| retrieval_heads=8, |
| dropout=0.0, |
| retrieval_dropout=0.0, |
| ) |
| return replace(cfg, **overrides) |
|
|
|
|
| |
| |
| |
|
|
| def count_parameters(module: nn.Module, trainable_only: bool = True) -> int: |
| if trainable_only: |
| return sum(p.numel() for p in module.parameters() if p.requires_grad) |
| return sum(p.numel() for p in module.parameters()) |
|
|
|
|
| def init_pgsm_weights(module: nn.Module, std: float = 0.02) -> None: |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x_even = x[..., 0::2] |
| x_odd = x[..., 1::2] |
| return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) |
|
|
|
|
| def _positions_from_query_tokens(input_ids: torch.Tensor, query_token_id: int) -> torch.Tensor: |
| """ |
| Return one retrieval position per batch row. |
| |
| If multiple query tokens exist, the last one is used. |
| If none exist in a row, the final token is used. |
| """ |
| batch, steps = input_ids.shape |
| device = input_ids.device |
| matches = input_ids.eq(int(query_token_id)) |
| positions = torch.full((batch,), steps - 1, dtype=torch.long, device=device) |
|
|
| for b in range(batch): |
| found = torch.nonzero(matches[b], as_tuple=False).flatten() |
| if found.numel() > 0: |
| positions[b] = found[-1] |
| return positions |
|
|
|
|
| |
| |
| |
|
|
| class CausalDepthwiseConv(nn.Module): |
| """ |
| Depthwise causal convolution. |
| |
| This is the main local state propagation primitive. It is parallel over time |
| during training and does not construct an attention matrix. |
| """ |
|
|
| def __init__(self, dim: int, kernel_size: int): |
| super().__init__() |
| self.dim = int(dim) |
| self.kernel_size = int(kernel_size) |
| self.conv = nn.Conv1d( |
| dim, |
| dim, |
| kernel_size, |
| groups=dim, |
| padding=kernel_size - 1, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| y = self.conv(x.transpose(1, 2)) |
| y = y[:, :, : x.size(1)] |
| return y.transpose(1, 2) |
|
|
|
|
| class ParallelGeometricBlock(nn.Module): |
| """ |
| Attention-free parallel geometric/state-mixing block. |
| |
| Structure: |
| norm -> causal depthwise local state -> gated state residual |
| norm -> gated MLP -> residual |
| """ |
|
|
| def __init__(self, dim: int, hidden: int, kernel_size: int, dropout: float = 0.0): |
| super().__init__() |
| self.norm_state = nn.LayerNorm(dim) |
| self.local_state = CausalDepthwiseConv(dim, kernel_size) |
| self.state_mix = nn.Linear(dim, dim) |
| self.state_gate = nn.Linear(dim, dim) |
| self.drop_state = nn.Dropout(dropout) |
|
|
| self.norm_ff = nn.LayerNorm(dim) |
| self.ff_in = nn.Linear(dim, hidden * 2) |
| self.ff_out = nn.Linear(hidden, dim) |
| self.drop_ff = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.norm_state(x) |
| local = self.local_state(h) |
| gated_state = torch.tanh(self.state_mix(local)) * torch.sigmoid(self.state_gate(h)) |
| x = x + self.drop_state(gated_state) |
|
|
| h = self.norm_ff(x) |
| value, gate = self.ff_in(h).chunk(2, dim=-1) |
| ff = self.ff_out(F.silu(gate) * value) |
| x = x + self.drop_ff(ff) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class RotaryCache(nn.Module): |
| """ |
| RoPE cache for tensors shaped [B,H,T,D] and query tensors [B,H,1,D]. |
| """ |
|
|
| def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0): |
| super().__init__() |
| if head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for RoPE") |
| self.head_dim = int(head_dim) |
| self.max_seq_len = int(max_seq_len) |
|
|
| inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len).float() |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
|
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos", emb.cos()[None, None, :, :], persistent=False) |
| self.register_buffer("sin", emb.sin()[None, None, :, :], persistent=False) |
|
|
| def apply_sequence(self, x: torch.Tensor) -> torch.Tensor: |
| |
| steps = x.size(-2) |
| if steps > self.max_seq_len: |
| raise ValueError( |
| f"Sequence length {steps} exceeds RoPE cache length {self.max_seq_len}. " |
| "Increase config.block_size." |
| ) |
| cos = self.cos[:, :, :steps, :].to(device=x.device, dtype=x.dtype) |
| sin = self.sin[:, :, :steps, :].to(device=x.device, dtype=x.dtype) |
| return (x * cos) + (rotate_half(x) * sin) |
|
|
| def apply_query_positions(self, q: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: |
| |
| cos = self.cos[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :] |
| sin = self.sin[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :] |
| return (q * cos) + (rotate_half(q) * sin) |
|
|
|
|
| class QueryOnlyRoPERetriever(nn.Module): |
| """ |
| Sparse retrieval applied only to selected positions. |
| |
| For each batch row, one retrieval position attends backward over prior token |
| states using RoPE Q/K. This is O(T) per retrieved position, not O(T^2). |
| |
| This module is the key successful retrieval primitive from the experiments. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| retrieval_dim: int, |
| retrieval_heads: int, |
| block_size: int, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| if retrieval_dim % retrieval_heads != 0: |
| raise ValueError("retrieval_dim must be divisible by retrieval_heads") |
| self.dim = int(dim) |
| self.retrieval_dim = int(retrieval_dim) |
| self.retrieval_heads = int(retrieval_heads) |
| self.head_dim = retrieval_dim // retrieval_heads |
| if self.head_dim % 2 != 0: |
| raise ValueError("retrieval_dim / retrieval_heads must be even for RoPE") |
|
|
| self.norm = nn.LayerNorm(dim) |
| self.q = nn.Linear(dim, retrieval_dim) |
| self.k = nn.Linear(dim, retrieval_dim) |
| self.v = nn.Linear(dim, retrieval_dim) |
| self.out = nn.Linear(retrieval_dim, dim) |
| self.gate = nn.Linear(dim * 2, dim) |
| self.dropout = nn.Dropout(dropout) |
| self.rope = RotaryCache(self.head_dim, max_seq_len=block_size + 8) |
|
|
| def forward(self, x: torch.Tensor, retrieval_positions: torch.Tensor) -> torch.Tensor: |
| |
| batch, steps, _ = x.shape |
| device = x.device |
| bidx = torch.arange(batch, device=device) |
|
|
| h = self.norm(x) |
|
|
| k = self.k(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2) |
| v = self.v(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2) |
| k = self.rope.apply_sequence(k) |
|
|
| qh = h[bidx, retrieval_positions] |
| q = self.q(qh).view(batch, self.retrieval_heads, 1, self.head_dim) |
| q = self.rope.apply_query_positions(q, retrieval_positions) |
|
|
| scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
| |
| pos = torch.arange(steps, device=device)[None, None, None, :] |
| causal_mask = pos < retrieval_positions[:, None, None, None] |
| scores = scores.masked_fill(~causal_mask, float("-inf")) |
|
|
| att = F.softmax(scores, dim=-1) |
| att = self.dropout(att) |
|
|
| read = (att @ v).transpose(1, 2).contiguous().view(batch, self.retrieval_dim) |
| read = self.out(read) |
|
|
| old = x[bidx, retrieval_positions] |
| gate = torch.sigmoid(self.gate(torch.cat([qh, read], dim=-1))) |
| new = old + gate * read |
|
|
| out = x.clone() |
| out[bidx, retrieval_positions] = new |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class PGSMSparseRoPELM(nn.Module): |
| """ |
| Parallel Geometric State Model with optional query-only sparse RoPE retrieval. |
| |
| Forward API: |
| logits, loss = model(input_ids, labels=None, retrieval_positions=None) |
| |
| input_ids: |
| LongTensor [B,T] |
| |
| labels: |
| LongTensor [B,T], optional. |
| Standard next-token labels are supported. |
| Use config.ignore_index for ignored positions. |
| |
| retrieval_positions: |
| Optional LongTensor [B]. |
| If supplied, sparse retrieval is applied exactly at these positions. |
| If omitted, config controls whether to auto-detect query-token positions, |
| use final token, or skip retrieval. |
| """ |
|
|
| def __init__(self, config: PGSMConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.token_emb = nn.Embedding(config.vocab_size, config.dim) |
| self.blocks = nn.ModuleList( |
| [ |
| ParallelGeometricBlock( |
| dim=config.dim, |
| hidden=config.hidden, |
| kernel_size=config.kernel_size, |
| dropout=config.dropout, |
| ) |
| for _ in range(config.layers) |
| ] |
| ) |
|
|
| self.retriever: Optional[QueryOnlyRoPERetriever] |
| if config.use_sparse_retrieval: |
| self.retriever = QueryOnlyRoPERetriever( |
| dim=config.dim, |
| retrieval_dim=config.retrieval_dim, |
| retrieval_heads=config.retrieval_heads, |
| block_size=config.block_size, |
| dropout=config.retrieval_dropout, |
| ) |
| else: |
| self.retriever = None |
|
|
| self.post_retrieval_block: Optional[ParallelGeometricBlock] |
| if config.use_sparse_retrieval and config.use_post_retrieval_block: |
| self.post_retrieval_block = ParallelGeometricBlock( |
| dim=config.dim, |
| hidden=config.hidden, |
| kernel_size=config.kernel_size, |
| dropout=config.dropout, |
| ) |
| else: |
| self.post_retrieval_block = None |
|
|
| self.final_norm = nn.LayerNorm(config.dim) |
| self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
| self.apply(lambda module: init_pgsm_weights(module, std=config.init_std)) |
|
|
| if config.tie_weights: |
| self.lm_head.weight = self.token_emb.weight |
|
|
| @property |
| def block_size(self) -> int: |
| return self.config.block_size |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.config.vocab_size |
|
|
| def num_parameters(self, trainable_only: bool = True) -> int: |
| return count_parameters(self, trainable_only=trainable_only) |
|
|
| def _resolve_retrieval_positions( |
| self, |
| input_ids: torch.Tensor, |
| retrieval_positions: Optional[torch.Tensor], |
| ) -> Optional[torch.Tensor]: |
| if not self.config.use_sparse_retrieval: |
| return None |
|
|
| if retrieval_positions is not None: |
| return retrieval_positions.to(device=input_ids.device, dtype=torch.long) |
|
|
| if ( |
| self.config.auto_retrieve_on_query_token |
| and self.config.query_token_id is not None |
| ): |
| return _positions_from_query_tokens(input_ids, self.config.query_token_id) |
|
|
| if self.config.retrieve_at_last_token_if_unspecified: |
| return torch.full( |
| (input_ids.size(0),), |
| input_ids.size(1) - 1, |
| dtype=torch.long, |
| device=input_ids.device, |
| ) |
|
|
| return None |
|
|
| def encode( |
| self, |
| input_ids: torch.Tensor, |
| retrieval_positions: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if input_ids.dim() != 2: |
| raise ValueError("input_ids must have shape [batch, steps]") |
| if input_ids.size(1) > self.config.block_size: |
| raise ValueError( |
| f"Input length {input_ids.size(1)} exceeds config.block_size={self.config.block_size}" |
| ) |
|
|
| x = self.token_emb(input_ids) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| positions = self._resolve_retrieval_positions(input_ids, retrieval_positions) |
| if positions is not None: |
| if self.retriever is None: |
| raise RuntimeError("retriever is None but retrieval positions were resolved") |
| x = self.retriever(x, positions) |
| if self.post_retrieval_block is not None: |
| x = self.post_retrieval_block(x) |
|
|
| return self.final_norm(x) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| retrieval_positions: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| x = self.encode(input_ids, retrieval_positions=retrieval_positions) |
| logits = self.lm_head(x) |
|
|
| loss: Optional[torch.Tensor] = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| labels.reshape(-1), |
| ignore_index=self.config.ignore_index, |
| ) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| ) -> torch.Tensor: |
| """ |
| Simple generation helper. |
| |
| For normal generation, sparse retrieval is not automatically applied unless |
| config.retrieve_at_last_token_if_unspecified=True or query-token detection |
| is enabled. Training modules can provide their own generation loop if they |
| need custom retrieval-position behavior. |
| """ |
| self.eval() |
| for _ in range(max_new_tokens): |
| idx_cond = input_ids[:, -self.config.block_size :] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] |
|
|
| if temperature <= 0: |
| next_id = torch.argmax(logits, dim=-1, keepdim=True) |
| else: |
| logits = logits / temperature |
| if top_k is not None: |
| values, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits = logits.masked_fill(logits < values[:, [-1]], float("-inf")) |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
|
|
| input_ids = torch.cat([input_ids, next_id], dim=1) |
|
|
| return input_ids |
|
|
|
|
| |
| |
| |
|
|
| def build_pgsm_model( |
| size: str = "small", |
| vocab_size: int = 256, |
| block_size: int = 1024, |
| **overrides: Any, |
| ) -> PGSMSparseRoPELM: |
| size = size.lower().strip() |
| if size == "tiny": |
| cfg = PGSMConfig.tiny(vocab_size=vocab_size, block_size=block_size, **overrides) |
| elif size == "small": |
| cfg = PGSMConfig.small(vocab_size=vocab_size, block_size=block_size, **overrides) |
| elif size == "medium": |
| cfg = PGSMConfig.medium(vocab_size=vocab_size, block_size=block_size, **overrides) |
| elif size == "large": |
| cfg = PGSMConfig.large(vocab_size=vocab_size, block_size=block_size, **overrides) |
| else: |
| raise ValueError(f"Unknown model size: {size!r}. Use tiny, small, medium, or large.") |
| return PGSMSparseRoPELM(cfg) |
|
|
|
|
| __all__ = [ |
| "PGSMConfig", |
| "PGSMSparseRoPELM", |
| "build_pgsm_model", |
| "count_parameters", |
| ] |
|
|