# coding=utf-8 # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/usr/bin/env python3 """ Script to prune hidden_size and intermediate_size of a HuggingFace checkpoint. Supports FP8, BF16, and NVFP4 precision formats: - FP8: Weights are stored as float8_e4m3fn (one value per element, no packing). weight_scale and input_scale are per-tensor scalars — they are NOT pruned. - BF16: Weights are stored as bfloat16. No scale tensors are present. - NVFP4: Two fp4 values are packed into one uint8 along the in_features dim (compression factor 2). weight_scale tensors are one scale per 16 values (compression factor 16) — these ARE pruned. weight_scale_2 tensors are per-tensor scalars and are NOT pruned. Modules listed under hf_quant_config.json `quantization.exclude_modules` are stored unpacked and follow the same layout as BF16. In all cases, only actual weight/bias/norm tensors are pruned along the appropriate axes. Predefined size presets: - 12B: target hidden_size=1920, target intermediate_size=960 - 23B: target hidden_size=2304, target intermediate_size=1600 Usage: python zero_shot_slicing.py \\ --source-checkpoint /path/to/source \\ --target-checkpoint /path/to/target \\ --size 23B \\ --precision nvfp4 """ import json import argparse import shutil from pathlib import Path from collections import defaultdict from typing import Dict, Set import torch from safetensors.torch import load_file, save_file from safetensors import safe_open # Predefined size presets: maps size label -> (target_hidden_size, target_intermediate_size) SIZE_PRESETS = { "12B": (1920, 960), "23B": (2304, 1600), } def load_exclude_modules(quant_config_path: Path) -> Set[str]: """Load modules excluded from quantization (NVFP4 only). These modules are stored in full precision and follow the BF16 layout (no in_features packing). """ if not quant_config_path.exists(): return set() with open(quant_config_path, 'r') as f: config = json.load(f) return set(config.get('quantization', {}).get('exclude_modules', [])) def get_in_features_compression(tensor_name: str, exclude_modules: Set[str], precision: str) -> int: """Return the in_features (dim 1) compression factor. FP8 / BF16 ---------- Always 1 — values are stored 1:1 with no packing. NVFP4 ----- * weight : 2 (two fp4 values packed into one uint8) * weight_scale : 16 (one scale per 16 values) * exclude-list mod : 1 (stored unquantized, BF16-style layout) weight_scale_2 is a per-tensor scalar and is filtered out by the caller before this function is reached. """ if precision != "nvfp4": return 1 if 'weight_scale' in tensor_name and 'weight_scale_2' not in tensor_name: return 16 if any(tensor_name.startswith(mod) for mod in exclude_modules): return 1 return 2 def prune_single_tensor( tensor: torch.Tensor, tensor_name: str, original_hidden_size: int, target_hidden_size: int, target_intermediate_size: int, precision: str, exclude_modules: Set[str], ) -> torch.Tensor: """Determine the correct target shape for *tensor_name* and prune by dropping the last elements along each dimension that needs shrinking. Rules ----- * Tensors that are never touched: - FP8 / BF16: any tensor whose name contains '_scale' (per-tensor scalars). - NVFP4: weight_scale_2 only — weight_scale tensors ARE pruned (factor 16). - mixer.norm weights (Mamba internal norm, size != hidden_size). * For 1D tensors (layer-norms, biases): only prune if size == original_hidden_size. * For up_proj / down_proj (MLP): - shared_experts: only the hidden_size dim is pruned (their intermediate size is independent and must not be changed). - regular experts: both hidden_size and intermediate_size dims are pruned. * For all other 2D tensors the name determines which dim is hidden_size; the NVFP4 compression factor is applied to dim 1 (in_features) when applicable. """ # ---- tensors we never touch ---- if precision == "nvfp4": if 'weight_scale_2' in tensor_name: return tensor else: if '_scale' in tensor_name: return tensor if '.mixer.norm.' in tensor_name: return tensor shape = tensor.shape compression = get_in_features_compression(tensor_name, exclude_modules, precision) # ---- 1D (norms / biases) ---- if len(shape) == 1: if shape[0] == original_hidden_size and original_hidden_size > target_hidden_size: return tensor[:target_hidden_size] return tensor if len(shape) != 2: return tensor # safety: skip anything higher-dimensional # ---- MLP up_proj / down_proj ---- if 'up_proj' in tensor_name or 'down_proj' in tensor_name: is_shared = 'shared_expert' in tensor_name if 'up_proj' in tensor_name: # Layout: [intermediate_size, hidden_size / compression] target_dim0 = shape[0] if is_shared else target_intermediate_size target_dim1 = target_hidden_size // compression else: # down_proj # Layout: [hidden_size, intermediate_size / compression] target_dim0 = target_hidden_size target_dim1 = shape[1] if is_shared else target_intermediate_size // compression result = tensor if result.shape[0] > target_dim0: result = result[:target_dim0, :] if result.shape[1] > target_dim1: result = result[:, :target_dim1] return result # ---- remaining 2D tensors: figure out which dim(s) carry hidden_size ---- original_compressed = original_hidden_size // compression target_compressed = target_hidden_size // compression # dim 0 == hidden_size (out_features, never compressed): o_proj, out_proj if any(p in tensor_name for p in ['o_proj', 'out_proj']): result = tensor if shape[0] == original_hidden_size and original_hidden_size > target_hidden_size: result = result[:target_hidden_size, :] return result # embeddings, gate, lm_head — dim 1 is raw hidden_size (never NVFP4-quantized) if any(p in tensor_name for p in ['embeddings', 'embed_tokens', 'gate', 'lm_head']): if shape[1] == original_hidden_size and original_hidden_size > target_hidden_size: return tensor[:, :target_hidden_size] return tensor # dim 1 == hidden_size / compression (in_features): q_proj, k_proj, v_proj, in_proj if any(p in tensor_name for p in ['q_proj', 'k_proj', 'v_proj', 'in_proj']): if shape[1] == original_compressed and original_compressed > target_compressed: return tensor[:, :target_compressed] return tensor return tensor def load_sharded_checkpoint(checkpoint_dir: Path) -> Dict[str, torch.Tensor]: """Load all tensors from a sharded safetensors checkpoint.""" index_path = checkpoint_dir / "model.safetensors.index.json" if not index_path.exists(): single_file = checkpoint_dir / "model.safetensors" if single_file.exists(): return load_file(str(single_file)) raise FileNotFoundError(f"Could not find checkpoint index or single file in {checkpoint_dir}") with open(index_path, 'r') as f: weight_map = json.load(f).get("weight_map", {}) files_to_tensors = defaultdict(list) for tensor_name, filename in weight_map.items(): files_to_tensors[filename].append(tensor_name) state_dict = {} for filename, tensor_names in files_to_tensors.items(): file_path = checkpoint_dir / filename if not file_path.exists(): raise FileNotFoundError(f"Checkpoint file not found: {file_path}") with safe_open(str(file_path), framework="pt", device="cpu") as f: for tensor_name in tensor_names: state_dict[tensor_name] = f.get_tensor(tensor_name) return state_dict def save_sharded_checkpoint( state_dict: Dict[str, torch.Tensor], checkpoint_dir: Path, original_index_path: Path ): """Save sharded safetensors checkpoint, preserving original sharding structure.""" checkpoint_dir.mkdir(parents=True, exist_ok=True) with open(original_index_path, 'r') as f: original_weight_map = json.load(f).get("weight_map", {}) files_to_tensors = defaultdict(list) for tensor_name in state_dict.keys(): filename = original_weight_map.get(tensor_name, list(original_weight_map.values())[0] if original_weight_map else "model-00001-of-00001.safetensors") files_to_tensors[filename].append(tensor_name) new_weight_map = {} total_size = 0 for filename, tensor_names in files_to_tensors.items(): file_dict = {name: state_dict[name] for name in tensor_names if name in state_dict} if file_dict: # Ensure all tensors are contiguous before saving file_dict = {name: tensor.contiguous() if not tensor.is_contiguous() else tensor for name, tensor in file_dict.items()} save_file(file_dict, str(checkpoint_dir / filename)) for tensor_name in tensor_names: if tensor_name in state_dict: new_weight_map[tensor_name] = filename total_size += sum(t.numel() * t.element_size() for t in file_dict.values()) with open(checkpoint_dir / "model.safetensors.index.json", 'w') as f: json.dump({"metadata": {"total_size": total_size}, "weight_map": new_weight_map}, f, indent=4) def prune_checkpoint( source_checkpoint: Path, target_checkpoint: Path, original_hidden_size: int, target_hidden_size: int, original_intermediate_size: int, target_intermediate_size: int, precision: str, exclude_modules: Set[str], ): """Main function to prune the checkpoint.""" print(f"Precision: {precision.upper()}") if precision == "nvfp4" and exclude_modules: print(f"Loaded {len(exclude_modules)} exclude modules from hf_quant_config.json") print(f"Loading checkpoint from {source_checkpoint}...") state_dict = load_sharded_checkpoint(source_checkpoint) print(f"Loaded {len(state_dict)} tensors") print(f"Pruning hidden_size: {original_hidden_size} -> {target_hidden_size}") print(f"Pruning intermediate_size: {original_intermediate_size} -> {target_intermediate_size}") pruned_count = 0 unchanged_scales = 0 for tensor_name, tensor in state_dict.items(): original_shape = tensor.shape pruned_tensor = prune_single_tensor( tensor, tensor_name, original_hidden_size, target_hidden_size, target_intermediate_size, precision, exclude_modules, ) if pruned_tensor.shape != original_shape: pruned_count += 1 state_dict[tensor_name] = pruned_tensor.contiguous() elif '_scale' in tensor_name: unchanged_scales += 1 print(f"\nPruned {pruned_count} tensors") if precision == "nvfp4": print(f"Left {unchanged_scales} weight_scale_2 tensors unchanged") elif precision == "fp8": print(f"Left {unchanged_scales} scale tensors unchanged") elif unchanged_scales > 0: print(f"Note: {unchanged_scales} scale tensors found (unexpected for {precision.upper()})") target_checkpoint.mkdir(parents=True, exist_ok=True) # Update config.json with new sizes (do this before copying to avoid overwriting) config_path = source_checkpoint / "config.json" if config_path.exists(): with open(config_path, 'r') as f: config = json.load(f) config['hidden_size'] = target_hidden_size config['intermediate_size'] = target_intermediate_size config['moe_intermediate_size'] = target_intermediate_size with open(target_checkpoint / "config.json", 'w') as f: json.dump(config, f, indent=2) print(f"Updated config.json: hidden_size={target_hidden_size}, intermediate_size={target_intermediate_size}, moe_intermediate_size={target_intermediate_size}") # Copy all non-model files (everything except safetensors, index, and config.json) for item in source_checkpoint.iterdir(): if item.is_file(): # Skip model safetensors files, index, and config.json (we already updated it) if item.suffix == '.safetensors' or item.name == 'model.safetensors.index.json' or item.name == 'config.json': continue # Copy everything else (tokenizer, Python modules, etc.) shutil.copy2(item, target_checkpoint / item.name) print(f"Saving pruned checkpoint to {target_checkpoint}...") save_sharded_checkpoint(state_dict, target_checkpoint, source_checkpoint / "model.safetensors.index.json") print("Done!") def main(): parser = argparse.ArgumentParser( description="Prune HF checkpoint hidden_size and intermediate_size (supports FP8, BF16, NVFP4)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=f"Size presets: {', '.join(f'{k} (hidden={v[0]}, intermediate={v[1]})' for k, v in SIZE_PRESETS.items())}", ) parser.add_argument("--source-checkpoint", type=str, required=True, help="Path to source checkpoint directory") parser.add_argument("--target-checkpoint", type=str, required=True, help="Path to target checkpoint directory") parser.add_argument("--precision", type=str, required=True, choices=["fp8", "bf16", "nvfp4"], help="Checkpoint precision format: fp8, bf16, or nvfp4") parser.add_argument("--size", type=str, required=True, choices=list(SIZE_PRESETS.keys()), help="Model size preset (sets target hidden/intermediate sizes)") parser.add_argument("--original-hidden-size", type=int, default=None, help="Original hidden_size (auto-detected from config if not provided)") parser.add_argument("--original-intermediate-size", type=int, default=None, help="Original intermediate_size (auto-detected from config if not provided)") args = parser.parse_args() target_hidden_size, target_intermediate_size = SIZE_PRESETS[args.size] source_checkpoint = Path(args.source_checkpoint) target_checkpoint = Path(args.target_checkpoint) if not source_checkpoint.exists(): raise ValueError(f"Source checkpoint does not exist: {source_checkpoint}") # Load config to get original sizes config_path = source_checkpoint / "config.json" if config_path.exists(): with open(config_path, 'r') as f: config = json.load(f) original_hidden_size = args.original_hidden_size or config.get('hidden_size') original_intermediate_size = args.original_intermediate_size or config.get('intermediate_size') else: if args.original_hidden_size is None or args.original_intermediate_size is None: raise ValueError("config.json not found and original sizes not provided") original_hidden_size = args.original_hidden_size original_intermediate_size = args.original_intermediate_size # NVFP4 needs the exclude-modules list to know which tensors are stored unpacked. exclude_modules: Set[str] = set() if args.precision == "nvfp4": exclude_modules = load_exclude_modules(source_checkpoint / "hf_quant_config.json") prune_checkpoint( source_checkpoint, target_checkpoint, original_hidden_size, target_hidden_size, original_intermediate_size, target_intermediate_size, args.precision, exclude_modules, ) if __name__ == "__main__": main()