AndriejusNak commited on
Commit
3015835
Β·
verified Β·
1 Parent(s): c527a1c

Add comprehensive model card

Browse files
Files changed (1) hide show
  1. README.md +175 -0
README.md ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - lt
4
+ - en
5
+ license: other
6
+ license_name: gemma
7
+ license_link: https://ai.google.dev/gemma/terms
8
+ base_model: google/gemma-4-26B-A4B-it
9
+ tags:
10
+ - gemma
11
+ - moe
12
+ - mixture-of-experts
13
+ - qlora
14
+ - nf4
15
+ - fine-tuning
16
+ - dpo
17
+ - ops-agent
18
+ - infrastructure
19
+ - tool-calling
20
+ - single-gpu
21
+ - rtx-3090
22
+ datasets:
23
+ - custom
24
+ pipeline_tag: text-generation
25
+ library_name: peft
26
+ ---
27
+
28
+ # Gemma 4 26B MoE Fine-Tuning on a Single RTX 3090
29
+
30
+ > Training a 26-billion parameter Mixture-of-Experts model on consumer hardware β€”
31
+ > and making it work.
32
+
33
+ [![GitHub](https://img.shields.io/badge/GitHub-Pipeline_Code-blue?logo=github)](https://github.com/AndriejusNak/gemma4-26b-moe-finetune)
34
+
35
+ ## The Challenge
36
+
37
+ Google's `gemma-4-26B-A4B-it` is a 25.81B parameter MoE model with 128 experts per layer, top-8 routing, across 30 transformer layers. The conventional wisdom says you need multiple A100s or H100s to fine-tune this. We did it on a single NVIDIA RTX 3090 (24 GB VRAM).
38
+
39
+ ## Model Details
40
+
41
+ | | |
42
+ |---|---|
43
+ | Base Model | [google/gemma-4-26B-A4B-it](https://huggingface.co/google/gemma-4-26B-A4B-it) |
44
+ | Total Parameters | 25.81B |
45
+ | Active per Token | ~4B (top-8 of 128 experts) |
46
+ | PEFT Method | LoRA (r=16, Ξ±=16) on attention (q,k,v,o_proj) |
47
+ | Quantization | Custom per-expert NF4 (manual quantize/dequantize) |
48
+ | Training Hardware | Single RTX 3090 24GB |
49
+ | Peak VRAM | 18.4 GB |
50
+ | Framework | PyTorch + PEFT 0.18.1 + bitsandbytes |
51
+
52
+ ## Training Results
53
+
54
+ ### SFT Phase β€” COMPLETE βœ…
55
+
56
+ | Metric | Value |
57
+ |--------|-------|
58
+ | Training time | 687 min (11.45 hours) |
59
+ | Best loss | **1.2146** (epoch 2 average) |
60
+ | Lowest step loss | 1.1653 (step 700) |
61
+ | Total steps | 750 (2 epochs) |
62
+ | NaN explosions | **0** |
63
+ | Training data | 6,119 examples (multi-source) |
64
+
65
+ <details>
66
+ <summary>Full SFT Loss Curve</summary>
67
+
68
+ ```
69
+ Step 50 | Loss: 6.8887 | LR: 2.00e-5 | VRAM: 18.0GB
70
+ Step 100 | Loss: 2.8130 | LR: 1.96e-5 | VRAM: 18.4GB (-59%)
71
+ Step 150 | Loss: 1.7163 | LR: 1.88e-5 | VRAM: 18.4GB (-39%)
72
+ Step 200 | Loss: 1.5580 | LR: 1.75e-5 | VRAM: 18.4GB (-9%)
73
+ Step 250 | Loss: 1.4741 | LR: 1.59e-5 | VRAM: 17.9GB (-5.4%)
74
+ Step 300 | Loss: 1.3663 | LR: 1.40e-5 | VRAM: 18.3GB (-7.3%)
75
+ Step 350 | Loss: 1.3284 | LR: 1.19e-5 | VRAM: 18.0GB (-2.8%)
76
+ --- Epoch 1 avg: 2.3724 --- Saved! ---
77
+ Step 400 | Loss: 1.2911 | LR: 9.74e-6 | VRAM: 18.4GB (-2.8%)
78
+ Step 450 | Loss: 1.2570 | LR: 7.56e-6 | VRAM: 18.0GB (-2.6%)
79
+ Step 500 | Loss: 1.2168 | LR: 5.50e-6 | VRAM: 18.4GB (-3.2%)
80
+ Step 550 | Loss: 1.2311 | LR: 3.66e-6 | VRAM: 18.4GB (+1.2%)
81
+ Step 600 | Loss: 1.2355 | LR: 2.12e-6 | VRAM: 18.4GB (+0.4%)
82
+ Step 650 | Loss: 1.2046 | LR: 2.00e-6 | VRAM: 18.3GB (-2.5%)
83
+ Step 700 | Loss: 1.1653 | LR: 2.00e-6 | VRAM: 18.0GB (-3.3%) ← lowest
84
+ Step 750 | Loss: 1.1688 | LR: 2.00e-6 | VRAM: 18.4GB (+0.3%)
85
+ --- Epoch 2 avg: 1.2146 --- BEST! Saved! ---
86
+ ```
87
+ </details>
88
+
89
+ ### DPO Phase β€” IN PROGRESS πŸ”„
90
+
91
+ | Metric | Value |
92
+ |--------|-------|
93
+ | Dataset | 2,708 pairs (30% HARD augmented) |
94
+ | Reference cache | Pre-computed in 79.3 min |
95
+ | Beta | 0.1 |
96
+ | Current loss | 0.7308 (step 50/507, sweet spot) |
97
+ | NaN | **0** |
98
+
99
+ ## Key Technical Innovations
100
+
101
+ ### 1. Custom NF4 Expert Quantization
102
+
103
+ Standard `bitsandbytes` `load_in_4bit` **cannot quantize** MoE expert weights β€” they're stored as `nn.Parameter`, not `nn.Linear`. We quantize each of the 11,520 expert matrices individually.
104
+
105
+ ### 2. Token-Centric Expert Forward Pass
106
+
107
+ Instead of a Python `for` loop over 128 experts per layer, we batch tokens by their routed expert using `unique_consecutive`, achieving 2-4Γ— speedup with saturated GPU utilization.
108
+
109
+ ### 3. DPO with CPU-Cached Reference Model
110
+
111
+ DPO normally requires 4 forward passes (policy + reference Γ— chosen + rejected). We pre-compute all reference logprobs once (79.3 min), then train with only 2 forward passes β€” **2Γ— faster**, fitting in 24 GB.
112
+
113
+ ### 4. HARD Adversarial DPO Training
114
+
115
+ 13 categories of adversarial training pairs including tool confusion, overthinking, hallucination guards, wrong parameters, and edge cases. Tool corruption augmentation teaches reasoning correctness, not just output quality.
116
+
117
+ ### 5. Iterative DPO Pipeline
118
+
119
+ Post-training failure analysis automatically generates targeted HARD pairs for the next iteration, with scripts for confusion matrix analysis, failure classification, and dataset building.
120
+
121
+ ## VRAM Budget
122
+
123
+ ```
124
+ Expert NF4 (30 layers Γ— 128 experts): 11.3 GiB
125
+ Attention BF16 (30 layers): 2.4 GiB
126
+ Embedding + LM head BF16: 2.9 GiB
127
+ CUDA context: 0.5 GiB
128
+ ────────────────────────────────────────────────────
129
+ Model total: 17.1 GiB
130
+ LoRA weights (r=16): ~22 MB
131
+ Activations (seq=512, batch=1): 1-3 GiB
132
+ ────────────────────────────────────────────────────
133
+ Training peak: 18.4 GiB βœ…
134
+ ```
135
+
136
+ ## Technical Challenges Solved
137
+
138
+ | # | Problem | Solution |
139
+ |---|---------|----------|
140
+ | 1 | bitsandbytes can't quantize nn.Parameter | Manual per-expert NF4 |
141
+ | 2 | MoE expert loop is pure Python (slow) | Token-centric batching β€” 2-4Γ— speedup |
142
+ | 3 | DPO doubles VRAM (need ref model) | CPU-cached reference logprobs |
143
+ | 4 | NaN loss explosions | All-masked filter + FP32 loss + NaN skip |
144
+ | 5 | DPO length bias (.sum() rewards) | Changed to .mean() |
145
+ | 6 | Multi-turn masking trained on all turns | Mask all except LAST assistant response |
146
+ | 7 | PEFT set_adapter() rejects lists | Bypass via model.base_model.set_adapter() |
147
+ | 8 | 26B model in 24 GB | Gradient checkpointing + NF4 + attention LoRA |
148
+ | 9 | Ollama safetensors converter crashes | llama.cpp convert_hf_to_gguf.py pipeline |
149
+ | 10 | HARD overfitting | Prefix augmentation + dynamic generation |
150
+
151
+ ## Usage
152
+
153
+ This is an SFT LoRA adapter. After DPO completes, a merged DPO adapter will also be uploaded.
154
+
155
+ ```python
156
+ from peft import PeftModel
157
+ from transformers import AutoModelForCausalLM, AutoTokenizer
158
+
159
+ base = AutoModelForCausalLM.from_pretrained("google/gemma-4-26B-A4B-it")
160
+ model = PeftModel.from_pretrained(base, "AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
161
+ tokenizer = AutoTokenizer.from_pretrained("AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
162
+ ```
163
+
164
+ > **Note**: Loading requires the custom NF4 quantization code from the training pipeline. See `v6_26b_pipeline.py` for implementation.
165
+
166
+ ## Files
167
+
168
+ | File | Description |
169
+ |------|-------------|
170
+ | `sft_adapter/` | SFT LoRA adapter (r=16, attention-only) |
171
+ | `v6_26b_pipeline.py` | Full 6-phase training pipeline (~1700 lines) |
172
+
173
+ ## License
174
+
175
+ Training pipeline code: MIT. Base model: [Google Gemma License](https://ai.google.dev/gemma/terms).