|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import annotations
|
|
|
| import os
|
| os.environ['USE_FLASH_ATTENTION'] = '1'
|
|
|
| import torch
|
| from torch.nn.attention import SDPBackend, sdpa_kernel
|
| torch.backends.cuda.enable_flash_sdp(True)
|
|
|
|
|
|
|
|
|
|
|
| from functools import partial
|
| from typing import Tuple, Callable
|
|
|
| import torch
|
| from torch.nn import Module
|
| from torch import nn, einsum, Tensor
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
|
|
| from collections import namedtuple
|
| from functools import wraps
|
| from packaging import version
|
| from dataclasses import dataclass
|
|
|
| from einops import rearrange, repeat, pack, unpack
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class Intermediates:
|
| qk_similarities: Tensor | None = None
|
| pre_softmax_attn: Tensor | None = None
|
| post_softmax_attn: Tensor | None = None
|
| values: Tensor | None = None
|
| cached_kv: Tuple[Tensor, Tensor] | None = None
|
| layer_type: str | None = None
|
|
|
| def to_tuple(self):
|
| return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
|
|
|
|
|
|
| def exists(val):
|
| return val is not None
|
|
|
| def default(val, d):
|
| return val if exists(val) else d
|
|
|
| def at_most_one_of(*bools):
|
| return sum([*map(int, bools)]) <= 1
|
|
|
| def compact(arr):
|
| return [*filter(exists, arr)]
|
|
|
| @torch.jit.script
|
| def softclamp(t: Tensor, value: float):
|
| return (t / value).tanh() * value
|
|
|
| def pack_one(t, pattern):
|
| return pack([t], pattern)
|
|
|
| def unpack_one(t, ps, pattern):
|
| return unpack(t, ps, pattern)[0]
|
|
|
| def once(fn):
|
| called = False
|
| @wraps(fn)
|
| def inner(x):
|
| nonlocal called
|
| if called:
|
| return
|
| called = True
|
| return fn(x)
|
| return inner
|
|
|
| print_once = once(print)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def selective_attn(
|
| sim,
|
| sim_head_gate = None,
|
| no_mask_sos = True
|
| ):
|
| i, j, device = *sim.shape[-2:], sim.device
|
| sim_head_gate = default(sim_head_gate, sim[:, 0])
|
|
|
| gate = F.relu(sim_head_gate)
|
|
|
| if no_mask_sos:
|
| gate = gate.clone()
|
| gate[..., -i] = 0.
|
|
|
| eye = torch.eye(i, device = device)
|
|
|
| if j > i:
|
| eye = F.pad(eye, (j - i, 0), value = 1.)
|
|
|
| gate = (1. - eye) * gate
|
| gate = F.pad(gate, (0, 0, 1, -1), value = 0.)
|
| gate = gate.cumsum(dim = -2)
|
|
|
| return sim - rearrange(gate, 'b i j -> b 1 i j')
|
|
|
|
|
|
|
| def qk_l2_dist_squared(q, k):
|
| if k.ndim == 3:
|
| k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
|
|
|
| q, packed_shape = pack_one(q, '* i d')
|
| k, _ = pack_one(k, '* j d')
|
|
|
| l2_dist_squared = torch.cdist(q, k) ** 2
|
| return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
|
|
|
|
|
|
| def one_hot_straight_through(logits, temperature = 1.):
|
| one_hot_indices = logits.argmax(dim = -1, keepdim = True)
|
| one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
|
|
|
| soft_attn = (logits / temperature).softmax(dim = -1)
|
| return one_hot + soft_attn - soft_attn.detach()
|
|
|
|
|
|
|
|
|
| def sparse_topk_attn(
|
| logits,
|
| sparse_topk,
|
| temperature = 1.,
|
| straight_through = False
|
| ):
|
| orig_logits = logits
|
|
|
| mask_value = -torch.finfo(logits.dtype).max
|
| top_values, _ = logits.topk(sparse_topk, dim = -1)
|
| sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
|
| logits = logits.masked_fill(~sparse_topk_mask, mask_value)
|
| topk_attn = logits.softmax(dim = -1)
|
|
|
| if not straight_through:
|
| return topk_attn
|
|
|
| soft_attn = (orig_logits / temperature).softmax(dim = -1)
|
| return topk_attn.detach() + soft_attn - soft_attn.detach()
|
|
|
|
|
|
|
|
|
| def create_causal_mask(i, j, device):
|
| return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
|
|
| def onnx_create_causal_mask(i, j, device):
|
| r = torch.arange(i, device = device)
|
| causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
|
| causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
|
| return causal_mask
|
|
|
|
|
|
|
| class Attend(Module):
|
| def __init__(
|
| self,
|
| *,
|
| dropout = 0.,
|
| causal = False,
|
| heads = None,
|
| pre_talking_heads = False,
|
| post_talking_heads = False,
|
| pre_scale_post_talking_heads = False,
|
| sparse_topk = None,
|
| sparse_topk_straight_through = False,
|
| scale = None,
|
| qk_norm = False,
|
| l2_distance = False,
|
| sigmoid = False,
|
| custom_attn_fn: Callable | None = None,
|
| flash = False,
|
| softclamp_logits = False,
|
| logit_softclamp_value = 50.,
|
| add_zero_kv = False,
|
| selective = False,
|
| hard = False,
|
| cope = None,
|
| onnxable = False,
|
| sdp_kwargs: dict = dict(
|
| enable_flash = True,
|
| enable_math = True,
|
| enable_mem_efficient = True
|
| )
|
| ):
|
| super().__init__()
|
| self.scale = scale
|
|
|
|
|
|
|
| self.causal = causal
|
| self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
|
|
|
|
|
|
| is_sparse_topk_attn = exists(sparse_topk)
|
|
|
| assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
| assert not (flash and hard), 'hard attention not available for flash'
|
| assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
|
|
|
| assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
|
|
|
| if exists(custom_attn_fn):
|
| self.attn_fn = custom_attn_fn
|
| elif sigmoid:
|
| self.attn_fn = F.sigmoid
|
| elif hard:
|
| self.attn_fn = one_hot_straight_through
|
| elif is_sparse_topk_attn:
|
| self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
|
| else:
|
| softmax_fn = partial(F.softmax, dim = -1)
|
| self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
|
|
|
|
|
|
| self.dropout = dropout
|
| self.attn_dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
| assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
|
|
| self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
| self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
| self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
|
|
| if exists(self.pre_softmax_talking_heads):
|
| nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
|
|
| if exists(self.post_softmax_talking_heads):
|
| nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
|
|
| if exists(self.pre_scale_post_talking_heads):
|
|
|
| nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
|
|
|
|
|
|
| assert not (flash and selective), 'selective attention cannot work on flash attention'
|
| assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
| self.selective = selective
|
|
|
|
|
|
|
| self.l2_distance = l2_distance
|
|
|
|
|
|
|
|
|
| self.add_zero_kv = add_zero_kv
|
|
|
|
|
|
|
| if softclamp_logits:
|
| assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
| assert logit_softclamp_value > 0.
|
|
|
| self.softclamp_logits = softclamp_logits
|
| self.logit_softclamp_value = logit_softclamp_value
|
|
|
|
|
|
|
| self.cope = cope
|
|
|
|
|
|
|
| self.flash = flash
|
|
|
| torch_version = version.parse(torch.__version__)
|
| assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
|
|
|
|
|
|
| if torch_version >= version.parse('2.3'):
|
| from torch.nn.attention import SDPBackend
|
|
|
| str_to_backend = dict(
|
| enable_flash = SDPBackend.FLASH_ATTENTION,
|
| enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
| enable_math = SDPBackend.MATH,
|
| enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
| )
|
|
|
| sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
|
|
| self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
| else:
|
| self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
|
|
| def flash_attn(
|
| self,
|
| q, k, v,
|
| mask = None,
|
| attn_bias = None
|
| ):
|
| batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
|
|
|
|
|
|
|
|
| if k.ndim == 3:
|
| k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
|
|
|
| if v.ndim == 3:
|
| v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
|
|
|
|
|
|
|
| if self.l2_distance:
|
| k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
| k = F.pad(k, (0, 1), value = -1.)
|
| k = torch.cat((k, k_norm_sq), dim = -1)
|
|
|
| q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
| q = torch.cat((2 * q, q_norm_sq), dim = -1)
|
| q = F.pad(q, (0, 1), value = -1.)
|
|
|
|
|
|
|
| if exists(self.scale):
|
| default_scale = q.shape[-1] ** -0.5
|
| q = q * (self.scale / default_scale)
|
|
|
|
|
|
|
|
|
| causal = self.causal
|
|
|
|
|
|
|
|
|
| if q_len == 1 and causal:
|
| causal = False
|
|
|
|
|
|
|
| if exists(mask):
|
| assert mask.ndim == 4
|
| mask = mask.expand(batch, heads, q_len, k_len)
|
|
|
|
|
|
|
| if k_len > q_len and causal:
|
| causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| if not exists(mask):
|
| mask = ~causal_mask
|
| else:
|
| mask = mask & ~causal_mask
|
| causal = False
|
|
|
|
|
|
|
| if exists(mask) and causal:
|
| causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| mask = mask & ~causal_mask
|
| causal = False
|
|
|
|
|
|
|
| row_is_entirely_masked = None
|
|
|
| if exists(mask):
|
| row_is_entirely_masked = ~mask.any(dim = -1)
|
|
|
|
|
|
|
|
|
| if exists(attn_bias):
|
| attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
|
|
|
|
|
|
|
|
| mask_value = -torch.finfo(q.dtype).max
|
|
|
| if exists(mask):
|
| attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
|
| elif causal:
|
| causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
| attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
| causal = False
|
|
|
|
|
|
|
|
|
| mask = attn_bias
|
|
|
|
|
|
|
| with self.sdp_context_manager():
|
| out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| attn_mask = mask,
|
| dropout_p = self.dropout if self.training else 0.,
|
| is_causal = causal
|
| )
|
|
|
|
|
|
|
| if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
| out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
|
|
| return out, Intermediates()
|
|
|
| def forward(
|
| self,
|
| q, k, v,
|
| mask = None,
|
| attn_bias = None,
|
| prev_attn = None
|
| ):
|
| """
|
| einstein notation
|
| b - batch
|
| h - heads
|
| n, i, j - sequence length (base sequence length, source, target)
|
| d - feature dimension
|
| """
|
|
|
| n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
|
|
|
| scale = default(self.scale, q.shape[-1] ** -0.5)
|
|
|
| causal = self.causal
|
|
|
|
|
|
|
| if exists(mask) and mask.ndim == 2:
|
| mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
|
|
|
|
|
| if n == 1 and causal:
|
| causal = False
|
|
|
|
|
|
|
| if kv_heads == 1:
|
| k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
|
| elif kv_heads < heads:
|
| k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
|
|
|
|
|
|
|
| if self.add_zero_kv:
|
| k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
|
|
|
| if exists(mask):
|
| mask = F.pad(mask, (1, 0), value = True)
|
|
|
| if exists(attn_bias):
|
| attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
|
|
|
| if self.flash:
|
| assert not exists(prev_attn), 'residual attention not compatible with flash attention'
|
| return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
|
|
|
| kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
|
|
| if not self.l2_distance:
|
| sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
|
| else:
|
| sim = -qk_l2_dist_squared(q, k)
|
|
|
| sim = sim * scale
|
|
|
| if exists(prev_attn):
|
| sim = sim + prev_attn
|
|
|
| qk_similarities = sim.clone()
|
|
|
| if exists(self.pre_scale_post_talking_heads):
|
| pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
|
|
| if exists(self.pre_softmax_talking_heads):
|
| sim = sim + self.pre_softmax_talking_heads(sim)
|
|
|
| if exists(attn_bias):
|
| sim = sim + attn_bias
|
|
|
| if self.softclamp_logits:
|
| sim = softclamp(sim, self.logit_softclamp_value)
|
|
|
| i, j, dtype = *sim.shape[-2:], sim.dtype
|
|
|
| mask_value = -torch.finfo(sim.dtype).max
|
|
|
| if exists(mask):
|
| sim = sim.masked_fill(~mask, mask_value)
|
|
|
| if causal:
|
| causal_mask = self.create_causal_mask(i, j, device = device)
|
| sim = sim.masked_fill(causal_mask, mask_value)
|
|
|
| row_is_entirely_masked = None
|
|
|
| if exists(mask):
|
| row_is_entirely_masked = ~mask.any(dim = -1)
|
|
|
| if exists(self.cope):
|
| sim = sim + self.cope(q, sim)
|
|
|
| if self.selective:
|
| sim = selective_attn(sim)
|
|
|
| pre_softmax_attn = sim
|
|
|
| attn = self.attn_fn(sim)
|
|
|
| attn = attn.type(dtype)
|
|
|
| post_softmax_attn = attn
|
|
|
| attn = self.attn_dropout(attn)
|
|
|
| if exists(self.post_softmax_talking_heads):
|
| attn = self.post_softmax_talking_heads(attn)
|
|
|
| if exists(self.pre_scale_post_talking_heads):
|
| attn = attn * pre_to_post_scale
|
|
|
| out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
|
|
| intermediates = Intermediates(
|
| qk_similarities = qk_similarities,
|
| pre_softmax_attn = pre_softmax_attn,
|
| post_softmax_attn = post_softmax_attn
|
| )
|
|
|
| if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
| out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
|
|
| return out, intermediates
|
|
|
|
|
|
|
|
|
|
|
| from typing import Callable
|
|
|
| import math
|
| from copy import deepcopy
|
| from random import random, randrange
|
| from packaging import version
|
|
|
| import torch
|
| from torch.amp import autocast
|
| import torch.nn.functional as F
|
| from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
| from torch.utils._pytree import tree_flatten, tree_unflatten
|
| from torch.nn import Module, ModuleList, ModuleDict
|
|
|
| from functools import partial, wraps
|
| from collections import namedtuple
|
| from contextlib import nullcontext
|
| from dataclasses import dataclass
|
|
|
| import einx
|
| from einops.layers.torch import Rearrange
|
| from einops import rearrange, repeat, reduce, pack, unpack
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_DIM_HEAD = 64
|
|
|
| @dataclass
|
| class LayerIntermediates:
|
| hiddens: list[Tensor] | None = None
|
| last_hidden: Tensor | None = None
|
| attn_intermediates: list[Intermediates] | None = None
|
| layer_hiddens: list[Tensor] | None = None
|
| attn_z_loss: Tensor | None = None
|
| mems: Tensor | None = None
|
| memory_tokens: Tensor | None = None
|
| logit_entropies: Tensor | None = None
|
|
|
| LinearNoBias = partial(nn.Linear, bias = False)
|
|
|
|
|
|
|
| def exists(val):
|
| return val is not None
|
|
|
| def default(val, d):
|
| if exists(val):
|
| return val
|
| return d() if callable(d) else d
|
|
|
| def identity(t, *args, **kwargs):
|
| return t
|
|
|
| def first(it, default = None):
|
| return it[0] if len(it) > 0 else default
|
|
|
| def is_empty(x):
|
| return len(x) == 0
|
|
|
| def cast_tuple(val, depth = 1):
|
| return val if isinstance(val, tuple) else (val,) * depth
|
|
|
| def divisible_by(num, den):
|
| return (num % den) == 0
|
|
|
| def maybe(fn = None):
|
| if not exists(fn):
|
| fn = identity
|
|
|
| @wraps(fn)
|
| def inner(x, *args, **kwargs):
|
| if not exists(x):
|
| return x
|
| return fn(x, *args, **kwargs)
|
| return inner
|
|
|
| def at_most_one_of(*bools):
|
| return sum(map(int, bools)) <= 1
|
|
|
| class always():
|
| def __init__(self, val):
|
| self.val = val
|
| def __call__(self, *args, **kwargs):
|
| return self.val
|
|
|
| class not_equals():
|
| def __init__(self, val):
|
| self.val = val
|
| def __call__(self, x, *args, **kwargs):
|
| return x != self.val
|
|
|
| class equals():
|
| def __init__(self, val):
|
| self.val = val
|
| def __call__(self, x, *args, **kwargs):
|
| return x == self.val
|
|
|
| def Sequential(*modules):
|
| return nn.Sequential(*filter(exists, modules))
|
|
|
|
|
|
|
| def log(t, eps = 1e-20):
|
| return t.clamp(min = eps).log()
|
|
|
| def max_neg_value(tensor):
|
| return -torch.finfo(tensor.dtype).max
|
|
|
| def l2norm(t, groups = 1):
|
| t = rearrange(t, '... (g d) -> ... g d', g = groups)
|
| t = F.normalize(t, p = 2, dim = -1)
|
| return rearrange(t, '... g d -> ... (g d)')
|
|
|
| def softclamp(t, value):
|
| return (t / value).tanh() * value
|
|
|
| def masked_mean(t, mask = None, dim = 1):
|
| if not exists(mask):
|
| return t.mean(dim = dim)
|
|
|
| dims_append = (1,) * (t.ndim - mask.ndim)
|
| mask = mask.reshape(*mask.shape, *dims_append)
|
|
|
| num = (t * mask).sum(dim = dim)
|
| den = mask.sum(dim = dim).clamp(min = 1.)
|
| return num / den
|
|
|
| def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
| if pad == (0, 0):
|
| return t
|
|
|
| dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
| zeros = ((0, 0) * dims_from_right)
|
| return F.pad(t, (*zeros, *pad), value = value)
|
|
|
| def or_reduce(masks):
|
| head, *body = masks
|
| for rest in body:
|
| head = head | rest
|
| return head
|
|
|
|
|
|
|
| def calc_entropy(
|
| t: Tensor,
|
| is_prob = False
|
| ):
|
| prob = t.softmax(dim = -1) if not is_prob else t
|
| return -(prob * log(prob)).sum(dim = -1)
|
|
|
|
|
|
|
| def calc_z_loss(
|
| pre_softmax_attns: list[Tensor],
|
| mask = None,
|
| weight = 1.
|
| ):
|
|
|
|
|
|
|
|
|
| lse = 0.
|
|
|
| for attn in pre_softmax_attns:
|
| lse = lse + attn.logsumexp(dim = -1)
|
|
|
| loss = torch.square(lse)
|
| loss = reduce(loss, 'b h n -> b n', 'sum')
|
|
|
| if not exists(mask):
|
| return loss.mean() * weight
|
|
|
| loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
|
| return loss * weight
|
|
|
|
|
|
|
| def init_zero_(layer):
|
| nn.init.constant_(layer.weight, 0.)
|
| if exists(layer.bias):
|
| nn.init.constant_(layer.bias, 0.)
|
|
|
|
|
|
|
| def pick_and_pop(keys, d):
|
| values = tuple(d.pop(key) for key in keys)
|
| return dict(zip(keys, values))
|
|
|
| def group_dict_by_key(cond, d):
|
| return_val = [dict(),dict()]
|
| for key in d.keys():
|
| match = bool(cond(key))
|
| ind = int(not match)
|
| return_val[ind][key] = d[key]
|
| return tuple(return_val)
|
|
|
| def string_begins_with(prefix, str):
|
| return str.startswith(prefix)
|
|
|
| def group_by_key_prefix(prefix, d):
|
| return group_dict_by_key(partial(string_begins_with, prefix), d)
|
|
|
| def groupby_prefix_and_trim(prefix, d):
|
| kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
| prefix_len = len(prefix)
|
| kwargs_without_prefix = {key[prefix_len:]: value for key, value in kwargs_with_prefix.items()}
|
| return kwargs_without_prefix, kwargs
|
|
|
|
|
|
|
| def dropout_seq(seq, mask, dropout):
|
| b, n, *_, device = *seq.shape, seq.device
|
| logits = torch.randn(b, n, device = device)
|
|
|
| if exists(mask):
|
| mask_value = max_neg_value(logits)
|
| logits = logits.masked_fill(~mask, mask_value)
|
|
|
| keep_prob = 1. - dropout
|
| num_keep = max(1, int(keep_prob * n))
|
| keep_indices = logits.topk(num_keep, dim = 1).indices
|
|
|
| batch_indices = arange(b, device = device)
|
| batch_indices = rearrange(batch_indices, 'b -> b 1')
|
|
|
| seq = seq[batch_indices, keep_indices]
|
|
|
| if exists(mask):
|
| seq_counts = mask.sum(dim = -1)
|
| seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
|
| keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
|
|
|
| mask = mask[batch_indices, keep_indices] & keep_mask
|
|
|
| return seq, mask
|
|
|
|
|
|
|
| class ReluSquared(Module):
|
| def forward(self, x):
|
| return F.relu(x) ** 2
|
|
|
|
|
|
|
| class TokenEmbedding(Module):
|
| def __init__(self, dim, num_tokens, l2norm_embed = False):
|
| super().__init__()
|
| self.l2norm_embed = l2norm_embed
|
| self.emb = nn.Embedding(num_tokens, dim)
|
|
|
| def forward(self, x):
|
| token_emb = self.emb(x.long())
|
| return l2norm(token_emb) if self.l2norm_embed else token_emb
|
|
|
| def init_(self):
|
| if self.l2norm_embed:
|
| nn.init.normal_(self.emb.weight, std=1e-5)
|
| return
|
| nn.init.kaiming_normal_(self.emb.weight)
|
|
|
|
|
|
|
| class AbsolutePositionalEmbedding(Module):
|
| def __init__(self, dim, max_seq_len, l2norm_embed = False):
|
| super().__init__()
|
| self.scale = dim ** -0.5 if not l2norm_embed else 1.
|
| self.max_seq_len = max_seq_len
|
| self.l2norm_embed = l2norm_embed
|
| self.emb = nn.Embedding(max_seq_len, dim)
|
|
|
| def forward(self, x, pos = None, seq_start_pos = None):
|
| seq_len, device = x.shape[1], x.device
|
| assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
|
|
| if not exists(pos):
|
| pos = arange(seq_len, device = device)
|
|
|
| if exists(seq_start_pos):
|
| pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
|
|
| pos_emb = self.emb(pos)
|
| pos_emb = pos_emb * self.scale
|
| return l2norm(pos_emb) if self.l2norm_embed else pos_emb
|
|
|
| class ScaledSinusoidalEmbedding(Module):
|
| def __init__(self, dim, theta = 10000):
|
| super().__init__()
|
| assert divisible_by(dim, 2)
|
| self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
|
|
| half_dim = dim // 2
|
| freq_seq = arange(half_dim).float() / half_dim
|
| inv_freq = theta ** -freq_seq
|
| self.register_buffer('inv_freq', inv_freq, persistent = False)
|
|
|
| def forward(self, x, pos = None, seq_start_pos = None):
|
| seq_len, device = x.shape[1], x.device
|
|
|
| if not exists(pos):
|
| pos = arange(seq_len, device = device)
|
|
|
| if exists(seq_start_pos):
|
| pos = pos - seq_start_pos[..., None]
|
|
|
| emb = einsum('i, j -> i j', pos, self.inv_freq)
|
| emb = cat((emb.sin(), emb.cos()), dim = -1)
|
| return emb * self.scale
|
|
|
| class RelativePositionBias(Module):
|
| def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
|
| super().__init__()
|
| self.scale = scale
|
| self.causal = causal
|
| self.num_buckets = num_buckets
|
| self.max_distance = max_distance
|
| self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
|
|
| @staticmethod
|
| def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
|
| ret = 0
|
| n = -relative_position
|
| if not causal:
|
| num_buckets //= 2
|
| ret += (n < 0).long() * num_buckets
|
| n = torch.abs(n)
|
| else:
|
| n = torch.max(n, torch.zeros_like(n))
|
|
|
| max_exact = num_buckets // 2
|
| is_small = n < max_exact
|
|
|
| val_if_large = max_exact + (
|
| torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
| ).long()
|
| val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
|
|
| ret += torch.where(is_small, n, val_if_large)
|
| return ret
|
|
|
| @property
|
| def device(self):
|
| return next(self.parameters()).device
|
|
|
| def forward(self, i, j):
|
| device = self.device
|
| q_pos = arange(j - i, j, dtype = torch.long, device = device)
|
| k_pos = arange(j, dtype = torch.long, device = device)
|
| rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
|
| rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
| values = self.relative_attention_bias(rp_bucket)
|
| bias = rearrange(values, 'i j h -> h i j')
|
| return bias * self.scale
|
|
|
| class CoPE(Module):
|
| """
|
| Appendix B of https://arxiv.org/abs/2405.18719
|
| """
|
| def __init__ (
|
| self,
|
| dim,
|
| heads,
|
| max_pos,
|
| soft_onehot = False,
|
| talking_heads = False,
|
| soft_onehot_temp = 5e-2
|
| ):
|
| super () . __init__ ()
|
| self.max_pos = max_pos
|
| self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
|
|
| self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
|
| self.soft_onehot = soft_onehot
|
| self.soft_onehot_temp = soft_onehot_temp
|
|
|
| if not soft_onehot:
|
| return
|
|
|
| self.register_buffer('positions', arange(max_pos))
|
|
|
| def forward(self, query, attn_logits):
|
|
|
| if exists(self.talking_heads):
|
| i, j = attn_logits.shape[-2:]
|
| causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
|
|
|
| attn_logits = self.talking_heads(attn_logits)
|
|
|
| attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
|
|
|
|
|
|
|
| gates = attn_logits.sigmoid()
|
|
|
| pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
| pos = pos.clamp(max = self.max_pos - 1)
|
|
|
| logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
|
|
| if self.soft_onehot:
|
| diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
|
| soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
| cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
| else:
|
|
|
| pos_ceil = pos.ceil().long()
|
| pos_floor = pos.floor().long()
|
| logits_ceil = logits_int.gather(-1, pos_ceil)
|
| logits_floor = logits_int.gather(-1, pos_floor)
|
|
|
| w = pos - pos_floor
|
| cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
|
|
|
| return cope_pos_emb
|
|
|
| class DynamicPositionBias(Module):
|
| def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
| super().__init__()
|
| assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
|
| self.log_distance = log_distance
|
|
|
| self.mlp = ModuleList([])
|
|
|
| self.mlp.append(Sequential(
|
| nn.Linear(1, dim),
|
| LayerNorm(dim) if norm else None,
|
| nn.SiLU()
|
| ))
|
|
|
| for _ in range(depth - 1):
|
| self.mlp.append(Sequential(
|
| nn.Linear(dim, dim),
|
| nn.LayerNorm(dim) if norm else None,
|
| nn.SiLU()
|
| ))
|
|
|
| self.mlp.append(nn.Linear(dim, heads))
|
|
|
| @property
|
| def device(self):
|
| return next(self.parameters()).device
|
|
|
| def forward(self, i, j):
|
| n, device = j, self.device
|
|
|
|
|
| seq_arange = arange(j - i, j, device = device)
|
| context_arange = arange(j, device = device)
|
| indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
| indices += (j - 1)
|
|
|
|
|
| pos = arange(-j + 1, j, device = device).float()
|
| pos = rearrange(pos, '... -> ... 1')
|
|
|
| if self.log_distance:
|
| pos = torch.sign(pos) * torch.log(pos.abs() + 1)
|
|
|
| for layer in self.mlp:
|
| pos = layer(pos)
|
|
|
|
|
| bias = pos[indices]
|
| bias = rearrange(bias, 'i j h -> h i j')
|
| return bias
|
|
|
| class AlibiPositionalBias(Module):
|
| def __init__(
|
| self,
|
| heads,
|
| total_heads = None,
|
| slopes: list[int] | None = None,
|
| **kwargs
|
| ):
|
| super().__init__()
|
| self.heads = heads
|
| self.total_heads = default(total_heads, heads)
|
|
|
| slopes = Tensor(default(slopes, self._get_slopes(heads)))
|
| slopes = rearrange(slopes, 'h -> h 1 1')
|
|
|
| self.register_buffer('slopes', slopes, persistent = False)
|
| self.register_buffer('bias', None, persistent = False)
|
|
|
| @property
|
| def device(self):
|
| return next(self.buffers()).device
|
|
|
| @staticmethod
|
| def _get_slopes(heads):
|
| def get_slopes_power_of_2(n):
|
| start = (2**(-2**-(math.log2(n)-3)))
|
| ratio = start
|
| return [start*ratio**i for i in range(n)]
|
|
|
| if math.log2(heads).is_integer():
|
| return get_slopes_power_of_2(heads)
|
|
|
| closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
| return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
|
|
|
| def forward_custom_pos(
|
| self,
|
| pos_i: Tensor,
|
| pos_j: Tensor | None = None
|
| ):
|
| h, device = self.total_heads, self.device
|
|
|
| pos_j = default(pos_j, pos_i)
|
| bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
|
|
|
| if bias.ndim == 3:
|
| bias = rearrange(bias, 'b i j -> b 1 i j')
|
|
|
| bias = bias * self.slopes
|
| num_heads_unalibied = h - bias.shape[-3]
|
| bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
|
|
| return bias
|
|
|
| def forward(self, i, j):
|
| h, device = self.total_heads, self.device
|
|
|
| if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
| return self.bias[..., -i:, -j:]
|
|
|
| seq_arange = arange(j - i, j, device = device)
|
| context_arange = arange(j, device = device)
|
| bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
|
|
|
| bias = bias * self.slopes
|
| num_heads_unalibied = h - bias.shape[-3]
|
| bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
|
|
| self.register_buffer('bias', bias, persistent = False)
|
| return self.bias
|
|
|
| class DataDependentAlibi(Module):
|
| """ https://openreview.net/forum?id=q2Lnyegkr8 """
|
|
|
| def __init__(
|
| self,
|
| dim,
|
| heads,
|
| causal = True,
|
| bias_init = 5.,
|
| post_log_scale = 1.,
|
| ):
|
| super().__init__()
|
|
|
| self.causal = causal
|
|
|
| linear = nn.Linear(dim, heads * (1 if causal else 2))
|
|
|
| self.to_forget_gates = nn.Sequential(
|
| linear,
|
| Rearrange('b n h -> b h n'),
|
| nn.LogSigmoid()
|
| )
|
|
|
| nn.init.constant_(linear.bias, bias_init)
|
| self.post_log_scale = post_log_scale
|
|
|
| def forward(self, x):
|
| bidirectional = not self.causal
|
|
|
| forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
|
|
| forget_gates = forget_gates.cumsum(dim = -1)
|
|
|
| if bidirectional:
|
| forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
|
|
|
| forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
|
|
| if bidirectional:
|
| forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
|
| forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
|
|
|
| return forget_gates
|
|
|
| class PerRowDataDependentAlibi(Module):
|
| """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
|
|
|
| def __init__(
|
| self,
|
| dim,
|
| heads,
|
| causal = True,
|
| dim_head = 8,
|
| post_log_scale = 1.
|
| ):
|
| super().__init__()
|
| assert causal, 'bidirectional not supported yet'
|
|
|
| self.scale = dim_head ** -0.5
|
|
|
| linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
|
|
| self.to_forget_gates = nn.Sequential(
|
| linear,
|
| Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
|
| )
|
|
|
| self.post_log_scale = post_log_scale
|
|
|
| def forward(self, x):
|
| q, k = self.to_forget_gates(x)
|
| forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
|
|
| forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
|
|
|
|
|
|
|
| n = x.shape[-2]
|
| causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
|
|
|
| forget_gates = forget_gates.masked_fill(causal_mask, 0.)
|
|
|
|
|
|
|
| forget_gates = forget_gates.flip(dims = (-1,))
|
| forget_gates = forget_gates.cumsum(dim = -1)
|
| forget_gates = forget_gates.flip(dims = (-1,))
|
|
|
| return forget_gates
|
|
|
| class RotaryEmbedding(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| use_xpos = False,
|
| scale_base = 512,
|
| interpolation_factor = 1.,
|
| base = 10000,
|
| base_rescale_factor = 1.
|
| ):
|
| super().__init__()
|
|
|
|
|
|
|
| base *= base_rescale_factor ** (dim / (dim - 2))
|
|
|
| inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
|
| self.register_buffer('inv_freq', inv_freq)
|
|
|
| assert interpolation_factor >= 1.
|
| self.interpolation_factor = interpolation_factor
|
|
|
| if not use_xpos:
|
| self.register_buffer('scale', None)
|
| return
|
|
|
| scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
|
| self.scale_base = scale_base
|
| self.register_buffer('scale', scale)
|
|
|
| def forward_from_seq_len(self, seq_len):
|
| device = self.inv_freq.device
|
|
|
| t = arange(seq_len, device = device)
|
| return self.forward(t)
|
|
|
| @autocast('cuda', enabled = False)
|
| def forward(self, t):
|
| max_pos = t.max() + 1
|
|
|
| if t.ndim == 1:
|
| t = rearrange(t, 'n -> 1 n')
|
|
|
| freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
| freqs = stack((freqs, freqs), dim = -1)
|
| freqs = rearrange(freqs, '... d r -> ... (d r)')
|
|
|
| if not exists(self.scale):
|
| return freqs, 1.
|
|
|
| power = (t - (max_pos // 2)) / self.scale_base
|
| scale = self.scale ** rearrange(power, '... n -> ... n 1')
|
| scale = stack((scale, scale), dim = -1)
|
| scale = rearrange(scale, '... d r -> ... (d r)')
|
|
|
| return freqs, scale
|
|
|
| def rotate_half(x):
|
| x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| x1, x2 = x.unbind(dim = -1)
|
| x = stack((-x2, x1), dim = -1)
|
| return rearrange(x, '... d r -> ... (d r)')
|
|
|
| @autocast('cuda', enabled = False)
|
| def apply_rotary_pos_emb(t, freqs, scale = 1):
|
| rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
|
|
| freqs = freqs[:, -seq_len:, :]
|
| scale = scale[:, -seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
|
|
| if t.ndim == 4 and freqs.ndim == 3:
|
| freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
|
|
|
|
| t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
| t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| out = cat((t, t_unrotated), dim = -1)
|
|
|
| return out.type(orig_dtype)
|
|
|
|
|
|
|
| class Scale(Module):
|
| def __init__(self, value, fn):
|
| super().__init__()
|
| self.value = value
|
| self.fn = fn
|
|
|
| def forward(self, x, **kwargs):
|
| out = self.fn(x, **kwargs)
|
| scale_fn = lambda t: t * self.value
|
|
|
| if not isinstance(out, tuple):
|
| return scale_fn(out)
|
|
|
| return (scale_fn(out[0]), *out[1:])
|
|
|
| class LayerNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| unit_offset = False
|
| ):
|
| """
|
| bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
| """
|
| super().__init__()
|
| self.unit_offset = unit_offset
|
|
|
| self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
| self.gamma = nn.Parameter(torch.ones(dim))
|
| nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
|
|
| def forward(self, x):
|
| normed = self.ln(x)
|
| gamma = self.gamma + float(self.unit_offset)
|
| return normed * gamma
|
|
|
| class AdaptiveLayerNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| dim_condition = None
|
| ):
|
| super().__init__()
|
| dim_condition = default(dim_condition, dim)
|
|
|
| self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
| self.to_gamma = LinearNoBias(dim_condition, dim)
|
| nn.init.zeros_(self.to_gamma.weight)
|
|
|
| def forward(self, x, *, condition):
|
| if condition.ndim == 2:
|
| condition = rearrange(condition, 'b d -> b 1 d')
|
|
|
| normed = self.ln(x)
|
| gamma = self.to_gamma(condition)
|
| return normed * (gamma + 1.)
|
|
|
| class ScaleNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| unit_offset = False
|
| ):
|
| super().__init__()
|
| self.unit_offset = unit_offset
|
| self.scale = dim ** 0.5
|
|
|
| self.g = nn.Parameter(torch.zeros(1))
|
| nn.init.constant_(self.g, 1. - float(unit_offset))
|
|
|
| def forward(self, x):
|
| gamma = self.g + float(self.unit_offset)
|
| return F.normalize(x, dim = -1) * self.scale * gamma
|
|
|
| class RMSNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| unit_offset = False
|
| ):
|
| super().__init__()
|
| self.unit_offset = unit_offset
|
| self.scale = dim ** 0.5
|
|
|
| self.g = nn.Parameter(torch.zeros(dim))
|
| nn.init.constant_(self.g, 1. - float(unit_offset))
|
|
|
| def forward(self, x):
|
| gamma = self.g + float(self.unit_offset)
|
| return F.normalize(x, dim = -1) * self.scale * gamma
|
|
|
| class AdaptiveRMSNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| dim_condition = None
|
| ):
|
| super().__init__()
|
| self.scale = dim ** 0.5
|
| dim_condition = default(dim_condition, dim)
|
|
|
| self.to_gamma = LinearNoBias(dim_condition, dim)
|
| nn.init.zeros_(self.to_gamma.weight)
|
|
|
| def forward(self, x, *, condition):
|
| if condition.ndim == 2:
|
| condition = rearrange(condition, 'b d -> b 1 d')
|
|
|
| normed = F.normalize(x, dim = -1)
|
| gamma = self.to_gamma(condition)
|
| return normed * self.scale * (gamma + 1.)
|
|
|
| class SimpleRMSNorm(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| **kwargs
|
| ):
|
| super().__init__()
|
| self.scale = dim ** 0.5
|
|
|
| def forward(self, x):
|
| return F.normalize(x, dim = -1) * self.scale
|
|
|
| class MultiheadRMSNorm(Module):
|
| def __init__(self, dim, heads):
|
| super().__init__()
|
| self.rmsnorm = SimpleRMSNorm(dim)
|
| self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
|
|
| def forward(self, x):
|
| return self.rmsnorm(x) * (self.gamma + 1.)
|
|
|
| class DynamicTanh(Module):
|
| """ https://arxiv.org/abs/2503.10622 """
|
| def __init__(
|
| self,
|
| dim,
|
| init_alpha = 1.,
|
| gamma = 1.,
|
| beta = 0.,
|
| unit_offset = False
|
| ):
|
| super().__init__()
|
| self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
|
|
|
| self.gamma = nn.Parameter(torch.ones(dim))
|
| self.beta = nn.Parameter(torch.zeros(dim))
|
|
|
| self.pre_tanh_scale_offset = init_alpha if unit_offset else 0.
|
| self.gamma_offset = float(unit_offset)
|
|
|
| nn.init.constant_(self.pre_tanh_scale, 0 if unit_offset else init_alpha)
|
| nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
|
|
| def forward(self, x):
|
| pre_tanh_scale = self.pre_tanh_scale + self.pre_tanh_scale_offset
|
| gamma = self.gamma + self.gamma_offset
|
| return (x * pre_tanh_scale).tanh() * gamma + self.beta
|
|
|
|
|
|
|
| class Residual(Module):
|
| def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
|
| super().__init__()
|
| self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
| self.scale_residual_constant = scale_residual_constant
|
|
|
| def prepare(self, residual):
|
| return residual, residual, dict()
|
|
|
| def forward(self, x, residual, **kwargs):
|
| if exists(self.residual_scale):
|
| residual = residual * self.residual_scale
|
|
|
| if self.scale_residual_constant != 1:
|
| residual = residual * self.scale_residual_constant
|
|
|
| return x + residual
|
|
|
| class GRUGating(Module):
|
| def __init__(self, dim, scale_residual = False, **kwargs):
|
| super().__init__()
|
| self.gru = nn.GRUCell(dim, dim)
|
| self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
|
|
| def prepare(self, residual):
|
| return residual, residual, dict()
|
|
|
| def forward(self, x, residual, **kwargs):
|
| if exists(self.residual_scale):
|
| residual = residual * self.residual_scale
|
|
|
| gated_output = self.gru(
|
| rearrange(x, 'b n d -> (b n) d'),
|
| rearrange(residual, 'b n d -> (b n) d')
|
| )
|
|
|
| return gated_output.reshape_as(x)
|
|
|
|
|
|
|
| class HyperConnection(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| *,
|
| layer_index,
|
| num_residual_streams,
|
| num_input_views = 1,
|
| tanh = True,
|
| **kwargs
|
| ):
|
| """
|
| https://arxiv.org/abs/2409.19606
|
| Appendix J - Algorithm 2, Dynamic only
|
| """
|
| super().__init__()
|
|
|
| self.act = nn.Tanh() if tanh else nn.Identity()
|
|
|
| self.norm = nn.LayerNorm(dim, bias = False)
|
|
|
| self.num_residual_streams = num_residual_streams
|
| self.layer_index = layer_index
|
|
|
| self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
|
|
| init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
| init_alpha0[layer_index % num_residual_streams, :] = 1.
|
|
|
| self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
|
|
| self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
| self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
|
| self.num_input_views = num_input_views
|
|
|
| self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
| self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
|
|
| def prepare(self, residuals):
|
|
|
| residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
|
|
|
| normed = self.norm(residuals)
|
|
|
| wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
| dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
| alpha = dynamic_alpha + self.static_alpha
|
|
|
| dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
| dynamic_beta = dc_weight * self.dynamic_beta_scale
|
| beta = dynamic_beta + self.static_beta
|
|
|
|
|
|
|
| mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
|
|
| views = self.num_input_views
|
|
|
| if views == 1:
|
| branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
| else:
|
| branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
|
| branch_input = rearrange(branch_input, '... v d -> v ... d')
|
|
|
| return branch_input, residuals, dict(beta = beta)
|
|
|
| def forward(self, x, residuals, *, beta):
|
| residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
| return rearrange(residuals, 'b n s d -> (b s) n d')
|
|
|
|
|
|
|
| class DynamicLIMe(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| num_layers,
|
| num_views = 1,
|
| norm = True,
|
| use_softmax = True
|
| ):
|
| super().__init__()
|
| self.num_layers = num_layers
|
| self.multiple_views = num_views > 1
|
|
|
| self.to_weights = Sequential(
|
| RMSNorm(dim) if norm else None,
|
| nn.Linear(dim, num_views * num_layers),
|
| Rearrange('... (views layers) -> views ... layers', views = num_views),
|
| nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
|
| )
|
|
|
| def forward(
|
| self,
|
| x,
|
| hiddens
|
| ):
|
|
|
| if not is_tensor(hiddens):
|
| hiddens = stack(hiddens)
|
|
|
| assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
|
|
|
| weights = self.to_weights(x)
|
|
|
| out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
|
|
|
| if self.multiple_views:
|
| return out
|
|
|
| return rearrange(out, '1 ... -> ...')
|
|
|
|
|
|
|
| def shift(t, amount, mask = None):
|
| if amount == 0:
|
| return t
|
|
|
| amount = min(amount, t.shape[1])
|
|
|
| if exists(mask):
|
| t = t.masked_fill(~mask[..., None], 0.)
|
|
|
| return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
|
|
|
| class ShiftTokens(Module):
|
| def __init__(self, shifts, fn):
|
| super().__init__()
|
| self.fn = fn
|
| self.shifts = tuple(shifts)
|
|
|
| def forward(self, x, **kwargs):
|
| mask = kwargs.get('mask', None)
|
| shifts = self.shifts
|
| segments = len(shifts)
|
| feats_per_shift = x.shape[-1] // segments
|
| splitted = x.split(feats_per_shift, dim = -1)
|
| segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
| segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
|
| x = cat((*segments_to_shift, *rest), dim = -1)
|
| return self.fn(x, **kwargs)
|
|
|
| class FoldAxially(Module):
|
| def __init__(
|
| self,
|
| axial_dim,
|
| fn: Module
|
| ):
|
| super().__init__()
|
| self.fn = fn
|
| self.axial_dim = axial_dim
|
|
|
| def forward(
|
| self,
|
| x,
|
| **kwargs
|
| ):
|
| if self.axial_dim == 1:
|
| return self.fn(x, **kwargs)
|
|
|
| seq_len, axial_dim = x.shape[1], self.axial_dim
|
|
|
| next_multiple = math.ceil(seq_len / axial_dim) * axial_dim
|
| x = pad_at_dim(x, (0, next_multiple - seq_len), dim = 1)
|
|
|
| x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
|
|
|
| out = self.fn(x, **kwargs)
|
|
|
| (out, *rest_out), tree_spec = tree_flatten(out)
|
|
|
| out = rearrange(out, '(b axial_dim) n ... -> b (n axial_dim) ...', axial_dim = axial_dim)
|
|
|
| out = out[:, :seq_len]
|
| out = tree_unflatten((out, *rest_out), tree_spec)
|
|
|
| return out
|
|
|
|
|
|
|
| class LayerScale(Module):
|
| def __init__(
|
| self,
|
| fn: Module,
|
| dim,
|
| init_value = 0.,
|
| unit_offset = False
|
| ):
|
| super().__init__()
|
| self.unit_offset = unit_offset
|
|
|
| self.fn = fn
|
| self.gamma = nn.Parameter(torch.zeros(dim))
|
| nn.init.constant_(self.gamma, init_value - float(unit_offset))
|
|
|
| def forward(self, x, **kwargs):
|
| out = self.fn(x, **kwargs)
|
|
|
| gamma = self.gamma + float(self.unit_offset)
|
|
|
| if isinstance(out, Tensor):
|
| return out * gamma
|
|
|
| out, *rest = out
|
| return out * gamma, *rest
|
|
|
| class AdaptiveLayerScale(Module):
|
| def __init__(
|
| self,
|
| fn: Module,
|
| dim,
|
| dim_condition = None,
|
| init_bias_value = -2.
|
| ):
|
| super().__init__()
|
| self.fn = fn
|
|
|
| dim_condition = default(dim_condition, dim)
|
| self.to_gamma = nn.Linear(dim_condition, dim)
|
|
|
| nn.init.zeros_(self.to_gamma.weight)
|
| nn.init.constant_(self.to_gamma.bias, init_bias_value)
|
|
|
| def forward(self, x, *, condition, **kwargs):
|
| if condition.ndim == 2:
|
| condition = rearrange(condition, 'b d -> b 1 d')
|
|
|
| out = self.fn(x, **kwargs)
|
| gamma = self.to_gamma(condition).sigmoid()
|
|
|
| if isinstance(out, Tensor):
|
| return out * gamma
|
|
|
| out, *rest = out
|
| return out * gamma, *rest
|
|
|
|
|
|
|
| class ConcatCombine(Module):
|
| def __init__(self, dim, prev_layer_ind):
|
| super().__init__()
|
| self.prev_layer_ind = prev_layer_ind
|
| self.combine = LinearNoBias(dim * 2, dim)
|
|
|
| def forward(self, x, prev_layers: list[Tensor]):
|
| skip = prev_layers[self.prev_layer_ind]
|
| concatted_skip = cat((skip, x), dim = -1)
|
| return self.combine(concatted_skip)
|
|
|
|
|
|
|
| class GLU(Module):
|
| def __init__(
|
| self,
|
| dim_in,
|
| dim_out,
|
| activation: Callable,
|
| mult_bias = False
|
| ):
|
| super().__init__()
|
| self.act = activation
|
| self.proj = nn.Linear(dim_in, dim_out * 2)
|
| self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
|
|
|
| def forward(self, x):
|
| x, gate = self.proj(x).chunk(2, dim = -1)
|
| return x * self.act(gate) * self.mult_bias
|
|
|
| class FeedForward(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| dim_out = None,
|
| mult = 4,
|
| glu = False,
|
| glu_mult_bias = False,
|
| swish = False,
|
| relu_squared = False,
|
| custom_activation = None,
|
| post_act_ln = False,
|
| dropout = 0.,
|
| sublayer_dropout = 0.,
|
| no_bias = False,
|
| zero_init_output = False
|
| ):
|
| super().__init__()
|
| inner_dim = int(dim * mult)
|
| dim_out = default(dim_out, dim)
|
|
|
| if exists(custom_activation):
|
| activation = deepcopy(custom_activation)
|
| elif relu_squared:
|
| activation = ReluSquared()
|
| elif swish:
|
| activation = nn.SiLU()
|
| else:
|
| activation = nn.GELU()
|
|
|
| if glu:
|
| project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
|
| else:
|
| project_in = nn.Sequential(
|
| nn.Linear(dim, inner_dim, bias = not no_bias),
|
| activation
|
| )
|
|
|
| self.ff = Sequential(
|
| project_in,
|
| LayerNorm(inner_dim) if post_act_ln else None,
|
| nn.Dropout(dropout),
|
| nn.Linear(inner_dim, dim_out, bias = not no_bias),
|
| nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
| )
|
|
|
|
|
| if zero_init_output:
|
| init_zero_(self.ff[-1])
|
|
|
| def forward(self, x):
|
| return self.ff(x)
|
|
|
|
|
|
|
| class Attention(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| dim_head = DEFAULT_DIM_HEAD,
|
| dim_context = None,
|
| heads = 8,
|
| causal = False,
|
| flash = False,
|
| pre_talking_heads = False,
|
| post_talking_heads = False,
|
| pre_scale_post_talking_heads = False,
|
| head_scale = False,
|
| sparse_topk = None,
|
| sparse_topk_straight_through = False,
|
| num_mem_kv = 0,
|
| dropout = 0.,
|
| sublayer_dropout = 0.,
|
| on_attn = False,
|
| gate_value_heads = False,
|
| swiglu_values = False,
|
| gate_values = False,
|
| zero_init_output = False,
|
| hard = False,
|
| max_attend_past = None,
|
| qk_norm = False,
|
| qk_norm_groups = 1,
|
| qk_norm_scale = 10,
|
| qk_norm_dim_scale = False,
|
| l2_distance = False,
|
| sigmoid = False,
|
| selective = False,
|
| custom_attn_fn: Callable | None = None,
|
| hybrid_module: Module | None = None,
|
| hybrid_mask_kwarg: str | None = None,
|
| hybrid_fold_axial_dim: int | None = None,
|
| hybrid_learned_mix = False,
|
| one_kv_head = False,
|
| kv_heads = None,
|
| value_dim_head = None,
|
| dim_out = None,
|
| add_zero_kv = False,
|
| rotate_num_heads = None,
|
| data_dependent_alibi = False,
|
| data_dependent_alibi_per_row = False,
|
| data_dependent_alibi_per_row_dim_head = 8,
|
| data_dependent_alibi_kwargs: dict = dict(),
|
| use_cope = False,
|
| cope_max_pos = 16,
|
| cope_soft_onehot_pos = False,
|
| cope_talking_heads = False,
|
| softclamp_logits = False,
|
| logit_softclamp_value = 50.,
|
| learned_value_residual_mix = False,
|
| laser = False,
|
| laser_softclamp_value = 15.,
|
| qkv_receive_diff_residuals = False,
|
| use_latent_q = False,
|
| dim_latent_q = None,
|
| use_latent_kv = False,
|
| dim_latent_kv = None,
|
| latent_rope_subheads = None,
|
| onnxable = False,
|
| attend_sdp_kwargs: dict = dict(
|
| enable_flash = True,
|
| enable_math = True,
|
| enable_mem_efficient = True
|
| )
|
| ):
|
| super().__init__()
|
| dim_kv = default(dim_context, dim)
|
|
|
| self.scale = dim_head ** -0.5
|
|
|
| self.heads = heads
|
| self.causal = causal
|
| self.max_attend_past = max_attend_past
|
|
|
| assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
|
|
|
| value_dim_head = default(value_dim_head, dim_head)
|
| kv_heads = default(kv_heads, heads)
|
|
|
| kv_heads = 1 if one_kv_head else kv_heads
|
| assert divisible_by(heads, kv_heads)
|
|
|
| self.kv_heads = kv_heads
|
|
|
| q_dim = dim_head * heads
|
| k_dim = dim_head * kv_heads
|
| v_dim = value_dim_head * kv_heads
|
| out_dim = value_dim_head * heads
|
|
|
|
|
|
|
|
|
| self.to_latent_q = None
|
| self.to_latent_kv = None
|
| self.to_rotateable_k = None
|
|
|
| dim_q_input = dim
|
| dim_kv_input = dim_kv
|
|
|
| if use_latent_q:
|
| assert exists(dim_latent_q)
|
| self.to_latent_q = LinearNoBias(dim, dim_latent_q)
|
| dim_q_input = dim_latent_q
|
|
|
| if use_latent_kv:
|
| assert exists(dim_latent_kv)
|
| self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
|
| dim_kv_input = dim_latent_kv
|
|
|
| if exists(latent_rope_subheads):
|
| assert not exists(rotate_num_heads), '`rotate_num_heads` cannot be set when multi-latent attention is being used'
|
| rotate_num_heads = latent_rope_subheads
|
|
|
| k_dim = dim_head * (kv_heads - latent_rope_subheads)
|
|
|
| self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
|
| self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
|
|
|
| self.use_latent_q = use_latent_q
|
| self.use_latent_kv = use_latent_kv
|
|
|
|
|
|
|
| self.to_q = LinearNoBias(dim_q_input, q_dim)
|
| self.to_k = LinearNoBias(dim_kv_input, k_dim)
|
| self.to_v = LinearNoBias(dim_kv_input, v_dim)
|
|
|
|
|
|
|
| self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
| self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
| self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
|
|
|
| self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
|
|
|
|
|
| self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
|
|
|
|
|
|
| self.laser = laser
|
| self.laser_softclamp_value = laser_softclamp_value
|
|
|
|
|
|
|
| self.to_v_gate = None
|
| if gate_values:
|
| self.to_v_gate = nn.Linear(dim, out_dim)
|
| self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
|
| nn.init.constant_(self.to_v_gate.weight, 0)
|
| nn.init.constant_(self.to_v_gate.bias, 10)
|
|
|
|
|
|
|
| self.to_v_head_gate = None
|
| if gate_value_heads:
|
| self.to_v_head_gate = nn.Linear(dim, heads)
|
| nn.init.constant_(self.to_v_head_gate.weight, 0)
|
| nn.init.constant_(self.to_v_head_gate.bias, 10)
|
|
|
|
|
|
|
| self.qk_norm = qk_norm
|
| self.qk_norm_groups = qk_norm_groups
|
| self.qk_norm_scale = qk_norm_scale
|
|
|
|
|
|
|
| self.qk_norm_dim_scale = qk_norm_dim_scale
|
|
|
| self.qk_norm_q_scale = self.qk_norm_k_scale = 1
|
| if qk_norm and qk_norm_dim_scale:
|
| self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
|
| self.qk_norm_k_scale = nn.Parameter(torch.ones(kv_heads, 1, dim_head))
|
|
|
| assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
|
| assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
|
|
|
|
|
|
|
|
| cope = None
|
|
|
| if use_cope:
|
| assert causal, 'CoPE was designed for causal attention'
|
| assert not flash, 'CoPE is not flash attention compatible'
|
|
|
| cope = CoPE(
|
| dim = dim_head,
|
| heads = heads,
|
| max_pos = cope_max_pos,
|
| talking_heads = cope_talking_heads,
|
| soft_onehot = cope_soft_onehot_pos
|
| )
|
|
|
|
|
|
|
|
|
| self.data_dependent_alibi = None
|
|
|
| if data_dependent_alibi:
|
|
|
| dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
| dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
|
|
|
| if data_dependent_alibi_per_row:
|
| dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
|
|
| self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
|
|
|
|
|
|
|
| self.attend = Attend(
|
| heads = heads,
|
| causal = causal,
|
| pre_talking_heads = pre_talking_heads,
|
| post_talking_heads = post_talking_heads,
|
| pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
| dropout = dropout,
|
| sparse_topk = sparse_topk,
|
| sparse_topk_straight_through = sparse_topk_straight_through,
|
| hard = hard,
|
| qk_norm = qk_norm,
|
| scale = qk_norm_scale if qk_norm else self.scale,
|
| l2_distance = l2_distance,
|
| sigmoid = sigmoid,
|
| selective = selective,
|
| custom_attn_fn = custom_attn_fn,
|
| add_zero_kv = add_zero_kv,
|
| flash = flash,
|
| softclamp_logits = softclamp_logits,
|
| logit_softclamp_value = logit_softclamp_value,
|
| cope = cope,
|
| onnxable = onnxable,
|
| sdp_kwargs = attend_sdp_kwargs
|
| )
|
|
|
|
|
|
|
| self.head_scale = head_scale
|
| if head_scale:
|
| self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
|
|
|
|
|
|
| self.sparse_topk = sparse_topk
|
|
|
|
|
|
|
| self.num_mem_kv = num_mem_kv
|
| if num_mem_kv > 0:
|
| self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
| self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
|
|
|
|
|
|
| self.to_value_residual_mix = nn.Sequential(
|
| nn.Linear(dim, heads),
|
| nn.Sigmoid(),
|
| Rearrange('b n h -> b h n 1')
|
| ) if learned_value_residual_mix else always(0.5)
|
|
|
|
|
|
|
| self.attn_on_attn = on_attn
|
|
|
|
|
|
|
| hybrid_mix = None
|
| hybrid_norms = None
|
| hybrid_module = maybe(deepcopy)(hybrid_module)
|
|
|
| if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
| hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
| hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
|
|
|
| hybrid_norms = ModuleList([
|
| MultiheadRMSNorm(dim_head, heads = heads),
|
| MultiheadRMSNorm(dim_head, heads = heads)
|
| ])
|
|
|
| self.hybrid_module = hybrid_module
|
| self.hybrid_norms = hybrid_norms
|
| self.hybrid_mix = hybrid_mix
|
| self.hybrid_mask_kwarg = hybrid_mask_kwarg
|
|
|
|
|
|
|
| dim_out = default(dim_out, dim)
|
| self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
|
|
|
|
|
|
| self.sublayer_dropout = nn.Dropout(sublayer_dropout) if sublayer_dropout > 0. else None
|
|
|
|
|
|
|
| rotate_num_heads = default(rotate_num_heads, heads)
|
|
|
| assert 0 < rotate_num_heads <= heads
|
| is_partial_rotate_heads = rotate_num_heads < heads
|
| assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
|
|
|
| self.rotate_num_heads = rotate_num_heads
|
|
|
|
|
|
|
| self.can_cache_kv = not selective
|
|
|
|
|
|
|
| if zero_init_output:
|
| init_zero_(self.to_out)
|
|
|
| def forward(
|
| self,
|
| x,
|
| context = None,
|
| mask = None,
|
| context_mask = None,
|
| attn_mask = None,
|
| rel_pos = None,
|
| attn_bias = None,
|
| rotary_pos_emb = None,
|
| context_rotary_pos_emb = None,
|
| pos = None,
|
| prev_attn = None,
|
| mem = None,
|
| mem_mask = None,
|
| return_intermediates = False,
|
| cache: Intermediates | None = None,
|
| value_residual = None
|
| ):
|
| b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
|
|
|
|
|
|
|
|
| assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
|
|
|
| if qkv_receive_diff_residuals:
|
| assert x.ndim == 4 and x.shape[0] == 3
|
|
|
| q_input, k_input, v_input = x
|
| else:
|
| kv_input = default(context, x)
|
| q_input, k_input, v_input = x, kv_input, kv_input
|
|
|
| if exists(mem):
|
| k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
| v_input, _ = pack([mem, v_input], 'b * d')
|
|
|
|
|
|
|
|
|
| k_sub_heads = None
|
|
|
| if self.use_latent_q:
|
| q_input = self.to_latent_q(q_input)
|
|
|
| if is_multi_latent_attn:
|
| assert not qkv_receive_diff_residuals
|
| needs_k_sub_heads = exists(self.to_rotateable_k)
|
|
|
| latent_kv_input = self.to_latent_kv(k_input)
|
|
|
| if needs_k_sub_heads:
|
| rotateable_k = self.to_rotateable_k(k_input)
|
| k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
|
|
|
| if exists(cache):
|
| cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
|
| latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
|
|
|
| if exists(maybe_cached_k_sub_heads):
|
| k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
|
|
|
| if return_intermediates:
|
| cached_kv = (latent_kv_input, k_sub_heads)
|
|
|
| k_input = v_input = latent_kv_input
|
|
|
|
|
|
|
| q = self.to_q(q_input)
|
| k = self.to_k(k_input)
|
| v = self.to_v(v_input)
|
|
|
| q = self.split_q_heads(q)
|
| k = self.split_k_heads(k)
|
| v = self.split_v_heads(v)
|
|
|
|
|
|
|
| if exists(k_sub_heads):
|
| k = cat((k, k_sub_heads), dim = 1)
|
|
|
|
|
|
|
| orig_values = v
|
|
|
|
|
|
|
| if exists(value_residual):
|
| value_residual_mix = self.to_value_residual_mix(q_input)
|
| v = value_residual.lerp(v, value_residual_mix)
|
|
|
|
|
|
|
| if self.qk_norm:
|
| qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
| q, k = map(qk_l2norm, (q, k))
|
| scale = self.qk_norm_scale
|
|
|
| q = q * self.qk_norm_q_scale
|
| k = k * self.qk_norm_k_scale
|
|
|
|
|
|
|
| if not is_multi_latent_attn:
|
| if exists(cache):
|
| ck, cv = cache.cached_kv
|
|
|
| if exists(mem):
|
| mk, k = unpack(k, mem_packed_shape, 'b h * d')
|
| mv, v = unpack(v, mem_packed_shape, 'b h * d')
|
|
|
| k = cat((ck, k), dim = -2)
|
| v = cat((cv, v), dim = -2)
|
|
|
| if exists(mem):
|
| k = cat((mk, k), dim = -2)
|
| v = cat((mv, v), dim = -2)
|
|
|
| if return_intermediates:
|
| mem_len = mem.shape[-2] if exists(mem) else 0
|
| cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
|
|
| if exists(rotary_pos_emb):
|
| rotate_num_heads = self.rotate_num_heads
|
| partial_rotate_heads = rotate_num_heads < h
|
|
|
| freqs, xpos_scale = rotary_pos_emb
|
| q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
|
|
| if partial_rotate_heads:
|
| q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
|
| k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
|
|
|
| q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
|
|
|
| if has_context:
|
|
|
|
|
| freqs, xpos_scale = context_rotary_pos_emb
|
| _, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
|
|
| k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
|
|
|
| if partial_rotate_heads:
|
| q = cat((q_rest, q), dim = 1)
|
| k = cat((k_rest, k), dim = 1)
|
|
|
| input_mask = context_mask
|
|
|
| if not exists(input_mask) and not has_context:
|
| input_mask = mask
|
|
|
| if (exists(input_mask) or exists(mem_mask)) and exists(mem):
|
| seq_len, mem_len = n, mem.shape[-2]
|
|
|
| if not exists(mem_mask):
|
| input_mask = pad_at_dim(input_mask, (mem_len, 0), dim = -1, value = True)
|
| elif not exists(input_mask):
|
| input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
|
| else:
|
| input_mask = cat((mem_mask, input_mask), dim = -1)
|
|
|
|
|
|
|
| i, j = tuple(t.shape[-2] for t in (q, k))
|
|
|
|
|
|
|
| if num_mem_kv > 0:
|
| mem_k, mem_v = tuple(repeat(t, 'h n d -> b h n d', b = b) for t in (self.mem_k, self.mem_v))
|
|
|
| if self.qk_norm:
|
| mem_k = l2norm(mem_k)
|
| mem_k = mem_k * self.qk_norm_k_scale
|
|
|
| k = cat((mem_k, k), dim = -2)
|
| v = cat((mem_v, v), dim = -2)
|
|
|
| if exists(input_mask):
|
| input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
|
|
|
|
|
|
| mask_value = max_neg_value(q)
|
| masks = []
|
| final_attn_mask = None
|
|
|
| if exists(input_mask):
|
| input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
| masks.append(~input_mask)
|
|
|
| if exists(attn_mask):
|
| assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
|
| if attn_mask.ndim == 2:
|
| attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
|
| elif attn_mask.ndim == 3:
|
| attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
|
| masks.append(~attn_mask)
|
|
|
| if exists(self.max_attend_past):
|
| range_q = arange(j - i, j, device = device)
|
| range_k = arange(j, device = device)
|
| dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
|
| max_attend_past_mask = dist > self.max_attend_past
|
| max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1)
|
| masks.append(max_attend_past_mask)
|
|
|
| if len(masks) > 0:
|
| final_attn_mask = ~or_reduce(masks)
|
|
|
|
|
|
|
| if exists(rel_pos):
|
| assert not exists(attn_bias)
|
|
|
| if exists(pos):
|
| assert isinstance(rel_pos, AlibiPositionalBias), 'only alibi allowed for custom positions at the moment'
|
|
|
| attn_bias = rel_pos.forward_custom_pos(pos)
|
| else:
|
| attn_bias = rel_pos(i, j)
|
|
|
| attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
|
|
|
|
|
|
| if exists(self.data_dependent_alibi):
|
| attn_bias = self.data_dependent_alibi(x)
|
|
|
| attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
|
|
| if self.laser:
|
| v = softclamp(v, self.laser_softclamp_value)
|
| v = v.exp()
|
|
|
|
|
|
|
| out, intermediates = self.attend(
|
| q, k, v,
|
| mask = final_attn_mask,
|
| attn_bias = attn_bias,
|
| prev_attn = prev_attn
|
| )
|
|
|
|
|
|
|
| if self.laser:
|
| out = log(out)
|
|
|
|
|
|
|
| intermediates.values = orig_values
|
|
|
|
|
|
|
| if head_scale:
|
| out = out * self.head_scale_params
|
|
|
|
|
|
|
| if exists(self.to_v_head_gate):
|
| head_gate = self.to_v_head_gate(x)
|
| out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
|
|
|
|
|
|
|
|
|
|
| if exists(self.hybrid_module):
|
|
|
|
|
|
|
| hybrid_forward_kwargs = dict()
|
|
|
| if not self.causal and exists(self.hybrid_mask_kwarg):
|
| hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
|
|
|
|
|
|
| hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
|
|
|
|
|
|
| (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
|
|
|
|
|
|
| if hybrid_out.ndim == 3:
|
| hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
|
|
| out_norm, hybrid_out_norm = self.hybrid_norms
|
|
|
| out = out_norm(out)
|
| hybrid_out = hybrid_out_norm(hybrid_out)
|
|
|
| if exists(self.hybrid_mix):
|
| mix = self.hybrid_mix(x)
|
| mix = rearrange(mix, 'b n h -> b h n 1')
|
| out = out.lerp(hybrid_out, mix.sigmoid())
|
| else:
|
| out = 0.5 * (out + hybrid_out)
|
|
|
|
|
|
|
| out = self.merge_heads(out)
|
|
|
|
|
|
|
| if exists(self.to_v_gate):
|
| gates = self.to_v_gate(x)
|
| out = out * self.to_v_gate_activation(gates)
|
|
|
|
|
|
|
| out = self.to_out(out)
|
|
|
|
|
|
|
| out = maybe(self.sublayer_dropout)(out)
|
|
|
| if exists(mask):
|
| out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
|
|
| if not return_intermediates:
|
| return out
|
|
|
| intermediates.cached_kv = cached_kv
|
|
|
| return out, intermediates
|
|
|
| class AttentionLayers(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| depth = None,
|
| heads = 8,
|
| causal = False,
|
| cross_attend = False,
|
| only_cross = False,
|
| use_scalenorm = False,
|
| use_rmsnorm = False,
|
| use_dynamic_tanh = False,
|
| dynamic_tanh_init_alpha = 1.,
|
| use_simple_rmsnorm = False,
|
| use_adaptive_layernorm = False,
|
| use_adaptive_rmsnorm = False,
|
| use_adaptive_layerscale = False,
|
| norm_add_unit_offset = True,
|
| dim_condition = None,
|
| adaptive_condition_mlp = False,
|
| adaptive_condition_mlp_expansion = 4,
|
| alibi_pos_bias = False,
|
| alibi_num_heads = None,
|
| rel_pos_bias = False,
|
| rel_pos_num_buckets = 32,
|
| rel_pos_max_distance = 128,
|
| dynamic_pos_bias = False,
|
| dynamic_pos_bias_log_distance = False,
|
| dynamic_pos_bias_mlp_depth = 2,
|
| dynamic_pos_bias_norm = False,
|
| rotary_pos_emb = False,
|
| rotary_emb_dim = None,
|
| rotary_xpos = False,
|
| rotary_interpolation_factor = 1.,
|
| rotary_xpos_scale_base = 512,
|
| rotary_base_rescale_factor = 1.,
|
| rotate_num_heads = None,
|
| weight_tie_layers = False,
|
| custom_layers: tuple[str, ...] | None = None,
|
| layers_execute_order: tuple[int, ...] | None = None,
|
| sandwich_coef = None,
|
| par_ratio = None,
|
| residual_attn = False,
|
| cross_residual_attn = False,
|
| macaron = False,
|
| pre_norm = True,
|
| pre_norm_has_final_norm = True,
|
| gate_residual = False,
|
| scale_residual = False,
|
| scale_residual_constant = 1.,
|
| shift_tokens = 0,
|
| sandwich_norm = False,
|
| softclamp_output = False,
|
| softclamp_output_value = 30.,
|
| zero_init_branch_output = False,
|
| layer_dropout = 0.,
|
| cross_attn_tokens_dropout = 0.,
|
| disable_abs_pos_emb = None,
|
| use_layerscale = False,
|
| layerscale_init_value = 0.,
|
| unet_skips = False,
|
| integrate_layers = False,
|
| layer_integrate_use_softmax = True,
|
| num_residual_streams = 1,
|
| qkv_receive_diff_residuals = False,
|
| reinject_input = False,
|
| learned_reinject_input_gate = False,
|
| add_value_residual = False,
|
| learned_value_residual_mix = True,
|
| rel_pos_kwargs: dict = dict(),
|
| residual_fn_kwargs: dict = dict(),
|
| **kwargs
|
| ):
|
| super().__init__()
|
| rotary_pos_emb = rotary_pos_emb or rotary_xpos
|
|
|
| ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
| attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
| cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
|
|
| dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
| data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
|
|
|
| assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
|
|
| self.dim = dim
|
| self.causal = causal
|
| self.layers = ModuleList([])
|
|
|
|
|
|
|
|
|
|
|
| qkv_receive_diff_residuals |= integrate_layers
|
|
|
|
|
|
|
| assert num_residual_streams > 0
|
| has_hyper_connections = num_residual_streams > 1
|
|
|
| self.num_residual_streams = num_residual_streams
|
| self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
|
|
| assert not (has_hyper_connections and gate_residual)
|
|
|
| hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
|
|
|
|
|
|
|
| hiddens_counter = 0
|
| self.layer_integrators = ModuleList([])
|
|
|
| assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
|
|
|
|
|
|
|
| self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
|
|
| rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
|
|
|
| assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'
|
|
|
| if rotary_emb_dim < 32:
|
| print('when training language model, rotary embedding dimension should be at least 32')
|
|
|
| assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
| self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
|
|
| assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
|
| assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
|
|
|
|
|
|
| flash_attn = attn_kwargs.get('flash', False)
|
| assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
|
|
|
| self.rel_pos = None
|
|
|
| if rel_pos_bias:
|
| assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
| self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
|
| elif dynamic_pos_bias:
|
| assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
|
| self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
|
| elif alibi_pos_bias:
|
| alibi_num_heads = default(alibi_num_heads, heads)
|
| assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
| self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
|
|
| assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
|
|
| self.pre_norm = pre_norm
|
| self.sandwich_norm = sandwich_norm
|
|
|
| self.residual_attn = residual_attn
|
| self.cross_residual_attn = cross_residual_attn
|
| assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
|
|
|
| self.cross_attend = cross_attend
|
|
|
|
|
|
|
| assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
|
|
| norm_need_condition = False
|
| dim_condition = default(dim_condition, dim)
|
| dim_condition_mult = 1
|
|
|
| if adaptive_condition_mlp:
|
| dim_condition_mult = adaptive_condition_mlp_expansion
|
|
|
| if use_scalenorm:
|
| norm_class = ScaleNorm
|
| elif use_rmsnorm:
|
| norm_class = RMSNorm
|
| elif use_simple_rmsnorm:
|
| norm_class = SimpleRMSNorm
|
| elif use_dynamic_tanh:
|
| assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
|
| norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
|
| elif use_adaptive_layernorm:
|
| norm_need_condition = True
|
| norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
| elif use_adaptive_rmsnorm:
|
| norm_need_condition = True
|
| norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
| else:
|
| norm_class = LayerNorm
|
|
|
| norm_fn = partial(norm_class, dim)
|
|
|
| if not norm_need_condition and norm_add_unit_offset:
|
|
|
| norm_fn = partial(norm_fn, unit_offset = True)
|
|
|
| self.norm_need_condition = norm_need_condition
|
| self.dim_condition = dim_condition
|
|
|
|
|
|
|
| if cross_attend and not only_cross:
|
| default_block = ('a', 'c', 'f')
|
| elif cross_attend and only_cross:
|
| default_block = ('c', 'f')
|
| else:
|
| default_block = ('a', 'f')
|
|
|
| if macaron:
|
| default_block = ('f',) + default_block
|
|
|
|
|
|
|
| assert at_most_one_of(use_layerscale, use_adaptive_layerscale)
|
|
|
| post_branch_fn = None
|
| post_branch_fn_needs_condition = False
|
|
|
| if use_layerscale:
|
| post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
| elif use_adaptive_layerscale:
|
| post_branch_fn = partial(AdaptiveLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
|
| post_branch_fn_needs_condition = True
|
|
|
| self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
|
|
| if exists(post_branch_fn) and not post_branch_fn_needs_condition and norm_add_unit_offset:
|
| post_branch_fn = partial(post_branch_fn, unit_offset = True)
|
|
|
|
|
|
|
| self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
|
|
| self.adaptive_mlp = nn.Identity()
|
|
|
| if self.need_condition and adaptive_condition_mlp:
|
| self.adaptive_mlp = nn.Sequential(
|
| LinearNoBias(dim_condition, dim_condition * dim_condition_mult),
|
| nn.SiLU()
|
| )
|
|
|
|
|
|
|
| if zero_init_branch_output:
|
| attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
| ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
|
|
|
|
|
|
| assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
|
|
|
| assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
|
|
| if weight_tie_layers:
|
| assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
|
| assert not exists(layers_execute_order)
|
| layers_execute_order = tuple(range(len(default_block))) * depth
|
| depth = 1
|
|
|
|
|
|
|
| len_default_block = 1
|
|
|
| if exists(custom_layers):
|
| layer_types = custom_layers
|
| elif exists(par_ratio):
|
| par_depth = depth * len(default_block)
|
| assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
| default_block = tuple(filter(not_equals('f'), default_block))
|
| par_attn = par_depth // par_ratio
|
| depth_cut = par_depth * 2 // 3
|
| par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
| assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
| par_block = default_block + ('f',) * (par_width - len(default_block))
|
| par_head = par_block * par_attn
|
| layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
| elif exists(sandwich_coef):
|
| assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
| layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
| else:
|
| assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
|
| layer_types = default_block * depth
|
| len_default_block = len(default_block)
|
|
|
| self.layer_types = layer_types
|
| self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
|
|
|
| assert all([i < len(self.layer_types) for i in self.layers_execute_order])
|
|
|
| self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
|
|
|
|
|
|
| depth = default(depth, len(self.layers_execute_order))
|
| self.depth = depth
|
|
|
|
|
|
|
| self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
|
|
|
|
|
|
|
| self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
|
|
|
|
|
|
|
| shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
|
|
|
|
|
|
|
|
| self.softclamp_output = softclamp_output
|
| self.softclamp_output_value = softclamp_output_value
|
|
|
|
|
|
|
| self.final_norm = norm_fn() if pre_norm else nn.Identity()
|
|
|
|
|
|
|
| self.unet_skips = unet_skips
|
| num_skips = self.depth // len_default_block
|
|
|
| assert not (unet_skips and num_skips == 0), 'must have depth of at least 2 for unet skip connections'
|
|
|
| skip_indices = [i * len_default_block for i in range(num_skips)]
|
|
|
| self.skip_combines = ModuleList([])
|
|
|
|
|
|
|
| self.reinject_input = reinject_input
|
| self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
| self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
|
|
|
|
|
|
|
| self.add_value_residual = add_value_residual
|
|
|
| is_first_self_attn = True
|
| is_first_cross_attn = True
|
| learned_value_residual_mix &= add_value_residual
|
|
|
|
|
|
|
| for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
|
|
|
|
|
|
|
|
| block_begin = divisible_by(ind, len_default_block)
|
| block_ind = ind // len_default_block
|
|
|
| is_last_layer = ind == (len(self.layer_types) - 1)
|
|
|
|
|
|
|
| layer_qkv_receives_diff_view = layer_type == 'a' and qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
|
|
|
| if layer_type == 'a':
|
| self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
|
|
| layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = layer_qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
|
| is_first_self_attn = False
|
|
|
| elif layer_type == 'c':
|
| layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
| is_first_cross_attn = False
|
|
|
| elif layer_type == 'f':
|
| layer = FeedForward(dim, **ff_kwargs)
|
| layer = layer if not macaron else Scale(0.5, layer)
|
|
|
| else:
|
| raise Exception(f'invalid layer type {layer_type}')
|
|
|
| if layer_shift_tokens > 0:
|
| shift_range_upper = layer_shift_tokens + 1
|
| shift_range_lower = -layer_shift_tokens if not causal else 0
|
| layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
|
|
| if exists(post_branch_fn):
|
| layer = post_branch_fn(layer)
|
|
|
| layer_integrate = None
|
|
|
| if integrate_layers:
|
| num_layer_hiddens = ind + 1
|
| layer_integrate_num_view = 3 if layer_qkv_receives_diff_view else 1
|
|
|
| layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
|
|
|
| if has_hyper_connections:
|
| residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
|
|
| if layer_type == 'a' and hyper_conn_produce_diff_views:
|
| residual_fn = partial(residual_fn, num_input_views = 3)
|
|
|
| elif gate_residual:
|
| residual_fn = GRUGating
|
| else:
|
| residual_fn = Residual
|
|
|
| residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
|
|
|
|
|
|
|
| skip_combine = None
|
| is_latter_half = block_begin and block_ind >= (self.depth / 2)
|
|
|
| if self.unet_skips and is_latter_half:
|
| skip_combine = ConcatCombine(dim, skip_indices.pop())
|
|
|
|
|
|
|
| pre_branch_norm = norm_fn() if pre_norm else None
|
| post_branch_norm = norm_fn() if sandwich_norm else None
|
| post_main_norm = norm_fn() if not pre_norm else None
|
|
|
| norms = ModuleList([
|
| pre_branch_norm,
|
| post_branch_norm,
|
| post_main_norm
|
| ])
|
|
|
| self.skip_combines.append(skip_combine)
|
|
|
| self.layer_integrators.append(layer_integrate)
|
|
|
| self.layers.append(ModuleList([
|
| norms,
|
| layer,
|
| residual
|
| ]))
|
|
|
|
|
|
|
| self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
|
|
| def forward(
|
| self,
|
| x,
|
| context = None,
|
| mask = None,
|
| context_mask = None,
|
| attn_mask = None,
|
| self_attn_kv_mask = None,
|
| mems = None,
|
| mem_masks = None,
|
| seq_start_pos: Tensor | None = None,
|
| cache: LayerIntermediates | None = None,
|
| cache_age = 1,
|
| return_hiddens = False,
|
| rotary_pos_emb = None,
|
| pos = None,
|
| context_pos = None,
|
| attn_bias = None,
|
| condition = None,
|
| in_attn_cond = None,
|
| layers_execute_order: tuple[int, ...] | None = None
|
| ):
|
| assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
| assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
|
|
|
|
|
|
| if exists(condition):
|
| assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
|
|
| assert condition.ndim in {2, 3}
|
|
|
| if condition.ndim == 2:
|
| condition = rearrange(condition, 'b d -> b 1 d')
|
|
|
| condition = self.adaptive_mlp(condition)
|
|
|
|
|
|
|
| norm_kwargs = dict()
|
|
|
| if self.norm_need_condition:
|
| norm_kwargs.update(condition = condition)
|
|
|
|
|
|
|
| block_forward_kwargs = dict()
|
|
|
| if self.post_branch_fn_needs_condition:
|
| block_forward_kwargs.update(condition = condition)
|
|
|
|
|
|
|
| hiddens = []
|
| layer_hiddens = []
|
| intermediates = []
|
|
|
| prev_attn = None
|
| prev_cross_attn = None
|
|
|
| mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
| mem_masks = mem_masks.copy() if exists(mem_masks) else [None] * self.num_attn_layers
|
|
|
|
|
|
|
| if exists(seq_start_pos):
|
| seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
|
| left_pad_mask = seq_arange >= seq_start_pos[..., None]
|
|
|
| if exists(self_attn_kv_mask):
|
| self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
|
| else:
|
| self_attn_kv_mask = left_pad_mask
|
|
|
|
|
|
|
| cross_attn_rotary_pos_emb = dict()
|
|
|
| if exists(self.rotary_pos_emb):
|
| if not exists(rotary_pos_emb):
|
| maybe_mem = first(mems, None)
|
| mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
|
|
| if not exists(pos):
|
| pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
|
|
| rotary_pos_emb = self.rotary_pos_emb(pos)
|
|
|
|
|
|
|
| if exists(context_pos):
|
| assert self.cross_attend
|
| context_rotary_pos_emb = self.rotary_pos_emb(context_pos)
|
|
|
| cross_attn_rotary_pos_emb.update(
|
| rotary_pos_emb = rotary_pos_emb,
|
| context_rotary_pos_emb = context_rotary_pos_emb
|
| )
|
|
|
|
|
|
|
| attn_cache = []
|
|
|
| if exists(cache):
|
| assert self.causal and not any([*map(exists, (mask, attn_mask))])
|
|
|
| if exists(context):
|
| context = context[:, :0]
|
|
|
| if cache_age > 0:
|
| x = x[:, -cache_age:]
|
|
|
| attn_cache = cache.attn_intermediates
|
|
|
| iter_attn_cache = iter(attn_cache)
|
|
|
|
|
|
|
| streams = self.num_residual_streams
|
| is_multistream = streams > 1
|
|
|
| if is_multistream:
|
| x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
|
|
|
|
|
|
|
| layer_variables = (
|
| self.layer_types,
|
| self.skip_combines,
|
| self.layers,
|
| self.layer_dropouts,
|
| self.layer_integrators
|
| )
|
|
|
|
|
|
|
| layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
| layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
|
|
|
|
|
|
| inp_inject = None
|
|
|
| if self.reinject_input:
|
| assert not exists(in_attn_cond)
|
| inp_inject = self.reinject_input_proj(x)
|
|
|
| elif exists(in_attn_cond):
|
|
|
| inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
|
|
| if exists(inp_inject) and exists(self.learned_reinject_input_gate):
|
| inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
|
| inp_inject = inp_inject * inp_inject_gate
|
|
|
|
|
|
|
| skip_hiddens = []
|
|
|
|
|
|
|
| first_self_attn_inter = None
|
| first_cross_attn_inter = None
|
|
|
|
|
|
|
| for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
|
| is_last = ind == (len(self.layers) - 1)
|
|
|
|
|
|
|
| skip_hiddens.append(x)
|
|
|
| if exists(skip_combine):
|
| x = skip_combine(x, skip_hiddens)
|
|
|
|
|
|
|
| if self.training and layer_dropout > 0. and random() < layer_dropout:
|
| continue
|
|
|
| if layer_type == 'a':
|
| if return_hiddens:
|
| hiddens.append(x)
|
|
|
| layer_mem = mems.pop(0) if mems else None
|
| layer_mem_mask = mem_masks.pop(0) if mem_masks else None
|
|
|
| if layer_type == 'c':
|
| if self.training and self.cross_attn_tokens_dropout > 0.:
|
| context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
|
|
|
| x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
|
|
| layer_hiddens.append(x)
|
|
|
| if exists(layer_integrator):
|
| x = layer_integrator(x, layer_hiddens)
|
|
|
| pre_norm, post_branch_norm, post_main_norm = norm
|
|
|
| if self.need_condition:
|
| pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
|
| post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
| post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
|
|
| if exists(inp_inject):
|
| x = x + inp_inject
|
|
|
| if exists(pre_norm):
|
| x = pre_norm(x)
|
|
|
| if layer_type == 'a' and exists(layer_mem):
|
| layer_mem = pre_norm(layer_mem)
|
|
|
| block = partial(block, **block_forward_kwargs)
|
|
|
|
|
|
|
| maybe_self_attn_value_residual = None
|
| maybe_cross_attn_value_residual = None
|
|
|
| if self.add_value_residual:
|
| if exists(first_self_attn_inter):
|
| maybe_self_attn_value_residual = first_self_attn_inter.values
|
|
|
| if exists(first_cross_attn_inter):
|
| maybe_cross_attn_value_residual = first_cross_attn_inter.values
|
|
|
|
|
|
|
| if layer_type == 'a':
|
| out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
| elif layer_type == 'c':
|
| out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
| elif layer_type == 'f':
|
| out = block(x)
|
|
|
|
|
|
|
| if not exists(first_self_attn_inter) and layer_type == 'a':
|
| first_self_attn_inter = inter
|
|
|
| if not exists(first_cross_attn_inter) and layer_type == 'c':
|
| first_cross_attn_inter = inter
|
|
|
| if exists(post_branch_norm):
|
| out = post_branch_norm(out)
|
|
|
| x = residual_fn(out, inner_residual, **residual_kwargs)
|
|
|
| if layer_type in ('a', 'c') and return_hiddens:
|
| inter.layer_type = layer_type
|
| intermediates.append(inter)
|
|
|
| if layer_type == 'a' and self.residual_attn:
|
| prev_attn = inter.pre_softmax_attn
|
| elif layer_type == 'c' and self.cross_residual_attn:
|
| prev_cross_attn = inter.pre_softmax_attn
|
|
|
| if exists(post_main_norm):
|
| x = post_main_norm(x)
|
|
|
| if return_hiddens:
|
| layer_hiddens.append(x)
|
|
|
| if self.softclamp_output:
|
| x = softclamp(x, self.softclamp_output_value)
|
|
|
| final_norm = self.final_norm
|
|
|
| if self.need_condition:
|
| final_norm = maybe(partial)(final_norm, **norm_kwargs)
|
|
|
|
|
|
|
| if is_multistream:
|
| x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
|
|
| x = final_norm(x)
|
|
|
| if not return_hiddens:
|
| return x
|
|
|
| intermediates = LayerIntermediates(
|
| hiddens = hiddens,
|
| last_hidden = x,
|
| attn_intermediates = intermediates,
|
| layer_hiddens = layer_hiddens,
|
| )
|
|
|
| return x, intermediates
|
|
|
| class Encoder(AttentionLayers):
|
| def __init__(self, **kwargs):
|
| assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
| super().__init__(causal = False, **kwargs)
|
|
|
| class Decoder(AttentionLayers):
|
| def __init__(self, **kwargs):
|
| assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
| super().__init__(causal = True, **kwargs)
|
|
|
| class PrefixDecoder(AttentionLayers):
|
| def __init__(self, **kwargs):
|
| assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
| super().__init__(causal = False, **kwargs)
|
|
|
| def forward(
|
| self,
|
| x,
|
| *args,
|
| attn_mask = None,
|
| prefix_attn_len = None,
|
| **kwargs
|
| ):
|
| b, n, device = x.shape[0], x.shape[1], x.device
|
| causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
|
|
|
| forwarded_mask = ~causal_mask
|
|
|
| if exists(prefix_attn_len):
|
| if isinstance(prefix_attn_len, int):
|
| prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
|
|
|
| prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
|
| forwarded_mask = forwarded_mask | prefix_mask
|
|
|
| if exists(attn_mask):
|
| forwarded_mask = forwarded_mask & attn_mask
|
|
|
| return super().forward(x, *args, attn_mask = forwarded_mask, **kwargs)
|
|
|
| class CrossAttender(AttentionLayers):
|
| def __init__(self, **kwargs):
|
| super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
|
| class ViTransformerWrapper(Module):
|
| def __init__(
|
| self,
|
| *,
|
| image_size,
|
| patch_size,
|
| attn_layers: Encoder,
|
| channels = 3,
|
| num_classes = None,
|
| post_emb_norm = False,
|
| num_register_tokens = 0,
|
| emb_dropout = 0.
|
| ):
|
| super().__init__()
|
| assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
|
| dim = attn_layers.dim
|
| num_patches = (image_size // patch_size) ** 2
|
| patch_dim = channels * patch_size ** 2
|
|
|
| self.patch_size = patch_size
|
|
|
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
|
|
| has_register_tokens = num_register_tokens > 0
|
| self.has_register_tokens = has_register_tokens
|
|
|
| if has_register_tokens:
|
| self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
|
|
| self.patch_to_embedding = nn.Sequential(
|
| LayerNorm(patch_dim),
|
| nn.Linear(patch_dim, dim),
|
| LayerNorm(dim)
|
| )
|
|
|
| self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity()
|
| self.dropout = nn.Dropout(emb_dropout)
|
|
|
| self.attn_layers = attn_layers
|
|
|
| self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
|
|
|
| def forward(
|
| self,
|
| img,
|
| return_embeddings = False,
|
| return_logits_and_embeddings = False
|
| ):
|
| b, p = img.shape[0], self.patch_size
|
|
|
| x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
| x = self.patch_to_embedding(x)
|
| n = x.shape[1]
|
|
|
| x = x + self.pos_embedding[:, :n]
|
|
|
| x = self.post_emb_norm(x)
|
| x = self.dropout(x)
|
|
|
| if self.has_register_tokens:
|
| r = repeat(self.register_tokens, 'n d -> b n d', b = b)
|
| x, ps = pack((x, r), 'b * d')
|
|
|
| embed = self.attn_layers(x)
|
|
|
| if self.has_register_tokens:
|
| embed, _ = unpack(embed, ps, 'b * d')
|
|
|
| assert at_most_one_of(return_embeddings, return_logits_and_embeddings)
|
|
|
| if not exists(self.mlp_head) or return_embeddings:
|
| return embed
|
|
|
| pooled = embed.mean(dim = -2)
|
| logits = self.mlp_head(pooled)
|
|
|
| if not return_logits_and_embeddings:
|
| return logits
|
|
|
| return logits, embed
|
|
|
| class TransformerWrapper(Module):
|
| def __init__(
|
| self,
|
| *,
|
| num_tokens,
|
| max_seq_len,
|
| attn_layers: AttentionLayers,
|
| embed_num_tokens: dict[str, int] = dict(),
|
| emb_dim = None,
|
| max_mem_len = 0,
|
| shift_mem_down = 0,
|
| emb_dropout = 0.,
|
| post_emb_norm = False,
|
| num_memory_tokens = None,
|
| memory_tokens_interspersed_every = None,
|
| tie_embedding = False,
|
| logits_dim = None,
|
| return_only_embed = False,
|
| num_output_heads = 1,
|
| use_abs_pos_emb = True,
|
| scaled_sinu_pos_emb = False,
|
| l2norm_embed = False,
|
| recycling = False,
|
| train_max_recycle_steps = 4,
|
| emb_frac_gradient = 1.,
|
| attn_z_loss_weight = 1e-4,
|
| average_pool_embed = False,
|
| use_cls_token = False,
|
| num_cls_tokens = 1,
|
| squeeze_out_last_dim = False,
|
| token_emb: TokenEmbedding | None = None,
|
| mixture_of_softmax = False,
|
| mixture_of_softmax_k = 4,
|
| sigsoftmax_logits = False,
|
| to_logits: Module | None = None,
|
| ):
|
| super().__init__()
|
|
|
| dim = attn_layers.dim
|
| emb_dim = default(emb_dim, dim)
|
| self.emb_dim = emb_dim
|
| self.num_tokens = num_tokens
|
| self.num_cls_tokens = num_cls_tokens
|
|
|
| self.max_seq_len = max_seq_len
|
| self.max_mem_len = max_mem_len
|
| self.shift_mem_down = shift_mem_down
|
|
|
| self.l2norm_embed = l2norm_embed
|
|
|
| if not exists(token_emb):
|
| token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
|
|
|
| self.token_emb = token_emb
|
|
|
| no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
|
|
|
| if no_abs_pos_emb:
|
| self.pos_emb = always(0)
|
| elif scaled_sinu_pos_emb:
|
| self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
| else:
|
| self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
|
|
|
|
|
|
|
| self.embeds = None
|
|
|
| if len(embed_num_tokens) > 0:
|
| self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(num_tokens, emb_dim) for name, num_tokens in embed_num_tokens.items()})
|
|
|
|
|
|
|
| self.emb_frac_gradient = emb_frac_gradient
|
|
|
| self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
| self.emb_dropout = nn.Dropout(emb_dropout)
|
|
|
| self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
| self.attn_layers = attn_layers
|
|
|
| self.init_()
|
|
|
| assert num_output_heads > 0
|
|
|
| assert at_most_one_of(average_pool_embed, use_cls_token)
|
|
|
|
|
|
|
| self.recycling = recycling
|
| self.recycled_proj = LinearNoBias(dim, dim) if recycling else None
|
|
|
| self.train_max_recycle_steps = train_max_recycle_steps
|
|
|
|
|
|
|
| self.cls_token = None
|
|
|
| if use_cls_token:
|
| self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
|
| nn.init.normal_(self.cls_token, std = 0.02)
|
|
|
|
|
|
|
| self.average_pool_embed = average_pool_embed
|
|
|
|
|
|
|
| self.output_is_log_prob = mixture_of_softmax
|
|
|
| self.to_mixture = None
|
| self.combine_mixture = None
|
|
|
| if mixture_of_softmax:
|
| assert num_output_heads == 1
|
|
|
| self.to_mixture = Sequential(
|
| LinearNoBias(dim, dim * mixture_of_softmax_k),
|
| Rearrange('... (k d) -> ... k d', k = mixture_of_softmax_k)
|
| )
|
|
|
| self.combine_mixture = LinearNoBias(dim, mixture_of_softmax_k)
|
|
|
|
|
|
|
| self.sigsoftmax_logits = sigsoftmax_logits
|
|
|
|
|
|
|
| logits_dim = default(logits_dim, num_tokens)
|
|
|
| self.has_multiple_heads = num_output_heads > 1
|
|
|
| if return_only_embed:
|
| self.to_logits = None
|
| elif tie_embedding:
|
| assert isinstance(token_emb, TokenEmbedding), 'can only tie embedding if using `TokenEmbedding`'
|
| self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
| elif num_output_heads > 1:
|
| self.to_logits = ModuleList([LinearNoBias(dim, logits_dim) for _ in range(num_output_heads)])
|
| else:
|
| self.to_logits = LinearNoBias(dim, logits_dim) if not exists(to_logits) else to_logits
|
|
|
|
|
|
|
| num_memory_tokens = default(num_memory_tokens, 0)
|
| self.num_memory_tokens = num_memory_tokens
|
| if num_memory_tokens > 0:
|
| self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
|
|
| self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
|
|
|
|
|
|
| self.squeeze_out_last_dim = squeeze_out_last_dim
|
|
|
|
|
|
|
| self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
|
| self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
|
|
| def init_(self):
|
| if hasattr(self.token_emb, 'init_'):
|
| self.token_emb.init_()
|
|
|
| if self.l2norm_embed:
|
| if not isinstance(self.pos_emb, always):
|
| nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
|
|
| def forward(
|
| self,
|
| x,
|
| return_embeddings = False,
|
| return_logits_and_embeddings = False,
|
| return_intermediates = False,
|
| return_embeddings_and_intermediates = False,
|
| return_logit_entropies = False,
|
| mask = None,
|
| return_mems = False,
|
| return_attn = False,
|
| mems = None,
|
| mem_masks = None,
|
| recycle_steps = None,
|
| pos = None,
|
| prepend_embeds = None,
|
| prepend_mask = None,
|
| embed_ids: dict[str, Tensor] = dict(),
|
| sum_embeds = None,
|
| return_attn_z_loss = False,
|
| attn_z_loss_weight = 1e-4,
|
| seq_start_pos = None,
|
| cache: LayerIntermediates | None = None,
|
| token_emb_kwargs = dict(),
|
| to_logits_kwargs = dict(),
|
| **kwargs,
|
| ):
|
|
|
|
|
|
|
| if not exists(x):
|
| assert exists(prepend_embeds)
|
| x = prepend_embeds.new_empty((prepend_embeds.shape[0], 0), dtype = torch.long)
|
|
|
|
|
|
|
| b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
|
|
| return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
| return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
|
|
|
|
|
|
| external_pos_emb = exists(pos) and pos.dtype != torch.long
|
| pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
| x = self.token_emb(x, **token_emb_kwargs) + pos_emb
|
|
|
|
|
|
|
| assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
|
|
|
| if exists(self.embeds):
|
| assert len(embed_ids) == len(self.embeds)
|
|
|
| for name, embed_id in embed_ids.items():
|
| embed_key = f'{name}_embed'
|
|
|
| assert embed_key in self.embeds
|
| embed = self.embeds[embed_key](embed_id)
|
|
|
| x = x + embed
|
|
|
|
|
|
|
| if exists(sum_embeds):
|
| x = x + sum_embeds
|
|
|
|
|
|
|
| x = self.post_emb_norm(x)
|
|
|
|
|
|
|
| if exists(prepend_embeds):
|
| prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
| assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
|
|
| x = cat((prepend_embeds, x), dim = -2)
|
|
|
| if exists(prepend_mask) or exists(mask):
|
| mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
| prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
|
|
| mask = cat((prepend_mask, mask), dim = -1)
|
|
|
|
|
|
|
| if emb_frac_gradient < 1:
|
| assert emb_frac_gradient > 0
|
| x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
|
|
|
|
|
|
|
| x = self.emb_dropout(x)
|
|
|
| x = self.project_emb(x)
|
|
|
|
|
|
|
| if exists(self.cls_token):
|
| cls_tokens = repeat(self.cls_token, '... -> b ...', b = b)
|
| x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
|
|
|
| if exists(mask):
|
| mask = F.pad(mask, (self.num_cls_tokens, 0), value = True)
|
|
|
|
|
|
|
| if has_memory_tokens:
|
| mem_seq = x.shape[-2]
|
| mem_every = self.memory_tokens_interspersed_every
|
|
|
| if exists(mem_every):
|
| assert mem_every > 0
|
| assert isinstance(self.attn_layers, Decoder), 'only for decoder'
|
| next_seq_len = math.ceil(n / mem_every) * mem_every
|
|
|
| x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
|
| x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
|
|
|
| mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
|
| x, mem_packed_shape = pack((mem, x), 'b * d')
|
|
|
|
|
| if not exists(mem_every) and exists(mask):
|
| mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
|
|
| if exists(mem_every):
|
| x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
|
|
|
|
|
|
| if self.shift_mem_down and exists(mems):
|
| mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
| mems = [*mems_r, *mems_l]
|
|
|
|
|
|
|
| if not self.recycling:
|
| assert not exists(recycle_steps) or recycle_steps == 1, 'you did not train with recycling'
|
|
|
|
|
|
|
| attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
|
|
| else:
|
|
|
|
|
| recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
|
| assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
|
|
|
| for i in range(recycle_steps):
|
| first_step = i == 0
|
| last_step = i == (recycle_steps - 1)
|
|
|
| context = nullcontext if last_step else torch.no_grad
|
|
|
| with context():
|
| maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
|
|
|
| attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
|
|
| x = attended
|
|
|
|
|
|
|
| if has_memory_tokens:
|
| if exists(mem_every):
|
| x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
|
|
|
| mem, x = unpack(x, mem_packed_shape, 'b * d')
|
|
|
| intermediates.memory_tokens = mem
|
|
|
| if exists(mem_every):
|
| x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
|
|
| x = x[:, :mem_seq]
|
|
|
|
|
|
|
| if self.average_pool_embed:
|
| x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
|
| if exists(self.cls_token):
|
| x, _ = unpack(x, cls_packed_shape, 'b * d')
|
| x = x.squeeze(1)
|
|
|
|
|
|
|
| combine_mixture = None
|
|
|
| if exists(self.to_mixture):
|
| combine_mixture = self.combine_mixture(x).softmax(dim = -1)
|
| x = self.to_mixture(x)
|
|
|
|
|
|
|
| if not return_embeddings:
|
| if self.has_multiple_heads:
|
| logits = tuple(fn(x, **to_logits_kwargs) for fn in self.to_logits)
|
| else:
|
| logits = self.to_logits(x, **to_logits_kwargs)
|
|
|
|
|
|
|
| if self.sigsoftmax_logits:
|
| logits = logits + logits.sigmoid().log()
|
|
|
|
|
|
|
| if exists(combine_mixture):
|
| with autocast('cuda', enabled = False):
|
| prob = logits.softmax(dim = -1)
|
| mos = einsum('... k d, ... k -> ... d', prob, combine_mixture)
|
| logits = log(mos)
|
|
|
|
|
|
|
| if self.squeeze_out_last_dim:
|
| logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
|
|
|
| if not self.has_multiple_heads:
|
| logits = first(logits)
|
|
|
|
|
|
|
| if return_logits_and_embeddings:
|
| out = (logits, x)
|
| elif return_embeddings_and_intermediates:
|
| out = (x, intermediates)
|
| elif return_embeddings:
|
| out = x
|
| else:
|
| out = logits
|
|
|
|
|
|
|
| if return_logit_entropies:
|
| intermediates.logit_entropies = calc_entropy(logits)
|
| return_intermediates = True
|
|
|
|
|
|
|
| if return_attn_z_loss:
|
| pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
|
| intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
| return_intermediates = True
|
|
|
| if return_mems:
|
| hiddens = intermediates.hiddens
|
| new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
| new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
|
|
| if not return_intermediates:
|
| return out, new_mems
|
|
|
| intermediates.mems = new_mems
|
|
|
| if return_intermediates:
|
| return out, intermediates
|
|
|
| if return_attn:
|
| attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
| return out, attn_maps
|
|
|
| return out
|
|
|
| class XTransformer(Module):
|
| def __init__(
|
| self,
|
| *,
|
| dim,
|
| tie_token_emb = False,
|
| ignore_index = -100,
|
| pad_value = 0,
|
| cross_attn_tokens_dropout = 0.,
|
| **kwargs
|
| ):
|
| super().__init__()
|
| enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
|
| dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
|
|
|
| assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
|
| enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
|
| enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
|
| enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
|
| enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
|
| enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
|
|
|
| dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
|
| dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
|
| dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
|
| dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
|
|
|
| self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
|
|
|
| self.encoder = TransformerWrapper(
|
| **enc_transformer_kwargs,
|
| return_only_embed = True,
|
| attn_layers = Encoder(dim = dim, **enc_kwargs)
|
| )
|
|
|
| self.decoder = TransformerWrapper(
|
| **dec_transformer_kwargs,
|
| attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
|
| )
|
|
|
| if tie_token_emb:
|
| self.decoder.token_emb = self.encoder.token_emb
|
|
|
| self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
|
|
|
| @torch.no_grad()
|
| def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
|
| encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
|
| return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
|
|
|
| def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
|
|
|
| enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
|
|
|
| if exists(src_prepend_embeds) and exists(mask):
|
| mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
|
|
|
| if self.training and self.cross_attn_tokens_dropout > 0:
|
| enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
|
|
| out = self.decoder(tgt, context = enc, context_mask = mask)
|
| return out
|
|
|
|
|
|
|
|
|
|
|
| from math import ceil, log
|
| from typing import Tuple, Callable
|
|
|
| import torch
|
| from torch import nn, Tensor
|
| from torch.nn import Module
|
| import torch.nn.functional as F
|
|
|
| from einops import rearrange, pack, unpack
|
|
|
| def exists(val):
|
| return val is not None
|
|
|
| def default(val, d):
|
| return val if exists(val) else d
|
|
|
| def identity(t, *args, **kwargs):
|
| return t
|
|
|
| def join(arr, delimiter = ', '):
|
| return delimiter.join(arr)
|
|
|
| def cast_tuple(t, length = 1):
|
| return t if isinstance(t, tuple) else (t,) * length
|
|
|
| def eval_decorator(fn):
|
| def inner(self, *args, **kwargs):
|
| was_training = self.training
|
| self.eval()
|
| out = fn(self, *args, **kwargs)
|
| self.train(was_training)
|
| return out
|
| return inner
|
|
|
|
|
|
|
| def align_right(t, lens, pad_id = 0):
|
| batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
|
|
|
| assert lens.ndim == 1 and lens.shape[0] == batch
|
| assert lens.amax() <= seq_len
|
|
|
| pad_lens = seq_len - lens
|
| max_pad_len = pad_lens.amax()
|
|
|
| batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
| prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
|
|
| t = F.pad(t, (max_pad_len, 0), value = pad_id)
|
| offset = max_pad_len - pad_lens
|
|
|
| aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
| return aligned
|
|
|
|
|
|
|
| def top_p(logits, thres = 0.9):
|
| sorted_logits, sorted_indices = torch.sort(logits, descending = True)
|
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
|
|
|
| sorted_indices_to_remove = cum_probs > thres
|
| sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
|
|
|
| sorted_logits[sorted_indices_to_remove] = float('-inf')
|
| return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
|
|
|
|
|
|
| def top_k(logits, frac_num_tokens = 0.1, k = None):
|
| num_tokens = logits.shape[-1]
|
|
|
| k = default(k, ceil(frac_num_tokens * num_tokens))
|
| k = min(k, num_tokens)
|
|
|
| val, ind = torch.topk(logits, k)
|
| probs = torch.full_like(logits, float('-inf'))
|
| probs.scatter_(1, ind, val)
|
| return probs
|
|
|
|
|
|
|
| def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
|
| probs = logits.softmax(dim = -1)
|
| max_probs = probs.amax(dim = -1, keepdim = True)
|
| limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
|
| return torch.where(probs < limit, float('-inf'), logits)
|
|
|
|
|
|
|
|
|
| def min_p(logits, min_p = 0.1):
|
| probs = logits.softmax(dim = -1)
|
| max_probs = probs.amax(dim = -1, keepdim = True)
|
| limit = min_p * max_probs
|
| return torch.where(probs < limit, float('-inf'), logits)
|
|
|
|
|
|
|
| FILTER_LOGITS_FN = dict(
|
| top_p = top_p,
|
| top_k = top_k,
|
| top_a = top_a,
|
| min_p = min_p
|
| )
|
|
|
|
|
|
|
| def contrastive_decode_fn(
|
| expert_logits,
|
| amateur_logits,
|
| alpha = 0.1,
|
| beta = 0.5
|
| ):
|
| """
|
| Appendix A Algorithm 2
|
| https://arxiv.org/abs/2309.09117
|
| """
|
|
|
| cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
|
| diffs = (1 + beta) * expert_logits - beta * amateur_logits
|
| contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
|
| return contrastive_decode_logits
|
|
|
|
|
|
|
| class AutoregressiveWrapper(Module):
|
| def __init__(
|
| self,
|
| net,
|
| ignore_index = -100,
|
| pad_value = 0,
|
| mask_prob = 0.,
|
| add_attn_z_loss = False
|
| ):
|
| super().__init__()
|
| self.pad_value = pad_value
|
| self.ignore_index = ignore_index
|
|
|
| self.net = net
|
| self.max_seq_len = net.max_seq_len
|
|
|
|
|
| assert mask_prob < 1.
|
| self.mask_prob = mask_prob
|
|
|
|
|
| self.add_attn_z_loss = add_attn_z_loss
|
|
|
| @torch.inference_mode()
|
| @eval_decorator
|
| def generate(
|
| self,
|
| prompts,
|
| seq_len,
|
| eos_token = None,
|
| temperature = 1.,
|
| prompt_lens: Tensor | None = None,
|
| filter_logits_fn: str | Callable = top_k,
|
| restrict_to_max_seq_len = True,
|
| amateur_model: Module | Tuple[Module] | None = None,
|
| filter_kwargs: dict = dict(),
|
| contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| beta = 0.5,
|
| alpha = 0.1
|
| ),
|
| cache_kv = True,
|
| return_prime=False,
|
| verbose=True,
|
| **kwargs
|
| ):
|
| max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
|
|
| prompts, ps = pack([prompts], '* n')
|
|
|
| b, t = prompts.shape
|
|
|
|
|
|
|
| if isinstance(filter_logits_fn, str):
|
| assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
|
|
| filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
|
|
|
|
|
|
| seq_start_pos = None
|
| if exists(prompt_lens):
|
| prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| seq_start_pos = t - prompt_lens
|
|
|
|
|
|
|
| out = prompts
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
|
|
|
|
| cache = None
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| amateur_model = cast_tuple(amateur_model)
|
| contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
|
|
| assert len(amateur_model) == len(contrastive_decode_kwargs)
|
|
|
| amateur_caches = [None] * len(amateur_model)
|
| filter_logits_fn = identity
|
|
|
| for i, module in enumerate(amateur_model):
|
| if isinstance(module, AutoregressiveWrapper):
|
| amateur_model[i] = module.net
|
|
|
| module.eval()
|
|
|
|
|
|
|
| for sl in range(seq_len):
|
|
|
| if restrict_to_max_seq_len:
|
| max_len_exceeded = out.shape[-1] > max_seq_len
|
|
|
| assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
|
|
| x = out[:, -max_seq_len:]
|
|
|
| if exists(cache):
|
| for inter in cache.attn_intermediates:
|
| if inter.layer_type == 'a':
|
| inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
|
|
| logits, new_cache = self.net(
|
| x,
|
| return_intermediates = True,
|
| cache = cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| if cache_kv and self.net.can_cache_kv:
|
| cache = new_cache
|
|
|
| logits = logits[:, -1]
|
|
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
| amateur_logits, next_amateur_cache = amateur(
|
| x,
|
| return_intermediates = True,
|
| cache = amateur_cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| amateur_logits = amateur_logits[:, -1]
|
|
|
| assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
| logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
|
|
| if cache_kv and amateur.can_cache_kv:
|
| amateur_caches[i] = next_amateur_cache
|
|
|
|
|
|
|
| if greedy:
|
| sample = logits.argmax(dim = -1, keepdim = True)
|
| else:
|
| filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
|
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| if verbose:
|
| if sl % 32 == 0:
|
| print(sl, '/', seq_len)
|
|
|
| if not exists(eos_token):
|
| continue
|
|
|
| is_eos_tokens = (out == eos_token)
|
|
|
| if is_eos_tokens.any(dim = -1).all():
|
|
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
|
|
| break
|
|
|
| if exists(eos_token):
|
|
|
| shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| out = out.masked_fill(mask, self.pad_value)
|
|
|
| if return_prime:
|
| out = out[:, :]
|
|
|
| else:
|
| out = out[:, t:]
|
|
|
| out, = unpack(out, ps, '* n')
|
|
|
| return out
|
|
|
| @torch.inference_mode()
|
| @eval_decorator
|
| def generate_masked(
|
| self,
|
| prompts,
|
| seq_len,
|
| eos_token = None,
|
| temperature = 1.,
|
| prompt_lens: Tensor | None = None,
|
| filter_logits_fn: str | Callable = top_k,
|
| restrict_to_max_seq_len = True,
|
| amateur_model: Module | Tuple[Module] | None = None,
|
| filter_kwargs: dict = dict(),
|
| contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| beta = 0.5,
|
| alpha = 0.1
|
| ),
|
| cache_kv = True,
|
| return_prime=False,
|
| verbose=True,
|
| masked_token_ids: list[int] | Tensor | None = None,
|
| **kwargs
|
| ):
|
| max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
|
|
| prompts, ps = pack([prompts], '* n')
|
|
|
| b, t = prompts.shape
|
|
|
|
|
| if isinstance(filter_logits_fn, str):
|
| assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
|
|
|
|
| if masked_token_ids is not None:
|
| if not torch.is_tensor(masked_token_ids):
|
| masked_token_ids = torch.tensor(masked_token_ids, dtype=torch.long, device=device)
|
| else:
|
| masked_token_ids = masked_token_ids.to(device=device, dtype=torch.long)
|
|
|
| masked_token_ids = torch.unique(masked_token_ids)
|
|
|
|
|
| masked_token_ids = masked_token_ids[masked_token_ids >= 0]
|
| else:
|
| masked_token_ids = None
|
|
|
|
|
| seq_start_pos = None
|
| if exists(prompt_lens):
|
| prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| seq_start_pos = t - prompt_lens
|
|
|
|
|
| out = prompts
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
|
|
| cache = None
|
|
|
|
|
| if exists(amateur_model):
|
| amateur_model = cast_tuple(amateur_model)
|
| contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
|
|
| assert len(amateur_model) == len(contrastive_decode_kwargs)
|
|
|
| amateur_caches = [None] * len(amateur_model)
|
| filter_logits_fn = identity
|
|
|
| for i, module in enumerate(amateur_model):
|
| if isinstance(module, AutoregressiveWrapper):
|
| amateur_model[i] = module.net
|
|
|
| module.eval()
|
|
|
|
|
| for sl in range(seq_len):
|
|
|
| if restrict_to_max_seq_len:
|
| max_len_exceeded = out.shape[-1] > max_seq_len
|
|
|
| assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
|
|
| x = out[:, -max_seq_len:]
|
|
|
| if exists(cache):
|
| for inter in cache.attn_intermediates:
|
| if inter.layer_type == 'a':
|
| inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
|
|
| logits, new_cache = self.net(
|
| x,
|
| return_intermediates = True,
|
| cache = cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| if cache_kv and self.net.can_cache_kv:
|
| cache = new_cache
|
|
|
| logits = logits[:, -1]
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
| amateur_logits, next_amateur_cache = amateur(
|
| x,
|
| return_intermediates = True,
|
| cache = amateur_cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| amateur_logits = amateur_logits[:, -1]
|
|
|
| assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
| logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
|
|
| if cache_kv and amateur.can_cache_kv:
|
| amateur_caches[i] = next_amateur_cache
|
|
|
|
|
| if masked_token_ids is not None and masked_token_ids.numel() > 0:
|
|
|
| vocab_size = logits.shape[-1]
|
| valid_masked = masked_token_ids[masked_token_ids < vocab_size]
|
| if valid_masked.numel() > 0:
|
|
|
| neg_inf = -1e9
|
|
|
| logits[:, valid_masked] = neg_inf
|
|
|
|
|
| if greedy:
|
| sample = logits.argmax(dim = -1, keepdim = True)
|
| else:
|
| filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| if verbose:
|
| if sl % 32 == 0:
|
| print(sl, '/', seq_len)
|
|
|
| if not exists(eos_token):
|
| continue
|
|
|
| is_eos_tokens = (out == eos_token)
|
|
|
| if is_eos_tokens.any(dim = -1).all():
|
|
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
|
|
| break
|
|
|
| if exists(eos_token):
|
|
|
| shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| out = out.masked_fill(mask, self.pad_value)
|
|
|
| if return_prime:
|
| out = out[:, :]
|
|
|
| else:
|
| out = out[:, t:]
|
|
|
| out, = unpack(out, ps, '* n')
|
|
|
| return out
|
|
|
| @torch.inference_mode()
|
| @eval_decorator
|
| def generate_biased(
|
| self,
|
| prompts,
|
| seq_len,
|
| eos_token = None,
|
| temperature = 1.,
|
| prompt_lens: Tensor | None = None,
|
| filter_logits_fn: str | Callable = top_k,
|
| restrict_to_max_seq_len = True,
|
| amateur_model: Module | Tuple[Module] | None = None,
|
| filter_kwargs: dict = dict(),
|
| contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| beta = 0.5,
|
| alpha = 0.1
|
| ),
|
| cache_kv = True,
|
| return_prime=False,
|
| verbose=True,
|
| logit_bias: dict | Tensor | None = None,
|
| **kwargs
|
| ):
|
| """
|
| Autoregressive generation with optional additive logit bias.
|
|
|
| logit_bias:
|
| - dict[token_id -> float] OR
|
| - torch.Tensor of shape (vocab,) OR (batch, vocab)
|
| """
|
|
|
| max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
|
|
| prompts, ps = pack([prompts], '* n')
|
|
|
| b, t = prompts.shape
|
|
|
|
|
| if isinstance(filter_logits_fn, str):
|
| assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
|
|
|
|
| seq_start_pos = None
|
| if exists(prompt_lens):
|
| prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| seq_start_pos = t - prompt_lens
|
|
|
|
|
| out = prompts
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
|
|
| cache = None
|
|
|
|
|
| if exists(amateur_model):
|
| amateur_model = cast_tuple(amateur_model)
|
| contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
| assert len(amateur_model) == len(contrastive_decode_kwargs)
|
| amateur_caches = [None] * len(amateur_model)
|
| filter_logits_fn = identity
|
| for i, module in enumerate(amateur_model):
|
| if isinstance(module, AutoregressiveWrapper):
|
| amateur_model[i] = module.net
|
| module.eval()
|
|
|
|
|
|
|
|
|
| prepared_bias = None
|
| lazy_build_bias_from_dict = None
|
|
|
| if exists(logit_bias):
|
| if isinstance(logit_bias, dict):
|
|
|
| vocab_size = None
|
|
|
|
|
| try:
|
| if hasattr(self.net, "config") and getattr(self.net.config, "vocab_size", None) is not None:
|
| vocab_size = int(self.net.config.vocab_size)
|
| elif getattr(self.net, "vocab_size", None) is not None:
|
| vocab_size = int(self.net.vocab_size)
|
| else:
|
|
|
|
|
| get_out = getattr(self.net, "get_output_embeddings", None)
|
| if callable(get_out) and get_out() is not None:
|
| vocab_size = int(get_out().weight.shape[0])
|
| elif hasattr(self.net, "embed_tokens"):
|
| vocab_size = int(self.net.embed_tokens.weight.shape[0])
|
| elif hasattr(self.net, "lm_head"):
|
| vocab_size = int(self.net.lm_head.weight.shape[0])
|
| except Exception:
|
| vocab_size = None
|
|
|
| if vocab_size is not None:
|
| bias_vec = torch.zeros(int(vocab_size), device=device, dtype=torch.float32)
|
| for tok, val in logit_bias.items():
|
| tok_i = int(tok)
|
| if tok_i < 0 or tok_i >= vocab_size:
|
| raise IndexError(f"logit_bias token id {tok_i} out of range for vocab size {vocab_size}")
|
| bias_vec[tok_i] = float(val)
|
| prepared_bias = bias_vec
|
| else:
|
|
|
| lazy_build_bias_from_dict = {int(k): float(v) for k, v in logit_bias.items()}
|
|
|
| elif isinstance(logit_bias, torch.Tensor):
|
| prepared_bias = logit_bias.to(device=device, dtype=torch.float32)
|
| else:
|
| raise TypeError("logit_bias must be dict or torch.Tensor")
|
|
|
|
|
| for sl in range(seq_len):
|
|
|
| if restrict_to_max_seq_len:
|
| max_len_exceeded = out.shape[-1] > max_seq_len
|
| assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
|
| 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
| x = out[:, -max_seq_len:]
|
| if exists(cache):
|
| for inter in cache.attn_intermediates:
|
| if inter.layer_type == 'a':
|
| inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
| else:
|
| x = out
|
|
|
| logits, new_cache = self.net(
|
| x,
|
| return_intermediates = True,
|
| cache = cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| if cache_kv and self.net.can_cache_kv:
|
| cache = new_cache
|
|
|
| logits = logits[:, -1]
|
|
|
|
|
|
|
| if lazy_build_bias_from_dict is not None:
|
| vocab_size = logits.shape[-1]
|
| bias_vec = torch.zeros(vocab_size, device=device, dtype=torch.float32)
|
| for tok, val in lazy_build_bias_from_dict.items():
|
| if tok < 0 or tok >= vocab_size:
|
| raise IndexError(f"logit_bias token id {tok} out of range for vocab size {vocab_size}")
|
| bias_vec[tok] = val
|
| prepared_bias = bias_vec
|
| lazy_build_bias_from_dict = None
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
| amateur_logits, next_amateur_cache = amateur(
|
| x,
|
| return_intermediates = True,
|
| cache = amateur_cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
| amateur_logits = amateur_logits[:, -1]
|
| assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
| logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
| if cache_kv and amateur.can_cache_kv:
|
| amateur_caches[i] = next_amateur_cache
|
|
|
|
|
|
|
|
|
| if exists(prepared_bias):
|
|
|
| if prepared_bias.dim() == 1:
|
|
|
| logits = logits + prepared_bias.unsqueeze(0)
|
| elif prepared_bias.dim() == 2:
|
|
|
| if prepared_bias.shape[0] != logits.shape[0]:
|
| raise ValueError("logit_bias tensor batch size must match logits batch size")
|
| logits = logits + prepared_bias
|
| else:
|
| raise ValueError("logit_bias tensor must be 1D (vocab,) or 2D (batch, vocab)")
|
|
|
|
|
| if greedy:
|
| sample = logits.argmax(dim = -1, keepdim = True)
|
| else:
|
| filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| if verbose:
|
| if sl % 32 == 0:
|
| print(sl, '/', seq_len)
|
|
|
| if not exists(eos_token):
|
| continue
|
|
|
| is_eos_tokens = (out == eos_token)
|
|
|
| if is_eos_tokens.any(dim = -1).all():
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
| break
|
|
|
| if exists(eos_token):
|
|
|
| shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| out = out.masked_fill(mask, self.pad_value)
|
|
|
| if return_prime:
|
| out = out[:, :]
|
| else:
|
| out = out[:, t:]
|
|
|
| out, = unpack(out, ps, '* n')
|
|
|
| return out
|
|
|
| @torch.inference_mode()
|
| @eval_decorator
|
| def generate_advanced(
|
| self,
|
| prompts,
|
| seq_len,
|
| eos_token = None,
|
| temperature = 1.,
|
| prompt_lens: Tensor | None = None,
|
| filter_logits_fn: str | Callable = top_k,
|
| restrict_to_max_seq_len = True,
|
| amateur_model: Module | Tuple[Module] | None = None,
|
| filter_kwargs: dict = dict(),
|
| contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| beta = 0.5,
|
| alpha = 0.1
|
| ),
|
| cache_kv = True,
|
| return_prime=False,
|
| verbose=True,
|
|
|
| logits_bias: dict | None = None,
|
| masked_tokens: list | Tensor | None = None,
|
|
|
| binary_classifier: bool = False,
|
| classifier_model: Module | None = None,
|
| batches: list | None = None,
|
| threshold: float = 0.5,
|
| classifier_device: torch.device | None = None,
|
|
|
| **kwargs
|
| ):
|
|
|
| if binary_classifier:
|
| assert classifier_model is not None, "classifier_model must be provided when binary_classifier=True"
|
| assert batches is not None, "batches (iterable of input tensors) must be provided when binary_classifier=True"
|
|
|
| device = classifier_device if classifier_device is not None else (prompts.device if exists(prompts) else torch.device('cpu'))
|
|
|
| all_probs = []
|
| all_preds = []
|
|
|
| classifier_model.eval()
|
| with torch.no_grad():
|
| for x in batches:
|
| x = x.to(device)
|
| logits = classifier_model(x).squeeze()
|
| probs = torch.sigmoid(logits)
|
| preds = (probs >= threshold).long()
|
|
|
| all_probs.extend(probs.cpu().tolist())
|
| all_preds.extend(preds.cpu().tolist())
|
|
|
| return all_preds, all_probs
|
|
|
|
|
| max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
|
|
| prompts, ps = pack([prompts], '* n')
|
|
|
| b, t = prompts.shape
|
|
|
|
|
| if isinstance(filter_logits_fn, str):
|
| assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
|
|
|
|
| seq_start_pos = None
|
| if exists(prompt_lens):
|
| prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| seq_start_pos = t - prompt_lens
|
|
|
|
|
| out = prompts
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
|
|
| cache = None
|
|
|
|
|
| if exists(amateur_model):
|
| amateur_model = cast_tuple(amateur_model)
|
| contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
|
|
| assert len(amateur_model) == len(contrastive_decode_kwargs)
|
|
|
| amateur_caches = [None] * len(amateur_model)
|
| filter_logits_fn = identity
|
|
|
| for i, module in enumerate(amateur_model):
|
| if isinstance(module, AutoregressiveWrapper):
|
| amateur_model[i] = module.net
|
|
|
| module.eval()
|
|
|
|
|
| if exists(logits_bias):
|
| assert isinstance(logits_bias, dict), "logits_bias must be a dict {token_id: bias_value}"
|
| if exists(masked_tokens):
|
| if isinstance(masked_tokens, torch.Tensor):
|
| masked_tokens = masked_tokens.tolist()
|
| else:
|
| masked_tokens = list(masked_tokens)
|
|
|
|
|
| for sl in range(seq_len):
|
|
|
| if restrict_to_max_seq_len:
|
| max_len_exceeded = out.shape[-1] > max_seq_len
|
|
|
| assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
|
|
| x = out[:, -max_seq_len:]
|
|
|
| if exists(cache):
|
| for inter in cache.attn_intermediates:
|
| if inter.layer_type == 'a':
|
| inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
|
|
| logits, new_cache = self.net(
|
| x,
|
| return_intermediates = True,
|
| cache = cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| if cache_kv and self.net.can_cache_kv:
|
| cache = new_cache
|
|
|
| logits = logits[:, -1]
|
|
|
|
|
| if exists(amateur_model):
|
| for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
| amateur_logits, next_amateur_cache = amateur(
|
| x,
|
| return_intermediates = True,
|
| cache = amateur_cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| amateur_logits = amateur_logits[:, -1]
|
|
|
| assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
| logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
|
|
| if cache_kv and amateur.can_cache_kv:
|
| amateur_caches[i] = next_amateur_cache
|
|
|
|
|
|
|
| if exists(logits_bias):
|
|
|
| for tok_id, bias_val in logits_bias.items():
|
|
|
| if isinstance(bias_val, torch.Tensor):
|
| if bias_val.dim() == 1 and bias_val.shape[0] == b:
|
| bias_to_add = bias_val.to(device)
|
| else:
|
| bias_to_add = bias_val.to(device).view(1).expand(b)
|
| else:
|
| bias_to_add = torch.tensor(float(bias_val), device=device).view(1).expand(b)
|
|
|
| logits[:, int(tok_id)] = logits[:, int(tok_id)] + bias_to_add
|
|
|
|
|
| if exists(masked_tokens) and len(masked_tokens) > 0:
|
| NEG_INF = -1e9
|
| idx = torch.tensor(masked_tokens, device=device, dtype=torch.long)
|
| idx = idx[(idx >= 0) & (idx < logits.shape[-1])]
|
| if idx.numel() > 0:
|
| logits.index_fill_(dim=-1, index=idx, value=NEG_INF)
|
|
|
|
|
|
|
| if greedy:
|
| sample = logits.argmax(dim = -1, keepdim = True)
|
| else:
|
| filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| if verbose:
|
| if sl % 32 == 0:
|
| print(sl, '/', seq_len)
|
|
|
| if not exists(eos_token):
|
| continue
|
|
|
| is_eos_tokens = (out == eos_token)
|
|
|
| if is_eos_tokens.any(dim = -1).all():
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
| break
|
|
|
| if exists(eos_token):
|
|
|
| shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| out = out.masked_fill(mask, self.pad_value)
|
|
|
| if return_prime:
|
| out = out[:, :]
|
|
|
| else:
|
| out = out[:, t:]
|
|
|
| out, = unpack(out, ps, '* n')
|
|
|
| return out
|
|
|
| def compute_accuracy(self, logits, labels):
|
|
|
| out = torch.argmax(logits, dim=-1)
|
| out = out.flatten()
|
| labels = labels.flatten()
|
|
|
| mask = (labels != self.ignore_index)
|
| out = out[mask]
|
| labels = labels[mask]
|
|
|
| num_right = (out == labels)
|
| num_right = torch.sum(num_right).type(torch.float32)
|
|
|
| acc = num_right / len(labels)
|
|
|
| return acc
|
|
|
| def forward(self, x, return_outputs = False, **kwargs):
|
| seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
|
|
|
| inp, target = x[:, :-1], x[:, 1:]
|
| inp = torch.where(inp == ignore_index, self.pad_value, inp)
|
|
|
| if self.mask_prob > 0.:
|
| rand = torch.randn(inp.shape, device = x.device)
|
| rand[:, 0] = -torch.finfo(rand.dtype).max
|
| num_mask = min(int(seq * self.mask_prob), seq - 1)
|
| indices = rand.topk(num_mask, dim = -1).indices
|
| mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
|
| kwargs.update(self_attn_kv_mask = mask)
|
|
|
| logits, cache = self.net(
|
| inp,
|
| return_intermediates = True,
|
| return_attn_z_loss = add_attn_z_loss,
|
| **kwargs
|
| )
|
|
|
| acc = self.compute_accuracy(logits, target)
|
|
|
| loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
|
|
| loss = loss_fn(
|
| rearrange(logits, 'b n c -> b c n'),
|
| target,
|
| ignore_index = ignore_index
|
| )
|
|
|
| if add_attn_z_loss:
|
| loss = loss + cache.attn_z_loss
|
|
|
| if not return_outputs:
|
| return loss, acc
|
|
|
| return loss, acc, logits, cache
|
|
|
| @torch.inference_mode()
|
| @eval_decorator
|
| def generate_expert(
|
| self,
|
| prompts,
|
| seq_len,
|
| eos_token = None,
|
| temperature = 1.,
|
| prompt_lens: Tensor | None = None,
|
| filter_logits_fn: str | Callable = top_k,
|
| restrict_to_max_seq_len = True,
|
| amateur_model: Module | Tuple[Module] | None = None,
|
| filter_kwargs: dict = dict(),
|
| contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
| beta = 0.5,
|
| alpha = 0.1
|
| ),
|
| cache_kv = True,
|
| return_prime=False,
|
| verbose=True,
|
|
|
| token_type_ids: torch.LongTensor | None = None,
|
| type_temperatures: dict | None = None,
|
| type_biases: dict | None = None,
|
| repetition_window: int = 64,
|
| repetition_penalty_per_type: dict | None = None,
|
| rare_types: set | None = None,
|
| rare_type_boost: float = 0.0,
|
| entropy_threshold: float = 2.0,
|
|
|
| forbidden_token_ids: torch.LongTensor | torch.BoolTensor | None = None,
|
| forbidden_value: float = -1e9,
|
| **kwargs
|
| ):
|
| max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
|
|
| prompts, ps = pack([prompts], '* n')
|
|
|
| b, t = prompts.shape
|
|
|
|
|
|
|
| if isinstance(filter_logits_fn, str):
|
| assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
| filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
|
|
|
|
|
|
| seq_start_pos = None
|
| if exists(prompt_lens):
|
| prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
| seq_start_pos = t - prompt_lens
|
|
|
|
|
|
|
| out = prompts
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
|
|
|
|
| cache = None
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| amateur_model = cast_tuple(amateur_model)
|
| contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
|
|
| assert len(amateur_model) == len(contrastive_decode_kwargs)
|
|
|
| amateur_caches = [None] * len(amateur_model)
|
| filter_logits_fn = identity
|
|
|
| for i, module in enumerate(amateur_model):
|
| if isinstance(module, AutoregressiveWrapper):
|
| amateur_model[i] = module.net
|
|
|
| module.eval()
|
|
|
|
|
|
|
| if token_type_ids is not None:
|
| token_type_ids = token_type_ids.to(device)
|
|
|
|
|
| per_token_temp = None
|
| if type_temperatures is not None and len(type_temperatures) > 0:
|
| per_token_temp = torch.ones_like(token_type_ids, dtype=torch.float32)
|
| for type_id, temp_val in type_temperatures.items():
|
| per_token_temp[token_type_ids == type_id] = float(temp_val)
|
|
|
| per_token_bias = None
|
| if type_biases is not None and len(type_biases) > 0:
|
| per_token_bias = torch.zeros_like(token_type_ids, dtype=torch.float32)
|
| for type_id, bias_val in type_biases.items():
|
| per_token_bias[token_type_ids == type_id] = float(bias_val)
|
|
|
|
|
| per_type_rep_penalty = repetition_penalty_per_type or {}
|
|
|
|
|
| rare_type_mask = None
|
| if rare_types is not None and len(rare_types) > 0:
|
| rare_type_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
| for rt in rare_types:
|
| rare_type_mask |= (token_type_ids == rt)
|
| else:
|
| per_token_temp = None
|
| per_token_bias = None
|
| per_type_rep_penalty = {}
|
| rare_type_mask = None
|
|
|
|
|
|
|
| forbidden_mask_per_batch = None
|
| if forbidden_token_ids is not None:
|
|
|
| if forbidden_token_ids.dtype in (torch.int64, torch.int32):
|
|
|
| vocab_size = self.net.config.vocab_size if hasattr(self.net, 'config') else None
|
|
|
| if vocab_size is None and token_type_ids is not None:
|
| vocab_size = token_type_ids.shape[0]
|
| assert vocab_size is not None, "Cannot infer vocab size for forbidden_token_ids; provide a boolean mask instead."
|
| mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
| ids = forbidden_token_ids.to(device)
|
| mask[ids.clamp(0, vocab_size-1)] = True
|
| forbidden_mask_per_batch = mask.unsqueeze(0).expand(b, -1)
|
| elif forbidden_token_ids.dtype == torch.bool:
|
|
|
| if forbidden_token_ids.dim() == 1:
|
| forbidden_mask_per_batch = forbidden_token_ids.to(device).unsqueeze(0).expand(b, -1)
|
| elif forbidden_token_ids.dim() == 2:
|
| assert forbidden_token_ids.shape[0] == b, "forbidden_token_ids batch dimension must match prompts batch size"
|
| forbidden_mask_per_batch = forbidden_token_ids.to(device)
|
| else:
|
| raise ValueError("forbidden_token_ids boolean mask must be 1D [vocab] or 2D [b, vocab]")
|
| else:
|
| raise TypeError("forbidden_token_ids must be LongTensor of ids or BoolTensor mask")
|
|
|
|
|
|
|
| for sl in range(seq_len):
|
|
|
| if restrict_to_max_seq_len:
|
| max_len_exceeded = out.shape[-1] > max_seq_len
|
|
|
| assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
|
| 'the network cannot use cached key values when decoding outside the max sequence length. ' \
|
| 'most likely because you are using absolute positional embedding. ' \
|
| 'you can switch to rotary embeddings to resolve this issue'
|
|
|
| x = out[:, -max_seq_len:]
|
|
|
| if exists(cache):
|
| for inter in cache.attn_intermediates:
|
| if inter.layer_type == 'a':
|
| inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
|
|
| logits, new_cache = self.net(
|
| x,
|
| return_intermediates = True,
|
| cache = cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| if cache_kv and self.net.can_cache_kv:
|
| cache = new_cache
|
|
|
| logits = logits[:, -1]
|
|
|
|
|
|
|
| if exists(amateur_model):
|
| for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(
|
| zip(amateur_model, amateur_caches, contrastive_decode_kwargs)
|
| ):
|
| amateur_logits, next_amateur_cache = amateur(
|
| x,
|
| return_intermediates = True,
|
| cache = amateur_cache,
|
| seq_start_pos = seq_start_pos,
|
| **kwargs
|
| )
|
|
|
| amateur_logits = amateur_logits[:, -1]
|
|
|
| assert amateur_logits.shape == logits.shape, \
|
| 'logits dimension are not the same between amateur and expert model'
|
| logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
|
|
| if cache_kv and amateur.can_cache_kv:
|
| amateur_caches[i] = next_amateur_cache
|
|
|
|
|
|
|
| if token_type_ids is not None:
|
|
|
|
|
| if per_token_bias is not None:
|
| logits = logits + per_token_bias
|
|
|
|
|
| if repetition_window > 0 and len(per_type_rep_penalty) > 0:
|
|
|
| recent = out[:, -repetition_window:].to(device)
|
|
|
| recent_types = token_type_ids[recent]
|
|
|
|
|
|
|
| for bi in range(b):
|
| types_b = recent_types[bi]
|
| if types_b.numel() == 0:
|
| continue
|
|
|
| for type_id, penalty_scale in per_type_rep_penalty.items():
|
|
|
| mask = (types_b == type_id)
|
| if mask.any():
|
| freq = mask.float().mean().item()
|
| if freq > 0.0:
|
|
|
| type_mask = (token_type_ids == type_id)
|
|
|
|
|
| logits[bi, type_mask] /= (1.0 + freq * (penalty_scale - 1.0))
|
|
|
|
|
| if rare_type_mask is not None and rare_type_boost > 0.0:
|
|
|
| probs_raw = F.softmax(logits, dim=-1)
|
| log_probs_raw = torch.log(probs_raw + 1e-9)
|
| entropy = -(probs_raw * log_probs_raw).sum(dim=-1)
|
|
|
|
|
| low_entropy = entropy < entropy_threshold
|
| if low_entropy.any():
|
|
|
| boost_vec = torch.zeros_like(logits)
|
| boost_vec[:, rare_type_mask] = rare_type_boost
|
| logits = torch.where(
|
| low_entropy.unsqueeze(-1),
|
| logits + boost_vec,
|
| logits
|
| )
|
|
|
|
|
|
|
| if per_token_temp is not None:
|
|
|
|
|
| logits = logits / per_token_temp
|
|
|
|
|
| if forbidden_mask_per_batch is not None:
|
|
|
| assert forbidden_mask_per_batch.shape[0] == b and forbidden_mask_per_batch.shape[1] == logits.shape[-1], \
|
| "forbidden mask shape must be [b, vocab]"
|
|
|
| logits = logits.masked_fill(forbidden_mask_per_batch, float(forbidden_value))
|
|
|
|
|
|
|
|
|
|
|
| if greedy:
|
| sample = logits.argmax(dim = -1, keepdim = True)
|
| else:
|
| filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
|
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| if verbose:
|
| if sl % 32 == 0:
|
| print(sl, '/', seq_len)
|
|
|
| if not exists(eos_token):
|
| continue
|
|
|
| is_eos_tokens = (out == eos_token)
|
|
|
| if is_eos_tokens.any(dim = -1).all():
|
|
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
|
|
| break
|
|
|
| if exists(eos_token):
|
|
|
| shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
| mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
| out = out.masked_fill(mask, self.pad_value)
|
|
|
| if return_prime:
|
| out = out[:, :]
|
| else:
|
| out = out[:, t:]
|
|
|
| out, = unpack(out, ps, '* n')
|
|
|
| return out
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ClsInferenceDataset(Dataset):
|
| """
|
| Dataset for pairs (src_seq, label).
|
| src_seq: list of token IDs (ints).
|
| label: single int or float (0 or 1).
|
| """
|
| def __init__(self, data_pairs):
|
| self.data_pairs = data_pairs
|
|
|
| def __len__(self):
|
| return len(self.data_pairs)
|
|
|
| def __getitem__(self, idx):
|
| src_seq = self.data_pairs[idx]
|
| x = torch.tensor(src_seq, dtype=torch.long)
|
| return x
|
|
|
| def build_cls_model(num_tokens=18819,
|
| max_seq_len=1024,
|
| logits_dim=1,
|
| use_cls_token=True,
|
| squeeze_out_last_dim=True,
|
| dim=1024,
|
| depth=8,
|
| heads=8,
|
| attn_flash=True,
|
| rotary_pos_emb=False,
|
| device='cuda'
|
| ):
|
|
|
| """
|
| Constructs the Transformer model that outputs a single logit per input.
|
| """
|
|
|
| model = TransformerWrapper(
|
| num_tokens=num_tokens,
|
| max_seq_len=max_seq_len,
|
| logits_dim=logits_dim,
|
| use_cls_token=use_cls_token,
|
| squeeze_out_last_dim = squeeze_out_last_dim,
|
| attn_layers=Encoder(dim=dim,
|
| depth=depth,
|
| heads=heads,
|
| attn_flash=attn_flash,
|
| rotary_pos_emb=rotary_pos_emb
|
| )
|
| )
|
|
|
| return model.to(device)
|
|
|
| def load_cls_model(checkpoint_path, device='cuda'):
|
|
|
| """
|
| Rebuilds the architecture, loads weights.
|
| """
|
|
|
| model = build_cls_model(device=device)
|
| state = torch.load(checkpoint_path, map_location=device)
|
| model.load_state_dict(state)
|
| model.to(device).eval()
|
|
|
| return model
|
|
|
| def cls_predict(model,
|
| seqs,
|
| batch_size=8,
|
| threshold=0.5,
|
| seq_len=1024,
|
| pad_token=18818,
|
| device='cuda'
|
| ):
|
|
|
| """
|
| Returns two lists:
|
| - probs: float probabilities
|
| - preds: int 0/1 predictions
|
| """
|
|
|
| def collate_fn(batch):
|
|
|
| tensors = [s[:seq_len].detach().clone() for s in batch]
|
| max_len = min(seq_len, max(t.size(0) for t in tensors))
|
| padded = torch.full((len(tensors), max_len), pad_token, dtype=torch.long)
|
| for i, t in enumerate(tensors):
|
| L = t.size(0)
|
| padded[i, :L] = t
|
| return padded
|
|
|
| ds = ClsInferenceDataset(seqs)
|
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
|
|
|
| all_probs = []
|
| all_preds = []
|
|
|
| model.to(device)
|
| model.eval()
|
|
|
| with torch.inference_mode():
|
| for x in loader:
|
|
|
| x = x.to(device)
|
|
|
| logits = model(x).squeeze()
|
|
|
| probs = torch.sigmoid(logits)
|
|
|
| preds = (probs >= threshold).long()
|
|
|
| probs = probs.cpu().tolist()
|
| preds = preds.cpu().tolist()
|
|
|
| if type(preds) == list:
|
| all_probs.extend(probs)
|
| all_preds.extend(preds)
|
|
|
| else:
|
| all_probs.append(probs)
|
| all_preds.append(preds)
|
|
|
| return all_preds, all_probs
|
|
|
|
|
|
|
|
|
|
|
| import inspect
|
| import math
|
| from typing import Callable, Optional, Dict, Any, List, Tuple
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| def print_probs_scoring_guide():
|
| print(inspect.getdoc(probs_scoring_guide))
|
|
|
| def probs_scoring_guide():
|
|
|
| """
|
| Return dictionary structure and metric descriptions for generate_with_probs / score_sequences.
|
|
|
| Returns
|
| -------
|
| result : dict
|
| A dictionary containing token-level and sequence-level scoring information.
|
|
|
| Keys
|
| ----
|
| tokens : torch.Tensor
|
| Tensor of token ids for each batch entry. Shape (batch, seq_len).
|
| - Meaning: Generated tokens (for generate_with_probs) or the original
|
| input sequences (for score_sequences).
|
| - Interpretation: Map ids to text with your tokenizer to inspect outputs.
|
|
|
| token_probs : List[List[float]]
|
| Per-batch lists of probabilities assigned to each chosen token at the time
|
| it was produced. Values in [0, 1].
|
| - Meaning: Softmax probability for the selected token at each step.
|
| - Interpretation: Higher → model more confident about that token. Do not
|
| multiply many token_probs directly (underflow risk); use log-probs.
|
|
|
| token_logprobs : List[List[float]]
|
| Per-batch lists of natural log probabilities (nats) for each chosen token:
|
| log p(x_t | x_<t).
|
| - Meaning: Numerically stable per-token log-probabilities.
|
| - Interpretation: Less negative = more likely. Sum these to get sequence_logprobs.
|
|
|
| token_scores : List[List[float]]
|
| Per-batch lists of token negative log-probabilities (NLL) computed as -log p.
|
| - Meaning: Token-level loss (positive).
|
| - Interpretation: Lower = model found token less surprising. Useful to spot spikes.
|
|
|
| sequence_logprobs : List[float]
|
| Sum of token log-probabilities for each sequence (nats): sum_t log p(x_t | x_<t).
|
| - Meaning: Canonical sequence score; additive and numerically stable.
|
| - Interpretation: Use this to compare sequences. Higher (less negative) is better.
|
|
|
| nll : List[float]
|
| Negative sequence log-probabilities (nats): -sequence_logprobs.
|
| - Meaning: Sequence-level negative log-likelihood (loss).
|
| - Interpretation: Lower NLL indicates a sequence the model finds more probable.
|
|
|
| sequence_probs : List[float]
|
| Numeric probabilities computed as exp(sequence_logprobs) (float64 when possible).
|
| - Meaning: Absolute probability of the full sequence.
|
| - Interpretation: Often underflows to 0.0 for realistic lengths; prefer sequence_logprobs.
|
|
|
| sequence_prob_display : List[str]
|
| Human-readable string for sequence probability. If numeric underflow occurs,
|
| this shows an approximate scientific form (e.g., "~10^-550.65").
|
| - Meaning: Readable magnitude of the sequence probability.
|
| - Interpretation: Use this for reporting instead of raw sequence_probs when it is 0.0.
|
|
|
| mask : torch.Tensor
|
| Boolean tensor indicating which positions were included in scoring.
|
| Shape (batch, scored_len). False for padded positions or tokens after the first EOS.
|
| - Meaning: Aligns token-level lists with original sequence positions.
|
| - Interpretation: Use to ignore padded or post-EOS tokens in aggregates.
|
|
|
| metadata : dict
|
| Miscellaneous run information such as:
|
| - prompt_len : int or list[int] — length of prompt tokens (if applicable)
|
| - generated_len : int — number of generated tokens (generate_with_probs)
|
| - temperature : float — sampling temperature used
|
| - seq_len : int — original sequence length (score_sequences)
|
| - Interpretation: Useful for reproducing runs and normalizing comparisons.
|
|
|
| metrics : dict
|
| Per-sequence derived diagnostics (under result["metrics"]["per_sequence"]).
|
| Each entry contains:
|
| - sequence_index : int
|
| - token_count : int
|
| - sequence_logprob_nats : float
|
| Sum of log-probs (nats). Primary canonical score.
|
| - sequence_log10 : float
|
| Log10 of the sequence probability (for display).
|
| - sequence_prob_display : str
|
| Human-friendly scientific display of the sequence probability.
|
| - avg_logprob_per_token_nats : float
|
| Average log-prob per token (nats): (1/T) * sum_t log p.
|
| - Interpretation: Normalizes for length; higher (less negative) is better.
|
| - avg_logprob_per_token_bits : float
|
| Average log-prob per token in bits (divide nats by ln(2)).
|
| - Interpretation: Intuitive unit; lower bits = easier prediction.
|
| - geometric_mean_token_prob : str
|
| Geometric mean of token probabilities (display).
|
| - Interpretation: Typical per-token probability; quick sense of per-token confidence.
|
| - perplexity : float
|
| exp(-avg_logprob_per_token). Standard LM metric; lower is better.
|
|
|
| Notes
|
| -----
|
| - Use `sequence_logprobs` (or `nll`) as the authoritative score for comparisons and ranking.
|
| - Avoid relying on `sequence_probs` for comparisons because of floating-point underflow.
|
| - Prefer `avg_logprob_per_token_nats` or `perplexity` when comparing sequences of different lengths.
|
| - Token-level spikes in `token_scores` (large -log p) indicate surprising tokens and are useful
|
| for debugging prompts or model behavior.
|
|
|
| Examples
|
| --------
|
| # Example usage after calling generate_with_probs or score_sequences:
|
| res = generate_with_probs(...)
|
| print("Sequence logprob (nats):", res["sequence_logprobs"][0])
|
| print("Sequence prob (display):", res["sequence_prob_display"][0])
|
| print("Avg logprob/token (nats):", res["metrics"]["per_sequence"][0]["avg_logprob_per_token_nats"])
|
| print("Perplexity:", res["metrics"]["per_sequence"][0]["perplexity"])
|
| """
|
|
|
| return inspect.getdoc(probs_scoring_guide)
|
|
|
|
|
| def _safe_exp64(logp: float) -> Tuple[float, str]:
|
| lp64 = torch.tensor(logp, dtype=torch.float64)
|
| try:
|
| p64 = float(torch.exp(lp64).item())
|
| except Exception:
|
| p64 = 0.0
|
| if p64 == 0.0:
|
| log10_prob = float(lp64.item() / math.log(10.0))
|
| display = f"~10^{log10_prob:.2f}"
|
| else:
|
| display = f"{p64:.6e}"
|
| return p64, display
|
|
|
| def _attach_metrics_to_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
| seq_logprobs: List[float] = result.get("sequence_logprobs", [])
|
| token_logprobs: List[List[float]] = result.get("token_logprobs", [])
|
| token_probs: List[List[float]] = result.get("token_probs", [])
|
| metrics = {"per_sequence": []}
|
| for i, seq_lp in enumerate(seq_logprobs):
|
| toks_lp = token_logprobs[i] if i < len(token_logprobs) else []
|
| token_count = len(toks_lp)
|
| avg_lp = float(sum(toks_lp) / token_count) if token_count > 0 else 0.0
|
| avg_lp_bits = avg_lp / math.log(2.0)
|
| try:
|
| perplexity = math.exp(-avg_lp)
|
| except OverflowError:
|
| perplexity = float("inf")
|
| log10_prob = seq_lp / math.log(10.0)
|
| seq_prob_display = result.get("sequence_prob_display", [None]*len(seq_logprobs))[i]
|
| if seq_prob_display is None:
|
| seq_prob_display = f"~10^{log10_prob:.2f}"
|
| if token_count > 0:
|
| try:
|
| geom_mean = math.exp(avg_lp)
|
| geom_mean_display = f"{geom_mean:.6e}"
|
| except OverflowError:
|
| geom_mean_display = f"exp({avg_lp:.3f})"
|
| else:
|
| geom_mean_display = "n/a"
|
| metrics["per_sequence"].append({
|
| "sequence_index": i,
|
| "token_count": token_count,
|
| "sequence_logprob_nats": float(seq_lp),
|
| "sequence_log10": float(log10_prob),
|
| "sequence_prob_display": seq_prob_display,
|
| "avg_logprob_per_token_nats": float(avg_lp),
|
| "avg_logprob_per_token_bits": float(avg_lp_bits),
|
| "geometric_mean_token_prob": geom_mean_display,
|
| "perplexity": float(perplexity)
|
| })
|
| result["metrics"] = metrics
|
| return result
|
|
|
| def _decode_token(tokenizer, tok_id: int) -> str:
|
| if tokenizer is None:
|
| return str(tok_id)
|
| try:
|
| if hasattr(tokenizer, "decode"):
|
| return tokenizer.decode([tok_id])
|
| if hasattr(tokenizer, "convert_ids_to_tokens"):
|
| return tokenizer.convert_ids_to_tokens([tok_id])[0]
|
| except Exception:
|
| pass
|
| return str(tok_id)
|
|
|
|
|
|
|
|
|
| @torch.inference_mode()
|
| def generate_with_probs(
|
| model,
|
| prompts: torch.Tensor,
|
| seq_len: int,
|
| eos_token: Optional[int] = None,
|
| temperature: float = 1.0,
|
| prompt_lens: Optional[torch.Tensor] = None,
|
| filter_logits_fn: Optional[Callable] = None,
|
| filter_kwargs: Optional[Dict[str, Any]] = None,
|
| pad_value: Optional[int] = None,
|
| tokenizer = None,
|
| print_table: bool = False,
|
| device: Optional[torch.device] = None,
|
| verbose: bool = True,
|
| include_top1: bool = True,
|
| **kwargs
|
| ) -> Dict[str, Any]:
|
| """
|
| Generate sequences from an autoregressive model while collecting per-token probabilities,
|
| log-probabilities, scores and an optional "diff" view comparing sampled tokens to the
|
| model's top-1 (greedy) tokens.
|
|
|
| This function runs the model in inference mode and appends sampled tokens to the provided
|
| prompts until `seq_len` tokens have been generated (or until an `eos_token` ends all
|
| sequences). It supports temperature sampling, optional logits filtering, and returns
|
| detailed diagnostics useful for evaluation, debugging and analysis (per-token probs,
|
| cumulative sequence log-probabilities, NLL, perplexity, and a diff of sampled vs top-1).
|
|
|
| Key behaviors
|
| - Operates under `torch.inference_mode()` (no gradients).
|
| - If `prompt_lens` is provided, prompts are right-aligned into a padded buffer of the
|
| same shape as `prompts` before generation (useful when prompts are suffixes).
|
| - If `filter_logits_fn` is provided it is applied to raw logits before softmax.
|
| - If `temperature == 0.0` the function performs greedy decoding (argmax).
|
| - If `include_top1` is True, the function computes the top-1 token and its probability
|
| at each step (after optional filtering) and records whether the sampled token matched it.
|
| - If `eos_token` is provided, generation stops early when every batch item has produced
|
| an EOS; generated outputs after the first EOS are optionally padded with `pad_value`.
|
| - Returned numeric log-probabilities are in natural log (nats) and converted to float64
|
| for sequence-level aggregation to reduce numerical error.
|
|
|
| Parameters
|
| - model: A model object exposing a `net` callable with signature
|
| `logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`.
|
| `logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first
|
| element is that tensor.
|
| - prompts (torch.Tensor): Integer token tensor of shape (batch, prompt_len) containing
|
| prompt tokens. Prompts are copied and extended in-place to produce generated sequences.
|
| - seq_len (int): Maximum number of tokens to generate per example (not counting prompt).
|
| - eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, generation
|
| may stop early and outputs after the first EOS are optionally replaced with `pad_value`.
|
| - temperature (float): Sampling temperature. `0.0` forces greedy decoding.
|
| - prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths (int or tensor)
|
| used to right-align prompts into the generation buffer when prompts are suffixes.
|
| - filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax.
|
| Signature should accept `(logits, **filter_kwargs)` and return logits of same shape.
|
| - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`.
|
| - pad_value (Optional[int]): Token id used to pad generated outputs after EOS (if any).
|
| - tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and
|
| printed tables. If absent, token ids are stringified.
|
| - print_table (bool): If True, prints a human-readable table summarizing per-token stats.
|
| - device (Optional[torch.device]): Device to run generation on. Defaults to `prompts.device`.
|
| - verbose (bool): If True, prints progress messages during generation.
|
| - include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs,
|
| and a `diff` structure listing positions where sampled != top-1.
|
| - **kwargs: Additional keyword arguments forwarded to `model.net`.
|
|
|
| Returns
|
| A dictionary with the following keys (types shown informally):
|
| - "tokens" (torch.Tensor): Generated tokens (batch, generated_len) as CPU tensor.
|
| - "token_probs" (List[List[float]]): Per-batch list of per-token sampling probabilities.
|
| - "token_logprobs" (List[List[float]]): Per-batch list of per-token log-probabilities (nats).
|
| - "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities).
|
| - "sequence_logprobs" (List[float]): Sum of token log-probabilities per generated sequence.
|
| - "sequence_probs" (List[float]): Sequence probabilities (exp of sequence_logprobs) where
|
| numerically possible; extremely small values may be represented as 0.0.
|
| - "sequence_prob_display" (List[str]): Human-friendly display of sequence probability
|
| (either decimal or approximate 10^x form for tiny values).
|
| - "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob).
|
| - "metadata" (dict): Contains "prompt_len", "generated_len", and "temperature".
|
| - "diff" (List[List[Dict]]): Per-batch list of dictionaries for positions where the sampled
|
| token differed from the top-1 token. Each dict contains:
|
| - "pos": position index within the generated span (0-based)
|
| - "token": sampled token id
|
| - "token_str": decoded sampled token (or id string)
|
| - "token_prob": sampled token probability
|
| - "top1_token": top-1 token id
|
| - "top1_token_str": decoded top-1 token
|
| - "top1_prob": top-1 probability
|
| - "match": boolean (always False for entries in diff)
|
| - If `include_top1` is True, additional keys are included:
|
| - "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches"
|
|
|
| After the primary result is assembled the function attaches a "metrics" entry with:
|
| - "per_sequence": list of per-sequence metric dicts containing:
|
| - "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10",
|
| "sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits",
|
| "geometric_mean_token_prob", "perplexity"
|
|
|
| Notes and caveats
|
| - Numerical stability: very small probabilities are clamped before log to avoid -inf;
|
| sequence probabilities that underflow are represented with an approximate 10^x string.
|
| - The function assumes the model's logits correspond to the next-token distribution for
|
| the last position of the provided input; it uses `logits[:, -1, :]` for sampling.
|
| - The function may raise exceptions if `model.net` returns tensors of unexpected shape
|
| or if device/dtype mismatches occur.
|
| - This function is intended for analysis and debugging; it is not optimized for maximal
|
| throughput in production sampling loops.
|
|
|
| Example (conceptual)
|
| >>> res = generate_with_probs(model, prompts, seq_len=20, temperature=0.8, tokenizer=tok)
|
| >>> print(res["metrics"]["per_sequence"][0]["perplexity"])
|
| """
|
| if filter_kwargs is None:
|
| filter_kwargs = {}
|
| if device is None:
|
| device = prompts.device
|
| if pad_value is None:
|
| pad_value = getattr(model, "pad_value", None)
|
|
|
| model.eval()
|
| with torch.inference_mode():
|
| prompts_in = prompts.to(device)
|
| b, t = prompts_in.shape
|
|
|
| if prompt_lens is not None:
|
| aligned = torch.full_like(prompts_in, pad_value)
|
| for i in range(b):
|
| L = int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else int(prompt_lens[i])
|
| if L > 0:
|
| aligned[i, -L:] = prompts_in[i, -L:]
|
| prompts_in = aligned
|
|
|
| out = prompts_in.clone()
|
|
|
| token_probs: List[List[float]] = [[] for _ in range(b)]
|
| token_logprobs: List[List[float]] = [[] for _ in range(b)]
|
| token_scores: List[List[float]] = [[] for _ in range(b)]
|
| seq_logprob_tensors = [torch.tensor(0.0, dtype=torch.float64) for _ in range(b)]
|
|
|
| top1_tokens: List[List[int]] = [[] for _ in range(b)]
|
| top1_token_probs: List[List[float]] = [[] for _ in range(b)]
|
| top1_token_logprobs: List[List[float]] = [[] for _ in range(b)]
|
| top1_matches: List[List[bool]] = [[] for _ in range(b)]
|
|
|
| greedy = (temperature == 0.0)
|
|
|
| if verbose:
|
| print("Generating sequence of max length:", seq_len)
|
|
|
| for sl in range(seq_len):
|
| max_seq_len = getattr(model, "max_seq_len", None)
|
| x = out if max_seq_len is None else out[:, -max_seq_len:]
|
|
|
| logits_out = model.net(x, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)
|
| logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out
|
| logits = logits[:, -1, :]
|
|
|
|
|
| if include_top1:
|
| top1_ids = logits.argmax(dim=-1, keepdim=True)
|
| filtered_for_top1 = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
|
| probs_for_top1 = F.softmax(filtered_for_top1 / (temperature if temperature > 0 else 1.0), dim=-1)
|
| top1_p = probs_for_top1.gather(1, top1_ids).squeeze(1)
|
| top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64)
|
|
|
| if greedy:
|
| filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / (temperature if temperature > 0 else 1.0), dim=-1)
|
| sample = logits.argmax(dim=-1, keepdim=True)
|
| else:
|
| filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **filter_kwargs)
|
| probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| sample = torch.multinomial(probs, 1)
|
|
|
| picked_probs = probs.gather(1, sample).squeeze(1)
|
| picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64)
|
|
|
| out = torch.cat((out, sample), dim=-1)
|
|
|
| for i in range(b):
|
| p = float(picked_probs[i].cpu().item())
|
| lp = float(picked_logprobs[i].cpu().item())
|
| token_probs[i].append(p)
|
| token_logprobs[i].append(lp)
|
| token_scores[i].append(-lp)
|
| seq_logprob_tensors[i] = seq_logprob_tensors[i] + torch.tensor(lp, dtype=torch.float64)
|
|
|
| if include_top1:
|
| tid = int(top1_ids[i].item())
|
| tp = float(top1_p[i].cpu().item())
|
| tlp = float(top1_lp[i].cpu().item())
|
| top1_tokens[i].append(tid)
|
| top1_token_probs[i].append(tp)
|
| top1_token_logprobs[i].append(tlp)
|
| top1_matches[i].append(int(sample[i].item()) == tid)
|
|
|
| if verbose and (sl % 32 == 0):
|
| print(f"{sl} / {seq_len}")
|
|
|
| if eos_token is not None:
|
| last_tokens = out[:, -1]
|
| if (last_tokens == eos_token).any(dim=-1).all():
|
| if verbose:
|
| print('Model called the end of sequence at:', sl, '/', seq_len)
|
| break
|
|
|
| gen = out[:, t:].cpu()
|
|
|
| if eos_token is not None:
|
| for i in range(b):
|
| seq_full = out[i].cpu()
|
| eos_positions = (seq_full == eos_token).nonzero(as_tuple=False)
|
| if eos_positions.numel() > 0:
|
| first_eos_idx = int(eos_positions[0].item())
|
| gen_len_before_eos = max(0, first_eos_idx - t)
|
| token_probs[i] = token_probs[i][:gen_len_before_eos]
|
| token_logprobs[i] = token_logprobs[i][:gen_len_before_eos]
|
| token_scores[i] = token_scores[i][:gen_len_before_eos]
|
| seq_logprob_tensors[i] = torch.tensor(sum(token_logprobs[i]), dtype=torch.float64)
|
| if include_top1:
|
| top1_tokens[i] = top1_tokens[i][:gen_len_before_eos]
|
| top1_token_probs[i] = top1_token_probs[i][:gen_len_before_eos]
|
| top1_token_logprobs[i] = top1_token_logprobs[i][:gen_len_before_eos]
|
| top1_matches[i] = top1_matches[i][:gen_len_before_eos]
|
| if pad_value is not None:
|
| start_mask = max(0, first_eos_idx - t)
|
| if start_mask < gen.shape[1]:
|
| gen[i, start_mask:] = pad_value
|
|
|
| sequence_logprobs: List[float] = [float(x.item()) for x in seq_logprob_tensors]
|
| sequence_probs: List[float] = []
|
| sequence_prob_display: List[str] = []
|
| nll: List[float] = []
|
|
|
| for lp in sequence_logprobs:
|
| pnum, disp = _safe_exp64(lp)
|
| sequence_probs.append(pnum)
|
| sequence_prob_display.append(disp)
|
| nll.append(-lp)
|
|
|
| result = {
|
| "tokens": gen,
|
| "token_probs": token_probs,
|
| "token_logprobs": token_logprobs,
|
| "token_scores": token_scores,
|
| "sequence_logprobs": sequence_logprobs,
|
| "sequence_probs": sequence_probs,
|
| "sequence_prob_display": sequence_prob_display,
|
| "nll": nll,
|
| "metadata": {
|
| "prompt_len": t,
|
| "generated_len": gen.shape[1],
|
| "temperature": temperature
|
| }
|
| }
|
|
|
| if include_top1:
|
| result.update({
|
| "top1_tokens": top1_tokens,
|
| "top1_token_probs": top1_token_probs,
|
| "top1_token_logprobs": top1_token_logprobs,
|
| "top1_matches": top1_matches
|
| })
|
|
|
|
|
| diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)]
|
| if include_top1:
|
| for i in range(b):
|
| for pos, (sample_tok, sample_p, t1_tok, t1_p, match) in enumerate(zip(
|
| [int(x) for x in gen[i].tolist()],
|
| token_probs[i],
|
| top1_tokens[i],
|
| top1_token_probs[i],
|
| top1_matches[i]
|
| )):
|
| if not match:
|
| diff_all[i].append({
|
| "pos": pos,
|
| "token": sample_tok,
|
| "token_str": _decode_token(tokenizer, sample_tok),
|
| "token_prob": sample_p,
|
| "top1_token": int(t1_tok),
|
| "top1_token_str": _decode_token(tokenizer, int(t1_tok)),
|
| "top1_prob": t1_p,
|
| "match": bool(match)
|
| })
|
| result["diff"] = diff_all
|
|
|
| result = _attach_metrics_to_result(result)
|
|
|
| if print_table:
|
| for i in range(b):
|
| print("="*110)
|
| print(f"Batch {i} (prompt_len={t})")
|
| print("-"*110)
|
| print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match")
|
| print("-"*110)
|
| cum_logp = 0.0
|
| for idx, (p, lp, sc) in enumerate(zip(token_probs[i], token_logprobs[i], token_scores[i])):
|
| cum_logp += lp
|
| tok_id = int(gen[i, idx].item()) if idx < gen.shape[1] else -1
|
| tok_display = _decode_token(tokenizer, tok_id)
|
| if include_top1:
|
| t1_id = top1_tokens[i][idx]
|
| t1_p = top1_token_probs[i][idx]
|
| match_mark = "*" if top1_matches[i][idx] else " "
|
| print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}){match_mark}")
|
| else:
|
| print(f"{idx:3d} | {tok_display:>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}")
|
| print("-"*110)
|
| print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}")
|
| m = result["metrics"]["per_sequence"][i]
|
| print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}")
|
| if result["diff"][i]:
|
| print("DIFF (sampled != top1) positions:")
|
| for d in result["diff"][i]:
|
| print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}")
|
| else:
|
| print("No diffs: sampled tokens matched top1 at every step.")
|
| print("="*110)
|
|
|
| return result
|
|
|
|
|
|
|
|
|
| @torch.inference_mode()
|
| def score_sequences(
|
| model,
|
| sequences: torch.Tensor,
|
| prompt_lens: Optional[torch.Tensor] = None,
|
| eos_token: Optional[int] = None,
|
| pad_value: Optional[int] = None,
|
| filter_logits_fn: Optional[Callable] = None,
|
| filter_kwargs: Optional[Dict[str, Any]] = None,
|
| tokenizer = None,
|
| print_table: bool = False,
|
| device: Optional[torch.device] = None,
|
| verbose: bool = False,
|
| include_top1: bool = True,
|
| **kwargs
|
| ) -> Dict[str, Any]:
|
| """
|
| Compute per-token and per-sequence likelihood statistics for given full sequences
|
| under an autoregressive model, optionally comparing each target token to the model's
|
| top-1 prediction and producing a diff of mismatches.
|
|
|
| This function scores provided sequences by computing the model's next-token distribution
|
| for each position and extracting the probability and log-probability assigned to the
|
| actual target token (i.e., the token that follows each input prefix). It supports
|
| masking of padding tokens, optional EOS-based truncation, and an optional logits filter.
|
| The function returns detailed per-token lists, aggregated sequence log-probabilities,
|
| NLLs, human-friendly probability displays, and diagnostic "diff" entries where the
|
| target token differs from the model's greedy top-1.
|
|
|
| Key behaviors
|
| - Operates under `torch.inference_mode()` (no gradients).
|
| - Expects `sequences` shaped (batch, seq_len). The function scores tokens at positions
|
| 1..(seq_len-1) where each target is `sequences[:, pos]` and the corresponding input
|
| is `sequences[:, :pos]`.
|
| - If `filter_logits_fn` is provided it is applied to the model logits before softmax.
|
| - If `pad_value` is provided, positions where the target equals `pad_value` are masked
|
| out and not counted in sequence sums or per-token lists.
|
| - If `eos_token` is provided, tokens after the first EOS in each sequence are masked out.
|
| - If `include_top1` is True, the function computes top-1 ids and probabilities and
|
| records whether the target matched the top-1 at each scored position.
|
|
|
| Parameters
|
| - model: A model object exposing a `net` callable with signature
|
| `logits = model.net(tokens, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)`.
|
| `logits` must be a tensor of shape (batch, seq, vocab) or a tuple/list whose first
|
| element is that tensor.
|
| - sequences (torch.Tensor): Integer token tensor of shape (batch, seq_len) containing
|
| full sequences to be scored. The first token of each sequence is treated as context
|
| and scoring begins at the second token.
|
| - prompt_lens (Optional[torch.Tensor]): Optional per-batch prompt lengths; included in
|
| returned metadata for bookkeeping (does not change scoring logic).
|
| - eos_token (Optional[int]): Token id that marks end-of-sequence. If provided, tokens
|
| after the first EOS are excluded from scoring.
|
| - pad_value (Optional[int]): Token id used to indicate padding; masked positions are
|
| excluded from per-token lists and sequence aggregates.
|
| - filter_logits_fn (Optional[Callable]): Function applied to raw logits before softmax.
|
| Signature should accept `(logits, **filter_kwargs)` and return logits of same shape.
|
| - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments forwarded to `filter_logits_fn`.
|
| - tokenizer: Optional tokenizer used to decode token ids for human-readable diffs and
|
| printed tables. If absent, token ids are stringified.
|
| - print_table (bool): If True, prints a human-readable table summarizing per-token stats.
|
| - device (Optional[torch.device]): Device to run scoring on. Defaults to `sequences.device`.
|
| - verbose (bool): If True, prints progress or extra information (currently minimal).
|
| - include_top1 (bool): If True, compute and return top-1 tokens, their probs/logprobs,
|
| and a `diff` structure listing positions where target != top-1.
|
| - **kwargs: Additional keyword arguments forwarded to `model.net`.
|
|
|
| Returns
|
| A dictionary with the following keys:
|
| - "tokens" (torch.Tensor): The input `sequences` returned as a CPU tensor.
|
| - "token_probs" (List[List[float]]): Per-batch lists of probabilities assigned to each
|
| scored target token (masked positions removed).
|
| - "token_logprobs" (List[List[float]]): Per-batch lists of log-probabilities (nats).
|
| - "token_scores" (List[List[float]]): Per-token scores (negative log-probabilities).
|
| - "sequence_logprobs" (List[float]): Sum of log-probabilities over unmasked target tokens.
|
| - "sequence_probs" (List[float]): Sequence probabilities where numerically representable.
|
| - "sequence_prob_display" (List[str]): Human-friendly display of sequence probability.
|
| - "nll" (List[float]): Negative log-likelihood per sequence (i.e., -sequence_logprob).
|
| - "mask" (torch.BoolTensor): Boolean mask (batch, scored_len) indicating which target
|
| positions were included in scoring (True = scored).
|
| - "diff" (List[List[Dict]]): Per-batch list of dicts for positions where the target
|
| token did not match the model's top-1. Each dict contains:
|
| - "pos": index within the scored positions (0-based)
|
| - "token": target token id
|
| - "token_str": decoded target token (or id string)
|
| - "token_prob": probability assigned to the target token
|
| - "top1_token": top-1 token id
|
| - "top1_token_str": decoded top-1 token
|
| - "top1_prob": top-1 probability
|
| - "match": boolean (False for entries in diff)
|
| - "metadata" (dict): Contains "prompt_len" (if provided), "seq_len" (original sequence
|
| length), and "scored_len_per_batch" (number of scored tokens per batch item).
|
| - If `include_top1` is True, additional keys are included:
|
| - "top1_tokens", "top1_token_probs", "top1_token_logprobs", "top1_matches"
|
|
|
| After assembling the primary result the function attaches a "metrics" entry with:
|
| - "per_sequence": list of per-sequence metric dicts containing:
|
| - "sequence_index", "token_count", "sequence_logprob_nats", "sequence_log10",
|
| "sequence_prob_display", "avg_logprob_per_token_nats", "avg_logprob_per_token_bits",
|
| "geometric_mean_token_prob", "perplexity"
|
|
|
| Notes and caveats
|
| - The function expects `sequences` to contain at least two tokens per batch item; if
|
| `seq_len < 2` a minimal result with empty scored lists is returned.
|
| - Numerical stability: probabilities are clamped before log to avoid -inf; extremely
|
| small sequence probabilities are represented in approximate 10^x form.
|
| - The function may raise ValueError if `model.net` returns logits of unexpected shape.
|
| - This routine is intended for evaluation and analysis of model likelihoods rather than
|
| high-performance batched scoring in production.
|
|
|
| Example (conceptual)
|
| >>> res = score_sequences(model, sequences, pad_value=0, eos_token=2, tokenizer=tok)
|
| >>> print(res["metrics"]["per_sequence"][0]["avg_logprob_per_token_nats"])
|
| """
|
| if filter_kwargs is None:
|
| filter_kwargs = {}
|
| if device is None:
|
| device = sequences.device
|
|
|
| model.eval()
|
| with torch.inference_mode():
|
| sequences = sequences.to(device)
|
| b, L = sequences.shape
|
|
|
| if L < 2:
|
| empty = [[] for _ in range(b)]
|
| return {
|
| "tokens": sequences.cpu(),
|
| "token_probs": empty,
|
| "token_logprobs": empty,
|
| "token_scores": empty,
|
| "sequence_probs": [1.0 for _ in range(b)],
|
| "sequence_prob_display": [f"{1.0:.6e}" for _ in range(b)],
|
| "sequence_logprobs": [0.0 for _ in range(b)],
|
| "nll": [0.0 for _ in range(b)],
|
| "mask": torch.zeros((b, 0), dtype=torch.bool),
|
| "diff": [[] for _ in range(b)],
|
| "metadata": {"prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens),
|
| "seq_len": L,
|
| "scored_len": 0}
|
| }
|
|
|
| inputs = sequences[:, :-1]
|
| targets = sequences[:, 1:]
|
|
|
| logits_out = model.net(inputs, return_intermediates=True, cache=None, seq_start_pos=None, **kwargs)
|
| logits = logits_out[0] if isinstance(logits_out, (tuple, list)) else logits_out
|
|
|
| if logits.dim() != 3:
|
| raise ValueError(f"Expected logits with shape (b, seq, vocab), got {logits.shape}")
|
|
|
| filtered_logits = logits if filter_logits_fn is None else filter_logits_fn(logits, **(filter_kwargs or {}))
|
| probs = F.softmax(filtered_logits, dim=-1)
|
| targets_unsq = targets.unsqueeze(-1)
|
| picked_probs = probs.gather(dim=-1, index=targets_unsq).squeeze(-1)
|
| picked_logprobs = torch.log(picked_probs.clamp_min(1e-45)).to(dtype=torch.float64)
|
|
|
| if include_top1:
|
| top1_ids = probs.argmax(dim=-1)
|
| top1_p = probs.gather(-1, top1_ids.unsqueeze(-1)).squeeze(-1)
|
| top1_lp = torch.log(top1_p.clamp_min(1e-45)).to(dtype=torch.float64)
|
|
|
| mask = torch.ones_like(picked_probs, dtype=torch.bool)
|
| if pad_value is not None:
|
| mask = mask & (targets != pad_value)
|
|
|
| if eos_token is not None:
|
| for i in range(b):
|
| seq_full = sequences[i]
|
| eos_positions = (seq_full == eos_token).nonzero(as_tuple=False)
|
| if eos_positions.numel() > 0:
|
| first_eos = int(eos_positions[0].item())
|
| cutoff = max(0, first_eos - 1)
|
| if cutoff + 1 < mask.shape[1]:
|
| mask[i, cutoff+1:] = False
|
|
|
| token_probs: List[List[float]] = []
|
| token_logprobs: List[List[float]] = []
|
| token_scores: List[List[float]] = []
|
| sequence_logprobs: List[float] = []
|
| sequence_probs: List[float] = []
|
| sequence_prob_display: List[str] = []
|
| nll: List[float] = []
|
|
|
| top1_tokens: List[List[int]] = [[] for _ in range(b)]
|
| top1_token_probs: List[List[float]] = [[] for _ in range(b)]
|
| top1_token_logprobs: List[List[float]] = [[] for _ in range(b)]
|
| top1_matches: List[List[bool]] = [[] for _ in range(b)]
|
|
|
| diff_all: List[List[Dict[str, Any]]] = [[] for _ in range(b)]
|
|
|
| for i in range(b):
|
| row_mask = mask[i]
|
| row_probs = picked_probs[i]
|
| row_logps = picked_logprobs[i]
|
| kept_probs = row_probs[row_mask].cpu().tolist()
|
| kept_logps = row_logps[row_mask].cpu().tolist()
|
| kept_scores = [-lp for lp in kept_logps]
|
| token_probs.append([float(x) for x in kept_probs])
|
| token_logprobs.append([float(x) for x in kept_logps])
|
| token_scores.append([float(x) for x in kept_scores])
|
|
|
| if include_top1:
|
| t1_row = top1_ids[i]
|
| t1_p_row = top1_p[i]
|
| t1_lp_row = top1_lp[i]
|
| kept_t1_ids = t1_row[row_mask].cpu().tolist()
|
| kept_t1_ps = t1_p_row[row_mask].cpu().tolist()
|
| kept_t1_lps = t1_lp_row[row_mask].cpu().tolist()
|
| top1_tokens[i] = [int(x) for x in kept_t1_ids]
|
| top1_token_probs[i] = [float(x) for x in kept_t1_ps]
|
| top1_token_logprobs[i] = [float(x) for x in kept_t1_lps]
|
| kept_targets = targets[i][row_mask].cpu().tolist()
|
| top1_matches[i] = [int(t == top1) for t, top1 in zip(kept_targets, kept_t1_ids)]
|
|
|
|
|
| for pos_idx, (tgt, tgt_p, t1, t1_p, match) in enumerate(zip(kept_targets, kept_probs, kept_t1_ids, kept_t1_ps, top1_matches[i])):
|
| if not match:
|
| diff_all[i].append({
|
| "pos": pos_idx,
|
| "token": int(tgt),
|
| "token_str": _decode_token(tokenizer, int(tgt)),
|
| "token_prob": float(tgt_p),
|
| "top1_token": int(t1),
|
| "top1_token_str": _decode_token(tokenizer, int(t1)),
|
| "top1_prob": float(t1_p),
|
| "match": bool(match)
|
| })
|
|
|
| seq_lp_tensor = torch.tensor(sum(kept_logprobs := kept_logps), dtype=torch.float64)
|
| seq_lp = float(seq_lp_tensor.item())
|
| pnum, disp = _safe_exp64(seq_lp)
|
| sequence_logprobs.append(seq_lp)
|
| sequence_probs.append(pnum)
|
| sequence_prob_display.append(disp)
|
| nll.append(-seq_lp)
|
|
|
| result = {
|
| "tokens": sequences.cpu(),
|
| "token_probs": token_probs,
|
| "token_logprobs": token_logprobs,
|
| "token_scores": token_scores,
|
| "sequence_logprobs": sequence_logprobs,
|
| "sequence_probs": sequence_probs,
|
| "sequence_prob_display": sequence_prob_display,
|
| "nll": nll,
|
| "mask": mask.cpu(),
|
| "diff": diff_all,
|
| "metadata": {
|
| "prompt_len": None if prompt_lens is None else (prompt_lens.tolist() if isinstance(prompt_lens, torch.Tensor) else prompt_lens),
|
| "seq_len": L,
|
| "scored_len_per_batch": [int(m.sum().item()) for m in mask]
|
| }
|
| }
|
|
|
| if include_top1:
|
| result.update({
|
| "top1_tokens": top1_tokens,
|
| "top1_token_probs": top1_token_probs,
|
| "top1_token_logprobs": top1_token_logprobs,
|
| "top1_matches": top1_matches
|
| })
|
|
|
| result = _attach_metrics_to_result(result)
|
|
|
| if print_table:
|
| for i in range(b):
|
| print("=" * 120)
|
| header = f"Batch {i} (seq_len={L})"
|
| if prompt_lens is not None:
|
| header += f" prompt_len={int(prompt_lens[i].item()) if isinstance(prompt_lens[i], torch.Tensor) else prompt_lens[i]}"
|
| print(header)
|
| print("-" * 120)
|
| print(" idx | token | prob | logprob | cum_logp | token_nll | top1_token (p) | match")
|
| print("-" * 120)
|
| cum_logp = 0.0
|
| pos_idx = 0
|
| for pos in range(1, L):
|
| if not mask[i, pos-1]:
|
| continue
|
| tok_id = int(sequences[i, pos].item())
|
| p = float(picked_probs[i, pos-1].cpu().item())
|
| lp = float(picked_logprobs[i, pos-1].cpu().item())
|
| cum_logp += lp
|
| sc = -lp
|
| if include_top1:
|
| t1_id = top1_tokens[i][pos_idx]
|
| t1_p = top1_token_probs[i][pos_idx]
|
| match = top1_matches[i][pos_idx]
|
| print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f} | {_decode_token(tokenizer, t1_id):>12s} ({t1_p:5.3f}) | {match}")
|
| else:
|
| print(f"{pos_idx:3d} | {_decode_token(tokenizer, tok_id):>12s} | {p:9.6f} | {lp:11.6f} | {cum_logp:12.6f} | {sc:10.6f}")
|
| pos_idx += 1
|
| print("-" * 120)
|
| print(f"Sequence logprob (nats): {result['sequence_logprobs'][i]:.6f} | Sequence prob: {result['sequence_prob_display'][i]} | NLL: {result['nll'][i]:.6f}")
|
| m = result["metrics"]["per_sequence"][i]
|
| print(f"Avg logprob/token: {m['avg_logprob_per_token_nats']:.6f} nats ({m['avg_logprob_per_token_bits']:.4f} bits) | Perplexity: {m['perplexity']:.6f}")
|
| if result["diff"][i]:
|
| print("DIFF (target != top1) positions:")
|
| for d in result["diff"][i]:
|
| print(f" pos={d['pos']} token={d['token_str']}({d['token']}) p={d['token_prob']:.6f} | top1={d['top1_token_str']}({d['top1_token']}) p={d['top1_prob']:.6f}")
|
| else:
|
| print("No diffs: target tokens matched top1 at every scored position.")
|
| print("=" * 120)
|
|
|
| return result
|
|
|
|
|
|
|
|
|
|
|
| from datetime import datetime, timedelta
|
| from zoneinfo import ZoneInfo
|
|
|
| def calculate_eta(
|
| hours_until_done: float,
|
| *,
|
| tz: str = "America/Los_Angeles",
|
| now: datetime | None = None,
|
| return_dict: bool = False,
|
| verbose: bool = True,
|
| ):
|
|
|
| """
|
| Compute an ETA timestamp based on the current time (or a provided time)
|
| in a specified timezone.
|
|
|
| Parameters
|
| ----------
|
| hours_until_done : float
|
| Number of hours remaining until completion.
|
| tz : str, optional
|
| IANA timezone name (default: "America/Los_Angeles").
|
| now : datetime or None, optional
|
| If provided, use this datetime as the starting point.
|
| If None, the current time in the given timezone is used.
|
| return_dict : bool, optional
|
| If True, return a dictionary with ETA components.
|
| verbose : bool, optional
|
| If True, print a formatted ETA string.
|
|
|
| Returns
|
| -------
|
| datetime or dict
|
| ETA as a datetime object or a dictionary (if return_dict=True).
|
|
|
| Examples
|
| --------
|
|
|
| # Simple ETA 5.5 hours from now
|
| calculate_eta(5.5)
|
|
|
| # ETA using a custom starting time in Tokyo
|
| from datetime import datetime
|
| calculate_eta(
|
| 12,
|
| tz="Asia/Tokyo",
|
| now=datetime(2026, 1, 29, 8, 30),
|
| )
|
|
|
| # Get ETA as a dict without printing
|
| info = calculate_eta(3, verbose=False, return_dict=True)
|
| print(info["pretty"])
|
| """
|
|
|
|
|
| zone = ZoneInfo(tz)
|
|
|
|
|
| current_time = now.astimezone(zone) if now else datetime.now(zone)
|
|
|
|
|
| eta = current_time + timedelta(hours=hours_until_done)
|
|
|
|
|
| pretty = eta.strftime("ETA: %A, %B %d %Y @ %H:%M")
|
|
|
| if verbose:
|
| print(pretty)
|
|
|
| if return_dict:
|
| return {
|
| "eta_datetime": eta,
|
| "year": eta.year,
|
| "month": eta.month,
|
| "day": eta.day,
|
| "hour": eta.hour,
|
| "minute": eta.minute,
|
| "second": eta.second,
|
| "timezone": tz,
|
| "pretty": pretty,
|
| }
|
|
|
| return eta
|
|
|
| def calculate_training_run_eta(
|
| num_epochs: int,
|
| num_steps_per_epoch: int,
|
| sec_per_iter: float,
|
| *,
|
| cost_per_hr: float = 0.0,
|
| tz: str = "America/Los_Angeles",
|
| now: datetime | None = None,
|
| return_dict: bool = False,
|
| verbose: bool = True,
|
| ):
|
| """
|
| Compute ETA and cost for a full training run based on:
|
| - number of epochs
|
| - number of steps per epoch
|
| - seconds per iteration
|
| - optional cost per hour of compute
|
|
|
| Prints:
|
| - start time
|
| - ETA timestamp
|
| - per-epoch runtime (h/m/s)
|
| - total runtime (h/m/s)
|
| - cost per epoch
|
| - total run cost
|
|
|
| Returns:
|
| datetime or dict (if return_dict=True)
|
|
|
| Examples:
|
|
|
| # 2 epochs, 7770 steps each, 15.07 sec/iter, $5.3 per/hr
|
| calculate_training_run_eta(
|
| num_epochs=2,
|
| num_steps_per_epoch=7771,
|
| cost_per_hr=5.3,
|
| sec_per_iter=15.07,
|
| )
|
|
|
| # Get structured info without printing
|
| info = calculate_training_run_eta(
|
| 3, 1000, 0.5,
|
| verbose=False,
|
| return_dict=True
|
| )
|
| print(info["eta_str"])
|
| """
|
|
|
| zone = ZoneInfo(tz)
|
| start_time = now.astimezone(zone) if now else datetime.now(zone)
|
|
|
|
|
| total_iters = num_epochs * num_steps_per_epoch
|
| total_seconds = total_iters * sec_per_iter
|
| epoch_seconds = num_steps_per_epoch * sec_per_iter
|
|
|
| eta = start_time + timedelta(seconds=total_seconds)
|
|
|
|
|
| def fmt(seconds: float) -> str:
|
| seconds = int(seconds)
|
| h = seconds // 3600
|
| m = (seconds % 3600) // 60
|
| s = seconds % 60
|
| return f"{h}h {m}m {s}s"
|
|
|
|
|
| total_hours = total_seconds / 3600
|
| epoch_hours = epoch_seconds / 3600
|
|
|
| cost_epoch = epoch_hours * cost_per_hr
|
| cost_total = total_hours * cost_per_hr
|
|
|
|
|
| start_str = start_time.strftime("%A, %B %d %Y @ %H:%M")
|
| eta_str = eta.strftime("%A, %B %d %Y @ %H:%M")
|
|
|
| if verbose:
|
| print(f"Start Time: {start_str}")
|
| print(f"ETA: {eta_str}")
|
| print(f"Per Epoch: {fmt(epoch_seconds)}")
|
| print(f"Total Run: {fmt(total_seconds)}")
|
| print(f"Cost/Epoch: ${cost_epoch:,.2f}")
|
| print(f"Cost/Run: ${cost_total:,.2f}")
|
|
|
| if return_dict:
|
| return {
|
| "start_time": start_time,
|
| "eta": eta,
|
| "start_str": start_str,
|
| "eta_str": eta_str,
|
| "epoch_seconds": epoch_seconds,
|
| "total_seconds": total_seconds,
|
| "epoch_runtime_hms": fmt(epoch_seconds),
|
| "total_runtime_hms": fmt(total_seconds),
|
| "epoch_hours": epoch_hours,
|
| "total_hours": total_hours,
|
| "cost_per_hr": cost_per_hr,
|
| "cost_epoch": cost_epoch,
|
| "cost_total": cost_total,
|
| "timezone": tz,
|
| }
|
|
|
| return eta
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from typing import Optional, Dict, List, Union, Set
|
| import numpy as np
|
| from tqdm import tqdm
|
| from contextlib import nullcontext
|
|
|
|
|
|
|
|
|
|
|
| def get_embeddings(
|
| model,
|
| inputs: torch.Tensor,
|
| pooling: str = 'mean',
|
| mask: Optional[torch.Tensor] = None,
|
| token_ids: Optional[List[int]] = None,
|
| token_weights: Optional[Dict[int, float]] = None,
|
| layer_index: int = -1,
|
| normalize: bool = False,
|
| device: Optional[torch.device] = None,
|
| dtype: torch.dtype = torch.bfloat16,
|
| pad_idx: int = 18819,
|
| use_amp: bool = True,
|
| verbose: bool = True,
|
| _max_concat_tokens: Optional[int] = None,
|
| ) -> np.ndarray:
|
|
|
| """
|
| Get embeddings for a single batch of inputs.
|
|
|
| Parameters
|
| ----------
|
| model : AutoregressiveWrapper
|
| Your trained transformer model
|
| inputs : torch.Tensor
|
| Input token sequences of shape (batch, seq_len)
|
| pooling : str
|
| Pooling strategy: 'mean' or 'concat'
|
| mask : Optional[torch.Tensor]
|
| Boolean mask, True for valid tokens. Auto-generated if None
|
| token_ids : Optional[List[int]]
|
| Token IDs to include. Works independently or with token_weights.
|
| token_weights : Optional[Dict[int, float]]
|
| Token ID to weight/priority mapping:
|
| - 'mean': weights for weighted average
|
| - 'concat': priority scores for selection when limiting count
|
| - If provided WITHOUT token_ids: keys become the filter
|
| - If provided WITH token_ids: only tokens in BOTH are used (intersection)
|
| layer_index : int
|
| Which layer's hidden states to use (-1 for last)
|
| normalize : bool
|
| L2-normalize output embeddings
|
| device : Optional[torch.device]
|
| Device for inference
|
| dtype : torch.dtype
|
| Dtype for autocast
|
| pad_idx : int
|
| Padding token index
|
| use_amp : bool
|
| Use automatic mixed precision
|
| verbose : bool
|
| Print warnings and info
|
| _max_concat_tokens : Optional[int]
|
| Internal: pre-computed max tokens for concat mode
|
|
|
| Returns
|
| -------
|
| np.ndarray
|
| Embeddings array:
|
| - 'mean': (batch, dim)
|
| - 'concat': (batch, max_tokens * dim)
|
| """
|
|
|
| model.eval()
|
|
|
| if device is None:
|
| device = next(model.parameters()).device
|
|
|
| inputs = inputs.to(device)
|
|
|
| if inputs.ndim == 1:
|
| inputs = inputs.unsqueeze(0)
|
|
|
| batch_size, seq_len = inputs.shape
|
|
|
| if mask is None:
|
| mask = (inputs != pad_idx)
|
| else:
|
| mask = mask.to(device)
|
|
|
| if mask.dtype != torch.bool:
|
| mask = mask.bool()
|
|
|
| if hasattr(model, 'net'):
|
| net_model = model.net
|
| else:
|
| net_model = model
|
|
|
| if use_amp and device.type == 'cuda':
|
| ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
|
| else:
|
| ctx = nullcontext()
|
|
|
| try:
|
| with torch.no_grad():
|
| with ctx if use_amp else nullcontext():
|
| output = net_model(
|
| inputs,
|
| mask=mask if mask.ndim == 2 else mask.squeeze(),
|
| return_intermediates=True,
|
| )
|
|
|
| if isinstance(output, tuple) and len(output) == 2:
|
| _, intermediates = output
|
| else:
|
| intermediates = None
|
|
|
| hidden = _extract_hidden_states(intermediates, layer_index, verbose=verbose)
|
|
|
| if hidden is None:
|
| raise ValueError("Could not extract hidden states")
|
|
|
| except Exception as e:
|
| if verbose:
|
| print(f"Warning: Could not extract hidden states, using token embeddings. Error: {e}")
|
| hidden = _get_token_embeddings(net_model, inputs)
|
|
|
| seq_mask = (inputs != pad_idx)
|
| seq_mask_expanded = seq_mask.unsqueeze(-1)
|
| hidden = hidden * seq_mask_expanded.float()
|
|
|
|
|
| effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
|
|
|
| if pooling == 'mean':
|
| emb = _mean_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights, verbose=verbose)
|
| elif pooling == 'concat':
|
| emb = _concat_pooling(hidden, inputs, seq_mask, effective_token_ids, token_weights,
|
| max_tokens=_max_concat_tokens, verbose=verbose)
|
| else:
|
| raise ValueError(f"Unknown pooling strategy: {pooling}. Use 'mean' or 'concat'")
|
|
|
| if normalize:
|
| emb = F.normalize(emb, p=2, dim=-1)
|
|
|
| return emb.cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_embeddings_batched(
|
| model,
|
| sequences: List[List[int]],
|
| pooling: str = 'mean',
|
| token_ids: Optional[List[int]] = None,
|
| token_weights: Optional[Dict[int, float]] = None,
|
| max_seq_len: int = 8192,
|
| pad_idx: int = 18819,
|
| batch_size: int = 8,
|
| use_amp: bool = True,
|
| dtype: torch.dtype = torch.bfloat16,
|
| verbose: bool = True,
|
| show_progress: bool = True,
|
| normalize: bool = False,
|
| ) -> np.ndarray:
|
|
|
| """
|
| Process multiple sequences in TRUE batches for memory efficiency.
|
|
|
| Parameters
|
| ----------
|
| model : AutoregressiveWrapper
|
| Your trained transformer model
|
| sequences : List[List[int]]
|
| List of token sequences (list of lists)
|
| pooling : str
|
| Pooling strategy: 'mean' or 'concat'
|
| token_ids : Optional[List[int]]
|
| Token IDs to include
|
| token_weights : Optional[Dict[int, float]]
|
| Token ID to weight/priority mapping
|
| max_seq_len : int
|
| Maximum sequence length
|
| pad_idx : int
|
| Padding token index
|
| batch_size : int
|
| Batch size for processing
|
| use_amp : bool
|
| Use automatic mixed precision
|
| dtype : torch.dtype
|
| Dtype for autocast
|
| verbose : bool
|
| Print messages
|
| show_progress : bool
|
| Show tqdm progress bar
|
| normalize : bool
|
| L2-normalize output embeddings
|
|
|
| Returns
|
| -------
|
| np.ndarray
|
| Embeddings array with consistent dimensions
|
| """
|
|
|
| model.eval()
|
|
|
| num_sequences = len(sequences)
|
|
|
| if verbose:
|
| print(f"Processing {num_sequences} sequences in batches of {batch_size}...")
|
|
|
|
|
| max_concat_tokens = None
|
| if pooling == 'concat':
|
| effective_token_ids = _compute_effective_token_ids(token_ids, token_weights)
|
| max_concat_tokens = _scan_max_matching_tokens(sequences, effective_token_ids, pad_idx, max_seq_len)
|
| if verbose and max_concat_tokens is not None:
|
| print(f"Auto-detected max matching tokens: {max_concat_tokens}")
|
| elif verbose and max_concat_tokens == 0:
|
| print("Warning: No sequences contain matching token IDs, using 1 token placeholder")
|
| max_concat_tokens = 1
|
|
|
| all_embeddings = []
|
| num_batches = (num_sequences + batch_size - 1) // batch_size
|
|
|
| batch_iterator = tqdm(range(num_batches), desc="Extracting embeddings", disable=not (show_progress and verbose)) if show_progress and verbose else range(num_batches)
|
|
|
| for batch_idx in batch_iterator:
|
| start_idx = batch_idx * batch_size
|
| end_idx = min((batch_idx + 1) * batch_size, num_sequences)
|
|
|
| batch_sequences = sequences[start_idx:end_idx]
|
| max_len_in_batch = min(max_seq_len, max(len(seq) for seq in batch_sequences))
|
|
|
| padded_batch = []
|
| for seq in batch_sequences:
|
| if len(seq) > max_len_in_batch:
|
| seq = seq[:max_len_in_batch]
|
| else:
|
| seq = seq + [pad_idx] * (max_len_in_batch - len(seq))
|
| padded_batch.append(seq)
|
|
|
| batch_inputs = torch.tensor(padded_batch, dtype=torch.long)
|
|
|
| batch_embeddings = get_embeddings(
|
| model,
|
| batch_inputs,
|
| pooling=pooling,
|
| token_ids=token_ids,
|
| token_weights=token_weights,
|
| pad_idx=pad_idx,
|
| use_amp=use_amp,
|
| dtype=dtype,
|
| verbose=verbose and batch_idx == 0,
|
| normalize=normalize,
|
| _max_concat_tokens=max_concat_tokens,
|
| )
|
|
|
| all_embeddings.append(batch_embeddings)
|
|
|
| final_embeddings = np.concatenate(all_embeddings, axis=0)
|
|
|
| if verbose:
|
| print(f"Final embeddings shape: {final_embeddings.shape}")
|
|
|
| return final_embeddings
|
|
|
|
|
|
|
|
|
|
|
| def _compute_effective_token_ids(token_ids: Optional[List[int]], token_weights: Optional[Dict[int, float]]) -> Optional[Set[int]]:
|
| """
|
| Compute effective token IDs with INTUITIVE logic:
|
|
|
| - token_ids=None, token_weights=None → None (all valid tokens)
|
| - token_ids=[...], token_weights=None → token_ids
|
| - token_ids=None, token_weights={...} → keys from token_weights
|
| - token_ids=[...], token_weights={...} → INTERSECTION (only tokens in BOTH)
|
|
|
| This ensures token_weights acts as a filter when provided, not just weights.
|
| """
|
| if token_ids is None and token_weights is None:
|
| return None
|
|
|
| token_ids_set = set(token_ids) if token_ids is not None else None
|
| weights_keys_set = set(token_weights.keys()) if token_weights is not None else None
|
|
|
| if token_ids_set is None and weights_keys_set is not None:
|
|
|
| return weights_keys_set
|
| elif token_ids_set is not None and weights_keys_set is None:
|
|
|
| return token_ids_set
|
| elif token_ids_set is not None and weights_keys_set is not None:
|
|
|
|
|
| intersection = token_ids_set & weights_keys_set
|
| if len(intersection) == 0:
|
|
|
| print(f"Warning: token_ids and token_weights have no overlap. Using token_ids only.")
|
| return token_ids_set
|
| return intersection
|
| else:
|
| return None
|
|
|
| def _scan_max_matching_tokens(sequences: List[List[int]],
|
| token_ids: Optional[Set[int]],
|
| pad_idx: int,
|
| max_seq_len: int) -> int:
|
| """
|
| Scan all sequences to find maximum number of tokens matching token_ids.
|
| """
|
| if token_ids is None:
|
| return max(min(len(seq), max_seq_len) for seq in sequences) if sequences else 0
|
|
|
| max_count = 0
|
| for seq in sequences:
|
| truncated = seq[:max_seq_len]
|
| count = sum(1 for tok in truncated if tok in token_ids and tok != pad_idx)
|
| max_count = max(max_count, count)
|
|
|
| return max_count
|
|
|
| def _extract_hidden_states(intermediates, layer_index: int = -1, verbose: bool = True):
|
| """Extract hidden states from LayerIntermediates object."""
|
| if intermediates is None:
|
| if verbose:
|
| print("Warning: intermediates is None")
|
| return None
|
|
|
| if hasattr(intermediates, 'layer_hiddens') and intermediates.layer_hiddens is not None:
|
| if len(intermediates.layer_hiddens) > 0:
|
| return intermediates.layer_hiddens[layer_index]
|
|
|
| if hasattr(intermediates, 'hiddens') and intermediates.hiddens is not None:
|
| if len(intermediates.hiddens) > 0:
|
| return intermediates.hiddens[layer_index]
|
|
|
| if hasattr(intermediates, 'attn_intermediates') and intermediates.attn_intermediates is not None:
|
| if len(intermediates.attn_intermediates) > 0:
|
| attn_int = intermediates.attn_intermediates[layer_index]
|
| if hasattr(attn_int, 'values') and attn_int.values is not None:
|
| return attn_int.values
|
|
|
| if verbose:
|
| print("Warning: Could not find layer_hiddens in intermediates")
|
|
|
| return None
|
|
|
|
|
| def _get_token_embeddings(net_model, inputs: torch.Tensor):
|
| """Get token embeddings directly from embedding layer."""
|
| if hasattr(net_model, 'token_emb'):
|
| if hasattr(net_model.token_emb, 'emb'):
|
| return net_model.token_emb.emb(inputs)
|
| else:
|
| return net_model.token_emb(inputs)
|
| elif hasattr(net_model, 'emb'):
|
| return net_model.emb(inputs)
|
| else:
|
| raise ValueError("Could not find embedding layer in model")
|
|
|
| def _mean_pooling(
|
| hidden: torch.Tensor,
|
| inputs: torch.Tensor,
|
| mask: torch.Tensor,
|
| token_ids: Optional[Set[int]],
|
| token_weights: Optional[Dict[int, float]],
|
| verbose: bool = True
|
| ) -> torch.Tensor:
|
| """
|
| Mean pooling with token ID filtering and weighted averaging.
|
| """
|
| batch_size, seq_len, dim = hidden.shape
|
| device = hidden.device
|
|
|
| if mask.ndim > 2:
|
| mask = mask.squeeze()
|
|
|
| effective_mask = mask.clone()
|
|
|
| if token_ids is not None:
|
| token_mask = torch.zeros_like(mask, dtype=torch.bool, device=device)
|
| for tid in token_ids:
|
| token_mask = token_mask | (inputs == tid)
|
| effective_mask = effective_mask & token_mask
|
|
|
| if verbose and effective_mask.sum() == 0:
|
| print(f"Warning: No tokens match filter, falling back to all valid tokens")
|
| effective_mask = mask
|
|
|
| if token_weights is not None:
|
| weights = torch.zeros_like(effective_mask, dtype=torch.float32, device=device)
|
|
|
| for token_id, weight in token_weights.items():
|
| id_mask = (inputs == token_id) & effective_mask
|
| weights = weights.masked_fill(id_mask, float(weight))
|
|
|
| weights = weights.masked_fill(effective_mask & (weights == 0), 1.0)
|
|
|
| weighted_hidden = hidden * weights.unsqueeze(-1)
|
| sum_weighted = weighted_hidden.sum(dim=1)
|
| sum_weights = weights.sum(dim=1, keepdim=True).clamp(min=1e-9)
|
| return sum_weighted / sum_weights
|
| else:
|
| masked_hidden = hidden * effective_mask.unsqueeze(-1).float()
|
| sum_hidden = masked_hidden.sum(dim=1)
|
| count = effective_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
|
| return sum_hidden / count
|
|
|
| def _concat_pooling(
|
| hidden: torch.Tensor,
|
| inputs: torch.Tensor,
|
| mask: torch.Tensor,
|
| token_ids: Optional[Set[int]],
|
| token_weights: Optional[Dict[int, float]],
|
| max_tokens: Optional[int],
|
| verbose: bool = True
|
| ) -> torch.Tensor:
|
| """
|
| Concat pooling with token ID filtering and weight-based priority selection.
|
| """
|
| batch_size, seq_len, dim = hidden.shape
|
| device = hidden.device
|
|
|
| if max_tokens is None:
|
| max_tokens = 1
|
|
|
| output_dim = max_tokens * dim
|
|
|
| all_token_embs = []
|
|
|
| for i in range(batch_size):
|
| seq_mask = mask[i]
|
| seq_inputs = inputs[i]
|
|
|
| if token_ids is not None:
|
| matching_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)
|
| for tid in token_ids:
|
| matching_mask = matching_mask | ((seq_inputs == tid) & seq_mask)
|
| valid_indices = matching_mask.nonzero(as_tuple=True)[0]
|
| else:
|
| valid_indices = seq_mask.nonzero(as_tuple=True)[0]
|
|
|
| if len(valid_indices) == 0:
|
| emb = torch.zeros(dim, device=device)
|
| emb = F.pad(emb, (0, output_dim - dim))
|
| all_token_embs.append(emb)
|
| continue
|
|
|
| matching_embs = hidden[i, valid_indices, :]
|
|
|
| if token_weights is not None and len(valid_indices) > max_tokens:
|
| weights_list = []
|
| for idx in valid_indices:
|
| tok_id = seq_inputs[idx].item()
|
| weights_list.append(token_weights.get(tok_id, 1.0))
|
|
|
| sorted_pairs = sorted(zip(range(len(valid_indices)), weights_list),
|
| key=lambda x: x[1], reverse=True)
|
| top_indices = [valid_indices[p[0]] for p in sorted_pairs[:max_tokens]]
|
| matching_embs = hidden[i, torch.tensor(top_indices, device=device), :]
|
| elif len(valid_indices) > max_tokens:
|
| matching_embs = matching_embs[:max_tokens]
|
|
|
| if len(valid_indices) < max_tokens:
|
| padding_needed = max_tokens - len(valid_indices)
|
| padding = torch.zeros(padding_needed, dim, device=device)
|
| matching_embs = torch.cat([matching_embs, padding], dim=0)
|
|
|
| emb = matching_embs.reshape(-1)
|
| all_token_embs.append(emb)
|
|
|
| return torch.stack(all_token_embs, dim=0)
|
|
|
|
|
|
|
|
|
|
|
| def get_enc_embeddings(
|
| model,
|
| sequences: List[List[int]],
|
| seq_len: Optional[int] = 3072,
|
| seq_pad_idx: int = 385,
|
| batch_size: int = 64,
|
| save_every_num_batches: int = -1,
|
| save_file_path: str = "saved_embeddings.npy",
|
| device: Optional[torch.device] = None,
|
| normalize: bool = False,
|
| pooling: str = "auto",
|
| token_type_weights: Optional[Tuple[float, float, float]] = None,
|
| use_bfloat16: bool = True,
|
| return_dtype: str = "float32",
|
| return_numpy: bool = False,
|
| verbose: bool = True,
|
| show_progress_bar: bool = True
|
| ) -> Union[Tensor, np.ndarray]:
|
|
|
| """
|
| Compute embeddings for a list of token sequences using a PyTorch model with optional bfloat16/autocast,
|
| pooling, normalization, and periodic saving.
|
|
|
| This function batches input token id sequences, pads/truncates them to a fixed length, runs the model
|
| in evaluation mode under `torch.no_grad()` and optional mixed-precision autocast, and returns a single
|
| tensor (or NumPy array) containing per-sequence embeddings. The model is expected to accept a LongTensor
|
| of token ids `x` and a boolean mask `mask` and to return either:
|
| - a 2-D tensor `(B, D)` of already-pooled embeddings, or
|
| - a 3-D tensor `(B, L, D)` of per-token embeddings (which will be pooled according to `pooling`).
|
|
|
| Key behaviors:
|
| - Sequences are padded with `seq_pad_idx` and masked so padding does not affect pooling.
|
| - If `seq_len` is provided, sequences longer than `seq_len` are truncated; otherwise the batch max length is used.
|
| - Mixed-precision autocast is used when `use_bfloat16` is True and supported by the device; the function
|
| falls back to the default autocast or no autocast if unavailable.
|
| - Supports three pooling modes for per-token embeddings:
|
| - `"auto"` or `"mean"`: simple masked mean pooling across tokens.
|
| - `"weighted_mean"`: weighted mean pooling by token type (onset/duration/pitch) inferred from token ids;
|
| weights are provided via `token_type_weights` and padding tokens are ignored.
|
| - Optionally L2-normalizes embeddings (in float32) when `normalize=True`.
|
| - Returned embeddings can be cast to `float16` for storage/transfer via `return_dtype`.
|
| - Embeddings are collected on CPU; intermediate results can be periodically saved to `save_file_path`.
|
| - If `return_numpy=True`, a NumPy array is returned; otherwise a CPU `torch.Tensor` is returned.
|
|
|
| Args:
|
| model (torch.nn.Module):
|
| PyTorch model used to compute embeddings. The model will be moved to `device` (or its current
|
| parameter device if `device` is None) and set to `eval()` for inference. The forward call must
|
| accept `x` (LongTensor) and `mask` (BoolTensor) and return embeddings when called with
|
| `return_embeddings=True`.
|
| sequences (List[List[int]]):
|
| Batch of token id sequences (each sequence is a list of ints). Can be empty; an empty result
|
| with shape `(0, 0)` will be returned in that case.
|
| seq_len (Optional[int], default=3072):
|
| Target sequence length for truncation/padding. If None, the maximum sequence length in the
|
| current batch is used.
|
| seq_pad_idx (int, default=385):
|
| Token id used for padding positions.
|
| batch_size (int, default=64):
|
| Number of sequences processed per forward pass.
|
| save_every_num_batches (int, default=-1):
|
| If > 0, the function will save accumulated embeddings to `save_file_path` every
|
| `save_every_num_batches` batches. A non-positive value disables periodic saving.
|
| save_file_path (str, default="saved_embeddings.npy"):
|
| File path used by `np.save` when periodic saving is enabled.
|
| device (Optional[torch.device], default=None):
|
| Device to run the model and tensors on. If None, the device of the model parameters is used.
|
| normalize (bool, default=False):
|
| If True, L2-normalize each embedding vector (done in float32 for numerical stability).
|
| pooling (str, default="auto"):
|
| Pooling strategy applied when model returns per-token embeddings:
|
| - "auto" or "mean": masked mean pooling.
|
| - "weighted_mean": weighted mean pooling by token type using `token_type_weights`.
|
| Any other value raises `ValueError`.
|
| token_type_weights (Optional[Tuple[float, float, float]], default=None):
|
| Per-token-type weights `(onset_w, duration_w, pitch_w)` used when `pooling="weighted_mean"`.
|
| If None, defaults to `(1.0, 1.0, 1.0)`. Token type ranges are inferred as:
|
| onset: token_id in [0, 127]
|
| duration:token_id in [128, 255]
|
| pitch: token_id in [256, 383]
|
| use_bfloat16 (bool, default=True):
|
| If True, attempts to use `torch.bfloat16` autocast for the device; falls back gracefully if not supported.
|
| return_dtype (str, default="float32"):
|
| Data type for returned embeddings: `"float32"` or `"float16"`. Internally embeddings are normalized
|
| in float32; casting to float16 happens just before collecting results if requested.
|
| return_numpy (bool, default=False):
|
| If True, the final result is returned as a NumPy array; otherwise a CPU `torch.Tensor` is returned.
|
| verbose (bool, default=True):
|
| If True, prints progress and short diagnostic messages via `tqdm`.
|
| show_progress_bar (bool, default=True)
|
| If True, displays tqdm progress bar.
|
|
|
| Returns:
|
| Union[torch.Tensor, numpy.ndarray]:
|
| - If `return_numpy` is False: a CPU `torch.Tensor` of shape `(N, D)` and dtype `torch.float32`
|
| or `torch.float16` depending on `return_dtype`.
|
| - If `return_numpy` is True: a NumPy array of shape `(N, D)` and dtype `np.float32` or `np.float16`.
|
| `N` is the total number of input sequences and `D` is the embedding dimensionality produced by the model.
|
|
|
| Raises:
|
| AssertionError:
|
| If `return_dtype` is not one of `"float32"` or `"float16"`.
|
| RuntimeError:
|
| If the model returns `None` for embeddings (indicates incorrect forward flags or model behavior).
|
| ValueError:
|
| If the model returns an embedding tensor with unexpected dimensionality or if `pooling` is unsupported.
|
|
|
| Notes:
|
| - The function uses `pad_and_mask` to produce `x` (LongTensor) and `mask` (BoolTensor). Padding tokens
|
| are ignored by pooling operations.
|
| - When `pooling="weighted_mean"`, if `token_ids` are not available or the model returns a 2-D tensor,
|
| the function falls back to masked mean pooling.
|
| - Periodic saving concatenates all embeddings collected so far and writes them with `np.save`. Save
|
| failures are caught and reported when `verbose=True` but do not abort processing.
|
| - The function runs the model under `torch.no_grad()` and sets `model.eval()`; it will move the model
|
| to `device` if provided.
|
| - For reproducible numeric behavior across devices, ensure the model and device support the requested
|
| autocast dtype (bfloat16) and that any randomness is controlled externally.
|
|
|
| Example:
|
| >>> # simple usage
|
| >>> embs = get_embeddings_bf16(model, sequences, seq_len=1024, batch_size=32, pooling="mean",
|
| ... normalize=True, return_dtype="float32", return_numpy=False)
|
| """
|
|
|
| assert return_dtype in ("float32", "float16"), "return_dtype must be 'float32' or 'float16'"
|
|
|
| model_device = next(model.parameters()).device if device is None else device
|
| model.to(model_device)
|
| model.eval()
|
|
|
| all_embs: List[Tensor] = []
|
| total_batches = math.ceil(len(sequences) / batch_size) if batch_size > 0 else 0
|
|
|
| if verbose:
|
| tqdm.write(
|
| f"[get_embeddings_bf16] sequences={len(sequences)}, batch_size={batch_size}, "
|
| f"batches={total_batches}, device={model_device}, seq_len={seq_len}, pooling={pooling}"
|
| )
|
|
|
|
|
| autocast_ctx = None
|
| if use_bfloat16:
|
| try:
|
| autocast_ctx = torch.amp.autocast(device_type=model_device.type, dtype=torch.bfloat16)
|
| except Exception:
|
| try:
|
| autocast_ctx = torch.amp.autocast(device_type=model_device.type)
|
| except Exception:
|
| autocast_ctx = None
|
| else:
|
| try:
|
| autocast_ctx = torch.amp.autocast(device_type=model_device.type)
|
| except Exception:
|
| autocast_ctx = None
|
|
|
| with torch.inference_mode():
|
| batch_iter = range(0, len(sequences), batch_size)
|
| pbar = tqdm(batch_iter, disable=not show_progress_bar, total=total_batches, desc="Embedding batches")
|
| for batch_idx, i in enumerate(pbar):
|
| batch_seqs = sequences[i : i + batch_size]
|
| x, mask = pad_and_mask(batch_seqs, pad_idx=seq_pad_idx, seq_len=seq_len, device=model_device, verbose=verbose)
|
|
|
|
|
|
|
| if autocast_ctx is not None:
|
| with autocast_ctx:
|
| out = model(x, return_embeddings=True, mask=mask)
|
| else:
|
| out = model(x, return_embeddings=True, mask=mask)
|
|
|
| if out is None:
|
| raise RuntimeError("model returned None for embeddings. Check forward flags.")
|
|
|
|
|
| if out.dim() == 2:
|
|
|
| emb = out
|
| elif out.dim() == 3:
|
|
|
| if pooling in ("mean", "auto"):
|
| emb = masked_mean_pool(out, mask, dim=1, verbose=verbose)
|
| elif pooling == "weighted_mean":
|
|
|
| emb = masked_weighted_mean_pool(out, mask, token_ids=x, token_type_weights=token_type_weights, dim=1, verbose=verbose)
|
| else:
|
| raise ValueError(f"unsupported pooling: {pooling}")
|
| else:
|
| raise ValueError(f"unexpected embedding tensor shape: {out.shape}")
|
|
|
|
|
| if emb.dtype != torch.float32:
|
| emb = emb.float()
|
|
|
|
|
| if normalize:
|
| emb = F.normalize(emb, p=2, dim=-1)
|
|
|
|
|
| if return_dtype == "float16":
|
| emb = emb.half()
|
|
|
| all_embs.append(emb.cpu())
|
|
|
|
|
| if verbose:
|
| pbar.set_postfix({"batch": batch_idx + 1, "emb_shape": f"{emb.shape}", "dtype": str(emb.dtype)})
|
|
|
|
|
| if save_every_num_batches > 0:
|
|
|
| bnum = batch_idx
|
| if (bnum + 1) % save_every_num_batches == 0:
|
| try:
|
| concatenated = torch.cat(all_embs, dim=0).numpy()
|
| np.save(save_file_path, concatenated)
|
| if verbose:
|
| tqdm.write(f"[get_embeddings_bf16] saved {concatenated.shape[0]} embeddings to {save_file_path}")
|
| except Exception as e:
|
|
|
| if verbose:
|
| tqdm.write(f"[get_embeddings_bf16] warning: failed to save embeddings: {e}")
|
|
|
| if len(all_embs) == 0:
|
|
|
| empty = torch.empty((0, 0), dtype=(torch.float16 if return_dtype == "float16" else torch.float32))
|
| if verbose:
|
| tqdm.write("[get_embeddings_bf16] no embeddings were produced; returning empty tensor")
|
| return empty.numpy() if return_numpy else empty
|
|
|
| result = torch.cat(all_embs, dim=0)
|
|
|
| if verbose:
|
| tqdm.write(f"[get_embeddings_bf16] finished: total_embeddings={result.shape[0]}, dim={result.shape[1]}, dtype={result.dtype}")
|
|
|
| if return_numpy:
|
| return result.numpy()
|
|
|
| return result
|
|
|
|
|
|
|
| def masked_mean_pool(
|
| token_embeddings: Tensor,
|
| mask: Tensor,
|
| dim: int = 1,
|
| eps: float = 1e-9,
|
| verbose: bool = True,
|
| ) -> Tensor:
|
|
|
| """
|
| Compute a masked mean pooling over a specified dimension.
|
|
|
| This function computes the mean of `token_embeddings` along `dim`, ignoring
|
| positions where `mask` is False. The mask is cast to the same dtype as the
|
| embeddings to allow safe multiplication. A small epsilon is used to avoid
|
| division by zero for sequences that are entirely masked out.
|
|
|
| Args:
|
| token_embeddings: Tensor of shape (B, L, D) or similar where `dim` indexes
|
| the sequence length. Embeddings dtype can be float16/float32/bfloat16.
|
| mask: Boolean tensor of shape broadcastable to the sequence dimension
|
| (e.g., (B, L)). True indicates valid tokens; False indicates padding.
|
| dim: Dimension along which to pool (default: 1, the sequence length).
|
| eps: Small value to avoid division by zero when a row has zero valid tokens.
|
| verbose: If True, prints a short summary about the pooling operation.
|
|
|
| Returns:
|
| Tensor of pooled embeddings with the sequence dimension removed, typically
|
| shape (B, D). The returned dtype matches `token_embeddings.dtype`.
|
| """
|
|
|
| mask_f = mask.to(token_embeddings.dtype)
|
| summed = (token_embeddings * mask_f.unsqueeze(-1)).sum(dim=dim)
|
| counts = mask_f.sum(dim=dim).clamp_min(eps).unsqueeze(-1)
|
| pooled = summed / counts
|
|
|
| if verbose:
|
|
|
| valid_counts = counts.squeeze(-1)
|
| tqdm.write(
|
| f"[masked_mean_pool] pooled shape={pooled.shape}, "
|
| f"counts min={valid_counts.min().item():.3f}, max={valid_counts.max().item():.3f}"
|
| )
|
|
|
| return pooled
|
|
|
|
|
|
|
| def masked_weighted_mean_pool(
|
| token_embs: Tensor,
|
| valid_mask: Tensor,
|
| token_ids: Optional[Tensor] = None,
|
| token_type_weights: Optional[Tuple[float, float, float]] = None,
|
| dim: int = 1,
|
| verbose: bool = False,
|
| ) -> Tensor:
|
|
|
| """
|
| Weighted mean pooling across tokens. If token_ids is provided, token types are
|
| inferred using the same ranges as the reference code:
|
| - onset: token_id in [0, 127]
|
| - duration:token_id in [128, 255]
|
| - pitch: token_id in [256, 383]
|
| token_type_weights: (onset_w, duration_w, pitch_w). If None, defaults to (1.0,1.0,1.0)
|
| The function multiplies each token embedding by its scalar weight and divides
|
| by the sum of weights for valid tokens per sequence.
|
| """
|
|
|
| B, L, D = token_embs.shape
|
| device = token_embs.device
|
| dtype = token_embs.dtype
|
|
|
| if token_ids is None:
|
|
|
| if verbose:
|
| tqdm.write("[masked_weighted_mean_pool] token_ids is None, falling back to masked_mean_pool")
|
| return masked_mean_pool(token_embs, valid_mask, dim=dim, verbose=verbose)
|
|
|
|
|
| if token_type_weights is None:
|
| onset_w, duration_w, pitch_w = 1.0, 1.0, 1.0
|
| else:
|
| onset_w, duration_w, pitch_w = token_type_weights
|
|
|
|
|
| onset_mask = (token_ids >= 0) & (token_ids < 128)
|
| duration_mask = (token_ids >= 128) & (token_ids < 256)
|
| pitch_mask = (token_ids >= 256) & (token_ids < 384)
|
|
|
|
|
| onset_mask = onset_mask & valid_mask
|
| duration_mask = duration_mask & valid_mask
|
| pitch_mask = pitch_mask & valid_mask
|
|
|
|
|
| w = torch.ones((B, L), device=device, dtype=dtype)
|
| if onset_w != 1.0:
|
| w = torch.where(onset_mask, torch.tensor(onset_w, device=device, dtype=dtype), w)
|
| if duration_w != 1.0:
|
| w = torch.where(duration_mask, torch.tensor(duration_w, device=device, dtype=dtype), w)
|
| if pitch_w != 1.0:
|
| w = torch.where(pitch_mask, torch.tensor(pitch_w, device=device, dtype=dtype), w)
|
|
|
|
|
| valid_mask_f = valid_mask.to(dtype)
|
| w = w * valid_mask_f
|
|
|
|
|
| denom = w.sum(dim=1, keepdim=True).clamp(min=1e-6)
|
| w_exp = w.unsqueeze(-1)
|
| summed = (token_embs * w_exp).sum(dim=dim)
|
| pooled = summed / denom
|
|
|
| return pooled
|
|
|
|
|
|
|
| def pad_and_mask(
|
| sequences: List[List[int]],
|
| pad_idx: int = 385,
|
| seq_len: Optional[int] = None,
|
| device: Optional[torch.device] = None,
|
| verbose: bool = False,
|
| ) -> Tuple[Tensor, Tensor]:
|
|
|
| """
|
| Pad and create a boolean mask for a batch of integer token sequences.
|
|
|
| This utility converts a list of variable-length integer sequences into a
|
| padded LongTensor and a corresponding boolean mask indicating valid token
|
| positions. Sequences longer than `seq_len` are truncated. If `seq_len` is
|
| None, the function uses the maximum sequence length in the batch.
|
|
|
| Args:
|
| sequences: List of token id sequences (each a list of ints).
|
| pad_idx: Integer token id used for padding positions (default: 385).
|
| seq_len: Optional target sequence length. If provided, sequences are
|
| truncated or padded to this length. If None, the maximum length in
|
| `sequences` is used.
|
| device: Optional torch.device where the returned tensors will be placed.
|
| If None, tensors are created on the default device.
|
| verbose: If True, shows a small progress bar while processing sequences
|
| and prints a summary.
|
|
|
| Returns:
|
| A tuple (x, mask):
|
| - x: LongTensor of shape (B, T) containing padded token ids.
|
| - mask: BoolTensor of shape (B, T) where True indicates a real token.
|
| """
|
|
|
|
|
| if not sequences:
|
| empty = torch.empty((0, 0), dtype=torch.long, device=device)
|
| empty_mask = torch.empty((0, 0), dtype=torch.bool, device=device)
|
| return empty, empty_mask
|
|
|
|
|
| lengths = [len(s) for s in sequences]
|
| batch_max = max(lengths)
|
|
|
|
|
|
|
| if seq_len is None:
|
| target_len = batch_max
|
| else:
|
| target_len = min(seq_len, batch_max)
|
|
|
| b = len(sequences)
|
| if target_len == 0:
|
| x = torch.full((b, 0), pad_idx, dtype=torch.long, device=device)
|
| mask = torch.zeros((b, 0), dtype=torch.bool, device=device)
|
| return x, mask
|
|
|
| x = torch.full((b, target_len), pad_idx, dtype=torch.long, device=device)
|
| mask = torch.zeros((b, target_len), dtype=torch.bool, device=device)
|
|
|
|
|
| iterator = enumerate(sequences)
|
| if verbose:
|
| iterator = enumerate(tqdm(sequences, disable=not verbose, desc="Pad & mask"))
|
|
|
| for i, seq in iterator:
|
| if not seq:
|
| continue
|
|
|
| L = len(seq)
|
| if L > target_len:
|
| L = target_len
|
|
|
| seq_slice = seq[:L]
|
| seq_tensor = torch.tensor(seq_slice, dtype=torch.long, device=device)
|
| else:
|
| seq_tensor = torch.tensor(seq, dtype=torch.long, device=device)
|
|
|
| x[i, :L] = seq_tensor[:L]
|
| mask[i, :L] = True
|
|
|
| if verbose:
|
| tqdm.write(
|
| f"[pad_and_mask] batch_size={b}, target_len={target_len}, "
|
| f"min_len={min(lengths)}, max_len={max(lengths)}"
|
| )
|
|
|
| return x, mask
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from tqdm import tqdm
|
| from typing import Optional, Union, Tuple
|
|
|
| def topk_cosine_neighbors(embeddings: torch.Tensor,
|
| k: int = 10,
|
| key_embeddings: Optional[torch.Tensor] = None,
|
| row_batch: Optional[int] = None,
|
| col_batch: Optional[int] = None,
|
| device: Optional[Union[str, torch.device]] = None,
|
| normalize: bool = True,
|
| dtype: Optional[torch.dtype] = None,
|
| show_progress: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
| """
|
| For each query embedding, find the indices and similarities of its top-k neighbors
|
| from a set of key embeddings, sorted by descending similarity.
|
|
|
| Supports both self-similarity (single array, excludes self) and pairwise
|
| retrieval (two arrays, no exclusion).
|
|
|
| Optimized for maximum speed and memory efficiency across CPU, CUDA, and MPS.
|
| Uses a streaming batched approach to handle datasets larger than GPU memory.
|
|
|
| Args:
|
| embeddings (torch.Tensor): Query embeddings, shape (N_q, D).
|
| k (int): How many neighbors to return.
|
| key_embeddings (torch.Tensor, optional): Database/Key embeddings, shape (N_k, D).
|
| If None, defaults to 'embeddings' (self-search).
|
| row_batch (int, optional): Number of query rows to process at once. Auto-tuned if None.
|
| col_batch (int, optional): Number of key columns to process at once. Auto-tuned if None.
|
| device (str or torch.device, optional): Target device. If None, uses embeddings.device.
|
| normalize (bool): If True, L2-normalize embeddings. Skip if already normalized.
|
| dtype (torch.dtype, optional): Compute dtype (e.g., torch.float16, torch.bfloat16).
|
| If None, uses embeddings.dtype.
|
| show_progress (bool): Show tqdm progress bar.
|
|
|
| Returns:
|
| top_idx (torch.Tensor): shape (N_q, k), int32 indices of nearest neighbors (indices into key_embeddings).
|
| top_sim (torch.Tensor): shape (N_q, k), float32 cosine similarities.
|
| """
|
|
|
|
|
| is_self_search = (key_embeddings is None)
|
| if is_self_search:
|
| key_embeddings = embeddings
|
|
|
|
|
| if device is None:
|
| device = embeddings.device
|
| else:
|
| device = torch.device(device)
|
|
|
|
|
| if dtype is None:
|
| dtype = embeddings.dtype
|
| else:
|
| assert dtype.is_floating_point, "dtype must be a floating point type"
|
|
|
|
|
|
|
| query_embeddings = embeddings.to(device=device, dtype=dtype).contiguous()
|
| key_embeddings = key_embeddings.to(device=device, dtype=dtype).contiguous()
|
|
|
| N_q, D = query_embeddings.shape
|
| N_k, D_k = key_embeddings.shape
|
|
|
| if D != D_k:
|
| raise ValueError(f"Query and Key embeddings must have same dimension. Got {D} and {D_k}")
|
|
|
|
|
| if k < 1:
|
| raise ValueError(f"k must be >= 1; got {k}")
|
|
|
| if is_self_search:
|
| if k >= N_q:
|
| raise ValueError(f"For self-search, k must be < N (to exclude self). Got N={N_q}, k={k}")
|
| else:
|
| if k > N_k:
|
| raise ValueError(f"For pairwise search, k must be <= N_k. Got N_k={N_k}, k={k}")
|
|
|
|
|
|
|
| if row_batch is None:
|
| if device.type == 'cuda':
|
| row_batch = 16384
|
| elif device.type == 'mps':
|
| row_batch = 8192
|
| else:
|
| row_batch = 4096
|
|
|
| if col_batch is None:
|
| if device.type == 'cuda':
|
| col_batch = 16384
|
| elif device.type == 'mps':
|
| col_batch = 8192
|
| else:
|
| col_batch = 4096
|
|
|
|
|
| row_batch = min(row_batch, N_q)
|
| col_batch = min(col_batch, N_k)
|
|
|
|
|
| if normalize:
|
|
|
| query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
|
|
|
|
| if not is_self_search:
|
| key_embeddings = F.normalize(key_embeddings, p=2, dim=1)
|
|
|
|
|
| top_sim = torch.empty((N_q, k), dtype=torch.float32, device=device)
|
| top_idx = torch.empty((N_q, k), dtype=torch.int, device=device)
|
|
|
|
|
|
|
| merge_sim_buffer = torch.empty((row_batch, 2 * k), dtype=dtype, device=device)
|
| merge_idx_buffer = torch.empty((row_batch, 2 * k), dtype=torch.int, device=device)
|
|
|
|
|
| sim_buffer = torch.empty((row_batch, col_batch), dtype=dtype, device=device)
|
|
|
|
|
| min_val = -torch.finfo(dtype).max
|
|
|
|
|
| with torch.no_grad():
|
| iterator = range(0, N_q, row_batch)
|
| if show_progress:
|
| desc = "Query Batches" if not is_self_search else "Row Batches"
|
| iterator = tqdm(iterator, desc=desc, leave=True)
|
|
|
| for i in iterator:
|
| i_end = min(i + row_batch, N_q)
|
| rb = i_end - i
|
|
|
| rows = query_embeddings[i:i_end]
|
|
|
|
|
|
|
| curr_sim = torch.full((rb, k), min_val, dtype=dtype, device=device)
|
| curr_idx = torch.full((rb, k), -1, dtype=torch.int, device=device)
|
|
|
| for j in range(0, N_k, col_batch):
|
| j_end = min(j + col_batch, N_k)
|
| cb = j_end - j
|
|
|
| cols = key_embeddings[j:j_end]
|
|
|
|
|
|
|
| sim_block = sim_buffer[:rb, :cb]
|
| torch.matmul(rows, cols.T, out=sim_block)
|
|
|
|
|
| if is_self_search:
|
| offset = i - j
|
| r_start = max(0, -offset)
|
| r_end = min(rb, cb - offset)
|
|
|
| if r_start < r_end:
|
|
|
| r_range = torch.arange(r_start, r_end, dtype=torch.long, device=device)
|
| c_range = r_range + offset
|
| sim_block[r_range, c_range] = min_val
|
|
|
|
|
| if cb >= k:
|
| blk_s, blk_p = torch.topk(sim_block, k, dim=1, largest=True, sorted=True)
|
| blk_i = blk_p + j
|
| else:
|
|
|
| pad_size = k - cb
|
| pad_vals = torch.full((rb, pad_size), min_val, dtype=dtype, device=device)
|
| sims_padded = torch.cat([sim_block, pad_vals], dim=1)
|
| blk_s, blk_p = torch.topk(sims_padded, k, dim=1, largest=True, sorted=True)
|
| blk_i = blk_p + j
|
|
|
| blk_i[blk_s == min_val] = -1
|
|
|
|
|
|
|
| merge_sim_buffer[:rb, :k] = curr_sim
|
| merge_sim_buffer[:rb, k:2*k] = blk_s
|
| merge_idx_buffer[:rb, :k] = curr_idx
|
| merge_idx_buffer[:rb, k:2*k] = blk_i
|
|
|
| curr_sim, top_p = torch.topk(merge_sim_buffer[:rb, :2*k], k, dim=1, largest=True, sorted=True)
|
| curr_idx = torch.gather(merge_idx_buffer[:rb, :2*k], dim=1, index=top_p)
|
|
|
|
|
| top_sim[i:i_end] = curr_sim.to(torch.float32)
|
| top_idx[i:i_end] = curr_idx
|
|
|
|
|
| if k == 1:
|
| return top_idx.view(-1), top_sim.view(-1)
|
|
|
| return top_idx, top_sim
|
|
|
|
|
|
|
|
|
|
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| from sklearn.metrics import pairwise_distances
|
|
|
| def plot_emb_cosine_similarity(embeddings,
|
| clip=2.0,
|
| gamma=0.55,
|
| cmap="inferno",
|
| figsize=(20, 20),
|
| dpi=300,
|
| output_fname='embeddings_similarity_plot.png',
|
| return_sims=False
|
| ):
|
|
|
| """
|
| Produces a crisp, high-contrast cosine similarity heatmap.
|
| - clip: percentile clipping (1–5 recommended)
|
| - gamma: nonlinear contrast (0.4–0.8 recommended)
|
|
|
| -----------
|
| Use Example
|
| -----------
|
|
|
| tok_emb = model.net.token_emb.emb.weight.detach().cpu()
|
|
|
| plot_cosine_similarity(tok_emb)
|
| """
|
|
|
|
|
| cos_dist = pairwise_distances(embeddings, metric="cosine")
|
| cos_sim = 1 - cos_dist
|
|
|
|
|
| sim = np.sign(cos_sim) * (np.abs(cos_sim) ** gamma)
|
|
|
|
|
| vmin, vmax = np.percentile(sim, [clip, 100 - clip])
|
|
|
|
|
| plt.figure(figsize=figsize, dpi=dpi)
|
| plt.imshow(sim, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
| plt.colorbar(fraction=0.046, pad=0.04)
|
| plt.title("Embeddings Pairwise Cosine Similarity")
|
| plt.xlabel("Embedding Index")
|
| plt.ylabel("Embeddings Index")
|
| plt.tight_layout()
|
| plt.savefig(output_fname)
|
| plt.show()
|
|
|
| if return_sims:
|
| return sim
|
|
|
|
|
|
|
|
|
|
|
| def unfreeze_last_n_blocks_and_norms(model,
|
| n_last=2,
|
| verbose=True
|
| ):
|
|
|
| """
|
| 2-3 unfrozen layers usually produce good results. Default is 2
|
|
|
| Returns: configured model and optimizer
|
| """
|
|
|
|
|
| for p in model.parameters():
|
| p.requires_grad = False
|
|
|
|
|
| for p in model.net.to_logits.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| layers = model.net.attn_layers.layers
|
| last_blocks = list(layers)[-n_last:]
|
| for block in last_blocks:
|
| for name, p in block.named_parameters():
|
| p.requires_grad = True
|
|
|
|
|
| final_norm = getattr(model.net.attn_layers, "final_norm", None)
|
| if final_norm is not None:
|
| for p in final_norm.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| total = sum(p.numel() for p in model.parameters())
|
|
|
| if verbose:
|
| print(f"Trainable params {trainable:,} / {total:,}")
|
|
|
|
|
| head_params = list(model.net.to_logits.parameters())
|
| head_param_ids = {id(p) for p in head_params}
|
|
|
|
|
| pretrained_params = []
|
| head_only = []
|
|
|
| for p in model.parameters():
|
| if not p.requires_grad:
|
| continue
|
| if id(p) in head_param_ids:
|
| head_only.append(p)
|
| else:
|
| pretrained_params.append(p)
|
|
|
|
|
| trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only)
|
| total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
| assert trainable == total_trainable, "Mismatch in grouped trainable params"
|
|
|
| if verbose:
|
| print(f"Pretrained params: {sum(p.numel() for p in pretrained_params):,}")
|
| print(f"Head params: {sum(p.numel() for p in head_only):,}")
|
| print(f"Total trainable: {total_trainable:,}")
|
|
|
| optim = torch.optim.Adam([
|
| {"params": pretrained_params, "lr": 1e-5},
|
| {"params": head_params, "lr": 5e-5}
|
| ])
|
|
|
| return model, optim
|
|
|
| def unfreeze_last_n_blocks_and_norms_full(model,
|
| n_last_encoder=1,
|
| n_last_decoder=2,
|
| verbose=True
|
| ):
|
|
|
| """
|
| Freeze entire XTransformer, then unfreeze:
|
| - Last `n_last_encoder` encoder blocks (including all parameters in those blocks, e.g., LayerNorms)
|
| - Last `n_last_decoder` decoder blocks (including all parameters in those blocks)
|
| - Final encoder/decoder LayerNorms (if present and has params)
|
| - Decoder's output head (`to_logits`)
|
|
|
| """
|
|
|
| from x_transformer_2_3_1 import LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm
|
|
|
|
|
| for p in model.parameters():
|
| p.requires_grad = False
|
|
|
|
|
| for p in model.decoder.net.to_logits.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| def is_parametrized_norm(module):
|
|
|
| norm_types = (LayerNorm, RMSNorm, ScaleNorm, AdaptiveLayerNorm, AdaptiveRMSNorm)
|
| if isinstance(module, norm_types):
|
| return True
|
|
|
| if isinstance(module, torch.nn.LayerNorm):
|
| return True
|
| return False
|
|
|
|
|
| def unfreeze_last_blocks(transformer_wrapper, n_last):
|
| if n_last <= 0:
|
| return
|
|
|
|
|
| attn_layers = transformer_wrapper.attn_layers
|
| layers = attn_layers.layers
|
| last_blocks = list(layers)[-n_last:]
|
|
|
| for block in last_blocks:
|
|
|
| for p in block.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| for submodule in block.modules():
|
| if is_parametrized_norm(submodule):
|
| for p in submodule.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| final_norm = getattr(attn_layers, 'final_norm', None)
|
| if final_norm is not None and list(final_norm.parameters()):
|
| for p in final_norm.parameters():
|
| p.requires_grad = True
|
|
|
|
|
| unfreeze_last_blocks(model.encoder, n_last_encoder)
|
| unfreeze_last_blocks(model.decoder.net, n_last_decoder)
|
|
|
|
|
|
|
|
|
| head_params = list(model.decoder.net.to_logits.parameters())
|
| head_param_ids = {id(p) for p in head_params}
|
|
|
| pretrained_params = []
|
| head_only = []
|
|
|
| for p in model.parameters():
|
| if not p.requires_grad:
|
| continue
|
| if id(p) in head_param_ids:
|
| head_only.append(p)
|
| else:
|
| pretrained_params.append(p)
|
|
|
|
|
| trainable = sum(p.numel() for p in pretrained_params) + sum(p.numel() for p in head_only)
|
| total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| assert trainable == total_trainable, "Mismatch in grouped trainable params"
|
|
|
| if verbose:
|
| print(f"Trainable params {trainable:,} / {total_trainable:,}")
|
| print(f"Pretrained (enc/dec): {sum(p.numel() for p in pretrained_params):,}")
|
| print(f"Head: {sum(p.numel() for p in head_only):,}")
|
| print(f"Total trainable: {total_trainable:,}")
|
|
|
|
|
| optim = torch.optim.Adam([
|
| {"params": pretrained_params, "lr": 1e-5},
|
| {"params": head_only, "lr": 5e-5}
|
| ])
|
|
|
| return model, optim
|
|
|
|
|
|
|
|
|
|
|
| def merge_encoder_and_decoder(model,
|
| encoder_ckpt,
|
| decoder_ckpt,
|
| print_keys=False,
|
| verbose=True
|
| ):
|
|
|
| if verbose:
|
| print('=' * 70)
|
| print('Merging...')
|
| print('=' * 70)
|
|
|
| if print_keys:
|
| print('=' * 70)
|
| print('Merged model keys:', model.state_dict().keys())
|
| print('=' * 70)
|
|
|
| if verbose:
|
| print('=' * 70)
|
| print('Loading encoder model...')
|
| print('=' * 70)
|
|
|
| enc_ckpt = torch.load(decoder_ckpt, map_location='cpu')
|
| enc_pre_sd = enc_ckpt.get('state_dict', enc_ckpt)
|
|
|
| if print_keys:
|
| print('=' * 70)
|
| print('Encoder model keys:', enc_pre_sd.keys())
|
| print('=' * 70)
|
|
|
| if verbose:
|
| print('=' * 70)
|
| print('Loading decoder model...')
|
| print('=' * 70)
|
|
|
| dec_ckpt = torch.load(encoder_ckpt, map_location='cpu')
|
| dec_pre_sd = dec_ckpt.get('state_dict', dec_ckpt)
|
|
|
| if print_keys:
|
| print('=' * 70)
|
| print('Decoder model keys', dec_pre_sd.keys())
|
| print('=' * 70)
|
|
|
| if verbose:
|
| print('=' * 70)
|
| print('Prepping merged model...')
|
| print('=' * 70)
|
|
|
| model_new_sd = model.state_dict()
|
|
|
| for old_key, tensor in enc_pre_sd.items():
|
|
|
| new_key = 'encoder.' + old_key
|
| if new_key in model_new_sd:
|
| model_new_sd[new_key] = tensor
|
|
|
| for old_key, tensor in dec_pre_sd.items():
|
|
|
| new_key = old_key.replace('net.', 'decoder.net.')
|
| if new_key in model_new_sd:
|
| model_new_sd[new_key] = tensor
|
|
|
| if verbose:
|
| print('=' * 70)
|
| print('Final integrity check...')
|
| print('=' * 70)
|
|
|
|
|
| incompat = model.load_state_dict(model_new_sd, strict=False)
|
|
|
| if verbose:
|
|
|
| print("Missing keys: ", incompat.missing_keys)
|
| print("Unexpected keys: ", incompat.unexpected_keys)
|
|
|
| try:
|
| if verbose:
|
| print('=' * 70)
|
| print('Loading merged model...')
|
|
|
| model.load_state_dict(model_new_sd, strict=True)
|
|
|
| if verbose:
|
| print('Done!')
|
| print('=' * 70)
|
|
|
| return model
|
|
|
| except:
|
| if verbose:
|
| print('Failed to create merged model!')
|
| print('=' * 70)
|
|
|
| return incompat
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import time
|
| from collections import deque
|
| from math import ceil
|
| import numpy as np
|
| import torch
|
| from torch.utils.data import Dataset
|
| from torch import nn
|
| from typing import List, Tuple, Optional, Callable, Sequence
|
|
|
| class BoundaryDataset(Dataset):
|
| def __init__(self, inputs_list, labels_list):
|
| self.inputs = inputs_list
|
| self.labels = labels_list
|
|
|
| def __len__(self):
|
| return len(self.inputs)
|
|
|
| def __getitem__(self, idx):
|
| return self.inputs[idx], self.labels[idx], None
|
|
|
| class BoundaryClassifier(nn.Module):
|
| def __init__(self, num_tokens: int, max_seq_len: int, dim: int = 512,
|
| depth: int = 12, heads: int = 16, num_labels: int = 2,
|
| pad_token_id: int = 384, dropout: float = 0.1):
|
| super().__init__()
|
| self.pad_token_id = pad_token_id
|
|
|
| self.backbone = TransformerWrapper(
|
| num_tokens=num_tokens,
|
| max_seq_len=max_seq_len,
|
| attn_layers=Encoder(dim=dim, depth=depth, heads=heads,
|
| rotary_pos_emb=True, attn_flash=True)
|
| )
|
|
|
| self.classifier = nn.Sequential(
|
| nn.LayerNorm(dim),
|
| nn.Linear(dim, dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim, dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim, num_labels)
|
| )
|
|
|
| def forward(self, input_ids, attn_mask=None):
|
| hidden = self.backbone(input_ids, mask=attn_mask, return_embeddings=True)
|
| logits = self.classifier(hidden)
|
| return logits
|
|
|
| class FocalLoss(nn.Module):
|
| def __init__(self, gamma=2.0, alpha=None, ignore_index=384):
|
| super().__init__()
|
| self.gamma = gamma
|
| self.alpha = alpha
|
| self.ignore_index = ignore_index
|
|
|
| def forward(self, logits, targets, mask=None):
|
| B, N, C = logits.shape
|
| logits_flat = logits.view(B * N, C)
|
| targets_flat = targets.view(B * N)
|
|
|
|
|
| if mask is not None:
|
| valid_mask = (targets_flat != self.ignore_index) & mask.view(-1)
|
| else:
|
| valid_mask = (targets_flat != self.ignore_index)
|
|
|
| if valid_mask.sum() == 0:
|
| return torch.tensor(0.0, device=logits.device)
|
|
|
|
|
| log_probs = torch.log_softmax(logits_flat, dim=-1)
|
|
|
|
|
|
|
| targets_clamped = targets_flat.clamp(0, C - 1)
|
|
|
|
|
| p_t = log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped]
|
| p_t = torch.exp(p_t)
|
|
|
|
|
| nll_loss = -log_probs[torch.arange(len(logits_flat), device=logits.device), targets_clamped]
|
|
|
|
|
| focal_factor = (1.0 - p_t) ** self.gamma
|
| loss = focal_factor * nll_loss
|
|
|
|
|
| if self.alpha is not None:
|
|
|
| alpha_t = self.alpha[targets_clamped]
|
| loss = alpha_t * loss
|
|
|
|
|
| loss = loss * valid_mask.float()
|
|
|
| return loss.sum() / valid_mask.sum().clamp(min=1.0)
|
|
|
| Logger = Optional[Callable[[str], None]]
|
|
|
| def filter_balanced_sequences(
|
| sequences: List[List[int]],
|
| token_types: List[List[int]],
|
| tol: float = 0.1,
|
| min_len: int = 1,
|
| max_len: Optional[int] = None,
|
| balance_target: float = 0.5,
|
| return_indices: bool = False,
|
| verbose: int = 2,
|
| logger: Logger = None,
|
| progress_chunk: int = 50000
|
| ) -> Tuple[List[List[int]], List[List[int]], Optional[List[int]]]:
|
|
|
| """
|
| Filter sequence/token-type pairs to those whose token-type distribution is near-balanced,
|
| with verbosity and lightweight progress reporting.
|
|
|
| Parameters
|
| ----------
|
| sequences : List[List[int]]
|
| Token sequences (not used for balance computation).
|
| token_types : List[List[int]]
|
| Binary token-type lists (0/1) corresponding to sequences.
|
| tol : float
|
| Allowed absolute deviation from balance_target.
|
| min_len : int
|
| Minimum token_types length to consider.
|
| max_len : Optional[int]
|
| Maximum token_types length to consider.
|
| balance_target : float
|
| Target proportion of 1s (0..1).
|
| return_indices : bool
|
| If True, also return kept indices.
|
| verbose : int
|
| 0 = silent, 1 = concise, 2 = detailed chunk diagnostics.
|
| logger : callable or None
|
| If provided, called with status strings instead of/in addition to printing.
|
| progress_chunk : int
|
| Emit chunk updates every `progress_chunk` items when verbose >= 1.
|
|
|
| Returns
|
| -------
|
| filtered_sequences, filtered_token_types, indices_or_none
|
| """
|
|
|
| def _log(msg: str):
|
| if logger:
|
| try:
|
| logger(msg)
|
| except Exception:
|
| pass
|
| if verbose >= 1:
|
| print(msg)
|
|
|
| if len(sequences) != len(token_types):
|
| raise ValueError("`sequences` and `token_types` must have the same length.")
|
|
|
| n = len(token_types)
|
| if n == 0:
|
| _log("Input empty: nothing to do.")
|
| return [], [], ([] if return_indices else None)
|
|
|
| start_all = time.perf_counter()
|
| _log(f"Starting filter: {n} pairs; tol={tol}; target={balance_target}; min_len={min_len}; max_len={max_len}")
|
|
|
|
|
|
|
| lengths = np.empty(n, dtype=np.int32)
|
| counts = np.empty(n, dtype=np.int32)
|
|
|
| t0 = time.perf_counter()
|
| for i, tlist in enumerate(token_types):
|
| lengths[i] = len(tlist)
|
|
|
| counts[i] = sum(tlist)
|
|
|
| if verbose >= 2 and (i + 1) % progress_chunk == 0:
|
| elapsed = time.perf_counter() - t0
|
| _log(f" scanned {i+1}/{n} token_types (elapsed {elapsed:.2f}s)")
|
|
|
| scan_time = time.perf_counter() - t0
|
| _log(f"Scanned counts and lengths in {scan_time:.2f}s")
|
|
|
|
|
| nonzero_mask = lengths > 0
|
| mask = nonzero_mask.copy()
|
| if min_len > 1:
|
| mask &= (lengths >= min_len)
|
| if max_len is not None:
|
| mask &= (lengths <= max_len)
|
|
|
| candidates = int(mask.sum())
|
| _log(f"Candidates after length filtering: {candidates}/{n}")
|
|
|
| if candidates == 0:
|
| _log("No candidates after length filtering. Exiting.")
|
| return [], [], ([] if return_indices else None)
|
|
|
|
|
| lengths_f = lengths.astype(np.float32)
|
| proportions = np.empty_like(lengths_f)
|
|
|
| proportions[mask] = counts[mask].astype(np.float32) / lengths_f[mask]
|
| proportions[~mask] = -1.0
|
|
|
|
|
| balance_mask = np.abs(proportions - float(balance_target)) <= float(tol)
|
| final_mask = mask & balance_mask
|
| kept = int(final_mask.sum())
|
| elapsed_total = time.perf_counter() - start_all
|
| _log(f"Kept {kept}/{n} sequences (elapsed total {elapsed_total:.2f}s)")
|
|
|
| if kept == 0:
|
| _log("No sequences met the balance criterion. Exiting.")
|
| return [], [], ([] if return_indices else None)
|
|
|
|
|
| keep_idx = np.nonzero(final_mask)[0].tolist()
|
|
|
|
|
|
|
| t_build = time.perf_counter()
|
| filtered_sequences = [sequences[i] for i in keep_idx]
|
| filtered_token_types = [token_types[i] for i in keep_idx]
|
| build_time = time.perf_counter() - t_build
|
|
|
| _log(f"Built filtered lists: {kept} items (build time {build_time:.2f}s)")
|
|
|
|
|
| if verbose >= 1:
|
|
|
| kept_props = proportions[final_mask]
|
| mean_prop = float(np.mean(kept_props))
|
| std_prop = float(np.std(kept_props))
|
| min_prop = float(np.min(kept_props))
|
| max_prop = float(np.max(kept_props))
|
| _log(f"Kept proportions stats: mean={mean_prop:.4f}, std={std_prop:.4f}, min={min_prop:.4f}, max={max_prop:.4f}")
|
|
|
| if return_indices:
|
| return filtered_sequences, filtered_token_types, keep_idx
|
| else:
|
| return filtered_sequences, filtered_token_types, None
|
|
|
| def compute_class_counts_from_list(labels_list: Sequence[Sequence[int]],
|
| num_labels: int = 2,
|
| pad_idx: int = 384) -> torch.LongTensor:
|
| if len(labels_list) == 0:
|
| return torch.zeros(num_labels, dtype=torch.long)
|
|
|
| counts = [0] * num_labels
|
| for lbl in range(num_labels):
|
| counts[lbl] = sum(seq.count(lbl) for seq in labels_list)
|
|
|
| if 0 <= pad_idx < num_labels:
|
| pad_total = sum(seq.count(pad_idx) for seq in labels_list)
|
| counts[pad_idx] = max(0, counts[pad_idx] - pad_total)
|
|
|
| return torch.tensor(counts, dtype=torch.long)
|
|
|
| def compute_class_weights(
|
| labels_list,
|
| num_labels=2,
|
| pad_idx=384,
|
| smoothing=0.0,
|
| power=1.0,
|
| max_ratio=50.0
|
| ):
|
|
|
| """
|
| Stable, imbalance-preserving class weights.
|
| - No renormalization that destroys imbalance
|
| - Optional smoothing (default 0)
|
| - Optional exponent scaling (power)
|
| - Optional cap on extreme ratios
|
| """
|
|
|
|
|
| counts = compute_class_counts_from_list(labels_list, num_labels, pad_idx).float()
|
| counts = torch.clamp(counts, min=1.0)
|
|
|
|
|
| inv = 1.0 / counts
|
|
|
|
|
| if smoothing > 0:
|
| inv = inv + smoothing
|
|
|
|
|
| if power != 1.0:
|
| inv = inv ** power
|
|
|
|
|
| inv = inv / inv.min()
|
|
|
|
|
| inv = torch.clamp(inv, max=max_ratio)
|
|
|
| return inv
|
|
|
| def collate_fn_from_lists(batch, pad_token_id=384):
|
| input_seqs = [list(x[0]) for x in batch]
|
| label_seqs = [list(x[1]) for x in batch]
|
|
|
|
|
| lengths = [len(s) for s in input_seqs]
|
| max_len = max(lengths) if lengths else 0
|
| B = len(input_seqs)
|
|
|
| input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long)
|
| labels = torch.full((B, max_len), pad_token_id, dtype=torch.long)
|
| attn_mask = torch.zeros((B, max_len), dtype=torch.bool)
|
|
|
| for i, (xseq, yseq, length) in enumerate(zip(input_seqs, label_seqs, lengths)):
|
| if length > 0:
|
| input_ids[i, :length] = torch.LongTensor(xseq)
|
| labels[i, :length] = torch.LongTensor(yseq[:length])
|
|
|
| attn_mask[i, :length] = True
|
|
|
| return input_ids, labels, attn_mask
|
|
|
|
|
|
|
| |