"""Model loading and text generation engine.""" import torch from model import GPT, GPTConfig from config import _CHECKPOINT_FILE def load_model() -> tuple[GPT, dict[str, int], dict[int, str]]: """Load the trained model and tokenizer.""" torch.serialization.add_safe_globals([GPTConfig]) checkpoint = torch.load(_CHECKPOINT_FILE, map_location="cpu", weights_only=True) config = checkpoint["config"] stoi = checkpoint["stoi"] itos = checkpoint["itos"] model = GPT(config) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, stoi, itos def generate( model: GPT, prompt: str, stoi: dict[str, int], itos: dict[int, str], *, max_new_tokens: int = 200, temperature: float = 0.8, device: str = "cpu", seed: int | None = None, ) -> str: """Generate text from prompt.""" if seed is not None: torch.manual_seed(seed) model = model.to(device) prompt_tokens = [stoi.get(c, stoi.get(" ", 0)) for c in prompt] idx = torch.tensor([prompt_tokens], dtype=torch.long, device=device) with torch.no_grad(): for _ in range(max_new_tokens): idx_cond = idx[:, -model.config.block_size :] logits, _ = model(idx_cond) logits = logits[:, -1, :] / temperature probs = torch.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) tokens = idx[0].tolist() return "".join([itos.get(t, "") for t in tokens]) # --------------------------------------------------------------------------- # Eager-load model on first import # --------------------------------------------------------------------------- print("Loading model...") model, stoi, itos = load_model() print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params")