File size: 6,856 Bytes
3015835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
---
language:
- lt
- en
license: other
license_name: gemma
license_link: https://ai.google.dev/gemma/terms
base_model: google/gemma-4-26B-A4B-it
tags:
- gemma
- moe
- mixture-of-experts
- qlora
- nf4
- fine-tuning
- dpo
- ops-agent
- infrastructure
- tool-calling
- single-gpu
- rtx-3090
datasets:
- custom
pipeline_tag: text-generation
library_name: peft
---

# Gemma 4 26B MoE Fine-Tuning on a Single RTX 3090

> Training a 26-billion parameter Mixture-of-Experts model on consumer hardware β€”
> and making it work.

[![GitHub](https://img.shields.io/badge/GitHub-Pipeline_Code-blue?logo=github)](https://github.com/AndriejusNak/gemma4-26b-moe-finetune)

## The Challenge

Google's `gemma-4-26B-A4B-it` is a 25.81B parameter MoE model with 128 experts per layer, top-8 routing, across 30 transformer layers. The conventional wisdom says you need multiple A100s or H100s to fine-tune this. We did it on a single NVIDIA RTX 3090 (24 GB VRAM).

## Model Details

| | |
|---|---|
| Base Model | [google/gemma-4-26B-A4B-it](https://huggingface.co/google/gemma-4-26B-A4B-it) |
| Total Parameters | 25.81B |
| Active per Token | ~4B (top-8 of 128 experts) |
| PEFT Method | LoRA (r=16, Ξ±=16) on attention (q,k,v,o_proj) |
| Quantization | Custom per-expert NF4 (manual quantize/dequantize) |
| Training Hardware | Single RTX 3090 24GB |
| Peak VRAM | 18.4 GB |
| Framework | PyTorch + PEFT 0.18.1 + bitsandbytes |

## Training Results

### SFT Phase β€” COMPLETE βœ…

| Metric | Value |
|--------|-------|
| Training time | 687 min (11.45 hours) |
| Best loss | **1.2146** (epoch 2 average) |
| Lowest step loss | 1.1653 (step 700) |
| Total steps | 750 (2 epochs) |
| NaN explosions | **0** |
| Training data | 6,119 examples (multi-source) |

<details>
<summary>Full SFT Loss Curve</summary>

```
Step  50 | Loss: 6.8887 | LR: 2.00e-5 | VRAM: 18.0GB
Step 100 | Loss: 2.8130 | LR: 1.96e-5 | VRAM: 18.4GB  (-59%)
Step 150 | Loss: 1.7163 | LR: 1.88e-5 | VRAM: 18.4GB  (-39%)
Step 200 | Loss: 1.5580 | LR: 1.75e-5 | VRAM: 18.4GB  (-9%)
Step 250 | Loss: 1.4741 | LR: 1.59e-5 | VRAM: 17.9GB  (-5.4%)
Step 300 | Loss: 1.3663 | LR: 1.40e-5 | VRAM: 18.3GB  (-7.3%)
Step 350 | Loss: 1.3284 | LR: 1.19e-5 | VRAM: 18.0GB  (-2.8%)
--- Epoch 1 avg: 2.3724 --- Saved! ---
Step 400 | Loss: 1.2911 | LR: 9.74e-6 | VRAM: 18.4GB  (-2.8%)
Step 450 | Loss: 1.2570 | LR: 7.56e-6 | VRAM: 18.0GB  (-2.6%)
Step 500 | Loss: 1.2168 | LR: 5.50e-6 | VRAM: 18.4GB  (-3.2%)
Step 550 | Loss: 1.2311 | LR: 3.66e-6 | VRAM: 18.4GB  (+1.2%)
Step 600 | Loss: 1.2355 | LR: 2.12e-6 | VRAM: 18.4GB  (+0.4%)
Step 650 | Loss: 1.2046 | LR: 2.00e-6 | VRAM: 18.3GB  (-2.5%)
Step 700 | Loss: 1.1653 | LR: 2.00e-6 | VRAM: 18.0GB  (-3.3%) ← lowest
Step 750 | Loss: 1.1688 | LR: 2.00e-6 | VRAM: 18.4GB  (+0.3%)
--- Epoch 2 avg: 1.2146 --- BEST! Saved! ---
```
</details>

### DPO Phase β€” IN PROGRESS πŸ”„

| Metric | Value |
|--------|-------|
| Dataset | 2,708 pairs (30% HARD augmented) |
| Reference cache | Pre-computed in 79.3 min |
| Beta | 0.1 |
| Current loss | 0.7308 (step 50/507, sweet spot) |
| NaN | **0** |

## Key Technical Innovations

### 1. Custom NF4 Expert Quantization

Standard `bitsandbytes` `load_in_4bit` **cannot quantize** MoE expert weights β€” they're stored as `nn.Parameter`, not `nn.Linear`. We quantize each of the 11,520 expert matrices individually.

### 2. Token-Centric Expert Forward Pass

Instead of a Python `for` loop over 128 experts per layer, we batch tokens by their routed expert using `unique_consecutive`, achieving 2-4Γ— speedup with saturated GPU utilization.

### 3. DPO with CPU-Cached Reference Model

DPO normally requires 4 forward passes (policy + reference Γ— chosen + rejected). We pre-compute all reference logprobs once (79.3 min), then train with only 2 forward passes β€” **2Γ— faster**, fitting in 24 GB.

### 4. HARD Adversarial DPO Training

13 categories of adversarial training pairs including tool confusion, overthinking, hallucination guards, wrong parameters, and edge cases. Tool corruption augmentation teaches reasoning correctness, not just output quality.

### 5. Iterative DPO Pipeline

Post-training failure analysis automatically generates targeted HARD pairs for the next iteration, with scripts for confusion matrix analysis, failure classification, and dataset building.

## VRAM Budget

```
Expert NF4 (30 layers Γ— 128 experts):     11.3 GiB
Attention BF16 (30 layers):                 2.4 GiB
Embedding + LM head BF16:                  2.9 GiB
CUDA context:                               0.5 GiB
────────────────────────────────────────────────────
Model total:                               17.1 GiB
LoRA weights (r=16):                       ~22 MB
Activations (seq=512, batch=1):           1-3 GiB
────────────────────────────────────────────────────
Training peak:                            18.4 GiB βœ…
```

## Technical Challenges Solved

| # | Problem | Solution |
|---|---------|----------|
| 1 | bitsandbytes can't quantize nn.Parameter | Manual per-expert NF4 |
| 2 | MoE expert loop is pure Python (slow) | Token-centric batching β€” 2-4Γ— speedup |
| 3 | DPO doubles VRAM (need ref model) | CPU-cached reference logprobs |
| 4 | NaN loss explosions | All-masked filter + FP32 loss + NaN skip |
| 5 | DPO length bias (.sum() rewards) | Changed to .mean() |
| 6 | Multi-turn masking trained on all turns | Mask all except LAST assistant response |
| 7 | PEFT set_adapter() rejects lists | Bypass via model.base_model.set_adapter() |
| 8 | 26B model in 24 GB | Gradient checkpointing + NF4 + attention LoRA |
| 9 | Ollama safetensors converter crashes | llama.cpp convert_hf_to_gguf.py pipeline |
| 10 | HARD overfitting | Prefix augmentation + dynamic generation |

## Usage

This is an SFT LoRA adapter. After DPO completes, a merged DPO adapter will also be uploaded.

```python
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base = AutoModelForCausalLM.from_pretrained("google/gemma-4-26B-A4B-it")
model = PeftModel.from_pretrained(base, "AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
tokenizer = AutoTokenizer.from_pretrained("AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
```

> **Note**: Loading requires the custom NF4 quantization code from the training pipeline. See `v6_26b_pipeline.py` for implementation.

## Files

| File | Description |
|------|-------------|
| `sft_adapter/` | SFT LoRA adapter (r=16, attention-only) |
| `v6_26b_pipeline.py` | Full 6-phase training pipeline (~1700 lines) |

## License

Training pipeline code: MIT. Base model: [Google Gemma License](https://ai.google.dev/gemma/terms).