Text Generation
PEFT
Safetensors
Lithuanian
English
gemma
Mixture of Experts
mixture-of-experts
qlora
nf4
fine-tuning
dpo
ops-agent
infrastructure
tool-calling
single-gpu
rtx-3090
Instructions to use AndriejusNak/gemma4-26b-moe-finetune with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use AndriejusNak/gemma4-26b-moe-finetune with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
Add comprehensive model card
Browse files
README.md
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- lt
|
| 4 |
+
- en
|
| 5 |
+
license: other
|
| 6 |
+
license_name: gemma
|
| 7 |
+
license_link: https://ai.google.dev/gemma/terms
|
| 8 |
+
base_model: google/gemma-4-26B-A4B-it
|
| 9 |
+
tags:
|
| 10 |
+
- gemma
|
| 11 |
+
- moe
|
| 12 |
+
- mixture-of-experts
|
| 13 |
+
- qlora
|
| 14 |
+
- nf4
|
| 15 |
+
- fine-tuning
|
| 16 |
+
- dpo
|
| 17 |
+
- ops-agent
|
| 18 |
+
- infrastructure
|
| 19 |
+
- tool-calling
|
| 20 |
+
- single-gpu
|
| 21 |
+
- rtx-3090
|
| 22 |
+
datasets:
|
| 23 |
+
- custom
|
| 24 |
+
pipeline_tag: text-generation
|
| 25 |
+
library_name: peft
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
# Gemma 4 26B MoE Fine-Tuning on a Single RTX 3090
|
| 29 |
+
|
| 30 |
+
> Training a 26-billion parameter Mixture-of-Experts model on consumer hardware β
|
| 31 |
+
> and making it work.
|
| 32 |
+
|
| 33 |
+
[](https://github.com/AndriejusNak/gemma4-26b-moe-finetune)
|
| 34 |
+
|
| 35 |
+
## The Challenge
|
| 36 |
+
|
| 37 |
+
Google's `gemma-4-26B-A4B-it` is a 25.81B parameter MoE model with 128 experts per layer, top-8 routing, across 30 transformer layers. The conventional wisdom says you need multiple A100s or H100s to fine-tune this. We did it on a single NVIDIA RTX 3090 (24 GB VRAM).
|
| 38 |
+
|
| 39 |
+
## Model Details
|
| 40 |
+
|
| 41 |
+
| | |
|
| 42 |
+
|---|---|
|
| 43 |
+
| Base Model | [google/gemma-4-26B-A4B-it](https://huggingface.co/google/gemma-4-26B-A4B-it) |
|
| 44 |
+
| Total Parameters | 25.81B |
|
| 45 |
+
| Active per Token | ~4B (top-8 of 128 experts) |
|
| 46 |
+
| PEFT Method | LoRA (r=16, Ξ±=16) on attention (q,k,v,o_proj) |
|
| 47 |
+
| Quantization | Custom per-expert NF4 (manual quantize/dequantize) |
|
| 48 |
+
| Training Hardware | Single RTX 3090 24GB |
|
| 49 |
+
| Peak VRAM | 18.4 GB |
|
| 50 |
+
| Framework | PyTorch + PEFT 0.18.1 + bitsandbytes |
|
| 51 |
+
|
| 52 |
+
## Training Results
|
| 53 |
+
|
| 54 |
+
### SFT Phase β COMPLETE β
|
| 55 |
+
|
| 56 |
+
| Metric | Value |
|
| 57 |
+
|--------|-------|
|
| 58 |
+
| Training time | 687 min (11.45 hours) |
|
| 59 |
+
| Best loss | **1.2146** (epoch 2 average) |
|
| 60 |
+
| Lowest step loss | 1.1653 (step 700) |
|
| 61 |
+
| Total steps | 750 (2 epochs) |
|
| 62 |
+
| NaN explosions | **0** |
|
| 63 |
+
| Training data | 6,119 examples (multi-source) |
|
| 64 |
+
|
| 65 |
+
<details>
|
| 66 |
+
<summary>Full SFT Loss Curve</summary>
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
Step 50 | Loss: 6.8887 | LR: 2.00e-5 | VRAM: 18.0GB
|
| 70 |
+
Step 100 | Loss: 2.8130 | LR: 1.96e-5 | VRAM: 18.4GB (-59%)
|
| 71 |
+
Step 150 | Loss: 1.7163 | LR: 1.88e-5 | VRAM: 18.4GB (-39%)
|
| 72 |
+
Step 200 | Loss: 1.5580 | LR: 1.75e-5 | VRAM: 18.4GB (-9%)
|
| 73 |
+
Step 250 | Loss: 1.4741 | LR: 1.59e-5 | VRAM: 17.9GB (-5.4%)
|
| 74 |
+
Step 300 | Loss: 1.3663 | LR: 1.40e-5 | VRAM: 18.3GB (-7.3%)
|
| 75 |
+
Step 350 | Loss: 1.3284 | LR: 1.19e-5 | VRAM: 18.0GB (-2.8%)
|
| 76 |
+
--- Epoch 1 avg: 2.3724 --- Saved! ---
|
| 77 |
+
Step 400 | Loss: 1.2911 | LR: 9.74e-6 | VRAM: 18.4GB (-2.8%)
|
| 78 |
+
Step 450 | Loss: 1.2570 | LR: 7.56e-6 | VRAM: 18.0GB (-2.6%)
|
| 79 |
+
Step 500 | Loss: 1.2168 | LR: 5.50e-6 | VRAM: 18.4GB (-3.2%)
|
| 80 |
+
Step 550 | Loss: 1.2311 | LR: 3.66e-6 | VRAM: 18.4GB (+1.2%)
|
| 81 |
+
Step 600 | Loss: 1.2355 | LR: 2.12e-6 | VRAM: 18.4GB (+0.4%)
|
| 82 |
+
Step 650 | Loss: 1.2046 | LR: 2.00e-6 | VRAM: 18.3GB (-2.5%)
|
| 83 |
+
Step 700 | Loss: 1.1653 | LR: 2.00e-6 | VRAM: 18.0GB (-3.3%) β lowest
|
| 84 |
+
Step 750 | Loss: 1.1688 | LR: 2.00e-6 | VRAM: 18.4GB (+0.3%)
|
| 85 |
+
--- Epoch 2 avg: 1.2146 --- BEST! Saved! ---
|
| 86 |
+
```
|
| 87 |
+
</details>
|
| 88 |
+
|
| 89 |
+
### DPO Phase β IN PROGRESS π
|
| 90 |
+
|
| 91 |
+
| Metric | Value |
|
| 92 |
+
|--------|-------|
|
| 93 |
+
| Dataset | 2,708 pairs (30% HARD augmented) |
|
| 94 |
+
| Reference cache | Pre-computed in 79.3 min |
|
| 95 |
+
| Beta | 0.1 |
|
| 96 |
+
| Current loss | 0.7308 (step 50/507, sweet spot) |
|
| 97 |
+
| NaN | **0** |
|
| 98 |
+
|
| 99 |
+
## Key Technical Innovations
|
| 100 |
+
|
| 101 |
+
### 1. Custom NF4 Expert Quantization
|
| 102 |
+
|
| 103 |
+
Standard `bitsandbytes` `load_in_4bit` **cannot quantize** MoE expert weights β they're stored as `nn.Parameter`, not `nn.Linear`. We quantize each of the 11,520 expert matrices individually.
|
| 104 |
+
|
| 105 |
+
### 2. Token-Centric Expert Forward Pass
|
| 106 |
+
|
| 107 |
+
Instead of a Python `for` loop over 128 experts per layer, we batch tokens by their routed expert using `unique_consecutive`, achieving 2-4Γ speedup with saturated GPU utilization.
|
| 108 |
+
|
| 109 |
+
### 3. DPO with CPU-Cached Reference Model
|
| 110 |
+
|
| 111 |
+
DPO normally requires 4 forward passes (policy + reference Γ chosen + rejected). We pre-compute all reference logprobs once (79.3 min), then train with only 2 forward passes β **2Γ faster**, fitting in 24 GB.
|
| 112 |
+
|
| 113 |
+
### 4. HARD Adversarial DPO Training
|
| 114 |
+
|
| 115 |
+
13 categories of adversarial training pairs including tool confusion, overthinking, hallucination guards, wrong parameters, and edge cases. Tool corruption augmentation teaches reasoning correctness, not just output quality.
|
| 116 |
+
|
| 117 |
+
### 5. Iterative DPO Pipeline
|
| 118 |
+
|
| 119 |
+
Post-training failure analysis automatically generates targeted HARD pairs for the next iteration, with scripts for confusion matrix analysis, failure classification, and dataset building.
|
| 120 |
+
|
| 121 |
+
## VRAM Budget
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
Expert NF4 (30 layers Γ 128 experts): 11.3 GiB
|
| 125 |
+
Attention BF16 (30 layers): 2.4 GiB
|
| 126 |
+
Embedding + LM head BF16: 2.9 GiB
|
| 127 |
+
CUDA context: 0.5 GiB
|
| 128 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
Model total: 17.1 GiB
|
| 130 |
+
LoRA weights (r=16): ~22 MB
|
| 131 |
+
Activations (seq=512, batch=1): 1-3 GiB
|
| 132 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
+
Training peak: 18.4 GiB β
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## Technical Challenges Solved
|
| 137 |
+
|
| 138 |
+
| # | Problem | Solution |
|
| 139 |
+
|---|---------|----------|
|
| 140 |
+
| 1 | bitsandbytes can't quantize nn.Parameter | Manual per-expert NF4 |
|
| 141 |
+
| 2 | MoE expert loop is pure Python (slow) | Token-centric batching β 2-4Γ speedup |
|
| 142 |
+
| 3 | DPO doubles VRAM (need ref model) | CPU-cached reference logprobs |
|
| 143 |
+
| 4 | NaN loss explosions | All-masked filter + FP32 loss + NaN skip |
|
| 144 |
+
| 5 | DPO length bias (.sum() rewards) | Changed to .mean() |
|
| 145 |
+
| 6 | Multi-turn masking trained on all turns | Mask all except LAST assistant response |
|
| 146 |
+
| 7 | PEFT set_adapter() rejects lists | Bypass via model.base_model.set_adapter() |
|
| 147 |
+
| 8 | 26B model in 24 GB | Gradient checkpointing + NF4 + attention LoRA |
|
| 148 |
+
| 9 | Ollama safetensors converter crashes | llama.cpp convert_hf_to_gguf.py pipeline |
|
| 149 |
+
| 10 | HARD overfitting | Prefix augmentation + dynamic generation |
|
| 150 |
+
|
| 151 |
+
## Usage
|
| 152 |
+
|
| 153 |
+
This is an SFT LoRA adapter. After DPO completes, a merged DPO adapter will also be uploaded.
|
| 154 |
+
|
| 155 |
+
```python
|
| 156 |
+
from peft import PeftModel
|
| 157 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 158 |
+
|
| 159 |
+
base = AutoModelForCausalLM.from_pretrained("google/gemma-4-26B-A4B-it")
|
| 160 |
+
model = PeftModel.from_pretrained(base, "AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
|
| 161 |
+
tokenizer = AutoTokenizer.from_pretrained("AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
> **Note**: Loading requires the custom NF4 quantization code from the training pipeline. See `v6_26b_pipeline.py` for implementation.
|
| 165 |
+
|
| 166 |
+
## Files
|
| 167 |
+
|
| 168 |
+
| File | Description |
|
| 169 |
+
|------|-------------|
|
| 170 |
+
| `sft_adapter/` | SFT LoRA adapter (r=16, attention-only) |
|
| 171 |
+
| `v6_26b_pipeline.py` | Full 6-phase training pipeline (~1700 lines) |
|
| 172 |
+
|
| 173 |
+
## License
|
| 174 |
+
|
| 175 |
+
Training pipeline code: MIT. Base model: [Google Gemma License](https://ai.google.dev/gemma/terms).
|