SurpriseLensModel / pgsm_sparse_rope_lm.py
nilmeruo's picture
Upload 3 files
01b6330 verified
Raw
History Blame Contribute Delete
20.5 kB
#!/usr/bin/env python3
"""
pgsm_sparse_rope_lm.py
Reusable model module for the custom LLM architecture developed from the
long-memory experiments:
Parallel Geometric State Model (PGSM)
+ optional query-only sparse RoPE retrieval head
Core design:
- Fast attention-free local backbone.
- Depthwise causal convolution for local state propagation.
- Gated state mixing.
- Gated MLP blocks.
- Optional sparse retrieval only at selected query positions.
- Retrieval dimension is configurable; experiments showed retrieval_dim=512
was the first strong setting at block_size=1024 / distance=768.
This file is intentionally model-only. It does not include training loops,
datasets, benchmark code, or CLI handling. Import it from your training module.
Example:
from pgsm_sparse_rope_lm import PGSMConfig, PGSMSparseRoPELM
cfg = PGSMConfig.small(vocab_size=256, block_size=1024)
model = PGSMSparseRoPELM(cfg)
logits, loss = model(input_ids, labels)
For retrieval tasks where only specific answer/query positions should do sparse
long-range retrieval:
logits, loss = model(input_ids, labels, retrieval_positions=answer_pos)
For normal causal LM pretraining, you can disable sparse retrieval or use
automatic query-token detection if your data marks query positions.
"""
from __future__ import annotations
import math
from dataclasses import asdict, dataclass, replace
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------
# Configuration
# -----------------------------
@dataclass(frozen=True)
class PGSMConfig:
# Vocabulary / sequence
vocab_size: int = 256
block_size: int = 1024
# Backbone
dim: int = 192
layers: int = 3
hidden: int = 384
kernel_size: int = 17
dropout: float = 0.0
# Sparse retrieval
use_sparse_retrieval: bool = True
retrieval_dim: int = 512
retrieval_heads: int = 4
retrieval_dropout: float = 0.0
# Retrieval positioning
# If retrieval_positions is passed to forward(), that wins.
# Otherwise, if query_token_id is set, positions matching it can be used.
# Otherwise, retrieval can be skipped or applied to the final token.
query_token_id: Optional[int] = None
auto_retrieve_on_query_token: bool = False
retrieve_at_last_token_if_unspecified: bool = False
# Output / loss behavior
tie_weights: bool = True
use_post_retrieval_block: bool = True
ignore_index: int = -100
# Init
init_std: float = 0.02
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def tiny(
cls,
vocab_size: int = 256,
block_size: int = 512,
**overrides: Any,
) -> "PGSMConfig":
cfg = cls(
vocab_size=vocab_size,
block_size=block_size,
dim=128,
layers=3,
hidden=256,
kernel_size=17,
retrieval_dim=256,
retrieval_heads=4,
)
return replace(cfg, **overrides)
@classmethod
def small(
cls,
vocab_size: int = 256,
block_size: int = 1024,
**overrides: Any,
) -> "PGSMConfig":
# Closest to the successful experiment, with retrieval_dim=512.
cfg = cls(
vocab_size=vocab_size,
block_size=block_size,
dim=192,
layers=3,
hidden=384,
kernel_size=17,
retrieval_dim=512,
retrieval_heads=4,
)
return replace(cfg, **overrides)
@classmethod
def medium(
cls,
vocab_size: int,
block_size: int = 2048,
**overrides: Any,
) -> "PGSMConfig":
cfg = cls(
vocab_size=vocab_size,
block_size=block_size,
dim=384,
layers=6,
hidden=1024,
kernel_size=21,
retrieval_dim=768,
retrieval_heads=8,
dropout=0.0,
retrieval_dropout=0.0,
)
return replace(cfg, **overrides)
@classmethod
def large(
cls,
vocab_size: int,
block_size: int = 4096,
**overrides: Any,
) -> "PGSMConfig":
cfg = cls(
vocab_size=vocab_size,
block_size=block_size,
dim=768,
layers=12,
hidden=2048,
kernel_size=25,
retrieval_dim=1024,
retrieval_heads=8,
dropout=0.0,
retrieval_dropout=0.0,
)
return replace(cfg, **overrides)
# -----------------------------
# Utility functions
# -----------------------------
def count_parameters(module: nn.Module, trainable_only: bool = True) -> int:
if trainable_only:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
return sum(p.numel() for p in module.parameters())
def init_pgsm_weights(module: nn.Module, std: float = 0.02) -> None:
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=std)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
def _positions_from_query_tokens(input_ids: torch.Tensor, query_token_id: int) -> torch.Tensor:
"""
Return one retrieval position per batch row.
If multiple query tokens exist, the last one is used.
If none exist in a row, the final token is used.
"""
batch, steps = input_ids.shape
device = input_ids.device
matches = input_ids.eq(int(query_token_id))
positions = torch.full((batch,), steps - 1, dtype=torch.long, device=device)
for b in range(batch):
found = torch.nonzero(matches[b], as_tuple=False).flatten()
if found.numel() > 0:
positions[b] = found[-1]
return positions
# -----------------------------
# Backbone blocks
# -----------------------------
class CausalDepthwiseConv(nn.Module):
"""
Depthwise causal convolution.
This is the main local state propagation primitive. It is parallel over time
during training and does not construct an attention matrix.
"""
def __init__(self, dim: int, kernel_size: int):
super().__init__()
self.dim = int(dim)
self.kernel_size = int(kernel_size)
self.conv = nn.Conv1d(
dim,
dim,
kernel_size,
groups=dim,
padding=kernel_size - 1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B,T,D]
y = self.conv(x.transpose(1, 2))
y = y[:, :, : x.size(1)]
return y.transpose(1, 2)
class ParallelGeometricBlock(nn.Module):
"""
Attention-free parallel geometric/state-mixing block.
Structure:
norm -> causal depthwise local state -> gated state residual
norm -> gated MLP -> residual
"""
def __init__(self, dim: int, hidden: int, kernel_size: int, dropout: float = 0.0):
super().__init__()
self.norm_state = nn.LayerNorm(dim)
self.local_state = CausalDepthwiseConv(dim, kernel_size)
self.state_mix = nn.Linear(dim, dim)
self.state_gate = nn.Linear(dim, dim)
self.drop_state = nn.Dropout(dropout)
self.norm_ff = nn.LayerNorm(dim)
self.ff_in = nn.Linear(dim, hidden * 2)
self.ff_out = nn.Linear(hidden, dim)
self.drop_ff = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm_state(x)
local = self.local_state(h)
gated_state = torch.tanh(self.state_mix(local)) * torch.sigmoid(self.state_gate(h))
x = x + self.drop_state(gated_state)
h = self.norm_ff(x)
value, gate = self.ff_in(h).chunk(2, dim=-1)
ff = self.ff_out(F.silu(gate) * value)
x = x + self.drop_ff(ff)
return x
# -----------------------------
# Sparse RoPE retrieval
# -----------------------------
class RotaryCache(nn.Module):
"""
RoPE cache for tensors shaped [B,H,T,D] and query tensors [B,H,1,D].
"""
def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0):
super().__init__()
if head_dim % 2 != 0:
raise ValueError("head_dim must be even for RoPE")
self.head_dim = int(head_dim)
self.max_seq_len = int(max_seq_len)
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Duplicate so cos/sin match [D] after rotate_half.
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin", emb.sin()[None, None, :, :], persistent=False)
def apply_sequence(self, x: torch.Tensor) -> torch.Tensor:
# x: [B,H,T,D]
steps = x.size(-2)
if steps > self.max_seq_len:
raise ValueError(
f"Sequence length {steps} exceeds RoPE cache length {self.max_seq_len}. "
"Increase config.block_size."
)
cos = self.cos[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
sin = self.sin[:, :, :steps, :].to(device=x.device, dtype=x.dtype)
return (x * cos) + (rotate_half(x) * sin)
def apply_query_positions(self, q: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
# q: [B,H,1,D], positions: [B]
cos = self.cos[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
sin = self.sin[0, 0, positions, :].to(device=q.device, dtype=q.dtype)[:, None, None, :]
return (q * cos) + (rotate_half(q) * sin)
class QueryOnlyRoPERetriever(nn.Module):
"""
Sparse retrieval applied only to selected positions.
For each batch row, one retrieval position attends backward over prior token
states using RoPE Q/K. This is O(T) per retrieved position, not O(T^2).
This module is the key successful retrieval primitive from the experiments.
"""
def __init__(
self,
dim: int,
retrieval_dim: int,
retrieval_heads: int,
block_size: int,
dropout: float = 0.0,
):
super().__init__()
if retrieval_dim % retrieval_heads != 0:
raise ValueError("retrieval_dim must be divisible by retrieval_heads")
self.dim = int(dim)
self.retrieval_dim = int(retrieval_dim)
self.retrieval_heads = int(retrieval_heads)
self.head_dim = retrieval_dim // retrieval_heads
if self.head_dim % 2 != 0:
raise ValueError("retrieval_dim / retrieval_heads must be even for RoPE")
self.norm = nn.LayerNorm(dim)
self.q = nn.Linear(dim, retrieval_dim)
self.k = nn.Linear(dim, retrieval_dim)
self.v = nn.Linear(dim, retrieval_dim)
self.out = nn.Linear(retrieval_dim, dim)
self.gate = nn.Linear(dim * 2, dim)
self.dropout = nn.Dropout(dropout)
self.rope = RotaryCache(self.head_dim, max_seq_len=block_size + 8)
def forward(self, x: torch.Tensor, retrieval_positions: torch.Tensor) -> torch.Tensor:
# x: [B,T,D], retrieval_positions: [B]
batch, steps, _ = x.shape
device = x.device
bidx = torch.arange(batch, device=device)
h = self.norm(x)
k = self.k(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
v = self.v(h).view(batch, steps, self.retrieval_heads, self.head_dim).transpose(1, 2)
k = self.rope.apply_sequence(k)
qh = h[bidx, retrieval_positions]
q = self.q(qh).view(batch, self.retrieval_heads, 1, self.head_dim)
q = self.rope.apply_query_positions(q, retrieval_positions)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Strictly backward. The retrieval position cannot read itself.
pos = torch.arange(steps, device=device)[None, None, None, :]
causal_mask = pos < retrieval_positions[:, None, None, None]
scores = scores.masked_fill(~causal_mask, float("-inf"))
att = F.softmax(scores, dim=-1)
att = self.dropout(att)
read = (att @ v).transpose(1, 2).contiguous().view(batch, self.retrieval_dim)
read = self.out(read)
old = x[bidx, retrieval_positions]
gate = torch.sigmoid(self.gate(torch.cat([qh, read], dim=-1)))
new = old + gate * read
out = x.clone()
out[bidx, retrieval_positions] = new
return out
# -----------------------------
# Main model
# -----------------------------
class PGSMSparseRoPELM(nn.Module):
"""
Parallel Geometric State Model with optional query-only sparse RoPE retrieval.
Forward API:
logits, loss = model(input_ids, labels=None, retrieval_positions=None)
input_ids:
LongTensor [B,T]
labels:
LongTensor [B,T], optional.
Standard next-token labels are supported.
Use config.ignore_index for ignored positions.
retrieval_positions:
Optional LongTensor [B].
If supplied, sparse retrieval is applied exactly at these positions.
If omitted, config controls whether to auto-detect query-token positions,
use final token, or skip retrieval.
"""
def __init__(self, config: PGSMConfig):
super().__init__()
self.config = config
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
self.blocks = nn.ModuleList(
[
ParallelGeometricBlock(
dim=config.dim,
hidden=config.hidden,
kernel_size=config.kernel_size,
dropout=config.dropout,
)
for _ in range(config.layers)
]
)
self.retriever: Optional[QueryOnlyRoPERetriever]
if config.use_sparse_retrieval:
self.retriever = QueryOnlyRoPERetriever(
dim=config.dim,
retrieval_dim=config.retrieval_dim,
retrieval_heads=config.retrieval_heads,
block_size=config.block_size,
dropout=config.retrieval_dropout,
)
else:
self.retriever = None
self.post_retrieval_block: Optional[ParallelGeometricBlock]
if config.use_sparse_retrieval and config.use_post_retrieval_block:
self.post_retrieval_block = ParallelGeometricBlock(
dim=config.dim,
hidden=config.hidden,
kernel_size=config.kernel_size,
dropout=config.dropout,
)
else:
self.post_retrieval_block = None
self.final_norm = nn.LayerNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
self.apply(lambda module: init_pgsm_weights(module, std=config.init_std))
if config.tie_weights:
self.lm_head.weight = self.token_emb.weight
@property
def block_size(self) -> int:
return self.config.block_size
@property
def vocab_size(self) -> int:
return self.config.vocab_size
def num_parameters(self, trainable_only: bool = True) -> int:
return count_parameters(self, trainable_only=trainable_only)
def _resolve_retrieval_positions(
self,
input_ids: torch.Tensor,
retrieval_positions: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if not self.config.use_sparse_retrieval:
return None
if retrieval_positions is not None:
return retrieval_positions.to(device=input_ids.device, dtype=torch.long)
if (
self.config.auto_retrieve_on_query_token
and self.config.query_token_id is not None
):
return _positions_from_query_tokens(input_ids, self.config.query_token_id)
if self.config.retrieve_at_last_token_if_unspecified:
return torch.full(
(input_ids.size(0),),
input_ids.size(1) - 1,
dtype=torch.long,
device=input_ids.device,
)
return None
def encode(
self,
input_ids: torch.Tensor,
retrieval_positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if input_ids.dim() != 2:
raise ValueError("input_ids must have shape [batch, steps]")
if input_ids.size(1) > self.config.block_size:
raise ValueError(
f"Input length {input_ids.size(1)} exceeds config.block_size={self.config.block_size}"
)
x = self.token_emb(input_ids)
for block in self.blocks:
x = block(x)
positions = self._resolve_retrieval_positions(input_ids, retrieval_positions)
if positions is not None:
if self.retriever is None:
raise RuntimeError("retriever is None but retrieval positions were resolved")
x = self.retriever(x, positions)
if self.post_retrieval_block is not None:
x = self.post_retrieval_block(x)
return self.final_norm(x)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
retrieval_positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
x = self.encode(input_ids, retrieval_positions=retrieval_positions)
logits = self.lm_head(x)
loss: Optional[torch.Tensor] = None
if labels is not None:
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=self.config.ignore_index,
)
return logits, loss
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
"""
Simple generation helper.
For normal generation, sparse retrieval is not automatically applied unless
config.retrieve_at_last_token_if_unspecified=True or query-token detection
is enabled. Training modules can provide their own generation loop if they
need custom retrieval-position behavior.
"""
self.eval()
for _ in range(max_new_tokens):
idx_cond = input_ids[:, -self.config.block_size :]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
if temperature <= 0:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k is not None:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = logits.masked_fill(logits < values[:, [-1]], float("-inf"))
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_id], dim=1)
return input_ids
# -----------------------------
# Convenience factory
# -----------------------------
def build_pgsm_model(
size: str = "small",
vocab_size: int = 256,
block_size: int = 1024,
**overrides: Any,
) -> PGSMSparseRoPELM:
size = size.lower().strip()
if size == "tiny":
cfg = PGSMConfig.tiny(vocab_size=vocab_size, block_size=block_size, **overrides)
elif size == "small":
cfg = PGSMConfig.small(vocab_size=vocab_size, block_size=block_size, **overrides)
elif size == "medium":
cfg = PGSMConfig.medium(vocab_size=vocab_size, block_size=block_size, **overrides)
elif size == "large":
cfg = PGSMConfig.large(vocab_size=vocab_size, block_size=block_size, **overrides)
else:
raise ValueError(f"Unknown model size: {size!r}. Use tiny, small, medium, or large.")
return PGSMSparseRoPELM(cfg)
__all__ = [
"PGSMConfig",
"PGSMSparseRoPELM",
"build_pgsm_model",
"count_parameters",
]