NaveenKumar Namachivayam commited on
Commit
b817849
·
1 Parent(s): ba05e77

feat: add Thirukkural Tamil text dataset with English translations

Browse files

- Add complete Thirukkural text (7046 lines) with Tamil verses and English translations
- Include all 133 chapters covering virtue, wealth, and love
- Format each kural with Tamil original, transliteration, and English couplet translation
- Organize by sections: domestic virtue, ascetic virtue, royalty, love, and more

data/thirukkural_clean.txt ADDED
The diff for this file is too large to render. See raw diff
 
hf-space/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Valluvar or AI?
3
+ emoji: 🕉️
4
+ colorFrom: orange
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.x
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Valluvar or AI? 🕉️
14
+
15
+ An AI that writes new Thirukkurals in the style of Thiruvalluvar.
16
+
17
+ ## Features
18
+
19
+ - **Generate Kural**: Enter a Tamil theme and get a bilingual couplet
20
+ - **Valluvar or AI Quiz**: Can you tell which is original and which is AI-generated?
21
+ - **Temperature Control**: Adjust creativity from coherent (0.5) to wild (2.0)
22
+
23
+ ## Model
24
+
25
+ - **Architecture**: GPT (8L/8H/512D, 25.4M params)
26
+ - **Training Data**: Thirukkural (1330 kurals + English translations)
27
+ - **Tokenization**: Character-level
28
+
29
+ ## Examples
30
+
31
+ **Traditional themes work great:**
32
+ - `கடவுள் வாழ்த்து` (Praise of God) ✅
33
+ - `அரசியல்` (Politics/Governance) ✅
34
+ - `நட்பு` (Friendship) ✅
35
+
36
+ **Modern topics don't work:**
37
+ - `விஞ்ஞானம்` (Science) ❌
38
+ - `கணிதம்` (Mathematics) ❌
39
+
40
+ The model learned Thiruvalluvar's form and traditional themes, but not modern concepts.
41
+
42
+ ## How to Use
43
+
44
+ 1. Enter a Tamil word or theme
45
+ 2. Adjust temperature (0.8 recommended)
46
+ 3. Click Generate
47
+ 4. See if the model memorized or created something new!
hf-space/app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for Thirukkural GPT - Valluvar or AI."""
2
+ import random
3
+ import re
4
+
5
+ import gradio as gr
6
+ import torch
7
+
8
+ from model import GPT, GPTConfig
9
+
10
+
11
+ def load_model():
12
+ """Load the trained model and tokenizer."""
13
+ # Allow GPTConfig for safe loading
14
+ from model import GPTConfig
15
+ torch.serialization.add_safe_globals([GPTConfig])
16
+ checkpoint = torch.load("checkpoint_final.pt", map_location="cpu", weights_only=True)
17
+ config = checkpoint["config"]
18
+ stoi = checkpoint["stoi"]
19
+ itos = checkpoint["itos"]
20
+
21
+ model = GPT(config)
22
+ model.load_state_dict(checkpoint["model_state_dict"])
23
+ model.eval()
24
+
25
+ return model, stoi, itos
26
+
27
+
28
+ def generate(model, prompt, stoi, itos, max_new_tokens=200, temperature=0.8, device="cpu"):
29
+ """Generate text from prompt."""
30
+ model = model.to(device)
31
+
32
+ # Encode prompt
33
+ prompt_tokens = [stoi.get(c, stoi.get(" ", 0)) for c in prompt]
34
+ idx = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
35
+
36
+ # Generate
37
+ with torch.no_grad():
38
+ for _ in range(max_new_tokens):
39
+ # Crop to block size
40
+ idx_cond = idx[:, -model.config.block_size :]
41
+
42
+ # Get predictions
43
+ logits, _ = model(idx_cond)
44
+ logits = logits[:, -1, :] / temperature
45
+
46
+ # Sample
47
+ probs = torch.softmax(logits, dim=-1)
48
+ idx_next = torch.multinomial(probs, num_samples=1)
49
+
50
+ # Append
51
+ idx = torch.cat((idx, idx_next), dim=1)
52
+
53
+ # Decode
54
+ tokens = idx[0].tolist()
55
+ result = "".join([itos.get(t, "") for t in tokens])
56
+ return result
57
+
58
+
59
+ def is_real_kural(text, original_text):
60
+ """Check if generated text exists in original kurals.
61
+
62
+ A kural is considered "real" if:
63
+ 1. The Tamil couplet (2 lines) exists in original
64
+ 2. The English translation matches
65
+ """
66
+ lines = text.strip().split("\n")
67
+
68
+ # Get Tamil lines (contain Tamil Unicode)
69
+ tamil_lines = [l.strip() for l in lines if re.search(r"[\u0B80-\u0BFF]", l)]
70
+ # Get English lines (no Tamil, just text)
71
+ english_lines = [l.strip() for l in lines if l.strip() and not re.search(r"[\u0B80-\u0BFF]", l)]
72
+
73
+ if len(tamil_lines) < 2:
74
+ return False
75
+
76
+ # Check if Tamil couplet exists in original
77
+ first_tamil = tamil_lines[0]
78
+ second_tamil = tamil_lines[1] if len(tamil_lines) > 1 else ""
79
+
80
+ # A true kural needs both Tamil lines to exist consecutively
81
+ tamil_couplet = first_tamil + "\n" + second_tamil
82
+ if tamil_couplet not in original_text:
83
+ return False
84
+
85
+ # Also check that English lines roughly match (at least one should exist)
86
+ if english_lines:
87
+ first_english = english_lines[0]
88
+ # Check if this English translation exists near the Tamil
89
+ return first_english in original_text
90
+
91
+ return True
92
+
93
+
94
+ # Load model and data
95
+ print("Loading model...")
96
+ model, stoi, itos = load_model()
97
+ print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params")
98
+
99
+ # Load original text for verification
100
+ with open("thirukkural_clean.txt", "r", encoding="utf-8") as f:
101
+ ORIGINAL_TEXT = f.read()
102
+
103
+
104
+ def generate_kural(prompt, temperature, max_tokens):
105
+ """Generate and format kural with proper structure."""
106
+ # Generate with higher token count to ensure complete kural
107
+ output_raw = generate(model, prompt, stoi, itos, int(max_tokens) + 100, temperature)
108
+
109
+ # Extract first complete kural from generated text
110
+ lines = output_raw.strip().split("\n")
111
+
112
+ # Find the first proper kural (skip headers, get 2 Tamil + 2 English lines)
113
+ tamil_lines = []
114
+ english_lines = []
115
+
116
+ for line in lines:
117
+ line = line.strip()
118
+ if not line or " - " in line:
119
+ continue
120
+ # Skip short Tamil headers (1-2 words)
121
+ if re.search(r"[\u0B80-\u0BFF]", line) and len(line.split()) <= 2 and not re.search(r"[a-zA-Z]", line):
122
+ continue
123
+
124
+ if re.search(r"[\u0B80-\u0BFF]", line):
125
+ if len(tamil_lines) < 2:
126
+ tamil_lines.append(line)
127
+ elif line and len(english_lines) < 2:
128
+ english_lines.append(line)
129
+
130
+ # Build formatted output
131
+ formatted_lines = []
132
+ if tamil_lines:
133
+ formatted_lines.extend(tamil_lines[:2])
134
+ if english_lines:
135
+ formatted_lines.extend(english_lines[:2])
136
+
137
+ output = "\n".join(formatted_lines) if formatted_lines else format_kural(output_raw)
138
+
139
+ # Check if real or AI
140
+ is_real = is_real_kural(output_raw, ORIGINAL_TEXT)
141
+ source = "📖 Original Thirukkural" if is_real else "🤖 AI Generated"
142
+
143
+ return output, source
144
+
145
+
146
+ def format_kural(text):
147
+ """Format kural text with proper structure (2 Tamil + 2 English lines)."""
148
+ lines = text.strip().split("\n")
149
+
150
+ # Skip headers: lines with " - " OR short single Tamil words (chapter names)
151
+ def is_header(line):
152
+ # Headers have " - " or are short Tamil-only phrases (1-3 words)
153
+ if " - " in line:
154
+ return True
155
+ # Check if it's a short Tamil phrase (likely a chapter title)
156
+ if re.search(r"[\u0B80-\u0BFF]", line) and len(line.split()) <= 3:
157
+ # And no English words
158
+ if not re.search(r"[a-zA-Z]", line):
159
+ return True
160
+ return False
161
+
162
+ content_lines = [l.strip() for l in lines if l.strip() and not is_header(l)]
163
+
164
+ # Classify lines
165
+ tamil_lines = [l for l in content_lines if re.search(r"[\u0B80-\u0BFF]", l)]
166
+ english_lines = [l for l in content_lines if l and not re.search(r"[\u0B80-\u0BFF]", l)]
167
+
168
+ # Build proper 4-line kural
169
+ formatted = []
170
+
171
+ # Tamil couplet (2 lines)
172
+ if len(tamil_lines) >= 2:
173
+ formatted.extend(tamil_lines[:2])
174
+ elif len(tamil_lines) == 1:
175
+ formatted.append(tamil_lines[0])
176
+ formatted.append("") # Placeholder
177
+
178
+ # English translation (2 lines)
179
+ if len(english_lines) >= 2:
180
+ formatted.extend(english_lines[:2])
181
+ elif len(english_lines) == 1:
182
+ formatted.append(english_lines[0])
183
+ formatted.append("")
184
+
185
+ return "\n".join(formatted)
186
+
187
+
188
+ def valluvar_or_ai_quiz():
189
+ """Generate a quiz: one real, one AI."""
190
+ # Get random real kural - find a proper 4-line kural
191
+ lines = ORIGINAL_TEXT.strip().split("\n")
192
+
193
+ # Find a random valid kural (2 Tamil + 2 English lines)
194
+ attempts = 0
195
+ real_kural = ""
196
+ while attempts < 100:
197
+ idx = random.randint(0, len(lines) - 4)
198
+ chunk = lines[idx:idx+4]
199
+ tamil_count = sum(1 for l in chunk if re.search(r"[\u0B80-\u0BFF]", l))
200
+ english_count = sum(1 for l in chunk if l.strip() and not re.search(r"[\u0B80-\u0BFF]", l))
201
+ if tamil_count == 2 and english_count == 2:
202
+ real_kural = "\n".join(chunk).strip()
203
+ break
204
+ attempts += 1
205
+
206
+ # Fallback if no proper kural found
207
+ if not real_kural:
208
+ real_kural = "அகர முதல எழுத்தெல்லாம் ஆதி\nபகவன் முதற்றே உலகு\n'A' leads letters; the Ancient Lord\nLeads and lords the entire world"
209
+
210
+ # Generate AI kural with random prompt
211
+ prompts = ["கடவுள் வாழ்த்து", "நட்பு", "அறன்", "வான் சிறப்பு", "அரசியல்"]
212
+ prompt = random.choice(prompts)
213
+ ai_kural_raw = generate(model, prompt, stoi, itos, 150, 0.8)
214
+ ai_kural = format_kural(ai_kural_raw)
215
+
216
+ # Format real kural too
217
+ real_kural = format_kural(real_kural)
218
+
219
+ # Shuffle
220
+ kurals = [("A", real_kural, True), ("B", ai_kural, False)]
221
+ random.shuffle(kurals)
222
+
223
+ return (
224
+ f"## Option A\n```\n{kurals[0][1]}\n```\n\n---\n\n## Option B\n```\n{kurals[1][1]}\n```",
225
+ kurals[0][2],
226
+ kurals[1][2],
227
+ "A" if kurals[0][2] else "B",
228
+ )
229
+
230
+
231
+ # Gradio Interface
232
+ with gr.Blocks(title="Valluvar or AI?") as demo:
233
+ gr.Markdown("# 🕉️ Valluvar or AI?")
234
+ gr.Markdown(
235
+ "An AI that writes new Thirukkurals in the style of Thiruvalluvar. "
236
+ "Enter a Tamil theme to generate bilingual wisdom."
237
+ )
238
+
239
+ with gr.Tab("✨ Generate Kural"):
240
+ with gr.Row():
241
+ with gr.Column():
242
+ prompt = gr.Textbox(
243
+ label="Theme (Tamil)",
244
+ placeholder="e.g., கடவுள் வாழ்த்து, நட்பு, அரசியல்",
245
+ value="கடவுள் வாழ்த்து",
246
+ )
247
+ temperature = gr.Slider(
248
+ minimum=0.1,
249
+ maximum=2.0,
250
+ value=0.8,
251
+ step=0.1,
252
+ label="Temperature (Creativity)",
253
+ )
254
+ max_tokens = gr.Slider(
255
+ minimum=50,
256
+ maximum=400,
257
+ value=200,
258
+ step=50,
259
+ label="Max Tokens",
260
+ )
261
+ generate_btn = gr.Button("Generate", variant="primary")
262
+
263
+ with gr.Column():
264
+ output = gr.Textbox(
265
+ label="Generated Kural",
266
+ lines=10,
267
+ )
268
+ source = gr.Textbox(label="Source")
269
+
270
+ generate_btn.click(
271
+ fn=generate_kural,
272
+ inputs=[prompt, temperature, max_tokens],
273
+ outputs=[output, source],
274
+ )
275
+
276
+ # Quick theme buttons
277
+ gr.Markdown("### Quick Themes")
278
+ with gr.Row():
279
+ themes = [
280
+ "கடவுள் வாழ்த்து",
281
+ "வான் சிறப்பு",
282
+ "நட்பு",
283
+ "அரசியல்",
284
+ "அறன் வலியுறுத்தல்",
285
+ ]
286
+ for theme in themes:
287
+ btn = gr.Button(theme, size="sm")
288
+ btn.click(lambda t=theme: t, outputs=prompt)
289
+
290
+ with gr.Tab("🎯 Valluvar or AI? Quiz"):
291
+ gr.Markdown("Can you tell which is the original Thirukkural and which is AI-generated?")
292
+
293
+ quiz_output = gr.Markdown()
294
+ with gr.Row():
295
+ guess_a = gr.Button("Option A is Real", variant="secondary")
296
+ guess_b = gr.Button("Option B is Real", variant="secondary")
297
+ quiz_result = gr.Markdown()
298
+ new_quiz_btn = gr.Button("New Quiz", variant="primary")
299
+
300
+ # Store answers
301
+ a_is_real = gr.State()
302
+ b_is_real = gr.State()
303
+ correct_answer = gr.State()
304
+
305
+ def check_answer(guess, a_real, b_real, correct):
306
+ if guess == correct:
307
+ return "✅ Correct! You identified the original Thirukkural."
308
+ return "❌ Wrong! The original Thirukkural was: " + correct
309
+
310
+ new_quiz_btn.click(
311
+ fn=valluvar_or_ai_quiz,
312
+ outputs=[quiz_output, a_is_real, b_is_real, correct_answer],
313
+ )
314
+
315
+ guess_a.click(
316
+ fn=lambda a, b, c: check_answer("A", a, b, c),
317
+ inputs=[a_is_real, b_is_real, correct_answer],
318
+ outputs=quiz_result,
319
+ )
320
+
321
+ guess_b.click(
322
+ fn=lambda a, b, c: check_answer("B", a, b, c),
323
+ inputs=[a_is_real, b_is_real, correct_answer],
324
+ outputs=quiz_result,
325
+ )
326
+
327
+ with gr.Tab("📊 About"):
328
+ gr.Markdown(
329
+ f"""
330
+ ## Model Details
331
+
332
+ - **Architecture:** GPT ({model.config.n_layer}L/{model.config.n_head}H/{model.config.n_embd}D)
333
+ - **Parameters:** {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M
334
+ - **Vocabulary:** {len(stoi)} characters (Tamil + English)
335
+ - **Training Data:** Thirukkural (1330 kurals with English translations)
336
+ - **Tokenization:** Character-level
337
+
338
+ ## Training
339
+
340
+ - Steps: 10,000
341
+ - Device: Apple MPS (Mac Mini)
342
+ - Time: ~5 hours
343
+ - Final Loss: ~1.5
344
+
345
+ ## Capabilities
346
+
347
+ - ✅ Generate authentic Tamil couplets (2 lines × 4 words)
348
+ - ✅ Produce coherent English translations
349
+ - ✅ Handle traditional themes (virtue, politics, love)
350
+ - ❌ Modern topics (science, technology) - not in training data
351
+
352
+ ## Examples of AI vs Original
353
+
354
+ The model sometimes generates exact memorized kurals from the 1330,
355
+ and sometimes creates entirely new ones in Thiruvalluvar's style.
356
+
357
+ Built with ❤️ using PyTorch and Gradio.
358
+ """
359
+ )
360
+
361
+
362
+ if __name__ == "__main__":
363
+ demo.launch()
hf-space/model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT Model Architecture for Thirukkural Training."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ @dataclass
10
+ class GPTConfig:
11
+ """Configuration for GPT model."""
12
+
13
+ vocab_size: int = 65 # Will be set dynamically based on dataset
14
+ block_size: int = 256 # Max sequence length
15
+ n_layer: int = 6 # Number of transformer blocks
16
+ n_head: int = 6 # Number of attention heads
17
+ n_embd: int = 384 # Embedding dimension
18
+
19
+
20
+ class CausalSelfAttention(nn.Module):
21
+ """Multi-head causal self-attention layer."""
22
+
23
+ def __init__(self, config: GPTConfig):
24
+ super().__init__()
25
+ assert config.n_embd % config.n_head == 0
26
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
27
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
28
+ self.n_head = config.n_head
29
+ self.n_embd = config.n_embd
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ B, T, C = x.shape
33
+ qkv = self.c_attn(x)
34
+ q, k, v = qkv.split(self.n_embd, dim=2)
35
+
36
+ head_dim = C // self.n_head
37
+ q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
38
+ k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)
39
+ v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)
40
+
41
+ y = torch.nn.functional.scaled_dot_product_attention(
42
+ q, k, v, is_causal=True
43
+ )
44
+
45
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
46
+ return self.c_proj(y)
47
+
48
+
49
+ class MLP(nn.Module):
50
+ """Feed-forward network with GELU activation."""
51
+
52
+ def __init__(self, config: GPTConfig):
53
+ super().__init__()
54
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
55
+ self.gelu = nn.GELU(approximate="tanh")
56
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.c_fc(x)
60
+ x = self.gelu(x)
61
+ return self.c_proj(x)
62
+
63
+
64
+ class Block(nn.Module):
65
+ """Transformer block with attention and MLP."""
66
+
67
+ def __init__(self, config: GPTConfig):
68
+ super().__init__()
69
+ self.ln_1 = nn.LayerNorm(config.n_embd)
70
+ self.attn = CausalSelfAttention(config)
71
+ self.ln_2 = nn.LayerNorm(config.n_embd)
72
+ self.mlp = MLP(config)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ x = x + self.attn(self.ln_1(x))
76
+ x = x + self.mlp(self.ln_2(x))
77
+ return x
78
+
79
+
80
+ class GPT(nn.Module):
81
+ """GPT language model."""
82
+
83
+ def __init__(self, config: GPTConfig):
84
+ super().__init__()
85
+ self.config = config
86
+ self.transformer = nn.ModuleDict(
87
+ dict(
88
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
89
+ wpe=nn.Embedding(config.block_size, config.n_embd),
90
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
91
+ ln_f=nn.LayerNorm(config.n_embd),
92
+ )
93
+ )
94
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
95
+ self.transformer.wte.weight = self.lm_head.weight
96
+
97
+ def forward(
98
+ self, idx: torch.Tensor, targets: torch.Tensor | None = None
99
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
100
+ B, T = idx.shape
101
+ pos = torch.arange(0, T, device=idx.device)
102
+
103
+ tok_emb = self.transformer.wte(idx)
104
+ pos_emb = self.transformer.wpe(pos)
105
+ x = tok_emb + pos_emb
106
+
107
+ for block in self.transformer.h:
108
+ x = block(x)
109
+
110
+ x = self.transformer.ln_f(x)
111
+ logits = self.lm_head(x)
112
+
113
+ loss = None
114
+ if targets is not None:
115
+ loss = nn.functional.cross_entropy(
116
+ logits.view(-1, logits.size(-1)), targets.view(-1)
117
+ )
118
+ return logits, loss
hf-space/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
hf-space/thirukkural_clean.txt ADDED
The diff for this file is too large to render. See raw diff
 
train.py CHANGED
@@ -61,11 +61,11 @@ def get_lr(
61
 
62
  def train(
63
  data_path: str,
64
- max_steps: int = 5000,
65
  batch_size: int = 64,
66
- n_layer: int = 6,
67
- n_head: int = 6,
68
- n_embd: int = 384,
69
  block_size: int = 256,
70
  ) -> tuple[GPT, dict[str, int], dict[int, str]]:
71
  """Train a GPT model on the given dataset."""
 
61
 
62
  def train(
63
  data_path: str,
64
+ max_steps: int = 10000,
65
  batch_size: int = 64,
66
+ n_layer: int = 8,
67
+ n_head: int = 8,
68
+ n_embd: int = 512,
69
  block_size: int = 256,
70
  ) -> tuple[GPT, dict[str, int], dict[int, str]]:
71
  """Train a GPT model on the given dataset."""