gpt-152m-demo / app.py
Nj-1111's picture
update
b7e7d5a verified
Raw
History Blame Contribute Delete
9.06 kB
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer
from huggingface_hub import hf_hub_download
import math
# ── Model Architecture (must match training exactly) ─────────
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, head_dim, max_seq_len):
super().__init__()
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim))
freqs = torch.outer(torch.arange(max_seq_len).float(), inv_freq)
self.register_buffer("cos_table", freqs.cos())
self.register_buffer("sin_table", freqs.sin())
@staticmethod
def _rotate_half(x):
half = x.shape[-1] // 2
return torch.cat([-x[..., half:], x[..., :half]], dim=-1)
def forward(self, x):
T = x.shape[2]
cos = torch.cat([self.cos_table[:T], self.cos_table[:T]], dim=-1)
sin = torch.cat([self.sin_table[:T], self.sin_table[:T]], dim=-1)
return x * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, num_heads, context_length, dropout):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.rope = RotaryPositionalEmbedding(self.head_dim, context_length)
self.dropout = nn.Dropout(dropout)
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
causal = torch.zeros(context_length, context_length)
causal.masked_fill_(mask, float("-inf"))
self.register_buffer("causal_mask", causal.unsqueeze(0).unsqueeze(0))
def forward(self, x):
B, T, C = x.shape
Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
Q, K = self.rope(Q), self.rope(K)
scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim)
scores = scores + self.causal_mask[:,:,:T,:T]
w = self.dropout(F.softmax(scores, dim=-1))
out = torch.matmul(w, V).transpose(1,2).contiguous().view(B,T,C)
return self.out_proj(out)
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, ffn_hidden_dim, dropout):
super().__init__()
self.linear_gate = nn.Linear(d_model, ffn_hidden_dim, bias=False)
self.linear_value = nn.Linear(d_model, ffn_hidden_dim, bias=False)
self.linear_out = nn.Linear(ffn_hidden_dim, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(
self.linear_out(
F.silu(self.linear_gate(x)) * self.linear_value(x)
)
)
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, ffn_hidden_dim, context_length, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, num_heads, context_length, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = SwiGLUFFN(d_model, ffn_hidden_dim, dropout)
def forward(self, x):
return x + self.ffn(self.ln2(x + self.attn(self.ln1(x))))
class GPTModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding = nn.Embedding(50257, 768)
self.blocks = nn.ModuleList([
TransformerBlock(768, 12, 3072, 512, 0.1) for _ in range(12)
])
self.ln_final = nn.LayerNorm(768)
self.lm_head = nn.Linear(768, 50257, bias=False)
self.lm_head.weight = self.token_embedding.weight
def forward(self, x):
h = self.token_embedding(x)
for block in self.blocks:
h = block(h)
return self.lm_head(self.ln_final(h))
# ── Load model and tokenizer ─────────────────────────────────
DEVICE = torch.device("cpu") # Spaces free tier uses CPU
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print("Downloading model weights...")
model_path = hf_hub_download(
repo_id="Nj-1111/gpt-152m-fineweb", # ← your HF username/repo
filename="pytorch_model.pt"
)
model = GPTModel().to(DEVICE)
ckpt = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
print("Model loaded βœ“")
# ── Generation function ──────────────────────────────────────
def generate_text(prompt, max_new_tokens, temperature, top_k, repetition_penalty):
if not prompt.strip():
return "Please enter a prompt."
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
generated = input_ids.clone()
with torch.no_grad():
for _ in range(int(max_new_tokens)):
x = generated[:, -512:]
logits = model(x)[:, -1, :].float()
for token_id in set(generated[0].tolist()):
if logits[0, token_id] > 0:
logits[0, token_id] /= repetition_penalty
else:
logits[0, token_id] *= repetition_penalty
logits = logits / max(temperature, 1e-8)
k = min(int(top_k), logits.size(-1))
topk_vals, _ = torch.topk(logits, k)
logits = logits.masked_fill(logits < topk_vals[:, -1:], -1e9)
probs = torch.softmax(logits, dim=-1).clamp(min=0)
probs = probs / probs.sum()
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(generated[0], skip_special_tokens=True)
# ── Gradio Interface ─────────────────────────────────────────
examples = [
["Quantum mechanics is the branch of physics that", 150, 0.8, 50, 1.3],
["The French Revolution began in 1789 because", 150, 0.8, 40, 1.3],
["DNA carries genetic information by", 150, 0.8, 50, 1.3],
["The solar system consists of eight planets", 150, 0.8, 40, 1.3],
["In mathematics, a prime number is", 150, 0.7, 30, 1.3],
["Climate change affects the environment by", 150, 0.8, 50, 1.3],
]
with gr.Blocks(title="GPT-152M Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧠 GPT-152M β€” Trained From Scratch
A 152 million parameter language model built with raw PyTorch and trained on
197M tokens of educational text (FineWeb-Edu). No pretrained weights were used.
**Best results:** Use textbook-style prompts, not search queries.
""")
with gr.Row():
with gr.Column(scale=2):
prompt_box = gr.Textbox(
label="Prompt",
placeholder="e.g. Quantum mechanics is the branch of physics that",
lines=3
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
output_box = gr.Textbox(label="Generated Text", lines=8, interactive=False)
with gr.Column(scale=1):
max_tokens = gr.Slider(50, 300, value=150, step=10,
label="Max new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05,
label="Temperature (higher = more creative)")
top_k = gr.Slider(10, 100, value=50, step=5,
label="Top-k (lower = more focused)")
rep_penalty = gr.Slider(1.0, 2.0, value=1.3, step=0.05,
label="Repetition penalty")
gr.Examples(
examples=examples,
inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty],
outputs=output_box,
fn=generate_text,
cache_examples=True,
label="Example prompts β€” click any to try"
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty],
outputs=output_box
)
gr.Markdown("""
---
**Model:** GPT-152M | **Dataset:** FineWeb-Edu (197M tokens) |
**Hardware:** Free Kaggle T4 GPU (~8.5 hours) | **Framework:** PyTorch 2.9
⚠️ This model was trained for educational purposes.
Outputs may be factually incorrect.
""")
demo.launch()