| """ |
| Small Language Model (SLM) - Transformer from Scratch |
| ====================================================== |
| Arsitektur Transformer decoder-only (GPT-style) untuk Bahasa Indonesia. |
| Dibangun dari nol menggunakan PyTorch. |
| |
| Author: Jekardah AI Lab |
| """ |
|
|
| import math |
| import json |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass, asdict |
| from typing import Optional |
|
|
|
|
| @dataclass |
| class SLMConfig: |
| """Konfigurasi model SLM.""" |
| vocab_size: int = 32000 |
| embed_dim: int = 256 |
| num_heads: int = 4 |
| num_layers: int = 4 |
| ffn_dim: int = 512 |
| max_seq_len: int = 128 |
| dropout: float = 0.1 |
| layer_norm_eps: float = 1e-5 |
|
|
| def save(self, path: str): |
| with open(path, "w") as f: |
| json.dump(asdict(self), f, indent=2) |
|
|
| @classmethod |
| def load(cls, path: str) -> "SLMConfig": |
| with open(path, "r") as f: |
| return cls(**json.load(f)) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization (lebih efisien dari LayerNorm).""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * rms * self.weight |
|
|
|
|
| class RotaryPositionalEncoding(nn.Module): |
| """ |
| Rotary Position Embedding (RoPE). |
| Teknik modern yang dipakai LLaMA, Qwen, dll. |
| """ |
|
|
| def __init__(self, dim: int, max_seq_len: int = 128, base: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| |
| t = torch.arange(max_seq_len).float() |
| freqs = torch.outer(t, inv_freq) |
| cos_cached = freqs.cos() |
| sin_cached = freqs.sin() |
| self.register_buffer("cos_cached", cos_cached) |
| self.register_buffer("sin_cached", sin_cached) |
|
|
| def forward(self, seq_len: int): |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """Apply rotary embeddings to input tensor.""" |
| |
| head_dim = x.shape[-1] |
| x1 = x[..., : head_dim // 2] |
| x2 = x[..., head_dim // 2:] |
|
|
| cos = cos[:x.shape[2]].unsqueeze(0).unsqueeze(0) |
| sin = sin[:x.shape[2]].unsqueeze(0).unsqueeze(0) |
|
|
| rotated = torch.cat((-x2, x1), dim=-1) |
| x_rope = x * torch.cat((cos, cos), dim=-1) + rotated * torch.cat((sin, sin), dim=-1) |
| return x_rope |
|
|
|
|
| class MultiHeadSelfAttention(nn.Module): |
| """ |
| Multi-Head Self Attention dengan causal mask. |
| Setiap token hanya bisa "melihat" token sebelumnya (autoregressive). |
| """ |
|
|
| def __init__(self, config: SLMConfig): |
| super().__init__() |
| self.num_heads = config.num_heads |
| self.head_dim = config.embed_dim // config.num_heads |
| self.embed_dim = config.embed_dim |
|
|
| assert config.embed_dim % config.num_heads == 0, \ |
| "embed_dim harus bisa dibagi num_heads" |
|
|
| |
| self.q_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| self.k_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| self.v_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
|
|
| |
| self.out_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
|
|
| |
| self.rope = RotaryPositionalEncoding(self.head_dim, config.max_seq_len) |
|
|
| |
| self.attn_dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| batch_size, seq_len, _ = x.shape |
|
|
| |
| q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| cos, sin = self.rope(seq_len) |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
|
|
| |
| scale = math.sqrt(self.head_dim) |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / scale |
|
|
| |
| if mask is None: |
| mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), |
| diagonal=1 |
| ) |
| attn_weights = attn_weights.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
|
|
| |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| |
| output = torch.matmul(attn_weights, v) |
|
|
| |
| output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) |
| output = self.out_proj(output) |
|
|
| return output |
|
|
|
|
| class FeedForward(nn.Module): |
| """ |
| Feed-Forward Network dengan SwiGLU activation. |
| Teknik modern yang dipakai LLaMA, Mistral, dll. |
| """ |
|
|
| def __init__(self, config: SLMConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(config.embed_dim, config.ffn_dim, bias=False) |
| self.up_proj = nn.Linear(config.embed_dim, config.ffn_dim, bias=False) |
| self.down_proj = nn.Linear(config.ffn_dim, config.embed_dim, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| gate = F.silu(self.gate_proj(x)) |
| up = self.up_proj(x) |
| x = gate * up |
| x = self.down_proj(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """ |
| Satu block Transformer: Attention → FFN, dengan RMSNorm + residual. |
| """ |
|
|
| def __init__(self, config: SLMConfig): |
| super().__init__() |
| self.attention = MultiHeadSelfAttention(config) |
| self.feed_forward = FeedForward(config) |
| self.attn_norm = RMSNorm(config.embed_dim, config.layer_norm_eps) |
| self.ffn_norm = RMSNorm(config.embed_dim, config.layer_norm_eps) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| x = x + self.attention(self.attn_norm(x), mask) |
|
|
| |
| x = x + self.feed_forward(self.ffn_norm(x)) |
|
|
| return x |
|
|
|
|
| class SmallLM(nn.Module): |
| """ |
| Small Language Model (SLM) - GPT-style Transformer. |
| |
| Arsitektur: |
| Token Embedding + RoPE |
| → N × TransformerBlock (Attention + FFN) |
| → RMSNorm |
| → Output Linear (predict next token) |
| """ |
|
|
| def __init__(self, config: SLMConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim) |
|
|
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.num_layers) |
| ]) |
|
|
| |
| self.norm = RMSNorm(config.embed_dim, config.layer_norm_eps) |
|
|
| |
| self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False) |
|
|
| |
| self.lm_head.weight = self.token_embedding.weight |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| """Xavier/Kaiming initialization.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| input_ids: Token IDs, shape (batch, seq_len) |
| |
| Returns: |
| Logits, shape (batch, seq_len, vocab_size) |
| """ |
| |
| x = self.token_embedding(input_ids) |
|
|
| |
| seq_len = input_ids.shape[1] |
| mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=input_ids.device, dtype=torch.bool), |
| diagonal=1 |
| ) |
|
|
| |
| for layer in self.layers: |
| x = layer(x, mask) |
|
|
| |
| x = self.norm(x) |
| logits = self.lm_head(x) |
|
|
| return logits |
|
|
| def count_parameters(self) -> int: |
| """Count total trainable parameters.""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 50, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| top_p: float = 0.9, |
| ) -> torch.Tensor: |
| """ |
| Autoregressive text generation. |
| |
| Args: |
| input_ids: Starting token IDs, shape (1, seq_len) |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature (lower = more deterministic) |
| top_k: Top-k sampling |
| top_p: Nucleus (top-p) sampling |
| """ |
| self.eval() |
|
|
| for _ in range(max_new_tokens): |
| |
| idx_cond = input_ids[:, -self.config.max_seq_len:] |
|
|
| |
| logits = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k > 0: |
| top_k_values, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < top_k_values[:, [-1]]] = float('-inf') |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[sorted_mask] = float('-inf') |
| logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| |
| if next_token.item() == 3: |
| break |
|
|
| return input_ids |
|
|
| def save_pretrained(self, directory: str): |
| """Save model weights and config.""" |
| os.makedirs(directory, exist_ok=True) |
|
|
| |
| self.config.save(os.path.join(directory, "config.json")) |
|
|
| |
| from safetensors.torch import save_file |
| |
| state_dict = {} |
| for k, v in self.state_dict().items(): |
| if k != "lm_head.weight": |
| state_dict[k] = v |
| save_file(state_dict, os.path.join(directory, "model.safetensors")) |
|
|
| print(f"💾 Model saved to: {directory}") |
|
|
| @classmethod |
| def from_pretrained(cls, directory: str, device: str = "cpu") -> "SmallLM": |
| """Load model from directory.""" |
| config = SLMConfig.load(os.path.join(directory, "config.json")) |
| model = cls(config) |
|
|
| from safetensors.torch import load_file |
| state_dict = load_file(os.path.join(directory, "model.safetensors")) |
| |
| if "lm_head.weight" not in state_dict and "token_embedding.weight" in state_dict: |
| state_dict["lm_head.weight"] = state_dict["token_embedding.weight"] |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
|
|
| print(f"✅ Model loaded from: {directory}") |
| print(f" Parameters: {model.count_parameters():,}") |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| |
| config = SLMConfig() |
| model = SmallLM(config) |
|
|
| print(f"🧠 SmallLM Architecture") |
| print(f" Parameters: {model.count_parameters():,}") |
| print(f" Config: {config}") |
|
|
| |
| dummy_input = torch.randint(0, config.vocab_size, (2, 32)) |
| logits = model(dummy_input) |
| print(f"\n Input shape: {dummy_input.shape}") |
| print(f" Output shape: {logits.shape}") |
| print(f" ✅ Forward pass works!") |
|
|