Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +118 -0
- config.json +18 -0
- layer_types.json +22 -0
- loss_curves.png +3 -0
- lr_schedule.png +0 -0
- pytorch_model.bin +3 -0
- training_config.json +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
loss_curves.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- gemma3
|
| 5 |
+
- language-model
|
| 6 |
+
- pre-training
|
| 7 |
+
- from-scratch
|
| 8 |
+
- tinystories
|
| 9 |
+
- transformer
|
| 10 |
+
- multi-query-attention
|
| 11 |
+
- sliding-window-attention
|
| 12 |
+
- rope
|
| 13 |
+
language:
|
| 14 |
+
- en
|
| 15 |
+
datasets:
|
| 16 |
+
- roneneldan/TinyStories
|
| 17 |
+
metrics:
|
| 18 |
+
- perplexity
|
| 19 |
+
pipeline_tag: text-generation
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# Gemma 3 270M โ Pre-trained from Scratch on TinyStories
|
| 23 |
+
|
| 24 |
+
A custom implementation of the **Gemma 3 architecture** (scaled to 164.6M parameters), pre-trained from scratch on the TinyStories dataset.
|
| 25 |
+
|
| 26 |
+
## ๐ Results
|
| 27 |
+
|
| 28 |
+
| Metric | Value |
|
| 29 |
+
|--------|-------|
|
| 30 |
+
| **Best Val Loss** | 1.7845 |
|
| 31 |
+
| **Perplexity** | 5.96 |
|
| 32 |
+
| **Best Iteration** | 13,000 |
|
| 33 |
+
| **Parameters** | 164.6M |
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
## ๐๏ธ Architecture
|
| 38 |
+
|
| 39 |
+
This model implements the **complete Gemma 3 architecture** with all modern innovations:
|
| 40 |
+
|
| 41 |
+
| Component | Specification |
|
| 42 |
+
|-----------|--------------|
|
| 43 |
+
| Layers | 18 (15 sliding + 3 full attention) |
|
| 44 |
+
| Embedding Dim | 640 |
|
| 45 |
+
| Attention Heads | 4 (Multi-Query, 1 KV group) |
|
| 46 |
+
| Head Dimension | 256 |
|
| 47 |
+
| FFN Hidden | 2,048 (GeGLU activation) |
|
| 48 |
+
| Context Length | 32,768 tokens |
|
| 49 |
+
| Vocabulary | 50,257 (GPT-2 BPE) |
|
| 50 |
+
|
| 51 |
+
### Key Features
|
| 52 |
+
- **Sliding Window Attention** (w=512): O(nรw) instead of O(nยฒ), 64ร cheaper
|
| 53 |
+
- **Multi-Query Attention**: All query heads share 1 K,V head โ 4ร less KV cache
|
| 54 |
+
- **RoPE with Dual Bases**: 10K (local patterns) + 1M (long-range dependencies)
|
| 55 |
+
- **QK Normalization**: RMSNorm on Q,K vectors before attention
|
| 56 |
+
- **Gemma-style RMSNorm**: (1 + weight) scaling for stable initialization
|
| 57 |
+
- **GeGLU Feed-Forward**: Gated GELU activation with 3.2ร expansion
|
| 58 |
+
|
| 59 |
+
### Layer Type Pattern
|
| 60 |
+
```
|
| 61 |
+
Layers 1-5: Sliding Attention (local, base=10K)
|
| 62 |
+
Layer 6: Full Attention (global, base=1M)
|
| 63 |
+
Layers 7-11: Sliding Attention (local, base=10K)
|
| 64 |
+
Layer 12: Full Attention (global, base=1M)
|
| 65 |
+
Layers 13-17: Sliding Attention (local, base=10K)
|
| 66 |
+
Layer 18: Full Attention (global, base=1M)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## ๐ Training
|
| 70 |
+
|
| 71 |
+
- **Dataset**: TinyStories (2.1M stories, 471M tokens)
|
| 72 |
+
- **Tokenizer**: GPT-2 BPE via tiktoken (50,257 vocab)
|
| 73 |
+
- **Optimizer**: AdamW (ฮฒ1=0.9, ฮฒ2=0.95, ฮต=1e-9, weight_decay=0.1)
|
| 74 |
+
- **Learning Rate**: 1e-4 โ 5e-5 (cosine decay with 1K step warmup)
|
| 75 |
+
- **Precision**: bfloat16 mixed precision
|
| 76 |
+
- **Hardware**: NVIDIA A100 40GB (Google Colab Pro)
|
| 77 |
+
- **Gradient Clipping**: max_norm=0.5
|
| 78 |
+
|
| 79 |
+
## ๐ป Usage
|
| 80 |
+
```python
|
| 81 |
+
import torch
|
| 82 |
+
import tiktoken
|
| 83 |
+
|
| 84 |
+
# Load model (you'll need the model class definition)
|
| 85 |
+
model = Gemma3Model(config)
|
| 86 |
+
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
|
| 87 |
+
model.load_state_dict(state_dict)
|
| 88 |
+
model.eval()
|
| 89 |
+
|
| 90 |
+
# Tokenize
|
| 91 |
+
enc = tiktoken.get_encoding("gpt2")
|
| 92 |
+
prompt = "Once upon a time"
|
| 93 |
+
input_ids = torch.tensor([enc.encode_ordinary(prompt)])
|
| 94 |
+
|
| 95 |
+
# Generate
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
output = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
|
| 98 |
+
print(enc.decode(output[0].tolist()))
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## ๐ Sample Outputs
|
| 102 |
+
|
| 103 |
+
**Prompt**: "Once upon a time, there was a little cat named Mittens"
|
| 104 |
+
|
| 105 |
+
**Temperature 0.7**: *Mittens was very hungry and wanted to eat some food. She went outside
|
| 106 |
+
to find some grass to eat. Mittens saw a big tree and decided to climb it. She climbed up and
|
| 107 |
+
up until she reached the top. As she was in the tree, she saw a small bird with a broken wing.
|
| 108 |
+
Mittens knew just what to do. She took the bird to her mom and asked for help.*
|
| 109 |
+
|
| 110 |
+
## ๐ Credits
|
| 111 |
+
|
| 112 |
+
- **Architecture Reference**: Vizuara Team - Raj ([Tutorial](https://youtu.be/bLDlwcl6hbA))
|
| 113 |
+
- **Dataset**: TinyStories by Ronen Eldan & Yuanzhi Li
|
| 114 |
+
- **Tokenizer**: OpenAI tiktoken (GPT-2 BPE)
|
| 115 |
+
|
| 116 |
+
## ๐ License
|
| 117 |
+
|
| 118 |
+
Apache 2.0
|
config.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "Gemma3Custom",
|
| 3 |
+
"vocab_size": 50257,
|
| 4 |
+
"context_length": 32768,
|
| 5 |
+
"emb_dim": 640,
|
| 6 |
+
"n_layers": 18,
|
| 7 |
+
"n_heads": 4,
|
| 8 |
+
"head_dim": 256,
|
| 9 |
+
"hidden_dim": 2048,
|
| 10 |
+
"n_kv_groups": 1,
|
| 11 |
+
"qk_norm": true,
|
| 12 |
+
"query_pre_attn_scalar": 256,
|
| 13 |
+
"rope_base": 1000000.0,
|
| 14 |
+
"rope_local_base": 10000.0,
|
| 15 |
+
"sliding_window": 512,
|
| 16 |
+
"dtype": "bfloat16",
|
| 17 |
+
"total_parameters": 164631936
|
| 18 |
+
}
|
layer_types.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"layer_types": [
|
| 3 |
+
"sliding_attention",
|
| 4 |
+
"sliding_attention",
|
| 5 |
+
"sliding_attention",
|
| 6 |
+
"sliding_attention",
|
| 7 |
+
"sliding_attention",
|
| 8 |
+
"full_attention",
|
| 9 |
+
"sliding_attention",
|
| 10 |
+
"sliding_attention",
|
| 11 |
+
"sliding_attention",
|
| 12 |
+
"sliding_attention",
|
| 13 |
+
"sliding_attention",
|
| 14 |
+
"full_attention",
|
| 15 |
+
"sliding_attention",
|
| 16 |
+
"sliding_attention",
|
| 17 |
+
"sliding_attention",
|
| 18 |
+
"sliding_attention",
|
| 19 |
+
"sliding_attention",
|
| 20 |
+
"full_attention"
|
| 21 |
+
]
|
| 22 |
+
}
|
loss_curves.png
ADDED
|
Git LFS Details
|
lr_schedule.png
ADDED
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f93bc6fbab3eb74705e5a412b0b44eb031fdd1af4203bb1f6fbd8f92878c5f2c
|
| 3 |
+
size 329341959
|
training_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_iters": 60000,
|
| 3 |
+
"batch_size": 32,
|
| 4 |
+
"block_size": 128,
|
| 5 |
+
"gradient_accumulation_steps": 4,
|
| 6 |
+
"learning_rate": 0.0001,
|
| 7 |
+
"min_lr": 5e-05,
|
| 8 |
+
"warmup_steps": 1000,
|
| 9 |
+
"beta1": 0.9,
|
| 10 |
+
"beta2": 0.95,
|
| 11 |
+
"weight_decay": 0.1,
|
| 12 |
+
"gradient_clip_norm": 0.5,
|
| 13 |
+
"best_val_loss": 1.7845207452774048,
|
| 14 |
+
"best_iteration": 13000,
|
| 15 |
+
"perplexity": 5.96,
|
| 16 |
+
"dataset": "roneneldan/TinyStories",
|
| 17 |
+
"tokenizer": "gpt2 (tiktoken)",
|
| 18 |
+
"hardware": "NVIDIA A100 40GB"
|
| 19 |
+
}
|