# Gemma 4 requires transformers >= 5.5.0 (model_type: gemma4). # If your llmcompressor pins an older version, install with: # pip install llmcompressor # pip install transformers>=5.5 import os from typing import Any, Dict, List, Tuple from compressed_tensors.offload import dispatch_model from datasets import concatenate_datasets, load_dataset from transformers import AutoModelForImageTextToText, AutoProcessor from llmcompressor import oneshot from llmcompressor.modifiers.gptq import GPTQModifier LANGUAGE_CONFIGS: List[Tuple[str, str]] = [ ("go", "Go"), ("java", "Java"), ("javascript", "Javascript"), ("php", "PHP"), ("python", "Python"), ("ruby", "Ruby"), ] REQUIRED_TEXT_FIELDS: Tuple[str, ...] = ( "language", "func_name", "func_documentation_string", "func_code_string", ) def compute_language_sample_counts(total_samples: int, n_languages: int) -> List[int]: base = total_samples // n_languages remainder = total_samples % n_languages return [base + (1 if i < remainder else 0) for i in range(n_languages)] def is_nonempty_string(value: Any) -> bool: return isinstance(value, str) and bool(value.strip()) def has_required_code_fields(example: Dict[str, Any]) -> bool: return all(is_nonempty_string(example.get(field)) for field in REQUIRED_TEXT_FIELDS) def build_text_for_code_example(example: Dict[str, Any]) -> str: language = example["language"].strip() code = example["func_code_string"].strip() return f"```{language}\n{code}\n```" def load_mixed_codesearchnet_dataset( dataset_name: str, split: str, total_samples: int, seed: int, workers: int, ): counts = compute_language_sample_counts(total_samples, len(LANGUAGE_CONFIGS)) filter_workers = max(1, min(workers, os.cpu_count() or 1)) subsets = [] for idx, ((cfg, _display_name), target_count) in enumerate(zip(LANGUAGE_CONFIGS, counts)): ds = load_dataset(dataset_name, cfg, split=split) ds = ds.filter(has_required_code_fields, num_proc=filter_workers) if len(ds) < target_count: raise RuntimeError( f"Not enough valid samples for language={cfg}. " f"required={target_count}, valid_available={len(ds)}" ) ds = ds.shuffle(seed=seed + idx).select(range(target_count)) subsets.append(ds) return concatenate_datasets(subsets).shuffle(seed=seed) # Load model. MODEL_ID = "google/gemma-4-31B-it" DATASET_ID = "code_search_net" DATASET_SPLIT = "train" NUM_CALIBRATION_SAMPLES = 258 MAX_SEQUENCE_LENGTH = 2048 DATASET_SEED = 42 DATASET_WORKERS = max(1, os.cpu_count() or 1) model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, dtype="auto") processor = AutoProcessor.from_pretrained(MODEL_ID) tokenizer = getattr(processor, "tokenizer", None) if tokenizer is None: raise RuntimeError("AutoProcessor does not expose tokenizer for calibration preprocessing.") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Configure the quantization algorithm and scheme. # In this case, we: # * quantize the weights to fp4 with per group 16 via GPTQ # * skip the vision encoder, audio encoder, embedding projections, and lm_head recipe = GPTQModifier( targets="Linear", scheme="NVFP4A16", ignore=[ "lm_head", "re:.*vision_tower.*", "re:.*embed_vision.*", ], ) raw_ds = load_mixed_codesearchnet_dataset( dataset_name=DATASET_ID, split=DATASET_SPLIT, total_samples=NUM_CALIBRATION_SAMPLES, seed=DATASET_SEED, workers=DATASET_WORKERS, ) def preprocess_function(example: Dict[str, Any]) -> Dict[str, str]: return {"text": build_text_for_code_example(example)} calib_ds = raw_ds.map( preprocess_function, batched=False, remove_columns=raw_ds.column_names, ) calib_ds = calib_ds.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0) if len(calib_ds) != NUM_CALIBRATION_SAMPLES: raise RuntimeError( "Calibration dataset size mismatch after preprocessing: " f"expected={NUM_CALIBRATION_SAMPLES}, actual={len(calib_ds)}" ) # Apply quantization. oneshot( model=model, tokenizer=tokenizer, recipe=recipe, dataset=calib_ds, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) print("\n\n========== SAMPLE GENERATION ==============") dispatch_model(model) messages = [ {"role": "user", "content": "Hello my name is"}, ] text = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) inputs = processor(text=text, return_tensors="pt").to(model.device) output = model.generate(**inputs, max_new_tokens=100) print(processor.decode(output[0], skip_special_tokens=True)) print("==========================================\n\n") # Save to disk in compressed-tensors format. SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4A16-GPTQ" model.save_pretrained(SAVE_DIR, save_compressed=True) processor.save_pretrained(SAVE_DIR)