| """ |
| Training Script - Small Language Model (SLM) |
| ============================================= |
| Author: Jekardah AI Lab |
| """ |
|
|
| import os |
| import re |
| import sys |
| import time |
| import math |
| import random |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from model import SmallLM, SLMConfig |
|
|
|
|
| class TextDataset(Dataset): |
| def __init__(self, all_ids, seq_len=64): |
| self.data = torch.tensor(all_ids, dtype=torch.long) |
| self.seq_len = seq_len |
| self.stride = seq_len // 2 |
| self.n = max(0, (len(all_ids) - seq_len - 1) // self.stride) |
|
|
| def __len__(self): |
| return self.n |
|
|
| def __getitem__(self, idx): |
| s = idx * self.stride |
| chunk = self.data[s : s + self.seq_len + 1] |
| return chunk[:-1], chunk[1:] |
|
|
|
|
| def load_and_tokenize(tokenizer, kbbi_path, max_texts=5000): |
| texts = [] |
|
|
| |
| local_kbbi = os.path.join(os.path.dirname(__file__), "kbbi_raw.txt") |
| kbbi_file = local_kbbi if os.path.exists(local_kbbi) else kbbi_path |
|
|
| if os.path.exists(kbbi_file): |
| print(f"๐ Loading KBBI dari: {kbbi_file}") |
| with open(kbbi_file, "r", encoding="utf-8", errors="ignore") as f: |
| raw = f.read() |
| raw = raw.replace("\f", " ") |
| raw = re.sub(r'^\d+\s*$', '', raw, flags=re.MULTILINE) |
| raw = re.sub(r'-\n\s*', '', raw) |
| raw = re.sub(r'(?<!\n)\n(?!\n)', ' ', raw) |
| raw = re.sub(r' +', ' ', raw) |
| for line in raw.split('\n'): |
| line = line.strip() |
| if len(line) >= 20: |
| alpha = sum(1 for c in line if c.isalpha()) / max(len(line), 1) |
| if alpha >= 0.4: |
| texts.append(line) |
| print(f" KBBI: {len(texts)} texts") |
| else: |
| print("โ ๏ธ KBBI tidak ditemukan! Jalankan: python extract_kbbi.py") |
|
|
| bpe_dir = os.path.join(os.path.dirname(__file__), "..", "bpe-tokenizer-id") |
| sys.path.insert(0, bpe_dir) |
| try: |
| from training_data import get_training_data |
| general = get_training_data() |
| texts.extend(general * 5) |
| print(f" General: {len(general)} (ร5)") |
| except ImportError: |
| pass |
|
|
| random.seed(42) |
| random.shuffle(texts) |
| if len(texts) > max_texts: |
| texts = texts[:max_texts] |
| print(f" โก Limited to {max_texts}") |
|
|
| print("๐ Tokenizing...") |
| all_ids = [] |
| for text in texts: |
| all_ids.extend(tokenizer.encode(text)) |
| all_ids.append(3) |
|
|
| print(f" Tokens: {len(all_ids):,}") |
| return all_ids |
|
|
|
|
| def train(): |
| print("๐ Training SLM") |
| print("=" * 60) |
|
|
| SEQ_LEN = 64 |
| BATCH = 16 |
| EPOCHS = 10 |
| LR = 1e-3 |
| MAX_MIN = 7 |
|
|
| config = SLMConfig( |
| vocab_size=4000, embed_dim=128, num_heads=4, |
| num_layers=2, ffn_dim=256, max_seq_len=SEQ_LEN, dropout=0.1, |
| ) |
|
|
| print(f"โ๏ธ {config.embed_dim}d, {config.num_layers}L, {config.num_heads}H, vocab={config.vocab_size}") |
|
|
| |
| bpe_dir = os.path.join(os.path.dirname(__file__), "..", "bpe-tokenizer-id") |
| sys.path.insert(0, bpe_dir) |
| from bpe_tokenizer import BPETokenizer |
| tok_dir = os.path.join(bpe_dir, "output") |
| tokenizer = BPETokenizer.from_pretrained(tok_dir) |
| print(f" Vocab: {len(tokenizer.vocab):,}") |
|
|
| |
| kbbi_path = os.path.join(bpe_dir, "kbbi_raw.txt") |
| all_ids = load_and_tokenize(tokenizer, kbbi_path, 5000) |
|
|
| dataset = TextDataset(all_ids, SEQ_LEN) |
| loader = DataLoader(dataset, batch_size=BATCH, shuffle=True, drop_last=True) |
| print(f" Samples: {len(dataset):,} Batches: {len(loader):,}") |
|
|
| |
| model = SmallLM(config) |
| pc = model.count_parameters() |
| print(f"\n๐ง Params: {pc:,} (~{pc*4/1024/1024:.1f} MB)") |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) |
| loss_fn = nn.CrossEntropyLoss() |
| total_steps = len(loader) * EPOCHS |
|
|
| def get_lr(step): |
| if step < 30: return step / 30 |
| p = (step - 30) / max(total_steps - 30, 1) |
| return 0.5 * (1 + math.cos(math.pi * p)) |
|
|
| sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr) |
|
|
| print(f"\n{'='*60}") |
| print(f"๐๏ธ Go! ({total_steps} steps, max {MAX_MIN}min)") |
| print(f"{'='*60}\n") |
|
|
| best = float('inf') |
| step = 0 |
| t0 = time.time() |
|
|
| for ep in range(EPOCHS): |
| model.train() |
| ep_loss = 0 |
| ep_t = time.time() |
|
|
| for bi, (x, y) in enumerate(loader): |
| mins = (time.time() - t0) / 60 |
| if mins > MAX_MIN: |
| print(f"\n โฐ Time limit!") |
| break |
|
|
| logits = model(x) |
| loss = loss_fn(logits.view(-1, config.vocab_size), y.view(-1)) |
|
|
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| sched.step() |
|
|
| ep_loss += loss.item() |
| step += 1 |
|
|
| if step % 50 == 0 or bi == 0: |
| ppl = math.exp(min(loss.item(), 20)) |
| sps = (bi + 1) / max(time.time() - ep_t, 0.1) |
| print(f" E{ep+1} S{step}/{total_steps} | " |
| f"L={loss.item():.3f} PPL={ppl:.0f} | " |
| f"{sps:.1f}st/s {mins:.1f}m") |
|
|
| if (time.time() - t0) / 60 > MAX_MIN: |
| break |
|
|
| avg = ep_loss / max(bi + 1, 1) |
| print(f"\n ๐ Epoch {ep+1}: L={avg:.3f} PPL={math.exp(min(avg,20)):.0f} " |
| f"({time.time()-ep_t:.0f}s)") |
|
|
| if avg < best: |
| best = avg |
| save_model(model, config, tok_dir) |
| print(f" ๐พ Saved! (L={best:.3f})") |
|
|
| print(f" ๐ Samples:") |
| gen_samples(model, tokenizer) |
| print() |
|
|
| if best == float('inf'): |
| save_model(model, config, tok_dir) |
|
|
| tt = time.time() - t0 |
| print("=" * 60) |
| print(f"โ
Done! {tt:.0f}s ({tt/60:.1f}min)") |
| print(f" Best: L={best:.3f} PPL={math.exp(min(best,20)):.0f}") |
|
|
| out = "./output_slm" |
| total = 0 |
| print(f"\n๐ฆ Files:") |
| for f in sorted(os.listdir(out)): |
| s = os.path.getsize(os.path.join(out, f)) |
| total += s |
| print(f" {f:<30} {s:>10,} bytes") |
| print(f" {'TOTAL':<30} {total:>10,} ({total/1024/1024:.1f} MB)") |
|
|
|
|
| def save_model(model, config, tok_dir): |
| import shutil |
| out = "./output_slm" |
| model.save_pretrained(out) |
| for f in ["vocab.json", "merges.txt", "tokenizer_config.json", |
| "special_tokens_map.json", "tokenizer.json"]: |
| src = os.path.join(tok_dir, f) |
| if os.path.exists(src): |
| shutil.copy2(src, os.path.join(out, f)) |
|
|
|
|
| @torch.no_grad() |
| def gen_samples(model, tokenizer, n=3): |
| model.eval() |
| prompts = ["indonesia", "pendidikan", "makan", "jakarta", |
| "teknologi", "kebudayaan", "ekonomi"] |
| for p in random.sample(prompts, min(n, len(prompts))): |
| ids = tokenizer.encode(p) |
| inp = torch.tensor([ids], dtype=torch.long) |
| out = model.generate(inp, max_new_tokens=20, temperature=0.9, top_k=30) |
| text = tokenizer.decode(out[0].tolist()) |
| print(f" \"{p}\" โ \"{text[:70]}\"") |
| model.train() |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|