Nj-1111 commited on
Commit
58deda5
Β·
verified Β·
1 Parent(s): 114b38b

create app_v1

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import GPT2Tokenizer
6
+ from huggingface_hub import hf_hub_download
7
+ import math
8
+
9
+ # ── Model Architecture (must match training exactly) ─────────
10
+
11
+ class RotaryPositionalEmbedding(nn.Module):
12
+ def __init__(self, head_dim, max_seq_len):
13
+ super().__init__()
14
+ inv_freq = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim))
15
+ freqs = torch.outer(torch.arange(max_seq_len).float(), inv_freq)
16
+ self.register_buffer("cos_table", freqs.cos())
17
+ self.register_buffer("sin_table", freqs.sin())
18
+
19
+ @staticmethod
20
+ def _rotate_half(x):
21
+ half = x.shape[-1] // 2
22
+ return torch.cat([-x[..., half:], x[..., :half]], dim=-1)
23
+
24
+ def forward(self, x):
25
+ T = x.shape[2]
26
+ cos = torch.cat([self.cos_table[:T], self.cos_table[:T]], dim=-1)
27
+ sin = torch.cat([self.sin_table[:T], self.sin_table[:T]], dim=-1)
28
+ return x * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
29
+
30
+
31
+ class MultiHeadSelfAttention(nn.Module):
32
+ def __init__(self, d_model, num_heads, context_length, dropout):
33
+ super().__init__()
34
+ self.num_heads = num_heads
35
+ self.head_dim = d_model // num_heads
36
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
37
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
38
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
39
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
40
+ self.rope = RotaryPositionalEmbedding(self.head_dim, context_length)
41
+ self.dropout = nn.Dropout(dropout)
42
+ mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
43
+ causal = torch.zeros(context_length, context_length)
44
+ causal.masked_fill_(mask, float("-inf"))
45
+ self.register_buffer("causal_mask", causal.unsqueeze(0).unsqueeze(0))
46
+
47
+ def forward(self, x):
48
+ B, T, C = x.shape
49
+ Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
50
+ K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
51
+ V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
52
+ Q, K = self.rope(Q), self.rope(K)
53
+ scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.head_dim)
54
+ scores = scores + self.causal_mask[:,:,:T,:T]
55
+ w = self.dropout(F.softmax(scores, dim=-1))
56
+ out = torch.matmul(w, V).transpose(1,2).contiguous().view(B,T,C)
57
+ return self.out_proj(out)
58
+
59
+
60
+ class SwiGLUFFN(nn.Module):
61
+ def __init__(self, d_model, ffn_hidden_dim, dropout):
62
+ super().__init__()
63
+ self.gate = nn.Linear(d_model, ffn_hidden_dim, bias=False)
64
+ self.value = nn.Linear(d_model, ffn_hidden_dim, bias=False)
65
+ self.out = nn.Linear(ffn_hidden_dim, d_model, bias=False)
66
+ self.dropout = nn.Dropout(dropout)
67
+
68
+ def forward(self, x):
69
+ return self.dropout(self.out(F.silu(self.gate(x)) * self.value(x)))
70
+
71
+
72
+ class TransformerBlock(nn.Module):
73
+ def __init__(self, d_model, num_heads, ffn_hidden_dim, context_length, dropout):
74
+ super().__init__()
75
+ self.ln1 = nn.LayerNorm(d_model)
76
+ self.attn = MultiHeadSelfAttention(d_model, num_heads, context_length, dropout)
77
+ self.ln2 = nn.LayerNorm(d_model)
78
+ self.ffn = SwiGLUFFN(d_model, ffn_hidden_dim, dropout)
79
+
80
+ def forward(self, x):
81
+ return x + self.ffn(self.ln2(x + self.attn(self.ln1(x))))
82
+
83
+
84
+ class GPTModel(nn.Module):
85
+ def __init__(self):
86
+ super().__init__()
87
+ self.token_embedding = nn.Embedding(50257, 768)
88
+ self.blocks = nn.ModuleList([
89
+ TransformerBlock(768, 12, 3072, 512, 0.1) for _ in range(12)
90
+ ])
91
+ self.ln_final = nn.LayerNorm(768)
92
+ self.lm_head = nn.Linear(768, 50257, bias=False)
93
+ self.lm_head.weight = self.token_embedding.weight
94
+
95
+ def forward(self, x):
96
+ h = self.token_embedding(x)
97
+ for block in self.blocks:
98
+ h = block(h)
99
+ return self.lm_head(self.ln_final(h))
100
+
101
+
102
+ # ── Load model and tokenizer ─────────────────────────────────
103
+
104
+ DEVICE = torch.device("cpu") # Spaces free tier uses CPU
105
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
106
+ tokenizer.pad_token = tokenizer.eos_token
107
+
108
+ print("Downloading model weights...")
109
+ model_path = hf_hub_download(
110
+ repo_id="neelbose11/gpt-152m-fineweb", # ← your HF username/repo
111
+ filename="pytorch_model.pt"
112
+ )
113
+
114
+ model = GPTModel().to(DEVICE)
115
+ ckpt = torch.load(model_path, map_location=DEVICE)
116
+ model.load_state_dict(ckpt["model_state_dict"])
117
+ model.eval()
118
+ print("Model loaded βœ“")
119
+
120
+
121
+ # ── Generation function ──────────────────────────────────────
122
+
123
+ def generate_text(prompt, max_new_tokens, temperature, top_k, repetition_penalty):
124
+ if not prompt.strip():
125
+ return "Please enter a prompt."
126
+
127
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
128
+ generated = input_ids.clone()
129
+
130
+ with torch.no_grad():
131
+ for _ in range(int(max_new_tokens)):
132
+ x = generated[:, -512:]
133
+ logits = model(x)[:, -1, :].float()
134
+
135
+ for token_id in set(generated[0].tolist()):
136
+ if logits[0, token_id] > 0:
137
+ logits[0, token_id] /= repetition_penalty
138
+ else:
139
+ logits[0, token_id] *= repetition_penalty
140
+
141
+ logits = logits / max(temperature, 1e-8)
142
+ k = min(int(top_k), logits.size(-1))
143
+ topk_vals, _ = torch.topk(logits, k)
144
+ logits = logits.masked_fill(logits < topk_vals[:, -1:], -1e9)
145
+ probs = torch.softmax(logits, dim=-1).clamp(min=0)
146
+ probs = probs / probs.sum()
147
+ next_token = torch.multinomial(probs, num_samples=1)
148
+ generated = torch.cat([generated, next_token], dim=1)
149
+ if next_token.item() == tokenizer.eos_token_id:
150
+ break
151
+
152
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
153
+
154
+
155
+ # ── Gradio Interface ─────────────────────────────────────────
156
+
157
+ examples = [
158
+ ["Quantum mechanics is the branch of physics that", 150, 0.8, 50, 1.3],
159
+ ["The French Revolution began in 1789 because", 150, 0.8, 40, 1.3],
160
+ ["DNA carries genetic information by", 150, 0.8, 50, 1.3],
161
+ ["The solar system consists of eight planets", 150, 0.8, 40, 1.3],
162
+ ["In mathematics, a prime number is", 150, 0.7, 30, 1.3],
163
+ ["Climate change affects the environment by", 150, 0.8, 50, 1.3],
164
+ ]
165
+
166
+ with gr.Blocks(title="GPT-152M Demo", theme=gr.themes.Soft()) as demo:
167
+ gr.Markdown("""
168
+ # 🧠 GPT-152M β€” Trained From Scratch
169
+
170
+ A 152 million parameter language model built with raw PyTorch and trained on
171
+ 197M tokens of educational text (FineWeb-Edu). No pretrained weights were used.
172
+
173
+ **Best results:** Use textbook-style prompts, not search queries.
174
+ """)
175
+
176
+ with gr.Row():
177
+ with gr.Column(scale=2):
178
+ prompt_box = gr.Textbox(
179
+ label="Prompt",
180
+ placeholder="e.g. Quantum mechanics is the branch of physics that",
181
+ lines=3
182
+ )
183
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
184
+ output_box = gr.Textbox(label="Generated Text", lines=8, interactive=False)
185
+
186
+ with gr.Column(scale=1):
187
+ max_tokens = gr.Slider(50, 300, value=150, step=10,
188
+ label="Max new tokens")
189
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05,
190
+ label="Temperature (higher = more creative)")
191
+ top_k = gr.Slider(10, 100, value=50, step=5,
192
+ label="Top-k (lower = more focused)")
193
+ rep_penalty = gr.Slider(1.0, 2.0, value=1.3, step=0.05,
194
+ label="Repetition penalty")
195
+
196
+ gr.Examples(
197
+ examples=examples,
198
+ inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty],
199
+ outputs=output_box,
200
+ fn=generate_text,
201
+ cache_examples=True,
202
+ label="Example prompts β€” click any to try"
203
+ )
204
+
205
+ generate_btn.click(
206
+ fn=generate_text,
207
+ inputs=[prompt_box, max_tokens, temperature, top_k, rep_penalty],
208
+ outputs=output_box
209
+ )
210
+
211
+ gr.Markdown("""
212
+