import torch import torch.nn as nn from torch.nn import functional as F from transformers import PreTrainedModel, PretrainedConfig from typing import Optional, Tuple, Union, Dict # --- 1. Configuration Class --- class CinnabarLMConfig(PretrainedConfig): model_type = "cinnabarlm" def __init__( self, vocab_size=4096, hidden_size=192, num_hidden_layers=6, num_attention_heads=8, intermediate_size=768, max_position_embeddings=2048, rope_theta=10000.0, rms_norm_eps=1e-6, initializer_range=0.02, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rms_norm_eps = rms_norm_eps self.initializer_range = initializer_range # --- 2. Functional Components --- class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def precompute_rope_freqs(dim, seq_len, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(seq_len) freqs = torch.outer(t, freqs) # Using complex view for rotations return torch.polar(torch.ones_like(freqs), freqs) # --- 3. Attention & Blocks --- class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.num_attention_heads self.hidden_size = config.hidden_size self.head_dim = self.hidden_size // self.n_head self.qkv = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) self.proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def forward(self, x, freqs_cis): B, T, C = x.size() q, k, v = self.qkv(x).chunk(3, dim=-1) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # Apply RoPE q_complex = torch.view_as_complex(q.reshape(*q.shape[:-1], -1, 2)) k_complex = torch.view_as_complex(k.reshape(*k.shape[:-1], -1, 2)) freqs_cis = freqs_cis[:T].view(1, 1, T, -1) q = torch.view_as_real(q_complex * freqs_cis).flatten(3) k = torch.view_as_real(k_complex * freqs_cis).flatten(3) # Scaled Dot-Product Attention (Flash Attention where available) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.proj(y) class Block(nn.Module): def __init__(self, config): super().__init__() self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention(config) self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size, bias=False), nn.SiLU(), # SwiGLU / Llama-style activation nn.Linear(config.intermediate_size, config.hidden_size, bias=False) ) def forward(self, x, freqs_cis): x = x + self.attn(self.norm1(x), freqs_cis) x = x + self.mlp(self.norm2(x)) return x # --- 4. Main Model Class --- class CinnabarLMPreTrainedModel(PreTrainedModel): config_class = CinnabarLMConfig base_model_prefix = "model" supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) class CinnabarLMForCausalLM(CinnabarLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Precompute RoPE and register as buffer (won't be trained, but saved) freqs_cis = precompute_rope_freqs( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, theta=config.rope_theta ) self.register_buffer("freqs_cis", freqs_cis) # Initialize weights self.post_init() def forward( self, input_ids: torch.LongTensor = None, labels: Optional[torch.LongTensor] = None, **kwargs ) -> Dict[str, torch.Tensor]: x = self.token_embedding(input_ids) for layer in self.layers: x = layer(x, self.freqs_cis) x = self.norm_f(x) logits = self.lm_head(x) loss = None if labels is not None: # Shift so that tokens predict the next token shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return { "loss": loss, "logits": logits }