Chords-Progressions-Transformer / x_transformer_2_3_1.py
projectlosangeles's picture
Upload 6 files
706c1fb verified
#===================================================================================================================
#
# X Trasformer Python Module
#
# Partial x-transformers code With useful modifications as a stand-alone Python module
#
# Version 10.0
#
# Original source code courtesy of lucidrains
# https://github.com/lucidrains/x-transformers
#
# Original source code retrieved on 04/30/2025
# Original version 2.3.1 / Commit 458bc12
#
# Project Los Angeles
# Tegridy Code 2026
#
#===================================================================================================================
#
# Critical dependencies
#
# !pip install torch
# !pip install einops
# !pip install einx
# !pip install numpy
# !pip install scikit-learn
# !pip install matplotlib
#
#===================================================================================================================
from __future__ import annotations
import os
os.environ['USE_FLASH_ATTENTION'] = '1'
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
torch.backends.cuda.enable_flash_sdp(True)
#==================================================================================================================================
# attend.py
#==================================================================================================================================
from functools import partial
from typing import Tuple, Callable
import torch
from torch.nn import Module
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange, repeat, pack, unpack
#========================================================================================================================
# constants
@dataclass
class Intermediates:
qk_similarities: Tensor | None = None
pre_softmax_attn: Tensor | None = None
post_softmax_attn: Tensor | None = None
values: Tensor | None = None
cached_kv: Tuple[Tensor, Tensor] | None = None
layer_type: str | None = None
def to_tuple(self):
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def at_most_one_of(*bools):
return sum([*map(int, bools)]) <= 1
def compact(arr):
return [*filter(exists, arr)]
@torch.jit.script
def softclamp(t: Tensor, value: float):
return (t / value).tanh() * value
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# selective attention
# https://arxiv.org/abs/2410.02703 - section 3.3
# it is a technique to allow each token to prevent itself from being attended to by future tokens
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
def selective_attn(
sim,
sim_head_gate = None,
no_mask_sos = True
):
i, j, device = *sim.shape[-2:], sim.device
sim_head_gate = default(sim_head_gate, sim[:, 0])
gate = F.relu(sim_head_gate) # only positive
if no_mask_sos:
gate = gate.clone()
gate[..., -i] = 0.
eye = torch.eye(i, device = device)
if j > i:
eye = F.pad(eye, (j - i, 0), value = 1.)
gate = (1. - eye) * gate
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
gate = gate.cumsum(dim = -2)
return sim - rearrange(gate, 'b i j -> b 1 i j')
# alternative distance functions
def qk_l2_dist_squared(q, k):
if k.ndim == 3:
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
q, packed_shape = pack_one(q, '* i d')
k, _ = pack_one(k, '* j d')
l2_dist_squared = torch.cdist(q, k) ** 2
return unpack_one(l2_dist_squared, packed_shape, '* i j')
# one-hot straight through softmax
def one_hot_straight_through(logits, temperature = 1.):
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
soft_attn = (logits / temperature).softmax(dim = -1)
return one_hot + soft_attn - soft_attn.detach()
# sparse topk attention - only keep topk attn logits for softmax
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
def sparse_topk_attn(
logits,
sparse_topk,
temperature = 1.,
straight_through = False
):
orig_logits = logits
mask_value = -torch.finfo(logits.dtype).max
top_values, _ = logits.topk(sparse_topk, dim = -1)
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
topk_attn = logits.softmax(dim = -1)
if not straight_through:
return topk_attn
soft_attn = (orig_logits / temperature).softmax(dim = -1)
return topk_attn.detach() + soft_attn - soft_attn.detach()
# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def onnx_create_causal_mask(i, j, device):
r = torch.arange(i, device = device)
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
return causal_mask
# main class
class Attend(Module):
def __init__(
self,
*,
dropout = 0.,
causal = False,
heads = None,
pre_talking_heads = False,
post_talking_heads = False,
pre_scale_post_talking_heads = False,
sparse_topk = None,
sparse_topk_straight_through = False,
scale = None,
qk_norm = False,
l2_distance = False,
sigmoid = False,
custom_attn_fn: Callable | None = None,
flash = False,
softclamp_logits = False,
logit_softclamp_value = 50.,
add_zero_kv = False,
selective = False,
hard = False,
cope = None,
onnxable = False,
sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
self.scale = scale
# causal related
self.causal = causal
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
# attention type
is_sparse_topk_attn = exists(sparse_topk)
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
assert not (flash and hard), 'hard attention not available for flash'
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
if exists(custom_attn_fn):
self.attn_fn = custom_attn_fn
elif sigmoid:
self.attn_fn = F.sigmoid
elif hard:
self.attn_fn = one_hot_straight_through
elif is_sparse_topk_attn:
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
else:
softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
# dropouts
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
# talking heads
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
if exists(self.pre_softmax_talking_heads):
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
if exists(self.post_softmax_talking_heads):
nn.init.dirac_(self.post_softmax_talking_heads.weight)
if exists(self.pre_scale_post_talking_heads):
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
# selective attention
assert not (flash and selective), 'selective attention cannot work on flash attention'
assert not (selective and not causal), 'selective attention is designed for autoregressive'
self.selective = selective
# l2 distance attention
self.l2_distance = l2_distance
# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
self.add_zero_kv = add_zero_kv
# soft clamp attention logit value
if softclamp_logits:
assert not flash, 'flash attention not compatible with logit softclamp value yet'
assert logit_softclamp_value > 0.
self.softclamp_logits = softclamp_logits
self.logit_softclamp_value = logit_softclamp_value
# contextual positional encoding
self.cope = cope
# flash attention
self.flash = flash
torch_version = version.parse(torch.__version__)
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# torch 2.3 uses new backend and context manager
if torch_version >= version.parse('2.3'):
from torch.nn.attention import SDPBackend
str_to_backend = dict(
enable_flash = SDPBackend.FLASH_ATTENTION,
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
enable_math = SDPBackend.MATH,
enable_cudnn = SDPBackend.CUDNN_ATTENTION
)
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
else:
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3:
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
if v.ndim == 3:
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
# handle maybe l2 distance
if self.l2_distance:
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
k = F.pad(k, (0, 1), value = -1.)
k = torch.cat((k, k_norm_sq), dim = -1)
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
q = torch.cat((2 * q, q_norm_sq), dim = -1)
q = F.pad(q, (0, 1), value = -1.)
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
causal = self.causal
# in the case of kv caching with one token (q_len == 1), just turn off causal masking
# in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
if q_len == 1 and causal:
causal = False
# expand key padding mask
if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
# handle kv cache - this should be bypassable in updated flash attention 2
if k_len > q_len and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
if not exists(mask):
mask = ~causal_mask
else:
mask = mask & ~causal_mask
causal = False
# manually handle causal mask, if another mask was given
if exists(mask) and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
mask = mask & ~causal_mask
causal = False
# protect against an entire row being masked out
row_is_entirely_masked = None
if exists(mask):
row_is_entirely_masked = ~mask.any(dim = -1)
# handle alibi positional bias
# convert from bool to float
if exists(attn_bias):
attn_bias = attn_bias.expand(batch, heads, -1, -1)
# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
mask_value = -torch.finfo(q.dtype).max
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
elif causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
causal = False
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
# make it an additive bias here
mask = attn_bias
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with self.sdp_context_manager():
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
# for a row that is entirely masked out, should zero out the output of that row token
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out, Intermediates()
def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
causal = self.causal
# handle key padding mask
if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b j -> b 1 1 j')
# handle kv cached decoding
if n == 1 and causal:
causal = False
# handle grouped multi-query attention
if kv_heads == 1:
k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
elif kv_heads < heads:
k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
# handle zero kv, as means for allowing network to attend to nothing
if self.add_zero_kv:
k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
if exists(attn_bias):
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
if not self.l2_distance:
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
else:
sim = -qk_l2_dist_squared(q, k)
sim = sim * scale
if exists(prev_attn):
sim = sim + prev_attn
qk_similarities = sim.clone()
if exists(self.pre_scale_post_talking_heads):
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
if exists(self.pre_softmax_talking_heads):
sim = sim + self.pre_softmax_talking_heads(sim)
if exists(attn_bias):
sim = sim + attn_bias
if self.softclamp_logits:
sim = softclamp(sim, self.logit_softclamp_value)
i, j, dtype = *sim.shape[-2:], sim.dtype
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if causal:
causal_mask = self.create_causal_mask(i, j, device = device)
sim = sim.masked_fill(causal_mask, mask_value)
row_is_entirely_masked = None
if exists(mask):
row_is_entirely_masked = ~mask.any(dim = -1)
if exists(self.cope):
sim = sim + self.cope(q, sim)
if self.selective:
sim = selective_attn(sim)
pre_softmax_attn = sim
attn = self.attn_fn(sim)
attn = attn.type(dtype)
post_softmax_attn = attn
attn = self.attn_dropout(attn)
if exists(self.post_softmax_talking_heads):
attn = self.post_softmax_talking_heads(attn)
if exists(self.pre_scale_post_talking_heads):
attn = attn * pre_to_post_scale
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
intermediates = Intermediates(
qk_similarities = qk_similarities,
pre_softmax_attn = pre_softmax_attn,
post_softmax_attn = post_softmax_attn
)
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out, intermediates
#=================================================================================================================================
# x_transformers.py
#=================================================================================================================================
from typing import Callable
import math
from copy import deepcopy
from random import random, randrange
from packaging import version
import torch
from torch.amp import autocast
import torch.nn.functional as F
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.nn import Module, ModuleList, ModuleDict
from functools import partial, wraps
from collections import namedtuple
from contextlib import nullcontext
from dataclasses import dataclass
import einx
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
# einstein notation
# b - batch
# n - sequence
# d - feature dimension
# h - attention heads
# i, j - sequence (source, target)
# constants
DEFAULT_DIM_HEAD = 64
@dataclass
class LayerIntermediates:
hiddens: list[Tensor] | None = None # all hiddens, before the final norm (in pre-norm architecture)
last_hidden: Tensor | None = None # very last hidden after all attention layers, after the final norm
attn_intermediates: list[Intermediates] | None = None
layer_hiddens: list[Tensor] | None = None
attn_z_loss: Tensor | None = None
mems: Tensor | None = None
memory_tokens: Tensor | None = None
logit_entropies: Tensor | None = None
LinearNoBias = partial(nn.Linear, bias = False)
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def identity(t, *args, **kwargs):
return t
def first(it, default = None):
return it[0] if len(it) > 0 else default
def is_empty(x):
return len(x) == 0
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
def divisible_by(num, den):
return (num % den) == 0
def maybe(fn = None):
if not exists(fn):
fn = identity
@wraps(fn)
def inner(x, *args, **kwargs):
if not exists(x):
return x
return fn(x, *args, **kwargs)
return inner
def at_most_one_of(*bools):
return sum(map(int, bools)) <= 1
class always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class not_equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x != self.val
class equals():
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x == self.val
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
# tensor helpers
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def l2norm(t, groups = 1):
t = rearrange(t, '... (g d) -> ... g d', g = groups)
t = F.normalize(t, p = 2, dim = -1)
return rearrange(t, '... g d -> ... (g d)')
def softclamp(t, value):
return (t / value).tanh() * value
def masked_mean(t, mask = None, dim = 1):
if not exists(mask):
return t.mean(dim = dim)
dims_append = (1,) * (t.ndim - mask.ndim)
mask = mask.reshape(*mask.shape, *dims_append)
num = (t * mask).sum(dim = dim)
den = mask.sum(dim = dim).clamp(min = 1.)
return num / den
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
if pad == (0, 0):
return t
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# entropy
def calc_entropy(
t: Tensor,
is_prob = False
):
prob = t.softmax(dim = -1) if not is_prob else t
return -(prob * log(prob)).sum(dim = -1)
# auxiliary loss helpers
def calc_z_loss(
pre_softmax_attns: list[Tensor],
mask = None,
weight = 1.
):
# the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
# in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
# also used in PaLM as one of the measures
lse = 0.
for attn in pre_softmax_attns:
lse = lse + attn.logsumexp(dim = -1)
loss = torch.square(lse)
loss = reduce(loss, 'b h n -> b n', 'sum')
if not exists(mask):
return loss.mean() * weight
loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
return loss * weight
# init helpers
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
# keyword argument helpers
def pick_and_pop(keys, d):
values = tuple(d.pop(key) for key in keys)
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return tuple(return_val)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
prefix_len = len(prefix)
kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
return kwargs_without_prefix, kwargs
# structured dropout, more effective than traditional attention dropouts
def dropout_seq(seq, mask, dropout):
b, n, *_, device = *seq.shape, seq.device
logits = torch.randn(b, n, device = device)
if exists(mask):
mask_value = max_neg_value(logits)
logits = logits.masked_fill(~mask, mask_value)
keep_prob = 1. - dropout
num_keep = max(1, int(keep_prob * n))
keep_indices = logits.topk(num_keep, dim = 1).indices
batch_indices = arange(b, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')
seq = seq[batch_indices, keep_indices]
if exists(mask):
seq_counts = mask.sum(dim = -1)
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
mask = mask[batch_indices, keep_indices] & keep_mask
return seq, mask
# activations
class ReluSquared(Module):
def forward(self, x):
return F.relu(x) ** 2
# embedding
class TokenEmbedding(Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
super().__init__()
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(num_tokens, dim)
def forward(self, x):
token_emb = self.emb(x.long())
return l2norm(token_emb) if self.l2norm_embed else token_emb
def init_(self):
if self.l2norm_embed:
nn.init.normal_(self.emb.weight, std=1e-5)
return
nn.init.kaiming_normal_(self.emb.weight)
# positional embeddings
class AbsolutePositionalEmbedding(Module):
def __init__(self, dim, max_seq_len, l2norm_embed = False):
super().__init__()
self.scale = dim ** -0.5 if not l2norm_embed else 1.
self.max_seq_len = max_seq_len
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if not exists(pos):
pos = arange(seq_len, device = device)
if exists(seq_start_pos):
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return l2norm(pos_emb) if self.l2norm_embed else pos_emb
class ScaledSinusoidalEmbedding(Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert divisible_by(dim, 2)
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if not exists(pos):
pos = arange(seq_len, device = device)
if exists(seq_start_pos):
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RelativePositionBias(Module):
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
@property
def device(self):
return next(self.parameters()).device
def forward(self, i, j):
device = self.device
q_pos = arange(j - i, j, dtype = torch.long, device = device)
k_pos = arange(j, dtype = torch.long, device = device)
rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale
class CoPE(Module):
"""
Appendix B of https://arxiv.org/abs/2405.18719
"""
def __init__ (
self,
dim,
heads,
max_pos,
soft_onehot = False,
talking_heads = False,
soft_onehot_temp = 5e-2
):
super () . __init__ ()
self.max_pos = max_pos
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
self.soft_onehot = soft_onehot
self.soft_onehot_temp = soft_onehot_temp
if not soft_onehot:
return
self.register_buffer('positions', arange(max_pos))
def forward(self, query, attn_logits):
if exists(self.talking_heads):
i, j = attn_logits.shape[-2:]
causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
attn_logits = self.talking_heads(attn_logits)
attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
# compute positions
gates = attn_logits.sigmoid()
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
pos = pos.clamp(max = self.max_pos - 1)
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
if self.soft_onehot:
diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
else:
# interpolate from integer positions
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()
logits_ceil = logits_int.gather(-1, pos_ceil)
logits_floor = logits_int.gather(-1, pos_floor)
w = pos - pos_floor
cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
return cope_pos_emb
class DynamicPositionBias(Module):
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
super().__init__()
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
self.log_distance = log_distance
self.mlp = ModuleList([])
self.mlp.append(Sequential(
nn.Linear(1, dim),
LayerNorm(dim) if norm else None,
nn.SiLU()
))
for _ in range(depth - 1):
self.mlp.append(Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim) if norm else None,
nn.SiLU()
))
self.mlp.append(nn.Linear(dim, heads))
@property
def device(self):
return next(self.parameters()).device
def forward(self, i, j):
n, device = j, self.device
# get the (n x n) matrix of distances
seq_arange = arange(j - i, j, device = device)
context_arange = arange(j, device = device)
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
indices += (j - 1)
# input to continuous positions MLP
pos = arange(-j + 1, j, device = device).float()
pos = rearrange(pos, '... -> ... 1')
if self.log_distance:
pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
for layer in self.mlp:
pos = layer(pos)
# get position biases
bias = pos[indices]
bias = rearrange(bias, 'i j h -> h i j')
return bias
class AlibiPositionalBias(Module):
def __init__(
self,
heads,
total_heads = None,
slopes: list[int] | None = None,
**kwargs
):
super().__init__()
self.heads = heads
self.total_heads = default(total_heads, heads)
slopes = Tensor(default(slopes, self._get_slopes(heads)))
slopes = rearrange(slopes, 'h -> h 1 1')
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)
@property
def device(self):
return next(self.buffers()).device
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
def forward_custom_pos(
self,
pos_i: Tensor,
pos_j: Tensor | None = None
):
h, device = self.total_heads, self.device
pos_j = default(pos_j, pos_i)
bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
if bias.ndim == 3:
bias = rearrange(bias, 'b i j -> b 1 i j')
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[-3]
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
return bias
def forward(self, i, j):
h, device = self.total_heads, self.device
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
return self.bias[..., -i:, -j:]
seq_arange = arange(j - i, j, device = device)
context_arange = arange(j, device = device)
bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[-3]
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
self.register_buffer('bias', bias, persistent = False)
return self.bias
class DataDependentAlibi(Module):
""" https://openreview.net/forum?id=q2Lnyegkr8 """
def __init__(
self,
dim,
heads,
causal = True,
bias_init = 5.,
post_log_scale = 1.,
):
super().__init__()
self.causal = causal
linear = nn.Linear(dim, heads * (1 if causal else 2))
self.to_forget_gates = nn.Sequential(
linear,
Rearrange('b n h -> b h n'),
nn.LogSigmoid()
)
nn.init.constant_(linear.bias, bias_init)
self.post_log_scale = post_log_scale
def forward(self, x):
bidirectional = not self.causal
forget_gates = self.to_forget_gates(x) * self.post_log_scale
forget_gates = forget_gates.cumsum(dim = -1)
if bidirectional:
forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
if bidirectional:
forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
return forget_gates
class PerRowDataDependentAlibi(Module):
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
def __init__(
self,
dim,
heads,
causal = True,
dim_head = 8,
post_log_scale = 1.
):
super().__init__()
assert causal, 'bidirectional not supported yet'
self.scale = dim_head ** -0.5
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
self.to_forget_gates = nn.Sequential(
linear,
Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
)
self.post_log_scale = post_log_scale
def forward(self, x):
q, k = self.to_forget_gates(x)
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
# mask out upper triangle + diagonal
n = x.shape[-2]
causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
forget_gates = forget_gates.masked_fill(causal_mask, 0.)
# reverse cumsum
forget_gates = forget_gates.flip(dims = (-1,))
forget_gates = forget_gates.cumsum(dim = -1)
forget_gates = forget_gates.flip(dims = (-1,))
return forget_gates
class RotaryEmbedding(Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = arange(seq_len, device = device)
return self.forward(t)
@autocast('cuda', enabled = False)
def forward(self, t):
max_pos = t.max() + 1
if t.ndim == 1:
t = rearrange(t, 'n -> 1 n')
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
freqs = stack((freqs, freqs), dim = -1)
freqs = rearrange(freqs, '... d r -> ... (d r)')
if not exists(self.scale):
return freqs, 1.
power = (t - (max_pos // 2)) / self.scale_base
scale = self.scale ** rearrange(power, '... n -> ... n 1')
scale = stack((scale, scale), dim = -1)
scale = rearrange(scale, '... d r -> ... (d r)')
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
freqs = freqs[:, -seq_len:, :]
scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = cat((t, t_unrotated), dim = -1)
return out.type(orig_dtype)
# norms
class Scale(Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)
scale_fn = lambda t: t * self.value
if not isinstance(out, tuple):
return scale_fn(out)
return (scale_fn(out[0]), *out[1:])
class LayerNorm(Module):
def __init__(
self,
dim,
unit_offset = False
):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
self.unit_offset = unit_offset
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
self.gamma = nn.Parameter(torch.ones(dim))
nn.init.constant_(self.gamma, 1. - float(unit_offset))
def forward(self, x):
normed = self.ln(x)
gamma = self.gamma + float(self.unit_offset)
return normed * gamma
class AdaptiveLayerNorm(Module):
def __init__(
self,
dim,
dim_condition = None
):
super().__init__()
dim_condition = default(dim_condition, dim)
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
self.to_gamma = LinearNoBias(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)
def forward(self, x, *, condition):
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
normed = self.ln(x)
gamma = self.to_gamma(condition)
return normed * (gamma + 1.)
class ScaleNorm(Module):
def __init__(
self,
dim,
unit_offset = False
):
super().__init__()
self.unit_offset = unit_offset
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.zeros(1))
nn.init.constant_(self.g, 1. - float(unit_offset))
def forward(self, x):
gamma = self.g + float(self.unit_offset)
return F.normalize(x, dim = -1) * self.scale * gamma
class RMSNorm(Module):
def __init__(
self,
dim,
unit_offset = False
):
super().__init__()
self.unit_offset = unit_offset
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.zeros(dim))
nn.init.constant_(self.g, 1. - float(unit_offset))
def forward(self, x):
gamma = self.g + float(self.unit_offset)
return F.normalize(x, dim = -1) * self.scale * gamma
class AdaptiveRMSNorm(Module):
def __init__(
self,
dim,
dim_condition = None
):
super().__init__()
self.scale = dim ** 0.5
dim_condition = default(dim_condition, dim)
self.to_gamma = LinearNoBias(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)
def forward(self, x, *, condition):
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
normed = F.normalize(x, dim = -1)
gamma = self.to_gamma(condition)
return normed * self.scale * (gamma + 1.)
class SimpleRMSNorm(Module):
def __init__(
self,
dim,
**kwargs
):
super().__init__()
self.scale = dim ** 0.5
def forward(self, x):
return F.normalize(x, dim = -1) * self.scale
class MultiheadRMSNorm(Module):
def __init__(self, dim, heads):
super().__init__()
self.rmsnorm = SimpleRMSNorm(dim)
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
def forward(self, x):
return self.rmsnorm(x) * (self.gamma + 1.)
class DynamicTanh(Module):
""" https://arxiv.org/abs/2503.10622 """
def __init__(
self,
dim,
init_alpha = 1.,
gamma = 1.,
beta = 0.,
unit_offset = False
):
super().__init__()
self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.pre_tanh_scale_offset = init_alpha if unit_offset else 0.
self.gamma_offset = float(unit_offset)
nn.init.constant_(self.pre_tanh_scale, 0 if unit_offset else init_alpha)
nn.init.constant_(self.gamma, 1. - float(unit_offset))
def forward(self, x):
pre_tanh_scale = self.pre_tanh_scale + self.pre_tanh_scale_offset
gamma = self.gamma + self.gamma_offset
return (x * pre_tanh_scale).tanh() * gamma + self.beta
# residual and residual gates
class Residual(Module):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
super().__init__()
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
self.scale_residual_constant = scale_residual_constant
def prepare(self, residual):
return residual, residual, dict()
def forward(self, x, residual, **kwargs):
if exists(self.residual_scale):
residual = residual * self.residual_scale
if self.scale_residual_constant != 1:
residual = residual * self.scale_residual_constant
return x + residual
class GRUGating(Module):
def __init__(self, dim, scale_residual = False, **kwargs):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
def prepare(self, residual):
return residual, residual, dict()
def forward(self, x, residual, **kwargs):
if exists(self.residual_scale):
residual = residual * self.residual_scale
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
return gated_output.reshape_as(x)
# hyper connections
class HyperConnection(Module):
def __init__(
self,
dim,
*,
layer_index,
num_residual_streams,
num_input_views = 1,
tanh = True,
**kwargs
):
"""
https://arxiv.org/abs/2409.19606
Appendix J - Algorithm 2, Dynamic only
"""
super().__init__()
self.act = nn.Tanh() if tanh else nn.Identity()
self.norm = nn.LayerNorm(dim, bias = False)
self.num_residual_streams = num_residual_streams
self.layer_index = layer_index
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
init_alpha0[layer_index % num_residual_streams, :] = 1.
self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
self.num_input_views = num_input_views
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
def prepare(self, residuals):
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
normed = self.norm(residuals)
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha
dc_weight = self.act(normed @ self.dynamic_beta_fn)
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta
# width connection
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
views = self.num_input_views
if views == 1:
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
else:
branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
branch_input = rearrange(branch_input, '... v d -> v ... d')
return branch_input, residuals, dict(beta = beta)
def forward(self, x, residuals, *, beta):
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
return rearrange(residuals, 'b n s d -> (b s) n d')
# LIMe - layer integrated memory (dynamic version)
class DynamicLIMe(Module):
def __init__(
self,
dim,
num_layers,
num_views = 1,
norm = True,
use_softmax = True
):
super().__init__()
self.num_layers = num_layers
self.multiple_views = num_views > 1
self.to_weights = Sequential(
RMSNorm(dim) if norm else None,
nn.Linear(dim, num_views * num_layers),
Rearrange('... (views layers) -> views ... layers', views = num_views),
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
)
def forward(
self,
x,
hiddens
):
if not is_tensor(hiddens):
hiddens = stack(hiddens)
assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
weights = self.to_weights(x)
out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
if self.multiple_views:
return out
return rearrange(out, '1 ... -> ...')
# token shifting
def shift(t, amount, mask = None):
if amount == 0:
return t
amount = min(amount, t.shape[1])
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
class ShiftTokens(Module):
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
mask = kwargs.get('mask', None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
x = cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
class FoldAxially(Module):
def __init__(
self,
axial_dim,
fn: Module
):
super().__init__()
self.fn = fn
self.axial_dim = axial_dim # will fold the sequence as rearrange("b (n axial_dim) ... -> (b axial_dim) n ...")
def forward(
self,
x,
**kwargs
):
if self.axial_dim == 1:
return self.fn(x, **kwargs)
seq_len, axial_dim = x.shape[1], self.axial_dim
next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
out = self.fn(x, **kwargs)
(out, *rest_out), tree_spec = tree_flatten(out)
out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
out = out[:, :seq_len]
out = tree_unflatten((out, *rest_out), tree_spec)
return out
# post branch operator
class LayerScale(Module):
def __init__(
self,
fn: Module,
dim,
init_value = 0.,
unit_offset = False
):
super().__init__()
self.unit_offset = unit_offset
self.fn = fn
self.gamma = nn.Parameter(torch.zeros(dim))
nn.init.constant_(self.gamma, init_value - float(unit_offset))
def forward(self, x, **kwargs):
out = self.fn(x, **kwargs)
gamma = self.gamma + float(self.unit_offset)
if isinstance(out, Tensor):
return out * gamma
out, *rest = out
return out * gamma, *rest
class AdaptiveLayerScale(Module):
def __init__(
self,
fn: Module,
dim,
dim_condition = None,
init_bias_value = -2.
):
super().__init__()
self.fn = fn
dim_condition = default(dim_condition, dim)
self.to_gamma = nn.Linear(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)
nn.init.constant_(self.to_gamma.bias, init_bias_value)
def forward(self, x, *, condition, **kwargs):
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
out = self.fn(x, **kwargs)
gamma = self.to_gamma(condition).sigmoid()
if isinstance(out, Tensor):
return out * gamma
out, *rest = out
return out * gamma, *rest
# skip connection combining
class ConcatCombine(Module):
def __init__(self, dim, prev_layer_ind):
super().__init__()
self.prev_layer_ind = prev_layer_ind
self.combine = LinearNoBias(dim * 2, dim)
def forward(self, x, prev_layers: list[Tensor]):
skip = prev_layers[self.prev_layer_ind]
concatted_skip = cat((skip, x), dim = -1)
return self.combine(concatted_skip)
# feedforward
class GLU(Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
mult_bias = False
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2)
self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim = -1)
return x * self.act(gate) * self.mult_bias
class FeedForward(Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
glu = False,
glu_mult_bias = False,
swish = False,
relu_squared = False,
custom_activation = None,
post_act_ln = False,
dropout = 0.,
sublayer_dropout = 0.,
no_bias = False,
zero_init_output = False
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
if exists(custom_activation):
activation = deepcopy(custom_activation)
elif relu_squared:
activation = ReluSquared()
elif swish:
activation = nn.SiLU()
else:
activation = nn.GELU()
if glu:
project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
else:
project_in = nn.Sequential(
nn.Linear(dim, inner_dim, bias = not no_bias),
activation
)
self.ff = Sequential(
project_in,
LayerNorm(inner_dim) if post_act_ln else None,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out, bias = not no_bias),
nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
)
# init last linear layer to 0
if zero_init_output:
init_zero_(self.ff[-1])
def forward(self, x):
return self.ff(x)
# attention. it is all we need
class Attention(Module):
def __init__(
self,
dim,
dim_head = DEFAULT_DIM_HEAD,
dim_context = None,
heads = 8,
causal = False,
flash = False,
pre_talking_heads = False,
post_talking_heads = False,
pre_scale_post_talking_heads = False,
head_scale = False,
sparse_topk = None,
sparse_topk_straight_through = False,
num_mem_kv = 0,
dropout = 0.,
sublayer_dropout = 0.,
on_attn = False,
gate_value_heads = False,
swiglu_values = False,
gate_values = False,
zero_init_output = False,
hard = False,
max_attend_past = None,
qk_norm = False,
qk_norm_groups = 1,
qk_norm_scale = 10,
qk_norm_dim_scale = False,
l2_distance = False,
sigmoid = False,
selective = False,
custom_attn_fn: Callable | None = None,
hybrid_module: Module | None = None,
hybrid_mask_kwarg: str | None = None,
hybrid_fold_axial_dim: int | None = None,
hybrid_learned_mix = False,
one_kv_head = False,
kv_heads = None,
value_dim_head = None,
dim_out = None,
add_zero_kv = False, # same as add_zero_attn in pytorch
rotate_num_heads = None,
data_dependent_alibi = False,
data_dependent_alibi_per_row = False,
data_dependent_alibi_per_row_dim_head = 8,
data_dependent_alibi_kwargs: dict = dict(),
use_cope = False,
cope_max_pos = 16,
cope_soft_onehot_pos = False,
cope_talking_heads = False,
softclamp_logits = False,
logit_softclamp_value = 50.,
learned_value_residual_mix = False,
laser = False, # https://arxiv.org/abs/2411.03493v1
laser_softclamp_value = 15.,
qkv_receive_diff_residuals = False,
use_latent_q = False,
dim_latent_q = None,
use_latent_kv = False,
dim_latent_kv = None,
latent_rope_subheads = None,
onnxable = False,
attend_sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
dim_kv = default(dim_context, dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.causal = causal
self.max_attend_past = max_attend_past
assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
value_dim_head = default(value_dim_head, dim_head)
kv_heads = default(kv_heads, heads)
kv_heads = 1 if one_kv_head else kv_heads
assert divisible_by(heads, kv_heads)
self.kv_heads = kv_heads
q_dim = dim_head * heads
k_dim = dim_head * kv_heads
v_dim = value_dim_head * kv_heads
out_dim = value_dim_head * heads
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
# for eventually supporting multi-latent attention (MLA)
self.to_latent_q = None
self.to_latent_kv = None
self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
dim_q_input = dim
dim_kv_input = dim_kv
if use_latent_q:
assert exists(dim_latent_q)
self.to_latent_q = LinearNoBias(dim, dim_latent_q)
dim_q_input = dim_latent_q
if use_latent_kv:
assert exists(dim_latent_kv)
self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
dim_kv_input = dim_latent_kv
if exists(latent_rope_subheads):
assert not exists(rotate_num_heads), '`rotate_num_heads` cannot be set when multi-latent attention is being used'
rotate_num_heads = latent_rope_subheads
k_dim = dim_head * (kv_heads - latent_rope_subheads)
self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
self.use_latent_q = use_latent_q
self.use_latent_kv = use_latent_kv
# query key projection
self.to_q = LinearNoBias(dim_q_input, q_dim)
self.to_k = LinearNoBias(dim_kv_input, k_dim)
self.to_v = LinearNoBias(dim_kv_input, v_dim)
# split and merge of attention heads
self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
# whether qkv receives different residual stream combinations from hyper connections or lime
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
# enhancing gradients to attention through exponentiated values
self.laser = laser
self.laser_softclamp_value = laser_softclamp_value
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, out_dim)
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)
# add per head gating of the output values, from 'Attend to nothing' paper
self.to_v_head_gate = None
if gate_value_heads:
self.to_v_head_gate = nn.Linear(dim, heads)
nn.init.constant_(self.to_v_head_gate.weight, 0)
nn.init.constant_(self.to_v_head_gate.bias, 10)
# cosine sim attention
self.qk_norm = qk_norm
self.qk_norm_groups = qk_norm_groups
self.qk_norm_scale = qk_norm_scale
# whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
self.qk_norm_dim_scale = qk_norm_dim_scale
self.qk_norm_q_scale = self.qk_norm_k_scale = 1
if qk_norm and qk_norm_dim_scale:
self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
self.qk_norm_k_scale = nn.Parameter(torch.ones(kv_heads, 1, dim_head))
assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
# contextual positional encoding
# https://arxiv.org/html/2405.18719v2
cope = None
if use_cope:
assert causal, 'CoPE was designed for causal attention'
assert not flash, 'CoPE is not flash attention compatible'
cope = CoPE(
dim = dim_head,
heads = heads,
max_pos = cope_max_pos,
talking_heads = cope_talking_heads,
soft_onehot = cope_soft_onehot_pos
)
# data dependent alibi
# https://openreview.net/forum?id=q2Lnyegkr8
self.data_dependent_alibi = None
if data_dependent_alibi:
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
if data_dependent_alibi_per_row:
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
# attend class - includes core attention algorithm + talking heads
self.attend = Attend(
heads = heads,
causal = causal,
pre_talking_heads = pre_talking_heads,
post_talking_heads = post_talking_heads,
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
dropout = dropout,
sparse_topk = sparse_topk,
sparse_topk_straight_through = sparse_topk_straight_through,
hard = hard,
qk_norm = qk_norm,
scale = qk_norm_scale if qk_norm else self.scale,
l2_distance = l2_distance,
sigmoid = sigmoid,
selective = selective,
custom_attn_fn = custom_attn_fn,
add_zero_kv = add_zero_kv,
flash = flash,
softclamp_logits = softclamp_logits,
logit_softclamp_value = logit_softclamp_value,
cope = cope,
onnxable = onnxable,
sdp_kwargs = attend_sdp_kwargs
)
# head scaling
self.head_scale = head_scale
if head_scale:
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
# maybe learned value residual mixer per token
self.to_value_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else always(0.5)
# attention on attention
self.attn_on_attn = on_attn
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
hybrid_mix = None
hybrid_norms = None
hybrid_module = maybe(deepcopy)(hybrid_module)
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
hybrid_norms = ModuleList([
MultiheadRMSNorm(dim_head, heads = heads),
MultiheadRMSNorm(dim_head, heads = heads)
])
self.hybrid_module = hybrid_module
self.hybrid_norms = hybrid_norms
self.hybrid_mix = hybrid_mix
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
# output dimension by default same as input, but can be overridden
dim_out = default(dim_out, dim)
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
# sublayer dropout
self.sublayer_dropout = nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
# the number of attention heads to rotate, for decoupled rope in multi-latent attention
rotate_num_heads = default(rotate_num_heads, heads)
assert 0 < rotate_num_heads <= heads
is_partial_rotate_heads = rotate_num_heads < heads
assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
self.rotate_num_heads = rotate_num_heads
# whether parent can kv cache
self.can_cache_kv = not selective
# init output projection 0
if zero_init_output:
init_zero_(self.to_out)
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
rel_pos = None,
attn_bias = None,
rotary_pos_emb = None,
context_rotary_pos_emb = None,
pos = None, # for custom alibi positions
prev_attn = None,
mem = None,
mem_mask = None,
return_intermediates = False,
cache: Intermediates | None = None,
value_residual = None
):
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
# an interesting possibility with hyper connections
# having queries, keys, values be routed from different layers
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
if qkv_receive_diff_residuals:
assert x.ndim == 4 and x.shape[0] == 3
q_input, k_input, v_input = x
else:
kv_input = default(context, x)
q_input, k_input, v_input = x, kv_input, kv_input
if exists(mem):
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
v_input, _ = pack([mem, v_input], 'b * d')
# multi-latent attention logic
# https://arxiv.org/abs/2405.04434 - Deepseek-AI team
k_sub_heads = None # the rotateable subheads of keys derived from base sequence
if self.use_latent_q:
q_input = self.to_latent_q(q_input)
if is_multi_latent_attn:
assert not qkv_receive_diff_residuals
needs_k_sub_heads = exists(self.to_rotateable_k)
latent_kv_input = self.to_latent_kv(k_input)
if needs_k_sub_heads:
rotateable_k = self.to_rotateable_k(k_input)
k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
if exists(cache):
cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
if exists(maybe_cached_k_sub_heads):
k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
if return_intermediates:
cached_kv = (latent_kv_input, k_sub_heads)
k_input = v_input = latent_kv_input
# query, key, value projection
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q = self.split_q_heads(q)
k = self.split_k_heads(k)
v = self.split_v_heads(v)
# take care of decoupled rope from multi-latent attention
if exists(k_sub_heads):
k = cat((k, k_sub_heads), dim = 1)
# if previous values passed in for residual, either invoke resformer
orig_values = v
# https://arxiv.org/abs/2410.17897v1
if exists(value_residual):
value_residual_mix = self.to_value_residual_mix(q_input)
v = value_residual.lerp(v, value_residual_mix)
# qk normalization
if self.qk_norm:
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
q, k = map(qk_l2norm, (q, k))
scale = self.qk_norm_scale
q = q * self.qk_norm_q_scale
k = k * self.qk_norm_k_scale
# take care of caching
if not is_multi_latent_attn:
if exists(cache):
ck, cv = cache.cached_kv
if exists(mem):
mk, k = unpack(k, mem_packed_shape, 'b h * d')
mv, v = unpack(v, mem_packed_shape, 'b h * d')
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
if exists(mem):
k = cat((mk, k), dim = -2)
v = cat((mv, v), dim = -2)
if return_intermediates:
mem_len = mem.shape[-2] if exists(mem) else 0
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
if exists(rotary_pos_emb):
rotate_num_heads = self.rotate_num_heads
partial_rotate_heads = rotate_num_heads < h
freqs, xpos_scale = rotary_pos_emb
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
if partial_rotate_heads:
q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
if has_context:
# override with `context_rotary_pos_emb` if provided
freqs, xpos_scale = context_rotary_pos_emb
_, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
if partial_rotate_heads:
q = cat((q_rest, q), dim = 1)
k = cat((k_rest, k), dim = 1)
input_mask = context_mask
if not exists(input_mask) and not has_context:
input_mask = mask
if (exists(input_mask) or exists(mem_mask)) and exists(mem):
seq_len, mem_len = n, mem.shape[-2]
if not exists(mem_mask):
input_mask = pad_at_dim(input_mask, (mem_len, 0), dim = -1, value = True)
elif not exists(input_mask):
input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
else:
input_mask = cat((mem_mask, input_mask), dim = -1)
# i, j determined for relative positional bias, excluding memory key / values
i, j = tuple(t.shape[-2] for t in (q, k))
# maybe append memory key / values
if num_mem_kv > 0:
mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
if self.qk_norm:
mem_k = l2norm(mem_k)
mem_k = mem_k * self.qk_norm_k_scale
k = cat((mem_k, k), dim = -2)
v = cat((mem_v, v), dim = -2)
if exists(input_mask):
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
# determine masking
mask_value = max_neg_value(q)
masks = []
final_attn_mask = None
if exists(input_mask):
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
if exists(attn_mask):
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
if attn_mask.ndim == 2:
attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
elif attn_mask.ndim == 3:
attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
masks.append(~attn_mask)
if exists(self.max_attend_past):
range_q = arange(j - i, j, device = device)
range_k = arange(j, device = device)
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
max_attend_past_mask = dist > self.max_attend_past
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
masks.append(max_attend_past_mask)
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
# prepare relative positional bias, if needed
if exists(rel_pos):
assert not exists(attn_bias)
if exists(pos):
assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
# allow for custom positions to be passed in
attn_bias = rel_pos.forward_custom_pos(pos)
else:
attn_bias = rel_pos(i, j)
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0)) # handle memory key / values
# prepare data dependent alibi from forgetting transformers paper, if needed
if exists(self.data_dependent_alibi):
attn_bias = self.data_dependent_alibi(x)
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
if self.laser:
v = softclamp(v, self.laser_softclamp_value)
v = v.exp()
# attention is all we need
out, intermediates = self.attend(
q, k, v,
mask = final_attn_mask,
attn_bias = attn_bias,
prev_attn = prev_attn
)
# laser
if self.laser:
out = log(out)
# store the values for resformer
intermediates.values = orig_values
# normformer scaling of heads
if head_scale:
out = out * self.head_scale_params
# per head gating, from https://arxiv.org/abs/2306.12929
if exists(self.to_v_head_gate):
head_gate = self.to_v_head_gate(x)
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
# if exists hybrid module, must do a normalization
# hybrid module
if exists(self.hybrid_module):
# hybrid input
hybrid_forward_kwargs = dict()
if not self.causal and exists(self.hybrid_mask_kwarg):
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
# hybrid forward
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
# handle hybrid out
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
if hybrid_out.ndim == 3:
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
out_norm, hybrid_out_norm = self.hybrid_norms
out = out_norm(out)
hybrid_out = hybrid_out_norm(hybrid_out)
if exists(self.hybrid_mix):
mix = self.hybrid_mix(x)
mix = rearrange(mix, 'b n h -> b h n 1')
out = out.lerp(hybrid_out, mix.sigmoid())
else:
out = 0.5 * (out + hybrid_out)
# merge heads
out = self.merge_heads(out)
# alphafold2 styled gating of the values
if exists(self.to_v_gate):
gates = self.to_v_gate(x)
out = out * self.to_v_gate_activation(gates)
# combine the heads
out = self.to_out(out)
# maybe sublayer dropout
out = maybe(self.sublayer_dropout)(out)
if exists(mask):
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
if not return_intermediates:
return out
intermediates.cached_kv = cached_kv
return out, intermediates
class AttentionLayers(Module):
def __init__(
self,
dim,
depth = None,
heads = 8,
causal = False,
cross_attend = False,
only_cross = False,
use_scalenorm = False,
use_rmsnorm = False,
use_dynamic_tanh = False,
dynamic_tanh_init_alpha = 1.,
use_simple_rmsnorm = False,
use_adaptive_layernorm = False,
use_adaptive_rmsnorm = False,
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
norm_add_unit_offset = True,
dim_condition = None,
adaptive_condition_mlp = False,
adaptive_condition_mlp_expansion = 4,
alibi_pos_bias = False,
alibi_num_heads = None,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
dynamic_pos_bias = False,
dynamic_pos_bias_log_distance = False,
dynamic_pos_bias_mlp_depth = 2,
dynamic_pos_bias_norm = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
rotary_xpos = False,
rotary_interpolation_factor = 1.,
rotary_xpos_scale_base = 512,
rotary_base_rescale_factor = 1.,
rotate_num_heads = None,
weight_tie_layers = False,
custom_layers: tuple[str, ...] | None = None,
layers_execute_order: tuple[int, ...] | None = None,
sandwich_coef = None,
par_ratio = None,
residual_attn = False,
cross_residual_attn = False,
macaron = False,
pre_norm = True,
pre_norm_has_final_norm = True,
gate_residual = False,
scale_residual = False,
scale_residual_constant = 1.,
shift_tokens = 0,
sandwich_norm = False,
softclamp_output = False,
softclamp_output_value = 30.,
zero_init_branch_output = False,
layer_dropout = 0.,
cross_attn_tokens_dropout = 0.,
disable_abs_pos_emb = None,
use_layerscale = False,
layerscale_init_value = 0.,
unet_skips = False,
integrate_layers = False,
layer_integrate_use_softmax = True,
num_residual_streams = 1,
qkv_receive_diff_residuals = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
learned_reinject_input_gate = False,
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
rel_pos_kwargs: dict = dict(),
residual_fn_kwargs: dict = dict(),
**kwargs
):
super().__init__()
rotary_pos_emb = rotary_pos_emb or rotary_xpos
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
self.dim = dim
self.causal = causal
self.layers = ModuleList([])
# routing related
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
# hyper connections
assert num_residual_streams > 0
has_hyper_connections = num_residual_streams > 1
self.num_residual_streams = num_residual_streams
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
assert not (has_hyper_connections and gate_residual)
hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
# LIMe
hiddens_counter = 0
self.layer_integrators = ModuleList([])
assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
# positions related
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'
if rotary_emb_dim < 32:
print('when training language model, rotary embedding dimension should be at least 32')
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
# relative positional bias
flash_attn = attn_kwargs.get('flash', False)
assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
self.rel_pos = None
if rel_pos_bias:
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
elif dynamic_pos_bias:
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
elif alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
self.cross_attend = cross_attend
# determine norm
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
norm_need_condition = False
dim_condition = default(dim_condition, dim)
dim_condition_mult = 1
if adaptive_condition_mlp:
dim_condition_mult = adaptive_condition_mlp_expansion
if use_scalenorm:
norm_class = ScaleNorm
elif use_rmsnorm:
norm_class = RMSNorm
elif use_simple_rmsnorm:
norm_class = SimpleRMSNorm
elif use_dynamic_tanh:
assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
elif use_adaptive_layernorm:
norm_need_condition = True
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
elif use_adaptive_rmsnorm:
norm_need_condition = True
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
else:
norm_class = LayerNorm
norm_fn = partial(norm_class, dim)
if not norm_need_condition and norm_add_unit_offset:
# researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
norm_fn = partial(norm_fn, unit_offset = True)
self.norm_need_condition = norm_need_condition
self.dim_condition = dim_condition
# determine default block layer type order
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
# determine post branch wrapper
assert at_most_one_of(use_layerscale, use_adaptive_layerscale)
post_branch_fn = None
post_branch_fn_needs_condition = False
if use_layerscale:
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
elif use_adaptive_layerscale:
post_branch_fn = partial(AdaptiveLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
post_branch_fn_needs_condition = True
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
post_branch_fn = partial(post_branch_fn, unit_offset = True)
# setup mlp for conditioning
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
self.adaptive_mlp = nn.Identity()
if self.need_condition and adaptive_condition_mlp:
self.adaptive_mlp = nn.Sequential(
LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
nn.SiLU()
)
# zero init
if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# setup weight tying, which is a special case of `layer_execute_order`
assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
if weight_tie_layers:
assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
assert not exists(layers_execute_order)
layers_execute_order = tuple(range(len(default_block))) * depth
depth = 1
# calculate layer block order
len_default_block = 1
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
layer_types = default_block * depth
len_default_block = len(default_block)
self.layer_types = layer_types
self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
assert all([i < len(self.layer_types) for i in self.layers_execute_order])
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
# set the depth
depth = default(depth, len(self.layers_execute_order))
self.depth = depth
# stochastic depth
self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
# structured dropout for cross attending
self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
# calculate token shifting
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
# optional soft clamping just before the final norm
# used in gemma 2
self.softclamp_output = softclamp_output
self.softclamp_output_value = softclamp_output_value
# whether it has post norm
self.final_norm = norm_fn() if pre_norm else nn.Identity()
# whether unet or not
self.unet_skips = unet_skips
num_skips = self.depth // len_default_block
assert not (unet_skips and num_skips == 0), 'must have depth of at least 2 for unet skip connections'
skip_indices = [i * len_default_block for i in range(num_skips)]
self.skip_combines = ModuleList([])
# whether there is reinjection of input at every layer
self.reinject_input = reinject_input
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
# add the value from the first self attention block to all latter projected self attention values as a residual
self.add_value_residual = add_value_residual
is_first_self_attn = True
is_first_cross_attn = True
learned_value_residual_mix &= add_value_residual
# iterate and construct layers
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
# `ind` is the index of each module - attention, feedforward, cross attention
# but `block_ind` refers to the typical enumeration of a transformer block (attn + ff + [optional] cross attn)
block_begin = divisible_by(ind, len_default_block)
block_ind = ind // len_default_block
is_last_layer = ind == (len(self.layer_types) - 1)
# attention, cross attention, feedforward
layer_qkv_receives_diff_view = layer_type == 'a' and qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
if layer_type == 'a':
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = layer_qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
is_first_self_attn = False
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
is_first_cross_attn = False
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1
shift_range_lower = -layer_shift_tokens if not causal else 0
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
if exists(post_branch_fn):
layer = post_branch_fn(layer)
layer_integrate = None
if integrate_layers:
num_layer_hiddens = ind + 1
layer_integrate_num_view = 3 if layer_qkv_receives_diff_view else 1
layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
if has_hyper_connections:
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
if layer_type == 'a' and hyper_conn_produce_diff_views:
residual_fn = partial(residual_fn, num_input_views = 3)
elif gate_residual:
residual_fn = GRUGating
else:
residual_fn = Residual
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
# handle unet skip connection
skip_combine = None
is_latter_half = block_begin and block_ind >= (self.depth / 2)
if self.unet_skips and is_latter_half:
skip_combine = ConcatCombine(dim, skip_indices.pop())
# all normalizations of the layer
pre_branch_norm = norm_fn() if pre_norm else None
post_branch_norm = norm_fn() if sandwich_norm else None
post_main_norm = norm_fn() if not pre_norm else None
norms = ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])
self.skip_combines.append(skip_combine)
self.layer_integrators.append(layer_integrate)
self.layers.append(ModuleList([
norms,
layer,
residual
]))
# determine whether can cache kv
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
self_attn_kv_mask = None,
mems = None,
mem_masks = None,
seq_start_pos: Tensor | None = None,
cache: LayerIntermediates | None = None,
cache_age = 1,
return_hiddens = False,
rotary_pos_emb = None,
pos = None,
context_pos = None,
attn_bias = None,
condition = None,
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
layers_execute_order: tuple[int, ...] | None = None
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
# handle condition
if exists(condition):
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
assert condition.ndim in {2, 3}
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
condition = self.adaptive_mlp(condition)
# setup maybe layernorm kwarg
norm_kwargs = dict()
if self.norm_need_condition:
norm_kwargs.update(condition = condition)
# maybe post branch fn conditioning (DiT paper's ada-ln-zero)
block_forward_kwargs = dict()
if self.post_branch_fn_needs_condition:
block_forward_kwargs.update(condition = condition)
# initialize accums
hiddens = []
layer_hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
mem_masks = mem_masks.copy() if exists(mem_masks) else [None] * self.num_attn_layers
# handle left padded sequences
if exists(seq_start_pos):
seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
left_pad_mask = seq_arange >= seq_start_pos[..., None]
if exists(self_attn_kv_mask):
self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
else:
self_attn_kv_mask = left_pad_mask
# rotary positions
cross_attn_rotary_pos_emb = dict()
if exists(self.rotary_pos_emb):
if not exists(rotary_pos_emb):
maybe_mem = first(mems, None) # todo - handle edge case where different layers get different memory lengths. don't think this will ever come up but who knows
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
if not exists(pos):
pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
rotary_pos_emb = self.rotary_pos_emb(pos)
# allow for rotary positions for context if provided
if exists(context_pos):
assert self.cross_attend
context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
cross_attn_rotary_pos_emb.update(
rotary_pos_emb = rotary_pos_emb,
context_rotary_pos_emb = context_rotary_pos_emb
)
# assume cached key / values
attn_cache = []
if exists(cache):
assert self.causal and not any([*map(exists, (mask, attn_mask))])
if exists(context):
context = context[:, :0]
if cache_age > 0:
x = x[:, -cache_age:] # for spec decoding, may be greater than 1
attn_cache = cache.attn_intermediates
iter_attn_cache = iter(attn_cache)
# setup multistreams if needed
streams = self.num_residual_streams
is_multistream = streams > 1
if is_multistream:
x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
# get layers to be executed
layer_variables = (
self.layer_types,
self.skip_combines,
self.layers,
self.layer_dropouts,
self.layer_integrators
)
# able to override the layers execution order on forward, for trying to depth extrapolate
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
# derived input for reinjection if needed
inp_inject = None
if self.reinject_input:
assert not exists(in_attn_cond)
inp_inject = self.reinject_input_proj(x)
elif exists(in_attn_cond):
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
if exists(inp_inject) and exists(self.learned_reinject_input_gate):
inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
inp_inject = inp_inject * inp_inject_gate
# store all hiddens for skips
skip_hiddens = []
# for value residuals
first_self_attn_inter = None
first_cross_attn_inter = None
# go through the attention and feedforward layers
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
is_last = ind == (len(self.layers) - 1)
# handle skip connections
skip_hiddens.append(x)
if exists(skip_combine):
x = skip_combine(x, skip_hiddens)
# layer dropout
if self.training and layer_dropout > 0. and random() < layer_dropout:
continue
if layer_type == 'a':
if return_hiddens:
hiddens.append(x)
layer_mem = mems.pop(0) if mems else None
layer_mem_mask = mem_masks.pop(0) if mem_masks else None
if layer_type == 'c':
if self.training and self.cross_attn_tokens_dropout > 0.:
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
layer_hiddens.append(x)
if exists(layer_integrator):
x = layer_integrator(x, layer_hiddens)
pre_norm, post_branch_norm, post_main_norm = norm
if self.need_condition:
pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
if exists(inp_inject):
x = x + inp_inject
if exists(pre_norm):
x = pre_norm(x)
if layer_type == 'a' and exists(layer_mem):
layer_mem = pre_norm(layer_mem)
block = partial(block, **block_forward_kwargs)
# handle maybe value residuals
maybe_self_attn_value_residual = None
maybe_cross_attn_value_residual = None
if self.add_value_residual:
if exists(first_self_attn_inter):
maybe_self_attn_value_residual = first_self_attn_inter.values
if exists(first_cross_attn_inter):
maybe_cross_attn_value_residual = first_cross_attn_inter.values
# forward depending on layer type
if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
elif layer_type == 'f':
out = block(x)
# store first self or cross attention intermediate for value residual
if not exists(first_self_attn_inter) and layer_type == 'a':
first_self_attn_inter = inter
if not exists(first_cross_attn_inter) and layer_type == 'c':
first_cross_attn_inter = inter
if exists(post_branch_norm):
out = post_branch_norm(out)
x = residual_fn(out, inner_residual, **residual_kwargs)
if layer_type in ('a', 'c') and return_hiddens:
inter.layer_type = layer_type
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if exists(post_main_norm):
x = post_main_norm(x)
if return_hiddens:
layer_hiddens.append(x)
if self.softclamp_output:
x = softclamp(x, self.softclamp_output_value)
final_norm = self.final_norm
if self.need_condition:
final_norm = maybe(partial)(final_norm, **norm_kwargs)
# take care of multistreams if needed, use sum for now
if is_multistream:
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
x = final_norm(x)
if not return_hiddens:
return x
intermediates = LayerIntermediates(
hiddens = hiddens,
last_hidden = x,
attn_intermediates = intermediates,
layer_hiddens = layer_hiddens,
)
return x, intermediates
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal = False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal = True, **kwargs)
class PrefixDecoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal = False, **kwargs)
def forward(
self,
x,
*args,
attn_mask = None,
prefix_attn_len = None,
**kwargs
):
b, n, device = x.shape[0], x.shape[1], x.device
causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
forwarded_mask = ~causal_mask
if exists(prefix_attn_len):
if isinstance(prefix_attn_len, int):
prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
forwarded_mask = forwarded_mask | prefix_mask
if exists(attn_mask):
forwarded_mask = forwarded_mask & attn_mask
return super().forward(x, *args, attn_mask = forwarded_mask, **kwargs)
class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend = True, only_cross = True, **kwargs)
class ViTransformerWrapper(Module):
def __init__(
self,
*,
image_size,
patch_size,
attn_layers: Encoder,
channels = 3,
num_classes = None,
post_emb_norm = False,
num_register_tokens = 0,
emb_dropout = 0.
):
super().__init__()
assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
has_register_tokens = num_register_tokens > 0
self.has_register_tokens = has_register_tokens
if has_register_tokens:
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.patch_to_embedding = nn.Sequential(
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
LayerNorm(dim)
)
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
self.dropout = nn.Dropout(emb_dropout)
self.attn_layers = attn_layers
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
def forward(
self,
img,
return_embeddings = False,
return_logits_and_embeddings = False
):
b, p = img.shape[0], self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
n = x.shape[1]
x = x + self.pos_embedding[:, :n]
x = self.post_emb_norm(x)
x = self.dropout(x)
if self.has_register_tokens:
r = repeat(self.register_tokens, 'n d -> b n d', b = b)
x, ps = pack((x, r), 'b * d')
embed = self.attn_layers(x)
if self.has_register_tokens:
embed, _ = unpack(embed, ps, 'b * d')
assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
if not exists(self.mlp_head) or return_embeddings:
return embed
pooled = embed.mean(dim = -2)
logits = self.mlp_head(pooled)
if not return_logits_and_embeddings:
return logits
return logits, embed
class TransformerWrapper(Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers: AttentionLayers,
embed_num_tokens: dict[str, int] = dict(),
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
emb_dropout = 0.,
post_emb_norm = False,
num_memory_tokens = None,
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
return_only_embed = False,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
l2norm_embed = False,
recycling = False, # from Jumper et al. - Alphafold2
train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
attn_z_loss_weight = 1e-4,
average_pool_embed = False,
use_cls_token = False,
num_cls_tokens = 1,
squeeze_out_last_dim = False,
token_emb: TokenEmbedding | None = None,
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
sigsoftmax_logits = False,
to_logits: Module | None = None,
):
super().__init__()
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.emb_dim = emb_dim
self.num_tokens = num_tokens
self.num_cls_tokens = num_cls_tokens
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
self.l2norm_embed = l2norm_embed
if not exists(token_emb):
token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
self.token_emb = token_emb
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
if no_abs_pos_emb:
self.pos_emb = always(0)
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
else:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
# additional embeddings - say type embedding from BERT
self.embeds = None
if len(embed_num_tokens) > 0:
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
self.emb_frac_gradient = emb_frac_gradient
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.init_()
assert num_output_heads > 0
assert at_most_one_of(average_pool_embed, use_cls_token)
# maybe recycling
self.recycling = recycling
self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
self.train_max_recycle_steps = train_max_recycle_steps
# classic cls token from the bert days
self.cls_token = None
if use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
nn.init.normal_(self.cls_token, std = 0.02)
# whether to average pool the embed (`global average pool`)
self.average_pool_embed = average_pool_embed
# output type
self.output_is_log_prob = mixture_of_softmax
self.to_mixture = None
self.combine_mixture = None
if mixture_of_softmax:
assert num_output_heads == 1
self.to_mixture = Sequential(
LinearNoBias(dim, dim * mixture_of_softmax_k),
Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
)
self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
# sig softmax
self.sigsoftmax_logits = sigsoftmax_logits
# output head, usually to logits of num_tokens
logits_dim = default(logits_dim, num_tokens)
self.has_multiple_heads = num_output_heads > 1
if return_only_embed:
self.to_logits = None
elif tie_embedding:
assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
elif num_output_heads > 1:
self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
else:
self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
# squeeze out last dimension if possible
self.squeeze_out_last_dim = squeeze_out_last_dim
# whether can do cached kv decoding
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
def init_(self):
if hasattr(self.token_emb, 'init_'):
self.token_emb.init_()
if self.l2norm_embed:
if not isinstance(self.pos_emb, always):
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
def forward(
self,
x,
return_embeddings = False,
return_logits_and_embeddings = False,
return_intermediates = False,
return_embeddings_and_intermediates = False,
return_logit_entropies = False,
mask = None,
return_mems = False,
return_attn = False,
mems = None,
mem_masks = None,
recycle_steps = None,
pos = None,
prepend_embeds = None,
prepend_mask = None,
embed_ids: dict[str, Tensor] = dict(),
sum_embeds = None,
return_attn_z_loss = False,
attn_z_loss_weight = 1e-4,
seq_start_pos = None,
cache: LayerIntermediates | None = None,
token_emb_kwargs = dict(),
to_logits_kwargs = dict(),
**kwargs,
):
# if sequence is None, auto create an empty one if `prepend_embeds` was supplied
if not exists(x):
assert exists(prepend_embeds)
x = prepend_embeds.new_empty((prepend_embeds.shape[0], 0), dtype = torch.long)
# shapes and variables
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
# absolute positional embedding
external_pos_emb = exists(pos) and pos.dtype != torch.long
pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
x = self.token_emb(x, **token_emb_kwargs) + pos_emb
# add additional embeddings
assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
if exists(self.embeds):
assert len(embed_ids) == len(self.embeds)
for name, embed_id in embed_ids.items():
embed_key = f'{name}_embed'
assert embed_key in self.embeds
embed = self.embeds[embed_key](embed_id)
x = x + embed
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
if exists(sum_embeds):
x = x + sum_embeds
# post embedding norm, purportedly leads to greater stabilization
x = self.post_emb_norm(x)
# whether to append embeds, as in PaLI, for image embeddings
if exists(prepend_embeds):
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
x = cat((prepend_embeds, x), dim = -2)
if exists(prepend_mask) or exists(mask):
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
mask = cat((prepend_mask, mask), dim = -1)
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
if emb_frac_gradient < 1:
assert emb_frac_gradient > 0
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
# embedding dropout
x = self.emb_dropout(x)
x = self.project_emb(x)
# maybe cls token
if exists(self.cls_token):
cls_tokens = repeat(self.cls_token, '... -> b ...', b = b)
x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
if exists(mask):
mask = F.pad(mask, (self.num_cls_tokens, 0), value = True)
# maybe memory / register tokens
if has_memory_tokens:
mem_seq = x.shape[-2]
mem_every = self.memory_tokens_interspersed_every
if exists(mem_every):
assert mem_every > 0
assert isinstance(self.attn_layers, Decoder), 'only for decoder'
next_seq_len = math.ceil(n / mem_every) * mem_every
x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
x, mem_packed_shape = pack((mem, x), 'b * d')
# auto-handle masking after appending memory tokens
if not exists(mem_every) and exists(mask):
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
if exists(mem_every):
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
# handle maybe shifting of memories
if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l]
# attention layers
if not self.recycling:
assert not exists(recycle_steps) or recycle_steps == 1, 'you did not train with recycling'
# regular
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
else:
# recycling
recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
for i in range(recycle_steps):
first_step = i == 0
last_step = i == (recycle_steps - 1)
context = nullcontext if last_step else torch.no_grad
with context():
maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
x = attended
# handle memories post-attention
if has_memory_tokens:
if exists(mem_every):
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
mem, x = unpack(x, mem_packed_shape, 'b * d')
intermediates.memory_tokens = mem
if exists(mem_every):
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
x = x[:, :mem_seq]
# global average pool
if self.average_pool_embed:
x = masked_mean(x, mask = orig_mask, dim = 1)
if exists(self.cls_token):
x, _ = unpack(x, cls_packed_shape, 'b * d')
x = x.squeeze(1) # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
# handle expansion to mixture if needed (for mixture of softmax)
combine_mixture = None
if exists(self.to_mixture):
combine_mixture = self.combine_mixture(x).softmax(dim = -1)
x = self.to_mixture(x)
# projecting to logits
if not return_embeddings:
if self.has_multiple_heads:
logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
else:
logits = self.to_logits(x, **to_logits_kwargs)
# maybe sig softmax
if self.sigsoftmax_logits:
logits = logits + logits.sigmoid().log()
# handle maybe combine mixture
if exists(combine_mixture):
with autocast('cuda', enabled = False):
prob = logits.softmax(dim = -1)
mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
logits = log(mos)
# maybe squeeze out last dimension of logits
if self.squeeze_out_last_dim:
logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
if not self.has_multiple_heads:
logits = first(logits)
# different returns
if return_logits_and_embeddings:
out = (logits, x)
elif return_embeddings_and_intermediates:
out = (x, intermediates)
elif return_embeddings:
out = x
else:
out = logits
# logit entropies
if return_logit_entropies:
intermediates.logit_entropies = calc_entropy(logits)
return_intermediates = True
# aux loss
if return_attn_z_loss:
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
return_intermediates = True
if return_mems:
hiddens = intermediates.hiddens
new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
if not return_intermediates:
return out, new_mems
intermediates.mems = new_mems
if return_intermediates:
return out, intermediates
if return_attn:
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
return out, attn_maps
return out
class XTransformer(Module):
def __init__(
self,
*,
dim,
tie_token_emb = False,
ignore_index = -100,
pad_value = 0,
cross_attn_tokens_dropout = 0.,
**kwargs
):
super().__init__()
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
self.encoder = TransformerWrapper(
**enc_transformer_kwargs,
return_only_embed = True,
attn_layers = Encoder(dim = dim, **enc_kwargs)
)
self.decoder = TransformerWrapper(
**dec_transformer_kwargs,
attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
)
if tie_token_emb:
self.decoder.token_emb = self.encoder.token_emb
self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
if exists(src_prepend_embeds) and exists(mask):
mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
if self.training and self.cross_attn_tokens_dropout > 0:
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
out = self.decoder(tgt, context = enc, context_mask = mask)
return out
#=================================================================================================================================
# autoregressive_wrapper.py
#=================================================================================================================================
from math import ceil, log
from typing import Tuple, Callable
import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from einops import rearrange, pack, unpack
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def join(arr, delimiter = ', '):
return delimiter.join(arr)
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else (t,) * length
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# for variable lengthed prefixes
def align_right(t, lens, pad_id = 0):
batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
assert lens.ndim == 1 and lens.shape[0] == batch
assert lens.amax() <= seq_len
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
t = F.pad(t, (max_pad_len, 0), value = pad_id)
offset = max_pad_len - pad_lens
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
return aligned
# nucleus
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending = True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# topk
def top_k(logits, frac_num_tokens = 0.1, k = None):
num_tokens = logits.shape[-1]
k = default(k, ceil(frac_num_tokens * num_tokens))
k = min(k, num_tokens)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# top_a
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
return torch.where(probs < limit, float('-inf'), logits)
# min_p
# https://arxiv.org/abs/2407.01082
def min_p(logits, min_p = 0.1):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = min_p * max_probs
return torch.where(probs < limit, float('-inf'), logits)
# filter logits functions dict[str -> Callable]
FILTER_LOGITS_FN = dict(
top_p = top_p,
top_k = top_k,
top_a = top_a,
min_p = min_p
)
# contrastive decoding function
def contrastive_decode_fn(
expert_logits,
amateur_logits,
alpha = 0.1,
beta = 0.5
):
"""
Appendix A Algorithm 2
https://arxiv.org/abs/2309.09117
"""
cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
return contrastive_decode_logits
# autoregressive wrapper class
class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.,
add_attn_z_loss = False
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.max_seq_len
# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
assert mask_prob < 1.
self.mask_prob = mask_prob
# whether to add router z-loss
self.add_attn_z_loss = add_attn_z_loss
@torch.inference_mode()
@eval_decorator
def generate(
self,
prompts,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
return_prime=False,
verbose=True,
**kwargs
):
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, ps = pack([prompts], '* n')
b, t = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
if verbose:
print("Generating sequence of max length:", seq_len)
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# sampling up to seq_len
for sl in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1]
# handle contrastive decoding, Li et al.
# https://arxiv.org/abs/2210.15097
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if verbose:
if sl % 32 == 0:
print(sl, '/', seq_len)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
if return_prime:
out = out[:, :]
else:
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
@torch.inference_mode()
@eval_decorator
def generate_masked(
self,
prompts,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
return_prime=False,
verbose=True,
masked_token_ids: list[int] | Tensor | None = None,
**kwargs
):
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, ps = pack([prompts], '* n')
b, t = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# prepare masked token ids tensor (if any)
if masked_token_ids is not None:
if not torch.is_tensor(masked_token_ids):
masked_token_ids = torch.tensor(masked_token_ids, dtype=torch.long, device=device)
else:
masked_token_ids = masked_token_ids.to(device=device, dtype=torch.long)
# keep unique and non-negative
masked_token_ids = torch.unique(masked_token_ids)
# remove any ids that are out of range (optional safety)
# we can't know vocab size here, so we only remove negative ids
masked_token_ids = masked_token_ids[masked_token_ids >= 0]
else:
masked_token_ids = None
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
if verbose:
print("Generating sequence of max length:", seq_len)
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# sampling up to seq_len
for sl in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1]
# handle contrastive decoding, Li et al.
# https://arxiv.org/abs/2210.15097
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# --- apply masked token ids here (after contrastive decoding, before filtering/sampling)
if masked_token_ids is not None and masked_token_ids.numel() > 0:
# safety: ensure indices are within logits' vocab dimension
vocab_size = logits.shape[-1]
valid_masked = masked_token_ids[masked_token_ids < vocab_size]
if valid_masked.numel() > 0:
# set logits for masked ids to a very large negative value
neg_inf = -1e9
# logits shape: (batch, vocab)
logits[:, valid_masked] = neg_inf
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if verbose:
if sl % 32 == 0:
print(sl, '/', seq_len)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
if return_prime:
out = out[:, :]
else:
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
@torch.inference_mode()
@eval_decorator
def generate_biased(
self,
prompts,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
return_prime=False,
verbose=True,
logit_bias: dict | Tensor | None = None, # <-- new parameter
**kwargs
):
"""
Autoregressive generation with optional additive logit bias.
logit_bias:
- dict[token_id -> float] OR
- torch.Tensor of shape (vocab,) OR (batch, vocab)
"""
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, ps = pack([prompts], '* n')
b, t = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
if verbose:
print("Generating sequence of max length:", seq_len)
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# -------------------------
# Prepare logit_bias (robust vocab-size detection)
# -------------------------
prepared_bias = None
lazy_build_bias_from_dict = None
if exists(logit_bias):
if isinstance(logit_bias, dict):
# try to determine vocab size from model without using logits
vocab_size = None
# common places to find vocab size
try:
if hasattr(self.net, "config") and getattr(self.net.config, "vocab_size", None) is not None:
vocab_size = int(self.net.config.vocab_size)
elif getattr(self.net, "vocab_size", None) is not None:
vocab_size = int(self.net.vocab_size)
else:
# try to infer from embedding / output projection weights
# huggingface style: get_output_embeddings() or embed_tokens or lm_head
get_out = getattr(self.net, "get_output_embeddings", None)
if callable(get_out) and get_out() is not None:
vocab_size = int(get_out().weight.shape[0])
elif hasattr(self.net, "embed_tokens"):
vocab_size = int(self.net.embed_tokens.weight.shape[0])
elif hasattr(self.net, "lm_head"):
vocab_size = int(self.net.lm_head.weight.shape[0])
except Exception:
vocab_size = None
if vocab_size is not None:
bias_vec = torch.zeros(int(vocab_size), device=device, dtype=torch.float32)
for tok, val in logit_bias.items():
tok_i = int(tok)
if tok_i < 0 or tok_i >= vocab_size:
raise IndexError(f"logit_bias token id {tok_i} out of range for vocab size {vocab_size}")
bias_vec[tok_i] = float(val)
prepared_bias = bias_vec
else:
# can't determine vocab size yet — build lazily after first logits are available
lazy_build_bias_from_dict = {int(k): float(v) for k, v in logit_bias.items()}
elif isinstance(logit_bias, torch.Tensor):
prepared_bias = logit_bias.to(device=device, dtype=torch.float32)
else:
raise TypeError("logit_bias must be dict or torch.Tensor")
# sampling up to seq_len
for sl in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
else:
x = out
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1] # shape (batch, vocab)
# If we couldn't build the bias earlier because vocab size was unknown,
# build it now from the first logits tensor.
if lazy_build_bias_from_dict is not None:
vocab_size = logits.shape[-1]
bias_vec = torch.zeros(vocab_size, device=device, dtype=torch.float32)
for tok, val in lazy_build_bias_from_dict.items():
if tok < 0 or tok >= vocab_size:
raise IndexError(f"logit_bias token id {tok} out of range for vocab size {vocab_size}")
bias_vec[tok] = val
prepared_bias = bias_vec
lazy_build_bias_from_dict = None # only build once
# handle contrastive decoding, Li et al.
# https://arxiv.org/abs/2210.15097
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# -------------------------
# Apply logit bias if provided
# -------------------------
if exists(prepared_bias):
# prepared_bias can be (vocab,) or (batch, vocab)
if prepared_bias.dim() == 1:
# broadcast to batch
logits = logits + prepared_bias.unsqueeze(0)
elif prepared_bias.dim() == 2:
# expect shape (batch, vocab)
if prepared_bias.shape[0] != logits.shape[0]:
raise ValueError("logit_bias tensor batch size must match logits batch size")
logits = logits + prepared_bias
else:
raise ValueError("logit_bias tensor must be 1D (vocab,) or 2D (batch, vocab)")
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if verbose:
if sl % 32 == 0:
print(sl, '/', seq_len)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
if return_prime:
out = out[:, :]
else:
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
@torch.inference_mode()
@eval_decorator
def generate_advanced(
self,
prompts,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
return_prime=False,
verbose=True,
# --- new generation options ---
logits_bias: dict | None = None, # {token_id: bias_value} where bias_value is float or Tensor(batch,)
masked_tokens: list | Tensor | None = None, # list of token ids to forbid
# --- binary classifier mode ---
binary_classifier: bool = False, # if True, run classifier snippet and return preds, probs
classifier_model: Module | None = None, # model to use for binary classification
batches: list | None = None, # iterable of input batches for classifier_model
threshold: float = 0.5, # threshold for converting probs to preds
classifier_device: torch.device | None = None,
# -----------------
**kwargs
):
# If binary classifier mode requested, run the provided snippet and return early.
if binary_classifier:
assert classifier_model is not None, "classifier_model must be provided when binary_classifier=True"
assert batches is not None, "batches (iterable of input tensors) must be provided when binary_classifier=True"
device = classifier_device if classifier_device is not None else (prompts.device if exists(prompts) else torch.device('cpu'))
all_probs = []
all_preds = []
classifier_model.eval()
with torch.no_grad():
for x in batches:
x = x.to(device)
logits = classifier_model(x).squeeze() # [B]
probs = torch.sigmoid(logits) # [B]
preds = (probs >= threshold).long()
all_probs.extend(probs.cpu().tolist())
all_preds.extend(preds.cpu().tolist())
return all_preds, all_probs
# --- normal generation path below ---
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, ps = pack([prompts], '* n')
b, t = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
if verbose:
print("Generating sequence of max length:", seq_len)
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# normalize inputs for new args
if exists(logits_bias):
assert isinstance(logits_bias, dict), "logits_bias must be a dict {token_id: bias_value}"
if exists(masked_tokens):
if isinstance(masked_tokens, torch.Tensor):
masked_tokens = masked_tokens.tolist()
else:
masked_tokens = list(masked_tokens)
# sampling up to seq_len
for sl in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1] # shape: (batch, vocab)
# handle contrastive decoding, Li et al.
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# --- APPLY LOGITS BIAS AND MASKING HERE (before filtering / softmax) ---
# logits_bias: dict {token_id: bias_value} where bias_value is float or Tensor(batch,)
if exists(logits_bias):
# apply per-token bias updates directly to logits to avoid allocating full vocab bias tensor
for tok_id, bias_val in logits_bias.items():
# support scalar or per-batch tensor
if isinstance(bias_val, torch.Tensor):
if bias_val.dim() == 1 and bias_val.shape[0] == b:
bias_to_add = bias_val.to(device)
else:
bias_to_add = bias_val.to(device).view(1).expand(b)
else:
bias_to_add = torch.tensor(float(bias_val), device=device).view(1).expand(b)
logits[:, int(tok_id)] = logits[:, int(tok_id)] + bias_to_add
# masked_tokens: list of token ids to forbid
if exists(masked_tokens) and len(masked_tokens) > 0:
NEG_INF = -1e9
idx = torch.tensor(masked_tokens, device=device, dtype=torch.long)
idx = idx[(idx >= 0) & (idx < logits.shape[-1])]
if idx.numel() > 0:
logits.index_fill_(dim=-1, index=idx, value=NEG_INF)
# -------------------------------------------------------------------
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if verbose:
if sl % 32 == 0:
print(sl, '/', seq_len)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
if return_prime:
out = out[:, :]
else:
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
def compute_accuracy(self, logits, labels):
out = torch.argmax(logits, dim=-1)
out = out.flatten()
labels = labels.flatten()
mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
out = out[mask]
labels = labels[mask]
num_right = (out == labels)
num_right = torch.sum(num_right).type(torch.float32)
acc = num_right / len(labels)
return acc
def forward(self, x, return_outputs = False, **kwargs):
seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
inp, target = x[:, :-1], x[:, 1:]
inp = torch.where(inp == ignore_index, self.pad_value, inp)
if self.mask_prob > 0.:
rand = torch.randn(inp.shape, device = x.device)
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim = -1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
kwargs.update(self_attn_kv_mask = mask)
logits, cache = self.net(
inp,
return_intermediates = True,
return_attn_z_loss = add_attn_z_loss,
**kwargs
)
acc = self.compute_accuracy(logits, target)
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
target,
ignore_index = ignore_index
)
if add_attn_z_loss:
loss = loss + cache.attn_z_loss
if not return_outputs:
return loss, acc
return loss, acc, logits, cache
@torch.inference_mode()
@eval_decorator
def generate_expert(
self,
prompts,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
return_prime=False,
verbose=True,
# --- new controls ---
token_type_ids: torch.LongTensor | None = None, # [vocab]
type_temperatures: dict | None = None, # {type_id: temp}
type_biases: dict | None = None, # {type_id: bias}
repetition_window: int = 64,
repetition_penalty_per_type: dict | None = None, # {type_id: penalty_scale}
rare_types: set | None = None, # e.g. {4, 5}
rare_type_boost: float = 0.0, # small, e.g. 0.5
entropy_threshold: float = 2.0, # when below, boost rare types
# --- masked tokens option ---
forbidden_token_ids: torch.LongTensor | torch.BoolTensor | None = None,
forbidden_value: float = -1e9,
**kwargs
):
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, ps = pack([prompts], '* n')
b, t = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
if verbose:
print("Generating sequence of max length:", seq_len)
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# precompute some tensors for type controls
if token_type_ids is not None:
token_type_ids = token_type_ids.to(device)
# build per-token temperature and bias vectors if provided
per_token_temp = None
if type_temperatures is not None and len(type_temperatures) > 0:
per_token_temp = torch.ones_like(token_type_ids, dtype=torch.float32)
for type_id, temp_val in type_temperatures.items():
per_token_temp[token_type_ids == type_id] = float(temp_val)
per_token_bias = None
if type_biases is not None and len(type_biases) > 0:
per_token_bias = torch.zeros_like(token_type_ids, dtype=torch.float32)
for type_id, bias_val in type_biases.items():
per_token_bias[token_type_ids == type_id] = float(bias_val)
# repetition penalty per type
per_type_rep_penalty = repetition_penalty_per_type or {}
# rare type mask
rare_type_mask = None
if rare_types is not None and len(rare_types) > 0:
rare_type_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
for rt in rare_types:
rare_type_mask |= (token_type_ids == rt)
else:
per_token_temp = None
per_token_bias = None
per_type_rep_penalty = {}
rare_type_mask = None
# prepare forbidden mask if provided
# We'll lazily convert forbidden_token_ids into a boolean mask of shape [b, vocab]
forbidden_mask_per_batch = None
if forbidden_token_ids is not None:
# If it's a LongTensor of ids (1D)
if forbidden_token_ids.dtype in (torch.int64, torch.int32):
# create a [vocab] bool mask from ids
vocab_size = self.net.config.vocab_size if hasattr(self.net, 'config') else None
# If we can't infer vocab_size, we'll infer from token_type_ids if available
if vocab_size is None and token_type_ids is not None:
vocab_size = token_type_ids.shape[0]
assert vocab_size is not None, "Cannot infer vocab size for forbidden_token_ids; provide a boolean mask instead."
mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
ids = forbidden_token_ids.to(device)
mask[ids.clamp(0, vocab_size-1)] = True
forbidden_mask_per_batch = mask.unsqueeze(0).expand(b, -1) # [b, vocab]
elif forbidden_token_ids.dtype == torch.bool:
# could be [vocab] or [b, vocab]
if forbidden_token_ids.dim() == 1:
forbidden_mask_per_batch = forbidden_token_ids.to(device).unsqueeze(0).expand(b, -1)
elif forbidden_token_ids.dim() == 2:
assert forbidden_token_ids.shape[0] == b, "forbidden_token_ids batch dimension must match prompts batch size"
forbidden_mask_per_batch = forbidden_token_ids.to(device)
else:
raise ValueError("forbidden_token_ids boolean mask must be 1D [vocab] or 2D [b, vocab]")
else:
raise TypeError("forbidden_token_ids must be LongTensor of ids or BoolTensor mask")
# sampling up to seq_len
for sl in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
'the network cannot use cached key values when decoding outside the max sequence length. ' \
'most likely because you are using absolute positional embedding. ' \
'you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1] # [b, vocab]
# handle contrastive decoding
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(
zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, \
'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# --------- STRUCTURED LOGIT SHAPING (no training) ---------
if token_type_ids is not None:
# 1) per-token bias (type-aware)
if per_token_bias is not None:
logits = logits + per_token_bias # broadcast [vocab]
# 2) repetition penalty per type (context-aware)
if repetition_window > 0 and len(per_type_rep_penalty) > 0:
# look at recent tokens
recent = out[:, -repetition_window:].to(device) # [b, w]
# map to types
recent_types = token_type_ids[recent] # [b, w]
# for each type, compute frequency and apply penalty
# we do this per batch element
for bi in range(b):
types_b = recent_types[bi] # [w]
if types_b.numel() == 0:
continue
# count occurrences per type id present in penalties
for type_id, penalty_scale in per_type_rep_penalty.items():
# penalty_scale > 1.0 means stronger penalty
mask = (types_b == type_id)
if mask.any():
freq = mask.float().mean().item() # 0..1
if freq > 0.0:
# build a penalty vector for this type
type_mask = (token_type_ids == type_id) # [vocab]
# subtract a penalty proportional to freq
# (log-space penalty)
logits[bi, type_mask] /= (1.0 + freq * (penalty_scale - 1.0))
# 3) entropy-based rare-type boost (gentle, context-aware)
if rare_type_mask is not None and rare_type_boost > 0.0:
# compute current probs & entropy (before global temperature)
probs_raw = F.softmax(logits, dim=-1) # [b, vocab]
log_probs_raw = torch.log(probs_raw + 1e-9)
entropy = -(probs_raw * log_probs_raw).sum(dim=-1) # [b]
# for low-entropy states, gently boost rare types
low_entropy = entropy < entropy_threshold
if low_entropy.any():
# boost only for those batch elements
boost_vec = torch.zeros_like(logits)
boost_vec[:, rare_type_mask] = rare_type_boost
logits = torch.where(
low_entropy.unsqueeze(-1),
logits + boost_vec,
logits
)
# 4) per-token temperature (type-aware)
# apply before global temperature
if per_token_temp is not None:
# divide logits by per-token temperature
# (smaller temp -> sharper distribution for that type)
logits = logits / per_token_temp
# --------- APPLY FORBIDDEN TOKEN MASK ---------
if forbidden_mask_per_batch is not None:
# ensure shapes match
assert forbidden_mask_per_batch.shape[0] == b and forbidden_mask_per_batch.shape[1] == logits.shape[-1], \
"forbidden mask shape must be [b, vocab]"
# set logits for forbidden tokens to a large negative value
logits = logits.masked_fill(forbidden_mask_per_batch, float(forbidden_value))
# ----------------------------------------------------------
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if verbose:
if sl % 32 == 0:
print(sl, '/', seq_len)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
if return_prime:
out = out[:, :]
else:
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
#=================================================================================================================================
# Binary classifier fuctions
# https://github.com/lucidrains/x-transformers/pull/264
#=================================================================================================================================
class ClsInferenceDataset(Dataset):
"""
Dataset for pairs (src_seq, label).
src_seq: list of token IDs (ints).
label: single int or float (0 or 1).
"""
def __init__(self, data_pairs):
self.data_pairs = data_pairs
def __len__(self):
return len(self.data_pairs)
def __getitem__(self, idx):
src_seq = self.data_pairs[idx]
x = torch.tensor(src_seq, dtype=torch.long)
return x
def build_cls_model(num_tokens=18819,
max_seq_len=1024,
logits_dim=1,
use_cls_token=True,
squeeze_out_last_dim=True,
dim=1024,
depth=8,
heads=8,
attn_flash=True,
rotary_pos_emb=False,
device='cuda'
):
"""
Constructs the Transformer model that outputs a single logit per input.
"""
model = TransformerWrapper(
num_tokens=num_tokens,
max_seq_len=max_seq_len,
logits_dim=logits_dim,
use_cls_token=use_cls_token,
squeeze_out_last_dim = squeeze_out_last_dim,
attn_layers=Encoder(dim=dim,
depth=depth,
heads=heads,
attn_flash=attn_flash,
rotary_pos_emb=rotary_pos_emb
)
)
return model.to(device)
def load_cls_model(checkpoint_path, device='cuda'):
"""
Rebuilds the architecture, loads weights.
"""
model = build_cls_model(device=device)
state = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state)
model.to(device).eval()
return model
def cls_predict(model,
seqs,
batch_size=8,
threshold=0.5,
seq_len=1024,
pad_token=18818,
device='cuda'
):
"""
Returns two lists:
- probs: float probabilities
- preds: int 0/1 predictions
"""
def collate_fn(batch):
# batch: list of sequences (list/1D-tensor)
tensors = [s[:seq_len].detach().clone() for s in batch]
max_len = min(seq_len, max(t.size(0) for t in tensors))
padded = torch.full((len(tensors), max_len), pad_token, dtype=torch.long)
for i, t in enumerate(tensors):
L = t.size(0)
padded[i, :L] = t
return padded
ds = ClsInferenceDataset(seqs)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
all_probs = []
all_preds = []
model.to(device)
model.eval()
with torch.inference_mode():
for x in loader:
x = x.to(device) # [B, L] (truncated & padded)
logits = model(x).squeeze() # [B]
probs = torch.sigmoid(logits) # [B]
preds = (probs >= threshold).long()
probs = probs.cpu().tolist()
preds = preds.cpu().tolist()
if type(preds) == list:
all_probs.extend(probs)
all_preds.extend(preds)
else:
all_probs.append(probs)
all_preds.append(preds)
return all_preds, all_probs
#=================================================================================================================================
# Sequences probabilities and scores functions
#=================================================================================================================================
import inspect
import math
from typing import Callable, Optional, Dict, Any, List, Tuple
import torch
import torch.nn.functional as F
def print_probs_scoring_guide():
print(inspect.getdoc(probs_scoring_guide))
def probs_scoring_guide():
"""
Return dictionary structure and metric descriptions for generate_with_probs / score_sequences.
Returns
-------
result : dict
A dictionary containing token-level and sequence-level scoring information.
Keys
----
tokens : torch.Tensor
Tensor of token ids for each batch entry. Shape (batch, seq_len).
- Meaning: Generated tokens (for generate_with_probs) or the original
input sequences (for score_sequences).
- Interpretation: Map ids to text with your tokenizer to inspect outputs.
token_probs : List[List[float]]
Per-batch lists of probabilities assigned to each chosen token at the time
it was produced. Values in [0, 1].
- Meaning: Softmax probability for the selected token at each step.
- Interpretation: Higher → model more confident about that token. Do not
multiply many token_probs directly (underflow risk); use log-probs.
token_logprobs : List[List[float]]
Per-batch lists of natural log probabilities (nats) for each chosen token:
log p(x_t | x_<t).
- Meaning: Numerically stable per-token log-probabilities.
- Interpretation: Less negative = more likely. Sum these to get sequence_logprobs.
token_scores : List[List[float]]
Per-batch lists of token negative log-probabilities (NLL) computed as -log p.
- Meaning: Token-level loss (positive).
- Interpretation: Lower = model found token less surprising. Useful to spot spikes.
sequence_logprobs : List[float]
Sum of token log-probabilities for each sequence (nats): sum_t log p(x_t | x_<t).
- Meaning: Canonical sequence score; additive and numerically stable.
- Interpretation: Use this to compare sequences. Higher (less negative) is better.
nll : List[float]
Negative sequence log-probabilities (nats): -sequence_logprobs.
- Meaning: Sequence-level negative log-likelihood (loss).
- Interpretation: Lower NLL indicates a sequence the model finds more probable.
sequence_probs : List[float]
Numeric probabilities computed as exp(sequence_logprobs) (float64 when possible).
- Meaning: Absolute probability of the full sequence.
- Interpretation: Often underflows to 0.0 for realistic lengths; prefer sequence_logprobs.
sequence_prob_display : List[str]
Human-readable string for sequence probability. If numeric underflow occurs,
this shows an approximate scientific form (e.g., "~10^-550.65").
- Meaning: Readable magnitude of the sequence probability.
- Interpretation: Use this for reporting instead of raw sequence_probs when it is 0.0.
mask : torch.Tensor
Boolean tensor indicating which positions were included in scoring.
Shape (batch, scored_len). False for padded positions or tokens after the first EOS.
- Meaning: Aligns token-level lists with original sequence positions.
- Interpretation: Use to ignore padded or post-EOS tokens in aggregates.
metadata : dict
Miscellaneous run information such as:
- prompt_len : int or list[int] — length of prompt tokens (if applicable)
- generated_len : int — number of generated tokens (generate_with_probs)
- temperature : float — sampling temperature used
- seq_len : int — original sequence length (score_sequences)
- Interpretation: Useful for reproducing runs and normalizing comparisons.
metrics : dict
Per-sequence derived diagnostics (under result["metrics"]["per_sequence"]).
Each entry contains:
- sequence_index : int
- token_count : int
- sequence_logprob_nats : float
Sum of log-probs (nats). Primary canonical score.
- sequence_log10 : float
Log10 of the sequence probability (for display).
- sequence_prob_display : str
Human-friendly scientific display of the sequence probability.
- avg_logprob_per_token_nats : float
Average log-prob per token (nats): (1/T) * sum_t log p.
- Interpretation: Normalizes for length; higher (less negative) is better.
- avg_logprob_per_token_bits : float
Average log-prob per token in bits (divide nats by ln(2)).
- Interpretation: Intuitive unit; lower bits = easier prediction.
- geometric_mean_token_prob : str
Geometric mean of token probabilities (display).
- Interpretation: Typical per-token probability; quick sense of per-token confidence.
- perplexity : float
exp(-avg_logprob_per_token). Standard LM metric; lower is better.
Notes
-----
- Use `sequence_logprobs` (or `nll`) as the authoritative score for comparisons and ranking.
- Avoid relying on `sequence_probs` for comparisons because of floating-point underflow.
- Prefer `avg_logprob_per_token_nats` or `perplexity` when comparing sequences of different lengths.
- Token-level spikes in `token_scores` (large -log p) indicate surprising tokens and are useful
for debugging prompts or model behavior.
Examples
--------
# Example usage after calling generate_with_probs or score_sequences:
res = generate_with_probs(...)
print("Sequence logprob (nats):", res["sequence_logprobs"][0])
print("Sequence prob (display):", res["sequence_prob_display"][0])
print("Avg logprob/token (nats):", res["metrics"]["per_sequence"][0]["avg_logprob_per_token_nats"])
print("Perplexity:", res["metrics"]["per_sequence"][0]["perplexity"])
"""
return inspect.getdoc(probs_scoring_guide)
# --- helpers ---
def _safe_exp64(logp: float) -> Tuple[float, str]:
lp64 = torch.tensor(logp, dtype=torch.float64)
try:
p64 = float(torch.exp(lp64).item())
except Exception:
p64 = 0.0
if p64 == 0.0:
log10_prob = float(lp64.item() / math.log(10.0))
display = f"~10^{log10_prob:.2f}"
else:
display = f"{p64:.6e}"
return p64, display
def _attach_metrics_to_result(result: Dict[str, Any]) -> Dict[str, Any]:
seq_logprobs: List[float] = result.get("sequence_logprobs", [])
token_logprobs: List[List[float]] = result.get("token_logprobs", [])
token_probs: List[List[float]] = result.get("token_probs", [])
metrics = {"per_sequence": []}
for i, seq_lp in enumerate(seq_logprobs):
toks_lp = token_logprobs[i] if i < len(token_logprobs) else []
token_count = len(toks_lp)
avg_lp = float(sum(toks_lp) / token_count) if token_count > 0 else 0.0
avg_lp_bits = avg_lp / math.log(2.0)
try:
perplexity = math.exp(-avg_lp)
except OverflowError:
perplexity = float("inf")
log10_prob = seq_lp / math.log(10.0)
seq_prob_display = result.get("sequence_prob_display", [None]*len(seq_logprobs))[i]
if seq_prob_display is None:
seq_prob_display = f"~10^{log10_prob:.2f}"
if token_count > 0:
try:
geom_mean = math.exp(avg_lp)
geom_mean_display = f"{geom_mean:.6e}"
except OverflowError:
geom_mean_display = f"exp({avg_lp:.3f})"
else:
geom_mean_display = "n/a"
metrics["per_sequence"].append({
"sequence_index": i,
"token_count": token_count,
"sequence_logprob_nats": float(seq_lp),
"sequence_log10": float(log10_prob),
"sequence_prob_display": seq_prob_display,
"avg_logprob_per_token_nats": float(avg_lp),
"avg_logprob_per_token_bits": float(avg_lp_bits),
"geometric_mean_token_prob": geom_mean_display,
"perplexity": float(perplexity)
})
result["metrics"] = metrics
return result
def _decode_token(tokenizer, tok_id: int) -> str:
if tokenizer is None:
return str(tok_id)
try:
if hasattr(tokenizer, "decode"):
return tokenizer.decode([tok_id])
if hasattr(tokenizer, "convert_ids_to_tokens"):
return tokenizer.convert_ids_to_tokens([tok_id])[0]
except Exception:
pass
return str(tok_id)
# ---------------------------
# generate_with_probs (with diff)
# ---------------------------
@torch.inference_mode()
def generate_with_probs(
model,
prompts: torch.Tensor,
seq_len: int,
eos_token: Optional[int] = None,
temperature: float = 1.0,
prompt_lens: Optional[torch.Tensor] = None,
filter_logits_fn: Optional[Callable] = None,
filter_kwargs: Optional[Dict[str, Any]] = None,
pad_value: Optional[int] = None,
tokenizer = None,
print_table: bool = False,
device: Optional[torch.device] = None,
verbose: bool = True,
include_top1: bool = True,
**kwargs
) -> Dict[str, Any]:
"""
Generate sequences from an autoregressive model while collecting per-token probabilities,
log-probabilities, scores and an optional "diff" view comparing sampled tokens to the
model's top-1 (greedy) tokens.
This function runs the model in inference mode and appends sampled tokens to the provided
prompts until `seq_len` tokens have been generated (or until an `eos_token` ends all
sequences). It supports temperature sampling, optional logits filtering, and returns
detailed diagnostics useful for evaluation, debugging and analysis (per-token probs,
cumulative sequence log-probabilities, NLL, perplexity, and a diff of sampled vs top-1).
Key behaviors
- Operates under `torch.inference_mode()` (no gradients).
- If `prompt_lens` is provided, prompts are right-aligned into a padded buffer of the
same shape as `prompts` before generation (useful when prompts are suffixes).
- If `filter_logits_fn` is provided it is applied to raw logits before softmax.
- If `temperature == 0.0` the function performs greedy decoding (argmax).
- If `include_top1` is True, the function computes the top-1 token and its probability
at each step (after optional filtering) and records whether the sampled token matched it.
- If `eos_token` is provided, generation stops early when every batch item has produced
an EOS; generated outputs after the first EOS are optionally padded with `pad_value`.
- Returned numeric log-probabilities are in natural log (nats) and converted to float64
for sequence-level aggregation to reduce numerical error.
Parameters
- model: A model object exposing a `net` callable with signature
`logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`.
`logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first
element is that tensor.
- prompts (torch.Tensor): Integer token tensor of shape (batch, prompt_len) containing
prompt tokens. Prompts are copied and extended in-place to produce generated sequences.
- seq_len (int): Maximum number of tokens to generate per example (not counting prompt).
- eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, generation
may stop early and outputs after the first EOS are optionally replaced with `pad_value`.
- temperature (float): Sampling temperature. `0.0` forces greedy decoding.
- prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths (int or tensor)
used to right-align prompts into the generation buffer when prompts are suffixes.
- filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax.
Signature should accept `(logits, **filter_kwargs)` and return logits of same shape.
- filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`.
- pad_value (Optional[int]): Token id used to pad generated outputs after EOS (if any).
- tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and
printed tables. If absent, token ids are stringified.
- print_table (bool): If True, prints a human-readable table summarizing per-token stats.
- device (Optional[torch.device]): Device to run generation on. Defaults to `prompts.device`.
- verbose (bool): If True, prints progress messages during generation.
- include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs,
and a `diff` structure listing positions where sampled != top-1.
- **kwargs: Additional keyword arguments forwarded to `model.net`.
Returns
A dictionary with the following keys (types shown informally):
- "tokens" (torch.Tensor): Generated tokens (batch, generated_len) as CPU tensor.
- "token_probs" (List[List[float]]): Per-batch list of per-token sampling probabilities.
- "token_logprobs" (List[List[float]]): Per-batch list of per-token log-probabilities (nats).
- "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities).
- "sequence_logprobs" (List[float]): Sum of token log-probabilities per generated sequence.
- "sequence_probs" (List[float]): Sequence probabilities (exp of sequence_logprobs) where
numerically possible; extremely small values may be represented as 0.0.
- "sequence_prob_display" (List[str]): Human-friendly display of sequence probability
(either decimal or approximate 10^x form for tiny values).
- "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob).
- "metadata" (dict): Contains "prompt_len", "generated_len", and "temperature".
- "diff" (List[List[Dict]]): Per-batch list of dictionaries for positions where the sampled
token differed from the top-1 token. Each dict contains:
- "pos": position index within the generated span (0-based)
- "token": sampled token id
- "token_str": decoded sampled token (or id string)
- "token_prob": sampled token probability
- "top1_token": top-1 token id
- "top1_token_str": decoded top-1 token
- "top1_prob": top-1 probability
- "match": boolean (always False for entries in diff)
- If `include_top1` is True, additional keys are included:
- "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches"
After the primary result is assembled the function attaches a "metrics" entry with:
- "per_sequence": list of per-sequence metric dicts containing:
- "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10",
"sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits",
"geometric_mean_token_prob", "perplexity"
Notes and caveats
- Numerical stability: very small probabilities are clamped before log to avoid -inf;
sequence probabilities that underflow are represented with an approximate 10^x string.
- The function assumes the model's logits correspond to the next-token distribution for
the last position of the provided input; it uses `logits[:, -1, :]` for sampling.
- The function may raise exceptions if `model.net` returns tensors of unexpected shape
or if device/dtype mismatches occur.
- This function is intended for analysis and debugging; it is not optimized for maximal
throughput in production sampling loops.
Example (conceptual)
>>> res = generate_with_probs(model, prompts, seq_len=20, temperature=0.8, tokenizer=tok)
>>> print(res["metrics"]["per_sequence"][0]["perplexity"])
"""
if filter_kwargs is None:
filter_kwargs = {}
if device is None:
device = prompts.device
if pad_value is None:
pad_value = getattr(model, "pad_value", None)
model.eval()
with torch.inference_mode():
prompts_in = prompts.to(device)
b, t = prompts_in.shape
if prompt_lens is not None:
aligned = torch.full_like(prompts_in, pad_value)
for i in range(b):
L = int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else int(prompt_lens[i])
if L > 0:
aligned[i, -L:] = prompts_in[i, -L:]
prompts_in = aligned
out = prompts_in.clone()
token_probs: List[List[float]] = [[] for _ in range(b)]
token_logprobs: List[List[float]] = [[] for _ in range(b)]
token_scores: List[List[float]] = [[] for _ in range(b)]
seq_logprob_tensors = [torch.tensor(0.0, dtype=torch.float64) for _ in range(b)]
top1_tokens: List[List[int]] = [[] for _ in range(b)]
top1_token_probs: List[List[float]] = [[] for _ in range(b)]
top1_token_logprobs: List[List[float]] = [[] for _ in range(b)]
top1_matches: List[List[bool]] = [[] for _ in range(b)]
greedy = (temperature == 0.0)
if verbose:
print("Generating sequence of max length:", seq_len)
for sl in range(seq_len):
max_seq_len = getattr(model, "max_seq_len", None)
x = out if max_seq_len is None else out[:, -max_seq_len:]
logits_out = model.net(x, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)
logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out
logits = logits[:, -1, :]
# top1 (greedy) from raw logits
if include_top1:
top1_ids = logits.argmax(dim=-1, keepdim=True) # (batch,1)
filtered_for_top1 = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
probs_for_top1 = F.softmax(filtered_for_top1 / (temperature if temperature > 0 else 1.0), dim=-1)
top1_p = probs_for_top1.gather(1, top1_ids).squeeze(1)
top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64)
if greedy:
filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / (temperature if temperature > 0 else 1.0), dim=-1)
sample = logits.argmax(dim=-1, keepdim=True)
else:
filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
picked_probs = probs.gather(1, sample).squeeze(1)
picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64)
out = torch.cat((out, sample), dim=-1)
for i in range(b):
p = float(picked_probs[i].cpu().item())
lp = float(picked_logprobs[i].cpu().item())
token_probs[i].append(p)
token_logprobs[i].append(lp)
token_scores[i].append(-lp)
seq_logprob_tensors[i] = seq_logprob_tensors[i] + torch.tensor(lp, dtype=torch.float64)
if include_top1:
tid = int(top1_ids[i].item())
tp = float(top1_p[i].cpu().item())
tlp = float(top1_lp[i].cpu().item())
top1_tokens[i].append(tid)
top1_token_probs[i].append(tp)
top1_token_logprobs[i].append(tlp)
top1_matches[i].append(int(sample[i].item()) == tid)
if verbose and (sl % 32 == 0):
print(f"{sl} / {seq_len}")
if eos_token is not None:
last_tokens = out[:, -1]
if (last_tokens == eos_token).any(dim=-1).all():
if verbose:
print('Model called the end of sequence at:', sl, '/', seq_len)
break
gen = out[:, t:].cpu()
if eos_token is not None:
for i in range(b):
seq_full = out[i].cpu()
eos_positions = (seq_full == eos_token).nonzero(as_tuple=False)
if eos_positions.numel() > 0:
first_eos_idx = int(eos_positions[0].item())
gen_len_before_eos = max(0, first_eos_idx - t)
token_probs[i] = token_probs[i][:gen_len_before_eos]
token_logprobs[i] = token_logprobs[i][:gen_len_before_eos]
token_scores[i] = token_scores[i][:gen_len_before_eos]
seq_logprob_tensors[i] = torch.tensor(sum(token_logprobs[i]), dtype=torch.float64)
if include_top1:
top1_tokens[i] = top1_tokens[i][:gen_len_before_eos]
top1_token_probs[i] = top1_token_probs[i][:gen_len_before_eos]
top1_token_logprobs[i] = top1_token_logprobs[i][:gen_len_before_eos]
top1_matches[i] = top1_matches[i][:gen_len_before_eos]
if pad_value is not None:
start_mask = max(0, first_eos_idx - t)
if start_mask < gen.shape[1]:
gen[i, start_mask:] = pad_value
sequence_logprobs: List[float] = [float(x.item()) for x in seq_logprob_tensors]
sequence_probs: List[float] = []
sequence_prob_display: List[str] = []
nll: List[float] = []
for lp in sequence_logprobs:
pnum, disp = _safe_exp64(lp)
sequence_probs.append(pnum)
sequence_prob_display.append(disp)
nll.append(-lp)
result = {
"tokens": gen,
"token_probs": token_probs,
"token_logprobs": token_logprobs,
"token_scores": token_scores,
"sequence_logprobs": sequence_logprobs,
"sequence_probs": sequence_probs,
"sequence_prob_display": sequence_prob_display,
"nll": nll,
"metadata": {
"prompt_len": t,
"generated_len": gen.shape[1],
"temperature": temperature
}
}
if include_top1:
result.update({
"top1_tokens": top1_tokens,
"top1_token_probs": top1_token_probs,
"top1_token_logprobs": top1_token_logprobs,
"top1_matches": top1_matches
})
# build diff view: sampled != top1
diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)]
if include_top1:
for i in range(b):
for pos, (sample_tok, sample_p, t1_tok, t1_p, match) in enumerate(zip(
[int(x) for x in gen[i].tolist()],
token_probs[i],
top1_tokens[i],
top1_token_probs[i],
top1_matches[i]
)):
if not match:
diff_all[i].append({
"pos": pos,
"token": sample_tok,
"token_str": _decode_token(tokenizer, sample_tok),
"token_prob": sample_p,
"top1_token": int(t1_tok),
"top1_token_str": _decode_token(tokenizer, int(t1_tok)),
"top1_prob": t1_p,
"match": bool(match)
})
result["diff"] = diff_all
result = _attach_metrics_to_result(result)
if print_table:
for i in range(b):
print("="*110)
print(f"Batch {i} (prompt_len={t})")
print("-"*110)
print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match")
print("-"*110)
cum_logp = 0.0
for idx, (p, lp, sc) in enumerate(zip(token_probs[i], token_logprobs[i], token_scores[i])):
cum_logp += lp
tok_id = int(gen[i, idx].item()) if idx < gen.shape[1] else -1
tok_display = _decode_token(tokenizer, tok_id)
if include_top1:
t1_id = top1_tokens[i][idx]
t1_p = top1_token_probs[i][idx]
match_mark = "*" if top1_matches[i][idx] else " "
print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}){match_mark}")
else:
print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}")
print("-"*110)
print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}")
m = result["metrics"]["per_sequence"][i]
print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}")
if result["diff"][i]:
print("DIFF (sampled != top1) positions:")
for d in result["diff"][i]:
print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}")
else:
print("No diffs: sampled tokens matched top1 at every step.")
print("="*110)
return result
# ---------------------------
# score_sequences (with diff)
# ---------------------------
@torch.inference_mode()
def score_sequences(
model,
sequences: torch.Tensor,
prompt_lens: Optional[torch.Tensor] = None,
eos_token: Optional[int] = None,
pad_value: Optional[int] = None,
filter_logits_fn: Optional[Callable] = None,
filter_kwargs: Optional[Dict[str, Any]] = None,
tokenizer = None,
print_table: bool = False,
device: Optional[torch.device] = None,
verbose: bool = False,
include_top1: bool = True,
**kwargs
) -> Dict[str, Any]:
"""
Compute per-token and per-sequence likelihood statistics for given full sequences
under an autoregressive model, optionally comparing each target token to the model's
top-1 prediction and producing a diff of mismatches.
This function scores provided sequences by computing the model's next-token distribution
for each position and extracting the probability and log-probability assigned to the
actual target token (i.e., the token that follows each input prefix). It supports
masking of padding tokens, optional EOS-based truncation, and an optional logits filter.
The function returns detailed per-token lists, aggregated sequence log-probabilities,
NLLs, human-friendly probability displays, and diagnostic "diff" entries where the
target token differs from the model's greedy top-1.
Key behaviors
- Operates under `torch.inference_mode()` (no gradients).
- Expects `sequences` shaped (batch, seq_len). The function scores tokens at positions
1..(seq_len-1) where each target is `sequences[:, pos]` and the corresponding input
is `sequences[:, :pos]`.
- If `filter_logits_fn` is provided it is applied to the model logits before softmax.
- If `pad_value` is provided, positions where the target equals `pad_value` are masked
out and not counted in sequence sums or per-token lists.
- If `eos_token` is provided, tokens after the first EOS in each sequence are masked out.
- If `include_top1` is True, the function computes top-1 ids and probabilities and
records whether the target matched the top-1 at each scored position.
Parameters
- model: A model object exposing a `net` callable with signature
`logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`.
`logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first
element is that tensor.
- sequences (torch.Tensor): Integer token tensor of shape (batch, seq_len) containing
full sequences to be scored. The first token of each sequence is treated as context
and scoring begins at the second token.
- prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths; included in
returned metadata for bookkeeping (does not change scoring logic).
- eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, tokens
after the first EOS are excluded from scoring.
- pad_value (Optional[int]): Token id used to indicate padding; masked positions are
excluded from per-token lists and sequence aggregates.
- filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax.
Signature should accept `(logits, **filter_kwargs)` and return logits of same shape.
- filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`.
- tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and
printed tables. If absent, token ids are stringified.
- print_table (bool): If True, prints a human-readable table summarizing per-token stats.
- device (Optional[torch.device]): Device to run scoring on. Defaults to `sequences.device`.
- verbose (bool): If True, prints progress or extra information (currently minimal).
- include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs,
and a `diff` structure listing positions where target != top-1.
- **kwargs: Additional keyword arguments forwarded to `model.net`.
Returns
A dictionary with the following keys:
- "tokens" (torch.Tensor): The input `sequences` returned as a CPU tensor.
- "token_probs" (List[List[float]]): Per-batch lists of probabilities assigned to each
scored target token (masked positions removed).
- "token_logprobs" (List[List[float]]): Per-batch lists of log-probabilities (nats).
- "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities).
- "sequence_logprobs" (List[float]): Sum of log-probabilities over unmasked target tokens.
- "sequence_probs" (List[float]): Sequence probabilities where numerically representable.
- "sequence_prob_display" (List[str]): Human-friendly display of sequence probability.
- "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob).
- "mask" (torch.BoolTensor): Boolean mask (batch, scored_len) indicating which target
positions were included in scoring (True = scored).
- "diff" (List[List[Dict]]): Per-batch list of dicts for positions where the target
token did not match the model's top-1. Each dict contains:
- "pos": index within the scored positions (0-based)
- "token": target token id
- "token_str": decoded target token (or id string)
- "token_prob": probability assigned to the target token
- "top1_token": top-1 token id
- "top1_token_str": decoded top-1 token
- "top1_prob": top-1 probability
- "match": boolean (False for entries in diff)
- "metadata" (dict): Contains "prompt_len" (if provided), "seq_len" (original sequence
length), and "scored_len_per_batch" (number of scored tokens per batch item).
- If `include_top1` is True, additional keys are included:
- "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches"
After assembling the primary result the function attaches a "metrics" entry with:
- "per_sequence": list of per-sequence metric dicts containing:
- "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10",
"sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits",
"geometric_mean_token_prob", "perplexity"
Notes and caveats
- The function expects `sequences` to contain at least two tokens per batch item; if
`seq_len < 2` a minimal result with empty scored lists is returned.
- Numerical stability: probabilities are clamped before log to avoid -inf; extremely
small sequence probabilities are represented in approximate 10^x form.
- The function may raise ValueError if `model.net` returns logits of unexpected shape.
- This routine is intended for evaluation and analysis of model likelihoods rather than
high-performance batched scoring in production.
Example (conceptual)
>>> res = score_sequences(model, sequences, pad_value=0, eos_token=2, tokenizer=tok)
>>> print(res["metrics"]["per_sequence"][0]["avg_logprob_per_token_nats"])
"""
if filter_kwargs is None:
filter_kwargs = {}
if device is None:
device = sequences.device
model.eval()
with torch.inference_mode():
sequences = sequences.to(device)
b, L = sequences.shape
if L < 2:
empty = [[] for _ in range(b)]
return {
"tokens": sequences.cpu(),
"token_probs": empty,
"token_logprobs": empty,
"token_scores": empty,
"sequence_probs": [1.0 for _ in range(b)],
"sequence_prob_display": [f"{1.0:.6e}" for _ in range(b)],
"sequence_logprobs": [0.0 for _ in range(b)],
"nll": [0.0 for _ in range(b)],
"mask": torch.zeros((b, 0), dtype=torch.bool),
"diff": [[] for _ in range(b)],
"metadata": {"prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens),
"seq_len": L,
"scored_len": 0}
}
inputs = sequences[:, :-1]
targets = sequences[:, 1:]
logits_out = model.net(inputs, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)
logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out
if logits.dim() != 3:
raise ValueError(f"Expected logits with shape (b, seq, vocab), got {logits.shape}")
filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **(filter_kwargs or {}))
probs = F.softmax(filtered_logits, dim=-1)
targets_unsq = targets.unsqueeze(-1)
picked_probs = probs.gather(dim=-1, index=targets_unsq).squeeze(-1)
picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64)
if include_top1:
top1_ids = probs.argmax(dim=-1) # (b, seq)
top1_p = probs.gather(-1, top1_ids.unsqueeze(-1)).squeeze(-1)
top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64)
mask = torch.ones_like(picked_probs, dtype=torch.bool)
if pad_value is not None:
mask = mask & (targets != pad_value)
if eos_token is not None:
for i in range(b):
seq_full = sequences[i]
eos_positions = (seq_full == eos_token).nonzero(as_tuple=False)
if eos_positions.numel() > 0:
first_eos = int(eos_positions[0].item())
cutoff = max(0, first_eos - 1)
if cutoff + 1 < mask.shape[1]:
mask[i, cutoff+1:] = False
token_probs: List[List[float]] = []
token_logprobs: List[List[float]] = []
token_scores: List[List[float]] = []
sequence_logprobs: List[float] = []
sequence_probs: List[float] = []
sequence_prob_display: List[str] = []
nll: List[float] = []
top1_tokens: List[List[int]] = [[] for _ in range(b)]
top1_token_probs: List[List[float]] = [[] for _ in range(b)]
top1_token_logprobs: List[List[float]] = [[] for _ in range(b)]
top1_matches: List[List[bool]] = [[] for _ in range(b)]
diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)]
for i in range(b):
row_mask = mask[i]
row_probs = picked_probs[i]
row_logps = picked_logprobs[i]
kept_probs = row_probs[row_mask].cpu().tolist()
kept_logps = row_logps[row_mask].cpu().tolist()
kept_scores = [-lp for lp in kept_logps]
token_probs.append([float(x) for x in kept_probs])
token_logprobs.append([float(x) for x in kept_logps])
token_scores.append([float(x) for x in kept_scores])
if include_top1:
t1_row = top1_ids[i]
t1_p_row = top1_p[i]
t1_lp_row = top1_lp[i]
kept_t1_ids = t1_row[row_mask].cpu().tolist()
kept_t1_ps = t1_p_row[row_mask].cpu().tolist()
kept_t1_lps = t1_lp_row[row_mask].cpu().tolist()
top1_tokens[i] = [int(x) for x in kept_t1_ids]
top1_token_probs[i] = [float(x) for x in kept_t1_ps]
top1_token_logprobs[i] = [float(x) for x in kept_t1_lps]
kept_targets = targets[i][row_mask].cpu().tolist()
top1_matches[i] = [int(t == top1) for t, top1 in zip(kept_targets, kept_t1_ids)]
# build diff entries where target != top1
for pos_idx, (tgt, tgt_p, t1, t1_p, match) in enumerate(zip(kept_targets, kept_probs, kept_t1_ids, kept_t1_ps, top1_matches[i])):
if not match:
diff_all[i].append({
"pos": pos_idx,
"token": int(tgt),
"token_str": _decode_token(tokenizer, int(tgt)),
"token_prob": float(tgt_p),
"top1_token": int(t1),
"top1_token_str": _decode_token(tokenizer, int(t1)),
"top1_prob": float(t1_p),
"match": bool(match)
})
seq_lp_tensor = torch.tensor(sum(kept_logprobs := kept_logps), dtype=torch.float64)
seq_lp = float(seq_lp_tensor.item())
pnum, disp = _safe_exp64(seq_lp)
sequence_logprobs.append(seq_lp)
sequence_probs.append(pnum)
sequence_prob_display.append(disp)
nll.append(-seq_lp)
result = {
"tokens": sequences.cpu(),
"token_probs": token_probs,
"token_logprobs": token_logprobs,
"token_scores": token_scores,
"sequence_logprobs": sequence_logprobs,
"sequence_probs": sequence_probs,
"sequence_prob_display": sequence_prob_display,
"nll": nll,
"mask": mask.cpu(),
"diff": diff_all,
"metadata": {
"prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens),
"seq_len": L,
"scored_len_per_batch": [int(m.sum().item()) for m in mask]
}
}
if include_top1:
result.update({
"top1_tokens": top1_tokens,
"top1_token_probs": top1_token_probs,
"top1_token_logprobs": top1_token_logprobs,
"top1_matches": top1_matches
})
result = _attach_metrics_to_result(result)
if print_table:
for i in range(b):
print("=" * 120)
header = f"Batch {i} (seq_len={L})"
if prompt_lens is not None:
header += f" prompt_len={int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else prompt_lens[i]}"
print(header)
print("-" * 120)
print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match")
print("-" * 120)
cum_logp = 0.0
pos_idx = 0
for pos in range(1, L):
if not mask[i, pos-1]:
continue
tok_id = int(sequences[i, pos].item())
p = float(picked_probs[i, pos-1].cpu().item())
lp = float(picked_logprobs[i, pos-1].cpu().item())
cum_logp += lp
sc = -lp
if include_top1:
t1_id = top1_tokens[i][pos_idx]
t1_p = top1_token_probs[i][pos_idx]
match = top1_matches[i][pos_idx]
print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}) | {match}")
else:
print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}")
pos_idx += 1
print("-" * 120)
print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}")
m = result["metrics"]["per_sequence"][i]
print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}")
if result["diff"][i]:
print("DIFF (target != top1) positions:")
for d in result["diff"][i]:
print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}")
else:
print("No diffs: target tokens matched top1 at every scored position.")
print("=" * 120)
return result
#=================================================================================================================================
# ETA functions
#=================================================================================================================================
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
def calculate_eta(
hours_until_done: float,
*,
tz: str = "America/Los_Angeles",
now: datetime | None = None,
return_dict: bool = False,
verbose: bool = True,
):
"""
Compute an ETA timestamp based on the current time (or a provided time)
in a specified timezone.
Parameters
----------
hours_until_done : float
Number of hours remaining until completion.
tz : str, optional
IANA timezone name (default: "America/Los_Angeles").
now : datetime or None, optional
If provided, use this datetime as the starting point.
If None, the current time in the given timezone is used.
return_dict : bool, optional
If True, return a dictionary with ETA components.
verbose : bool, optional
If True, print a formatted ETA string.
Returns
-------
datetime or dict
ETA as a datetime object or a dictionary (if return_dict=True).
Examples
--------
# Simple ETA 5.5 hours from now
calculate_eta(5.5)
# ETA using a custom starting time in Tokyo
from datetime import datetime
calculate_eta(
12,
tz="Asia/Tokyo",
now=datetime(2026, 1, 29, 8, 30),
)
# Get ETA as a dict without printing
info = calculate_eta(3, verbose=False, return_dict=True)
print(info["pretty"])
"""
# Resolve timezone
zone = ZoneInfo(tz)
# Determine current time
current_time = now.astimezone(zone) if now else datetime.now(zone)
# Compute ETA
eta = current_time + timedelta(hours=hours_until_done)
# Format for printing
pretty = eta.strftime("ETA: %A, %B %d %Y @ %H:%M")
if verbose:
print(pretty)
if return_dict:
return {
"eta_datetime": eta,
"year": eta.year,
"month": eta.month,
"day": eta.day,
"hour": eta.hour,
"minute": eta.minute,
"second": eta.second,
"timezone": tz,
"pretty": pretty,
}
return eta
def calculate_training_run_eta(
num_epochs: int,
num_steps_per_epoch: int,
sec_per_iter: float,
*,
cost_per_hr: float = 0.0,
tz: str = "America/Los_Angeles",
now: datetime | None = None,
return_dict: bool = False,
verbose: bool = True,
):
"""
Compute ETA and cost for a full training run based on:
- number of epochs
- number of steps per epoch
- seconds per iteration
- optional cost per hour of compute
Prints:
- start time
- ETA timestamp
- per-epoch runtime (h/m/s)
- total runtime (h/m/s)
- cost per epoch
- total run cost
Returns:
datetime or dict (if return_dict=True)
Examples:
# 2 epochs, 7770 steps each, 15.07 sec/iter, $5.3 per/hr
calculate_training_run_eta(
num_epochs=2,
num_steps_per_epoch=7771,
cost_per_hr=5.3,
sec_per_iter=15.07,
)
# Get structured info without printing
info = calculate_training_run_eta(
3, 1000, 0.5,
verbose=False,
return_dict=True
)
print(info["eta_str"])
"""
zone = ZoneInfo(tz)
start_time = now.astimezone(zone) if now else datetime.now(zone)
# Core calculations
total_iters = num_epochs * num_steps_per_epoch
total_seconds = total_iters * sec_per_iter
epoch_seconds = num_steps_per_epoch * sec_per_iter
eta = start_time + timedelta(seconds=total_seconds)
# Formatting helpers
def fmt(seconds: float) -> str:
seconds = int(seconds)
h = seconds // 3600
m = (seconds % 3600) // 60
s = seconds % 60
return f"{h}h {m}m {s}s"
# Cost calculations
total_hours = total_seconds / 3600
epoch_hours = epoch_seconds / 3600
cost_epoch = epoch_hours * cost_per_hr
cost_total = total_hours * cost_per_hr
# Pretty strings
start_str = start_time.strftime("%A, %B %d %Y @ %H:%M")
eta_str = eta.strftime("%A, %B %d %Y @ %H:%M")
if verbose:
print(f"Start Time: {start_str}")
print(f"ETA: {eta_str}")
print(f"Per Epoch: {fmt(epoch_seconds)}")
print(f"Total Run: {fmt(total_seconds)}")
print(f"Cost/Epoch: ${cost_epoch:,.2f}")
print(f"Cost/Run: ${cost_total:,.2f}")
if return_dict:
return {
"start_time": start_time,
"eta": eta,
"start_str": start_str,
"eta_str": eta_str,
"epoch_seconds": epoch_seconds,
"total_seconds": total_seconds,
"epoch_runtime_hms": fmt(epoch_seconds),
"total_runtime_hms": fmt(total_seconds),
"epoch_hours": epoch_hours,
"total_hours": total_hours,
"cost_per_hr": cost_per_hr,
"cost_epoch": cost_epoch,
"cost_total": cost_total,
"timezone": tz,
}
return eta
#=================================================================================================================================
# Autoregressive embeddings retrieval functions
#=================================================================================================================================
import torch
import torch.nn.functional as F
from typing import Optional, Dict, List, Union, Set
import numpy as np
from tqdm import tqdm
from contextlib import nullcontext
#===================================================================================================================
# Advanced Embeddings Retrieval Function for Autoregressive X-Transformers
#===================================================================================================================
def get_embeddings(
model,
inputs: torch.Tensor,
pooling: str = 'mean',
mask: Optional[torch.Tensor] = None,
token_ids: Optional[List[int]] = None,
token_weights: Optional[Dict[int, float]] = None,
layer_index: int = -1,
normalize: bool = False,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.bfloat16,
pad_idx: int = 18819,
use_amp: bool = True,
verbose: bool = True,
_max_concat_tokens: Optional[int] = None,
) -> np.ndarray:
"""
Get embeddings for a single batch of inputs.
Parameters
----------
model : AutoregressiveWrapper
Your trained transformer model
inputs : torch.Tensor
Input token sequences of shape (batch, seq_len)
pooling : str
Pooling strategy: 'mean' or 'concat'
mask : Optional[torch.Tensor]
Boolean mask, True for valid tokens. Auto-generated if None
token_ids : Optional[List[int]]
Token IDs to include. Works independently or with token_weights.
token_weights : Optional[Dict[int, float]]
Token ID to weight/priority mapping:
- 'mean': weights for weighted average
- 'concat': priority scores for selection when limiting count
- If provided WITHOUT token_ids: keys become the filter
- If provided WITH token_ids: only tokens in BOTH are used (intersection)
layer_index : int
Which layer's hidden states to use (-1 for last)
normalize : bool
L2-normalize output embeddings
device : Optional[torch.device]
Device for inference
dtype : torch.dtype
Dtype for autocast
pad_idx : int
Padding token index
use_amp : bool
Use automatic mixed precision
verbose : bool
Print warnings and info
_max_concat_tokens : Optional[int]
Internal: pre-computed max tokens for concat mode
Returns
-------
np.ndarray
Embeddings array:
- 'mean': (batch, dim)
- 'concat': (batch, max_tokens * dim)
"""
model.eval()
if device is None:
device = next(model.parameters()).device
inputs = inputs.to(device)
if inputs.ndim == 1:
inputs = inputs.unsqueeze(0)
batch_size, seq_len = inputs.shape
if mask is None:
mask = (inputs != pad_idx)
else:
mask = mask.to(device)
if mask.dtype != torch.bool:
mask = mask.bool()
if hasattr(model, 'net'):
net_model = model.net
else:
net_model = model
if use_amp and device.type == 'cuda':
ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
else:
ctx = nullcontext()
try:
with torch.no_grad():
with ctx if use_amp else nullcontext():
output = net_model(
inputs,
mask=mask if mask.ndim == 2 else mask.squeeze(),
return_intermediates=True,
)
if isinstance(output, tuple) and len(output) == 2:
_, intermediates = output
else:
intermediates = None
hidden = _extract_hidden_states(intermediates, layer_index, verbose=verbose)
if hidden is None:
raise ValueError("Could not extract hidden states")
except Exception as e:
if verbose:
print(f"Warning: Could not extract hidden states, using token embeddings. Error: {e}")
hidden = _get_token_embeddings(net_model, inputs)
seq_mask = (inputs != pad_idx)
seq_mask_expanded = seq_mask.unsqueeze(-1)
hidden = hidden * seq_mask_expanded.float()
# Compute effective token IDs with INTUITIVE logic
effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
if pooling == 'mean':
emb = _mean_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, verbose=verbose)
elif pooling == 'concat':
emb = _concat_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights,
max_tokens=_max_concat_tokens, verbose=verbose)
else:
raise ValueError(f"Unknown pooling strategy: {pooling}. Use 'mean' or 'concat'")
if normalize:
emb = F.normalize(emb, p=2, dim=-1)
return emb.cpu().detach().numpy()
#===================================================================================================================
# Batched Processing Function
#===================================================================================================================
def get_embeddings_batched(
model,
sequences: List[List[int]],
pooling: str = 'mean',
token_ids: Optional[List[int]] = None,
token_weights: Optional[Dict[int, float]] = None,
max_seq_len: int = 8192,
pad_idx: int = 18819,
batch_size: int = 8,
use_amp: bool = True,
dtype: torch.dtype = torch.bfloat16,
verbose: bool = True,
show_progress: bool = True,
normalize: bool = False,
) -> np.ndarray:
"""
Process multiple sequences in TRUE batches for memory efficiency.
Parameters
----------
model : AutoregressiveWrapper
Your trained transformer model
sequences : List[List[int]]
List of token sequences (list of lists)
pooling : str
Pooling strategy: 'mean' or 'concat'
token_ids : Optional[List[int]]
Token IDs to include
token_weights : Optional[Dict[int, float]]
Token ID to weight/priority mapping
max_seq_len : int
Maximum sequence length
pad_idx : int
Padding token index
batch_size : int
Batch size for processing
use_amp : bool
Use automatic mixed precision
dtype : torch.dtype
Dtype for autocast
verbose : bool
Print messages
show_progress : bool
Show tqdm progress bar
normalize : bool
L2-normalize output embeddings
Returns
-------
np.ndarray
Embeddings array with consistent dimensions
"""
model.eval()
num_sequences = len(sequences)
if verbose:
print(f"Processing {num_sequences} sequences in batches of {batch_size}...")
# For concat mode: pre-scan to find max matching tokens across ALL sequences
max_concat_tokens = None
if pooling == 'concat':
effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
max_concat_tokens = _scan_max_matching_tokens(sequences, effective_token_ids, pad_idx, max_seq_len)
if verbose and max_concat_tokens is not None:
print(f"Auto-detected max matching tokens: {max_concat_tokens}")
elif verbose and max_concat_tokens == 0:
print("Warning: No sequences contain matching token IDs, using 1 token placeholder")
max_concat_tokens = 1
all_embeddings = []
num_batches = (num_sequences + batch_size - 1) // batch_size
batch_iterator = tqdm(range(num_batches), desc="Extracting embeddings", disable=not (show_progress and verbose)) if show_progress and verbose else range(num_batches)
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_sequences)
batch_sequences = sequences[start_idx:end_idx]
max_len_in_batch = min(max_seq_len, max(len(seq) for seq in batch_sequences))
padded_batch = []
for seq in batch_sequences:
if len(seq) > max_len_in_batch:
seq = seq[:max_len_in_batch]
else:
seq = seq + [pad_idx] * (max_len_in_batch - len(seq))
padded_batch.append(seq)
batch_inputs = torch.tensor(padded_batch, dtype=torch.long)
batch_embeddings = get_embeddings(
model,
batch_inputs,
pooling=pooling,
token_ids=token_ids,
token_weights=token_weights,
pad_idx=pad_idx,
use_amp=use_amp,
dtype=dtype,
verbose=verbose and batch_idx == 0,
normalize=normalize,
_max_concat_tokens=max_concat_tokens,
)
all_embeddings.append(batch_embeddings)
final_embeddings = np.concatenate(all_embeddings, axis=0)
if verbose:
print(f"Final embeddings shape: {final_embeddings.shape}")
return final_embeddings
#===================================================================================================================
# Helper Functions
#===================================================================================================================
def _compute_effective_token_ids(token_ids: Optional[List[int]], token_weights: Optional[Dict[int, float]]) -> Optional[Set[int]]:
"""
Compute effective token IDs with INTUITIVE logic:
- token_ids=None, token_weights=None → None (all valid tokens)
- token_ids=[...], token_weights=None → token_ids
- token_ids=None, token_weights={...} → keys from token_weights
- token_ids=[...], token_weights={...} → INTERSECTION (only tokens in BOTH)
This ensures token_weights acts as a filter when provided, not just weights.
"""
if token_ids is None and token_weights is None:
return None
token_ids_set = set(token_ids) if token_ids is not None else None
weights_keys_set = set(token_weights.keys()) if token_weights is not None else None
if token_ids_set is None and weights_keys_set is not None:
# Only token_weights provided: use its keys as filter
return weights_keys_set
elif token_ids_set is not None and weights_keys_set is None:
# Only token_ids provided: use token_ids as filter
return token_ids_set
elif token_ids_set is not None and weights_keys_set is not None:
# Both provided: INTERSECTION (only tokens in BOTH lists)
# This is the key fix for intuitive behavior
intersection = token_ids_set & weights_keys_set
if len(intersection) == 0:
# Warn but fall back to token_ids (more permissive)
print(f"Warning: token_ids and token_weights have no overlap. Using token_ids only.")
return token_ids_set
return intersection
else:
return None
def _scan_max_matching_tokens(sequences: List[List[int]],
token_ids: Optional[Set[int]],
pad_idx: int,
max_seq_len: int) -> int:
"""
Scan all sequences to find maximum number of tokens matching token_ids.
"""
if token_ids is None:
return max(min(len(seq), max_seq_len) for seq in sequences) if sequences else 0
max_count = 0
for seq in sequences:
truncated = seq[:max_seq_len]
count = sum(1 for tok in truncated if tok in token_ids and tok != pad_idx)
max_count = max(max_count, count)
return max_count
def _extract_hidden_states(intermediates, layer_index: int = -1, verbose: bool = True):
"""Extract hidden states from LayerIntermediates object."""
if intermediates is None:
if verbose:
print("Warning: intermediates is None")
return None
if hasattr(intermediates, 'layer_hiddens') and intermediates.layer_hiddens is not None:
if len(intermediates.layer_hiddens) > 0:
return intermediates.layer_hiddens[layer_index]
if hasattr(intermediates, 'hiddens') and intermediates.hiddens is not None:
if len(intermediates.hiddens) > 0:
return intermediates.hiddens[layer_index]
if hasattr(intermediates, 'attn_intermediates') and intermediates.attn_intermediates is not None:
if len(intermediates.attn_intermediates) > 0:
attn_int = intermediates.attn_intermediates[layer_index]
if hasattr(attn_int, 'values') and attn_int.values is not None:
return attn_int.values
if verbose:
print("Warning: Could not find layer_hiddens in intermediates")
return None
def _get_token_embeddings(net_model, inputs: torch.Tensor):
"""Get token embeddings directly from embedding layer."""
if hasattr(net_model, 'token_emb'):
if hasattr(net_model.token_emb, 'emb'):
return net_model.token_emb.emb(inputs)
else:
return net_model.token_emb(inputs)
elif hasattr(net_model, 'emb'):
return net_model.emb(inputs)
else:
raise ValueError("Could not find embedding layer in model")
def _mean_pooling(
hidden: torch.Tensor,
inputs: torch.Tensor,
mask: torch.Tensor,
token_ids: Optional[Set[int]],
token_weights: Optional[Dict[int, float]],
verbose: bool = True
) -> torch.Tensor:
"""
Mean pooling with token ID filtering and weighted averaging.
"""
batch_size, seq_len, dim = hidden.shape
device = hidden.device
if mask.ndim > 2:
mask = mask.squeeze()
effective_mask = mask.clone()
if token_ids is not None:
token_mask = torch.zeros_like(mask, dtype=torch.bool, device=device)
for tid in token_ids:
token_mask = token_mask | (inputs == tid)
effective_mask = effective_mask & token_mask
if verbose and effective_mask.sum() == 0:
print(f"Warning: No tokens match filter, falling back to all valid tokens")
effective_mask = mask
if token_weights is not None:
weights = torch.zeros_like(effective_mask, dtype=torch.float32, device=device)
for token_id, weight in token_weights.items():
id_mask = (inputs == token_id) & effective_mask
weights = weights.masked_fill(id_mask, float(weight))
weights = weights.masked_fill(effective_mask & (weights == 0), 1.0)
weighted_hidden = hidden * weights.unsqueeze(-1)
sum_weighted = weighted_hidden.sum(dim=1)
sum_weights = weights.sum(dim=1, keepdim=True).clamp(min=1e-9)
return sum_weighted / sum_weights
else:
masked_hidden = hidden * effective_mask.unsqueeze(-1).float()
sum_hidden = masked_hidden.sum(dim=1)
count = effective_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
return sum_hidden / count
def _concat_pooling(
hidden: torch.Tensor,
inputs: torch.Tensor,
mask: torch.Tensor,
token_ids: Optional[Set[int]],
token_weights: Optional[Dict[int, float]],
max_tokens: Optional[int],
verbose: bool = True
) -> torch.Tensor:
"""
Concat pooling with token ID filtering and weight-based priority selection.
"""
batch_size, seq_len, dim = hidden.shape
device = hidden.device
if max_tokens is None:
max_tokens = 1
output_dim = max_tokens * dim
all_token_embs = []
for i in range(batch_size):
seq_mask = mask[i]
seq_inputs = inputs[i]
if token_ids is not None:
matching_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
for tid in token_ids:
matching_mask = matching_mask | ((seq_inputs == tid) & seq_mask)
valid_indices = matching_mask.nonzero(as_tuple=True)[0]
else:
valid_indices = seq_mask.nonzero(as_tuple=True)[0]
if len(valid_indices) == 0:
emb = torch.zeros(dim, device=device)
emb = F.pad(emb, (0, output_dim - dim))
all_token_embs.append(emb)
continue
matching_embs = hidden[i, valid_indices, :]
if token_weights is not None and len(valid_indices) > max_tokens:
weights_list = []
for idx in valid_indices:
tok_id = seq_inputs[idx].item()
weights_list.append(token_weights.get(tok_id, 1.0))
sorted_pairs = sorted(zip(range(len(valid_indices)), weights_list),
key=lambda x: x[1], reverse=True)
top_indices = [valid_indices[p[0]] for p in sorted_pairs[:max_tokens]]
matching_embs = hidden[i, torch.tensor(top_indices, device=device), :]
elif len(valid_indices) > max_tokens:
matching_embs = matching_embs[:max_tokens]
if len(valid_indices) < max_tokens:
padding_needed = max_tokens - len(valid_indices)
padding = torch.zeros(padding_needed, dim, device=device)
matching_embs = torch.cat([matching_embs, padding], dim=0)
emb = matching_embs.reshape(-1)
all_token_embs.append(emb)
return torch.stack(all_token_embs, dim=0)
#=================================================================================================================================
# Non-Autoregressive Encoder Embeddings Retrieval Functions
#=================================================================================================================================
def get_enc_embeddings(
model,
sequences: List[List[int]],
seq_len: Optional[int] = 3072,
seq_pad_idx: int = 385,
batch_size: int = 64,
save_every_num_batches: int = -1,
save_file_path: str = "saved_embeddings.npy",
device: Optional[torch.device] = None,
normalize: bool = False,
pooling: str = "auto", # "auto" | "mean" | "weighted_mean"
token_type_weights: Optional[Tuple[float, float, float]] = None, # (onset_w, duration_w, pitch_w)
use_bfloat16: bool = True, # enable bfloat16 autocast when possible
return_dtype: str = "float32", # "float32" or "float16" for returned embeddings
return_numpy: bool = False,
verbose: bool = True,
show_progress_bar: bool = True
) -> Union[Tensor, np.ndarray]:
"""
Compute embeddings for a list of token sequences using a PyTorch model with optional bfloat16/autocast,
pooling, normalization, and periodic saving.
This function batches input token id sequences, pads/truncates them to a fixed length, runs the model
in evaluation mode under `torch.no_grad()` and optional mixed-precision autocast, and returns a single
tensor (or NumPy array) containing per-sequence embeddings. The model is expected to accept a LongTensor
of token ids `x` and a boolean mask `mask` and to return either:
- a 2-D tensor `(B, D)` of already-pooled embeddings, or
- a 3-D tensor `(B, L, D)` of per-token embeddings (which will be pooled according to `pooling`).
Key behaviors:
- Sequences are padded with `seq_pad_idx` and masked so padding does not affect pooling.
- If `seq_len` is provided, sequences longer than `seq_len` are truncated; otherwise the batch max length is used.
- Mixed-precision autocast is used when `use_bfloat16` is True and supported by the device; the function
falls back to the default autocast or no autocast if unavailable.
- Supports three pooling modes for per-token embeddings:
- `"auto"` or `"mean"`: simple masked mean pooling across tokens.
- `"weighted_mean"`: weighted mean pooling by token type (onset/duration/pitch) inferred from token ids;
weights are provided via `token_type_weights` and padding tokens are ignored.
- Optionally L2-normalizes embeddings (in float32) when `normalize=True`.
- Returned embeddings can be cast to `float16` for storage/transfer via `return_dtype`.
- Embeddings are collected on CPU; intermediate results can be periodically saved to `save_file_path`.
- If `return_numpy=True`, a NumPy array is returned; otherwise a CPU `torch.Tensor` is returned.
Args:
model (torch.nn.Module):
PyTorch model used to compute embeddings. The model will be moved to `device` (or its current
parameter device if `device` is None) and set to `eval()` for inference. The forward call must
accept `x` (LongTensor) and `mask` (BoolTensor) and return embeddings when called with
`return_embeddings=True`.
sequences (List[List[int]]):
Batch of token id sequences (each sequence is a list of ints). Can be empty; an empty result
with shape `(0, 0)` will be returned in that case.
seq_len (Optional[int], default=3072):
Target sequence length for truncation/padding. If None, the maximum sequence length in the
current batch is used.
seq_pad_idx (int, default=385):
Token id used for padding positions.
batch_size (int, default=64):
Number of sequences processed per forward pass.
save_every_num_batches (int, default=-1):
If > 0, the function will save accumulated embeddings to `save_file_path` every
`save_every_num_batches` batches. A non-positive value disables periodic saving.
save_file_path (str, default="saved_embeddings.npy"):
File path used by `np.save` when periodic saving is enabled.
device (Optional[torch.device], default=None):
Device to run the model and tensors on. If None, the device of the model parameters is used.
normalize (bool, default=False):
If True, L2-normalize each embedding vector (done in float32 for numerical stability).
pooling (str, default="auto"):
Pooling strategy applied when model returns per-token embeddings:
- "auto" or "mean": masked mean pooling.
- "weighted_mean": weighted mean pooling by token type using `token_type_weights`.
Any other value raises `ValueError`.
token_type_weights (Optional[Tuple[float, float, float]], default=None):
Per-token-type weights `(onset_w, duration_w, pitch_w)` used when `pooling="weighted_mean"`.
If None, defaults to `(1.0, 1.0, 1.0)`. Token type ranges are inferred as:
onset: token_id in [0, 127]
duration:token_id in [128, 255]
pitch: token_id in [256, 383]
use_bfloat16 (bool, default=True):
If True, attempts to use `torch.bfloat16` autocast for the device; falls back gracefully if not supported.
return_dtype (str, default="float32"):
Data type for returned embeddings: `"float32"` or `"float16"`. Internally embeddings are normalized
in float32; casting to float16 happens just before collecting results if requested.
return_numpy (bool, default=False):
If True, the final result is returned as a NumPy array; otherwise a CPU `torch.Tensor` is returned.
verbose (bool, default=True):
If True, prints progress and short diagnostic messages via `tqdm`.
show_progress_bar (bool, default=True)
If True, displays tqdm progress bar.
Returns:
Union[torch.Tensor, numpy.ndarray]:
- If `return_numpy` is False: a CPU `torch.Tensor` of shape `(N, D)` and dtype `torch.float32`
or `torch.float16` depending on `return_dtype`.
- If `return_numpy` is True: a NumPy array of shape `(N, D)` and dtype `np.float32` or `np.float16`.
`N` is the total number of input sequences and `D` is the embedding dimensionality produced by the model.
Raises:
AssertionError:
If `return_dtype` is not one of `"float32"` or `"float16"`.
RuntimeError:
If the model returns `None` for embeddings (indicates incorrect forward flags or model behavior).
ValueError:
If the model returns an embedding tensor with unexpected dimensionality or if `pooling` is unsupported.
Notes:
- The function uses `pad_and_mask` to produce `x` (LongTensor) and `mask` (BoolTensor). Padding tokens
are ignored by pooling operations.
- When `pooling="weighted_mean"`, if `token_ids` are not available or the model returns a 2-D tensor,
the function falls back to masked mean pooling.
- Periodic saving concatenates all embeddings collected so far and writes them with `np.save`. Save
failures are caught and reported when `verbose=True` but do not abort processing.
- The function runs the model under `torch.no_grad()` and sets `model.eval()`; it will move the model
to `device` if provided.
- For reproducible numeric behavior across devices, ensure the model and device support the requested
autocast dtype (bfloat16) and that any randomness is controlled externally.
Example:
>>> # simple usage
>>> embs = get_embeddings_bf16(model, sequences, seq_len=1024, batch_size=32, pooling="mean",
... normalize=True, return_dtype="float32", return_numpy=False)
"""
assert return_dtype in ("float32", "float16"), "return_dtype must be 'float32' or 'float16'"
model_device = next(model.parameters()).device if device is None else device
model.to(model_device)
model.eval()
all_embs: List[Tensor] = []
total_batches = math.ceil(len(sequences) / batch_size) if batch_size > 0 else 0
if verbose:
tqdm.write(
f"[get_embeddings_bf16] sequences={len(sequences)}, batch_size={batch_size}, "
f"batches={total_batches}, device={model_device}, seq_len={seq_len}, pooling={pooling}"
)
# Prepare autocast context using torch.amp.autocast
autocast_ctx = None
if use_bfloat16:
try:
autocast_ctx = torch.amp.autocast(device_type=model_device.type, dtype=torch.bfloat16)
except Exception:
try:
autocast_ctx = torch.amp.autocast(device_type=model_device.type)
except Exception:
autocast_ctx = None
else:
try:
autocast_ctx = torch.amp.autocast(device_type=model_device.type)
except Exception:
autocast_ctx = None
with torch.inference_mode():
batch_iter = range(0, len(sequences), batch_size)
pbar = tqdm(batch_iter, disable=not show_progress_bar, total=total_batches, desc="Embedding batches")
for batch_idx, i in enumerate(pbar):
batch_seqs = sequences[i : i + batch_size]
x, mask = pad_and_mask(batch_seqs, pad_idx=seq_pad_idx, seq_len=seq_len, device=model_device, verbose=verbose)
# x: (B, L) LongTensor token ids, mask: (B, L) boolean
# Run forward under autocast if available
if autocast_ctx is not None:
with autocast_ctx:
out = model(x, return_embeddings=True, mask=mask)
else:
out = model(x, return_embeddings=True, mask=mask)
if out is None:
raise RuntimeError("model returned None for embeddings. Check forward flags.")
# Handle shapes
if out.dim() == 2:
# already pooled: (B, D)
emb = out
elif out.dim() == 3:
# per-token embeddings: (B, L, D)
if pooling in ("mean", "auto"):
emb = masked_mean_pool(out, mask, dim=1, verbose=verbose)
elif pooling == "weighted_mean":
# Use token ids to compute per-token weights; fallback to mean if token ids missing
emb = masked_weighted_mean_pool(out, mask, token_ids=x, token_type_weights=token_type_weights, dim=1, verbose=verbose)
else:
raise ValueError(f"unsupported pooling: {pooling}")
else:
raise ValueError(f"unexpected embedding tensor shape: {out.shape}")
# Ensure embeddings are float32 for stable normalization/indexing
if emb.dtype != torch.float32:
emb = emb.float()
# L2 normalize in float32
if normalize:
emb = F.normalize(emb, p=2, dim=-1)
# Optionally cast to float16 for return/storage
if return_dtype == "float16":
emb = emb.half()
all_embs.append(emb.cpu())
# Update progress bar postfix with shapes and dtype
if verbose:
pbar.set_postfix({"batch": batch_idx + 1, "emb_shape": f"{emb.shape}", "dtype": str(emb.dtype)})
# Save intermediate results periodically
if save_every_num_batches > 0:
# compute 0-based batch number
bnum = batch_idx
if (bnum + 1) % save_every_num_batches == 0:
try:
concatenated = torch.cat(all_embs, dim=0).numpy()
np.save(save_file_path, concatenated)
if verbose:
tqdm.write(f"[get_embeddings_bf16] saved {concatenated.shape[0]} embeddings to {save_file_path}")
except Exception as e:
# Do not crash the whole run for a save failure; report if verbose
if verbose:
tqdm.write(f"[get_embeddings_bf16] warning: failed to save embeddings: {e}")
if len(all_embs) == 0:
# return empty tensor/array with shape (0, 0)
empty = torch.empty((0, 0), dtype=(torch.float16 if return_dtype == "float16" else torch.float32))
if verbose:
tqdm.write("[get_embeddings_bf16] no embeddings were produced; returning empty tensor")
return empty.numpy() if return_numpy else empty
result = torch.cat(all_embs, dim=0) # (N, D) on CPU
if verbose:
tqdm.write(f"[get_embeddings_bf16] finished: total_embeddings={result.shape[0]}, dim={result.shape[1]}, dtype={result.dtype}")
if return_numpy:
return result.numpy()
return result
###################################################################################
def masked_mean_pool(
token_embeddings: Tensor,
mask: Tensor,
dim: int = 1,
eps: float = 1e-9,
verbose: bool = True,
) -> Tensor:
"""
Compute a masked mean pooling over a specified dimension.
This function computes the mean of `token_embeddings` along `dim`, ignoring
positions where `mask` is False. The mask is cast to the same dtype as the
embeddings to allow safe multiplication. A small epsilon is used to avoid
division by zero for sequences that are entirely masked out.
Args:
token_embeddings: Tensor of shape (B, L, D) or similar where `dim` indexes
the sequence length. Embeddings dtype can be float16/float32/bfloat16.
mask: Boolean tensor of shape broadcastable to the sequence dimension
(e.g., (B, L)). True indicates valid tokens; False indicates padding.
dim: Dimension along which to pool (default: 1, the sequence length).
eps: Small value to avoid division by zero when a row has zero valid tokens.
verbose: If True, prints a short summary about the pooling operation.
Returns:
Tensor of pooled embeddings with the sequence dimension removed, typically
shape (B, D). The returned dtype matches `token_embeddings.dtype`.
"""
mask_f = mask.to(token_embeddings.dtype) # (B, L)
summed = (token_embeddings * mask_f.unsqueeze(-1)).sum(dim=dim) # (B, D)
counts = mask_f.sum(dim=dim).clamp_min(eps).unsqueeze(-1) # (B, 1)
pooled = summed / counts # (B, D)
if verbose:
# Use tqdm.write so it doesn't interfere with progress bars
valid_counts = counts.squeeze(-1)
tqdm.write(
f"[masked_mean_pool] pooled shape={pooled.shape}, "
f"counts min={valid_counts.min().item():.3f}, max={valid_counts.max().item():.3f}"
)
return pooled
###################################################################################
def masked_weighted_mean_pool(
token_embs: Tensor,
valid_mask: Tensor,
token_ids: Optional[Tensor] = None,
token_type_weights: Optional[Tuple[float, float, float]] = None,
dim: int = 1,
verbose: bool = False,
) -> Tensor:
"""
Weighted mean pooling across tokens. If token_ids is provided, token types are
inferred using the same ranges as the reference code:
- onset: token_id in [0, 127]
- duration:token_id in [128, 255]
- pitch: token_id in [256, 383]
token_type_weights: (onset_w, duration_w, pitch_w). If None, defaults to (1.0,1.0,1.0)
The function multiplies each token embedding by its scalar weight and divides
by the sum of weights for valid tokens per sequence.
"""
B, L, D = token_embs.shape
device = token_embs.device
dtype = token_embs.dtype
if token_ids is None:
# No token-level ids available: fallback to simple masked mean
if verbose:
tqdm.write("[masked_weighted_mean_pool] token_ids is None, falling back to masked_mean_pool")
return masked_mean_pool(token_embs, valid_mask, dim=dim, verbose=verbose)
# Default weights
if token_type_weights is None:
onset_w, duration_w, pitch_w = 1.0, 1.0, 1.0
else:
onset_w, duration_w, pitch_w = token_type_weights
# Build per-type boolean masks based on token id values (same ranges as reference)
onset_mask = (token_ids >= 0) & (token_ids < 128)
duration_mask = (token_ids >= 128) & (token_ids < 256)
pitch_mask = (token_ids >= 256) & (token_ids < 384)
# Combine with valid_mask to ignore padding positions
onset_mask = onset_mask & valid_mask
duration_mask = duration_mask & valid_mask
pitch_mask = pitch_mask & valid_mask
# Build per-token scalar weight tensor (B, L)
w = torch.ones((B, L), device=device, dtype=dtype)
if onset_w != 1.0:
w = torch.where(onset_mask, torch.tensor(onset_w, device=device, dtype=dtype), w)
if duration_w != 1.0:
w = torch.where(duration_mask, torch.tensor(duration_w, device=device, dtype=dtype), w)
if pitch_w != 1.0:
w = torch.where(pitch_mask, torch.tensor(pitch_w, device=device, dtype=dtype), w)
# Zero out weights for padding positions
valid_mask_f = valid_mask.to(dtype) # (B, L)
w = w * valid_mask_f # (B, L)
# Weighted sum and normalization
denom = w.sum(dim=1, keepdim=True).clamp(min=1e-6) # (B, 1)
w_exp = w.unsqueeze(-1) # (B, L, 1)
summed = (token_embs * w_exp).sum(dim=dim) # (B, D)
pooled = summed / denom # (B, D)
return pooled
###################################################################################
def pad_and_mask(
sequences: List[List[int]],
pad_idx: int = 385,
seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
verbose: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Pad and create a boolean mask for a batch of integer token sequences.
This utility converts a list of variable-length integer sequences into a
padded LongTensor and a corresponding boolean mask indicating valid token
positions. Sequences longer than `seq_len` are truncated. If `seq_len` is
None, the function uses the maximum sequence length in the batch.
Args:
sequences: List of token id sequences (each a list of ints).
pad_idx: Integer token id used for padding positions (default: 385).
seq_len: Optional target sequence length. If provided, sequences are
truncated or padded to this length. If None, the maximum length in
`sequences` is used.
device: Optional torch.device where the returned tensors will be placed.
If None, tensors are created on the default device.
verbose: If True, shows a small progress bar while processing sequences
and prints a summary.
Returns:
A tuple (x, mask):
- x: LongTensor of shape (B, T) containing padded token ids.
- mask: BoolTensor of shape (B, T) where True indicates a real token.
"""
# Fast path for empty batch
if not sequences:
empty = torch.empty((0, 0), dtype=torch.long, device=device)
empty_mask = torch.empty((0, 0), dtype=torch.bool, device=device)
return empty, empty_mask
# Compute lengths and the batch maximum length
lengths = [len(s) for s in sequences]
batch_max = max(lengths)
# If seq_len is given, only use it to cap lengths; but if the batch max is smaller,
# use the smaller value to avoid extra allocation/work.
if seq_len is None:
target_len = batch_max
else:
target_len = min(seq_len, batch_max)
b = len(sequences)
if target_len == 0:
x = torch.full((b, 0), pad_idx, dtype=torch.long, device=device)
mask = torch.zeros((b, 0), dtype=torch.bool, device=device)
return x, mask
x = torch.full((b, target_len), pad_idx, dtype=torch.long, device=device)
mask = torch.zeros((b, target_len), dtype=torch.bool, device=device)
# iterate with optional progress display
iterator = enumerate(sequences)
if verbose:
iterator = enumerate(tqdm(sequences, disable=not verbose, desc="Pad & mask"))
for i, seq in iterator:
if not seq:
continue
# Only truncate if seq is longer than the chosen target_len
L = len(seq)
if L > target_len:
L = target_len
# slice once to avoid creating a larger tensor then slicing
seq_slice = seq[:L]
seq_tensor = torch.tensor(seq_slice, dtype=torch.long, device=device)
else:
seq_tensor = torch.tensor(seq, dtype=torch.long, device=device)
x[i, :L] = seq_tensor[:L]
mask[i, :L] = True
if verbose:
tqdm.write(
f"[pad_and_mask] batch_size={b}, target_len={target_len}, "
f"min_len={min(lengths)}, max_len={max(lengths)}"
)
return x, mask
#=================================================================================================================================
# Embeddings similarity comparison functions
#=================================================================================================================================
import torch
import torch.nn.functional as F
from tqdm import tqdm
from typing import Optional, Union, Tuple
def topk_cosine_neighbors(embeddings: torch.Tensor,
k: int = 10,
key_embeddings: Optional[torch.Tensor] = None,
row_batch: Optional[int] = None,
col_batch: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
normalize: bool = True,
dtype: Optional[torch.dtype] = None,
show_progress: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
For each query embedding, find the indices and similarities of its top-k neighbors
from a set of key embeddings, sorted by descending similarity.
Supports both self-similarity (single array, excludes self) and pairwise
retrieval (two arrays, no exclusion).
Optimized for maximum speed and memory efficiency across CPU, CUDA, and MPS.
Uses a streaming batched approach to handle datasets larger than GPU memory.
Args:
embeddings (torch.Tensor): Query embeddings, shape (N_q, D).
k (int): How many neighbors to return.
key_embeddings (torch.Tensor, optional): Database/Key embeddings, shape (N_k, D).
If None, defaults to 'embeddings' (self-search).
row_batch (int, optional): Number of query rows to process at once. Auto-tuned if None.
col_batch (int, optional): Number of key columns to process at once. Auto-tuned if None.
device (str or torch.device, optional): Target device. If None, uses embeddings.device.
normalize (bool): If True, L2-normalize embeddings. Skip if already normalized.
dtype (torch.dtype, optional): Compute dtype (e.g., torch.float16, torch.bfloat16).
If None, uses embeddings.dtype.
show_progress (bool): Show tqdm progress bar.
Returns:
top_idx (torch.Tensor): shape (N_q, k), int32 indices of nearest neighbors (indices into key_embeddings).
top_sim (torch.Tensor): shape (N_q, k), float32 cosine similarities.
"""
# 1. Determine Search Mode (Self vs. Pairwise)
is_self_search = (key_embeddings is None)
if is_self_search:
key_embeddings = embeddings
# 2. Device & Dtype Setup
if device is None:
device = embeddings.device
else:
device = torch.device(device)
# Determine compute dtype
if dtype is None:
dtype = embeddings.dtype
else:
assert dtype.is_floating_point, "dtype must be a floating point type"
# Move and cast embeddings
# Ensure contiguous for efficient matmul
query_embeddings = embeddings.to(device=device, dtype=dtype).contiguous()
key_embeddings = key_embeddings.to(device=device, dtype=dtype).contiguous()
N_q, D = query_embeddings.shape
N_k, D_k = key_embeddings.shape
if D != D_k:
raise ValueError(f"Query and Key embeddings must have same dimension. Got {D} and {D_k}")
# Validation
if k < 1:
raise ValueError(f"k must be >= 1; got {k}")
if is_self_search:
if k >= N_q:
raise ValueError(f"For self-search, k must be < N (to exclude self). Got N={N_q}, k={k}")
else:
if k > N_k:
raise ValueError(f"For pairwise search, k must be <= N_k. Got N_k={N_k}, k={k}")
# 3. Auto-tune batch sizes based on device and memory
# Heuristics adjusted for potentially different N_q and N_k
if row_batch is None:
if device.type == 'cuda':
row_batch = 16384
elif device.type == 'mps':
row_batch = 8192
else:
row_batch = 4096 # CPU
if col_batch is None:
if device.type == 'cuda':
col_batch = 16384
elif device.type == 'mps':
col_batch = 8192
else:
col_batch = 4096 # CPU
# Clamp batch sizes to actual dimensions
row_batch = min(row_batch, N_q)
col_batch = min(col_batch, N_k)
# 4. Optional Normalization
if normalize:
# Normalize in-place if possible, or reassign
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
# Only normalize keys if they are distinct from queries to avoid redundant work
# in self-search case (already normalized above)
if not is_self_search:
key_embeddings = F.normalize(key_embeddings, p=2, dim=1)
# 5. Initialize Result Tensors (always float32 for precision in output)
top_sim = torch.empty((N_q, k), dtype=torch.float32, device=device)
top_idx = torch.empty((N_q, k), dtype=torch.int, device=device)
# Pre-allocate reusable buffers for inner loop (memory efficiency)
# Buffers for top-k merge (size 2k)
merge_sim_buffer = torch.empty((row_batch, 2 * k), dtype=dtype, device=device)
merge_idx_buffer = torch.empty((row_batch, 2 * k), dtype=torch.int, device=device)
# Buffer for column batch similarities
sim_buffer = torch.empty((row_batch, col_batch), dtype=dtype, device=device)
# Value for masking (minimum possible float for the dtype)
min_val = -torch.finfo(dtype).max
# 6. Inference Context
with torch.no_grad():
iterator = range(0, N_q, row_batch)
if show_progress:
desc = "Query Batches" if not is_self_search else "Row Batches"
iterator = tqdm(iterator, desc=desc, leave=True)
for i in iterator:
i_end = min(i + row_batch, N_q)
rb = i_end - i
rows = query_embeddings[i:i_end] # (rb, D)
# Initialize current batch top-k
# Use a tensor that persists across column batches for the current row batch
curr_sim = torch.full((rb, k), min_val, dtype=dtype, device=device)
curr_idx = torch.full((rb, k), -1, dtype=torch.int, device=device)
for j in range(0, N_k, col_batch):
j_end = min(j + col_batch, N_k)
cb = j_end - j
cols = key_embeddings[j:j_end] # (cb, D)
# Compute similarities in-place into buffer
# sim_block shape: (rb, cb)
sim_block = sim_buffer[:rb, :cb]
torch.matmul(rows, cols.T, out=sim_block)
# Mask self-similarity ONLY if self-search
if is_self_search:
offset = i - j
r_start = max(0, -offset)
r_end = min(rb, cb - offset)
if r_start < r_end:
# Vectorized masking of the diagonal
r_range = torch.arange(r_start, r_end, dtype=torch.long, device=device)
c_range = r_range + offset
sim_block[r_range, c_range] = min_val
# Top-k in block
if cb >= k:
blk_s, blk_p = torch.topk(sim_block, k, dim=1, largest=True, sorted=True)
blk_i = blk_p + j
else:
# Pad block to k if remaining keys are fewer than k
pad_size = k - cb
pad_vals = torch.full((rb, pad_size), min_val, dtype=dtype, device=device)
sims_padded = torch.cat([sim_block, pad_vals], dim=1)
blk_s, blk_p = torch.topk(sims_padded, k, dim=1, largest=True, sorted=True)
blk_i = blk_p + j
# Invalidate padded indices
blk_i[blk_s == min_val] = -1
# Merge with current best
# Layout: [curr_sim (k), blk_s (k)] -> topk(2k) -> keep k
merge_sim_buffer[:rb, :k] = curr_sim
merge_sim_buffer[:rb, k:2*k] = blk_s
merge_idx_buffer[:rb, :k] = curr_idx
merge_idx_buffer[:rb, k:2*k] = blk_i
curr_sim, top_p = torch.topk(merge_sim_buffer[:rb, :2*k], k, dim=1, largest=True, sorted=True)
curr_idx = torch.gather(merge_idx_buffer[:rb, :2*k], dim=1, index=top_p)
# Write results (convert to float32 for consistency)
top_sim[i:i_end] = curr_sim.to(torch.float32)
top_idx[i:i_end] = curr_idx
# 7. Post-processing return format
if k == 1:
return top_idx.view(-1), top_sim.view(-1)
return top_idx, top_sim
#=================================================================================================================================
# Embeddings visualization functions
#=================================================================================================================================
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
def plot_emb_cosine_similarity(embeddings,
clip=2.0,
gamma=0.55,
cmap="inferno",
figsize=(20, 20),
dpi=300,
output_fname='embeddings_similarity_plot.png',
return_sims=False
):
"""
Produces a crisp, high-contrast cosine similarity heatmap.
- clip: percentile clipping (1–5 recommended)
- gamma: nonlinear contrast (0.4–0.8 recommended)
-----------
Use Example
-----------
tok_emb = model.net.token_emb.emb.weight.detach().cpu()
plot_cosine_similarity(tok_emb)
"""
# 1. Compute cosine similarity (not distance)
cos_dist = pairwise_distances(embeddings, metric="cosine")
cos_sim = 1 - cos_dist
# 2. Gamma correction for contrast
sim = np.sign(cos_sim) * (np.abs(cos_sim) ** gamma)
# 3. Percentile clipping to remove flat tails
vmin, vmax = np.percentile(sim, [clip, 100 - clip])
# 4. Plot
plt.figure(figsize=figsize, dpi=dpi)
plt.imshow(sim, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
plt.colorbar(fraction=0.046, pad=0.04)
plt.title("Embeddings Pairwise Cosine Similarity")
plt.xlabel("Embedding Index")
plt.ylabel("Embeddings Index")
plt.tight_layout()
plt.savefig(output_fname)
plt.show()
if return_sims:
return sim
#=================================================================================================================================
# Fine-tuning functions
#=================================================================================================================================
def unfreeze_last_n_blocks_and_norms(model,
n_last=2,
verbose=True
):
"""
2-3 unfrozen layers usually produce good results. Default is 2
Returns: configured model and optimizer
"""
# freeze everything first
for p in model.parameters():
p.requires_grad = False
# unfreeze head
for p in model.net.to_logits.parameters():
p.requires_grad = True
# unfreeze last n blocks' params and any LayerNorms inside them that have params
layers = model.net.attn_layers.layers # ModuleList of blocks
last_blocks = list(layers)[-n_last:]
for block in last_blocks:
for name, p in block.named_parameters():
p.requires_grad = True
# unfreeze final norm if it has parameters
final_norm = getattr(model.net.attn_layers, "final_norm", None)
if final_norm is not None:
for p in final_norm.parameters():
p.requires_grad = True
# verify counts
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
if verbose:
print(f"Trainable params {trainable:,} / {total:,}")
# collect ids for head params
head_params = list(model.net.to_logits.parameters())
head_param_ids = {id(p) for p in head_params}
# group trainable params into two buckets without tensor comparisons
pretrained_params = []
head_only = []
for p in model.parameters():
if not p.requires_grad:
continue
if id(p) in head_param_ids:
head_only.append(p)
else:
pretrained_params.append(p)
# sanity checks
trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only)
total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
assert trainable == total_trainable, "Mismatch in grouped trainable params"
if verbose:
print(f"Pretrained params: {sum(p.numel() for p in pretrained_params):,}")
print(f"Head params: {sum(p.numel() for p in head_only):,}")
print(f"Total trainable: {total_trainable:,}")
optim = torch.optim.Adam([
{"params": pretrained_params, "lr": 1e-5},
{"params": head_params, "lr": 5e-5}
])
return model, optim
def unfreeze_last_n_blocks_and_norms_full(model,
n_last_encoder=1,
n_last_decoder=2,
verbose=True
):
"""
Freeze entire XTransformer, then unfreeze:
- Last `n_last_encoder` encoder blocks (including all parameters in those blocks, e.g., LayerNorms)
- Last `n_last_decoder` decoder blocks (including all parameters in those blocks)
- Final encoder/decoder LayerNorms (if present and has params)
- Decoder's output head (`to_logits`)
"""
from x_transformer_2_3_1 import LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm
# 1. Freeze everything
for p in model.parameters():
p.requires_grad = False
# 2. Unfreeze decoder head
for p in model.decoder.net.to_logits.parameters():
p.requires_grad = True
# 3. Helper to detect if a module is a parameterized LayerNorm-like module
def is_parametrized_norm(module):
# Custom norms from x-transformers
norm_types = (LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm)
if isinstance(module, norm_types):
return True
# Also include PyTorch built-in LayerNorm if used
if isinstance(module, torch.nn.LayerNorm):
return True
return False
# 4. Helper to unfreeze last N blocks + norms inside them + final norm
def unfreeze_last_blocks(transformer_wrapper, n_last):
if n_last <= 0:
return
# The actual AttentionLayers module
attn_layers = transformer_wrapper.attn_layers
layers = attn_layers.layers # ModuleList of blocks
last_blocks = list(layers)[-n_last:]
for block in last_blocks:
# Unfreeze all parameters in the block (includes attention, FFN, and any embedded LayerNorms)
for p in block.parameters():
p.requires_grad = True
# Additionally, explicitly unfreeze any LayerNorm-like submodules with params (defensive)
for submodule in block.modules():
if is_parametrized_norm(submodule):
for p in submodule.parameters():
p.requires_grad = True
# Unfreeze final norm (if exists and has params)
final_norm = getattr(attn_layers, 'final_norm', None)
if final_norm is not None and list(final_norm.parameters()):
for p in final_norm.parameters():
p.requires_grad = True
# 5. Apply to encoder and decoder
unfreeze_last_blocks(model.encoder, n_last_encoder)
unfreeze_last_blocks(model.decoder.net, n_last_decoder) # note: .net because of AutoregressiveWrapper
# ======================
# Parameter grouping (same as before)
# ======================
head_params = list(model.decoder.net.to_logits.parameters())
head_param_ids = {id(p) for p in head_params}
pretrained_params = []
head_only = []
for p in model.parameters():
if not p.requires_grad:
continue
if id(p) in head_param_ids:
head_only.append(p)
else:
pretrained_params.append(p)
# Sanity check
trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only)
total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
assert trainable == total_trainable, "Mismatch in grouped trainable params"
if verbose:
print(f"Trainable params {trainable:,} / {total_trainable:,}")
print(f"Pretrained (enc/dec): {sum(p.numel() for p in pretrained_params):,}")
print(f"Head: {sum(p.numel() for p in head_only):,}")
print(f"Total trainable: {total_trainable:,}")
# Optimizer
optim = torch.optim.Adam([
{"params": pretrained_params, "lr": 1e-5},
{"params": head_only, "lr": 5e-5}
])
return model, optim
#=================================================================================================================================
# Merging functions
#=================================================================================================================================
def merge_encoder_and_decoder(model,
encoder_ckpt,
decoder_ckpt,
print_keys=False,
verbose=True
):
if verbose:
print('=' * 70)
print('Merging...')
print('=' * 70)
if print_keys:
print('=' * 70)
print('Merged model keys:', model.state_dict().keys())
print('=' * 70)
if verbose:
print('=' * 70)
print('Loading encoder model...')
print('=' * 70)
enc_ckpt = torch.load(decoder_ckpt, map_location='cpu')
enc_pre_sd = enc_ckpt.get('state_dict', enc_ckpt)
if print_keys:
print('=' * 70)
print('Encoder model keys:', enc_pre_sd.keys())
print('=' * 70)
if verbose:
print('=' * 70)
print('Loading decoder model...')
print('=' * 70)
dec_ckpt = torch.load(encoder_ckpt, map_location='cpu')
dec_pre_sd = dec_ckpt.get('state_dict', dec_ckpt)
if print_keys:
print('=' * 70)
print('Decoder model keys', dec_pre_sd.keys())
print('=' * 70)
if verbose:
print('=' * 70)
print('Prepping merged model...')
print('=' * 70)
model_new_sd = model.state_dict()
for old_key, tensor in enc_pre_sd.items():
new_key = 'encoder.' + old_key
if new_key in model_new_sd:
model_new_sd[new_key] = tensor
for old_key, tensor in dec_pre_sd.items():
new_key = old_key.replace('net.', 'decoder.net.')
if new_key in model_new_sd:
model_new_sd[new_key] = tensor
if verbose:
print('=' * 70)
print('Final integrity check...')
print('=' * 70)
# new_sd is your merged/updated state_dict
incompat = model.load_state_dict(model_new_sd, strict=False)
if verbose:
# incompat is an IncompatibleKeys(namedtuple)
print("Missing keys: ", incompat.missing_keys)
print("Unexpected keys: ", incompat.unexpected_keys)
try:
if verbose:
print('=' * 70)
print('Loading merged model...')
model.load_state_dict(model_new_sd, strict=True)
if verbose:
print('Done!')
print('=' * 70)
return model
except:
if verbose:
print('Failed to create merged model!')
print('=' * 70)
return incompat
#=================================================================================================================================
# Boundary Classifier functions
#=================================================================================================================================
import os
import time
from collections import deque
from math import ceil
import numpy as np
import torch
from torch.utils.data import Dataset
from torch import nn
from typing import List, Tuple, Optional, Callable, Sequence
class BoundaryDataset(Dataset):
def __init__(self, inputs_list, labels_list):
self.inputs = inputs_list
self.labels = labels_list
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return self.inputs[idx], self.labels[idx], None
class BoundaryClassifier(nn.Module):
def __init__(self, num_tokens: int, max_seq_len: int, dim: int = 512,
depth: int = 12, heads: int = 16, num_labels: int = 2,
pad_token_id: int = 384, dropout: float = 0.1):
super().__init__()
self.pad_token_id = pad_token_id
self.backbone = TransformerWrapper(
num_tokens=num_tokens,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=dim, depth=depth, heads=heads,
rotary_pos_emb=True, attn_flash=True)
)
self.classifier = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim, num_labels)
)
def forward(self, input_ids, attn_mask=None):
hidden = self.backbone(input_ids, mask=attn_mask, return_embeddings=True)
logits = self.classifier(hidden)
return logits
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=None, ignore_index=384):
super().__init__()
self.gamma = gamma
self.alpha = alpha # Tensor of shape [num_classes]
self.ignore_index = ignore_index
def forward(self, logits, targets, mask=None):
B, N, C = logits.shape
logits_flat = logits.view(B * N, C)
targets_flat = targets.view(B * N)
# 1. Create valid mask (ignore PAD and any explicit mask)
if mask is not None:
valid_mask = (targets_flat != self.ignore_index) & mask.view(-1)
else:
valid_mask = (targets_flat != self.ignore_index)
if valid_mask.sum() == 0:
return torch.tensor(0.0, device=logits.device)
# 2. Numerically stable log_softmax
log_probs = torch.log_softmax(logits_flat, dim=-1)
# 3. CRITICAL FIX: Clamp targets to valid range [0, C-1] before indexing
# This prevents CUDA assert when targets contain PAD_IDX (e.g., 384)
targets_clamped = targets_flat.clamp(0, C - 1)
# 4. Gather probabilities safely
p_t = log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped]
p_t = torch.exp(p_t) # Convert back to probability for focal factor
# 5. Calculate NLL Loss
nll_loss = -log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped]
# 6. Focal Factor
focal_factor = (1.0 - p_t) ** self.gamma
loss = focal_factor * nll_loss
# 7. Apply Class Weights (Alpha)
if self.alpha is not None:
# Alpha must also be gathered using clamped indices
alpha_t = self.alpha[targets_clamped]
loss = alpha_t * loss
# 8. Apply Mask (Zero out loss for PAD tokens)
loss = loss * valid_mask.float()
return loss.sum() / valid_mask.sum().clamp(min=1.0)
Logger = Optional[Callable[[str], None]]
def filter_balanced_sequences(
sequences: List[List[int]],
token_types: List[List[int]],
tol: float = 0.1,
min_len: int = 1,
max_len: Optional[int] = None,
balance_target: float = 0.5,
return_indices: bool = False,
verbose: int = 2,
logger: Logger = None,
progress_chunk: int = 50000
) -> Tuple[List[List[int]], List[List[int]], Optional[List[int]]]:
"""
Filter sequence/token-type pairs to those whose token-type distribution is near-balanced,
with verbosity and lightweight progress reporting.
Parameters
----------
sequences : List[List[int]]
Token sequences (not used for balance computation).
token_types : List[List[int]]
Binary token-type lists (0/1) corresponding to sequences.
tol : float
Allowed absolute deviation from balance_target.
min_len : int
Minimum token_types length to consider.
max_len : Optional[int]
Maximum token_types length to consider.
balance_target : float
Target proportion of 1s (0..1).
return_indices : bool
If True, also return kept indices.
verbose : int
0 = silent, 1 = concise, 2 = detailed chunk diagnostics.
logger : callable or None
If provided, called with status strings instead of/in addition to printing.
progress_chunk : int
Emit chunk updates every `progress_chunk` items when verbose >= 1.
Returns
-------
filtered_sequences, filtered_token_types, indices_or_none
"""
def _log(msg: str):
if logger:
try:
logger(msg)
except Exception:
pass
if verbose >= 1:
print(msg)
if len(sequences) != len(token_types):
raise ValueError("`sequences` and `token_types` must have the same length.")
n = len(token_types)
if n == 0:
_log("Input empty: nothing to do.")
return [], [], ([] if return_indices else None)
start_all = time.perf_counter()
_log(f"Starting filter: {n} pairs; tol={tol}; target={balance_target}; min_len={min_len}; max_len={max_len}")
# Compute lengths and counts using Python builtins (fast for lists of ints)
# We iterate once and optionally emit chunk diagnostics to avoid storing huge intermediate lists.
lengths = np.empty(n, dtype=np.int32)
counts = np.empty(n, dtype=np.int32)
t0 = time.perf_counter()
for i, tlist in enumerate(token_types):
lengths[i] = len(tlist)
# sum on list of ints is C-optimized and fast
counts[i] = sum(tlist)
# chunked progress logging to avoid I/O overhead
if verbose >= 2 and (i + 1) % progress_chunk == 0:
elapsed = time.perf_counter() - t0
_log(f" scanned {i+1}/{n} token_types (elapsed {elapsed:.2f}s)")
scan_time = time.perf_counter() - t0
_log(f"Scanned counts and lengths in {scan_time:.2f}s")
# Mask length constraints and nonzero lengths
nonzero_mask = lengths > 0
mask = nonzero_mask.copy()
if min_len > 1:
mask &= (lengths >= min_len)
if max_len is not None:
mask &= (lengths <= max_len)
candidates = int(mask.sum())
_log(f"Candidates after length filtering: {candidates}/{n}")
if candidates == 0:
_log("No candidates after length filtering. Exiting.")
return [], [], ([] if return_indices else None)
# Compute proportions for candidates only
lengths_f = lengths.astype(np.float32)
proportions = np.empty_like(lengths_f)
# safe division only for masked entries
proportions[mask] = counts[mask].astype(np.float32) / lengths_f[mask]
proportions[~mask] = -1.0 # sentinel
# Balanced criterion
balance_mask = np.abs(proportions - float(balance_target)) <= float(tol)
final_mask = mask & balance_mask
kept = int(final_mask.sum())
elapsed_total = time.perf_counter() - start_all
_log(f"Kept {kept}/{n} sequences (elapsed total {elapsed_total:.2f}s)")
if kept == 0:
_log("No sequences met the balance criterion. Exiting.")
return [], [], ([] if return_indices else None)
# Get indices to keep
keep_idx = np.nonzero(final_mask)[0].tolist()
# Build filtered lists (list comprehension over kept indices)
# This is the only place we materialize the filtered lists.
t_build = time.perf_counter()
filtered_sequences = [sequences[i] for i in keep_idx]
filtered_token_types = [token_types[i] for i in keep_idx]
build_time = time.perf_counter() - t_build
_log(f"Built filtered lists: {kept} items (build time {build_time:.2f}s)")
# Optional detailed stats
if verbose >= 1:
# compute some quick stats on proportions of kept items
kept_props = proportions[final_mask]
mean_prop = float(np.mean(kept_props))
std_prop = float(np.std(kept_props))
min_prop = float(np.min(kept_props))
max_prop = float(np.max(kept_props))
_log(f"Kept proportions stats: mean={mean_prop:.4f}, std={std_prop:.4f}, min={min_prop:.4f}, max={max_prop:.4f}")
if return_indices:
return filtered_sequences, filtered_token_types, keep_idx
else:
return filtered_sequences, filtered_token_types, None
def compute_class_counts_from_list(labels_list: Sequence[Sequence[int]],
num_labels: int = 2,
pad_idx: int = 384) -> torch.LongTensor:
if len(labels_list) == 0:
return torch.zeros(num_labels, dtype=torch.long)
counts = [0] * num_labels
for lbl in range(num_labels):
counts[lbl] = sum(seq.count(lbl) for seq in labels_list)
if 0 <= pad_idx < num_labels:
pad_total = sum(seq.count(pad_idx) for seq in labels_list)
counts[pad_idx] = max(0, counts[pad_idx] - pad_total)
return torch.tensor(counts, dtype=torch.long)
def compute_class_weights(
labels_list,
num_labels=2,
pad_idx=384,
smoothing=0.0,
power=1.0,
max_ratio=50.0
):
"""
Stable, imbalance-preserving class weights.
- No renormalization that destroys imbalance
- Optional smoothing (default 0)
- Optional exponent scaling (power)
- Optional cap on extreme ratios
"""
# Count tokens
counts = compute_class_counts_from_list(labels_list, num_labels, pad_idx).float()
counts = torch.clamp(counts, min=1.0)
# Inverse frequency
inv = 1.0 / counts
# Optional smoothing
if smoothing > 0:
inv = inv + smoothing
# Optional exponent scaling
if power != 1.0:
inv = inv ** power
# Normalize so smallest class = 1.0
inv = inv / inv.min()
# Cap extreme ratios
inv = torch.clamp(inv, max=max_ratio)
return inv
def collate_fn_from_lists(batch, pad_token_id=384):
input_seqs = [list(x[0]) for x in batch]
label_seqs = [list(x[1]) for x in batch]
# Get actual lengths
lengths = [len(s) for s in input_seqs]
max_len = max(lengths) if lengths else 0
B = len(input_seqs)
input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long)
labels = torch.full((B, max_len), pad_token_id, dtype=torch.long)
attn_mask = torch.zeros((B, max_len), dtype=torch.bool)
for i, (xseq, yseq, length) in enumerate(zip(input_seqs, label_seqs, lengths)):
if length > 0:
input_ids[i, :length] = torch.LongTensor(xseq)
labels[i, :length] = torch.LongTensor(yseq[:length])
# IMPROVEMENT: Mask based on length, not token ID content
attn_mask[i, :length] = True
return input_ids, labels, attn_mask
#=================================================================================================================================
# This is the end of x_transformer_2_3_1 Python module
#=================================================================================================================================