Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import GPT2Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import math | |
| # ββ Model Architecture (must match training exactly) βββββββββ | |
| class RotaryPositionalEmbedding(nn.Module): | |
| def __init__(self, head_dim, max_seq_len): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim)) | |
| freqs = torch.outer(torch.arange(max_seq_len).float(), inv_freq) | |
| self.register_buffer("cos_table", freqs.cos()) | |
| self.register_buffer("sin_table", freqs.sin()) | |
| def _rotate_half(x): | |
| half = x.shape[-1] // 2 | |
| return torch.cat([-x[..., half:], x[..., :half]], dim=-1) | |
| def forward(self, x): | |
| T = x.shape[2] | |
| cos = torch.cat([self.cos_table[:T], self.cos_table[:T]], dim=-1) | |
| sin = torch.cat([self.sin_table[:T], self.sin_table[:T]], dim=-1) | |
| return x * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(x) * sin.unsqueeze(0).unsqueeze(0) | |
| class MultiHeadSelfAttention(nn.Module): | |
| def __init__(self, d_model, num_heads, context_length, dropout): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = d_model // num_heads | |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.rope = RotaryPositionalEmbedding(self.head_dim, context_length) | |
| self.dropout = nn.Dropout(dropout) | |
| mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool() | |
| causal = torch.zeros(context_length, context_length) | |
| causal.masked_fill_(mask, float("-inf")) | |
| self.register_buffer("causal_mask", causal.unsqueeze(0).unsqueeze(0)) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) | |
| K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) | |
| V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2) | |
| Q, K = self.rope(Q), self.rope(K) | |
| scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim) | |
| scores = scores + self.causal_mask[:,:,:T,:T] | |
| w = self.dropout(F.softmax(scores, dim=-1)) | |
| out = torch.matmul(w, V).transpose(1,2).contiguous().view(B,T,C) | |
| return self.out_proj(out) | |
| class SwiGLUFFN(nn.Module): | |
| def __init__(self, d_model, ffn_hidden_dim, dropout): | |
| super().__init__() | |
| self.linear_gate = nn.Linear(d_model, ffn_hidden_dim, bias=False) | |
| self.linear_value = nn.Linear(d_model, ffn_hidden_dim, bias=False) | |
| self.linear_out = nn.Linear(ffn_hidden_dim, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.dropout( | |
| self.linear_out( | |
| F.silu(self.linear_gate(x)) * self.linear_value(x) | |
| ) | |
| ) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, num_heads, ffn_hidden_dim, context_length, dropout): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = MultiHeadSelfAttention(d_model, num_heads, context_length, dropout) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ffn = SwiGLUFFN(d_model, ffn_hidden_dim, dropout) | |
| def forward(self, x): | |
| return x + self.ffn(self.ln2(x + self.attn(self.ln1(x)))) | |
| class GPTModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(50257, 768) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(768, 12, 3072, 512, 0.1) for _ in range(12) | |
| ]) | |
| self.ln_final = nn.LayerNorm(768) | |
| self.lm_head = nn.Linear(768, 50257, bias=False) | |
| self.lm_head.weight = self.token_embedding.weight | |
| def forward(self, x): | |
| h = self.token_embedding(x) | |
| for block in self.blocks: | |
| h = block(h) | |
| return self.lm_head(self.ln_final(h)) | |
| # ββ Load model and tokenizer βββββββββββββββββββββββββββββββββ | |
| DEVICE = torch.device("cpu") # Spaces free tier uses CPU | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Downloading model weights...") | |
| model_path = hf_hub_download( | |
| repo_id="Nj-1111/gpt-152m-fineweb", # β your HF username/repo | |
| filename="pytorch_model.pt" | |
| ) | |
| model = GPTModel().to(DEVICE) | |
| ckpt = torch.load(model_path, map_location=DEVICE) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model.eval() | |
| print("Model loaded β") | |
| # ββ Generation function ββββββββββββββββββββββββββββββββββββββ | |
| def generate_text(prompt, max_new_tokens, temperature, top_k, repetition_penalty): | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) | |
| generated = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| x = generated[:, -512:] | |
| logits = model(x)[:, -1, :].float() | |
| for token_id in set(generated[0].tolist()): | |
| if logits[0, token_id] > 0: | |
| logits[0, token_id] /= repetition_penalty | |
| else: | |
| logits[0, token_id] *= repetition_penalty | |
| logits = logits / max(temperature, 1e-8) | |
| k = min(int(top_k), logits.size(-1)) | |
| topk_vals, _ = torch.topk(logits, k) | |
| logits = logits.masked_fill(logits < topk_vals[:, -1:], -1e9) | |
| probs = torch.softmax(logits, dim=-1).clamp(min=0) | |
| probs = probs / probs.sum() | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated = torch.cat([generated, next_token], dim=1) | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| return tokenizer.decode(generated[0], skip_special_tokens=True) | |
| # ββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ | |
| examples = [ | |
| ["Quantum mechanics is the branch of physics that", 150, 0.8, 50, 1.3], | |
| ["The French Revolution began in 1789 because", 150, 0.8, 40, 1.3], | |
| ["DNA carries genetic information by", 150, 0.8, 50, 1.3], | |
| ["The solar system consists of eight planets", 150, 0.8, 40, 1.3], | |
| ["In mathematics, a prime number is", 150, 0.7, 30, 1.3], | |
| ["Climate change affects the environment by", 150, 0.8, 50, 1.3], | |
| ] | |
| with gr.Blocks(title="GPT-152M Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π§ GPT-152M β Trained From Scratch | |
| A 152 million parameter language model built with raw PyTorch and trained on | |
| 197M tokens of educational text (FineWeb-Edu). No pretrained weights were used. | |
| **Best results:** Use textbook-style prompts, not search queries. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="e.g. Quantum mechanics is the branch of physics that", | |
| lines=3 | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| output_box = gr.Textbox(label="Generated Text", lines=8, interactive=False) | |
| with gr.Column(scale=1): | |
| max_tokens = gr.Slider(50, 300, value=150, step=10, | |
| label="Max new tokens") | |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, | |
| label="Temperature (higher = more creative)") | |
| top_k = gr.Slider(10, 100, value=50, step=5, | |
| label="Top-k (lower = more focused)") | |
| rep_penalty = gr.Slider(1.0, 2.0, value=1.3, step=0.05, | |
| label="Repetition penalty") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty], | |
| outputs=output_box, | |
| fn=generate_text, | |
| cache_examples=True, | |
| label="Example prompts β click any to try" | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty], | |
| outputs=output_box | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Model:** GPT-152M | **Dataset:** FineWeb-Edu (197M tokens) | | |
| **Hardware:** Free Kaggle T4 GPU (~8.5 hours) | **Framework:** PyTorch 2.9 | |
| β οΈ This model was trained for educational purposes. | |
| Outputs may be factually incorrect. | |
| """) | |
| demo.launch() | |