"""Talkie 13B transformer — HuggingFace Transformers implementation. A faithful port of the reference implementation at https://github.com/talkie-lm/talkie that is compatible with the ``transformers`` ``PreTrainedModel`` API so the model can be loaded with ``AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)``. This fork adds three things to the upstream ``lewtun/talkie-1930-13b-it-hf`` port: 1. **KV cache support.** ``forward`` accepts and returns ``past_key_values`` as a ``DynamicCache``, and ``prepare_inputs_for_generation`` slices ``input_ids`` to the new tokens when a cache is present. Without this, every decode step reprocesses the full sequence (O(N²)) and any downstream loop that manages cache state itself (e.g. saklas's steering hook path) breaks. Math is byte-equivalent to the no-cache path; see the attention call site for causal-mask handling. 2. **``output_hidden_states`` / ``output_attentions``.** Standard transformers fields are populated when requested, so probes, steering tools, and any other interpretability work that pulls per-layer activations or attention maps gets real tensors instead of ``None``. 3. **``attention_mask`` and ``position_ids``.** Honored properly: padding masks compose with the causal mask in SDPA, and explicit ``position_ids`` override the cache-derived RoPE positions. Left-padded batched generation works end to end. """ from __future__ import annotations import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import GenerationMixin, PreTrainedModel from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) try: from .configuration_talkie import TalkieConfig except ImportError: from configuration_talkie import TalkieConfig # --------------------------------------------------------------------------- # Small helper modules (match reference exactly) # --------------------------------------------------------------------------- class TalkieHeadGain(nn.Module): """Per-head scalar gain applied to queries.""" def __init__(self, n_head: int): super().__init__() self.head_g = nn.Parameter(torch.ones(n_head)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, S, n_head, head_dim) return x * self.head_g.type_as(x).view(1, 1, -1, 1) class TalkieWeightGain(nn.Module): """Scalar gain applied to the lm_head weight matrix.""" def __init__(self): super().__init__() self.w_g = nn.Parameter(torch.ones(1)) def forward(self, w: torch.Tensor) -> torch.Tensor: return w * self.w_g.type_as(w) class TalkieActGain(nn.Module): """Scalar activation gain with a configurable initial value.""" def __init__(self, init_value: float): super().__init__() self.a_g = nn.Parameter(torch.ones(1) * init_value) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.a_g.type_as(x) # --------------------------------------------------------------------------- # RoPE # --------------------------------------------------------------------------- def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """Apply rotary position embeddings (half-rotation convention).""" assert x.ndim == 4 # (B, S, H, D) d = x.shape[3] // 2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3).type_as(x) def _precompute_rotary_embeddings( seq_len: int, head_dim: int, base: float, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) inv_freq = 1.0 / (base ** (channel_range / head_dim)) t = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() cos, sin = cos.bfloat16(), sin.bfloat16() cos, sin = cos[None, :, None, :], sin[None, :, None, :] return cos, sin # --------------------------------------------------------------------------- # Attention & MLP # --------------------------------------------------------------------------- class TalkieSelfAttention(nn.Module): """Multi-head self-attention with QK-norm, per-head gain, and KV cache.""" def __init__(self, config: TalkieConfig): super().__init__() self.n_head = config.num_attention_heads self.head_dim = config.head_dim n_state = config.hidden_size self.attn_query = nn.Linear(n_state, n_state, bias=False) self.attn_key = nn.Linear(n_state, n_state, bias=False) self.attn_value = nn.Linear(n_state, n_state, bias=False) self.attn_resid = nn.Linear(n_state, n_state, bias=False) self.head_gain = TalkieHeadGain(config.num_attention_heads) def forward( self, x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], past_key_values: Optional[Cache] = None, layer_idx: int = 0, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, seq_len, _ = x.size() q = self.attn_query(x).view(bsz, seq_len, self.n_head, self.head_dim) k = self.attn_key(x).view(bsz, seq_len, self.n_head, self.head_dim) v = self.attn_value(x).view(bsz, seq_len, self.n_head, self.head_dim) # cos/sin are pre-sliced (or gathered) for the *new* positions # only — the caller resolves position offsets and any explicit # ``position_ids`` before getting here. Apply RoPE, QK-norm, and # per-head Q gain in the natural ``[B, S, H, D]`` layout that the # helpers expect. cos, sin = cos_sin q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin) q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) q = self.head_gain(q) # Transpose to ``[B, H, S, D]`` for cache + SDPA — the standard # HF Cache layout. ``DynamicCache.get_seq_length`` reads # ``shape[-2]``, so this layout makes the cache self-describing. q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # ``cache.update`` appends k/v to layer_idx's slot and returns # the full (past + new) tensors. Stored post-RoPE and # post-RMSNorm; with absolute RoPE positions the cached K[i] is # the same value the full-sequence forward would produce at # position i, so prefill-then-decode is concatenation-equivalent # to a single full-sequence forward. if past_key_values is not None: k, v = past_key_values.update(k, v, layer_idx) kv_len = k.size(2) attn_weights: Optional[torch.Tensor] = None # The fast SDPA path runs when no padding mask is supplied and the # caller doesn't ask for raw weights. Three regimes there: # # prefill (q_len == k_len, no past): is_causal=True. Standard # upper-left triangular mask is correct. # # decode (q_len == 1 with past): is_causal=False. The single new # query position can attend to every past key (and itself) — # no mask is needed. (is_causal=True would wrongly mask # everything but k[0].) # # multi-token re-feed (q_len > 1 with past): explicit lower-right # causal mask aligned to the END of k, since the q positions # correspond to the LAST q_len positions of the full sequence. if attention_mask is None and not output_attentions: if seq_len == kv_len: attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) elif seq_len == 1: attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=False) else: offset = kv_len - seq_len mask = torch.ones( seq_len, kv_len, dtype=torch.bool, device=x.device ).tril(diagonal=offset) attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) else: # Build a 4D bool mask combining causality with any padding mask # the caller passed. ``True`` = attend. if seq_len == 1: causal = torch.ones(1, kv_len, dtype=torch.bool, device=x.device) elif seq_len == kv_len: causal = torch.ones( seq_len, kv_len, dtype=torch.bool, device=x.device ).tril() else: offset = kv_len - seq_len causal = torch.ones( seq_len, kv_len, dtype=torch.bool, device=x.device ).tril(diagonal=offset) if attention_mask is not None: # Expected format: bool or 0/1 integer mask covering the # full kv length. Float "additive" masks aren't supported # here — converting via ``.to(torch.bool)`` would treat # negative additive entries as ``True`` and silently invert # the mask. Standard HF generate always passes the # kv-length 2D form so this is rarely a problem in # practice; raise loudly when a caller breaks it. if attention_mask.dtype.is_floating_point: raise TypeError( "attention_mask must be a bool or integer 0/1 mask, " "not a float additive mask." ) pad = attention_mask.to(torch.bool) if pad.dim() != 2: raise ValueError( f"attention_mask must be 2D [B, kv_len]; got shape {tuple(pad.shape)}" ) if pad.shape[-1] != kv_len: raise ValueError( f"attention_mask kv-length {pad.shape[-1]} does not match " f"the cache's kv_len {kv_len}. Pass the full kv-length mask " f"so cached positions can be masked correctly — front-padding " f"with ones would wrongly mark left-padded prefill tokens as real." ) pad = pad[:, None, None, :] # [B, 1, 1, kv_len] mask4d = pad & causal[None, None, :, :] else: mask4d = causal[None, None, :, :] if output_attentions: # Manual attention so we can return the softmax weights. # Slower than SDPA, opt-in only. Softmax in fp32 then cast # back — same pattern as llama/qwen, otherwise bf16 row # sums drift by ~3e-3. scale = 1.0 / math.sqrt(self.head_dim) scores = torch.matmul(q, k.transpose(-2, -1)) * scale scores = scores.masked_fill( ~mask4d, torch.finfo(scores.dtype).min ) attn_weights = F.softmax(scores, dim=-1, dtype=torch.float32).to(q.dtype) attn_out = torch.matmul(attn_weights, v) else: attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask4d ) y = attn_out.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) return self.attn_resid(y), attn_weights class TalkieMLP(nn.Module): """SwiGLU feed-forward network.""" def __init__(self, config: TalkieConfig): super().__init__() n_state = config.hidden_size n_mlp = config.intermediate_size self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False) self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False) self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp_resid(F.silu(self.mlp_gate(x)) * self.mlp_linear(x)) # --------------------------------------------------------------------------- # Transformer block # --------------------------------------------------------------------------- class TalkieDecoderLayer(nn.Module): """Single transformer block with embedding skip connections and KV cache.""" def __init__(self, config: TalkieConfig): super().__init__() gain_init = (2 * config.num_hidden_layers) ** -0.5 self.attn = TalkieSelfAttention(config) self.attn_gain = TalkieActGain(gain_init) self.mlp = TalkieMLP(config) self.mlp_gain = TalkieActGain(gain_init) self.embed_skip = TalkieActGain(0.0) def forward( self, e_x: torch.Tensor, x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], past_key_values: Optional[Cache] = None, layer_idx: int = 0, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: attn_out, attn_weights = self.attn( F.rms_norm(x, (x.shape[-1],)), cos_sin, past_key_values=past_key_values, layer_idx=layer_idx, attention_mask=attention_mask, output_attentions=output_attentions, ) x = x + self.attn_gain(attn_out) x = x + self.mlp_gain(self.mlp(F.rms_norm(x, (x.shape[-1],)))) x = x + self.embed_skip(e_x) return x, attn_weights # --------------------------------------------------------------------------- # Full model # --------------------------------------------------------------------------- class TalkieModel(nn.Module): """Talkie 13B decoder stack (no lm_head).""" def __init__(self, config: TalkieConfig): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.blocks = nn.ModuleList( [TalkieDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) def forward( self, input_ids: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> BaseModelOutputWithPast: x = self.embed(input_ids) x = F.rms_norm(x, (x.shape[-1],)) # In cached generation, e_x is computed from the new tokens only # — and so is the running x — so the per-block residual # `x = x + g_skip * e_x` (which adds e_x with the same length as # x) lines up naturally throughout the loop. The cached K/V # already contain the influence of past e_x via their own # attention computation, so the math is equivalent to a # full-sequence forward. e_x = x all_hidden_states: Tuple[torch.Tensor, ...] = () all_attentions: Tuple[torch.Tensor, ...] = () for L, block in enumerate(self.blocks): if output_hidden_states: all_hidden_states = all_hidden_states + (x,) x, attn_weights = block( e_x, x, cos_sin, past_key_values=past_key_values, layer_idx=L, attention_mask=attention_mask, output_attentions=output_attentions, ) if output_attentions and attn_weights is not None: all_attentions = all_attentions + (attn_weights,) x = F.rms_norm(x, (x.shape[-1],)) if output_hidden_states: all_hidden_states = all_hidden_states + (x,) return BaseModelOutputWithPast( last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states if output_hidden_states else None, attentions=all_attentions if output_attentions else None, ) # --------------------------------------------------------------------------- # CausalLM wrapper (transformers-compatible) # --------------------------------------------------------------------------- class TalkieForCausalLM(PreTrainedModel, GenerationMixin): """Talkie 13B causal language model for HuggingFace Transformers. Load with:: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "a9lim/talkie-1930-13b-it-hf-cached", trust_remote_code=True ) """ config_class = TalkieConfig _no_split_modules = ["TalkieDecoderLayer"] _supports_cache_class = True supports_gradient_checkpointing = False def __init__(self, config: TalkieConfig): super().__init__(config) self.model = TalkieModel(config) self.lm_head = nn.Parameter( torch.zeros(config.vocab_size, config.hidden_size) ) self.lm_head_gain = TalkieWeightGain() # RoPE cos/sin are computed lazily on first forward pass so that # from_pretrained (which may construct on a meta device) works. self._rope_cos: torch.Tensor | None = None self._rope_sin: torch.Tensor | None = None self.post_init() def _ensure_rope(self, max_pos: int, device: torch.device) -> None: if ( self._rope_cos is None or self._rope_cos.shape[1] < max_pos or self._rope_cos.device != device ): cos, sin = _precompute_rotary_embeddings( max(max_pos, self.config.max_position_embeddings), self.config.head_dim, self.config.rope_theta, device=device, ) self._rope_cos = cos self._rope_sin = sin def _get_rope( self, seq_len: int, position_offset: int, device: torch.device, position_ids: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return RoPE (cos, sin) for the requested positions. Fast path (``position_ids is None``): slice the precomputed table for contiguous positions ``[position_offset, position_offset + seq_len)``, no CPU sync. Explicit path (``position_ids`` provided): gather per-token, supports left-padded batched generation and any caller that drives position encoding directly. Costs one ``.item()`` to size the table. """ if position_ids is None: self._ensure_rope(position_offset + seq_len, device) return ( self._rope_cos[:, position_offset : position_offset + seq_len], self._rope_sin[:, position_offset : position_offset + seq_len], ) max_pos = int(position_ids.max().item()) + 1 self._ensure_rope(max_pos, device) # Tables are [1, max_seq, 1, D/2]; index [0] gives [max_seq, 1, D/2], # then advanced indexing by [B, S] yields [B, S, 1, D/2]. cos = self._rope_cos[0, position_ids] sin = self._rope_sin[0, position_ids] return cos, sin def get_input_embeddings(self): return self.model.embed def set_input_embeddings(self, value): self.model.embed = value @staticmethod def _past_len(past_key_values: Optional[Cache]) -> int: if past_key_values is None: return 0 if hasattr(past_key_values, "get_seq_length"): return int(past_key_values.get_seq_length()) return 0 def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, position_ids=None, use_cache=True, **kwargs, ): past_len = self._past_len(past_key_values) if past_len > 0 and past_len < input_ids.shape[1]: input_ids = input_ids[:, past_len:] # If the caller didn't supply position_ids but did supply a # (possibly left-padded) attention_mask, derive positions from # the cumulative real-token count. Padded positions get a # placeholder of 1 — they're masked out in attention anyway, so # the exact value doesn't matter as long as it's in-range. if position_ids is None and attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_len > 0: position_ids = position_ids[:, -input_ids.shape[1]:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "position_ids": position_ids, "use_cache": use_cache, } def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[CausalLMOutputWithPast, Tuple]: _, seq_len = input_ids.shape output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # Materialize a DynamicCache when caching is requested but the # caller hasn't supplied one (e.g. first-step ``generate``). return_cache = use_cache if use_cache is not None else True if return_cache and past_key_values is None: past_key_values = DynamicCache() # HF generate passes position_ids covering the entire sequence # (kv_len) rather than just the new tokens (q_len). Accept that # by trimming the trailing q_len positions, which line up with # input_ids post-cache. A shorter mask is a caller bug; we # don't try to recover. if position_ids is not None and position_ids.shape[-1] != seq_len: if position_ids.shape[-1] < seq_len: raise ValueError( f"position_ids covers {position_ids.shape[-1]} positions but " f"input_ids has {seq_len} new tokens. Pass at least one " f"position per new token." ) position_ids = position_ids[:, -seq_len:] # When position_ids is omitted but a non-trivial attention_mask # is supplied, derive positions from the cumulative mask so # left-padded direct ``forward`` calls (no GenerationMixin) Just # Work. Padded positions get a placeholder of 1; they're masked # out of attention anyway. if ( position_ids is None and attention_mask is not None and attention_mask.dtype != torch.bool and attention_mask.dim() == 2 and (attention_mask == 0).any() ): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if position_ids.shape[-1] > seq_len: position_ids = position_ids[:, -seq_len:] position_offset = self._past_len(past_key_values) cos_sin = self._get_rope( seq_len, position_offset, input_ids.device, position_ids=position_ids ) outputs = self.model( input_ids, cos_sin, past_key_values=past_key_values, attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) hidden_states = outputs.last_hidden_state logits = F.linear(hidden_states, self.lm_head_gain(self.lm_head)).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) if not return_dict: output = (logits,) if return_cache: output = output + (outputs.past_key_values,) if output_hidden_states: output = output + (outputs.hidden_states,) if output_attentions: output = output + (outputs.attentions,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if return_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )