valluvar-or-ai / train.py
NaveenKumar Namachivayam
feat: add Thirukkural Tamil text dataset with English translations
b817849
raw
history blame
5.43 kB
"""Training script for Thirukkural GPT model."""
import json
import math
import sys
import torch
from tqdm import tqdm
from generate import generate
from model import GPT, GPTConfig
def get_device() -> torch.device:
"""Get the best available device for training."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_data(
filepath: str, block_size: int, batch_size: int, device: torch.device
) -> tuple:
"""Load and tokenize the dataset."""
with open(filepath, "r", encoding="utf-8") as f:
text = f.read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
tokens = torch.tensor([stoi[c] for c in text], dtype=torch.long)
print(f"Dataset: {len(tokens):,} chars, vocab size: {vocab_size}")
def get_batch(split_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ix = torch.randint(len(split_tokens) - block_size - 1, (batch_size,))
x = torch.stack([split_tokens[i : i + block_size] for i in ix]).to(device)
y = torch.stack([split_tokens[i + 1 : i + block_size + 1] for i in ix]).to(device)
return x, y
n = int(0.9 * len(tokens))
get_train = lambda: get_batch(tokens[:n])
get_val = lambda: get_batch(tokens[n:])
return get_train, get_val, vocab_size, stoi, itos
def get_lr(
step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float
) -> float:
"""Calculate learning rate with warmup and cosine decay."""
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
if step >= max_steps:
return min_lr
progress = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def train(
data_path: str,
max_steps: int = 10000,
batch_size: int = 64,
n_layer: int = 8,
n_head: int = 8,
n_embd: int = 512,
block_size: int = 256,
) -> tuple[GPT, dict[str, int], dict[int, str]]:
"""Train a GPT model on the given dataset."""
device = get_device()
print(f"Using device: {device}")
get_train_batch, get_val_batch, vocab_size, stoi, itos = load_data(
data_path, block_size, batch_size, device
)
config = GPTConfig(
vocab_size=vocab_size,
block_size=block_size,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
)
model = GPT(config).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model: {n_layer}L/{n_head}H/{n_embd}D, {n_params / 1e6:.1f}M params")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
max_lr = 1e-3
min_lr = max_lr * 0.1
warmup_steps = 100
loss_log: dict = {"steps": [], "train": [], "val": []}
pbar = tqdm(range(max_steps), desc="Training")
for step in pbar:
# Validation loss
if step % 100 == 0:
model.eval()
with torch.no_grad():
val_losses = []
for _ in range(20):
x, y = get_val_batch()
_, loss = model(x, y)
val_losses.append(loss.item())
val_loss = sum(val_losses) / len(val_losses)
tqdm.write(f"Step {step:5d} | val loss: {val_loss:.4f}")
model.train()
# Update learning rate
lr = get_lr(step, warmup_steps, max_steps, max_lr, min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# Training step
x, y = get_train_batch()
_, loss = model(x, y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}")
# Log loss
loss_log["steps"].append(step)
loss_log["train"].append(loss.item())
if step % 100 == 0:
loss_log["val"].append(val_loss)
# Generate sample
if step > 0 and step % 100 == 0:
model.eval()
sample = generate(
model, "கடவுள் வாழ்த்து", stoi, itos, max_new_tokens=150, temperature=0.8
)
tqdm.write(f"\n--- Step {step} sample ---\n{sample}\n---\n")
model.train()
# Save checkpoint
if step > 0 and step % 1000 == 0:
torch.save(
{
"step": step,
"model_state_dict": model.state_dict(),
"config": config,
"stoi": stoi,
"itos": itos,
},
f"checkpoint_{step}.pt",
)
# Save final checkpoint
torch.save(
{
"step": max_steps,
"model_state_dict": model.state_dict(),
"config": config,
"stoi": stoi,
"itos": itos,
},
"checkpoint_final.pt",
)
with open("loss_log.json", "w") as f:
json.dump(loss_log, f)
return model, stoi, itos
if __name__ == "__main__":
data_path = sys.argv[1] if len(sys.argv) > 1 else "../data/thirukkural.txt"
train(data_path)