#!/usr/bin/env python3 """ Janus V6 Training Pipeline — google/gemma-4-26B-A4B-it (27B MoE) ================================================================== CORRECT model: 26B total, 4B active per token (128 experts, top-8) Standard HuggingFace + PEFT + bitsandbytes. NO Unsloth. Phase 0: Download model (if not cached) Phase 1: SFT with QLoRA (manual training loop) Phase 2: DPO with QLoRA (manual training loop) Phase 3: Merge LoRA → bf16 on CPU (uses swap) Phase 4: GGUF Q4_K_M via llama.cpp Phase 5: Ollama deploy + smoke test Usage: python3 v6_26b_pipeline.py --phase 0 # Download only python3 v6_26b_pipeline.py --phase 1 # SFT python3 v6_26b_pipeline.py --phase 2 # DPO (requires SFT adapter) python3 v6_26b_pipeline.py --phase 3 # Merge python3 v6_26b_pipeline.py --phase 4 # GGUF python3 v6_26b_pipeline.py --phase 5 # Ollama python3 v6_26b_pipeline.py --phase all # Run all phases python3 v6_26b_pipeline.py --phase diag # Architecture diagnostics only """ import os, sys, json, time, glob, re, random, subprocess, gc, argparse os.environ["TORCHDYNAMO_DISABLE"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch # ===== CONFIGURATION ===== MODEL_NAME = "google/gemma-4-26B-A4B-it" BASE_DIR = "/home/andriejus/janus_auto" ADAPTER_DIR = f"{BASE_DIR}/adapters/janus_v6" MERGED_DIR = f"{BASE_DIR}/merged_models/gemma4_janus_v6" GGUF_DIR = f"{BASE_DIR}/merged_models" LLAMA_CPP = f"{BASE_DIR}/llama.cpp" MAX_SEQ_LENGTH = 512 # SFT: safe for 24GB VRAM DPO_SEQ_LENGTH = 768 # DPO: longer for ReAct/tool-chains (only 2 fwd per pair with cache) LORA_R = 16 LORA_ALPHA = 16 # For 26B MoE: attention-only LoRA by default (safe for 24GB VRAM) # 128 experts × 30 layers × 3 proj = 11520 modules — too many for LoRA # Attention: 30 layers × 4 proj = 120 modules — manageable # MLP LoRA: gate/up/down_proj in router MLP (NOT expert params!) # adds ~2.6GB but improves tool-calling generation quality TARGET_MODULES_ATTENTION = ["q_proj", "k_proj", "v_proj", "o_proj"] TARGET_MODULES_FULL = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] INCLUDE_MLP_LORA = False # SFT: attention-only (safe) DPO_INCLUDE_MLP_LORA = True # DPO: add MLP for better tool output generation DPO_MLP_LORA_R = 8 # Lower rank for MLP — less risk of destabilizing MoE routing # SFT Config SFT_EPOCHS = 2 SFT_LR = 2e-5 SFT_BATCH_SIZE = 1 SFT_GRAD_ACCUM = 16 SFT_WARMUP_RATIO = 0.05 SFT_MAX_GRAD_NORM = 1.0 # DPO Config DPO_EPOCHS = 3 DPO_LR = 2e-6 DPO_BETA = 0.1 DPO_BATCH_SIZE = 1 DPO_GRAD_ACCUM = 16 DPO_HARD_TARGET_RATIO = 0.30 # Upsample HARD examples to 30% of DPO dataset # Data Files SFT_FILES = [ f"{BASE_DIR}/data/janus_full_training.jsonl", f"{BASE_DIR}/data/tool_calling_react_seed.jsonl", f"{BASE_DIR}/data/tool_calling_train.jsonl", f"{BASE_DIR}/data/janus_quality_train.jsonl", f"{BASE_DIR}/data/finetuning/react_tool_examples_300.jsonl", f"{BASE_DIR}/data/finetuning/tool_calling_examples.jsonl", f"{BASE_DIR}/data/training/react_hard40.jsonl", f"{BASE_DIR}/data/training/training_data.jsonl", ] DPO_FILE = f"{BASE_DIR}/data/training_queue/dpo_v6_full.jsonl" LOG_FILE = "/tmp/v6_training.log" CHECKPOINT_FILE = "/tmp/v6_checkpoint.json" # ===== LOGGING ===== def log(msg): ts = time.strftime("%Y-%m-%d %H:%M:%S") line = f"[{ts}] {msg}" print(line, flush=True) with open(LOG_FILE, "a") as f: f.write(line + "\n") def save_checkpoint(phase, step=0, extra=None): data = {"phase": phase, "step": step, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")} if extra: data.update(extra) with open(CHECKPOINT_FILE, "w") as f: json.dump(data, f) # ===== PHASE 0: DOWNLOAD MODEL ===== def phase_download(): log("=" * 60) log("PHASE 0: DOWNLOAD google/gemma-4-26B-A4B-it") log("=" * 60) from transformers import AutoTokenizer, AutoConfig # Check if already cached cache_dir = os.path.expanduser("~/.cache/huggingface/hub") model_cache = os.path.join(cache_dir, "models--google--gemma-4-26B-A4B-it") if os.path.exists(model_cache): # Check if complete snapshots = os.path.join(model_cache, "snapshots") if os.path.exists(snapshots) and os.listdir(snapshots): snap = os.listdir(snapshots)[0] snap_dir = os.path.join(snapshots, snap) safetensors = glob.glob(os.path.join(snap_dir, "*.safetensors")) if len(safetensors) >= 5: # 26B model has multiple shards total_size = sum(os.path.getsize(f) for f in safetensors) log(f" Model already cached: {len(safetensors)} shards, {total_size/1e9:.1f} GB") log(f" Location: {snap_dir}") return True log(" Model not in cache. Starting download (~52 GB)...") log(" This will take a while depending on internet speed.") # Download config first (fast) log(" Downloading config...") config = AutoConfig.from_pretrained(MODEL_NAME) log(f" Architecture: {config.architectures}") log(f" Model type: {config.model_type}") text_cfg = config.text_config log(f" Hidden size: {text_cfg.hidden_size}") log(f" Num layers: {text_cfg.num_hidden_layers}") log(f" Num experts: {text_cfg.num_experts}") log(f" Top-k experts: {text_cfg.top_k_experts}") log(f" MoE intermediate: {text_cfg.moe_intermediate_size}") # Download tokenizer (fast) log(" Downloading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) log(f" Vocab size: {tokenizer.vocab_size}") # Verify turn tokens exist for tok_name in ["", "", "<|turn>", ""]: tok_id = tokenizer.convert_tokens_to_ids(tok_name) if tok_id != tokenizer.unk_token_id: log(f" Token '{tok_name}' -> ID {tok_id}") else: log(f" Token '{tok_name}' -> NOT FOUND (UNK)") # Download full model weights using snapshot_download (no RAM needed) log(" Downloading model weights via snapshot_download (no model loading)...") from huggingface_hub import snapshot_download start = time.time() local_dir = snapshot_download( MODEL_NAME, ignore_patterns=["*.gguf", "*.bin"], # Only safetensors + configs ) elapsed = time.time() - start # Count safetensors safetensors = glob.glob(os.path.join(local_dir, "*.safetensors")) total_size = sum(os.path.getsize(f) for f in safetensors) if safetensors else 0 log(f" Download complete in {elapsed:.0f}s") log(f" Location: {local_dir}") log(f" Safetensors: {len(safetensors)} shards, {total_size/1e9:.1f} GB") log("PHASE 0 COMPLETE — Model downloaded and cached") save_checkpoint("download_complete") return True # ===== PHASE DIAG: ARCHITECTURE DIAGNOSTICS ===== def phase_diagnostics(): log("=" * 60) log("DIAGNOSTICS: 26B-A4B-it Architecture Analysis") log("=" * 60) from transformers import AutoTokenizer, BitsAndBytesConfig, Gemma4ForConditionalGeneration log("Loading model in 4-bit for diagnostics...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model = Gemma4ForConditionalGeneration.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # 1. Top-level structure log("\n--- TOP-LEVEL CHILDREN ---") for n, m in model.named_children(): log(f" {n}: {type(m).__name__}") # 2. Module type census log("\n--- MODULE TYPE CENSUS ---") type_counts = {} for n, m in model.named_modules(): tname = type(m).__name__ type_counts[tname] = type_counts.get(tname, 0) + 1 for tname, count in sorted(type_counts.items(), key=lambda x: -x[1])[:20]: log(f" {tname}: {count}") # 3. Linear module names (for LoRA targeting) log("\n--- LINEAR MODULE NAMES (samples) ---") linear_names = set() for n, m in model.named_modules(): if "Linear" in type(m).__name__: # Extract last part of name (e.g., "q_proj", "gate_proj") last = n.rsplit(".", 1)[-1] if "." in n else n linear_names.add(last) for name in sorted(linear_names): log(f" {name}") # 4. Expert structure sample log("\n--- EXPERT STRUCTURE (layer 0) ---") for n, m in model.named_modules(): if "layer" in n and (".0." in n or "layers.0" in n) and "expert" in n.lower(): log(f" {n}: {type(m).__name__}") if hasattr(m, 'weight'): log(f" weight shape: {m.weight.shape}") # 5. ClippableLinear check log("\n--- ClippableLinear CHECK ---") clippable_count = sum(1 for n, m in model.named_modules() if "Clippable" in type(m).__name__) log(f" ClippableLinear modules: {clippable_count}") # 6. LoRA target estimate log("\n--- LORA MODULE COUNT ESTIMATE ---") attn_count = 0 mlp_count = 0 for n, m in model.named_modules(): if "Linear" in type(m).__name__: last = n.rsplit(".", 1)[-1] if last in ["q_proj", "k_proj", "v_proj", "o_proj"]: attn_count += 1 elif last in ["gate_proj", "up_proj", "down_proj"]: mlp_count += 1 log(f" Attention projections: {attn_count} (LoRA targets if attention-only)") log(f" MLP projections: {mlp_count} (includes MoE expert projections)") log(f" Total if all targeted: {attn_count + mlp_count}") # 7. VRAM usage after loading if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1e9 reserved = torch.cuda.memory_reserved() / 1e9 log(f"\n--- VRAM AFTER Q4 LOAD ---") log(f" Allocated: {alloc:.2f} GB") log(f" Reserved: {reserved:.2f} GB") # 8. Test forward pass (sanity) log("\n--- FORWARD PASS TEST ---") test_messages = [ {"role": "user", "content": "Kas yra Docker?"}, {"role": "assistant", "content": "Docker yra konteinerizacijos platforma."} ] test_text = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=False) log(f" Chat template output (first 200 chars): {test_text[:200]}") test_ids = tokenizer.encode(test_text, add_special_tokens=False, return_tensors="pt").to("cuda") log(f" Token count: {test_ids.shape[1]}") with torch.no_grad(): mm_types = torch.zeros_like(test_ids) out = model(input_ids=test_ids, labels=test_ids.clone(), mm_token_type_ids=mm_types) log(f" Loss: {out.loss.item():.4f}") del model gc.collect() torch.cuda.empty_cache() log("\nDIAGNOSTICS COMPLETE") # ===== DATA LOADING ===== def text_to_messages(text): """Convert pre-formatted Gemma3/4 text to messages list. Handles both / and <|turn>/ formats. Preserves native Gemma role names (developer, model) and also handles standard HF role names (system, assistant). """ messages = [] # Role normalization for apply_chat_template compatibility ROLE_MAP = { "developer": "system", # Gemma4 developer → HF system "model": "assistant", # Gemma4 model → HF assistant "user": "user", "system": "system", "assistant": "assistant", } # Try Gemma3 format: role\ncontent pattern3 = r'(\w+)\n(.*?)' matches = re.findall(pattern3, text, re.DOTALL) if matches: for role, content in matches: mapped_role = ROLE_MAP.get(role, role) messages.append({"role": mapped_role, "content": content.strip()}) return messages if messages else None # Try Gemma4 format: <|turn>role\ncontent pattern4 = r'<\|turn>(\w+)\n(.*?)' matches = re.findall(pattern4, text, re.DOTALL) if matches: for role, content in matches: mapped_role = ROLE_MAP.get(role, role) messages.append({"role": mapped_role, "content": content.strip()}) return messages if messages else None return None def load_sft_data(tokenizer): """Load all SFT files, convert to Gemma4 format via apply_chat_template.""" all_texts = [] skipped = 0 for sf in SFT_FILES: if not os.path.exists(sf): log(f" SKIP (not found): {sf}") continue loaded = 0 with open(sf, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: d = json.loads(line) messages = None if "messages" in d: messages = d["messages"] elif "text" in d: messages = text_to_messages(d["text"]) elif "instruction" in d and "output" in d: inp = d.get("input", "") user_content = f"{d['instruction']}\n{inp}".strip() if inp else d["instruction"] messages = [ {"role": "user", "content": user_content}, {"role": "assistant", "content": d["output"]} ] elif "prompt" in d and "response" in d: messages = [ {"role": "user", "content": d["prompt"]}, {"role": "assistant", "content": d["response"]} ] elif "instruction" in d and "response" in d: messages = [ {"role": "user", "content": d["instruction"]}, {"role": "assistant", "content": d["response"]} ] if messages and any(m.get("role") == "assistant" for m in messages): if all(m.get("content", "").strip() for m in messages): g4_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) all_texts.append(g4_text) loaded += 1 else: skipped += 1 except (json.JSONDecodeError, KeyError, TypeError): continue log(f" Loaded {os.path.basename(sf)}: {loaded} examples") log(f" Total SFT: {len(all_texts)} examples, skipped: {skipped}") return all_texts def load_dpo_data(tokenizer): """Load DPO pairs from dpo_v6_full.jsonl. Upsample HARD with augmentation.""" pairs = [] hard_pairs = [] easy_pairs = [] hard_raw = [] # Keep raw data for augmentation if not os.path.exists(DPO_FILE): log(f" DPO file not found: {DPO_FILE}") return pairs with open(DPO_FILE, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: d = json.loads(line) prompt = d["prompt"] chosen = d["chosen"] rejected = d["rejected"] difficulty = d.get("difficulty", "easy") # Format as full conversations chosen_msgs = [ {"role": "user", "content": prompt}, {"role": "assistant", "content": chosen} ] rejected_msgs = [ {"role": "user", "content": prompt}, {"role": "assistant", "content": rejected} ] chosen_text = tokenizer.apply_chat_template( chosen_msgs, tokenize=False, add_generation_prompt=False ) rejected_text = tokenizer.apply_chat_template( rejected_msgs, tokenize=False, add_generation_prompt=False ) prompt_text = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True ) entry = { "prompt_text": prompt_text, "chosen_text": chosen_text, "rejected_text": rejected_text, } if difficulty == "hard": hard_pairs.append(entry) hard_raw.append({"prompt": prompt, "chosen": chosen, "rejected": rejected}) else: easy_pairs.append(entry) except (json.JSONDecodeError, KeyError): continue orig_hard = len(hard_pairs) orig_total = len(hard_pairs) + len(easy_pairs) orig_ratio = orig_hard / max(orig_total, 1) # Upsample HARD pairs to target ratio — with prompt augmentation if orig_hard > 0 and orig_ratio < DPO_HARD_TARGET_RATIO: target_hard = int(DPO_HARD_TARGET_RATIO * len(easy_pairs) / (1 - DPO_HARD_TARGET_RATIO)) # First: include all originals augmented_hard = list(hard_pairs) # Augmentation strategies: # 1. Prompt prefix variations (reduces surface overfitting) # 2. Tool corruption (trains reasoning correctness, not just length) augment_prefixes = [ "", # original (already included) "Prašau, ", # "Please, " "Padėk man: ", # "Help me: " "Reikia pagalbos — ", # "Need help — " "Skubiai: ", # "Urgently: " "Klausimas: ", # "Question: " "Ar galėtum ", # "Could you " ] # Tool corruption patterns for rejected responses # These teach the model to prefer correct tool usage over wrong tools tool_corruptions = [ # (pattern_to_find, replacement) — applied to rejected only ("gpu_status()", "docker_ps()"), ("docker_logs(", "disk_usage("), ("docker_restart(", "answer("), ("shell(", "http_get("), ("http_get(", "shell("), ("memory_usage()", "uptime()"), ] aug_idx = 1 # Start from prefix[1] (prefix[0] is original) corruption_idx = 0 while len(augmented_hard) < target_hard: for raw, orig_entry in zip(hard_raw, hard_pairs): if len(augmented_hard) >= target_hard: break # Alternate: prefix augmentation vs tool corruption if aug_idx % 3 == 0 and corruption_idx < len(tool_corruptions): # Tool corruption: corrupt the rejected response pattern, replacement = tool_corruptions[corruption_idx % len(tool_corruptions)] corrupted_rejected = raw["rejected"] if pattern in raw["chosen"]: # Create rejected by corrupting the chosen (wrong tool) corrupted_rejected = raw["chosen"].replace(pattern, replacement, 1) elif pattern in raw["rejected"]: # Double-corrupt: make rejected even worse corrupted_rejected = raw["rejected"].replace(pattern, replacement, 1) else: # Fallback: use prefix augmentation instead prefix = augment_prefixes[aug_idx % len(augment_prefixes)] if prefix: aug_prompt = prefix + raw["prompt"] else: augmented_hard.append(orig_entry) continue aug_chosen_msgs = [ {"role": "user", "content": aug_prompt}, {"role": "assistant", "content": raw["chosen"]} ] aug_rejected_msgs = [ {"role": "user", "content": aug_prompt}, {"role": "assistant", "content": raw["rejected"]} ] aug_entry = { "prompt_text": tokenizer.apply_chat_template( [{"role": "user", "content": aug_prompt}], tokenize=False, add_generation_prompt=True ), "chosen_text": tokenizer.apply_chat_template( aug_chosen_msgs, tokenize=False, add_generation_prompt=False ), "rejected_text": tokenizer.apply_chat_template( aug_rejected_msgs, tokenize=False, add_generation_prompt=False ), } augmented_hard.append(aug_entry) continue # Build tool-corrupted DPO pair (same prompt, correct chosen, corrupted rejected) tc_chosen_msgs = [ {"role": "user", "content": raw["prompt"]}, {"role": "assistant", "content": raw["chosen"]} ] tc_rejected_msgs = [ {"role": "user", "content": raw["prompt"]}, {"role": "assistant", "content": corrupted_rejected} ] tc_entry = { "prompt_text": tokenizer.apply_chat_template( [{"role": "user", "content": raw["prompt"]}], tokenize=False, add_generation_prompt=True ), "chosen_text": tokenizer.apply_chat_template( tc_chosen_msgs, tokenize=False, add_generation_prompt=False ), "rejected_text": tokenizer.apply_chat_template( tc_rejected_msgs, tokenize=False, add_generation_prompt=False ), } augmented_hard.append(tc_entry) corruption_idx += 1 else: # Prefix augmentation prefix = augment_prefixes[aug_idx % len(augment_prefixes)] if prefix: aug_prompt = prefix + raw["prompt"] aug_chosen_msgs = [ {"role": "user", "content": aug_prompt}, {"role": "assistant", "content": raw["chosen"]} ] aug_rejected_msgs = [ {"role": "user", "content": aug_prompt}, {"role": "assistant", "content": raw["rejected"]} ] aug_entry = { "prompt_text": tokenizer.apply_chat_template( [{"role": "user", "content": aug_prompt}], tokenize=False, add_generation_prompt=True ), "chosen_text": tokenizer.apply_chat_template( aug_chosen_msgs, tokenize=False, add_generation_prompt=False ), "rejected_text": tokenizer.apply_chat_template( aug_rejected_msgs, tokenize=False, add_generation_prompt=False ), } augmented_hard.append(aug_entry) else: # Original copy (simple repeat as fallback) augmented_hard.append(orig_entry) aug_idx += 1 pairs = easy_pairs + augmented_hard final_hard = len(augmented_hard) else: pairs = easy_pairs + hard_pairs final_hard = orig_hard final_ratio = final_hard / max(len(pairs), 1) log(f" Loaded DPO: {orig_total} pairs ({orig_hard} HARD = {100*orig_ratio:.1f}%)") if final_hard != orig_hard: log(f" Augmented HARD: {orig_hard} → {final_hard} ({100*final_ratio:.1f}% of {len(pairs)} total)") return pairs # ===== TOKENIZATION ===== def tokenize_and_mask(text, tokenizer, max_len): """Tokenize text and create labels: ONLY the LAST assistant response is trained. In multi-turn conversations, earlier assistant responses are masked (-100). This gives a cleaner signal — model learns from the final answer only. """ ids = tokenizer.encode(text, add_special_tokens=False) if len(ids) > max_len: ids = ids[:max_len] input_ids = torch.tensor(ids, dtype=torch.long) labels = input_ids.clone() # Find the turn token ID dynamically turn_token_id = tokenizer.convert_tokens_to_ids("<|turn>") if turn_token_id == tokenizer.unk_token_id: turn_token_id = tokenizer.convert_tokens_to_ids("") # Encode "model" to get its token IDs (after <|turn>) model_token_ids = tokenizer.encode("model", add_special_tokens=False) # Find ALL occurrences of <|turn>model pattern model_starts = [] for i in range(len(ids) - len(model_token_ids)): if ids[i] == turn_token_id: match = True for j, mt in enumerate(model_token_ids): if i + 1 + j >= len(ids) or ids[i + 1 + j] != mt: match = False break if match: model_starts.append(i) if model_starts: # Mask everything up to the LAST assistant response start last_model_start = model_starts[-1] mask_end = last_model_start + 1 + len(model_token_ids) # Skip past any newline token while mask_end < len(ids) and ids[mask_end] in tokenizer.encode("\n", add_special_tokens=False): mask_end += 1 # Skip truncated examples: if last assistant response has < 10 trainable tokens, # the sequence was likely truncated mid-answer → bad training signal trainable_tokens = len(ids) - mask_end if trainable_tokens < 10: labels[:] = -100 # Mark all as masked → will be filtered by all-masked check return input_ids, labels # Mask everything before the last assistant content labels[:mask_end] = -100 return input_ids, labels from torch.utils.data import Dataset, DataLoader class SFTDataset(Dataset): def __init__(self, all_input_ids, all_labels): self.all_input_ids = all_input_ids self.all_labels = all_labels def __len__(self): return len(self.all_input_ids) def __getitem__(self, idx): return { "input_ids": self.all_input_ids[idx], "labels": self.all_labels[idx], } # ===== MODEL LOADING (v20: manual NF4 expert quantization) ===== def _nf4_expert_forward(self, hidden_states, top_k_index, top_k_weights): """Patched forward for Gemma4TextExperts with per-expert NF4 + TOKEN-CENTRIC batching. SELECTIVE DEQUANT: only top-k active experts dequanted (~8/128). TOKEN-CENTRIC: instead of looping per-expert with torch.where each time, we flatten all (token, expert) pairs, group by expert, batch-dequant, and run batched matmul per expert group with pre-gathered tokens. Eliminates redundant torch.where + GPU stays saturated. Flow: 1. Flatten routing into (token_idx, expert_id, weight) triples 2. Group by expert → get token indices per expert 3. Batch-dequant gate_up & down for active experts 4. Per expert group: gather tokens → matmul → scatter results (minimal Python overhead — grouping is O(tokens*k), compute is GPU) """ import bitsandbytes.functional as BF final_hidden_states = torch.zeros_like(hidden_states) dtype = hidden_states.dtype num_tokens = hidden_states.shape[0] k = top_k_index.shape[1] # top-k value (typically 8) # Flatten routing: [num_tokens, k] → flat arrays flat_token_idx = torch.arange(num_tokens, device=hidden_states.device).unsqueeze(1).expand(-1, k).reshape(-1) flat_expert_idx = top_k_index.reshape(-1) # [num_tokens * k] flat_weights = top_k_weights.reshape(-1) # [num_tokens * k] # Group by expert — build dict: expert_id → (token_indices, weights) # Use torch ops for speed: sort by expert, then split sorted_expert, sort_order = flat_expert_idx.sort() sorted_token_idx = flat_token_idx[sort_order] sorted_weights = flat_weights[sort_order] # Find unique experts and their counts unique_experts, counts = sorted_expert.unique_consecutive(return_counts=True) if len(unique_experts) == 0: return final_hidden_states # Split sorted arrays into per-expert groups split_sizes = counts.tolist() token_groups = sorted_token_idx.split(split_sizes) weight_groups = sorted_weights.split(split_sizes) # Process each active expert: dequant + matmul (no torch.where needed!) for grp_idx, eidx_t in enumerate(unique_experts): eidx = eidx_t.item() if eidx >= self.num_experts: continue grp_token_idx = token_groups[grp_idx] grp_weights = weight_groups[grp_idx] # Dequant this expert's weights gu = BF.dequantize_4bit( self._expert_gate_up_nf4[eidx], self._expert_gate_up_qs[eidx] ).reshape(self._expert_gate_up_shape).to(dtype) dw = BF.dequantize_4bit( self._expert_down_nf4[eidx], self._expert_down_qs[eidx] ).reshape(self._expert_down_shape).to(dtype) # Gather tokens for this expert and compute MLP current_state = hidden_states[grp_token_idx] # [n_tokens_for_expert, hidden] gate_val, up_val = torch.nn.functional.linear(current_state, gu).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate_val) * up_val current_hidden_states = torch.nn.functional.linear(current_hidden_states, dw) # Apply routing weights and scatter back current_hidden_states = current_hidden_states * grp_weights.unsqueeze(-1) final_hidden_states.index_add_( 0, grp_token_idx, current_hidden_states.to(final_hidden_states.dtype) ) del gu, dw return final_hidden_states def quantize_experts(model): """Quantize Gemma4TextExperts to NF4 — PER-EXPERT for selective dequant. Instead of quantizing the full [128, ...] stacked tensor as one blob, we quantize each expert individually. This allows _nf4_expert_forward to dequantize only the top-k active experts (~8 out of 128). Memory: ~11.3 GiB NF4 total (same as stacked), but dequant at inference only touches 8/128 = 6% → ~16x speedup per layer. """ import types import bitsandbytes.functional as BF layer_count = 0 for name, module in model.named_modules(): if type(module).__name__ != 'Gemma4TextExperts': continue layer_count += 1 num_experts = module.gate_up_proj.data.shape[0] # --- Quantize gate_up_proj per-expert: [num_experts, 2*intermediate, hidden] --- gate_data = module.gate_up_proj.data # [128, 2*inter, hidden] expert_gate_up_shape = gate_data.shape[1:] # single expert shape expert_gate_up_nf4 = [] expert_gate_up_qs = [] for eidx in range(num_experts): e_flat = gate_data[eidx].reshape(-1).to('cuda:0').contiguous() nf4, qs = BF.quantize_4bit(e_flat, quant_type='nf4', compress_statistics=True) expert_gate_up_nf4.append(nf4) expert_gate_up_qs.append(qs) del e_flat del gate_data module._expert_gate_up_nf4 = expert_gate_up_nf4 module._expert_gate_up_qs = expert_gate_up_qs module._expert_gate_up_shape = expert_gate_up_shape # --- Quantize down_proj per-expert: [num_experts, hidden, intermediate] --- down_data = module.down_proj.data expert_down_shape = down_data.shape[1:] expert_down_nf4 = [] expert_down_qs = [] for eidx in range(num_experts): e_flat = down_data[eidx].reshape(-1).to('cuda:0').contiguous() nf4, qs = BF.quantize_4bit(e_flat, quant_type='nf4', compress_statistics=True) expert_down_nf4.append(nf4) expert_down_qs.append(qs) del e_flat del down_data module._expert_down_nf4 = expert_down_nf4 module._expert_down_qs = expert_down_qs module._expert_down_shape = expert_down_shape # Replace BF16 parameters with tiny placeholders to free CPU memory module.gate_up_proj = torch.nn.Parameter( torch.empty(1, device='cpu'), requires_grad=False ) module.down_proj = torch.nn.Parameter( torch.empty(1, device='cpu'), requires_grad=False ) # Patch forward to use NF4 dequantization module.forward = types.MethodType(_nf4_expert_forward, module) alloc_gb = torch.cuda.memory_allocated() / (1024**3) log(f" Layer {layer_count}/30: experts -> NF4. GPU: {alloc_gb:.2f} GiB") gc.collect() torch.cuda.empty_cache() log(f" Expert quantization complete: {layer_count} layers processed.") def load_model_q4(): """Load 26B model with manual NF4 expert quantization (v20). Strategy: Load entire model on CPU (BF16, no BnB), manually quantize expert nn.Parameter weights to NF4 on GPU one layer at a time, then move the whole model to GPU. This bypasses the BnB limitation where expert weights (stored as nn.Parameter instead of nn.Linear) are NOT quantized by BnB's automatic 4-bit loading. Memory budget (24 GiB RTX 3090): - Expert NF4 (30 layers): ~11.3 GiB - Attention BF16 (30 layers): ~2.4 GiB - Embedding + LM head BF16: ~2.9 GiB - CUDA context: ~0.5 GiB - Total model: ~17.1 GiB → fits with ~7 GiB headroom for training """ from transformers import AutoTokenizer, Gemma4ForConditionalGeneration log("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token log("v20: Loading 26B on CPU (BF16, no BnB auto-quant)...") log(" Needs ~52 GiB virtual memory (27 RAM + 33 swap). Will be slow.") model = Gemma4ForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, attn_implementation="eager", ) log(" CPU load complete.") try: import psutil ram = psutil.virtual_memory() swap = psutil.swap_memory() log(f" RAM: {ram.used/1e9:.1f}/{ram.total/1e9:.1f} GiB ({ram.percent}%)") log(f" Swap: {swap.used/1e9:.1f}/{swap.total/1e9:.1f} GiB ({swap.percent}%)") except ImportError: pass # Quantize expert weights to NF4 on GPU (one layer at a time) log("Quantizing expert weights to NF4 on GPU...") quantize_experts(model) if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1e9 log(f" GPU after expert quantization: {alloc:.2f} GB") # Move all remaining modules (attention, norms, embedding, lm_head) to GPU log("Moving non-expert modules to GPU (BF16)...") model = model.to('cuda:0') if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1e9 log(f" GPU after full model on CUDA: {alloc:.2f} GB") # Verify device placement meta_count = sum(1 for p in model.parameters() if p.device.type == 'meta') gpu_count = sum(1 for p in model.parameters() if p.device.type == 'cuda') cpu_count = sum(1 for p in model.parameters() if p.device.type == 'cpu') log(f" Params: GPU={gpu_count}, CPU={cpu_count}, meta={meta_count}") torch.cuda.empty_cache() gc.collect() return model, tokenizer def apply_lora(model): """Apply LoRA adapter to model. v20: no BnB workarounds needed.""" from peft import get_peft_model, LoraConfig, TaskType # Check for ClippableLinear (should be 0 for 26B) clippable = 0 for name, module in list(model.named_modules()): if type(module).__name__ == "Gemma4ClippableLinear": clippable += 1 parts = name.rsplit(".", 1) if len(parts) == 2: parent = model.get_submodule(parts[0]) setattr(parent, parts[1], module.linear) if clippable > 0: log(f" Unwrapped {clippable} ClippableLinear modules") # Freeze all parameters + enable gradient checkpointing for param in model.parameters(): param.requires_grad = False model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() target_modules = TARGET_MODULES_FULL if INCLUDE_MLP_LORA else TARGET_MODULES_ATTENTION log(f" LoRA target modules: {target_modules}") lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=target_modules, lora_dropout=0, bias="none", task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) log(f" Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1e9 log(f" VRAM after LoRA: {alloc:.2f} GB") return model # ===== PHASE 1: SFT TRAINING ===== def phase_sft(): log("=" * 60) log("PHASE 1: SFT TRAINING (26B-A4B QLoRA)") log("=" * 60) model, tokenizer = load_model_q4() model = apply_lora(model) # Load data log("Loading SFT data...") texts = load_sft_data(tokenizer) if not texts: log("ERROR: No SFT data loaded!") return False random.shuffle(texts) # Tokenize all examples log(f"Tokenizing {len(texts)} examples (max_seq={MAX_SEQ_LENGTH})...") all_input_ids = [] all_labels = [] too_short = 0 all_masked = 0 for text in texts: ids, labels = tokenize_and_mask(text, tokenizer, MAX_SEQ_LENGTH) if len(ids) < 10: too_short += 1 continue # Skip examples where ALL labels are -100 (no valid target → NaN loss) if (labels == -100).all(): all_masked += 1 continue all_input_ids.append(ids) all_labels.append(labels) log(f" Tokenized: {len(all_input_ids)} examples, too short: {too_short}, all-masked: {all_masked}") # Dataset dataset = SFTDataset(all_input_ids, all_labels) def collate_fn(batch): return {k: v.unsqueeze(0) for k, v in batch[0].items()} dataloader = DataLoader(dataset, batch_size=SFT_BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # Optimizer from torch.optim import AdamW optimizer = AdamW( [p for p in model.parameters() if p.requires_grad], lr=SFT_LR, weight_decay=0.01, betas=(0.9, 0.999), ) total_steps = len(dataloader) * SFT_EPOCHS warmup_steps = int(total_steps * SFT_WARMUP_RATIO) # LR scheduler with warmup def get_lr(step): if step < warmup_steps: return SFT_LR * step / max(warmup_steps, 1) progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) return SFT_LR * max(0.1, 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))) log(f"\n--- SFT TRAINING ---") log(f" Examples: {len(dataset)}") log(f" Epochs: {SFT_EPOCHS}") log(f" Batch size: {SFT_BATCH_SIZE}") log(f" Grad accumulation: {SFT_GRAD_ACCUM}") log(f" Effective batch: {SFT_BATCH_SIZE * SFT_GRAD_ACCUM}") log(f" Total steps: {total_steps}") log(f" Optimizer steps: {total_steps // SFT_GRAD_ACCUM}") log(f" LR: {SFT_LR}") log(f" Warmup steps: {warmup_steps}") model.train() global_step = 0 best_loss = float("inf") running_loss = 0 loss_count = 0 nan_count = 0 sft_dir = os.path.join(ADAPTER_DIR, "sft_final") os.makedirs(sft_dir, exist_ok=True) start_time = time.time() for epoch in range(SFT_EPOCHS): epoch_loss = 0 epoch_steps = 0 for batch_idx, batch in enumerate(dataloader): global_step += 1 input_ids = batch["input_ids"].to("cuda") labels_t = batch["labels"].to("cuda") mm_types = torch.zeros_like(input_ids) attn_mask = torch.ones_like(input_ids) outputs = model( input_ids=input_ids, labels=labels_t, mm_token_type_ids=mm_types, attention_mask=attn_mask, ) # FP32 loss to avoid BF16 precision issues loss = outputs.loss.float() / SFT_GRAD_ACCUM # Skip NaN/Inf batches (data issues, numerical overflow) if torch.isnan(loss) or torch.isinf(loss): nan_count += 1 if nan_count % 100 == 1: log(f" WARNING: NaN/Inf loss at step {global_step} (total NaN: {nan_count})") continue loss.backward() loss_val = outputs.loss.item() running_loss += loss_val loss_count += 1 epoch_loss += loss_val epoch_steps += 1 if global_step % SFT_GRAD_ACCUM == 0: # Update LR lr = get_lr(global_step) for pg in optimizer.param_groups: pg["lr"] = lr torch.nn.utils.clip_grad_norm_(model.parameters(), SFT_MAX_GRAD_NORM) optimizer.step() optimizer.zero_grad() opt_step = global_step // SFT_GRAD_ACCUM if opt_step % 50 == 0: avg_loss = running_loss / max(loss_count, 1) elapsed = time.time() - start_time steps_per_sec = global_step / elapsed eta = (total_steps - global_step) / steps_per_sec if steps_per_sec > 0 else 0 vram = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 log(f" Epoch {epoch+1}/{SFT_EPOCHS} | " f"Step {opt_step}/{total_steps//SFT_GRAD_ACCUM} | " f"Loss: {avg_loss:.4f} | " f"LR: {lr:.2e} | " f"VRAM: {vram:.1f}GB | " f"NaN: {nan_count} | " f"ETA: {eta/60:.0f}min") running_loss = 0 loss_count = 0 save_checkpoint("sft", opt_step, {"loss": avg_loss, "epoch": epoch+1}) avg_epoch_loss = epoch_loss / max(epoch_steps, 1) log(f" Epoch {epoch+1} complete — avg loss: {avg_epoch_loss:.4f}") if avg_epoch_loss < best_loss: best_loss = avg_epoch_loss log(f" New best loss! Saving adapter to {sft_dir}") model.save_pretrained(sft_dir) tokenizer.save_pretrained(sft_dir) total_time = time.time() - start_time log(f"\nSFT COMPLETE — Best loss: {best_loss:.4f}, Time: {total_time/60:.1f} min") save_checkpoint("sft_complete", extra={"best_loss": best_loss, "time_min": total_time/60}) # Save final if not already saved model.save_pretrained(sft_dir) tokenizer.save_pretrained(sft_dir) del model, optimizer gc.collect() torch.cuda.empty_cache() return True # ===== PHASE 2: DPO TRAINING ===== def phase_dpo(): log("=" * 60) log("PHASE 2: DPO TRAINING (26B-A4B QLoRA, cached ref logprobs)") log("=" * 60) sft_dir = os.path.join(ADAPTER_DIR, "sft_final") if not os.path.exists(sft_dir): log(f"ERROR: SFT adapter not found at {sft_dir}") log("Run phase 1 (SFT) first!") return False # Load base model + SFT adapter model, tokenizer = load_model_q4() from peft import PeftModel # Unwrap ClippableLinear before PEFT (same as apply_lora) clippable = 0 for name, module in list(model.named_modules()): if type(module).__name__ == "Gemma4ClippableLinear": clippable += 1 parts = name.rsplit(".", 1) if len(parts) == 2: parent = model.get_submodule(parts[0]) setattr(parent, parts[1], module.linear) if clippable > 0: log(f" Unwrapped {clippable} ClippableLinear modules for DPO") log(f"Loading SFT adapter from {sft_dir}...") model = PeftModel.from_pretrained(model, sft_dir, is_trainable=True) # Optionally add MLP LoRA for DPO (better tool-calling generation) if DPO_INCLUDE_MLP_LORA and not INCLUDE_MLP_LORA: log(f" Adding MLP LoRA targets for DPO phase (r={DPO_MLP_LORA_R}, safer for MoE)...") from peft import LoraConfig mlp_config = LoraConfig( r=DPO_MLP_LORA_R, lora_alpha=DPO_MLP_LORA_R, # alpha=r → multiplier=1.0 (stable) target_modules=["gate_proj", "up_proj", "down_proj"], lora_dropout=0, bias="none", ) model.add_adapter("mlp_dpo", mlp_config) model.base_model.set_adapter(["default", "mlp_dpo"]) # PEFT 0.18.1 fix extra_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) log(f" MLP LoRA added. Trainable now: {extra_trainable:,}") model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model.enable_input_require_grads() trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) log(f" Trainable: {trainable:,}") # Load DPO data log("Loading DPO data...") pairs = load_dpo_data(tokenizer) if not pairs: log("ERROR: No DPO data loaded!") return False # ===== PRECOMPUTE REFERENCE LOGPROBS (1-time cost, ~2x DPO speedup) ===== log(f"Precomputing reference logprobs for {len(pairs)} pairs (seq_len={DPO_SEQ_LENGTH})...") log(f" This is a one-time cost that halves per-step compute.") def compute_ref_logprobs(text): """Compute reference (adapter-disabled) logprobs for one text.""" ids = tokenizer.encode(text, add_special_tokens=False) if len(ids) > DPO_SEQ_LENGTH: ids = ids[:DPO_SEQ_LENGTH] t = torch.tensor([ids], device="cuda") mm_types = torch.zeros_like(t) attn_mask = torch.ones_like(t) model.eval() with torch.no_grad(): model.disable_adapter_layers() out = model(input_ids=t, mm_token_type_ids=mm_types, attention_mask=attn_mask) model.enable_adapter_layers() logits = out.logits[:, :-1, :] targets = t[:, 1:] log_probs = torch.log_softmax(logits.float(), dim=-1) token_logprobs = log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1) return token_logprobs.mean().item() # .mean() not .sum() — prevents length bias ref_cache_start = time.time() ref_chosen_list = [] ref_rejected_list = [] for idx, pair in enumerate(pairs): ref_chosen_list.append(compute_ref_logprobs(pair["chosen_text"])) ref_rejected_list.append(compute_ref_logprobs(pair["rejected_text"])) if (idx + 1) % 100 == 0: elapsed = time.time() - ref_cache_start eta = elapsed / (idx + 1) * (len(pairs) - idx - 1) log(f" Ref cache: {idx+1}/{len(pairs)} ({elapsed/60:.1f}min elapsed, ETA {eta/60:.1f}min)") # Convert to pinned CPU tensors — faster GPU transfer, less overhead than Python lists ref_chosen_cache = torch.tensor(ref_chosen_list, dtype=torch.float32).pin_memory() ref_rejected_cache = torch.tensor(ref_rejected_list, dtype=torch.float32).pin_memory() del ref_chosen_list, ref_rejected_list ref_cache_time = time.time() - ref_cache_start log(f" Reference logprobs cached: {len(pairs)} pairs in {ref_cache_time/60:.1f} min") log(f" Cache: pinned CPU tensors ({ref_chosen_cache.nbytes * 2 / 1024:.1f} KB)") log(f" ⚡ Each DPO step now: 2 forward (policy only) instead of 4") model.train() # Shuffle pairs and ref caches in sync indices = list(range(len(pairs))) random.shuffle(indices) pairs = [pairs[i] for i in indices] idx_tensor = torch.tensor(indices, dtype=torch.long) ref_chosen_cache = ref_chosen_cache[idx_tensor] ref_rejected_cache = ref_rejected_cache[idx_tensor] from torch.optim import AdamW optimizer = AdamW( [p for p in model.parameters() if p.requires_grad], lr=DPO_LR, weight_decay=0.01, ) total_steps = len(pairs) * DPO_EPOCHS log(f"\n--- DPO TRAINING ---") log(f" Pairs: {len(pairs)}") log(f" Epochs: {DPO_EPOCHS}") log(f" Beta: {DPO_BETA}") log(f" Seq length: {DPO_SEQ_LENGTH}") log(f" Total steps: {total_steps}") log(f" Optimizer steps: {total_steps // DPO_GRAD_ACCUM}") log(f" MLP LoRA: {DPO_INCLUDE_MLP_LORA}") def get_policy_logprobs(text): """Get policy (with adapter) logprobs — the only forward needed per step.""" ids = tokenizer.encode(text, add_special_tokens=False) if len(ids) > DPO_SEQ_LENGTH: ids = ids[:DPO_SEQ_LENGTH] t = torch.tensor([ids], device="cuda") mm_types = torch.zeros_like(t) attn_mask = torch.ones_like(t) out = model(input_ids=t, mm_token_type_ids=mm_types, attention_mask=attn_mask) logits = out.logits[:, :-1, :] targets = t[:, 1:] log_probs = torch.log_softmax(logits.float(), dim=-1) token_logprobs = log_probs.gather(2, targets.unsqueeze(-1)).squeeze(-1) return token_logprobs.mean() # .mean() not .sum() — prevents length bias global_step = 0 running_loss = 0 loss_count = 0 best_loss = float("inf") nan_count = 0 dpo_dir = os.path.join(ADAPTER_DIR, "dpo_final") os.makedirs(dpo_dir, exist_ok=True) start_time = time.time() for epoch in range(DPO_EPOCHS): # Re-shuffle each epoch (with cache sync via tensor indexing) indices = list(range(len(pairs))) random.shuffle(indices) epoch_pairs = [pairs[i] for i in indices] idx_t = torch.tensor(indices, dtype=torch.long) # Async copy shuffled ref caches to GPU — non_blocking avoids CPU stall epoch_ref_chosen = ref_chosen_cache[idx_t].to("cuda", non_blocking=True) epoch_ref_rejected = ref_rejected_cache[idx_t].to("cuda", non_blocking=True) epoch_loss = 0 epoch_steps = 0 for i, pair in enumerate(epoch_pairs): global_step += 1 # Policy log probs (2 forward passes — WITH gradients) chosen_logp = get_policy_logprobs(pair["chosen_text"]) rejected_logp = get_policy_logprobs(pair["rejected_text"]) # Reference log probs (from cache — FREE!) ref_chosen_logp = epoch_ref_chosen[i] ref_rejected_logp = epoch_ref_rejected[i] # DPO loss chosen_rewards = DPO_BETA * (chosen_logp - ref_chosen_logp) rejected_rewards = DPO_BETA * (rejected_logp - ref_rejected_logp) loss = -torch.nn.functional.logsigmoid(chosen_rewards - rejected_rewards) / DPO_GRAD_ACCUM # NaN guard if torch.isnan(loss) or torch.isinf(loss): nan_count += 1 continue loss.backward() loss_val = loss.item() * DPO_GRAD_ACCUM running_loss += loss_val loss_count += 1 epoch_loss += loss_val epoch_steps += 1 if global_step % DPO_GRAD_ACCUM == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() opt_step = global_step // DPO_GRAD_ACCUM if opt_step % 25 == 0: avg_loss = running_loss / max(loss_count, 1) elapsed = time.time() - start_time steps_per_sec = global_step / elapsed if elapsed > 0 else 0 eta = (total_steps - global_step) / steps_per_sec if steps_per_sec > 0 else 0 vram = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 log(f" Epoch {epoch+1}/{DPO_EPOCHS} | " f"Step {opt_step}/{total_steps//DPO_GRAD_ACCUM} | " f"DPO Loss: {avg_loss:.4f} | " f"NaN: {nan_count} | " f"VRAM: {vram:.1f}GB | " f"ETA: {eta/60:.0f}min") running_loss = 0 loss_count = 0 save_checkpoint("dpo", opt_step, {"loss": avg_loss, "epoch": epoch+1}) avg_epoch_loss = epoch_loss / max(epoch_steps, 1) log(f" DPO Epoch {epoch+1} complete — avg loss: {avg_epoch_loss:.4f}") if avg_epoch_loss < best_loss: best_loss = avg_epoch_loss log(f" New best DPO loss! Saving adapter to {dpo_dir}") model.save_pretrained(dpo_dir) tokenizer.save_pretrained(dpo_dir) total_time = time.time() - start_time log(f"\nDPO COMPLETE — Best loss: {best_loss:.4f}, NaN skipped: {nan_count}, Time: {total_time/60:.1f} min") save_checkpoint("dpo_complete", extra={"best_loss": best_loss, "time_min": total_time/60, "nan_count": nan_count}) model.save_pretrained(dpo_dir) tokenizer.save_pretrained(dpo_dir) del model, optimizer, ref_chosen_cache, ref_rejected_cache gc.collect() torch.cuda.empty_cache() return True # ===== PHASE 3: MERGE TO BF16 ===== def phase_merge(): log("=" * 60) log("PHASE 3: MERGE LoRA → BF16 (CPU, will use swap)") log("=" * 60) dpo_dir = os.path.join(ADAPTER_DIR, "dpo_final") sft_dir = os.path.join(ADAPTER_DIR, "sft_final") # Prefer DPO adapter, fall back to SFT adapter_dir = dpo_dir if os.path.exists(dpo_dir) else sft_dir if not os.path.exists(adapter_dir): log(f"ERROR: No adapter found! Run SFT/DPO first.") return False log(f" Using adapter: {adapter_dir}") log(f" Output: {MERGED_DIR}") log(f" WARNING: This needs ~54GB RAM. Will use swap (slow but works).") os.makedirs(MERGED_DIR, exist_ok=True) # Load base model in bf16 on CPU with low memory from transformers import Gemma4ForConditionalGeneration, AutoTokenizer from peft import PeftModel import psutil log(" Loading base model on CPU (bf16)...") base_model = Gemma4ForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, ) log(f" Base model loaded. RAM: {psutil.virtual_memory().percent}%") log(" Loading LoRA adapter...") model = PeftModel.from_pretrained(base_model, adapter_dir) log(f" Adapter loaded. RAM: {psutil.virtual_memory().percent}%") log(" Merging LoRA into base model...") model = model.merge_and_unload() log(f" Merged. RAM: {psutil.virtual_memory().percent}%") log(f" Saving merged model to {MERGED_DIR} (small shards to reduce peak RAM)...") model.save_pretrained(MERGED_DIR, max_shard_size="2GB", safe_serialization=True) tokenizer = AutoTokenizer.from_pretrained(adapter_dir) tokenizer.save_pretrained(MERGED_DIR) # Verify total_params = sum(p.numel() for p in model.parameters()) log(f" Saved! Total params: {total_params:,} ({total_params/1e9:.2f}B)") safetensors = glob.glob(os.path.join(MERGED_DIR, "*.safetensors")) total_size = sum(os.path.getsize(f) for f in safetensors) log(f" Files: {len(safetensors)} shards, {total_size/1e9:.1f} GB total") del model, base_model gc.collect() log("PHASE 3 COMPLETE — Merged model saved") save_checkpoint("merge_complete") return True # ===== PHASE 4: GGUF CONVERSION ===== def phase_gguf(): log("=" * 60) log("PHASE 4: GGUF Q4_K_M CONVERSION") log("=" * 60) if not os.path.exists(MERGED_DIR): log(f"ERROR: Merged model not found at {MERGED_DIR}") return False convert_script = os.path.join(LLAMA_CPP, "convert_hf_to_gguf.py") quantize_bin = os.path.join(LLAMA_CPP, "build", "bin", "llama-quantize") if not os.path.exists(convert_script): log(f"ERROR: convert_hf_to_gguf.py not found at {convert_script}") return False os.makedirs(GGUF_DIR, exist_ok=True) # Step 1: Convert to BF16 GGUF (native format, no precision loss) bf16_gguf = os.path.join(GGUF_DIR, "gemma4-janus-v6-bf16.gguf") q4_gguf = os.path.join(GGUF_DIR, "gemma4-janus-v6-q4_k_m.gguf") log(f" Converting merged HF model to BF16 GGUF...") cmd = [ "python3", convert_script, MERGED_DIR, "--outfile", bf16_gguf, "--outtype", "bf16", ] log(f" CMD: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) if result.returncode != 0: log(f" STDERR: {result.stderr[-500:]}") log("ERROR: BF16 GGUF conversion failed!") return False bf16_size = os.path.getsize(bf16_gguf) / 1e9 log(f" BF16 GGUF: {bf16_size:.1f} GB") # Step 2: Quantize to Q4_K_M log(f" Quantizing BF16 → Q4_K_M...") if not os.path.exists(quantize_bin): quantize_bin = os.path.join(LLAMA_CPP, "llama-quantize") if not os.path.exists(quantize_bin): log(f" WARNING: llama-quantize not found. Skipping quantization.") log(f" You can manually run: llama-quantize {bf16_gguf} {q4_gguf} Q4_K_M") return True cmd = [quantize_bin, bf16_gguf, q4_gguf, "Q4_K_M"] log(f" CMD: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) if result.returncode != 0: log(f" STDERR: {result.stderr[-500:]}") log("ERROR: Q4_K_M quantization failed!") return False q4_size = os.path.getsize(q4_gguf) / 1e9 log(f" Q4_K_M GGUF: {q4_size:.1f} GB") # Clean up BF16 intermediate to save disk (~50GB) log(f" Removing BF16 GGUF to save disk ({bf16_size:.1f} GB)...") os.remove(bf16_gguf) log("PHASE 4 COMPLETE — GGUF ready") save_checkpoint("gguf_complete", extra={"q4_size_gb": q4_size}) return True # ===== PHASE 5: OLLAMA DEPLOY ===== def phase_ollama(): log("=" * 60) log("PHASE 5: OLLAMA DEPLOY") log("=" * 60) q4_gguf = os.path.join(GGUF_DIR, "gemma4-janus-v6-q4_k_m.gguf") if not os.path.exists(q4_gguf): log(f"ERROR: GGUF not found at {q4_gguf}") return False # Create Modelfile modelfile_content = f"""FROM {q4_gguf} PARAMETER temperature 0.3 PARAMETER top_p 0.9 PARAMETER top_k 40 PARAMETER repeat_penalty 1.1 PARAMETER num_predict 4096 PARAMETER num_ctx 8192 PARAMETER num_gpu 99 PARAMETER stop "Observation:" PARAMETER stop "" PARAMETER stop "<|tool_response>" PARAMETER stop "" TEMPLATE \"\"\"{{{{- range .Messages }}}} {{{{ .Role }}}} {{{{ .Content }}}} {{{{- end }}}} model \"\"\" SYSTEM \"\"\"Tu esi Janus — DevOps AI asistentas. Tu valdo Docker konteinerius, Linux serverius ir GPU resursus. Naudok ReAct formatą: Thought → Action → Observation. Atsakyk trumpai ir tiksliai. Jei nežinai — naudok įrankius. Klaidos atveju — diagnozuok ir taisyk automatiškai.\"\"\" """ modelfile_path = os.path.join(BASE_DIR, "Modelfile.gemma4-janus-v6") with open(modelfile_path, "w") as f: f.write(modelfile_content) log(f" Modelfile written: {modelfile_path}") # Ensure Ollama is running result = subprocess.run(["systemctl", "is-active", "ollama"], capture_output=True, text=True) if result.stdout.strip() != "active": log(" Starting Ollama...") subprocess.run(["sudo", "systemctl", "start", "ollama"], check=True) import time time.sleep(5) # Create model log(" Creating Ollama model: gemma4-janus-v6...") result = subprocess.run( ["ollama", "create", "gemma4-janus-v6", "-f", modelfile_path], capture_output=True, text=True, timeout=600 ) if result.returncode != 0: log(f" STDERR: {result.stderr}") log("ERROR: Ollama model creation failed!") return False log(f" STDOUT: {result.stdout}") # Verify result = subprocess.run(["ollama", "list"], capture_output=True, text=True) log(f" Ollama models:\n{result.stdout}") # Smoke test log(" Running smoke test...") import urllib.request payload = json.dumps({ "model": "gemma4-janus-v6", "prompt": "Kiek yra 2+2?", "stream": False, }).encode() req = urllib.request.Request( "http://localhost:11434/api/generate", data=payload, headers={"Content-Type": "application/json"}, ) try: with urllib.request.urlopen(req, timeout=120) as resp: data = json.loads(resp.read()) response_text = data.get("response", "") eval_duration = data.get("eval_duration", 0) / 1e9 log(f" Response: {response_text[:200]}") log(f" Time: {eval_duration:.1f}s") except Exception as e: log(f" Smoke test error: {e}") log("PHASE 5 COMPLETE — Ollama model deployed") save_checkpoint("deploy_complete") return True # ===== MAIN ===== def main(): parser = argparse.ArgumentParser(description="Janus V6 Training Pipeline (26B-A4B)") parser.add_argument("--phase", required=True, choices=["0", "1", "2", "3", "4", "5", "all", "diag"], help="Phase to run: 0=download, 1=SFT, 2=DPO, 3=merge, 4=GGUF, 5=ollama, all, diag") parser.add_argument("--include-mlp", action="store_true", help="Include MLP projections in LoRA targets (needs more VRAM)") args = parser.parse_args() global INCLUDE_MLP_LORA if args.include_mlp: INCLUDE_MLP_LORA = True log("MLP LoRA ENABLED — targeting expert projections too") log("=" * 60) log("JANUS V6 TRAINING PIPELINE") log(f" Model: {MODEL_NAME}") log(f" Phase: {args.phase}") log(f" LoRA: r={LORA_R}, alpha={LORA_ALPHA}") log(f" SFT seq: {MAX_SEQ_LENGTH}, DPO seq: {DPO_SEQ_LENGTH}") log(f" SFT MLP LoRA: {INCLUDE_MLP_LORA}") log(f" DPO MLP LoRA: {DPO_INCLUDE_MLP_LORA} (r={DPO_MLP_LORA_R})") log(f" DPO HARD target: {DPO_HARD_TARGET_RATIO:.0%}") log(f" Adapter dir: {ADAPTER_DIR}") log(f" Merged dir: {MERGED_DIR}") log("=" * 60) phase = args.phase if phase in ("0", "all"): if not phase_download(): log("PHASE 0 FAILED") sys.exit(1) if phase == "diag": phase_diagnostics() return if phase in ("1", "all"): if not phase_sft(): log("PHASE 1 (SFT) FAILED") sys.exit(1) if phase in ("2", "all"): if not phase_dpo(): log("PHASE 2 (DPO) FAILED") sys.exit(1) if phase in ("3", "all"): if not phase_merge(): log("PHASE 3 (MERGE) FAILED") sys.exit(1) if phase in ("4", "all"): if not phase_gguf(): log("PHASE 4 (GGUF) FAILED") sys.exit(1) if phase in ("5", "all"): if not phase_ollama(): log("PHASE 5 (OLLAMA) FAILED") sys.exit(1) log("\n" + "=" * 60) log("ALL REQUESTED PHASES COMPLETE!") log("=" * 60) if __name__ == "__main__": main()