"""Training script for Thirukkural GPT model.""" import json import math import sys import torch from tqdm import tqdm from generate import generate from model import GPT, GPTConfig def get_device() -> torch.device: """Get the best available device for training.""" if torch.backends.mps.is_available(): return torch.device("mps") elif torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def load_data( filepath: str, block_size: int, batch_size: int, device: torch.device ) -> tuple: """Load and tokenize the dataset.""" with open(filepath, "r", encoding="utf-8") as f: text = f.read() chars = sorted(set(text)) vocab_size = len(chars) stoi = {c: i for i, c in enumerate(chars)} itos = {i: c for c, i in stoi.items()} tokens = torch.tensor([stoi[c] for c in text], dtype=torch.long) print(f"Dataset: {len(tokens):,} chars, vocab size: {vocab_size}") def get_batch(split_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ix = torch.randint(len(split_tokens) - block_size - 1, (batch_size,)) x = torch.stack([split_tokens[i : i + block_size] for i in ix]).to(device) y = torch.stack([split_tokens[i + 1 : i + block_size + 1] for i in ix]).to(device) return x, y n = int(0.9 * len(tokens)) get_train = lambda: get_batch(tokens[:n]) get_val = lambda: get_batch(tokens[n:]) return get_train, get_val, vocab_size, stoi, itos def get_lr( step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float ) -> float: """Calculate learning rate with warmup and cosine decay.""" if step < warmup_steps: return max_lr * (step + 1) / warmup_steps if step >= max_steps: return min_lr progress = (step - warmup_steps) / (max_steps - warmup_steps) return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) def train( data_path: str, max_steps: int = 10000, batch_size: int = 64, n_layer: int = 8, n_head: int = 8, n_embd: int = 512, block_size: int = 256, ) -> tuple[GPT, dict[str, int], dict[int, str]]: """Train a GPT model on the given dataset.""" device = get_device() print(f"Using device: {device}") get_train_batch, get_val_batch, vocab_size, stoi, itos = load_data( data_path, block_size, batch_size, device ) config = GPTConfig( vocab_size=vocab_size, block_size=block_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, ) model = GPT(config).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Model: {n_layer}L/{n_head}H/{n_embd}D, {n_params / 1e6:.1f}M params") optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) max_lr = 1e-3 min_lr = max_lr * 0.1 warmup_steps = 100 loss_log: dict = {"steps": [], "train": [], "val": []} pbar = tqdm(range(max_steps), desc="Training") for step in pbar: # Validation loss if step % 100 == 0: model.eval() with torch.no_grad(): val_losses = [] for _ in range(20): x, y = get_val_batch() _, loss = model(x, y) val_losses.append(loss.item()) val_loss = sum(val_losses) / len(val_losses) tqdm.write(f"Step {step:5d} | val loss: {val_loss:.4f}") model.train() # Update learning rate lr = get_lr(step, warmup_steps, max_steps, max_lr, min_lr) for param_group in optimizer.param_groups: param_group["lr"] = lr # Training step x, y = get_train_batch() _, loss = model(x, y) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}") # Log loss loss_log["steps"].append(step) loss_log["train"].append(loss.item()) if step % 100 == 0: loss_log["val"].append(val_loss) # Generate sample if step > 0 and step % 100 == 0: model.eval() sample = generate( model, "கடவுள் வாழ்த்து", stoi, itos, max_new_tokens=150, temperature=0.8 ) tqdm.write(f"\n--- Step {step} sample ---\n{sample}\n---\n") model.train() # Save checkpoint if step > 0 and step % 1000 == 0: torch.save( { "step": step, "model_state_dict": model.state_dict(), "config": config, "stoi": stoi, "itos": itos, }, f"checkpoint_{step}.pt", ) # Save final checkpoint torch.save( { "step": max_steps, "model_state_dict": model.state_dict(), "config": config, "stoi": stoi, "itos": itos, }, "checkpoint_final.pt", ) with open("loss_log.json", "w") as f: json.dump(loss_log, f) return model, stoi, itos if __name__ == "__main__": data_path = sys.argv[1] if len(sys.argv) > 1 else "../data/thirukkural.txt" train(data_path)