Text Generation
Transformers
Safetensors
PyTorch
nemotron_h
nvidia
elastic
conversational
custom_code
jrd971000's picture
zero_shot_slicing: add NVFP4 support
3ce4d40
raw
history blame contribute delete
16.6 kB
# coding=utf-8
# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3
"""
Script to prune hidden_size and intermediate_size of a HuggingFace checkpoint.
Supports FP8, BF16, and NVFP4 precision formats:
- FP8: Weights are stored as float8_e4m3fn (one value per element, no packing).
weight_scale and input_scale are per-tensor scalars — they are NOT pruned.
- BF16: Weights are stored as bfloat16. No scale tensors are present.
- NVFP4: Two fp4 values are packed into one uint8 along the in_features dim
(compression factor 2). weight_scale tensors are one scale per 16 values
(compression factor 16) — these ARE pruned. weight_scale_2 tensors are
per-tensor scalars and are NOT pruned. Modules listed under
hf_quant_config.json `quantization.exclude_modules` are stored unpacked
and follow the same layout as BF16.
In all cases, only actual weight/bias/norm tensors are pruned along the
appropriate axes.
Predefined size presets:
- 12B: target hidden_size=1920, target intermediate_size=960
- 23B: target hidden_size=2304, target intermediate_size=1600
Usage:
python zero_shot_slicing.py \\
--source-checkpoint /path/to/source \\
--target-checkpoint /path/to/target \\
--size 23B \\
--precision nvfp4
"""
import json
import argparse
import shutil
from pathlib import Path
from collections import defaultdict
from typing import Dict, Set
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
# Predefined size presets: maps size label -> (target_hidden_size, target_intermediate_size)
SIZE_PRESETS = {
"12B": (1920, 960),
"23B": (2304, 1600),
}
def load_exclude_modules(quant_config_path: Path) -> Set[str]:
"""Load modules excluded from quantization (NVFP4 only).
These modules are stored in full precision and follow the BF16 layout
(no in_features packing).
"""
if not quant_config_path.exists():
return set()
with open(quant_config_path, 'r') as f:
config = json.load(f)
return set(config.get('quantization', {}).get('exclude_modules', []))
def get_in_features_compression(tensor_name: str, exclude_modules: Set[str], precision: str) -> int:
"""Return the in_features (dim 1) compression factor.
FP8 / BF16
----------
Always 1 — values are stored 1:1 with no packing.
NVFP4
-----
* weight : 2 (two fp4 values packed into one uint8)
* weight_scale : 16 (one scale per 16 values)
* exclude-list mod : 1 (stored unquantized, BF16-style layout)
weight_scale_2 is a per-tensor scalar and is filtered out by the caller
before this function is reached.
"""
if precision != "nvfp4":
return 1
if 'weight_scale' in tensor_name and 'weight_scale_2' not in tensor_name:
return 16
if any(tensor_name.startswith(mod) for mod in exclude_modules):
return 1
return 2
def prune_single_tensor(
tensor: torch.Tensor,
tensor_name: str,
original_hidden_size: int,
target_hidden_size: int,
target_intermediate_size: int,
precision: str,
exclude_modules: Set[str],
) -> torch.Tensor:
"""Determine the correct target shape for *tensor_name* and prune by dropping
the last elements along each dimension that needs shrinking.
Rules
-----
* Tensors that are never touched:
- FP8 / BF16: any tensor whose name contains '_scale' (per-tensor scalars).
- NVFP4: weight_scale_2 only — weight_scale tensors ARE pruned (factor 16).
- mixer.norm weights (Mamba internal norm, size != hidden_size).
* For 1D tensors (layer-norms, biases): only prune if size == original_hidden_size.
* For up_proj / down_proj (MLP):
- shared_experts: only the hidden_size dim is pruned (their intermediate
size is independent and must not be changed).
- regular experts: both hidden_size and intermediate_size dims are pruned.
* For all other 2D tensors the name determines which dim is hidden_size; the
NVFP4 compression factor is applied to dim 1 (in_features) when applicable.
"""
# ---- tensors we never touch ----
if precision == "nvfp4":
if 'weight_scale_2' in tensor_name:
return tensor
else:
if '_scale' in tensor_name:
return tensor
if '.mixer.norm.' in tensor_name:
return tensor
shape = tensor.shape
compression = get_in_features_compression(tensor_name, exclude_modules, precision)
# ---- 1D (norms / biases) ----
if len(shape) == 1:
if shape[0] == original_hidden_size and original_hidden_size > target_hidden_size:
return tensor[:target_hidden_size]
return tensor
if len(shape) != 2:
return tensor # safety: skip anything higher-dimensional
# ---- MLP up_proj / down_proj ----
if 'up_proj' in tensor_name or 'down_proj' in tensor_name:
is_shared = 'shared_expert' in tensor_name
if 'up_proj' in tensor_name:
# Layout: [intermediate_size, hidden_size / compression]
target_dim0 = shape[0] if is_shared else target_intermediate_size
target_dim1 = target_hidden_size // compression
else: # down_proj
# Layout: [hidden_size, intermediate_size / compression]
target_dim0 = target_hidden_size
target_dim1 = shape[1] if is_shared else target_intermediate_size // compression
result = tensor
if result.shape[0] > target_dim0:
result = result[:target_dim0, :]
if result.shape[1] > target_dim1:
result = result[:, :target_dim1]
return result
# ---- remaining 2D tensors: figure out which dim(s) carry hidden_size ----
original_compressed = original_hidden_size // compression
target_compressed = target_hidden_size // compression
# dim 0 == hidden_size (out_features, never compressed): o_proj, out_proj
if any(p in tensor_name for p in ['o_proj', 'out_proj']):
result = tensor
if shape[0] == original_hidden_size and original_hidden_size > target_hidden_size:
result = result[:target_hidden_size, :]
return result
# embeddings, gate, lm_head — dim 1 is raw hidden_size (never NVFP4-quantized)
if any(p in tensor_name for p in ['embeddings', 'embed_tokens', 'gate', 'lm_head']):
if shape[1] == original_hidden_size and original_hidden_size > target_hidden_size:
return tensor[:, :target_hidden_size]
return tensor
# dim 1 == hidden_size / compression (in_features): q_proj, k_proj, v_proj, in_proj
if any(p in tensor_name for p in ['q_proj', 'k_proj', 'v_proj', 'in_proj']):
if shape[1] == original_compressed and original_compressed > target_compressed:
return tensor[:, :target_compressed]
return tensor
return tensor
def load_sharded_checkpoint(checkpoint_dir: Path) -> Dict[str, torch.Tensor]:
"""Load all tensors from a sharded safetensors checkpoint."""
index_path = checkpoint_dir / "model.safetensors.index.json"
if not index_path.exists():
single_file = checkpoint_dir / "model.safetensors"
if single_file.exists():
return load_file(str(single_file))
raise FileNotFoundError(f"Could not find checkpoint index or single file in {checkpoint_dir}")
with open(index_path, 'r') as f:
weight_map = json.load(f).get("weight_map", {})
files_to_tensors = defaultdict(list)
for tensor_name, filename in weight_map.items():
files_to_tensors[filename].append(tensor_name)
state_dict = {}
for filename, tensor_names in files_to_tensors.items():
file_path = checkpoint_dir / filename
if not file_path.exists():
raise FileNotFoundError(f"Checkpoint file not found: {file_path}")
with safe_open(str(file_path), framework="pt", device="cpu") as f:
for tensor_name in tensor_names:
state_dict[tensor_name] = f.get_tensor(tensor_name)
return state_dict
def save_sharded_checkpoint(
state_dict: Dict[str, torch.Tensor],
checkpoint_dir: Path,
original_index_path: Path
):
"""Save sharded safetensors checkpoint, preserving original sharding structure."""
checkpoint_dir.mkdir(parents=True, exist_ok=True)
with open(original_index_path, 'r') as f:
original_weight_map = json.load(f).get("weight_map", {})
files_to_tensors = defaultdict(list)
for tensor_name in state_dict.keys():
filename = original_weight_map.get(tensor_name,
list(original_weight_map.values())[0] if original_weight_map else "model-00001-of-00001.safetensors")
files_to_tensors[filename].append(tensor_name)
new_weight_map = {}
total_size = 0
for filename, tensor_names in files_to_tensors.items():
file_dict = {name: state_dict[name] for name in tensor_names if name in state_dict}
if file_dict:
# Ensure all tensors are contiguous before saving
file_dict = {name: tensor.contiguous() if not tensor.is_contiguous() else tensor
for name, tensor in file_dict.items()}
save_file(file_dict, str(checkpoint_dir / filename))
for tensor_name in tensor_names:
if tensor_name in state_dict:
new_weight_map[tensor_name] = filename
total_size += sum(t.numel() * t.element_size() for t in file_dict.values())
with open(checkpoint_dir / "model.safetensors.index.json", 'w') as f:
json.dump({"metadata": {"total_size": total_size}, "weight_map": new_weight_map}, f, indent=4)
def prune_checkpoint(
source_checkpoint: Path,
target_checkpoint: Path,
original_hidden_size: int,
target_hidden_size: int,
original_intermediate_size: int,
target_intermediate_size: int,
precision: str,
exclude_modules: Set[str],
):
"""Main function to prune the checkpoint."""
print(f"Precision: {precision.upper()}")
if precision == "nvfp4" and exclude_modules:
print(f"Loaded {len(exclude_modules)} exclude modules from hf_quant_config.json")
print(f"Loading checkpoint from {source_checkpoint}...")
state_dict = load_sharded_checkpoint(source_checkpoint)
print(f"Loaded {len(state_dict)} tensors")
print(f"Pruning hidden_size: {original_hidden_size} -> {target_hidden_size}")
print(f"Pruning intermediate_size: {original_intermediate_size} -> {target_intermediate_size}")
pruned_count = 0
unchanged_scales = 0
for tensor_name, tensor in state_dict.items():
original_shape = tensor.shape
pruned_tensor = prune_single_tensor(
tensor, tensor_name,
original_hidden_size, target_hidden_size,
target_intermediate_size,
precision, exclude_modules,
)
if pruned_tensor.shape != original_shape:
pruned_count += 1
state_dict[tensor_name] = pruned_tensor.contiguous()
elif '_scale' in tensor_name:
unchanged_scales += 1
print(f"\nPruned {pruned_count} tensors")
if precision == "nvfp4":
print(f"Left {unchanged_scales} weight_scale_2 tensors unchanged")
elif precision == "fp8":
print(f"Left {unchanged_scales} scale tensors unchanged")
elif unchanged_scales > 0:
print(f"Note: {unchanged_scales} scale tensors found (unexpected for {precision.upper()})")
target_checkpoint.mkdir(parents=True, exist_ok=True)
# Update config.json with new sizes (do this before copying to avoid overwriting)
config_path = source_checkpoint / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
config['hidden_size'] = target_hidden_size
config['intermediate_size'] = target_intermediate_size
config['moe_intermediate_size'] = target_intermediate_size
with open(target_checkpoint / "config.json", 'w') as f:
json.dump(config, f, indent=2)
print(f"Updated config.json: hidden_size={target_hidden_size}, intermediate_size={target_intermediate_size}, moe_intermediate_size={target_intermediate_size}")
# Copy all non-model files (everything except safetensors, index, and config.json)
for item in source_checkpoint.iterdir():
if item.is_file():
# Skip model safetensors files, index, and config.json (we already updated it)
if item.suffix == '.safetensors' or item.name == 'model.safetensors.index.json' or item.name == 'config.json':
continue
# Copy everything else (tokenizer, Python modules, etc.)
shutil.copy2(item, target_checkpoint / item.name)
print(f"Saving pruned checkpoint to {target_checkpoint}...")
save_sharded_checkpoint(state_dict, target_checkpoint, source_checkpoint / "model.safetensors.index.json")
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Prune HF checkpoint hidden_size and intermediate_size (supports FP8, BF16, NVFP4)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=f"Size presets: {', '.join(f'{k} (hidden={v[0]}, intermediate={v[1]})' for k, v in SIZE_PRESETS.items())}",
)
parser.add_argument("--source-checkpoint", type=str, required=True,
help="Path to source checkpoint directory")
parser.add_argument("--target-checkpoint", type=str, required=True,
help="Path to target checkpoint directory")
parser.add_argument("--precision", type=str, required=True, choices=["fp8", "bf16", "nvfp4"],
help="Checkpoint precision format: fp8, bf16, or nvfp4")
parser.add_argument("--size", type=str, required=True, choices=list(SIZE_PRESETS.keys()),
help="Model size preset (sets target hidden/intermediate sizes)")
parser.add_argument("--original-hidden-size", type=int, default=None,
help="Original hidden_size (auto-detected from config if not provided)")
parser.add_argument("--original-intermediate-size", type=int, default=None,
help="Original intermediate_size (auto-detected from config if not provided)")
args = parser.parse_args()
target_hidden_size, target_intermediate_size = SIZE_PRESETS[args.size]
source_checkpoint = Path(args.source_checkpoint)
target_checkpoint = Path(args.target_checkpoint)
if not source_checkpoint.exists():
raise ValueError(f"Source checkpoint does not exist: {source_checkpoint}")
# Load config to get original sizes
config_path = source_checkpoint / "config.json"
if config_path.exists():
with open(config_path, 'r') as f:
config = json.load(f)
original_hidden_size = args.original_hidden_size or config.get('hidden_size')
original_intermediate_size = args.original_intermediate_size or config.get('intermediate_size')
else:
if args.original_hidden_size is None or args.original_intermediate_size is None:
raise ValueError("config.json not found and original sizes not provided")
original_hidden_size = args.original_hidden_size
original_intermediate_size = args.original_intermediate_size
# NVFP4 needs the exclude-modules list to know which tensors are stored unpacked.
exclude_modules: Set[str] = set()
if args.precision == "nvfp4":
exclude_modules = load_exclude_modules(source_checkpoint / "hf_quant_config.json")
prune_checkpoint(
source_checkpoint,
target_checkpoint,
original_hidden_size,
target_hidden_size,
original_intermediate_size,
target_intermediate_size,
args.precision,
exclude_modules,
)
if __name__ == "__main__":
main()