G3nadh commited on
Commit
71496e2
ยท
verified ยท
1 Parent(s): d14cdd8

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ loss_curves.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - gemma3
5
+ - language-model
6
+ - pre-training
7
+ - from-scratch
8
+ - tinystories
9
+ - transformer
10
+ - multi-query-attention
11
+ - sliding-window-attention
12
+ - rope
13
+ language:
14
+ - en
15
+ datasets:
16
+ - roneneldan/TinyStories
17
+ metrics:
18
+ - perplexity
19
+ pipeline_tag: text-generation
20
+ ---
21
+
22
+ # Gemma 3 270M โ€” Pre-trained from Scratch on TinyStories
23
+
24
+ A custom implementation of the **Gemma 3 architecture** (scaled to 164.6M parameters), pre-trained from scratch on the TinyStories dataset.
25
+
26
+ ## ๐Ÿ“Š Results
27
+
28
+ | Metric | Value |
29
+ |--------|-------|
30
+ | **Best Val Loss** | 1.7845 |
31
+ | **Perplexity** | 5.96 |
32
+ | **Best Iteration** | 13,000 |
33
+ | **Parameters** | 164.6M |
34
+
35
+ ![Training Loss Curves](loss_curves.png)
36
+
37
+ ## ๐Ÿ—๏ธ Architecture
38
+
39
+ This model implements the **complete Gemma 3 architecture** with all modern innovations:
40
+
41
+ | Component | Specification |
42
+ |-----------|--------------|
43
+ | Layers | 18 (15 sliding + 3 full attention) |
44
+ | Embedding Dim | 640 |
45
+ | Attention Heads | 4 (Multi-Query, 1 KV group) |
46
+ | Head Dimension | 256 |
47
+ | FFN Hidden | 2,048 (GeGLU activation) |
48
+ | Context Length | 32,768 tokens |
49
+ | Vocabulary | 50,257 (GPT-2 BPE) |
50
+
51
+ ### Key Features
52
+ - **Sliding Window Attention** (w=512): O(nร—w) instead of O(nยฒ), 64ร— cheaper
53
+ - **Multi-Query Attention**: All query heads share 1 K,V head โ€” 4ร— less KV cache
54
+ - **RoPE with Dual Bases**: 10K (local patterns) + 1M (long-range dependencies)
55
+ - **QK Normalization**: RMSNorm on Q,K vectors before attention
56
+ - **Gemma-style RMSNorm**: (1 + weight) scaling for stable initialization
57
+ - **GeGLU Feed-Forward**: Gated GELU activation with 3.2ร— expansion
58
+
59
+ ### Layer Type Pattern
60
+ ```
61
+ Layers 1-5: Sliding Attention (local, base=10K)
62
+ Layer 6: Full Attention (global, base=1M)
63
+ Layers 7-11: Sliding Attention (local, base=10K)
64
+ Layer 12: Full Attention (global, base=1M)
65
+ Layers 13-17: Sliding Attention (local, base=10K)
66
+ Layer 18: Full Attention (global, base=1M)
67
+ ```
68
+
69
+ ## ๐Ÿ“– Training
70
+
71
+ - **Dataset**: TinyStories (2.1M stories, 471M tokens)
72
+ - **Tokenizer**: GPT-2 BPE via tiktoken (50,257 vocab)
73
+ - **Optimizer**: AdamW (ฮฒ1=0.9, ฮฒ2=0.95, ฮต=1e-9, weight_decay=0.1)
74
+ - **Learning Rate**: 1e-4 โ†’ 5e-5 (cosine decay with 1K step warmup)
75
+ - **Precision**: bfloat16 mixed precision
76
+ - **Hardware**: NVIDIA A100 40GB (Google Colab Pro)
77
+ - **Gradient Clipping**: max_norm=0.5
78
+
79
+ ## ๐Ÿ’ป Usage
80
+ ```python
81
+ import torch
82
+ import tiktoken
83
+
84
+ # Load model (you'll need the model class definition)
85
+ model = Gemma3Model(config)
86
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
87
+ model.load_state_dict(state_dict)
88
+ model.eval()
89
+
90
+ # Tokenize
91
+ enc = tiktoken.get_encoding("gpt2")
92
+ prompt = "Once upon a time"
93
+ input_ids = torch.tensor([enc.encode_ordinary(prompt)])
94
+
95
+ # Generate
96
+ with torch.no_grad():
97
+ output = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
98
+ print(enc.decode(output[0].tolist()))
99
+ ```
100
+
101
+ ## ๐Ÿ“ Sample Outputs
102
+
103
+ **Prompt**: "Once upon a time, there was a little cat named Mittens"
104
+
105
+ **Temperature 0.7**: *Mittens was very hungry and wanted to eat some food. She went outside
106
+ to find some grass to eat. Mittens saw a big tree and decided to climb it. She climbed up and
107
+ up until she reached the top. As she was in the tree, she saw a small bird with a broken wing.
108
+ Mittens knew just what to do. She took the bird to her mom and asked for help.*
109
+
110
+ ## ๐Ÿ™ Credits
111
+
112
+ - **Architecture Reference**: Vizuara Team - Raj ([Tutorial](https://youtu.be/bLDlwcl6hbA))
113
+ - **Dataset**: TinyStories by Ronen Eldan & Yuanzhi Li
114
+ - **Tokenizer**: OpenAI tiktoken (GPT-2 BPE)
115
+
116
+ ## ๐Ÿ“„ License
117
+
118
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "Gemma3Custom",
3
+ "vocab_size": 50257,
4
+ "context_length": 32768,
5
+ "emb_dim": 640,
6
+ "n_layers": 18,
7
+ "n_heads": 4,
8
+ "head_dim": 256,
9
+ "hidden_dim": 2048,
10
+ "n_kv_groups": 1,
11
+ "qk_norm": true,
12
+ "query_pre_attn_scalar": 256,
13
+ "rope_base": 1000000.0,
14
+ "rope_local_base": 10000.0,
15
+ "sliding_window": 512,
16
+ "dtype": "bfloat16",
17
+ "total_parameters": 164631936
18
+ }
layer_types.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "layer_types": [
3
+ "sliding_attention",
4
+ "sliding_attention",
5
+ "sliding_attention",
6
+ "sliding_attention",
7
+ "sliding_attention",
8
+ "full_attention",
9
+ "sliding_attention",
10
+ "sliding_attention",
11
+ "sliding_attention",
12
+ "sliding_attention",
13
+ "sliding_attention",
14
+ "full_attention",
15
+ "sliding_attention",
16
+ "sliding_attention",
17
+ "sliding_attention",
18
+ "sliding_attention",
19
+ "sliding_attention",
20
+ "full_attention"
21
+ ]
22
+ }
loss_curves.png ADDED

Git LFS Details

  • SHA256: d316ec76582a2a1da379c9d24a174cb2eabb7b602333a49e237f09aea84b77b2
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
lr_schedule.png ADDED
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f93bc6fbab3eb74705e5a412b0b44eb031fdd1af4203bb1f6fbd8f92878c5f2c
3
+ size 329341959
training_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_iters": 60000,
3
+ "batch_size": 32,
4
+ "block_size": 128,
5
+ "gradient_accumulation_steps": 4,
6
+ "learning_rate": 0.0001,
7
+ "min_lr": 5e-05,
8
+ "warmup_steps": 1000,
9
+ "beta1": 0.9,
10
+ "beta2": 0.95,
11
+ "weight_decay": 0.1,
12
+ "gradient_clip_norm": 0.5,
13
+ "best_val_loss": 1.7845207452774048,
14
+ "best_iteration": 13000,
15
+ "perplexity": 5.96,
16
+ "dataset": "roneneldan/TinyStories",
17
+ "tokenizer": "gpt2 (tiktoken)",
18
+ "hardware": "NVIDIA A100 40GB"
19
+ }