"""Talkie 13B transformer — HuggingFace Transformers implementation. A faithful port of the reference implementation at https://github.com/talkie-lm/talkie that is compatible with the ``transformers`` ``PreTrainedModel`` API so the model can be loaded with ``AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)``. """ from __future__ import annotations from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import GenerationMixin, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast try: from .configuration_talkie import TalkieConfig except ImportError: from configuration_talkie import TalkieConfig # --------------------------------------------------------------------------- # Small helper modules (match reference exactly) # --------------------------------------------------------------------------- class TalkieHeadGain(nn.Module): """Per-head scalar gain applied to queries.""" def __init__(self, n_head: int): super().__init__() self.head_g = nn.Parameter(torch.ones(n_head)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, S, n_head, head_dim) return x * self.head_g.type_as(x).view(1, 1, -1, 1) class TalkieWeightGain(nn.Module): """Scalar gain applied to the lm_head weight matrix.""" def __init__(self): super().__init__() self.w_g = nn.Parameter(torch.ones(1)) def forward(self, w: torch.Tensor) -> torch.Tensor: return w * self.w_g.type_as(w) class TalkieActGain(nn.Module): """Scalar activation gain with a configurable initial value.""" def __init__(self, init_value: float): super().__init__() self.a_g = nn.Parameter(torch.ones(1) * init_value) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.a_g.type_as(x) # --------------------------------------------------------------------------- # RoPE # --------------------------------------------------------------------------- def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """Apply rotary position embeddings (half-rotation convention).""" assert x.ndim == 4 # (B, S, H, D) d = x.shape[3] // 2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3).type_as(x) def _precompute_rotary_embeddings( seq_len: int, head_dim: int, base: float, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) inv_freq = 1.0 / (base ** (channel_range / head_dim)) t = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() cos, sin = cos.bfloat16(), sin.bfloat16() cos, sin = cos[None, :, None, :], sin[None, :, None, :] return cos, sin # --------------------------------------------------------------------------- # Attention & MLP # --------------------------------------------------------------------------- class TalkieSelfAttention(nn.Module): """Multi-head self-attention with QK-norm and per-head gain.""" def __init__(self, config: TalkieConfig): super().__init__() self.n_head = config.num_attention_heads self.head_dim = config.head_dim n_state = config.hidden_size self.attn_query = nn.Linear(n_state, n_state, bias=False) self.attn_key = nn.Linear(n_state, n_state, bias=False) self.attn_value = nn.Linear(n_state, n_state, bias=False) self.attn_resid = nn.Linear(n_state, n_state, bias=False) self.head_gain = TalkieHeadGain(config.num_attention_heads) def forward( self, x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: bsz, seq_len, _ = x.size() q = self.attn_query(x).view(bsz, seq_len, self.n_head, self.head_dim) k = self.attn_key(x).view(bsz, seq_len, self.n_head, self.head_dim) v = self.attn_value(x).view(bsz, seq_len, self.n_head, self.head_dim) cos, sin = cos_sin q, k = _apply_rotary_emb(q, cos, sin), _apply_rotary_emb(k, cos, sin) q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) q = self.head_gain(q) y = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, ) y = y.transpose(1, 2).contiguous().view_as(x) return self.attn_resid(y) class TalkieMLP(nn.Module): """SwiGLU feed-forward network.""" def __init__(self, config: TalkieConfig): super().__init__() n_state = config.hidden_size n_mlp = config.intermediate_size self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False) self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False) self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp_resid(F.silu(self.mlp_gate(x)) * self.mlp_linear(x)) # --------------------------------------------------------------------------- # Transformer block # --------------------------------------------------------------------------- class TalkieDecoderLayer(nn.Module): """Single transformer block with embedding skip connections.""" def __init__(self, config: TalkieConfig): super().__init__() gain_init = (2 * config.num_hidden_layers) ** -0.5 self.attn = TalkieSelfAttention(config) self.attn_gain = TalkieActGain(gain_init) self.mlp = TalkieMLP(config) self.mlp_gain = TalkieActGain(gain_init) self.embed_skip = TalkieActGain(0.0) def forward( self, e_x: torch.Tensor, x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: x = x + self.attn_gain(self.attn(F.rms_norm(x, (x.shape[-1],)), cos_sin)) x = x + self.mlp_gain(self.mlp(F.rms_norm(x, (x.shape[-1],)))) x = x + self.embed_skip(e_x) return x # --------------------------------------------------------------------------- # Full model # --------------------------------------------------------------------------- class TalkieModel(nn.Module): """Talkie 13B decoder stack (no lm_head).""" def __init__(self, config: TalkieConfig): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.blocks = nn.ModuleList( [TalkieDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) def forward( self, input_ids: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: x = self.embed(input_ids) x = F.rms_norm(x, (x.shape[-1],)) e_x = x for block in self.blocks: x = block(e_x, x, cos_sin) x = F.rms_norm(x, (x.shape[-1],)) return x # --------------------------------------------------------------------------- # CausalLM wrapper (transformers-compatible) # --------------------------------------------------------------------------- class TalkieForCausalLM(PreTrainedModel, GenerationMixin): """Talkie 13B causal language model for HuggingFace Transformers. Load with:: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "lewtun/talkie-1930-13b-it-hf", trust_remote_code=True ) """ config_class = TalkieConfig _no_split_modules = ["TalkieDecoderLayer"] def __init__(self, config: TalkieConfig): super().__init__(config) self.model = TalkieModel(config) self.lm_head = nn.Parameter( torch.zeros(config.vocab_size, config.hidden_size) ) self.lm_head_gain = TalkieWeightGain() # RoPE cos/sin are computed lazily on first forward pass so that # from_pretrained (which may construct on a meta device) works. self._rope_cos: torch.Tensor | None = None self._rope_sin: torch.Tensor | None = None self.post_init() def _get_rope( self, seq_len: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """Return cached RoPE (cos, sin) tensors, recomputing if needed.""" if ( self._rope_cos is None or self._rope_cos.shape[1] < seq_len or self._rope_cos.device != device ): cos, sin = _precompute_rotary_embeddings( max(seq_len, self.config.max_position_embeddings), self.config.head_dim, self.config.rope_theta, device=device, ) self._rope_cos = cos self._rope_sin = sin return self._rope_cos[:, :seq_len], self._rope_sin[:, :seq_len] def get_input_embeddings(self): return self.model.embed def set_input_embeddings(self, value): self.model.embed = value def prepare_inputs_for_generation(self, input_ids, **kwargs): """Always pass the full accumulated input_ids (no KV cache).""" return {"input_ids": input_ids} def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, # unused, kept for API compat labels: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[CausalLMOutputWithPast, Tuple]: _, seq_len = input_ids.shape cos_sin = self._get_rope(seq_len, input_ids.device) hidden_states = self.model(input_ids, cos_sin) logits = F.linear(hidden_states, self.lm_head_gain(self.lm_head)).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) return CausalLMOutputWithPast(loss=loss, logits=logits)