#!/usr/bin/env python3 """ pgsm_sparse_rope_lm.py Reusable model module for the custom LLM architecture developed from the long-memory experiments: Parallel Geometric State Model (PGSM) + optional query-only sparse RoPE retrieval head Core design: - Fast attention-free local backbone. - Depthwise causal convolution for local state propagation. - Gated state mixing. - Gated MLP blocks. - Optional sparse retrieval only at selected query positions. - Retrieval dimension is configurable; experiments showed retrieval_dim=512 was the first strong setting at block_size=1024 / distance=768. This file is intentionally model-only. It does not include training loops, datasets, benchmark code, or CLI handling. Import it from your training module. Example: from pgsm_sparse_rope_lm import PGSMConfig, PGSMSparseRoPELM cfg = PGSMConfig.small(vocab_size=256, block_size=1024) model = PGSMSparseRoPELM(cfg) logits, loss = model(input_ids, labels) For retrieval tasks where only specific answer/query positions should do sparse long-range retrieval: logits, loss = model(input_ids, labels, retrieval_positions=answer_pos) For normal causal LM pretraining, you can disable sparse retrieval or use automatic query-token detection if your data marks query positions. """ from __future__ import annotations import math from dataclasses import asdict, dataclass, replace from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ----------------------------- # Configuration # ----------------------------- @dataclass(frozen=True) class PGSMConfig: # Vocabulary / sequence vocab_size: int = 256 block_size: int = 1024 # Backbone dim: int = 192 layers: int = 3 hidden: int = 384 kernel_size: int = 17 dropout: float = 0.0 # Sparse retrieval use_sparse_retrieval: bool = True retrieval_dim: int = 512 retrieval_heads: int = 4 retrieval_dropout: float = 0.0 # Retrieval positioning # If retrieval_positions is passed to forward(), that wins. # Otherwise, if query_token_id is set, positions matching it can be used. # Otherwise, retrieval can be skipped or applied to the final token. query_token_id: Optional[int] = None auto_retrieve_on_query_token: bool = False retrieve_at_last_token_if_unspecified: bool = False # Output / loss behavior tie_weights: bool = True use_post_retrieval_block: bool = True ignore_index: int = -100 # Init init_std: float = 0.02 def to_dict(self) -> Dict[str, Any]: return asdict(self) @classmethod def tiny( cls, vocab_size: int = 256, block_size: int = 512, **overrides: Any, ) -> "PGSMConfig": cfg = cls( vocab_size=vocab_size, block_size=block_size, dim=128, layers=3, hidden=256, kernel_size=17, retrieval_dim=256, retrieval_heads=4, ) return replace(cfg, **overrides) @classmethod def small( cls, vocab_size: int = 256, block_size: int = 1024, **overrides: Any, ) -> "PGSMConfig": # Closest to the successful experiment, with retrieval_dim=512. cfg = cls( vocab_size=vocab_size, block_size=block_size, dim=192, layers=3, hidden=384, kernel_size=17, retrieval_dim=512, retrieval_heads=4, ) return replace(cfg, **overrides) @classmethod def medium( cls, vocab_size: int, block_size: int = 2048, **overrides: Any, ) -> "PGSMConfig": cfg = cls( vocab_size=vocab_size, block_size=block_size, dim=384, layers=6, hidden=1024, kernel_size=21, retrieval_dim=768, retrieval_heads=8, dropout=0.0, retrieval_dropout=0.0, ) return replace(cfg, **overrides) @classmethod def large( cls, vocab_size: int, block_size: int = 4096, **overrides: Any, ) -> "PGSMConfig": cfg = cls( vocab_size=vocab_size, block_size=block_size, dim=768, layers=12, hidden=2048, kernel_size=25, retrieval_dim=1024, retrieval_heads=8, dropout=0.0, retrieval_dropout=0.0, ) return replace(cfg, **overrides) # ----------------------------- # Utility functions # ----------------------------- def count_parameters(module: nn.Module, trainable_only: bool = True) -> int: if trainable_only: return sum(p.numel() for p in module.parameters() if p.requires_grad) return sum(p.numel() for p in module.parameters()) def init_pgsm_weights(module: nn.Module, std: float = 0.02) -> None: if isinstance(module, (nn.Linear, nn.Embedding)): nn.init.normal_(module.weight, mean=0.0, std=std) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias) def rotate_half(x: torch.Tensor) -> torch.Tensor: x_even = x[..., 0::2] x_odd = x[..., 1::2] return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) def _positions_from_query_tokens(input_ids: torch.Tensor, query_token_id: int) -> torch.Tensor: """ Return one retrieval position per batch row. If multiple query tokens exist, the last one is used. If none exist in a row, the final token is used. """ batch, steps = input_ids.shape device = input_ids.device matches = input_ids.eq(int(query_token_id)) positions = torch.full((batch,), steps - 1, dtype=torch.long, device=device) for b in range(batch): found = torch.nonzero(matches[b], as_tuple=False).flatten() if found.numel() > 0: positions[b] = found[-1] return positions # ----------------------------- # Backbone blocks # ----------------------------- class CausalDepthwiseConv(nn.Module): """ Depthwise causal convolution. This is the main local state propagation primitive. It is parallel over time during training and does not construct an attention matrix. """ def __init__(self, dim: int, kernel_size: int): super().__init__() self.dim = int(dim) self.kernel_size = int(kernel_size) self.conv = nn.Conv1d( dim, dim, kernel_size, groups=dim, padding=kernel_size - 1, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B,T,D] y = self.conv(x.transpose(1, 2)) y = y[:, :, : x.size(1)] return y.transpose(1, 2) class ParallelGeometricBlock(nn.Module): """ Attention-free parallel geometric/state-mixing block. Structure: norm -> causal depthwise local state -> gated state residual norm -> gated MLP -> residual """ def __init__(self, dim: int, hidden: int, kernel_size: int, dropout: float = 0.0): super().__init__() self.norm_state = nn.LayerNorm(dim) self.local_state = CausalDepthwiseConv(dim, kernel_size) self.state_mix = nn.Linear(dim, dim) self.state_gate = nn.Linear(dim, dim) self.drop_state = nn.Dropout(dropout) self.norm_ff = nn.LayerNorm(dim) self.ff_in = nn.Linear(dim, hidden * 2) self.ff_out = nn.Linear(hidden, dim) self.drop_ff = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm_state(x) local = self.local_state(h) gated_state = torch.tanh(self.state_mix(local)) * torch.sigmoid(self.state_gate(h)) x = x + self.drop_state(gated_state) h = self.norm_ff(x) value, gate = self.ff_in(h).chunk(2, dim=-1) ff = self.ff_out(F.silu(gate) * value) x = x + self.drop_ff(ff) return x # ----------------------------- # Sparse RoPE retrieval # ----------------------------- class RotaryCache(nn.Module): """ RoPE cache for tensors shaped [B,H,T,D] and query tensors [B,H,1,D]. """ def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0): super().__init__() if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") self.head_dim = int(head_dim) self.max_seq_len = int(max_seq_len) inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) # Duplicate so cos/sin match [D] after rotate_half. emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin", emb.sin()[None, None, :, :], persistent=False) def apply_sequence(self, x: torch.Tensor) -> torch.Tensor: # x: [B,H,T,D] steps = x.size(-2) if steps > self.max_seq_len: raise ValueError( f"Sequence length {steps} exceeds RoPE cache length {self.max_seq_len}. " "Increase config.block_size." ) cos = self.cos[:, :, :steps, :].to(device=x.device, dtype=x.dtype) sin = self.sin[:, :, :steps, :].to(device=x.device, dtype=x.dtype) return (x * cos) + (rotate_half(x) * sin) def apply_query_positions(self, q: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: # q: [B,H,1,D], positions: [B] cos = self.cos[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :] sin = self.sin[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :] return (q * cos) + (rotate_half(q) * sin) class QueryOnlyRoPERetriever(nn.Module): """ Sparse retrieval applied only to selected positions. For each batch row, one retrieval position attends backward over prior token states using RoPE Q/K. This is O(T) per retrieved position, not O(T^2). This module is the key successful retrieval primitive from the experiments. """ def __init__( self, dim: int, retrieval_dim: int, retrieval_heads: int, block_size: int, dropout: float = 0.0, ): super().__init__() if retrieval_dim % retrieval_heads != 0: raise ValueError("retrieval_dim must be divisible by retrieval_heads") self.dim = int(dim) self.retrieval_dim = int(retrieval_dim) self.retrieval_heads = int(retrieval_heads) self.head_dim = retrieval_dim // retrieval_heads if self.head_dim % 2 != 0: raise ValueError("retrieval_dim / retrieval_heads must be even for RoPE") self.norm = nn.LayerNorm(dim) self.q = nn.Linear(dim, retrieval_dim) self.k = nn.Linear(dim, retrieval_dim) self.v = nn.Linear(dim, retrieval_dim) self.out = nn.Linear(retrieval_dim, dim) self.gate = nn.Linear(dim * 2, dim) self.dropout = nn.Dropout(dropout) self.rope = RotaryCache(self.head_dim, max_seq_len=block_size + 8) def forward(self, x: torch.Tensor, retrieval_positions: torch.Tensor) -> torch.Tensor: # x: [B,T,D], retrieval_positions: [B] batch, steps, _ = x.shape device = x.device bidx = torch.arange(batch, device=device) h = self.norm(x) k = self.k(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2) v = self.v(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2) k = self.rope.apply_sequence(k) qh = h[bidx, retrieval_positions] q = self.q(qh).view(batch, self.retrieval_heads, 1, self.head_dim) q = self.rope.apply_query_positions(q, retrieval_positions) scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Strictly backward. The retrieval position cannot read itself. pos = torch.arange(steps, device=device)[None, None, None, :] causal_mask = pos < retrieval_positions[:, None, None, None] scores = scores.masked_fill(~causal_mask, float("-inf")) att = F.softmax(scores, dim=-1) att = self.dropout(att) read = (att @ v).transpose(1, 2).contiguous().view(batch, self.retrieval_dim) read = self.out(read) old = x[bidx, retrieval_positions] gate = torch.sigmoid(self.gate(torch.cat([qh, read], dim=-1))) new = old + gate * read out = x.clone() out[bidx, retrieval_positions] = new return out # ----------------------------- # Main model # ----------------------------- class PGSMSparseRoPELM(nn.Module): """ Parallel Geometric State Model with optional query-only sparse RoPE retrieval. Forward API: logits, loss = model(input_ids, labels=None, retrieval_positions=None) input_ids: LongTensor [B,T] labels: LongTensor [B,T], optional. Standard next-token labels are supported. Use config.ignore_index for ignored positions. retrieval_positions: Optional LongTensor [B]. If supplied, sparse retrieval is applied exactly at these positions. If omitted, config controls whether to auto-detect query-token positions, use final token, or skip retrieval. """ def __init__(self, config: PGSMConfig): super().__init__() self.config = config self.token_emb = nn.Embedding(config.vocab_size, config.dim) self.blocks = nn.ModuleList( [ ParallelGeometricBlock( dim=config.dim, hidden=config.hidden, kernel_size=config.kernel_size, dropout=config.dropout, ) for _ in range(config.layers) ] ) self.retriever: Optional[QueryOnlyRoPERetriever] if config.use_sparse_retrieval: self.retriever = QueryOnlyRoPERetriever( dim=config.dim, retrieval_dim=config.retrieval_dim, retrieval_heads=config.retrieval_heads, block_size=config.block_size, dropout=config.retrieval_dropout, ) else: self.retriever = None self.post_retrieval_block: Optional[ParallelGeometricBlock] if config.use_sparse_retrieval and config.use_post_retrieval_block: self.post_retrieval_block = ParallelGeometricBlock( dim=config.dim, hidden=config.hidden, kernel_size=config.kernel_size, dropout=config.dropout, ) else: self.post_retrieval_block = None self.final_norm = nn.LayerNorm(config.dim) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) self.apply(lambda module: init_pgsm_weights(module, std=config.init_std)) if config.tie_weights: self.lm_head.weight = self.token_emb.weight @property def block_size(self) -> int: return self.config.block_size @property def vocab_size(self) -> int: return self.config.vocab_size def num_parameters(self, trainable_only: bool = True) -> int: return count_parameters(self, trainable_only=trainable_only) def _resolve_retrieval_positions( self, input_ids: torch.Tensor, retrieval_positions: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: if not self.config.use_sparse_retrieval: return None if retrieval_positions is not None: return retrieval_positions.to(device=input_ids.device, dtype=torch.long) if ( self.config.auto_retrieve_on_query_token and self.config.query_token_id is not None ): return _positions_from_query_tokens(input_ids, self.config.query_token_id) if self.config.retrieve_at_last_token_if_unspecified: return torch.full( (input_ids.size(0),), input_ids.size(1) - 1, dtype=torch.long, device=input_ids.device, ) return None def encode( self, input_ids: torch.Tensor, retrieval_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: if input_ids.dim() != 2: raise ValueError("input_ids must have shape [batch, steps]") if input_ids.size(1) > self.config.block_size: raise ValueError( f"Input length {input_ids.size(1)} exceeds config.block_size={self.config.block_size}" ) x = self.token_emb(input_ids) for block in self.blocks: x = block(x) positions = self._resolve_retrieval_positions(input_ids, retrieval_positions) if positions is not None: if self.retriever is None: raise RuntimeError("retriever is None but retrieval positions were resolved") x = self.retriever(x, positions) if self.post_retrieval_block is not None: x = self.post_retrieval_block(x) return self.final_norm(x) def forward( self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, retrieval_positions: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: x = self.encode(input_ids, retrieval_positions=retrieval_positions) logits = self.lm_head(x) loss: Optional[torch.Tensor] = None if labels is not None: loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=self.config.ignore_index, ) return logits, loss @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, ) -> torch.Tensor: """ Simple generation helper. For normal generation, sparse retrieval is not automatically applied unless config.retrieve_at_last_token_if_unspecified=True or query-token detection is enabled. Training modules can provide their own generation loop if they need custom retrieval-position behavior. """ self.eval() for _ in range(max_new_tokens): idx_cond = input_ids[:, -self.config.block_size :] logits, _ = self(idx_cond) logits = logits[:, -1, :] if temperature <= 0: next_id = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / temperature if top_k is not None: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits = logits.masked_fill(logits < values[:, [-1]], float("-inf")) probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_id], dim=1) return input_ids # ----------------------------- # Convenience factory # ----------------------------- def build_pgsm_model( size: str = "small", vocab_size: int = 256, block_size: int = 1024, **overrides: Any, ) -> PGSMSparseRoPELM: size = size.lower().strip() if size == "tiny": cfg = PGSMConfig.tiny(vocab_size=vocab_size, block_size=block_size, **overrides) elif size == "small": cfg = PGSMConfig.small(vocab_size=vocab_size, block_size=block_size, **overrides) elif size == "medium": cfg = PGSMConfig.medium(vocab_size=vocab_size, block_size=block_size, **overrides) elif size == "large": cfg = PGSMConfig.large(vocab_size=vocab_size, block_size=block_size, **overrides) else: raise ValueError(f"Unknown model size: {size!r}. Use tiny, small, medium, or large.") return PGSMSparseRoPELM(cfg) __all__ = [ "PGSMConfig", "PGSMSparseRoPELM", "build_pgsm_model", "count_parameters", ]