import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Tokenizer from huggingface_hub import hf_hub_download import math # ── Model Architecture (must match training exactly) ───────── class RotaryPositionalEmbedding(nn.Module): def __init__(self, head_dim, max_seq_len): super().__init__() inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim)) freqs = torch.outer(torch.arange(max_seq_len).float(), inv_freq) self.register_buffer("cos_table", freqs.cos()) self.register_buffer("sin_table", freqs.sin()) @staticmethod def _rotate_half(x): half = x.shape[-1] // 2 return torch.cat([-x[..., half:], x[..., :half]], dim=-1) def forward(self, x): T = x.shape[2] cos = torch.cat([self.cos_table[:T], self.cos_table[:T]], dim=-1) sin = torch.cat([self.sin_table[:T], self.sin_table[:T]], dim=-1) return x * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(x) * sin.unsqueeze(0).unsqueeze(0) class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, num_heads, context_length, dropout): super().__init__() self.num_heads = num_heads self.head_dim = d_model // num_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.rope = RotaryPositionalEmbedding(self.head_dim, context_length) self.dropout = nn.Dropout(dropout) mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool() causal = torch.zeros(context_length, context_length) causal.masked_fill_(mask, float("-inf")) self.register_buffer("causal_mask", causal.unsqueeze(0).unsqueeze(0)) def forward(self, x): B, T, C = x.shape Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) Q, K = self.rope(Q), self.rope(K) scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim) scores = scores + self.causal_mask[:,:,:T,:T] w = self.dropout(F.softmax(scores, dim=-1)) out = torch.matmul(w, V).transpose(1,2).contiguous().view(B,T,C) return self.out_proj(out) class SwiGLUFFN(nn.Module): def __init__(self, d_model, ffn_hidden_dim, dropout): super().__init__() self.linear_gate = nn.Linear(d_model, ffn_hidden_dim, bias=False) self.linear_value = nn.Linear(d_model, ffn_hidden_dim, bias=False) self.linear_out = nn.Linear(ffn_hidden_dim, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout( self.linear_out( F.silu(self.linear_gate(x)) * self.linear_value(x) ) ) class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, ffn_hidden_dim, context_length, dropout): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = MultiHeadSelfAttention(d_model, num_heads, context_length, dropout) self.ln2 = nn.LayerNorm(d_model) self.ffn = SwiGLUFFN(d_model, ffn_hidden_dim, dropout) def forward(self, x): return x + self.ffn(self.ln2(x + self.attn(self.ln1(x)))) class GPTModel(nn.Module): def __init__(self): super().__init__() self.token_embedding = nn.Embedding(50257, 768) self.blocks = nn.ModuleList([ TransformerBlock(768, 12, 3072, 512, 0.1) for _ in range(12) ]) self.ln_final = nn.LayerNorm(768) self.lm_head = nn.Linear(768, 50257, bias=False) self.lm_head.weight = self.token_embedding.weight def forward(self, x): h = self.token_embedding(x) for block in self.blocks: h = block(h) return self.lm_head(self.ln_final(h)) # ── Load model and tokenizer ───────────────────────────────── DEVICE = torch.device("cpu") # Spaces free tier uses CPU tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token print("Downloading model weights...") model_path = hf_hub_download( repo_id="Nj-1111/gpt-152m-fineweb", # ← your HF username/repo filename="pytorch_model.pt" ) model = GPTModel().to(DEVICE) ckpt = torch.load(model_path, map_location=DEVICE) model.load_state_dict(ckpt["model_state_dict"]) model.eval() print("Model loaded ✓") # ── Generation function ────────────────────────────────────── def generate_text(prompt, max_new_tokens, temperature, top_k, repetition_penalty): if not prompt.strip(): return "Please enter a prompt." input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) generated = input_ids.clone() with torch.no_grad(): for _ in range(int(max_new_tokens)): x = generated[:, -512:] logits = model(x)[:, -1, :].float() for token_id in set(generated[0].tolist()): if logits[0, token_id] > 0: logits[0, token_id] /= repetition_penalty else: logits[0, token_id] *= repetition_penalty logits = logits / max(temperature, 1e-8) k = min(int(top_k), logits.size(-1)) topk_vals, _ = torch.topk(logits, k) logits = logits.masked_fill(logits < topk_vals[:, -1:], -1e9) probs = torch.softmax(logits, dim=-1).clamp(min=0) probs = probs / probs.sum() next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() == tokenizer.eos_token_id: break return tokenizer.decode(generated[0], skip_special_tokens=True) # ── Gradio Interface ───────────────────────────────────────── examples = [ ["Quantum mechanics is the branch of physics that", 150, 0.8, 50, 1.3], ["The French Revolution began in 1789 because", 150, 0.8, 40, 1.3], ["DNA carries genetic information by", 150, 0.8, 50, 1.3], ["The solar system consists of eight planets", 150, 0.8, 40, 1.3], ["In mathematics, a prime number is", 150, 0.7, 30, 1.3], ["Climate change affects the environment by", 150, 0.8, 50, 1.3], ] with gr.Blocks(title="GPT-152M Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧠 GPT-152M — Trained From Scratch A 152 million parameter language model built with raw PyTorch and trained on 197M tokens of educational text (FineWeb-Edu). No pretrained weights were used. **Best results:** Use textbook-style prompts, not search queries. """) with gr.Row(): with gr.Column(scale=2): prompt_box = gr.Textbox( label="Prompt", placeholder="e.g. Quantum mechanics is the branch of physics that", lines=3 ) generate_btn = gr.Button("Generate", variant="primary", size="lg") output_box = gr.Textbox(label="Generated Text", lines=8, interactive=False) with gr.Column(scale=1): max_tokens = gr.Slider(50, 300, value=150, step=10, label="Max new tokens") temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature (higher = more creative)") top_k = gr.Slider(10, 100, value=50, step=5, label="Top-k (lower = more focused)") rep_penalty = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="Repetition penalty") gr.Examples( examples=examples, inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty], outputs=output_box, fn=generate_text, cache_examples=True, label="Example prompts — click any to try" ) generate_btn.click( fn=generate_text, inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty], outputs=output_box ) gr.Markdown(""" --- **Model:** GPT-152M | **Dataset:** FineWeb-Edu (197M tokens) | **Hardware:** Free Kaggle T4 GPU (~8.5 hours) | **Framework:** PyTorch 2.9 ⚠️ This model was trained for educational purposes. Outputs may be factually incorrect. """) demo.launch()