"""Text generation from trained GPT checkpoint.""" import argparse import torch from model import GPT @torch.no_grad() def generate( model: GPT, prompt: str, stoi: dict[str, int], itos: dict[int, str], max_new_tokens: int = 200, temperature: float = 0.8, top_k: int = 40, ) -> str: """Generate text from a prompt using the trained model.""" device = next(model.parameters()).device tokens = [stoi[c] for c in prompt if c in stoi] idx = torch.tensor([tokens], dtype=torch.long, device=device) model.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -model.config.block_size :] logits, _ = model(idx_cond) logits = logits[:, -1, :] / temperature if top_k > 0: values, _ = torch.topk(logits, top_k) logits[logits < values[:, -1:]] = float("-inf") probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_token], dim=1) return "".join([itos[i] for i in idx[0].tolist()]) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate text from a trained GPT checkpoint" ) parser.add_argument( "checkpoint", help="Path to checkpoint file (e.g. checkpoint_final.pt)" ) parser.add_argument( "--prompt", default="கடவுள் வாழ்த்து", help="Starting text for generation" ) parser.add_argument( "--max_new_tokens", type=int, default=200, help="Number of tokens to generate" ) parser.add_argument( "--temperature", type=float, default=0.8, help="Sampling temperature (lower = more deterministic)", ) parser.add_argument( "--top_k", type=int, default=40, help="Only sample from top-k most likely tokens" ) parser.add_argument("--seed", type=int, default=None, help="Random seed") args = parser.parse_args() if args.seed is not None: torch.manual_seed(args.seed) checkpoint = torch.load(args.checkpoint, weights_only=False) config = checkpoint["config"] stoi = checkpoint["stoi"] itos = checkpoint["itos"] model = GPT(config) model.load_state_dict(checkpoint["model_state_dict"]) output = generate( model, args.prompt, stoi, itos, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, ) print(output)