import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gc import json import os from tqdm import tqdm from datasets import load_dataset import random # --- CONFIGURATION --- MODEL_ID = "google/gemma-4-31B-it" # Adjust if your local path differs SAVE_PATH = "./gemma-4-31b-abliterated" BATCH_SIZE = 4 # Keep this low to survive the 31B hidden state extraction DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[*] Initializing Gemma 4 31B Abliteration Protocol on {DEVICE}...") # --- 1. LOAD MODEL & TOKENIZER --- print("[*] Loading Model and Tokenizer (bfloat16)...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto" # Let accelerate distribute the 62GB across your GPUs ) # --- 2. DATA PREPARATION --- print("[*] Downloading HuggingFace datasets...") # Load the datasets harmful_dataset = load_dataset('mlabonne/harmful_behaviors') harmless_dataset = load_dataset('mlabonne/harmless_alpaca') # Extract the raw text prompts # We shuffle and slice 256 samples to keep VRAM extraction manageable but statistically significant raw_harmful = random.sample(harmful_dataset['train']['text'], 256) raw_harmless = random.sample(harmless_dataset['train']['text'], 256) def format_gemma4_prompts(instructions): """Uses the native Gemma 4 chat template with system roles.""" formatted = [] for inst in instructions: messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": inst} ] # Tokenizer handles all the control tokens formatted.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)) return formatted print("[*] Formatting prompts with native Gemma 4 Chat Templates...") harmful_prompts = format_gemma4_prompts(raw_harmful) harmless_prompts = format_gemma4_prompts(raw_harmless) # --- 3. HIDDEN STATE EXTRACTION (VRAM SAFE) --- def get_hidden_states(prompts, batch_size=BATCH_SIZE): print(f"[*] Extracting hidden states (Batches of {batch_size})...") all_hidden_states = [] for i in tqdm(range(0, len(prompts), batch_size)): batch = prompts[i:i+batch_size] inputs = tokenizer(batch, padding=True, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) # outputs.hidden_states is a tuple of (num_layers + 1) tensors. # Shape of each tensor: [batch_size, sequence_length, hidden_dim] # We want the last token's state across ALL layers. # Stack to: [num_layers+1, batch, seq, dim] stacked_states = torch.stack(outputs.hidden_states) # Extract last token: [num_layers+1, batch, dim] last_token_states = stacked_states[:, torch.arange(len(batch)), -1, :] # IMMEDIATELY move to CPU float32 to save VRAM all_hidden_states.append(last_token_states.cpu().float()) del inputs, outputs, stacked_states, last_token_states torch.cuda.empty_cache() gc.collect() # Concatenate along the batch dimension: [num_layers+1, total_prompts, hidden_dim] return torch.cat(all_hidden_states, dim=1) print("\n[*] Processing Harmful Vector Space...") harmful_states = get_hidden_states(harmful_prompts) print("[*] Processing Harmless Vector Space...") harmless_states = get_hidden_states(harmless_prompts) # --- 4. DYNAMIC LAYER HUNTING --- print("\n[*] Hunting for the Refusal Vector...") mean_harmful = harmful_states.mean(dim=1) mean_harmless = harmless_states.mean(dim=1) refusal_directions = mean_harmful - mean_harmless # Find the state index with the highest magnitude magnitudes = torch.norm(refusal_directions[1:], dim=1) peak_state_idx = torch.argmax(magnitudes).item() + 1 print(f"[+] Peak Refusal Mass detected at state index: {peak_state_idx}") # Normalize the refusal vector refusal_vector = refusal_directions[peak_state_idx] refusal_vector = (refusal_vector / torch.norm(refusal_vector)).to(DEVICE).to(torch.bfloat16) # --- 5. ORTHOGONAL PROJECTION (THE ABLITERATION) --- # FIX 1: Safely navigate the Gemma 4 Multimodal Config num_layers = model.config.text_config.num_hidden_layers if hasattr(model.config, 'text_config') else model.config.num_hidden_layers # FIX 2: Correct the off-by-one mapping (State index 60 comes from Layer 59) target_layer_idx = peak_state_idx - 1 print(f"\n[*] Applying Orthogonal Projection starting at Layer {target_layer_idx}...") # FIX 3: Bulletproof dynamic layer discovery for Multimodal models def get_transformer_layers(model_obj, target_len): for name, module in model_obj.named_modules(): if name.endswith('layers') and isinstance(module, torch.nn.ModuleList) and len(module) == target_len: return module return model_obj.model.layers # Fallback transformer_layers = get_transformer_layers(model, num_layers) # Pre-calculate column and row vectors for the linear algebra v_col = refusal_vector.unsqueeze(1) # Shape: (5376, 1) v_row = refusal_vector.unsqueeze(0) # Shape: (1, 5376) # Abliterate the target layer and up to 4 subsequent layers (capped safely by num_layers) for layer_idx in range(target_layer_idx, min(target_layer_idx + 5, num_layers)): print(f" -> Abliterating Layer {layer_idx}...") o_proj = transformer_layers[layer_idx].self_attn.o_proj.weight.data down_proj = transformer_layers[layer_idx].mlp.down_proj.weight.data # CORRECTED MATH: v_col @ (v_row @ W) projection_o = torch.matmul(v_col, torch.matmul(v_row, o_proj)) transformer_layers[layer_idx].self_attn.o_proj.weight.data -= projection_o projection_down = torch.matmul(v_col, torch.matmul(v_row, down_proj)) transformer_layers[layer_idx].mlp.down_proj.weight.data -= projection_down # --- 6. CRYSTALLIZATION --- print(f"\n[*] Abliteration Complete. Saving uncensored weights to {SAVE_PATH}...") model.save_pretrained(SAVE_PATH) tokenizer.save_pretrained(SAVE_PATH) print("[+] SUCCESS: The 31B Teacher is ready to wake up.")