"""GPT-S-1.4M model implementation for Hugging Face transformers.""" import math from typing import Optional import torch import torch.nn as nn from torch.nn import functional as F from transformers import PreTrainedModel from transformers.cache_utils import DynamicCache from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_gpts3 import GPTS14MConfig class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return (x.float() * rms).type_as(x) * self.weight def build_rope_inv_freq(head_dim, theta=2500.0): return 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) def precompute_freqs_cis(head_dim, seq_len, theta=2500.0): freqs = build_rope_inv_freq(head_dim, theta) t = torch.arange(seq_len, dtype=torch.float32) freqs = torch.outer(t, freqs) return torch.polar(torch.ones_like(freqs), freqs) def apply_rotary_emb(q, k, freqs_cis): q_complex = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) k_complex = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) q_out = torch.view_as_real(q_complex * freqs_cis).flatten(-2) k_out = torch.view_as_real(k_complex * freqs_cis).flatten(-2) return q_out.type_as(q), k_out.type_as(k) class GPTS14MAttention(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.n_head = config.num_attention_heads self.n_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.n_rep = self.n_head // self.n_kv_heads self.xsa_projection = config.xsa_projection self.q_proj = nn.Linear(config.hidden_size, self.n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.n_head * self.head_dim, config.hidden_size, bias=False) def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None): B, T, _ = x.size() q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v_current = v q, k = apply_rotary_emb(q, k, freqs_cis) if past_key_value is not None: k, v = past_key_value.update(k, v, self.layer_idx) S = k.size(2) k = k.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim).reshape(B, self.n_head, S, self.head_dim) v = v.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim).reshape(B, self.n_head, S, self.head_dim) is_causal = past_key_value is None or past_key_value.get_seq_length(self.layer_idx) == T attn_mask = None if attention_mask is not None: key_pad = attention_mask.to(torch.bool)[:, None, None, :] if is_causal and T > 1: causal = torch.ones(T, S, dtype=torch.bool, device=x.device).tril(diagonal=S - T) attn_mask = key_pad & causal[None, None, :, :] else: attn_mask = key_pad.expand(B, 1, T, S) is_causal = False y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal) if self.xsa_projection: if v.size(2) == T: v_proj_base = v else: v_proj_base = v_current.unsqueeze(2).expand(B, self.n_kv_heads, self.n_rep, T, self.head_dim).reshape(B, self.n_head, T, self.head_dim) v_n = F.normalize(v_proj_base, dim=-1) y = y - (y * v_n).sum(dim=-1, keepdim=True) * v_n y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) return self.o_proj(y) class GPTS14MSwiGLUMLP(nn.Module): def __init__(self, config): super().__init__() self.w_gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.w_up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.w_down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) class GPTS14MBlock(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = GPTS14MAttention(config, layer_idx) self.ln_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = GPTS14MSwiGLUMLP(config) def forward(self, x, freqs_cis, past_key_value=None, use_cache=False, attention_mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, past_key_value, use_cache, attention_mask=attention_mask) x = x + self.mlp(self.ln_2(x)) return x class GPTS14MPreTrainedModel(PreTrainedModel): config_class = GPTS14MConfig base_model_prefix = "transformer" supports_gradient_checkpointing = False def _init_weights(self, module): std = self.config.hidden_size ** -0.5 if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=std) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=std) class GPTS14MForCausalLM(GPTS14MPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.hidden_size), h=nn.ModuleList([GPTS14MBlock(config, i) for i in range(config.num_hidden_layers)]), ln_f=RMSNorm(config.hidden_size, eps=config.rms_norm_eps), )) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.transformer["wte"].weight self._freqs_cis_cache = None self.post_init() def get_input_embeddings(self): return self.transformer["wte"] def set_input_embeddings(self, value): self.transformer["wte"] = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): if past_key_values is not None and past_key_values.get_seq_length() > 0: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True, } def _get_freqs_cis(self, seq_len, device): cache = self._freqs_cis_cache if cache is None or cache.device != device or cache.size(0) < seq_len: cache = precompute_freqs_cis(self.config.head_dim, seq_len, self.config.rope_theta).to(device) self._freqs_cis_cache = cache return cache[:seq_len] def forward( self, input_ids, attention_mask=None, labels=None, past_key_values: Optional[DynamicCache] = None, use_cache=False, **kwargs, ): B, T = input_ids.size() if use_cache and past_key_values is None: past_key_values = DynamicCache() past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 x = self.transformer["wte"](input_ids) freqs_cis = self._get_freqs_cis(past_len + T, input_ids.device)[past_len:] for block in self.transformer["h"]: x = block(x, freqs_cis, past_key_values if use_cache else None, use_cache, attention_mask=attention_mask) x = self.transformer["ln_f"](x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values if use_cache else None, ) # Backward-compatible aliases for code that expects older class names. GPTS3ForCausalLM = GPTS14MForCausalLM GPTX3ForCausalLM = GPTS14MForCausalLM