AndriejusNak's picture
Add comprehensive model card
3015835 verified
metadata
language:
  - lt
  - en
license: other
license_name: gemma
license_link: https://ai.google.dev/gemma/terms
base_model: google/gemma-4-26B-A4B-it
tags:
  - gemma
  - moe
  - mixture-of-experts
  - qlora
  - nf4
  - fine-tuning
  - dpo
  - ops-agent
  - infrastructure
  - tool-calling
  - single-gpu
  - rtx-3090
datasets:
  - custom
pipeline_tag: text-generation
library_name: peft

Gemma 4 26B MoE Fine-Tuning on a Single RTX 3090

Training a 26-billion parameter Mixture-of-Experts model on consumer hardware β€” and making it work.

GitHub

The Challenge

Google's gemma-4-26B-A4B-it is a 25.81B parameter MoE model with 128 experts per layer, top-8 routing, across 30 transformer layers. The conventional wisdom says you need multiple A100s or H100s to fine-tune this. We did it on a single NVIDIA RTX 3090 (24 GB VRAM).

Model Details

Base Model google/gemma-4-26B-A4B-it
Total Parameters 25.81B
Active per Token ~4B (top-8 of 128 experts)
PEFT Method LoRA (r=16, Ξ±=16) on attention (q,k,v,o_proj)
Quantization Custom per-expert NF4 (manual quantize/dequantize)
Training Hardware Single RTX 3090 24GB
Peak VRAM 18.4 GB
Framework PyTorch + PEFT 0.18.1 + bitsandbytes

Training Results

SFT Phase β€” COMPLETE βœ…

Metric Value
Training time 687 min (11.45 hours)
Best loss 1.2146 (epoch 2 average)
Lowest step loss 1.1653 (step 700)
Total steps 750 (2 epochs)
NaN explosions 0
Training data 6,119 examples (multi-source)
Full SFT Loss Curve
Step  50 | Loss: 6.8887 | LR: 2.00e-5 | VRAM: 18.0GB
Step 100 | Loss: 2.8130 | LR: 1.96e-5 | VRAM: 18.4GB  (-59%)
Step 150 | Loss: 1.7163 | LR: 1.88e-5 | VRAM: 18.4GB  (-39%)
Step 200 | Loss: 1.5580 | LR: 1.75e-5 | VRAM: 18.4GB  (-9%)
Step 250 | Loss: 1.4741 | LR: 1.59e-5 | VRAM: 17.9GB  (-5.4%)
Step 300 | Loss: 1.3663 | LR: 1.40e-5 | VRAM: 18.3GB  (-7.3%)
Step 350 | Loss: 1.3284 | LR: 1.19e-5 | VRAM: 18.0GB  (-2.8%)
--- Epoch 1 avg: 2.3724 --- Saved! ---
Step 400 | Loss: 1.2911 | LR: 9.74e-6 | VRAM: 18.4GB  (-2.8%)
Step 450 | Loss: 1.2570 | LR: 7.56e-6 | VRAM: 18.0GB  (-2.6%)
Step 500 | Loss: 1.2168 | LR: 5.50e-6 | VRAM: 18.4GB  (-3.2%)
Step 550 | Loss: 1.2311 | LR: 3.66e-6 | VRAM: 18.4GB  (+1.2%)
Step 600 | Loss: 1.2355 | LR: 2.12e-6 | VRAM: 18.4GB  (+0.4%)
Step 650 | Loss: 1.2046 | LR: 2.00e-6 | VRAM: 18.3GB  (-2.5%)
Step 700 | Loss: 1.1653 | LR: 2.00e-6 | VRAM: 18.0GB  (-3.3%) ← lowest
Step 750 | Loss: 1.1688 | LR: 2.00e-6 | VRAM: 18.4GB  (+0.3%)
--- Epoch 2 avg: 1.2146 --- BEST! Saved! ---

DPO Phase β€” IN PROGRESS πŸ”„

Metric Value
Dataset 2,708 pairs (30% HARD augmented)
Reference cache Pre-computed in 79.3 min
Beta 0.1
Current loss 0.7308 (step 50/507, sweet spot)
NaN 0

Key Technical Innovations

1. Custom NF4 Expert Quantization

Standard bitsandbytes load_in_4bit cannot quantize MoE expert weights β€” they're stored as nn.Parameter, not nn.Linear. We quantize each of the 11,520 expert matrices individually.

2. Token-Centric Expert Forward Pass

Instead of a Python for loop over 128 experts per layer, we batch tokens by their routed expert using unique_consecutive, achieving 2-4Γ— speedup with saturated GPU utilization.

3. DPO with CPU-Cached Reference Model

DPO normally requires 4 forward passes (policy + reference Γ— chosen + rejected). We pre-compute all reference logprobs once (79.3 min), then train with only 2 forward passes β€” 2Γ— faster, fitting in 24 GB.

4. HARD Adversarial DPO Training

13 categories of adversarial training pairs including tool confusion, overthinking, hallucination guards, wrong parameters, and edge cases. Tool corruption augmentation teaches reasoning correctness, not just output quality.

5. Iterative DPO Pipeline

Post-training failure analysis automatically generates targeted HARD pairs for the next iteration, with scripts for confusion matrix analysis, failure classification, and dataset building.

VRAM Budget

Expert NF4 (30 layers Γ— 128 experts):     11.3 GiB
Attention BF16 (30 layers):                 2.4 GiB
Embedding + LM head BF16:                  2.9 GiB
CUDA context:                               0.5 GiB
────────────────────────────────────────────────────
Model total:                               17.1 GiB
LoRA weights (r=16):                       ~22 MB
Activations (seq=512, batch=1):           1-3 GiB
────────────────────────────────────────────────────
Training peak:                            18.4 GiB βœ…

Technical Challenges Solved

# Problem Solution
1 bitsandbytes can't quantize nn.Parameter Manual per-expert NF4
2 MoE expert loop is pure Python (slow) Token-centric batching β€” 2-4Γ— speedup
3 DPO doubles VRAM (need ref model) CPU-cached reference logprobs
4 NaN loss explosions All-masked filter + FP32 loss + NaN skip
5 DPO length bias (.sum() rewards) Changed to .mean()
6 Multi-turn masking trained on all turns Mask all except LAST assistant response
7 PEFT set_adapter() rejects lists Bypass via model.base_model.set_adapter()
8 26B model in 24 GB Gradient checkpointing + NF4 + attention LoRA
9 Ollama safetensors converter crashes llama.cpp convert_hf_to_gguf.py pipeline
10 HARD overfitting Prefix augmentation + dynamic generation

Usage

This is an SFT LoRA adapter. After DPO completes, a merged DPO adapter will also be uploaded.

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base = AutoModelForCausalLM.from_pretrained("google/gemma-4-26B-A4B-it")
model = PeftModel.from_pretrained(base, "AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")
tokenizer = AutoTokenizer.from_pretrained("AndriejusNak/gemma4-26b-moe-finetune", subfolder="sft_adapter")

Note: Loading requires the custom NF4 quantization code from the training pipeline. See v6_26b_pipeline.py for implementation.

Files

File Description
sft_adapter/ SFT LoRA adapter (r=16, attention-only)
v6_26b_pipeline.py Full 6-phase training pipeline (~1700 lines)

License

Training pipeline code: MIT. Base model: Google Gemma License.