PEFT
Safetensors
English
Chinese
lora
distillation
svd
cross-architecture
adaptive-rank
gemma
llama
nemotron
Instructions to use win10/Nemotron2Gemma-AURORA-LoRA-27B-IT-0p95 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use win10/Nemotron2Gemma-AURORA-LoRA-27B-IT-0p95 with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("E:\text-generation-webui-1.14\user_data\models\google-gemma-3-27b-it-text") model = PeftModel.from_pretrained(base_model, "win10/Nemotron2Gemma-AURORA-LoRA-27B-IT-0p95") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Universal SVD-LoRA Distillation v4.2.0 (Cross-Arch Innovations) โ Plug-and-Play Truncated SVD (Low-VRAM) | |
| Key changes vs v3.7.16 (SVD-Optimized): | |
| ---------------------- | |
| 0. Adds AURORA-SVD (Adaptive Unrolled Residual-Order Range Augmentation SVD) as a distillation-centric low-rank backend. | |
| 0. Fixes DRF-SVD backend definition/indentation so --svd-mode drf works as intended. | |
| 0. Randomized SVD now uses QR-normalized subspace iteration (Halko et al., Alg. 4.4) for better numerical stability. | |
| 0. Adaptive rank selection is now measured against ||ฮ||_F^2 and grows rank geometrically instead of always computing SVD at max_rank. | |
| 1. Adds SVD-LLM style full-covariance whitening (Cholesky) as a calibration mode ("cov") | |
| to replace diagonal-only RMS scaling when requested. | |
| 2. Restores v3.7.15-stable TIES merge semantics by default and applies the TIES Trim | |
| step to single-delta cases by default (use --no-ties-trim-single to disable). | |
| 3. Preserves the low-VRAM streaming/offload design (cpu/disk offload + prefetch). | |
| Output: | |
| ------- | |
| Writes a PEFT LoRA adapter (adapter_model.safetensors + adapter_config.json). | |
| License: Apache 2.0 | |
| """ | |
| import os | |
| import re | |
| import json | |
| import math | |
| import time | |
| import gc | |
| import argparse | |
| import warnings | |
| import threading | |
| import queue | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Tuple, Any, Set, Union, Iterator | |
| from enum import Enum | |
| from contextlib import contextmanager | |
| from pathlib import Path | |
| import tempfile | |
| import shutil | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| try: | |
| from safetensors.torch import load_file, save_file, safe_open | |
| except ImportError: | |
| raise RuntimeError("safetensors is required. pip install safetensors") | |
| try: | |
| from transformers import AutoConfig | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| warnings.warn("transformers not found. Architecture auto-detection disabled.") | |
| # ============================================================================= | |
| # ENUMS & DATA CLASSES | |
| # ============================================================================= | |
| class AttentionType(Enum): | |
| MHA = "multi_head_attention" | |
| MQA = "multi_query_attention" | |
| GQA = "grouped_query_attention" | |
| class MLPType(Enum): | |
| STANDARD = "standard_ffn" | |
| SWIGLU = "swiglu" | |
| GEGLU = "geglu" | |
| MOE = "mixture_of_experts" | |
| class OffloadStrategy(Enum): | |
| NONE = "none" | |
| CPU = "cpu" | |
| DISK = "disk" | |
| AUTO = "auto" | |
| class ArchitectureInfo: | |
| """Detected architecture information.""" | |
| attention_type: AttentionType = AttentionType.MHA | |
| mlp_type: MLPType = MLPType.STANDARD | |
| num_attention_heads: int = 32 | |
| num_kv_heads: int = 32 | |
| hidden_size: int = 4096 | |
| intermediate_size: int = 11008 | |
| num_hidden_layers: int = 32 | |
| vocab_size: int = 32000 | |
| num_experts: int = 0 | |
| num_experts_per_tok: int = 0 | |
| layer_prefix: str = "model.layers" | |
| attn_prefix: str = "self_attn" | |
| mlp_prefix: str = "mlp" | |
| q_proj_name: str = "q_proj" | |
| k_proj_name: str = "k_proj" | |
| v_proj_name: str = "v_proj" | |
| o_proj_name: str = "o_proj" | |
| gate_proj_name: str = "gate_proj" | |
| up_proj_name: str = "up_proj" | |
| down_proj_name: str = "down_proj" | |
| has_fused_qkv: bool = False | |
| qkv_proj_name: str = "qkv_proj" | |
| class DistillConfig: | |
| """Configuration for Universal SVD-LoRA Distillation. | |
| v3.7 goals: | |
| - Truly low-cost, out-of-the-box SVD distillation via truncated / randomized SVD by default. | |
| - Preserve v3.2's low-VRAM, shard-wise processing and layer offloading characteristics. | |
| - Improve cross-architecture robustness via efficient SVD-based projection with bounded rank. | |
| """ | |
| # Paths | |
| teacher_path: str | |
| student_path: str | |
| output_path: str | |
| # LoRA rank controls | |
| rank_default: int = 64 | |
| rank_attn: Optional[int] = None | |
| rank_mlp: Optional[int] = None | |
| lora_alpha: Optional[int] = None | |
| alpha_mode: str = "auto" # auto | rank | fixed | |
| use_adaptive_rank: bool = True | |
| energy_threshold: float = 0.95 | |
| min_rank: int = 8 | |
| max_rank: int = 256 | |
| # Delta regularization / merge helpers | |
| use_dare: bool = False | |
| dare_drop_rate: float = 0.3 | |
| use_ties: bool = True | |
| ties_density: float = 0.3 | |
| ties_trim_single: bool = True | |
| # Optional MoE blending (kept for backwards-compat; not used by default) | |
| moe_merge_method: str = "none" | |
| max_experts_to_blend: int = 2 | |
| # Layer mapping schedule | |
| map_schedule: str = "linear" # linear | sigmoid | |
| sigmoid_k: float = 0.15 | |
| interp_mode: str = "schedule" # schedule | lsq | |
| # Runtime / memory controls | |
| num_gpus: int = 1 | |
| gpu_memory_fraction: float = 0.95 | |
| use_cuda_streams: bool = True | |
| num_cuda_streams: int = 3 | |
| use_mixed_precision: bool = True | |
| svd_batch_size: int = 1 | |
| offload_strategy: str = "cpu" # cpu | disk | none | |
| offload_dir: Optional[str] = None | |
| prefetch_layers: int = 2 | |
| use_pinned_memory: bool = True | |
| max_cpu_memory_gb: float = 64.0 | |
| # SVD execution / batching | |
| chunk_threshold_mb: int = 512 | |
| chunk_size_mb: int = 128 | |
| micro_batch_size: int = 8 | |
| # Module selection | |
| include_pattern: str = "self_attn|mlp" | |
| exclude_pattern: str = "" | |
| include_embed_lm_head: bool = False | |
| # Cross-architecture projection (teacher tensor -> student tensor shape) | |
| use_svd_projection: bool = True | |
| projection_rank: int = 256 | |
| projection_min_rank: int = 8 | |
| projection_adaptive_rank: bool = True | |
| projection_energy_threshold: float = 0.99 | |
| projection_randomized_iter: int = 2 | |
| projection_randomized_oversamples: int = 8 | |
| # Cross-arch subspace-delta controls | |
| subspace_delta: str = "off" # off|auto|diag|full | |
| subspace_offdiag_threshold: float = 0.55 | |
| subspace_denom: str = "core" # core|full | |
| subspace_norm_match: bool = True | |
| # SVD backend selection | |
| # - auto: heuristic (full vs low-rank) | |
| # - full: torch.linalg.svd | |
| # - randomized: Halko Alg. 4.4 (stable subspace iteration) | |
| # - krylov: Block Krylov subspace (stores intermediate iterates) | |
| # - adaptive: Adaptive range finder (Halko Alg. 4.2) for energy-targeted rank | |
| svd_mode: str = "auto" # auto | full | randomized | krylov | adaptive | drf | aurora | |
| svd_randomized_iter: int = 2 | |
| svd_randomized_oversamples: int = 8 | |
| svd_auto_min_dim: int = 1024 | |
| svd_auto_full_rank_ratio: float = 0.6 | |
| svd_auto_lowrank: str = "randomized" # randomized | krylov | drf | aurora | |
| # DRF-SVD (Dual-Residual Folding SVD) knobs | |
| # - steps: number of residual-folding refinements (>=0) | |
| # - theta: residual scaling exponent in [0,1] (0=none, 1=full 1/ฯ weighting) | |
| # - resid_eps: numerical stabilizer for 1/ฯ^theta | |
| svd_drf_steps: int = 1 | |
| svd_drf_theta: float = 0.5 | |
| svd_drf_resid_eps: float = 1e-8 | |
| # AURORA-SVD (Adaptive Unrolled Residual-Order Range Augmentation SVD) knobs | |
| # - steps: refinement rounds (>=0) | |
| # - order: 1=first-order residual; 2=adds second-order residual response via A/A^T (recommended) | |
| # - theta1/theta2: residual scaling exponents in [0,1] | |
| # - resid_eps: numerical stabilizer for scaling | |
| # - level2_keep: fraction of components kept for the level-2 term in (0,1]; 1.0 keeps all | |
| svd_aurora_steps: int = 1 | |
| svd_aurora_order: int = 2 | |
| svd_aurora_theta1: float = 0.5 | |
| svd_aurora_theta2: float = 0.25 | |
| svd_aurora_resid_eps: float = 1e-8 | |
| svd_aurora_level2_keep: float = 1.0 | |
| # Adaptive range finder (Alg. 4.2) knobs (used when svd_mode=adaptive) | |
| svd_adaptive_block_size: int = 32 # columns added per iteration | |
| svd_adaptive_n_test: int = 8 # test vectors for residual estimate | |
| # Delta stability guard (scales delta if ||delta||/||student|| exceeds this ratio; 0 disables) | |
| max_delta_ratio: float = 0.35 | |
| # Calibration-aware SVD (optional; default: off) | |
| # | |
| # Purpose: collect activation statistics on the *student* model (calibration prompts) | |
| # and apply data-aware whitening before SVD so the truncated low-rank factors better | |
| # match data-distribution sensitivity. | |
| # | |
| # calibration_mode: | |
| # - "none": disable | |
| # - "rms": diagonal RMS whitening using per-channel activation RMS (input/output) | |
| # - "cov": SVD-LLM style full-covariance input whitening via Cholesky factor S | |
| # of E[xx^T], using PiSSA on (ฮยทS) and then unwhitening with S^{-1}. | |
| calibration_mode: str = "none" # choices: none | rms | cov | |
| calib_data: Optional[str] = None # .txt (1 prompt/line), .json/.jsonl, or Alpaca JSON (.json/.jsonl) | |
| calib_format: str = "auto" # auto | txt | jsonl | json | alpaca | |
| calib_alpaca_template: str = "classic" # classic | plain (only used when calib_format=alpaca or auto-detected) | |
| calib_alpaca_include_output: bool = False # include output tokens in calibration forward (alpaca only) | |
| calib_max_samples: int = 128 # number of prompts (after filtering empties) | |
| calib_max_length: int = 256 # max tokens per prompt | |
| calib_padding: str = "longest" # longest | max_length (padding strategy during tokenization) | |
| calib_use_attention_mask: bool = True # exclude attention_mask==0 tokens from activation stats (padding tokens) | |
| calib_batch_size: int = 1 # batch size for forward passes | |
| calib_device: str = "auto" # auto | cuda | cpu | |
| calib_dtype: str = "auto" # auto | bf16 | fp16 | fp32 | |
| calib_load: Optional[str] = None # load precomputed stats (.safetensors) | |
| calib_save: Optional[str] = None # save computed stats (.safetensors) | |
| calib_eps: float = 1e-6 # numerical stabilizer for whitening | |
| calib_collect_in: bool = True # collect input RMS | |
| calib_collect_out: bool = True # collect output RMS | |
| # Full-covariance (SVD-LLM) calibration controls (only used when calib_mode=cov) | |
| # | |
| # Notes: | |
| # - cov whitening is applied on the *input* side (per SVD-LLM). For very large | |
| # input dims (e.g., MLP down_proj), collection can be gated via calib_cov_max_dim. | |
| # - calibration can be chunked across layers to bound peak memory. | |
| calib_cov_max_dim: int = 8192 # skip covariance collection if in_features > this | |
| calib_cov_chunk_layers: int = 4 # number of layers per calibration pass (>=1) | |
| calib_cov_groups: str = "qkv,o,mlp" # comma-list: qkv,o,mlp,down | |
| calib_cov_store_dtype: str = "fp16" # fp16|bf16|fp32 (saved chol factor dtype) | |
| # Misc | |
| auto_adjust_dare: bool = True | |
| seed: int = 42 | |
| verbose: bool = True | |
| # ============================================================================= | |
| # GPU MEMORY MANAGEMENT | |
| # ============================================================================= | |
| class GPUMemoryManager: | |
| def __init__(self, device: Union[str, torch.device], | |
| memory_fraction: float = 0.85, | |
| verbose: bool = True): | |
| self.device = torch.device(device) if isinstance(device, str) else device | |
| self.memory_fraction = memory_fraction | |
| self.verbose = verbose | |
| if self.device.type == "cuda": | |
| torch.cuda.set_device(self.device) | |
| self.total_memory = torch.cuda.get_device_properties(self.device).total_memory | |
| self.max_memory = int(self.total_memory * memory_fraction) | |
| else: | |
| self.total_memory = 0 | |
| self.max_memory = 0 | |
| def get_stats(self) -> Dict[str, float]: | |
| if self.device.type != "cuda": | |
| return {"total_gb": 0, "used_gb": 0, "free_gb": 0} | |
| allocated = torch.cuda.memory_allocated(self.device) | |
| return { | |
| "total_gb": self.total_memory / (1024**3), | |
| "used_gb": allocated / (1024**3), | |
| "free_gb": (self.max_memory - allocated) / (1024**3), | |
| } | |
| def can_allocate(self, size_bytes: int) -> bool: | |
| if self.device.type != "cuda": | |
| return True | |
| allocated = torch.cuda.memory_allocated(self.device) | |
| return (allocated + size_bytes) < self.max_memory | |
| def clear_cache(self): | |
| gc.collect() | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| # ============================================================================= | |
| # CUDA STREAM MANAGER | |
| # ============================================================================= | |
| class CUDAStreamManager: | |
| def __init__(self, device: torch.device, num_streams: int = 3): | |
| self.device = device | |
| self.num_streams = num_streams | |
| self.streams = [] | |
| self.current_stream_idx = 0 | |
| if device.type == "cuda": | |
| for _ in range(num_streams): | |
| self.streams.append(torch.cuda.Stream(device=device)) | |
| def stream_context(self, stream_idx: Optional[int] = None): | |
| if not self.streams: | |
| yield | |
| return | |
| idx = stream_idx if stream_idx is not None else self.current_stream_idx | |
| self.current_stream_idx = (self.current_stream_idx + 1) % self.num_streams | |
| with torch.cuda.stream(self.streams[idx]): | |
| yield | |
| def synchronize_current(self): | |
| if self.streams: | |
| self.streams[self.current_stream_idx].synchronize() | |
| def synchronize_all(self): | |
| for stream in self.streams: | |
| stream.synchronize() | |
| # ============================================================================= | |
| # LAYER OFFLOADER | |
| # ============================================================================= | |
| class LayerOffloader: | |
| def __init__(self, | |
| strategy: OffloadStrategy, | |
| device: torch.device, | |
| offload_dir: Optional[str] = None, | |
| use_pinned: bool = True, | |
| max_cpu_gb: float = 32.0, | |
| prefetch_count: int = 2, | |
| verbose: bool = True): | |
| self.strategy = strategy | |
| self.device = device | |
| self.use_pinned = use_pinned and device.type == "cuda" | |
| self.max_cpu_bytes = int(max_cpu_gb * 1024**3) | |
| self.prefetch_count = prefetch_count | |
| self.verbose = verbose | |
| self.gpu_cache: Dict[str, torch.Tensor] = {} | |
| self.cpu_cache: Dict[str, torch.Tensor] = {} | |
| self.disk_paths: Dict[str, str] = {} | |
| self.cpu_bytes_used = 0 | |
| if offload_dir: | |
| self.offload_dir = offload_dir | |
| else: | |
| self.offload_dir = tempfile.mkdtemp(prefix="distill_offload_") | |
| os.makedirs(self.offload_dir, exist_ok=True) | |
| self.prefetch_queue = queue.Queue() | |
| self.stop_prefetch = threading.Event() | |
| self.prefetch_thread = None | |
| self.transfer_stream = None | |
| if device.type == "cuda" and strategy != OffloadStrategy.NONE: | |
| self.transfer_stream = torch.cuda.Stream(device=device) | |
| self.prefetch_thread = threading.Thread(target=self._prefetch_worker, daemon=True) | |
| self.prefetch_thread.start() | |
| def _prefetch_worker(self): | |
| while not self.stop_prefetch.is_set(): | |
| try: | |
| key = self.prefetch_queue.get(timeout=0.1) | |
| if key not in self.gpu_cache: | |
| self._load_to_gpu(key) | |
| except queue.Empty: | |
| continue | |
| def _to_pinned(self, tensor: torch.Tensor) -> torch.Tensor: | |
| if self.use_pinned and not tensor.is_pinned(): | |
| pinned = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True) | |
| pinned.copy_(tensor) | |
| return pinned | |
| return tensor | |
| def store(self, key: str, tensor: torch.Tensor, priority: str = "cpu"): | |
| if self.strategy == OffloadStrategy.NONE: | |
| self.gpu_cache[key] = tensor | |
| else: | |
| size = tensor.numel() * tensor.element_size() | |
| if priority == "disk": | |
| self._store_to_disk(key, tensor) | |
| elif self.cpu_bytes_used + size <= self.max_cpu_bytes: | |
| cpu_tensor = tensor.cpu() | |
| self.cpu_cache[key] = self._to_pinned(cpu_tensor) | |
| self.cpu_bytes_used += size | |
| else: | |
| self._store_to_disk(key, tensor) | |
| def _store_to_disk(self, key: str, tensor: torch.Tensor): | |
| safe_key = re.sub(r'[^\w\-.]', '_', key) | |
| path = os.path.join(self.offload_dir, f"{safe_key}.pt") | |
| torch.save(tensor.cpu(), path) | |
| self.disk_paths[key] = path | |
| def _load_to_gpu(self, key: str) -> Optional[torch.Tensor]: | |
| if key in self.gpu_cache: | |
| return self.gpu_cache[key] | |
| tensor = None | |
| if key in self.cpu_cache: | |
| tensor = self.cpu_cache[key] | |
| elif key in self.disk_paths: | |
| tensor = torch.load(self.disk_paths[key]) | |
| if tensor is not None: | |
| if self.transfer_stream: | |
| with torch.cuda.stream(self.transfer_stream): | |
| gpu_tensor = tensor.to(self.device, non_blocking=True) | |
| self.transfer_stream.synchronize() | |
| else: | |
| gpu_tensor = tensor.to(self.device) | |
| self.gpu_cache[key] = gpu_tensor | |
| return gpu_tensor | |
| return None | |
| def get(self, key: str, prefetch_next: Optional[List[str]] = None) -> Optional[torch.Tensor]: | |
| if prefetch_next: | |
| for next_key in prefetch_next[:self.prefetch_count]: | |
| if next_key not in self.gpu_cache: | |
| self.prefetch_queue.put(next_key) | |
| return self._load_to_gpu(key) | |
| def evict(self, key: str): | |
| if key in self.gpu_cache: | |
| del self.gpu_cache[key] | |
| def evict_all_gpu(self): | |
| self.gpu_cache.clear() | |
| if self.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| def cleanup(self): | |
| self.stop_prefetch.set() | |
| if self.prefetch_thread: | |
| self.prefetch_thread.join(timeout=1.0) | |
| self.gpu_cache.clear() | |
| self.cpu_cache.clear() | |
| if os.path.exists(self.offload_dir): | |
| shutil.rmtree(self.offload_dir, ignore_errors=True) | |
| # ============================================================================= | |
| # GPU-ACCELERATED SVD | |
| # ============================================================================= | |
| class GPUAcceleratedSVD: | |
| """SVD backend with GPU acceleration and low-memory low-rank fallbacks. | |
| Implemented low-rank backends: | |
| - Randomized SVD with numerically-stable subspace iteration (Halko et al., Alg. 4.4) | |
| - Block Krylov subspace SVD (stores intermediate iterates, improves accuracy per iteration) | |
| - Adaptive range finder (Halko et al., Alg. 4.2) for energy-targeted rank selection | |
| - DRF-SVD (Dual-Residual Folding SVD): residual-folding refinement for improved subspace capture | |
| """ | |
| def __init__( | |
| self, | |
| device: torch.device, | |
| memory_manager: GPUMemoryManager, | |
| stream_manager: Optional[CUDAStreamManager], | |
| use_mixed_precision: bool = True, | |
| chunk_threshold_mb: int = 512, | |
| chunk_size_mb: int = 128, | |
| force_cpu: bool = False, | |
| verbose: bool = False, | |
| svd_mode: str = "auto", | |
| randomized_iter: int = 2, | |
| randomized_oversamples: int = 8, | |
| auto_min_dim: int = 1024, | |
| auto_full_rank_ratio: float = 0.6, | |
| auto_lowrank: str = "randomized", # randomized|krylov | |
| adaptive_block_size: int = 32, | |
| adaptive_n_test: int = 8, | |
| drf_steps: int = 1, | |
| drf_theta: float = 0.5, | |
| drf_resid_eps: float = 1e-8, | |
| aurora_steps: int = 1, | |
| aurora_order: int = 2, | |
| aurora_theta1: float = 0.5, | |
| aurora_theta2: float = 0.25, | |
| aurora_resid_eps: float = 1e-8, | |
| aurora_level2_keep: float = 1.0, | |
| ): | |
| self.device = device | |
| self.memory_manager = memory_manager | |
| self.stream_manager = stream_manager | |
| self.use_mixed_precision = use_mixed_precision | |
| self.chunk_threshold_bytes = int(chunk_threshold_mb) * 1024 * 1024 | |
| self.chunk_size_bytes = int(chunk_size_mb) * 1024 * 1024 | |
| self.force_cpu = bool(force_cpu) | |
| self.verbose = bool(verbose) | |
| self.svd_mode = (svd_mode or "auto").lower() | |
| self.randomized_iter = int(randomized_iter) | |
| self.randomized_oversamples = int(randomized_oversamples) | |
| self.auto_min_dim = int(auto_min_dim) | |
| self.auto_full_rank_ratio = float(auto_full_rank_ratio) | |
| self.auto_lowrank = (auto_lowrank or "randomized").lower() | |
| self.adaptive_block_size = int(adaptive_block_size) | |
| self.adaptive_n_test = int(adaptive_n_test) | |
| # DRF-SVD knobs | |
| self.drf_steps = int(drf_steps) | |
| self.drf_theta = float(drf_theta) | |
| self.drf_resid_eps = float(drf_resid_eps) | |
| # AURORA-SVD knobs | |
| self.aurora_steps = int(aurora_steps) | |
| self.aurora_order = int(aurora_order) | |
| self.aurora_theta1 = float(aurora_theta1) | |
| self.aurora_theta2 = float(aurora_theta2) | |
| self.aurora_resid_eps = float(aurora_resid_eps) | |
| self.aurora_level2_keep = float(aurora_level2_keep) | |
| # Validate | |
| if self.svd_mode not in ("auto", "full", "randomized", "krylov", "adaptive", "drf", "aurora"): | |
| raise ValueError( | |
| f"Invalid svd_mode: {svd_mode} " | |
| f"(expected auto|full|randomized|krylov|adaptive|drf|aurora)" | |
| ) | |
| if self.auto_lowrank not in ("randomized", "krylov", "drf", "aurora"): | |
| raise ValueError("auto_lowrank must be randomized|krylov|drf|aurora") | |
| if self.randomized_iter < 0: | |
| raise ValueError("randomized_iter must be >= 0") | |
| if self.randomized_oversamples < 0: | |
| raise ValueError("randomized_oversamples must be >= 0") | |
| if not (0.0 < self.auto_full_rank_ratio <= 1.0): | |
| raise ValueError("auto_full_rank_ratio must be in (0, 1]") | |
| if self.adaptive_block_size <= 0: | |
| raise ValueError("adaptive_block_size must be > 0") | |
| if self.adaptive_n_test <= 0: | |
| raise ValueError("adaptive_n_test must be > 0") | |
| if self.drf_steps < 0: | |
| raise ValueError("drf_steps must be >= 0") | |
| if not (0.0 <= self.drf_theta <= 1.0): | |
| raise ValueError("drf_theta must be in [0, 1]") | |
| if self.drf_resid_eps <= 0: | |
| raise ValueError("drf_resid_eps must be > 0") | |
| # AURORA validation | |
| if self.aurora_steps < 0: | |
| raise ValueError("aurora_steps must be >= 0") | |
| if self.aurora_order not in (1, 2): | |
| raise ValueError("aurora_order must be 1 or 2") | |
| if not (0.0 <= self.aurora_theta1 <= 1.0): | |
| raise ValueError("aurora_theta1 must be in [0, 1]") | |
| if not (0.0 <= self.aurora_theta2 <= 1.0): | |
| raise ValueError("aurora_theta2 must be in [0, 1]") | |
| if self.aurora_resid_eps <= 0: | |
| raise ValueError("aurora_resid_eps must be > 0") | |
| if not (0.0 <= self.aurora_level2_keep <= 1.0): | |
| raise ValueError("aurora_level2_keep must be in [0, 1]") | |
| def _stream_context(self): | |
| if self.stream_manager is None: | |
| yield | |
| else: | |
| with self.stream_manager.stream_context(): | |
| yield | |
| # ------------------------------------------------------------------------- | |
| # Public APIs | |
| # ------------------------------------------------------------------------- | |
| def svd(self, tensor: torch.Tensor, rank: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Compute truncated SVD (U, S, Vh) with a bounded rank. | |
| Note: | |
| - svd_mode=adaptive is an energy-targeted algorithm; for fixed-rank calls it | |
| dispatches to the configured low-rank backend (auto_lowrank). | |
| - Use `svd_energy(...)` to run the Adaptive range finder (Alg. 4.2). | |
| """ | |
| if tensor is None: | |
| raise ValueError("tensor is None") | |
| if tensor.dim() != 2: | |
| raise ValueError(f"SVD expects a 2D tensor, got dim={tensor.dim()}") | |
| m, n = tensor.shape | |
| min_dim = min(m, n) | |
| r = int(rank) | |
| if r <= 0: | |
| raise ValueError(f"rank must be > 0, got {rank}") | |
| r = min(r, min_dim) | |
| # Estimate size in fp32 (we cast for numerical stability) | |
| tensor_bytes = int(m) * int(n) * 4 | |
| if self.force_cpu: | |
| return self._cpu_svd(tensor, r) | |
| mode = self.svd_mode | |
| # Heuristics for auto mode: decide full vs low-rank. | |
| if mode == "auto": | |
| if tensor_bytes > self.chunk_threshold_bytes: | |
| mode = self.auto_lowrank | |
| elif min_dim >= self.auto_min_dim and r < int(min_dim * self.auto_full_rank_ratio): | |
| mode = self.auto_lowrank | |
| else: | |
| mode = "full" | |
| # svd_mode=adaptive is only meaningful for `svd_energy`; for fixed-rank it becomes low-rank. | |
| if mode == "adaptive": | |
| mode = self.auto_lowrank | |
| if mode in ("randomized", "krylov", "drf", "aurora"): | |
| return self._lowrank_svd( | |
| tensor, | |
| r, | |
| algo=mode, | |
| n_oversamples=self.randomized_oversamples, | |
| n_iter=self.randomized_iter, | |
| ) | |
| # Full SVD path | |
| try: | |
| return self._gpu_svd(tensor, r) | |
| except RuntimeError as e: | |
| if self.verbose: | |
| print(f"[SVD] GPU full SVD failed ({type(e).__name__}: {e}); falling back to low-rank SVD.") | |
| return self._lowrank_svd( | |
| tensor, | |
| r, | |
| algo=self.auto_lowrank, | |
| n_oversamples=self.randomized_oversamples, | |
| n_iter=self.randomized_iter, | |
| ) | |
| def randomized_svd( | |
| self, | |
| tensor: torch.Tensor, | |
| rank: int, | |
| n_oversamples: Optional[int] = None, | |
| n_iter: Optional[int] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Explicit low-rank SVD (randomized/Krylov) with per-call parameters. | |
| Dispatch: | |
| - If svd_mode is krylov, uses Block Krylov. | |
| - If svd_mode is auto and auto_lowrank is krylov, uses Block Krylov. | |
| - Otherwise uses randomized subspace iteration (Alg. 4.4). | |
| """ | |
| algo = "randomized" | |
| if self.svd_mode == "krylov": | |
| algo = "krylov" | |
| elif self.svd_mode == "drf": | |
| algo = "drf" | |
| elif self.svd_mode == "aurora": | |
| algo = "aurora" | |
| elif self.svd_mode == "auto": | |
| if self.auto_lowrank == "krylov": | |
| algo = "krylov" | |
| elif self.auto_lowrank == "drf": | |
| algo = "drf" | |
| elif self.auto_lowrank == "aurora": | |
| algo = "aurora" | |
| return self._lowrank_svd( | |
| tensor, | |
| int(rank), | |
| algo=algo, | |
| n_oversamples=self.randomized_oversamples if n_oversamples is None else int(n_oversamples), | |
| n_iter=self.randomized_iter if n_iter is None else int(n_iter), | |
| ) | |
| def svd_energy( | |
| self, | |
| tensor: torch.Tensor, | |
| energy_threshold: float, | |
| min_rank: int, | |
| max_rank: int, | |
| block_size: Optional[int] = None, | |
| n_test: Optional[int] = None, | |
| n_iter: Optional[int] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Energy-targeted low-rank SVD using Adaptive range finder (Halko et al., Alg. 4.2). | |
| Parameters | |
| ---------- | |
| energy_threshold: | |
| Target retained Frobenius energy fraction in (0,1]. The adaptive range finder | |
| controls the residual Frobenius energy via eps = sqrt(1 - energy_threshold). | |
| min_rank / max_rank: | |
| Hard rank bounds. | |
| block_size: | |
| Number of basis vectors added per adaptive iteration. | |
| n_test: | |
| Number of random test vectors for residual estimation. | |
| n_iter: | |
| Optional subspace-iteration steps applied to each new block before QR. | |
| """ | |
| if tensor is None: | |
| raise ValueError("tensor is None") | |
| if tensor.dim() != 2: | |
| raise ValueError(f"SVD expects a 2D tensor, got dim={tensor.dim()}") | |
| with self._stream_context(): | |
| x = tensor.to(self.device, non_blocking=True).float() | |
| m, n = x.shape | |
| min_dim = min(m, n) | |
| max_r = min(int(max_rank), min_dim) | |
| min_r = max(1, min(int(min_rank), max_r)) | |
| tau = float(energy_threshold) | |
| tau = float(max(0.0, min(1.0, tau))) | |
| # Residual Frobenius tolerance: ||(I-QQ^T)A||_F <= eps ||A||_F | |
| eps = math.sqrt(max(0.0, 1.0 - tau)) | |
| bs = self.adaptive_block_size if block_size is None else int(block_size) | |
| nt = self.adaptive_n_test if n_test is None else int(n_test) | |
| q = 0 if n_iter is None else int(n_iter) | |
| Q = self._adaptive_range_finder( | |
| x, | |
| eps=eps, | |
| max_rank=max_r, | |
| block_size=bs, | |
| n_test=nt, | |
| n_iter=q, | |
| ) | |
| # Ensure at least min_rank basis vectors. | |
| if Q.shape[1] < min_r: | |
| add = min_r - Q.shape[1] | |
| omega = torch.randn(n, add, device=self.device, dtype=torch.float32) | |
| Y = x @ omega | |
| if Q.shape[1] > 0: | |
| Y = Y - Q @ (Q.transpose(0, 1) @ Y) | |
| Qi, _ = torch.linalg.qr(Y, mode="reduced") | |
| Q = torch.cat([Q, Qi], dim=1) | |
| # Cap and re-orthonormalize. | |
| if Q.shape[1] > max_r: | |
| Q = Q[:, :max_r] | |
| Q, _ = torch.linalg.qr(Q, mode="reduced") | |
| B = Q.transpose(0, 1) @ x | |
| Ub, S, Vh = torch.linalg.svd(B, full_matrices=False) | |
| U = Q @ Ub | |
| return U, S, Vh | |
| # ------------------------------------------------------------------------- | |
| # Full SVD backends | |
| # ------------------------------------------------------------------------- | |
| def _gpu_svd(self, x: torch.Tensor, r: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Full SVD on GPU, then truncate to r.""" | |
| with self._stream_context(): | |
| x = x.to(self.device, non_blocking=True).float() | |
| U, S, Vh = torch.linalg.svd(x, full_matrices=False) | |
| return U[:, :r], S[:r], Vh[:r, :] | |
| def _cpu_svd(self, x: torch.Tensor, r: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Full SVD on CPU, then truncate to r.""" | |
| x = x.cpu().float() | |
| U, S, Vh = torch.linalg.svd(x, full_matrices=False) | |
| return U[:, :r].to(self.device), S[:r].to(self.device), Vh[:r, :].to(self.device) | |
| # ------------------------------------------------------------------------- | |
| # Low-rank SVD backends | |
| # ------------------------------------------------------------------------- | |
| def _lowrank_svd( | |
| self, | |
| x: torch.Tensor, | |
| r: int, | |
| algo: str, | |
| n_oversamples: int = 8, | |
| n_iter: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| if algo == "krylov": | |
| return self._krylov_svd(x, r, n_oversamples=n_oversamples, n_iter=n_iter) | |
| if algo == "drf": | |
| return self._drf_svd(x, r, n_oversamples=n_oversamples, n_iter=n_iter) | |
| if algo == "aurora": | |
| return self._aurora_svd(x, r, n_oversamples=n_oversamples, n_iter=n_iter) | |
| return self._randomized_svd(x, r, n_oversamples=n_oversamples, n_iter=n_iter) | |
| def _randomized_svd( | |
| self, | |
| x: torch.Tensor, | |
| r: int, | |
| n_oversamples: int = 8, | |
| n_iter: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Randomized SVD using numerically-stable subspace iteration (Halko et al., Algorithm 4.4). | |
| This alternates applications of A and A^T with intermediate QR orthonormalization. | |
| """ | |
| with self._stream_context(): | |
| x = x.to(self.device, non_blocking=True).float() | |
| m, n = x.shape | |
| r = min(int(r), min(m, n)) | |
| p = max(0, int(n_oversamples)) | |
| k = min(min(m, n), r + p) | |
| if k <= 0: | |
| raise ValueError("randomized_svd: computed k <= 0") | |
| omega = torch.randn(n, k, device=self.device, dtype=torch.float32) | |
| Y = x @ omega | |
| Q, _ = torch.linalg.qr(Y, mode="reduced") | |
| q = max(0, int(n_iter)) | |
| for _ in range(q): | |
| Y_tilde = x.transpose(0, 1) @ Q | |
| Q_tilde, _ = torch.linalg.qr(Y_tilde, mode="reduced") | |
| Y = x @ Q_tilde | |
| Q, _ = torch.linalg.qr(Y, mode="reduced") | |
| B = Q.transpose(0, 1) @ x | |
| Ub, S, Vh = torch.linalg.svd(B, full_matrices=False) | |
| U = Q @ Ub | |
| return U[:, :r], S[:r], Vh[:r, :] | |
| def _krylov_svd( | |
| self, | |
| x: torch.Tensor, | |
| r: int, | |
| n_oversamples: int = 8, | |
| n_iter: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Block Krylov subspace SVD. | |
| Builds the Krylov subspace span{Aฮฉ, (AA^T)Aฮฉ, ..., (AA^T)^q Aฮฉ} without forming AA^T, | |
| using alternating multiplies by A and A^T. Intermediate orthonormal blocks are stored | |
| and concatenated before solving the projected SVD. | |
| """ | |
| with self._stream_context(): | |
| x = x.to(self.device, non_blocking=True).float() | |
| m, n = x.shape | |
| r = min(int(r), min(m, n)) | |
| p = max(0, int(n_oversamples)) | |
| k = min(min(m, n), r + p) | |
| if k <= 0: | |
| raise ValueError("krylov_svd: computed k <= 0") | |
| omega = torch.randn(n, k, device=self.device, dtype=torch.float32) | |
| Y = x @ omega | |
| Q, _ = torch.linalg.qr(Y, mode="reduced") | |
| blocks = [Q] | |
| q = max(0, int(n_iter)) | |
| for _ in range(q): | |
| # One stable subspace iteration step, but keep each block. | |
| Y_tilde = x.transpose(0, 1) @ Q | |
| Q_tilde, _ = torch.linalg.qr(Y_tilde, mode="reduced") | |
| Y = x @ Q_tilde | |
| Q, _ = torch.linalg.qr(Y, mode="reduced") | |
| blocks.append(Q) | |
| Qbar = torch.cat(blocks, dim=1) | |
| Qbar, _ = torch.linalg.qr(Qbar, mode="reduced") | |
| B = Qbar.transpose(0, 1) @ x | |
| Ub, S, Vh = torch.linalg.svd(B, full_matrices=False) | |
| U = Qbar @ Ub | |
| return U[:, :r], S[:r], Vh[:r, :] | |
| def _drf_svd( | |
| self, | |
| x: torch.Tensor, | |
| r: int, | |
| n_oversamples: int = 8, | |
| n_iter: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| DRF-SVD (Dual-Residual Folding SVD) โ a residual-folding refinement on top of a | |
| single randomized range capture. | |
| Motivation (distillation-centric): | |
| For distillation deltas, the spectrum is often moderately ill-conditioned and | |
| the LoRA rank budget is tight. Standard randomized SVD accuracy can be limited | |
| by incomplete capture of directions that are *correlated* with the current | |
| approximate singular space but live in its orthogonal complement. DRF-SVD | |
| explicitly augments both left and right subspaces with *scaled residual* | |
| directions from the current approximation and then performs a RayleighโRitz | |
| compression back to the target rank. | |
| Algorithm sketch (one refinement step): | |
| 1) Compute an initial k0=(r+p) low-rank SVD approximation via randomized SVD. | |
| 2) Form residual blocks: | |
| R_L = A V - U diag(S) | |
| R_R = A^T U - V diag(S) | |
| then scale them by 1/(S^theta) to stabilize small/large singular components. | |
| 3) Fold residual blocks into the subspaces: | |
| Q = orth([U, R_L]), P = orth([V, R_R]) | |
| 4) Solve the projected SVD of M = Q^T A P and truncate back to k0. | |
| 5) Repeat for `drf_steps` rounds (default 1), then return the top-r slice. | |
| Notes: | |
| - This is *not* power iteration: residual folding adds directions proportional to | |
| the current first-order optimality violations for each singular triplet. | |
| - Numerical stability: QR is used for orthonormalization; residual scaling uses | |
| an epsilon floor to avoid division by 0. | |
| """ | |
| with self._stream_context(): | |
| A = x.to(self.device, non_blocking=True).float() | |
| m, n = A.shape | |
| min_dim = min(m, n) | |
| r = min(int(r), min_dim) | |
| p = max(0, int(n_oversamples)) | |
| # Internal working rank: keep (r+p) components for refinement, then truncate to r. | |
| k0 = min(min_dim, r + p) | |
| if k0 <= 0: | |
| raise ValueError("drf_svd: computed k0 <= 0") | |
| # Initial capture (use oversamples=0 here to make k0 the true working rank). | |
| U, S, Vh = self._randomized_svd(A, k0, n_oversamples=0, n_iter=max(0, int(n_iter))) | |
| # Ensure contiguous | |
| U = U.contiguous() | |
| S = S.contiguous() | |
| Vh = Vh.contiguous() | |
| steps = max(0, int(getattr(self, "drf_steps", 1))) | |
| theta = float(getattr(self, "drf_theta", 0.5)) | |
| theta = float(max(0.0, min(1.0, theta))) | |
| eps = float(getattr(self, "drf_resid_eps", 1e-8)) | |
| for _ in range(steps): | |
| # Current right singular basis | |
| V = Vh.transpose(0, 1).contiguous() # (n, k0) | |
| # Residual blocks (m,k0) and (n,k0) | |
| AV = A @ V | |
| ATU = A.transpose(0, 1) @ U | |
| # R_L = A V - U diag(S); R_R = A^T U - V diag(S) | |
| US = U * S.unsqueeze(0) | |
| VS = V * S.unsqueeze(0) | |
| R_L = AV - US | |
| R_R = ATU - VS | |
| # Scale residuals by 1/(S^theta) to reduce dominance of the largest modes. | |
| if theta > 0.0: | |
| denom = torch.pow(torch.clamp(S, min=eps), theta).unsqueeze(0) | |
| R_L = R_L / denom | |
| R_R = R_R / denom | |
| # Fold residuals into subspaces and re-orthonormalize. | |
| Q = torch.cat([U, R_L], dim=1) | |
| P = torch.cat([V, R_R], dim=1) | |
| Q, _ = torch.linalg.qr(Q, mode="reduced") | |
| P, _ = torch.linalg.qr(P, mode="reduced") | |
| # Project and solve small SVD | |
| # M = Q^T A P, computed as (Q^T (A P)) for efficiency. | |
| AP = A @ P | |
| M = Q.transpose(0, 1) @ AP | |
| Uc, Sc, Vhc = torch.linalg.svd(M, full_matrices=False) | |
| # Keep k0 components for potential further refinement. | |
| kk = min(int(Sc.numel()), int(k0)) | |
| if kk <= 0: | |
| break | |
| U = (Q @ Uc[:, :kk]).contiguous() | |
| S = Sc[:kk].contiguous() | |
| Vh = (Vhc[:kk, :] @ P.transpose(0, 1)).contiguous() | |
| # Defensive re-orthonormalization drift guard (cheap; kk is small). | |
| U, _ = torch.linalg.qr(U, mode="reduced") | |
| Vt = Vh.transpose(0, 1).contiguous() | |
| Vt, _ = torch.linalg.qr(Vt, mode="reduced") | |
| Vh = Vt.transpose(0, 1).contiguous() | |
| # Return top-r | |
| U = U[:, :r] | |
| S = S[:r] | |
| Vh = Vh[:r, :] | |
| return U, S, Vh | |
| def _aurora_svd( | |
| self, | |
| x: torch.Tensor, | |
| r: int, | |
| n_oversamples: int = 8, | |
| n_iter: int = 2, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| AURORA-SVD (Adaptive Unrolled Residual-Order Range Augmentation SVD). | |
| Distillation-centric goal: | |
| Improve low-rank approximation quality under a tight LoRA rank budget by | |
| augmenting the approximation subspace with *structured residual orders*, | |
| then performing a RayleighโRitz compression back to the working rank. | |
| Core idea (per refinement round): | |
| - Start from a working-rank approximation A โ U diag(S) V^T. | |
| - Form first-order residual blocks: | |
| R_L = A V โ U diag(S) | |
| R_R = A^T U โ V diag(S) | |
| - Optionally form a second-order residual response (order=2): | |
| T_L = A (scaled R_R), T_R = A^T (scaled R_L) | |
| - Build enlarged subspaces Q, P and compute the best rank-(r+p) approximation | |
| inside span(Q) ร span(P) by SVD of the projected core Q^T A P. | |
| Implementation notes (v4.2.0 speed path): | |
| - Uses *block-anchored orthonormalization*: keep current U/V blocks intact and | |
| orthonormalize residual blocks against them, avoiding full QR on concatenated | |
| bases while preserving the augmented subspaces (up to floating-point effects). | |
| - Avoids redundant re-orthonormalization of U/V: U and V remain orthonormal by | |
| construction (QR on residual blocks + orthogonalization against anchors; and | |
| SVD of the projected core). | |
| - Reuses A@V when constructing A@P (since V is an explicit leading block of P). | |
| """ | |
| with self._stream_context(): | |
| A = x.to(self.device, non_blocking=True).float() | |
| m, n = A.shape | |
| min_dim = min(m, n) | |
| r = min(int(r), min_dim) | |
| p = max(0, int(n_oversamples)) | |
| k0 = min(min_dim, r + p) | |
| if k0 <= 0: | |
| raise ValueError("aurora_svd: computed k0 <= 0") | |
| # Initial capture (use oversamples=0 so k0 is the true working rank). | |
| U, S, Vh = self._randomized_svd(A, k0, n_oversamples=0, n_iter=max(0, int(n_iter))) | |
| U = U.contiguous() | |
| S = S.contiguous() | |
| Vh = Vh.contiguous() | |
| steps = max(0, int(getattr(self, "aurora_steps", 1))) | |
| order = int(getattr(self, "aurora_order", 2)) | |
| order = 1 if order <= 1 else 2 | |
| theta1 = float(getattr(self, "aurora_theta1", 0.5)) | |
| theta2 = float(getattr(self, "aurora_theta2", 0.25)) | |
| theta1 = float(max(0.0, min(1.0, theta1))) | |
| theta2 = float(max(0.0, min(1.0, theta2))) | |
| eps = float(getattr(self, "aurora_resid_eps", 1e-8)) | |
| eps = float(max(1e-30, eps)) | |
| level2_keep = float(getattr(self, "aurora_level2_keep", 1.0)) | |
| level2_keep = float(max(0.0, min(1.0, level2_keep))) | |
| use_amp = bool(self.use_mixed_precision) and (self.device.type == "cuda") | |
| if use_amp: | |
| # Prefer bf16 when supported; otherwise fp16. | |
| amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| amp_dtype = None # type: ignore | |
| def _mm(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: | |
| """Matrix multiply with an optional AMP speed path; returns fp32 tensor.""" | |
| if use_amp: | |
| with torch.autocast(device_type="cuda", dtype=amp_dtype): | |
| return (X @ Y).float() | |
| return X @ Y | |
| def _orth_block(block: torch.Tensor, basis: torch.Tensor) -> torch.Tensor: | |
| """Orthonormalize `block` against an existing orthonormal `basis`.""" | |
| if block.numel() == 0: | |
| return block | |
| if basis.numel() != 0 and basis.shape[1] > 0: | |
| block = block - basis @ (basis.transpose(0, 1) @ block) | |
| # If the block is (near) zero after projection, QR will produce garbage; handle safely. | |
| if torch.linalg.norm(block, ord="fro") < 1e-20: | |
| return block[:, :0] | |
| Qb, _ = torch.linalg.qr(block, mode="reduced") | |
| return Qb | |
| for _ in range(steps): | |
| V = Vh.transpose(0, 1).contiguous() # (n, k0) | |
| # First-order residuals | |
| AV = _mm(A, V) | |
| ATU = _mm(A.transpose(0, 1), U) | |
| US = U * S.unsqueeze(0) | |
| VS = V * S.unsqueeze(0) | |
| R_L = AV - US # (m, k0) | |
| R_R = ATU - VS # (n, k0) | |
| # Residual scaling for stability / tail emphasis | |
| if theta1 > 0.0: | |
| denom1 = torch.pow(torch.clamp(S, min=eps), theta1).unsqueeze(0) | |
| R_Ls = R_L / denom1 | |
| R_Rs = R_R / denom1 | |
| else: | |
| R_Ls = R_L | |
| R_Rs = R_R | |
| # Build orthonormal bases with anchored blocks (U / V). | |
| Q_blocks: List[torch.Tensor] = [U] | |
| P_blocks: List[torch.Tensor] = [V] | |
| Q1 = _orth_block(R_Ls, U) | |
| if Q1.numel() != 0 and Q1.shape[1] > 0: | |
| Q_blocks.append(Q1) | |
| P1 = _orth_block(R_Rs, V) | |
| if P1.numel() != 0 and P1.shape[1] > 0: | |
| P_blocks.append(P1) | |
| if order >= 2: | |
| # Select a subset of components for the level-2 term (optional). | |
| if level2_keep <= 0.0: | |
| idx = None | |
| elif level2_keep >= 1.0 or k0 <= 1: | |
| idx = None | |
| else: | |
| # Residual energy per component: ||R_L[:,i]||^2 + ||R_R[:,i]||^2 | |
| eL = torch.sum(R_L.float() * R_L.float(), dim=0) | |
| eR = torch.sum(R_R.float() * R_R.float(), dim=0) | |
| e = (eL + eR).contiguous() | |
| keep = max(1, int(round(level2_keep * k0))) | |
| keep = min(keep, k0) | |
| _, topi = torch.topk(e, k=keep, largest=True, sorted=False) | |
| idx = topi | |
| if idx is None: | |
| R_L2 = R_Ls | |
| R_R2 = R_Rs | |
| S2 = S | |
| else: | |
| R_L2 = R_Ls[:, idx] | |
| R_R2 = R_Rs[:, idx] | |
| S2 = S[idx] | |
| # Second-order residual response (apply A / A^T) | |
| T_L = _mm(A, R_R2) | |
| T_R = _mm(A.transpose(0, 1), R_L2) | |
| # Orthogonalize level-2 blocks against current bases. | |
| Q_basis = torch.cat(Q_blocks, dim=1) | |
| P_basis = torch.cat(P_blocks, dim=1) | |
| T_L = T_L - Q_basis @ (Q_basis.transpose(0, 1) @ T_L) | |
| T_R = T_R - P_basis @ (P_basis.transpose(0, 1) @ T_R) | |
| # Optional additional scaling for the level-2 response. | |
| if theta2 > 0.0: | |
| denom2 = torch.pow(torch.clamp(S2, min=eps), theta2).unsqueeze(0) | |
| T_L = T_L / denom2 | |
| T_R = T_R / denom2 | |
| Q2 = _orth_block(T_L, Q_basis) | |
| if Q2.numel() != 0 and Q2.shape[1] > 0: | |
| Q_blocks.append(Q2) | |
| P2 = _orth_block(T_R, P_basis) | |
| if P2.numel() != 0 and P2.shape[1] > 0: | |
| P_blocks.append(P2) | |
| # Final orthonormal bases (already orthonormal by construction) | |
| Q = torch.cat(Q_blocks, dim=1) | |
| P = torch.cat(P_blocks, dim=1) | |
| # Compute A@P with reuse of A@V (V is the explicit leading block of P). | |
| if len(P_blocks) == 1: | |
| AP = AV | |
| else: | |
| P_rest = torch.cat(P_blocks[1:], dim=1) | |
| AP_rest = _mm(A, P_rest) | |
| AP = torch.cat([AV, AP_rest], dim=1) | |
| # Project and solve small SVD (RayleighโRitz) | |
| M = Q.transpose(0, 1) @ AP | |
| Uc, Sc, Vhc = torch.linalg.svd(M, full_matrices=False) | |
| kk = min(int(Sc.numel()), int(k0)) | |
| if kk <= 0: | |
| break | |
| U = (Q @ Uc[:, :kk]).contiguous() | |
| S = Sc[:kk].contiguous() | |
| Vh = (Vhc[:kk, :] @ P.transpose(0, 1)).contiguous() | |
| U = U[:, :r] | |
| S = S[:r] | |
| Vh = Vh[:r, :] | |
| return U, S, Vh | |
| # ------------------------------------------------------------------------- | |
| # Adaptive range finder (Alg. 4.2) | |
| # ------------------------------------------------------------------------- | |
| def _adaptive_range_finder( | |
| self, | |
| x: torch.Tensor, | |
| eps: float, | |
| max_rank: int, | |
| block_size: int, | |
| n_test: int, | |
| n_iter: int = 0, | |
| ) -> torch.Tensor: | |
| """ | |
| Adaptive randomized range finder (Halko et al., Alg. 4.2). | |
| Maintains an estimate of the residual via random test vectors: | |
| H = (I - QQ^T) A G | |
| and stops when ||H||_F <= eps ||A||_F or max_rank is reached. | |
| """ | |
| m, n = x.shape | |
| max_r = min(int(max_rank), min(m, n)) | |
| if max_r <= 0: | |
| return torch.zeros((m, 0), device=self.device, dtype=torch.float32) | |
| bs = max(1, int(block_size)) | |
| nt = max(1, int(n_test)) | |
| q = max(0, int(n_iter)) | |
| eps = float(max(0.0, eps)) | |
| normA = torch.linalg.norm(x, ord="fro") | |
| if normA < 1e-20: | |
| return torch.zeros((m, 0), device=self.device, dtype=torch.float32) | |
| G = torch.randn(n, nt, device=self.device, dtype=torch.float32) | |
| H = x @ G | |
| Q: Optional[torch.Tensor] = None | |
| k = 0 | |
| # Current residual estimate | |
| res = torch.linalg.norm(H, ord="fro") | |
| tol = eps * normA | |
| while (res > tol) and (k < max_r): | |
| b = min(bs, max_r - k) | |
| omega = torch.randn(n, b, device=self.device, dtype=torch.float32) | |
| Y = x @ omega # (m, b) | |
| # Optional stabilization / spectrum-boost for each added block. | |
| if q > 0: | |
| # Stable subspace iteration restricted to the block. | |
| Qi, _ = torch.linalg.qr(Y, mode="reduced") | |
| for _ in range(q): | |
| Z = x.transpose(0, 1) @ Qi | |
| Qz, _ = torch.linalg.qr(Z, mode="reduced") | |
| Y = x @ Qz | |
| Qi, _ = torch.linalg.qr(Y, mode="reduced") | |
| Y = Qi | |
| # Orthogonalize against current basis. | |
| if Q is not None and Q.shape[1] > 0: | |
| Y = Y - Q @ (Q.transpose(0, 1) @ Y) | |
| Qi, _ = torch.linalg.qr(Y, mode="reduced") | |
| # Double-orthogonalize for numerical safety. | |
| if Q is not None and Q.shape[1] > 0: | |
| Qi = Qi - Q @ (Q.transpose(0, 1) @ Qi) | |
| Qi, _ = torch.linalg.qr(Qi, mode="reduced") | |
| if Qi.numel() == 0 or Qi.shape[1] == 0: | |
| break | |
| Q = Qi if Q is None else torch.cat([Q, Qi], dim=1) | |
| k = Q.shape[1] | |
| # Update residual sketch: H <- (I - QiQi^T) H | |
| H = H - Qi @ (Qi.transpose(0, 1) @ H) | |
| res = torch.linalg.norm(H, ord="fro") | |
| if Q is None: | |
| return torch.zeros((m, 0), device=self.device, dtype=torch.float32) | |
| # Final re-orthonormalization (cheap for small k, avoids drift) | |
| Q, _ = torch.linalg.qr(Q, mode="reduced") | |
| if Q.shape[1] > max_r: | |
| Q = Q[:, :max_r] | |
| return Q | |
| # ============================================================================= | |
| # HELPER FUNCTIONS | |
| # ============================================================================= | |
| NORM_TOKENS = ["layernorm", "rmsnorm", "layer_norm", "rms_norm", "ln_", "norm"] | |
| EMBED_TOKENS = ["embed", "wte", "wpe", "lm_head", "output.weight"] | |
| def seed_all(seed: int): | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def read_index_map(model_path: str) -> Dict[str, str]: | |
| """Read weight index map from model directory.""" | |
| idx_files = [ | |
| "model.safetensors.index.json", | |
| "pytorch_model.bin.index.json", | |
| ] | |
| for idx_file in idx_files: | |
| path = os.path.join(model_path, idx_file) | |
| if os.path.exists(path): | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| return data.get("weight_map", {}) | |
| # Single file models | |
| for ext in [".safetensors", ".bin"]: | |
| for name in ["model", "pytorch_model"]: | |
| path = os.path.join(model_path, f"{name}{ext}") | |
| if os.path.exists(path): | |
| if ext == ".safetensors": | |
| with safe_open(path, framework="pt") as f: | |
| return {k: f"{name}{ext}" for k in f.keys()} | |
| return {} | |
| def scan_layers(keys: List[str]) -> Dict[str, List[int]]: | |
| """Scan model keys to find layer structure.""" | |
| layers = defaultdict(set) | |
| for key in keys: | |
| m = re.search(r'(model\.layers|transformer\.h|decoder\.layers)\.(\d+)\.', key) | |
| if m: | |
| prefix = m.group(1) | |
| idx = int(m.group(2)) | |
| layers[prefix].add(idx) | |
| return {k: sorted(list(v)) for k, v in layers.items()} | |
| def split_key(key: str) -> Optional[Tuple[str, str, int, str]]: | |
| """Split a model key into components.""" | |
| patterns = [ | |
| r'^(model\.layers)\.(\d+)\.(.+)$', | |
| r'^(transformer\.h)\.(\d+)\.(.+)$', | |
| r'^(decoder\.layers)\.(\d+)\.(.+)$', | |
| ] | |
| for pattern in patterns: | |
| m = re.match(pattern, key) | |
| if m: | |
| prefix = m.group(1) | |
| idx = int(m.group(2)) | |
| rest = m.group(3) | |
| token_name = rest.split('.')[0] | |
| return prefix, token_name, idx, rest | |
| return None | |
| # ============================================================================= | |
| # ARCHITECTURE DETECTION | |
| # ============================================================================= | |
| def detect_architecture_from_config(model_path: str) -> ArchitectureInfo: | |
| """Detect architecture from transformers config.""" | |
| info = ArchitectureInfo() | |
| if not HAS_TRANSFORMERS: | |
| return info | |
| try: | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| info.hidden_size = getattr(config, "hidden_size", 4096) | |
| info.intermediate_size = getattr(config, "intermediate_size", 11008) | |
| info.num_hidden_layers = getattr(config, "num_hidden_layers", 32) | |
| info.num_attention_heads = getattr(config, "num_attention_heads", 32) | |
| info.vocab_size = getattr(config, "vocab_size", 32000) | |
| num_kv = getattr(config, "num_key_value_heads", None) | |
| if num_kv is None: | |
| num_kv = info.num_attention_heads | |
| info.num_kv_heads = num_kv | |
| if num_kv == info.num_attention_heads: | |
| info.attention_type = AttentionType.MHA | |
| elif num_kv == 1: | |
| info.attention_type = AttentionType.MQA | |
| else: | |
| info.attention_type = AttentionType.GQA | |
| info.num_experts = getattr(config, "num_local_experts", 0) | |
| if info.num_experts == 0: | |
| info.num_experts = getattr(config, "num_experts", 0) | |
| info.num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) | |
| if info.num_experts > 0: | |
| info.mlp_type = MLPType.MOE | |
| elif hasattr(config, "hidden_act"): | |
| act = config.hidden_act.lower() if isinstance(config.hidden_act, str) else "" | |
| if "silu" in act or "swish" in act: | |
| info.mlp_type = MLPType.SWIGLU | |
| elif "gelu" in act: | |
| info.mlp_type = MLPType.GEGLU | |
| except Exception as e: | |
| warnings.warn(f"Config detection failed: {e}") | |
| return info | |
| def detect_architecture_from_weights(keys: List[str]) -> ArchitectureInfo: | |
| """Detect architecture from weight keys.""" | |
| info = ArchitectureInfo() | |
| key_str = " ".join(keys[:500]) | |
| if ".experts." in key_str or "block_sparse_moe" in key_str: | |
| info.mlp_type = MLPType.MOE | |
| expert_indices = set() | |
| for k in keys: | |
| m = re.search(r'experts\.(\d+)', k) | |
| if m: | |
| expert_indices.add(int(m.group(1))) | |
| if expert_indices: | |
| info.num_experts = max(expert_indices) + 1 | |
| if "c_attn" in key_str or "query_key_value" in key_str or "qkv_proj" in key_str: | |
| info.has_fused_qkv = True | |
| layer_indices = set() | |
| for k in keys: | |
| m = re.search(r'layers\.(\d+)', k) or re.search(r'\.h\.(\d+)', k) | |
| if m: | |
| layer_indices.add(int(m.group(1))) | |
| if layer_indices: | |
| info.num_hidden_layers = max(layer_indices) + 1 | |
| return info | |
| # ============================================================================= | |
| # DARE + TIES | |
| # ============================================================================= | |
| def apply_dare(delta: torch.Tensor, drop_rate: float = 0.7, | |
| rescale: bool = True) -> torch.Tensor: | |
| """DARE: Drop And REscale for delta parameters.""" | |
| if delta.dim() != 2 or drop_rate <= 0: | |
| return delta | |
| mask = torch.rand_like(delta.float()) > drop_rate | |
| sparse_delta = delta * mask | |
| if rescale and drop_rate < 1.0: | |
| sparse_delta = sparse_delta / (1.0 - drop_rate) | |
| return sparse_delta | |
| def apply_ties_single(delta: torch.Tensor, density: float = 0.3) -> torch.Tensor: | |
| """ | |
| Apply TIES-style magnitude trimming to a single delta. | |
| This is the "Trim" step of TIES: keep only the top-k parameters by magnitude. | |
| Useful for sparsification even when there's only one delta. | |
| """ | |
| if delta.dim() != 2 or density >= 1.0: | |
| return delta | |
| dtype = delta.dtype | |
| device = delta.device | |
| # Flatten and find threshold | |
| flat = delta.float().view(-1) | |
| k = max(1, int(flat.numel() * density)) | |
| # Get threshold value (k-th largest magnitude) | |
| threshold = torch.topk(flat.abs(), k, largest=True).values[-1] | |
| # Create mask and apply | |
| mask = flat.abs() >= threshold | |
| trimmed = flat * mask.float() | |
| return trimmed.view(delta.shape).to(dtype) | |
| def apply_ties(deltas: List[torch.Tensor], density: float = 0.3) -> torch.Tensor: | |
| """TIES-Merging: Trim, Elect Sign, Disjoint Merge.""" | |
| if not deltas: | |
| return torch.zeros(1) | |
| if len(deltas) == 1: | |
| return deltas[0] | |
| device = deltas[0].device | |
| dtype = deltas[0].dtype | |
| shape = deltas[0].shape | |
| deltas = [d if d.shape == shape else torch.zeros(shape, device=device, dtype=dtype) | |
| for d in deltas] | |
| stacked = torch.stack([d.float() for d in deltas], dim=0) | |
| k = max(1, int(stacked[0].numel() * density)) | |
| abs_flat = stacked.abs().view(len(deltas), -1) | |
| trimmed = [] | |
| for i in range(len(deltas)): | |
| threshold = torch.topk(abs_flat[i], k, largest=True).values[-1] | |
| mask = abs_flat[i] >= threshold | |
| trimmed_flat = stacked[i].view(-1) * mask.float() | |
| trimmed.append(trimmed_flat.view(shape)) | |
| stacked_trimmed = torch.stack(trimmed, dim=0) | |
| signs = torch.sign(stacked_trimmed) | |
| sign_sum = signs.sum(dim=0) | |
| elected_sign = torch.sign(sign_sum) | |
| elected_sign = torch.where(elected_sign == 0, torch.ones_like(elected_sign), elected_sign) | |
| merged = torch.zeros(shape, device=device, dtype=dtype) | |
| counts = torch.zeros(shape, device=device, dtype=dtype) | |
| for t in stacked_trimmed: | |
| match = (torch.sign(t) == elected_sign) | (t == 0) | |
| merged += torch.where(match, t, torch.zeros_like(t)) | |
| counts += match.float() | |
| counts = torch.clamp(counts, min=1) | |
| return (merged / counts).to(dtype) | |
| def lsq_mix_weight_from_deltas( | |
| delta_floor: torch.Tensor, | |
| delta_ceil: torch.Tensor, | |
| fallback: float, | |
| eps: float = 1e-12, | |
| ) -> float: | |
| """Compute convex mixing weight w in [0,1] minimizing ||(1-w)ฮ_floor + wฮ_ceil||_F. | |
| Weight-space least-squares mixing rule: | |
| argmin_w || ฮ_floor + w(ฮ_ceil - ฮ_floor) ||_F^2 | |
| w* = - <ฮ_floor, ฮ_ceil - ฮ_floor> / ||ฮ_ceil - ฮ_floor||_F^2 | |
| If the denominator is too small or inputs are invalid, returns fallback. | |
| """ | |
| try: | |
| if delta_floor is None or delta_ceil is None: | |
| return float(fallback) | |
| if delta_floor.shape != delta_ceil.shape: | |
| return float(fallback) | |
| d0 = delta_floor.float() | |
| d1 = delta_ceil.float() | |
| d = d1 - d0 | |
| denom = float(torch.sum(d * d).item()) | |
| if not (denom > eps) or not math.isfinite(denom): | |
| return float(fallback) | |
| numer = float((-torch.sum(d0 * d)).item()) | |
| if not math.isfinite(numer): | |
| return float(fallback) | |
| w = numer / denom | |
| if not math.isfinite(w): | |
| return float(fallback) | |
| if w < 0.0: | |
| w = 0.0 | |
| elif w > 1.0: | |
| w = 1.0 | |
| return float(w) | |
| except Exception: | |
| return float(fallback) | |
| # ============================================================================= | |
| # ADAPTIVE RANK | |
| # ============================================================================= | |
| def compute_adaptive_rank( | |
| singular_values: torch.Tensor, | |
| energy_threshold: float = 0.95, | |
| min_rank: int = 8, | |
| max_rank: int = 256, | |
| total_energy: Optional[Union[float, torch.Tensor]] = None, | |
| ) -> int: | |
| """ | |
| Compute an adaptive rank based on retained Frobenius-energy. | |
| Parameters | |
| ---------- | |
| singular_values: | |
| 1D tensor of (estimated) singular values, typically sorted in descending order. | |
| energy_threshold: | |
| Target fraction of Frobenius energy to retain. | |
| min_rank / max_rank: | |
| Hard bounds on the returned rank. | |
| total_energy: | |
| If provided, this is interpreted as ||A||_F^2 for the *original* matrix A. | |
| This is the preferred mode when `singular_values` is truncated (e.g. you only | |
| computed the top-k singular values). If not provided, `total_energy` is | |
| computed as sum(s_i^2) over the supplied `singular_values` (legacy behavior). | |
| Notes | |
| ----- | |
| For truncated spectra, using `total_energy=None` can systematically overestimate the | |
| retained-energy ratio (because missing tail energy is treated as 0). Passing | |
| `total_energy=||A||_F^2` avoids this. | |
| """ | |
| s = singular_values.detach().float() | |
| if s.numel() == 0: | |
| return max(0, int(min_rank)) | |
| # Guard against tiny negatives from numerical noise (SVD should be non-negative). | |
| s = torch.clamp(s, min=0.0) | |
| if total_energy is None: | |
| total = torch.sum(s ** 2) | |
| else: | |
| if torch.is_tensor(total_energy): | |
| total = total_energy.to(device=s.device, dtype=s.dtype) | |
| else: | |
| total = torch.tensor(float(total_energy), device=s.device, dtype=s.dtype) | |
| if total < 1e-20: | |
| # Degenerate matrix: keep at least min_rank (and at least 1 if possible). | |
| return max(1, int(min_rank)) | |
| cumulative = torch.cumsum(s ** 2, dim=0) | |
| target = float(energy_threshold) * total | |
| hit = (cumulative >= target) | |
| if bool(hit.any()): | |
| r = int(hit.nonzero(as_tuple=True)[0][0].item() + 1) | |
| else: | |
| r = int(s.numel()) | |
| r = max(int(min_rank), min(r, int(max_rank), int(s.numel()))) | |
| return r | |
| # ============================================================================= | |
| # IMPROVED DIMENSION PROJECTION (FIX #2) | |
| # ============================================================================= | |
| def svd_projection( | |
| src: torch.Tensor, | |
| target_shape: Tuple[int, ...], | |
| svd_engine: "GPUAcceleratedSVD", | |
| config: "DistillConfig", | |
| device: torch.device, | |
| ref: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Project a tensor to `target_shape`. | |
| v3.7.6 fix (student-subspace alignment): | |
| The legacy (pool/crop/interpolate) projection loses too much information and can | |
| destabilize merging (gibberish / repetition), especially for cross-architecture | |
| distillation. | |
| If `ref` (student tensor) is provided and both `src/ref` are 2D matrices, we: | |
| 1) Compute a truncated SVD of teacher: Wt โ Ut diag(St) Vt^T | |
| 2) Compute a truncated SVD of student: Ws โ Us diag(Ss) Vs^T | |
| 3) Build teacher low-rank factors Bt, At (PiSSA-style: sqrt split) | |
| 4) Resize factors to the student's row/col sizes (Bt -> rows, At -> cols) | |
| 5) Project the resized teacher into the *student* singular subspaces: | |
| core = Us^T (Bt_rs @ At_rs) Vs | |
| Wproj = Us @ core @ Vs^T | |
| This explicitly converts teacher information into the student's low-rank subspace | |
| (the "common" space) and avoids the destructive "pool/crop" behavior. | |
| Notes: | |
| - For non-2D tensors, falls back to safe interpolation/crop. | |
| - For same-shape matrices, returns `src` unchanged. | |
| """ | |
| if tuple(src.shape) == tuple(target_shape): | |
| return src | |
| # -------- 1D vectors: linear interpolation -------- | |
| if src.dim() == 1 and len(target_shape) == 1: | |
| out_len = int(target_shape[0]) | |
| x = torch.nan_to_num(src.float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| if x.numel() == out_len: | |
| return x.to(dtype=src.dtype, device=src.device) | |
| x3 = x.view(1, 1, -1) | |
| y3 = F.interpolate(x3, size=out_len, mode="linear", align_corners=False) | |
| return y3.view(out_len).to(dtype=src.dtype, device=src.device) | |
| # -------- 2D matrices: student-subspace alignment (preferred) -------- | |
| if src.dim() == 2 and len(target_shape) == 2: | |
| tm, tn = int(target_shape[0]), int(target_shape[1]) | |
| m, n = int(src.shape[0]), int(src.shape[1]) | |
| # Bound the working rank | |
| k_cap = int(getattr(config, 'projection_rank', 256)) | |
| k_max = min(m, n, tm, tn, k_cap) | |
| if k_max <= 0: | |
| return torch.zeros((tm, tn), device=src.device, dtype=src.dtype) | |
| # Resize helpers | |
| def _resize_rows(mat: torch.Tensor, new_rows: int) -> torch.Tensor: | |
| if mat.shape[0] == new_rows: | |
| return mat | |
| # (rows, k) -> (1, k, rows) -> interpolate -> (new_rows, k) | |
| x = mat.transpose(0, 1).unsqueeze(0) | |
| y = F.interpolate(x, size=new_rows, mode="linear", align_corners=False) | |
| return y.squeeze(0).transpose(0, 1) | |
| def _resize_cols(mat: torch.Tensor, new_cols: int) -> torch.Tensor: | |
| if mat.shape[1] == new_cols: | |
| return mat | |
| # (k, cols) -> (1, k, cols) -> interpolate -> (k, new_cols) | |
| x = mat.unsqueeze(0) | |
| y = F.interpolate(x, size=new_cols, mode="linear", align_corners=False) | |
| return y.squeeze(0) | |
| # Fast fallback: if no ref matrix is provided | |
| if ref is None or (not isinstance(ref, torch.Tensor)) or ref.dim() != 2: | |
| # Teacher-only SVD factor resize (legacy), then reconstruct | |
| src_f = torch.nan_to_num(src.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| try: | |
| Ut, St, Vht = svd_engine.randomized_svd( | |
| src_f, | |
| k_max, | |
| n_oversamples=getattr(config, 'projection_randomized_oversamples', 8), | |
| n_iter=getattr(config, 'projection_randomized_iter', 2), | |
| ) | |
| if getattr(config, 'projection_adaptive_rank', True): | |
| k_t = compute_adaptive_rank( | |
| St, | |
| energy_threshold=getattr(config, 'projection_energy_threshold', 0.99), | |
| min_rank=getattr(config, 'projection_min_rank', 8), | |
| max_rank=k_max, | |
| ) | |
| k = max(1, min(k_t, k_max)) | |
| Ut, St, Vht = Ut[:, :k], St[:k], Vht[:k, :] | |
| else: | |
| k = k_max | |
| sqrtS = torch.sqrt(torch.clamp(St, min=0.0)).to(dtype=Ut.dtype) | |
| Bt = Ut * sqrtS.unsqueeze(0) # (m, k) | |
| At = sqrtS.unsqueeze(1) * Vht # (k, n) | |
| Bt_rs = _resize_rows(Bt, tm) | |
| At_rs = _resize_cols(At, tn) | |
| out = Bt_rs @ At_rs | |
| return out.to(dtype=src.dtype, device=src.device) | |
| except Exception: | |
| # Last resort: bilinear interpolation of the dense matrix | |
| dense = torch.nan_to_num(src.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| out = F.interpolate(dense.unsqueeze(0).unsqueeze(0), size=(tm, tn), mode="bilinear", align_corners=False) | |
| return out.squeeze(0).squeeze(0).to(dtype=src.dtype, device=src.device) | |
| # Ensure ref has the expected target shape | |
| ref = ref.to(device) | |
| if tuple(ref.shape) != (tm, tn): | |
| # We do not attempt to project ref; fallback to teacher-only projection | |
| return svd_projection(src, (tm, tn), svd_engine, config, device, ref=None) | |
| # Compute truncated SVDs (teacher + student) | |
| src_f = torch.nan_to_num(src.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| ref_f = torch.nan_to_num(ref.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| try: | |
| Ut, St, Vht = svd_engine.randomized_svd( | |
| src_f, | |
| k_max, | |
| n_oversamples=getattr(config, 'projection_randomized_oversamples', 8), | |
| n_iter=getattr(config, 'projection_randomized_iter', 2), | |
| ) | |
| Us, Ss, Vhs = svd_engine.randomized_svd( | |
| ref_f, | |
| k_max, | |
| n_oversamples=getattr(config, 'projection_randomized_oversamples', 8), | |
| n_iter=getattr(config, 'projection_randomized_iter', 2), | |
| ) | |
| if getattr(config, 'projection_adaptive_rank', True): | |
| k_t = compute_adaptive_rank( | |
| St, | |
| energy_threshold=getattr(config, 'projection_energy_threshold', 0.99), | |
| min_rank=getattr(config, 'projection_min_rank', 8), | |
| max_rank=k_max, | |
| ) | |
| k_s = compute_adaptive_rank( | |
| Ss, | |
| energy_threshold=getattr(config, 'projection_energy_threshold', 0.99), | |
| min_rank=getattr(config, 'projection_min_rank', 8), | |
| max_rank=k_max, | |
| ) | |
| k = max(1, min(k_t, k_s, k_max)) | |
| else: | |
| k = k_max | |
| Ut, St, Vht = Ut[:, :k], St[:k], Vht[:k, :] | |
| Us, Ss, Vhs = Us[:, :k], Ss[:k], Vhs[:k, :] | |
| # Teacher low-rank factors (PiSSA sqrt split) | |
| sqrtSt = torch.sqrt(torch.clamp(St, min=0.0)).to(dtype=Ut.dtype) | |
| Bt = Ut * sqrtSt.unsqueeze(0) # (m, k) | |
| At = sqrtSt.unsqueeze(1) * Vht # (k, n) | |
| # Resize teacher factors to student dims | |
| Bt_rs = _resize_rows(Bt, tm) # (tm, k) | |
| At_rs = _resize_cols(At, tn) # (k, tn) | |
| # Student right singular vectors | |
| Vs = Vhs.transpose(0, 1).contiguous() # (tn, k) | |
| # core = Us^T (Bt_rs @ At_rs) Vs (kรk) | |
| # Compute as (Us^T Bt_rs) @ (At_rs Vs) for efficiency. | |
| left = Us.transpose(0, 1) @ Bt_rs # (k, k) | |
| right = At_rs @ Vs # (k, k) | |
| core = left @ right # (k, k) | |
| out = Us @ core @ Vs.transpose(0, 1) # (tm, tn) | |
| return out.to(dtype=src.dtype, device=src.device) | |
| except Exception: | |
| dense = torch.nan_to_num(src.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| out = F.interpolate(dense.unsqueeze(0).unsqueeze(0), size=(tm, tn), mode="bilinear", align_corners=False) | |
| return out.squeeze(0).squeeze(0).to(dtype=src.dtype, device=src.device) | |
| # -------- Generic fallback for other ranks/shapes -------- | |
| # Fallback: flatten + interpolate length-wise, then reshape/crop. | |
| tgt_numel = int(torch.tensor(target_shape).prod().item()) | |
| x = torch.nan_to_num(src.float().view(1, 1, -1), nan=0.0, posinf=0.0, neginf=0.0) | |
| y = F.interpolate(x, size=tgt_numel, mode="linear", align_corners=False) | |
| out = y.view(*target_shape) | |
| return out.to(dtype=src.dtype, device=src.device) | |
| # ============================================================================= | |
| # CROSS-ARCH SUBSPACE-DELTA (FIX: preserve student residual space) | |
| # ============================================================================= | |
| def _resize_2d_rows_linear(mat: torch.Tensor, target_rows: int) -> torch.Tensor: | |
| """Resize a 2D matrix along rows using 1D linear interpolation (preserves columns).""" | |
| if mat.dim() != 2: | |
| return mat | |
| if mat.shape[0] == target_rows: | |
| return mat | |
| if target_rows <= 0: | |
| return mat[:0] | |
| # Treat columns as channels, interpolate over rows | |
| x = mat.transpose(0, 1).unsqueeze(0) # (1, cols, rows) | |
| y = F.interpolate(x, size=target_rows, mode="linear", align_corners=False) | |
| return y.squeeze(0).transpose(0, 1).contiguous() | |
| def _resize_2d_cols_linear(mat: torch.Tensor, target_cols: int) -> torch.Tensor: | |
| """Resize a 2D matrix along cols using 1D linear interpolation (preserves rows).""" | |
| if mat.dim() != 2: | |
| return mat | |
| if mat.shape[1] == target_cols: | |
| return mat | |
| if target_cols <= 0: | |
| return mat[:, :0] | |
| # Treat rows as channels, interpolate over cols | |
| x = mat.unsqueeze(0) # (1, rows, cols) | |
| y = F.interpolate(x, size=target_cols, mode="linear", align_corners=False) | |
| return y.squeeze(0).contiguous() | |
| def compute_student_svd_basis( | |
| student_mat: torch.Tensor, | |
| svd_engine: "GPUAcceleratedSVD", | |
| config: "DistillConfig", | |
| device: torch.device, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute the student's truncated SVD basis (U, S, V) used as the common subspace. | |
| Returns: | |
| U: (out, k), S: (k,), V: (in, k) where V = Vh^T | |
| """ | |
| if student_mat.dim() != 2: | |
| raise ValueError("student_mat must be 2D") | |
| tm, tn = int(student_mat.shape[0]), int(student_mat.shape[1]) | |
| k_cap = min(int(getattr(config, "projection_rank", 256)), tm, tn) | |
| if k_cap <= 0: | |
| raise ValueError("projection_rank too small for matrix") | |
| ref_f = torch.nan_to_num(student_mat.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| U, S, Vh = svd_engine.randomized_svd( | |
| ref_f, | |
| k_cap, | |
| n_oversamples=int(getattr(config, "projection_randomized_oversamples", 8)), | |
| n_iter=int(getattr(config, "projection_randomized_iter", 2)), | |
| ) | |
| k = k_cap | |
| if getattr(config, "projection_adaptive_rank", True): | |
| k = compute_adaptive_rank( | |
| S, | |
| energy_threshold=float(getattr(config, "projection_energy_threshold", 0.99)), | |
| min_rank=int(getattr(config, "projection_min_rank", 8)), | |
| max_rank=k_cap, | |
| ) | |
| k = max(1, min(int(k), int(S.numel()), k_cap)) | |
| U = U[:, :k].contiguous() | |
| S = S[:k].contiguous() | |
| Vh = Vh[:k, :].contiguous() | |
| V = Vh.transpose(0, 1).contiguous() | |
| return U, S, V | |
| def teacher_core_in_student_basis( | |
| teacher_mat: torch.Tensor, | |
| U_s: torch.Tensor, | |
| V_s: torch.Tensor, | |
| svd_engine: "GPUAcceleratedSVD", | |
| config: "DistillConfig", | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute teacher's representation in the student's SVD basis: | |
| core_t = U_s^T * Wt_resized * V_s | |
| where Wt_resized is constructed from teacher truncated SVD factors (PiSSA sqrt-split) | |
| and then row/col-resized to the student's matrix shape. | |
| Returns: | |
| core_t: (k, k) in float32 on `device`. | |
| """ | |
| if teacher_mat.dim() != 2: | |
| raise ValueError("teacher_mat must be 2D") | |
| if U_s.dim() != 2 or V_s.dim() != 2: | |
| raise ValueError("U_s and V_s must be 2D") | |
| tm, tn = int(U_s.shape[0]), int(V_s.shape[0]) | |
| k_s = int(U_s.shape[1]) | |
| src_f = torch.nan_to_num(teacher_mat.to(device).float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| r_cap = min(k_s, int(src_f.shape[0]), int(src_f.shape[1])) | |
| if r_cap <= 0: | |
| return torch.zeros((k_s, k_s), device=device, dtype=torch.float32) | |
| U_t, S_t, Vh_t = svd_engine.randomized_svd( | |
| src_f, | |
| r_cap, | |
| n_oversamples=int(getattr(config, "projection_randomized_oversamples", 8)), | |
| n_iter=int(getattr(config, "projection_randomized_iter", 2)), | |
| ) | |
| r = r_cap | |
| if getattr(config, "projection_adaptive_rank", True): | |
| r = compute_adaptive_rank( | |
| S_t, | |
| energy_threshold=float(getattr(config, "projection_energy_threshold", 0.99)), | |
| min_rank=int(getattr(config, "projection_min_rank", 8)), | |
| max_rank=r_cap, | |
| ) | |
| r = max(1, min(int(r), int(S_t.numel()), r_cap)) | |
| U_t = U_t[:, :r] | |
| S_t = S_t[:r] | |
| Vh_t = Vh_t[:r, :] | |
| sqrt_S = torch.sqrt(torch.clamp(S_t.float(), min=1e-10)).to(dtype=U_t.dtype) | |
| B_t = (U_t * sqrt_S.unsqueeze(0)).contiguous() # (m_t, r) | |
| A_t = (sqrt_S.unsqueeze(1) * Vh_t).contiguous() # (r, n_t) | |
| B_rs = _resize_2d_rows_linear(B_t, tm) # (tm, r) | |
| A_rs = _resize_2d_cols_linear(A_t, tn) # (r, tn) | |
| # core = U_s^T (B_rs A_rs) V_s = (U_s^T B_rs) (A_rs V_s) | |
| left = U_s.transpose(0, 1) @ B_rs # (k, r) | |
| right = A_rs @ V_s # (r, k) | |
| core = left @ right # (k, k) | |
| return core.to(dtype=torch.float32) | |
| def extract_lora_from_core_delta( | |
| delta_core: torch.Tensor, | |
| U_s: torch.Tensor, | |
| V_s: torch.Tensor, | |
| rank_hint: int, | |
| config: "DistillConfig", | |
| out_dtype: torch.dtype, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int]: | |
| """ | |
| Build LoRA factors directly from a low-rank subspace delta: | |
| delta = U_s * delta_core * V_s^T | |
| We SVD delta_core (k x k) and lift it into (out x in) factors without forming the full matrix. | |
| """ | |
| if delta_core.dim() != 2: | |
| return None, None, 0 | |
| k = int(delta_core.shape[0]) | |
| if k <= 0: | |
| return None, None, 0 | |
| dc = torch.nan_to_num(delta_core.float(), nan=0.0, posinf=0.0, neginf=0.0) | |
| if torch.linalg.vector_norm(dc) < 1e-8: | |
| return None, None, 0 | |
| # Full SVD on small core (k <= projection_rank) | |
| try: | |
| Uc, Sc, Vh = torch.linalg.svd(dc, full_matrices=False) | |
| except Exception: | |
| return None, None, 0 | |
| if Sc.numel() == 0 or torch.sum(Sc ** 2) < 1e-12: | |
| return None, None, 0 | |
| if getattr(config, "use_adaptive_rank", True): | |
| actual_rank = compute_adaptive_rank( | |
| Sc, | |
| energy_threshold=float(getattr(config, "energy_threshold", 0.95)), | |
| min_rank=int(getattr(config, "min_rank", 8)), | |
| max_rank=min(int(getattr(config, "max_rank", 256)), k), | |
| ) | |
| else: | |
| actual_rank = int(rank_hint) | |
| actual_rank = max(int(getattr(config, "min_rank", 8)), | |
| min(actual_rank, int(getattr(config, "max_rank", 256)), k)) | |
| actual_rank = max(1, min(actual_rank, k, int(Sc.numel()))) | |
| Uc_r = Uc[:, :actual_rank] | |
| Sc_r = Sc[:actual_rank] | |
| Vh_r = Vh[:actual_rank, :] | |
| sqrt_S = torch.sqrt(torch.clamp(Sc_r, min=1e-10)) | |
| # Lift back to full matrix factors | |
| # B: (out, r) = (U_s @ Uc_r) * sqrt(S) | |
| # A: (r, in) = sqrt(S) * (Vh_r @ V_s^T) | |
| U_lift = (U_s.float() @ Uc_r).contiguous() # (out, r) | |
| V_lift_t = (Vh_r @ V_s.float().transpose(0, 1)).contiguous() # (r, in) | |
| lora_B = (U_lift * sqrt_S.unsqueeze(0)).to(dtype=out_dtype).contiguous() | |
| lora_A = (sqrt_S.unsqueeze(1) * V_lift_t).to(dtype=out_dtype).contiguous() | |
| return lora_A, lora_B, actual_rank | |
| def extract_lora_pissa( | |
| delta: torch.Tensor, | |
| rank: int, | |
| svd_engine: GPUAcceleratedSVD, | |
| config: DistillConfig, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int]: | |
| """ | |
| Extract LoRA matrices using PiSSA (sqrt-split SVD factors). | |
| SVD improvements (no backward-compat constraints assumed): | |
| 1) Adaptive-rank selection is measured against the *true* Frobenius energy ||ฮ||_F^2, | |
| not just the energy of the (possibly truncated) singular-value vector. | |
| 2) When adaptive rank is enabled, the SVD rank is grown geometrically until the | |
| energy target is met or max_rank is reached. This avoids always computing SVD at | |
| max_rank (which can be prohibitively expensive when max_rank is large). | |
| """ | |
| if delta.dim() != 2: | |
| return None, None, 0 | |
| out_features, in_features = delta.shape | |
| min_dim = min(out_features, in_features) | |
| # Sanitize numerical pathologies early. | |
| delta = torch.nan_to_num(delta, nan=0.0, posinf=1e6, neginf=-1e6) | |
| # Frobenius norm / energy (used for exact energy accounting in adaptive rank). | |
| delta_norm = torch.norm(delta.float()) | |
| if delta_norm < 1e-10: | |
| return None, None, 0 | |
| total_energy = delta_norm.float() ** 2 # ||ฮ||_F^2 | |
| # Respect hard bounds. | |
| max_rank = int(min(getattr(config, "max_rank", 256), min_dim)) | |
| min_rank = int(max(1, min(getattr(config, "min_rank", 8), max_rank))) | |
| # If adaptive rank is disabled, do a single SVD at the requested rank (bounded). | |
| if not bool(getattr(config, "use_adaptive_rank", True)): | |
| actual_rank = int(max(min_rank, min(int(rank), max_rank))) | |
| try: | |
| U, S, Vh = svd_engine.svd(delta, actual_rank) | |
| except Exception as e: | |
| print(f" SVD failed completely, skipping: {type(e).__name__}") | |
| return None, None, 0 | |
| if S.numel() == 0 or torch.all(S < 1e-10): | |
| return None, None, 0 | |
| S = torch.clamp(S, min=0.0) | |
| sqrtS = torch.sqrt(S).to(dtype=U.dtype) | |
| lora_B = U[:, :actual_rank] * sqrtS.unsqueeze(0) | |
| lora_A = sqrtS.unsqueeze(1) * Vh[:actual_rank, :] | |
| return lora_A.float(), lora_B.float(), actual_rank | |
| # Adaptive-rank path: grow the SVD rank until the retained energy meets the target. | |
| energy_threshold = float(getattr(config, "energy_threshold", 0.95)) | |
| energy_threshold = float(max(0.0, min(1.0, energy_threshold))) | |
| target_energy = total_energy * energy_threshold | |
| # Fast-path: Adaptive range finder (Halko et al., Alg. 4.2) + one projected SVD. | |
| if str(getattr(config, "svd_mode", "auto")).lower() == "adaptive": | |
| try: | |
| U, S, Vh = svd_engine.svd_energy( | |
| delta, | |
| energy_threshold=energy_threshold, | |
| min_rank=min_rank, | |
| max_rank=max_rank, | |
| block_size=int(getattr(config, "svd_adaptive_block_size", 32)), | |
| n_test=int(getattr(config, "svd_adaptive_n_test", 8)), | |
| n_iter=int(getattr(config, "svd_randomized_iter", 0)), | |
| ) | |
| except Exception as e: | |
| if bool(getattr(config, "verbose", False)): | |
| print(f" Adaptive SVD failed, falling back to geometric SVD: {type(e).__name__}") | |
| else: | |
| if S.numel() == 0 or torch.all(S < 1e-10): | |
| return None, None, 0 | |
| actual_rank = compute_adaptive_rank( | |
| S, | |
| energy_threshold=energy_threshold, | |
| min_rank=min_rank, | |
| max_rank=max_rank, | |
| total_energy=total_energy, | |
| ) | |
| U = U[:, :actual_rank] | |
| S = S[:actual_rank] | |
| Vh = Vh[:actual_rank, :] | |
| S = torch.clamp(S, min=0.0) | |
| sqrtS = torch.sqrt(S).to(dtype=U.dtype) | |
| lora_B = U * sqrtS.unsqueeze(0) | |
| lora_A = sqrtS.unsqueeze(1) * Vh | |
| return lora_A.float(), lora_B.float(), actual_rank | |
| # Start from the user-provided rank (bounded), then grow geometrically as needed. | |
| r_try = int(max(min_rank, min(int(rank), max_rank))) | |
| r_try = max(1, min(r_try, max_rank)) | |
| U: Optional[torch.Tensor] = None | |
| S: Optional[torch.Tensor] = None | |
| Vh: Optional[torch.Tensor] = None | |
| while True: | |
| try: | |
| U, S, Vh = svd_engine.svd(delta, r_try) | |
| except Exception as e: | |
| print(f" SVD failed at rank {r_try}, skipping: {type(e).__name__}") | |
| return None, None, 0 | |
| if S.numel() == 0 or torch.all(S < 1e-10): | |
| return None, None, 0 | |
| # Compute retained energy for prefixes of the singular spectrum we actually computed. | |
| s = torch.clamp(S.detach().float(), min=0.0) | |
| cum_energy = torch.cumsum(s ** 2, dim=0) | |
| hit = (cum_energy >= target_energy) | |
| if bool(hit.any()): | |
| r_needed = int(hit.nonzero(as_tuple=True)[0][0].item() + 1) | |
| actual_rank = int(max(min_rank, min(r_needed, max_rank))) | |
| break | |
| # Not enough energy captured by top-r_try singular values. | |
| if r_try >= max_rank: | |
| actual_rank = int(max_rank) | |
| break | |
| # Grow rank geometrically (fast convergence, bounded by max_rank). | |
| r_next = int(max(r_try + 1, math.ceil(r_try * 2.0))) | |
| r_try = int(min(max_rank, r_next)) | |
| # Slice to actual_rank (may be smaller than the last computed r_try). | |
| U = U[:, :actual_rank] | |
| S = S[:actual_rank] | |
| Vh = Vh[:actual_rank, :] | |
| S = torch.clamp(S, min=0.0) | |
| sqrtS = torch.sqrt(S).to(dtype=U.dtype) | |
| lora_B = U * sqrtS.unsqueeze(0) | |
| lora_A = sqrtS.unsqueeze(1) * Vh | |
| return lora_A.float(), lora_B.float(), actual_rank | |
| # ============================================================================= | |
| # CALIBRATION-AWARE (DATA-AWARE) SVD (OPTIONAL) | |
| # ============================================================================= | |
| # | |
| # Goal: | |
| # Collect activation statistics on the STUDENT model (calibration prompts) and apply | |
| # data-aware whitening before SVD so the truncated low-rank factors better match the | |
| # target input distribution. | |
| # | |
| # Modes: | |
| # - rms: diagonal whitening with per-channel activation RMS (input/output). | |
| # - cov: SVD-LLM style full-covariance whitening on the input side using a Cholesky | |
| # factor S of E[xx^T]. PiSSA is applied on (ฮยทS) and then unwhitened by S^{-1}. | |
| # | |
| # Notes: | |
| # - Optional; default disabled ("none"). | |
| # - Calibration only loads the STUDENT model; teacher weights are not needed. | |
| # - Stats are stored in safetensors (see save/load helpers below). | |
| class CalibStatsEntry: | |
| """Calibration stats for a module or a deterministic shared-input group.""" | |
| # Diagonal (ASVD-style) stats | |
| in_rms: Optional[torch.Tensor] = None # (in_features,) | |
| out_rms: Optional[torch.Tensor] = None # (out_features,) | |
| # Full-covariance (SVD-LLM) stats | |
| in_chol: Optional[torch.Tensor] = None # (in_features, in_features) lower-triangular | |
| def _dtype_from_str(dtype_str: str, device: torch.device) -> torch.dtype: | |
| ds = (dtype_str or "auto").lower() | |
| if ds == "bf16": | |
| return torch.bfloat16 | |
| if ds == "fp16": | |
| return torch.float16 | |
| if ds == "fp32": | |
| return torch.float32 | |
| # auto | |
| if device.type == "cuda": | |
| try: | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| except Exception: | |
| pass | |
| return torch.float16 | |
| return torch.float32 | |
| def _infer_calib_device(calib_device: str) -> torch.device: | |
| cd = (calib_device or "auto").lower() | |
| if cd == "cuda": | |
| return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| if cd == "cpu": | |
| return torch.device("cpu") | |
| # auto | |
| return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| def _flatten_text_field(x: Any) -> str: | |
| """Best-effort: normalize a field that may be str / list[str] / other into a single string.""" | |
| if x is None: | |
| return "" | |
| if isinstance(x, str): | |
| return x | |
| if isinstance(x, (list, tuple)): | |
| parts: List[str] = [] | |
| for v in x: | |
| if v is None: | |
| continue | |
| if isinstance(v, str): | |
| s = v.strip() | |
| else: | |
| s = str(v).strip() | |
| if s: | |
| parts.append(s) | |
| return "\n".join(parts) | |
| try: | |
| return str(x) | |
| except Exception: | |
| return "" | |
| def _looks_like_alpaca_record(obj: Any) -> bool: | |
| if not isinstance(obj, dict): | |
| return False | |
| # Classic Alpaca keys: instruction / input / output (some datasets use lists for input/output) | |
| if "instruction" in obj and ("input" in obj or "output" in obj): | |
| return True | |
| # Some variants add system_prompt or system | |
| if ("system_prompt" in obj or "system" in obj) and ("instruction" in obj or "input" in obj): | |
| return True | |
| return False | |
| def _format_alpaca_prompt(obj: Dict[str, Any], | |
| template: str = "classic", | |
| include_output: bool = False) -> str: | |
| """ | |
| Build a calibration prompt from an Alpaca-style record. | |
| Supported fields (best-effort): | |
| - system_prompt / system / system_message (optional) | |
| - instruction / prompt / query (required-ish) | |
| - input / context (optional) | |
| - output / response (optional; included only if include_output=True) | |
| """ | |
| sys_p = _flatten_text_field(obj.get("system_prompt", obj.get("system", obj.get("system_message", "")))).strip() | |
| instruction = _flatten_text_field(obj.get("instruction", obj.get("prompt", obj.get("query", "")))).strip() | |
| inp = _flatten_text_field(obj.get("input", obj.get("context", ""))).strip() | |
| out = _flatten_text_field(obj.get("output", obj.get("response", ""))).strip() | |
| template = (template or "classic").lower().strip() | |
| prefix = (sys_p + "\n\n") if sys_p else "" | |
| if template == "plain": | |
| parts: List[str] = [] | |
| if sys_p: | |
| parts.append(sys_p) | |
| if instruction: | |
| parts.append(instruction) | |
| if inp: | |
| parts.append(inp) | |
| if include_output and out: | |
| parts.append(out) | |
| return "\n\n".join(parts).strip() | |
| # classic | |
| if inp: | |
| prompt = f"{prefix}### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Response:\n" | |
| else: | |
| prompt = f"{prefix}### Instruction:\n{instruction}\n\n### Response:\n" | |
| if include_output and out: | |
| prompt = prompt + out | |
| return prompt | |
| def _stream_json_array_first_n(path: str, n: int) -> Optional[List[Any]]: | |
| """ | |
| Streaming parse for a top-level JSON array file: returns first n items without loading whole file. | |
| Returns None if the file does not look like a JSON array. | |
| """ | |
| try: | |
| decoder = json.JSONDecoder() | |
| buf = "" | |
| items: List[Any] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| # Read until we find the opening '[' | |
| while True: | |
| if not buf: | |
| chunk = f.read(65536) | |
| if not chunk: | |
| return [] | |
| buf += chunk | |
| buf = buf.lstrip() | |
| if buf.startswith("["): | |
| buf = buf[1:] | |
| break | |
| # Not an array | |
| return None | |
| while len(items) < n: | |
| buf = buf.lstrip() | |
| if buf.startswith("]"): | |
| break | |
| # Ensure buffer has enough content for a full JSON object | |
| while True: | |
| try: | |
| obj, idx = decoder.raw_decode(buf) | |
| items.append(obj) | |
| buf = buf[idx:] | |
| buf = buf.lstrip() | |
| if buf.startswith(","): | |
| buf = buf[1:] | |
| break | |
| except json.JSONDecodeError: | |
| chunk = f.read(65536) | |
| if not chunk: | |
| return items | |
| buf += chunk | |
| return items | |
| except Exception: | |
| return None | |
| def _load_calib_texts(path: Optional[str], | |
| max_samples: int, | |
| fmt: str = "auto", | |
| alpaca_template: str = "classic", | |
| alpaca_include_output: bool = False) -> List[str]: | |
| """ | |
| Load calibration texts. | |
| Supported formats: | |
| - txt: 1 prompt per line | |
| - jsonl: 1 JSON per line; supports {"text": "..."} or Alpaca records | |
| - json: JSON array/object; supports "text" or Alpaca records (streamed for arrays) | |
| - alpaca: force Alpaca parsing/formatting for .json/.jsonl | |
| """ | |
| if not path: | |
| # Minimal built-in prompts (language-mixed) as a fallback to keep the feature plug-and-play. | |
| return [ | |
| "ไฝ ๅฅฝ๏ผ่ซ็จไธๅฅ่ฉฑไป็ดนไฝ ่ชๅทฑใ", | |
| "่ซ่งฃ้ไป้บผๆฏ็ฅ่ญ่ธ้คพ๏ผKnowledge Distillation๏ผใ", | |
| "่ซๆไธ้ขๅฅๅญ็ฟป่ญฏๆ่ฑๆ๏ผไบบๅทฅๆบๆ งๆญฃๅจๆน่ฎไธ็ใ", | |
| "Write a short paragraph about neural networks.", | |
| "Summarize the following in one sentence: Large language models learn patterns from data.", | |
| "List three key points about transformer attention.", | |
| ][:max_samples] | |
| fmt = (fmt or "auto").lower().strip() | |
| p_lower = str(path).lower() | |
| if fmt == "auto": | |
| if p_lower.endswith(".jsonl"): | |
| fmt = "jsonl" | |
| elif p_lower.endswith(".json"): | |
| fmt = "json" | |
| else: | |
| fmt = "txt" | |
| # For alpaca, we accept both .json and .jsonl | |
| force_alpaca = (fmt == "alpaca") | |
| if force_alpaca: | |
| # decide underlying file type | |
| if p_lower.endswith(".jsonl"): | |
| fmt_under = "jsonl" | |
| else: | |
| fmt_under = "json" | |
| else: | |
| fmt_under = fmt | |
| def _extract_text(obj: Any) -> Optional[str]: | |
| if obj is None: | |
| return None | |
| if isinstance(obj, str): | |
| s = obj.strip() | |
| return s if s else None | |
| if isinstance(obj, dict): | |
| if force_alpaca or _looks_like_alpaca_record(obj): | |
| t = _format_alpaca_prompt(obj, template=alpaca_template, include_output=alpaca_include_output) | |
| t = (t or "").strip() | |
| return t if t else None | |
| # json/jsonl generic: prefer "text" | |
| t = obj.get("text", None) | |
| if isinstance(t, str) and t.strip(): | |
| return t.strip() | |
| # fallback: first string-ish value | |
| for v in obj.values(): | |
| if isinstance(v, str) and v.strip(): | |
| return v.strip() | |
| return None | |
| if isinstance(obj, (list, tuple)): | |
| # list[str] -> join | |
| s = _flatten_text_field(obj).strip() | |
| return s if s else None | |
| # last resort | |
| try: | |
| s = str(obj).strip() | |
| return s if s else None | |
| except Exception: | |
| return None | |
| texts: List[str] = [] | |
| try: | |
| if fmt_under == "txt": | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| texts.append(line) | |
| if len(texts) >= max_samples: | |
| break | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| if fmt_under == "jsonl": | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| continue | |
| t = _extract_text(obj) | |
| if t: | |
| texts.append(t) | |
| if len(texts) >= max_samples: | |
| break | |
| if texts: | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| # If empty, fall through to fallback prompts | |
| return _load_calib_texts(None, max_samples, "txt") | |
| if fmt_under == "json": | |
| # First try jsonl-style streaming (many datasets are .json but actually JSONL) | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| # not jsonl | |
| texts = [] | |
| break | |
| t = _extract_text(obj) | |
| if t: | |
| texts.append(t) | |
| if len(texts) >= max_samples: | |
| break | |
| if texts: | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| except Exception: | |
| pass | |
| # Streaming parse JSON array for first N items | |
| recs = _stream_json_array_first_n(path, max_samples if max_samples > 0 else 0) | |
| if recs is not None: | |
| for obj in recs: | |
| t = _extract_text(obj) | |
| if t: | |
| texts.append(t) | |
| if len(texts) >= max_samples: | |
| break | |
| if texts: | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| # Fallback: full json.load (may be heavy for huge files; use jsonl if possible) | |
| with open(path, "r", encoding="utf-8") as f: | |
| obj = json.load(f) | |
| # Unwrap common containers | |
| records: List[Any] = [] | |
| if isinstance(obj, list): | |
| records = obj | |
| elif isinstance(obj, dict): | |
| # common keys | |
| for k in ("data", "train", "records", "samples", "items"): | |
| v = obj.get(k, None) | |
| if isinstance(v, list): | |
| records = v | |
| break | |
| if not records: | |
| records = [obj] | |
| else: | |
| records = [obj] | |
| for r in records: | |
| t = _extract_text(r) | |
| if t: | |
| texts.append(t) | |
| if len(texts) >= max_samples: | |
| break | |
| if texts: | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| return _load_calib_texts(None, max_samples, "txt") | |
| except Exception: | |
| # If load fails, fall back to built-in prompts (do not hard-fail). | |
| return _load_calib_texts(None, max_samples, "txt") | |
| return texts[:max_samples] if max_samples > 0 else texts | |
| class _RMSAccumulator: | |
| __slots__ = ("in_sumsq", "out_sumsq", "in_count", "out_count") | |
| def __init__(self, in_features: int, out_features: int, collect_in: bool, collect_out: bool): | |
| self.in_sumsq = torch.zeros(in_features, dtype=torch.float32) if collect_in else None | |
| self.out_sumsq = torch.zeros(out_features, dtype=torch.float32) if collect_out else None | |
| self.in_count = 0 | |
| self.out_count = 0 | |
| def add_in(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): | |
| if self.in_sumsq is None: | |
| return | |
| # x: (..., in_features) | |
| x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).float() | |
| x = x.reshape(-1, x.shape[-1]) | |
| if attention_mask is not None and torch.is_tensor(attention_mask): | |
| m = attention_mask | |
| if m.dim() > 1: | |
| m = m.reshape(-1) | |
| else: | |
| m = m.reshape(-1) | |
| if m.numel() == x.shape[0]: | |
| m = m.to(device=x.device, dtype=x.dtype) | |
| v = (x * x * m.unsqueeze(1)).sum(dim=0) | |
| cnt = int(m.sum().item()) | |
| else: | |
| v = (x * x).sum(dim=0) | |
| cnt = int(x.shape[0]) | |
| else: | |
| v = (x * x).sum(dim=0) | |
| cnt = int(x.shape[0]) | |
| if cnt <= 0: | |
| return | |
| self.in_sumsq += v.detach().cpu() | |
| self.in_count += cnt | |
| def add_out(self, y: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): | |
| if self.out_sumsq is None: | |
| return | |
| # y: (..., out_features) | |
| y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0).float() | |
| y = y.reshape(-1, y.shape[-1]) | |
| if attention_mask is not None and torch.is_tensor(attention_mask): | |
| m = attention_mask | |
| if m.dim() > 1: | |
| m = m.reshape(-1) | |
| else: | |
| m = m.reshape(-1) | |
| if m.numel() == y.shape[0]: | |
| m = m.to(device=y.device, dtype=y.dtype) | |
| v = (y * y * m.unsqueeze(1)).sum(dim=0) | |
| cnt = int(m.sum().item()) | |
| else: | |
| v = (y * y).sum(dim=0) | |
| cnt = int(y.shape[0]) | |
| else: | |
| v = (y * y).sum(dim=0) | |
| cnt = int(y.shape[0]) | |
| if cnt <= 0: | |
| return | |
| self.out_sumsq += v.detach().cpu() | |
| self.out_count += cnt | |
| def finalize(self, eps: float) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| in_rms = None | |
| out_rms = None | |
| if self.in_sumsq is not None: | |
| denom = max(1, self.in_count) | |
| in_rms = torch.sqrt(self.in_sumsq / float(denom)).clamp(min=eps).contiguous() | |
| if self.out_sumsq is not None: | |
| denom = max(1, self.out_count) | |
| out_rms = torch.sqrt(self.out_sumsq / float(denom)).clamp(min=eps).contiguous() | |
| return in_rms, out_rms | |
| class _CovAccumulator: | |
| """Accumulate full input covariance: sum(x^T x) and token count.""" | |
| __slots__ = ("dim", "gram", "count") | |
| def __init__(self, dim: int, device: torch.device): | |
| self.dim = int(dim) | |
| self.gram = torch.zeros((self.dim, self.dim), device=device, dtype=torch.float32) | |
| self.count = 0 | |
| def add(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> None: | |
| if x is None: | |
| return | |
| if x.dim() > 2: | |
| x = x.reshape(-1, x.shape[-1]) | |
| if x.numel() == 0: | |
| return | |
| x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).to(dtype=torch.float32) | |
| if x.shape[-1] != self.dim: | |
| return | |
| if attn_mask is not None: | |
| m = attn_mask | |
| if m.dim() > 1: | |
| m = m.reshape(-1) | |
| if m.numel() == x.shape[0]: | |
| m = m.to(device=x.device, dtype=torch.float32) | |
| x = x * m.unsqueeze(1) | |
| cnt = int(torch.sum(m).item()) | |
| else: | |
| cnt = int(x.shape[0]) | |
| else: | |
| cnt = int(x.shape[0]) | |
| # x^T x | |
| self.gram += x.transpose(0, 1) @ x | |
| self.count += max(0, cnt) | |
| def _parse_layer_index(name: str) -> Optional[int]: | |
| """Return layer index if name matches common HF layer naming.""" | |
| m = re.search(r'(?:model\.layers|transformer\.h|decoder\.layers)\.(\d+)\.', name) | |
| if not m: | |
| return None | |
| try: | |
| return int(m.group(1)) | |
| except Exception: | |
| return None | |
| def _cov_group_key_for_module(module_name: str, student_arch: ArchitectureInfo) -> Optional[str]: | |
| """Map a module name to a deterministic shared-input covariance key.""" | |
| if not module_name or module_name.count(".") < 2: | |
| return None | |
| leaf = module_name.split(".")[-1] | |
| base = ".".join(module_name.split(".")[:-1]) | |
| qn = getattr(student_arch, "q_proj_name", "q_proj") | |
| kn = getattr(student_arch, "k_proj_name", "k_proj") | |
| vn = getattr(student_arch, "v_proj_name", "v_proj") | |
| on = getattr(student_arch, "o_proj_name", "o_proj") | |
| gn = getattr(student_arch, "gate_proj_name", "gate_proj") | |
| un = getattr(student_arch, "up_proj_name", "up_proj") | |
| dn = getattr(student_arch, "down_proj_name", "down_proj") | |
| if leaf in (qn, kn, vn): | |
| return f"{base}.qkv_in" | |
| if leaf == on: | |
| return f"{base}.o_in" | |
| if leaf in (gn, un, "w1", "w3"): | |
| return f"{base}.up_in" | |
| if leaf in (dn, "w2"): | |
| return f"{base}.down_in" | |
| return None | |
| def _resolve_cov_entry_key(module_name: str, calib_stats: Dict[str, CalibStatsEntry], student_arch: ArchitectureInfo) -> Optional[str]: | |
| """Resolve the key to use for covariance whitening (exact module or shared-input group).""" | |
| if module_name in calib_stats and calib_stats[module_name].in_chol is not None: | |
| return module_name | |
| gk = _cov_group_key_for_module(module_name, student_arch) | |
| if gk and gk in calib_stats and calib_stats[gk].in_chol is not None: | |
| return gk | |
| return None | |
| def _safe_cholesky(cov: torch.Tensor, eps: float, max_tries: int = 6) -> torch.Tensor: | |
| """Robust Cholesky with escalating jitter.""" | |
| cov = (cov + cov.transpose(0, 1)) * 0.5 | |
| d = int(cov.shape[0]) | |
| eye = torch.eye(d, device=cov.device, dtype=cov.dtype) | |
| jitter = float(eps) | |
| for _ in range(max_tries): | |
| try: | |
| return torch.linalg.cholesky(cov + jitter * eye) | |
| except Exception: | |
| jitter *= 10.0 | |
| # Last-resort: attempt cholesky_ex, then jitter more | |
| L, info = torch.linalg.cholesky_ex(cov + jitter * eye) | |
| if int(info) == 0: | |
| return L | |
| raise RuntimeError("Cholesky failed (covariance not SPD even after jitter).") | |
| def collect_calibration_stats_rms(student_path: str, | |
| student_arch: ArchitectureInfo, | |
| config: DistillConfig) -> Dict[str, CalibStatsEntry]: | |
| """Collect per-channel activation RMS stats on the STUDENT model.""" | |
| if getattr(config, "calibration_mode", "none") == "none": | |
| return {} | |
| if not HAS_TRANSFORMERS: | |
| raise RuntimeError("Calibration mode requires transformers installed.") | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| device = _infer_calib_device(getattr(config, "calib_device", "auto")) | |
| dtype = _dtype_from_str(getattr(config, "calib_dtype", "auto"), device) | |
| texts = _load_calib_texts( | |
| getattr(config, "calib_data", None), | |
| int(getattr(config, "calib_max_samples", 128)), | |
| getattr(config, "calib_format", "auto"), | |
| getattr(config, "calib_alpaca_template", "classic"), | |
| bool(getattr(config, "calib_alpaca_include_output", False)), | |
| ) | |
| if not texts: | |
| return {} | |
| tok = AutoTokenizer.from_pretrained(student_path, trust_remote_code=True, use_fast=True) | |
| if tok.pad_token_id is None: | |
| # Make padding well-defined for batch calibration | |
| tok.pad_token = tok.eos_token | |
| # Load model with best-effort low-memory strategy | |
| model = None | |
| try: | |
| if device.type == "cuda": | |
| # Prefer device_map="auto" (requires accelerate). Fall back to full cuda load. | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| student_path, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| except Exception: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| student_path, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| student_path, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load student model for calibration: {type(e).__name__}: {e}") | |
| model.eval() | |
| # Select target linear modules for stats collection | |
| include_re = None | |
| try: | |
| pat = getattr(config, "include_pattern", "self_attn|mlp") or "self_attn|mlp" | |
| include_re = re.compile(pat) | |
| except Exception: | |
| include_re = re.compile("self_attn|mlp") | |
| target_leaf_names = set([ | |
| student_arch.q_proj_name, student_arch.k_proj_name, student_arch.v_proj_name, student_arch.o_proj_name, | |
| student_arch.gate_proj_name, student_arch.up_proj_name, student_arch.down_proj_name, | |
| # Fallback names commonly seen in SwiGLU variants | |
| "w1", "w2", "w3", | |
| ]) | |
| acc: Dict[str, _RMSAccumulator] = {} | |
| mask_ref: Dict[str, Optional[torch.Tensor]] = {"mask": None} | |
| hooks: List[Any] = [] | |
| def _should_hook(name: str, module: torch.nn.Module) -> bool: | |
| if not isinstance(module, torch.nn.Linear): | |
| return False | |
| if include_re is not None and (not include_re.search(name)): | |
| return False | |
| leaf = name.split(".")[-1] | |
| return leaf in target_leaf_names | |
| # Register hooks | |
| for name, module in model.named_modules(): | |
| if not _should_hook(name, module): | |
| continue | |
| try: | |
| in_f = int(getattr(module, "in_features", 0)) | |
| out_f = int(getattr(module, "out_features", 0)) | |
| if in_f <= 0 or out_f <= 0: | |
| continue | |
| acc[name] = _RMSAccumulator(in_f, out_f, | |
| collect_in=bool(getattr(config, "calib_collect_in", True)), | |
| collect_out=bool(getattr(config, "calib_collect_out", True))) | |
| except Exception: | |
| continue | |
| def _pre_hook(mod, inputs, name=name): | |
| if name not in acc: | |
| return | |
| try: | |
| x = inputs[0] | |
| if isinstance(x, (tuple, list)): | |
| x = x[0] | |
| if x is None or not torch.is_tensor(x) or x.dim() == 0: | |
| return | |
| acc[name].add_in(x, mask_ref["mask"] if config.calib_use_attention_mask else None) | |
| except Exception: | |
| return | |
| def _fwd_hook(mod, inputs, output, name=name): | |
| if name not in acc: | |
| return | |
| try: | |
| y = output | |
| if isinstance(y, (tuple, list)): | |
| y = y[0] | |
| if y is None or not torch.is_tensor(y) or y.dim() == 0: | |
| return | |
| acc[name].add_out(y, mask_ref["mask"] if config.calib_use_attention_mask else None) | |
| except Exception: | |
| return | |
| hooks.append(module.register_forward_pre_hook(_pre_hook)) | |
| hooks.append(module.register_forward_hook(_fwd_hook)) | |
| if not acc: | |
| # Nothing matched; cleanup | |
| for h in hooks: | |
| try: | |
| h.remove() | |
| except Exception: | |
| pass | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return {} | |
| # Determine input device | |
| try: | |
| input_device = next(model.parameters()).device | |
| except Exception: | |
| input_device = device | |
| bs = max(1, int(getattr(config, "calib_batch_size", 1))) | |
| max_len = max(8, int(getattr(config, "calib_max_length", 256))) | |
| with torch.inference_mode(): | |
| for i in range(0, len(texts), bs): | |
| batch_texts = texts[i:i+bs] | |
| enc = tok( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=("max_length" if (config.calib_padding == "max_length" and max_len is not None) else True), | |
| truncation=True, | |
| max_length=max_len, | |
| ) | |
| enc = {k: v.to(input_device) for k, v in enc.items()} | |
| mask_ref["mask"] = enc.get("attention_mask", None) | |
| try: | |
| _ = model(**enc, use_cache=False) | |
| except Exception: | |
| # Skip problematic batch | |
| continue | |
| # Cleanup hooks | |
| for h in hooks: | |
| try: | |
| h.remove() | |
| except Exception: | |
| pass | |
| stats: Dict[str, CalibStatsEntry] = {} | |
| eps = float(getattr(config, "calib_eps", 1e-6)) | |
| for name, a in acc.items(): | |
| in_rms, out_rms = a.finalize(eps) | |
| stats[name] = CalibStatsEntry(in_rms=in_rms, out_rms=out_rms) | |
| # Release model memory | |
| del model | |
| try: | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return stats | |
| def collect_calibration_stats_cov(student_path: str, | |
| student_arch: ArchitectureInfo, | |
| config: DistillConfig) -> Dict[str, CalibStatsEntry]: | |
| """Collect full-covariance (SVD-LLM) input whitening stats on the STUDENT model. | |
| Produces in_chol factors keyed by deterministic shared-input group keys (see | |
| _cov_group_key_for_module). These factors are consumed by extract_lora_pissa_calibrated | |
| when calib_mode=cov. | |
| """ | |
| if not HAS_TRANSFORMERS: | |
| raise RuntimeError("Calibration mode requires transformers installed.") | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| device = _infer_calib_device(getattr(config, "calib_device", "auto")) | |
| dtype = _dtype_from_str(getattr(config, "calib_dtype", "auto"), device) | |
| texts = _load_calib_texts( | |
| getattr(config, "calib_data", None), | |
| int(getattr(config, "calib_max_samples", 128)), | |
| getattr(config, "calib_format", "auto"), | |
| getattr(config, "calib_alpaca_template", "classic"), | |
| bool(getattr(config, "calib_alpaca_include_output", False)), | |
| ) | |
| if not texts: | |
| return {} | |
| tok = AutoTokenizer.from_pretrained(student_path, trust_remote_code=True, use_fast=True) | |
| if tok.pad_token_id is None: | |
| tok.pad_token = tok.eos_token | |
| torch_dtype = dtype if device.type == "cuda" else torch.float32 | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| student_path, | |
| torch_dtype=torch_dtype, | |
| device_map=("auto" if (device.type == "cuda" and getattr(config, "calib_device", "auto") == "auto") else None), | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load student model for calibration: {e}") | |
| model.eval() | |
| include_pat = getattr(config, "include_pattern", "") | |
| exclude_pat = getattr(config, "exclude_pattern", "") | |
| include_re = re.compile(include_pat) if include_pat else None | |
| exclude_re = re.compile(exclude_pat) if exclude_pat else None | |
| # Which covariance groups to collect | |
| groups_raw = str(getattr(config, "calib_cov_groups", "qkv,o,mlp")).split(",") | |
| want_groups = {g.strip().lower() for g in groups_raw if g.strip()} | |
| max_dim = int(getattr(config, "calib_cov_max_dim", 8192)) | |
| chunk_layers = max(1, int(getattr(config, "calib_cov_chunk_layers", 4))) | |
| store_dtype_str = str(getattr(config, "calib_cov_store_dtype", "fp16")).lower() | |
| store_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(store_dtype_str, torch.float16) | |
| eps = float(getattr(config, "calib_eps", 1e-6)) | |
| use_attn_mask = bool(getattr(config, "calib_use_attention_mask", True)) | |
| # Build per-layer representative module selection. | |
| module_dict: Dict[str, torch.nn.Module] = dict(model.named_modules()) | |
| qn = getattr(student_arch, "q_proj_name", "q_proj") | |
| kn = getattr(student_arch, "k_proj_name", "k_proj") | |
| vn = getattr(student_arch, "v_proj_name", "v_proj") | |
| on = getattr(student_arch, "o_proj_name", "o_proj") | |
| gn = getattr(student_arch, "gate_proj_name", "gate_proj") | |
| un = getattr(student_arch, "up_proj_name", "up_proj") | |
| dn = getattr(student_arch, "down_proj_name", "down_proj") | |
| priorities: Dict[str, List[str]] = { | |
| "qkv": [qn, kn, vn], | |
| "o": [on], | |
| "mlp": [gn, un, "w1", "w3"], | |
| "down": [dn, "w2"], | |
| } | |
| layer_sel: Dict[int, Dict[str, str]] = defaultdict(dict) | |
| layer_pri: Dict[int, Dict[str, int]] = defaultdict(dict) | |
| for name, mod in module_dict.items(): | |
| if not isinstance(mod, torch.nn.Linear): | |
| continue | |
| if include_re and not include_re.search(name): | |
| continue | |
| if exclude_re and exclude_re.search(name): | |
| continue | |
| li = _parse_layer_index(name) | |
| if li is None: | |
| continue | |
| leaf = name.split(".")[-1] | |
| for g in want_groups: | |
| if g not in priorities: | |
| continue | |
| plist = priorities[g] | |
| if leaf not in plist: | |
| continue | |
| pidx = plist.index(leaf) | |
| if (g not in layer_sel[li]) or (pidx < layer_pri[li].get(g, 10**9)): | |
| layer_sel[li][g] = name | |
| layer_pri[li][g] = pidx | |
| layer_indices = sorted(layer_sel.keys()) | |
| if not layer_indices: | |
| # Nothing matched; return empty. | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return {} | |
| # Determine the device for feeding input tensors. | |
| try: | |
| input_device = next(model.parameters()).device | |
| except Exception: | |
| input_device = device | |
| max_len = int(getattr(config, "calib_max_length", 256)) if getattr(config, "calib_max_length", None) else None | |
| bs = int(getattr(config, "calib_batch_size", 1)) | |
| stats: Dict[str, CalibStatsEntry] = {} | |
| # Chunked passes to bound peak memory. | |
| mask_ref: Dict[str, Optional[torch.Tensor]] = {"mask": None} | |
| for c0 in range(0, len(layer_indices), chunk_layers): | |
| chunk = layer_indices[c0:c0 + chunk_layers] | |
| acc: Dict[str, _CovAccumulator] = {} | |
| hooks: List[Any] = [] | |
| # Register hooks for this chunk. | |
| for li in chunk: | |
| for g in want_groups: | |
| if g not in priorities: | |
| continue | |
| mod_name = layer_sel.get(li, {}).get(g, None) | |
| if not mod_name: | |
| continue | |
| mod = module_dict.get(mod_name, None) | |
| if mod is None or not isinstance(mod, torch.nn.Linear): | |
| continue | |
| in_dim = int(getattr(mod, "in_features", 0) or 0) | |
| if in_dim <= 0 or in_dim > max_dim: | |
| continue | |
| gk = _cov_group_key_for_module(mod_name, student_arch) | |
| if not gk: | |
| continue | |
| if gk in acc: | |
| continue | |
| try: | |
| dev = next(mod.parameters()).device | |
| except Exception: | |
| dev = input_device | |
| acc[gk] = _CovAccumulator(in_dim, dev) | |
| def _make_pre_hook(key: str): | |
| def _pre_hook(_mod, inputs): | |
| if not inputs: | |
| return | |
| x = inputs[0] | |
| if not torch.is_tensor(x): | |
| return | |
| m = mask_ref.get("mask", None) | |
| if not use_attn_mask: | |
| m = None | |
| acc[key].add(x, m) | |
| return _pre_hook | |
| hooks.append(mod.register_forward_pre_hook(_make_pre_hook(gk))) | |
| if not acc: | |
| # Nothing to collect in this chunk. | |
| for h in hooks: | |
| try: | |
| h.remove() | |
| except Exception: | |
| pass | |
| continue | |
| # Run calibration forward passes. | |
| with torch.inference_mode(): | |
| for i in range(0, len(texts), bs): | |
| batch_texts = texts[i:i + bs] | |
| enc = tok( | |
| batch_texts, | |
| return_tensors="pt", | |
| padding=("max_length" if (getattr(config, "calib_padding", "longest") == "max_length" and max_len is not None) else True), | |
| truncation=True, | |
| max_length=max_len, | |
| ) | |
| enc = {k: v.to(input_device) for k, v in enc.items()} | |
| mask_ref["mask"] = enc.get("attention_mask", None) | |
| try: | |
| _ = model(**enc, use_cache=False) | |
| except Exception: | |
| continue | |
| # Cleanup hooks for this chunk. | |
| for h in hooks: | |
| try: | |
| h.remove() | |
| except Exception: | |
| pass | |
| # Finalize and store cholesky factors. | |
| for gk, a in acc.items(): | |
| if a.count <= 0: | |
| continue | |
| cov = a.gram / float(max(1, a.count)) | |
| try: | |
| L = _safe_cholesky(cov, eps=eps) | |
| except Exception: | |
| continue | |
| stats[gk] = CalibStatsEntry(in_chol=L.detach().to(dtype=store_dtype).cpu().contiguous()) | |
| # Release chunk accumulators. | |
| del acc | |
| try: | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Release model memory | |
| del model | |
| try: | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return stats | |
| def collect_calibration_stats(student_path: str, | |
| student_arch: ArchitectureInfo, | |
| config: DistillConfig) -> Dict[str, CalibStatsEntry]: | |
| """Dispatch calibration collection based on config.calibration_mode.""" | |
| mode = str(getattr(config, "calibration_mode", "none")).lower() | |
| if mode == "rms": | |
| return collect_calibration_stats_rms(student_path, student_arch, config) | |
| if mode == "cov": | |
| return collect_calibration_stats_cov(student_path, student_arch, config) | |
| return {} | |
| def save_calibration_stats(stats: Dict[str, CalibStatsEntry], path: str) -> None: | |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) | |
| out: Dict[str, torch.Tensor] = {} | |
| for name, e in stats.items(): | |
| if e.in_rms is not None: | |
| out[f"{name}.in_rms"] = e.in_rms.detach().cpu().contiguous() | |
| if e.out_rms is not None: | |
| out[f"{name}.out_rms"] = e.out_rms.detach().cpu().contiguous() | |
| if e.in_chol is not None: | |
| out[f"{name}.in_chol"] = e.in_chol.detach().cpu().contiguous() | |
| if not out: | |
| return | |
| save_file(out, path, metadata={"format": "ud_calib_stats_v2"}) | |
| def load_calibration_stats(path: str) -> Dict[str, CalibStatsEntry]: | |
| tensors = load_file(path) | |
| stats: Dict[str, CalibStatsEntry] = {} | |
| tmp: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) | |
| for k, v in tensors.items(): | |
| if k.endswith(".in_rms"): | |
| tmp[k[:-len(".in_rms")]]["in"] = v | |
| elif k.endswith(".out_rms"): | |
| tmp[k[:-len(".out_rms")]]["out"] = v | |
| elif k.endswith(".in_chol"): | |
| tmp[k[:-len(".in_chol")]]["in_chol"] = v | |
| for name, d in tmp.items(): | |
| stats[name] = CalibStatsEntry( | |
| in_rms=d.get("in", None), | |
| out_rms=d.get("out", None), | |
| in_chol=d.get("in_chol", None), | |
| ) | |
| return stats | |
| def _delta_whiten_diag(delta: torch.Tensor, | |
| in_rms: Optional[torch.Tensor], | |
| out_rms: Optional[torch.Tensor], | |
| eps: float) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| """ | |
| Return whitened delta and the (clamped) scaling vectors used. | |
| delta: (out, in) | |
| in_rms: (in,) | |
| out_rms: (out,) | |
| """ | |
| dw = delta | |
| in_s = None | |
| out_s = None | |
| if out_rms is not None: | |
| out_s = out_rms.to(device=delta.device, dtype=torch.float32).clamp(min=eps) | |
| dw = dw / out_s.unsqueeze(1) | |
| if in_rms is not None: | |
| in_s = in_rms.to(device=delta.device, dtype=torch.float32).clamp(min=eps) | |
| dw = dw / in_s.unsqueeze(0) | |
| return dw, in_s, out_s | |
| def extract_lora_pissa_calibrated( | |
| delta: torch.Tensor, | |
| rank: int, | |
| svd_engine: GPUAcceleratedSVD, | |
| config: DistillConfig, | |
| module_name: str, | |
| calib_stats: Dict[str, CalibStatsEntry], | |
| student_arch: ArchitectureInfo, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int]: | |
| """Calibration-aware PiSSA. | |
| - rms: diagonal RMS whitening (ASVD-style), then unwhiten LoRA factors. | |
| - cov: SVD-LLM style full-covariance input whitening via Cholesky factor S of E[xx^T]. | |
| PiSSA is applied on (ฮยทS) and then unwhitened by S^{-1} on the right factor. | |
| """ | |
| if not calib_stats: | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| mode = str(getattr(config, "calibration_mode", "none")).lower() | |
| eps = float(getattr(config, "calib_eps", 1e-6)) | |
| if mode == "rms": | |
| entry = calib_stats.get(module_name, None) | |
| if entry is None: | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| if entry.in_rms is None and entry.out_rms is None: | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| dw, in_s, out_s = _delta_whiten_diag(delta.float(), entry.in_rms, entry.out_rms, eps) | |
| # Run PiSSA on whitened delta | |
| lora_A_w, lora_B_w, actual_rank = extract_lora_pissa(dw.to(delta.dtype), rank, svd_engine, config) | |
| if lora_A_w is None or lora_B_w is None or actual_rank <= 0: | |
| return None, None, 0 | |
| # Unwhiten factors back to original space: ฮ โ (D_outยทB_w) (A_wยทD_in) | |
| lora_A = lora_A_w.float() | |
| lora_B = lora_B_w.float() | |
| if out_s is not None: | |
| lora_B = lora_B * out_s.unsqueeze(1) | |
| if in_s is not None: | |
| lora_A = lora_A * in_s.unsqueeze(0) | |
| return lora_A.to(dtype=delta.dtype).contiguous(), lora_B.to(dtype=delta.dtype).contiguous(), actual_rank | |
| if mode == "cov": | |
| key = _resolve_cov_entry_key(module_name, calib_stats, student_arch) | |
| if not key: | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| S = calib_stats[key].in_chol | |
| if S is None: | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| # Ensure S is on the same device. | |
| S = S.to(device=delta.device, dtype=torch.float32) | |
| # PiSSA on (ฮยทS) | |
| dw = (delta.float() @ S).to(dtype=delta.dtype) | |
| lora_A_w, lora_B_w, actual_rank = extract_lora_pissa(dw, rank, svd_engine, config) | |
| if lora_A_w is None or lora_B_w is None or actual_rank <= 0: | |
| return None, None, 0 | |
| # Unwhiten on the right: A = A_w ยท S^{-1} | |
| A_w = lora_A_w.float().contiguous() | |
| B = lora_B_w.float().contiguous() | |
| # Solve (S^T) X^T = A_w^T => X = A_w ยท S^{-1} | |
| try: | |
| A = torch.linalg.solve_triangular(S.transpose(0, 1), A_w.transpose(0, 1), upper=True).transpose(0, 1) | |
| except Exception: | |
| # Fallback: generic solve (slower, but robust) | |
| A = torch.linalg.solve(S.transpose(0, 1), A_w.transpose(0, 1)).transpose(0, 1) | |
| return A.to(dtype=delta.dtype).contiguous(), B.to(dtype=delta.dtype).contiguous(), actual_rank | |
| return extract_lora_pissa(delta, rank, svd_engine, config) | |
| # ============================================================================= | |
| # LAYER MAPPING | |
| # ============================================================================= | |
| def teacher_idx_from_student_idx(s_idx: int, s_total: int, t_total: int, | |
| schedule: str = "sigmoid", | |
| k: float = 0.15) -> Tuple[int, float]: | |
| """Map student layer index to teacher layer index with interpolation weight.""" | |
| if t_total <= 1 or s_total <= 1: | |
| return 0, 0.0 | |
| s_norm = s_idx / (s_total - 1) | |
| if schedule == "sigmoid": | |
| x = 12 * (s_norm - 0.5) | |
| t_norm = 1 / (1 + math.exp(-k * x)) | |
| else: | |
| t_norm = s_norm | |
| t_float = t_norm * (t_total - 1) | |
| t_floor = int(t_float) | |
| t_floor = max(0, min(t_floor, t_total - 1)) | |
| interp_w = t_float - t_floor | |
| return t_floor, interp_w | |
| def create_layer_mapping(teacher_info: ArchitectureInfo, | |
| student_info: ArchitectureInfo) -> Dict[str, str]: | |
| """Create mapping from student layer names to teacher layer names.""" | |
| mapping = {} | |
| attn_map = [ | |
| (student_info.q_proj_name, teacher_info.q_proj_name), | |
| (student_info.k_proj_name, teacher_info.k_proj_name), | |
| (student_info.v_proj_name, teacher_info.v_proj_name), | |
| (student_info.o_proj_name, teacher_info.o_proj_name), | |
| ] | |
| mlp_map = [ | |
| (student_info.gate_proj_name, teacher_info.gate_proj_name), | |
| (student_info.up_proj_name, teacher_info.up_proj_name), | |
| (student_info.down_proj_name, teacher_info.down_proj_name), | |
| ] | |
| for s_name, t_name in attn_map + mlp_map: | |
| if s_name != t_name: | |
| mapping[s_name] = t_name | |
| return mapping | |
| # ============================================================================= | |
| # PEFT FORMAT OUTPUT | |
| # ============================================================================= | |
| def to_peft_key(param_key: str) -> str: | |
| """Convert parameter key to PEFT-compatible format.""" | |
| key = param_key | |
| if key.endswith(".weight"): | |
| key = key[:-len(".weight")] | |
| base_prefix = "base_model.model." | |
| if key.startswith(base_prefix): | |
| return key | |
| if key.startswith("model."): | |
| return base_prefix + key | |
| elif key.startswith("transformer."): | |
| return base_prefix + key | |
| else: | |
| return base_prefix + "model." + key | |
| def save_peft_adapter(lora_weights: Dict[str, torch.Tensor], | |
| output_dir: str, | |
| config: DistillConfig, | |
| student_arch: ArchitectureInfo, | |
| rank_info: Dict[str, int]): | |
| """Save LoRA weights in PEFT-compatible format.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| weights_path = os.path.join(output_dir, "adapter_model.safetensors") | |
| # safetensors requires CPU + contiguous tensors (no views). | |
| safe_weights: Dict[str, torch.Tensor] = {} | |
| for k, v in lora_weights.items(): | |
| if not isinstance(v, torch.Tensor): | |
| continue | |
| t = v.detach() | |
| if t.device.type != "cpu": | |
| t = t.to("cpu") | |
| # Ensure contiguous storage (also breaks view/reference sharing). | |
| t = t.contiguous() | |
| safe_weights[k] = t | |
| lora_weights = safe_weights | |
| save_file(lora_weights, weights_path) | |
| target_modules = set() | |
| for key in lora_weights.keys(): | |
| module = re.sub(r'\.lora_[AB](?:\.[^\.]+)?\.weight$', '', key) | |
| module = re.sub(r'^base_model\.model\.', '', module) | |
| parts = module.split('.') | |
| for part in reversed(parts): | |
| if part in ['q_proj', 'k_proj', 'v_proj', 'o_proj', | |
| 'gate_proj', 'up_proj', 'down_proj', | |
| 'c_attn', 'c_proj', 'c_fc', 'query_key_value', | |
| 'w1', 'w2', 'w3', 'fc1', 'fc2', 'dense']: | |
| target_modules.add(part) | |
| break | |
| if rank_info: | |
| rank_counts = defaultdict(int) | |
| for r in rank_info.values(): | |
| rank_counts[r] += 1 | |
| effective_rank = max(rank_counts.items(), key=lambda x: x[1])[0] | |
| else: | |
| effective_rank = config.rank_default | |
| def _rank_key_for_config(peft_key: str) -> str: | |
| # PEFT rank_pattern/alpha_pattern match base model module names (e.g. 'model.layers.0...'), | |
| # not the wrapped state_dict prefix (e.g. 'base_model.model.'). | |
| prefix = "base_model.model." | |
| return peft_key[len(prefix):] if peft_key.startswith(prefix) else peft_key | |
| # alpha_pattern: keep per-module scaling stable when rank_pattern varies | |
| alpha_pattern: Dict[str, int] = {} | |
| alpha_mode = getattr(config, "alpha_mode", "auto") | |
| if alpha_mode == "auto": | |
| alpha_mode = "rank" if config.lora_alpha is None else "fixed" | |
| if alpha_mode == "rank" and rank_info: | |
| alpha_pattern = {} | |
| for _k, _v in rank_info.items(): | |
| _iv = int(_v) | |
| alpha_pattern[_rank_key_for_config(_k)] = _iv | |
| rank_pattern: Dict[str, int] = {} | |
| if rank_info: | |
| for _k, _v in rank_info.items(): | |
| _iv = int(_v) | |
| rank_pattern[_rank_key_for_config(_k)] = _iv | |
| adapter_config = { | |
| "alpha_pattern": alpha_pattern, | |
| "auto_mapping": None, | |
| "base_model_name_or_path": config.student_path, | |
| "bias": "none", | |
| "fan_in_fan_out": False, | |
| "inference_mode": True, | |
| "init_lora_weights": True, | |
| "layer_replication": None, | |
| "layers_pattern": None, | |
| "layers_to_transform": None, | |
| "loftq_config": {}, | |
| "lora_alpha": config.lora_alpha or effective_rank, | |
| "lora_dropout": 0.0, | |
| "megatron_config": None, | |
| "megatron_core": "megatron.core", | |
| "modules_to_save": None, | |
| "peft_type": "LORA", | |
| "r": effective_rank, | |
| "rank_pattern": rank_pattern, | |
| "revision": None, | |
| "target_modules": sorted(list(target_modules)), | |
| "task_type": "CAUSAL_LM", | |
| "use_dora": False, | |
| "use_rslora": False, | |
| } | |
| config_path = os.path.join(output_dir, "adapter_config.json") | |
| with open(config_path, 'w') as f: | |
| json.dump(adapter_config, f, indent=2) | |
| print(f"\nโ Saved PEFT adapter to: {output_dir}") | |
| print(f" - Weights: {weights_path}") | |
| print(f" - Config: {config_path}") | |
| print(f" - Effective rank: {effective_rank}") | |
| print(f" - Target modules: {sorted(list(target_modules))}") | |
| # ============================================================================= | |
| # ACCELERATED DISTILLER | |
| # ============================================================================= | |
| class AcceleratedDistiller: | |
| """GPU-accelerated distillation with memory management.""" | |
| def __init__(self, config: DistillConfig, device: torch.device, | |
| teacher_arch: ArchitectureInfo, student_arch: ArchitectureInfo): | |
| self.config = config | |
| self.device = device | |
| self.teacher_arch = teacher_arch | |
| self.student_arch = student_arch | |
| # Check for cross-architecture distillation | |
| self.is_cross_arch = ( | |
| teacher_arch.hidden_size != student_arch.hidden_size or | |
| teacher_arch.num_hidden_layers != student_arch.num_hidden_layers | |
| ) | |
| if self.is_cross_arch and config.verbose: | |
| print(f"\nโ ๏ธ Cross-architecture distillation detected!") | |
| print(f" Teacher: {teacher_arch.num_hidden_layers} layers, {teacher_arch.hidden_size} hidden") | |
| print(f" Student: {student_arch.num_hidden_layers} layers, {student_arch.hidden_size} hidden") | |
| # FIX #3: Auto-adjust DARE for cross-architecture | |
| if self.is_cross_arch and config.auto_adjust_dare and config.use_dare: | |
| # Reduce drop rate for cross-arch to preserve more info | |
| old_rate = config.dare_drop_rate | |
| config.dare_drop_rate = min(0.1, config.dare_drop_rate) | |
| if config.verbose and old_rate != config.dare_drop_rate: | |
| print(f" Auto-adjusted DARE drop_rate: {old_rate} โ {config.dare_drop_rate}") | |
| self.memory_manager = GPUMemoryManager( | |
| device, config.gpu_memory_fraction, config.verbose | |
| ) | |
| self.stream_manager = CUDAStreamManager( | |
| device, config.num_cuda_streams | |
| ) if config.use_cuda_streams and device.type == "cuda" else None | |
| if config.offload_strategy == "auto": | |
| stats = self.memory_manager.get_stats() | |
| if stats["total_gb"] < 16: | |
| strategy = OffloadStrategy.CPU | |
| else: | |
| strategy = OffloadStrategy.NONE | |
| else: | |
| strategy = OffloadStrategy[config.offload_strategy.upper()] | |
| self.offloader = LayerOffloader( | |
| strategy=strategy, | |
| device=device, | |
| offload_dir=config.offload_dir, | |
| use_pinned=config.use_pinned_memory, | |
| max_cpu_gb=config.max_cpu_memory_gb, | |
| prefetch_count=config.prefetch_layers, | |
| verbose=config.verbose | |
| ) | |
| self.svd_engine = GPUAcceleratedSVD( | |
| device=device, | |
| memory_manager=self.memory_manager, | |
| stream_manager=self.stream_manager, | |
| use_mixed_precision=config.use_mixed_precision, | |
| chunk_threshold_mb=config.chunk_threshold_mb, | |
| chunk_size_mb=config.chunk_size_mb, | |
| verbose=config.verbose, | |
| svd_mode=getattr(config, 'svd_mode', 'auto'), | |
| randomized_iter=getattr(config, 'svd_randomized_iter', 2), | |
| randomized_oversamples=getattr(config, 'svd_randomized_oversamples', 8), | |
| auto_min_dim=getattr(config, 'svd_auto_min_dim', 1024), | |
| auto_full_rank_ratio=getattr(config, 'svd_auto_full_rank_ratio', 0.6), | |
| auto_lowrank=getattr(config, 'svd_auto_lowrank', 'randomized'), | |
| adaptive_block_size=getattr(config, 'svd_adaptive_block_size', 32), | |
| adaptive_n_test=getattr(config, 'svd_adaptive_n_test', 8), | |
| drf_steps=int(getattr(config, 'svd_drf_steps', 1)), | |
| drf_theta=float(getattr(config, 'svd_drf_theta', 0.5)), | |
| drf_resid_eps=float(getattr(config, 'svd_drf_resid_eps', 1e-8)), | |
| aurora_steps=int(getattr(config, 'svd_aurora_steps', 1)), | |
| aurora_order=int(getattr(config, 'svd_aurora_order', 2)), | |
| aurora_theta1=float(getattr(config, 'svd_aurora_theta1', 0.5)), | |
| aurora_theta2=float(getattr(config, 'svd_aurora_theta2', 0.25)), | |
| aurora_resid_eps=float(getattr(config, 'svd_aurora_resid_eps', 1e-8)), | |
| aurora_level2_keep=float(getattr(config, 'svd_aurora_level2_keep', 1.0)), | |
| ) | |
| self.lora_weights: Dict[str, torch.Tensor] = {} | |
| self.rank_info: Dict[str, int] = {} | |
| # MoE->Dense synthesis cache (teacher MoE -> student dense) | |
| self._teacher_moe_style: Optional[str] = None | |
| self._teacher_has_mlp_experts: bool = False | |
| self._teacher_has_block_sparse_moe: bool = False | |
| self._moe_dense_cache: Dict[Tuple[str, int, str], torch.Tensor] = {} | |
| self.stats = defaultdict(int) | |
| self.calib_stats = None # optional: {module_name: (in_rms, out_rms)} | |
| def load_tensors_with_prefetch(self, | |
| keys: List[str], | |
| model_folder: str, | |
| weight_map: Dict[str, str], | |
| next_keys: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: | |
| """Load tensors with optional prefetching.""" | |
| shards = defaultdict(list) | |
| for k in keys: | |
| if k in weight_map: | |
| shards[weight_map[k]].append(k) | |
| tensors = {} | |
| for shard, shard_keys in shards.items(): | |
| path = os.path.join(model_folder, shard) | |
| with safe_open(path, framework="pt") as f: | |
| for k in shard_keys: | |
| try: | |
| t = f.get_tensor(k) | |
| scale_key = f"{k}.scales" | |
| if scale_key in f.keys(): | |
| scale = f.get_tensor(scale_key) | |
| t = t.float() * scale.float() | |
| t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) | |
| tensors[k] = t | |
| except Exception: | |
| continue | |
| if self.stream_manager: | |
| with self.stream_manager.stream_context(): | |
| for k in tensors: | |
| tensors[k] = tensors[k].to(self.device, non_blocking=True) | |
| self.stream_manager.synchronize_current() | |
| else: | |
| for k in tensors: | |
| tensors[k] = tensors[k].to(self.device) | |
| return tensors | |
| # ----------------------------- | |
| # MoE -> Dense support (teacher MoE, student dense) | |
| # ----------------------------- | |
| def _init_teacher_moe_style(self, teacher_weight_map: Dict[str, str]) -> None: | |
| """One-time detection of teacher MoE key layout.""" | |
| if self._teacher_moe_style is not None: | |
| return | |
| # Detect common HF MoE layouts | |
| self._teacher_has_mlp_experts = any(".mlp.experts." in k for k in teacher_weight_map.keys()) | |
| self._teacher_has_block_sparse_moe = any("block_sparse_moe.experts." in k for k in teacher_weight_map.keys()) | |
| if self._teacher_has_mlp_experts and self._teacher_has_block_sparse_moe: | |
| self._teacher_moe_style = "both" | |
| elif self._teacher_has_mlp_experts: | |
| self._teacher_moe_style = "mlp_experts" | |
| elif self._teacher_has_block_sparse_moe: | |
| self._teacher_moe_style = "block_sparse_moe" | |
| else: | |
| self._teacher_moe_style = "none" | |
| def _estimate_moe_expert_weights( | |
| self, | |
| layer_prefix: str, | |
| layer_idx: int, | |
| teacher_weight_map: Dict[str, str], | |
| num_experts: int, | |
| ) -> torch.Tensor: | |
| """Estimate per-expert routing weights using router matrix norms (data-free).""" | |
| # Default: uniform | |
| w = torch.full((num_experts,), 1.0 / max(1, num_experts), dtype=torch.float32) | |
| # Candidate router keys (common HF names) | |
| router_keys: List[str] = [] | |
| if self._teacher_has_mlp_experts: | |
| router_keys.append(f"{layer_prefix}.{layer_idx}.mlp.gate.weight") | |
| if self._teacher_has_block_sparse_moe: | |
| router_keys.append(f"{layer_prefix}.{layer_idx}.block_sparse_moe.gate.weight") | |
| # fallback variants | |
| if self._teacher_has_mlp_experts: | |
| router_keys.append(f"{layer_prefix}.{layer_idx}.mlp.router.weight") | |
| if self._teacher_has_block_sparse_moe: | |
| router_keys.append(f"{layer_prefix}.{layer_idx}.block_sparse_moe.router.weight") | |
| router_key = next((k for k in router_keys if k in teacher_weight_map), None) | |
| if router_key is None: | |
| return w | |
| try: | |
| t = self.load_tensors_with_prefetch([router_key], self.config.teacher_path, teacher_weight_map)[router_key] | |
| t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0).float() | |
| if t.dim() == 2 and t.shape[0] == num_experts: | |
| scores = torch.linalg.vector_norm(t, ord=2, dim=1) | |
| # Softmax on scores for stability | |
| w = torch.softmax(scores, dim=0) | |
| except Exception: | |
| # Keep uniform | |
| pass | |
| return w | |
| def _synthesize_moe_dense_part( | |
| self, | |
| layer_prefix: str, | |
| layer_idx: int, | |
| part: str, | |
| teacher_weight_map: Dict[str, str], | |
| ) -> Optional[torch.Tensor]: | |
| """Synthesize a dense MLP matrix from teacher MoE experts for a single layer and part. | |
| part is one of: 'gate_proj', 'up_proj', 'down_proj' | |
| """ | |
| self._init_teacher_moe_style(teacher_weight_map) | |
| if self._teacher_moe_style in (None, "none"): | |
| return None | |
| num_experts = int(getattr(self.teacher_arch, "num_experts", 0) or 0) | |
| if num_experts <= 0: | |
| # Try infer from keys (robust fallback) | |
| max_e = -1 | |
| rx = re.compile(rf"^{re.escape(layer_prefix)}\.{layer_idx}\.(?:mlp\.experts|block_sparse_moe\.experts)\.(\d+)\.") | |
| for k in teacher_weight_map.keys(): | |
| m = rx.match(k) | |
| if m: | |
| max_e = max(max_e, int(m.group(1))) | |
| num_experts = max_e + 1 | |
| if num_experts <= 0: | |
| return None | |
| # Cache key | |
| cache_key = (layer_prefix, int(layer_idx), str(part)) | |
| if cache_key in self._moe_dense_cache: | |
| return self._moe_dense_cache[cache_key] | |
| # Map dense part -> possible expert param names | |
| part_candidates: List[str] = [] | |
| if part == "gate_proj": | |
| part_candidates = ["gate_proj", "w1"] | |
| elif part == "up_proj": | |
| part_candidates = ["up_proj", "w3"] | |
| elif part == "down_proj": | |
| part_candidates = ["down_proj", "w2"] | |
| else: | |
| return None | |
| # Estimate routing weights (data-free, router norm) | |
| w = self._estimate_moe_expert_weights(layer_prefix, layer_idx, teacher_weight_map, num_experts) | |
| # Top-k experts (use existing config knob max_experts_to_blend; if <=0, keep all) | |
| topk = int(self.config.max_experts_to_blend or 0) | |
| if topk <= 0: | |
| topk = num_experts | |
| topk = min(topk, num_experts) | |
| # Select experts | |
| if topk < num_experts: | |
| vals, idxs = torch.topk(w, k=topk, largest=True) | |
| sel_experts = idxs.tolist() | |
| sel_w = (vals / (vals.sum() + 1e-12)).tolist() | |
| else: | |
| sel_experts = list(range(num_experts)) | |
| sel_w = (w / (w.sum() + 1e-12)).tolist() | |
| # Build actual keys to load for selected experts | |
| expert_keys: List[str] = [] | |
| expert_key_for: Dict[int, str] = {} | |
| def _first_existing(cands: List[str]) -> Optional[str]: | |
| for kk in cands: | |
| if kk in teacher_weight_map: | |
| return kk | |
| return None | |
| for e in sel_experts: | |
| cands: List[str] = [] | |
| if self._teacher_has_mlp_experts: | |
| for pn in part_candidates: | |
| cands.append(f"{layer_prefix}.{layer_idx}.mlp.experts.{e}.{pn}.weight") | |
| if self._teacher_has_block_sparse_moe: | |
| for pn in part_candidates: | |
| cands.append(f"{layer_prefix}.{layer_idx}.block_sparse_moe.experts.{e}.{pn}.weight") | |
| kk = _first_existing(cands) | |
| if kk is None: | |
| continue | |
| expert_key_for[e] = kk | |
| expert_keys.append(kk) | |
| if not expert_keys: | |
| return None | |
| # Load expert weights (streamed accumulation) | |
| tensors = self.load_tensors_with_prefetch(expert_keys, self.config.teacher_path, teacher_weight_map) | |
| acc: Optional[torch.Tensor] = None | |
| total_w = 0.0 | |
| for e, we in zip(sel_experts, sel_w): | |
| kk = expert_key_for.get(e) | |
| if kk is None: | |
| continue | |
| t = tensors.get(kk, None) | |
| if t is None: | |
| continue | |
| # Accumulate in fp32 on CPU (keeps VRAM low, maximizes precision) | |
| t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0).float() | |
| if acc is None: | |
| acc = t.mul(float(we)) | |
| else: | |
| acc.add_(t, alpha=float(we)) | |
| total_w += float(we) | |
| if acc is None or total_w <= 0.0: | |
| return None | |
| # Normalize in case some experts were missing | |
| if abs(total_w - 1.0) > 1e-3: | |
| acc.div_(total_w) | |
| self._moe_dense_cache[cache_key] = acc | |
| return acc | |
| def _maybe_synthesize_teacher_from_moe( | |
| self, | |
| layer_prefix: str, | |
| t_floor: int, | |
| t_ceil: int, | |
| mapped_rest: str, | |
| teacher_weight_map: Dict[str, str], | |
| ) -> Dict[str, torch.Tensor]: | |
| """If teacher is MoE and student expects dense MLP weights, synthesize teacher matrices.""" | |
| if self.teacher_arch.mlp_type != MLPType.MOE: | |
| return {} | |
| if self.student_arch.mlp_type == MLPType.MOE: | |
| return {} | |
| # Only for standard MLP projections | |
| part = None | |
| if ".mlp.gate_proj." in f".{mapped_rest}.": | |
| part = "gate_proj" | |
| elif ".mlp.up_proj." in f".{mapped_rest}.": | |
| part = "up_proj" | |
| elif ".mlp.down_proj." in f".{mapped_rest}.": | |
| part = "down_proj" | |
| else: | |
| return {} | |
| out: Dict[str, torch.Tensor] = {} | |
| tf = self._synthesize_moe_dense_part(layer_prefix, int(t_floor), part, teacher_weight_map) | |
| if tf is not None: | |
| out[f"{layer_prefix}.{t_floor}.<moe_synth>.mlp.{part}.weight"] = tf | |
| if int(t_ceil) != int(t_floor): | |
| tc = self._synthesize_moe_dense_part(layer_prefix, int(t_ceil), part, teacher_weight_map) | |
| if tc is not None: | |
| out[f"{layer_prefix}.{t_ceil}.<moe_synth>.mlp.{part}.weight"] = tc | |
| return out | |
| def process_layer( | |
| self, | |
| student_key: str, | |
| student_tensor: torch.Tensor, | |
| teacher_tensors: Dict[str, torch.Tensor], | |
| interp_w: float, | |
| ) -> Optional[Tuple[torch.Tensor, torch.Tensor, int]]: | |
| """Process a single layer: compute delta, apply pruning, extract LoRA.""" | |
| if not teacher_tensors: | |
| return None | |
| teacher_keys = list(teacher_tensors.keys()) | |
| student_on_device = student_tensor.to(self.device, non_blocking=True) | |
| target_shape = student_tensor.shape | |
| # --------------------------------------------------------------------- | |
| # Cross-architecture safe path: | |
| # When projection is needed (teacher matrix shape != student matrix shape), | |
| # do NOT compute delta = (projected_teacher - student_full). | |
| # That subtracts the student's residual space and pushes the weight toward | |
| # a low-rank matrix, which is a common cause of degenerate/infinite output. | |
| # | |
| # Instead: | |
| # 1) compute student's truncated SVD basis (U_s, S_s, V_s) | |
| # 2) compute teacher core in that basis: core_t = U_s^T Wt_resized V_s | |
| # 3) compute student core: core_s = diag(S_s) | |
| # 4) delta_core = core_t - core_s (only inside the shared subspace) | |
| # 5) extract LoRA directly from delta_core without forming the full delta | |
| # --------------------------------------------------------------------- | |
| try_subspace = ( | |
| self.config.use_svd_projection | |
| and self.is_cross_arch | |
| and (getattr(self.config, "subspace_delta", "off") != "off") | |
| and student_on_device.dim() == 2 | |
| and len(target_shape) == 2 | |
| and any(tuple(t.shape) != tuple(target_shape) for t in teacher_tensors.values()) | |
| ) | |
| if try_subspace: | |
| try: | |
| # Student basis (common space) | |
| U_s, S_s, V_s = compute_student_svd_basis( | |
| student_on_device, self.svd_engine, self.config, self.device | |
| ) | |
| k = int(S_s.numel()) | |
| diag_Ss = torch.diag(S_s).to(device=self.device, dtype=torch.float32) | |
| # Compute teacher cores and build delta in core space | |
| if len(teacher_keys) == 1: | |
| core_t = teacher_core_in_student_basis( | |
| teacher_tensors[teacher_keys[0]], | |
| U_s, V_s, | |
| self.svd_engine, self.config, self.device | |
| ) | |
| delta_core = core_t - diag_Ss | |
| if self.config.use_dare: | |
| delta_core = apply_dare(delta_core, self.config.dare_drop_rate) | |
| if self.config.use_ties and getattr(self.config, "ties_trim_single", True): | |
| delta_core = apply_ties_single(delta_core, self.config.ties_density) | |
| else: | |
| core_floor = teacher_core_in_student_basis( | |
| teacher_tensors[teacher_keys[0]], | |
| U_s, V_s, | |
| self.svd_engine, self.config, self.device | |
| ) | |
| core_ceil = teacher_core_in_student_basis( | |
| teacher_tensors[teacher_keys[1]], | |
| U_s, V_s, | |
| self.svd_engine, self.config, self.device | |
| ) | |
| delta_floor = core_floor - diag_Ss | |
| delta_ceil = core_ceil - diag_Ss | |
| mix_w = float(interp_w) | |
| if getattr(self.config, "interp_mode", "schedule") == "lsq": | |
| mix_w = lsq_mix_weight_from_deltas(delta_floor, delta_ceil, mix_w) | |
| if self.config.use_ties: | |
| deltas = [delta_floor * (1.0 - mix_w), delta_ceil * mix_w] | |
| if self.config.use_dare: | |
| deltas = [apply_dare(d, self.config.dare_drop_rate) for d in deltas] | |
| delta_core = apply_ties(deltas, self.config.ties_density) | |
| else: | |
| core_t = core_floor * (1.0 - mix_w) + core_ceil * mix_w | |
| delta_core = core_t - diag_Ss | |
| if self.config.use_dare: | |
| delta_core = apply_dare(delta_core, self.config.dare_drop_rate) | |
| # Subspace-delta stabilization / mode control | |
| mode = getattr(self.config, "subspace_delta", "off") | |
| # Optional: match teacher core scale to student core scale | |
| if getattr(self.config, "subspace_norm_match", True): | |
| core_s_norm = torch.norm(diag_Ss.float()) + 1e-8 | |
| core_t = (delta_core + diag_Ss).to(dtype=torch.float32) | |
| core_t_norm = torch.norm(core_t) + 1e-8 | |
| scale_t = (core_s_norm / core_t_norm).clamp(0.25, 4.0) | |
| core_t = core_t * scale_t | |
| delta_core = (core_t - diag_Ss).to(dtype=delta_core.dtype) | |
| # Optional: suppress off-diagonal energy (improves stability) | |
| if mode in ("diag", "auto"): | |
| diag_part = torch.diag(torch.diag(delta_core)) | |
| off_part = delta_core - diag_part | |
| off_frac = float((torch.norm(off_part.float()) / (torch.norm(delta_core.float()) + 1e-8)).item()) | |
| thr = float(getattr(self.config, "subspace_offdiag_threshold", 0.55)) | |
| if mode == "diag" or (mode == "auto" and off_frac > thr): | |
| delta_core = diag_part | |
| # Delta ratio clipping (use ||delta_core|| since U/V are orthonormal) | |
| if getattr(self.config, "max_delta_ratio", 0.0) and self.config.max_delta_ratio > 0: | |
| denom_mode = getattr(self.config, "subspace_denom", "core") | |
| if denom_mode == "full": | |
| denom = torch.norm(student_on_device.float()) + 1e-8 | |
| else: | |
| denom = torch.norm(diag_Ss.float()) + 1e-8 | |
| ratio = torch.norm(delta_core.float()) / denom | |
| if ratio > self.config.max_delta_ratio: | |
| scale_t = (self.config.max_delta_ratio / (ratio + 1e-12)).clamp(max=1.0) | |
| scale = float(scale_t.item()) | |
| delta_core = delta_core * scale | |
| print( | |
| f" [delta-clip] scaled by {scale:.4f} (ratio {float(ratio.item()):.4f} -> {self.config.max_delta_ratio:.4f})" | |
| ) | |
| # Rank selection hint (still respected when adaptive rank is off) | |
| is_attn = any(proj in student_key for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]) | |
| is_mlp = any(proj in student_key for proj in ["gate_proj", "up_proj", "down_proj"]) | |
| rank = self.config.rank_default | |
| if is_attn and self.config.rank_attn is not None: | |
| rank = self.config.rank_attn | |
| elif is_mlp and self.config.rank_mlp is not None: | |
| rank = self.config.rank_mlp | |
| # Output dtype: match student for plug-and-play inference | |
| out_dtype = student_on_device.dtype if self.config.use_mixed_precision else torch.float32 | |
| lora_A, lora_B, actual_rank = extract_lora_from_core_delta( | |
| delta_core, U_s, V_s, int(rank), self.config, out_dtype | |
| ) | |
| if lora_A is None or lora_B is None or actual_rank <= 0: | |
| return None | |
| # Ensure shapes are correct | |
| if lora_A.shape[1] != target_shape[1] or lora_B.shape[0] != target_shape[0]: | |
| # Fallback if something unexpected happened | |
| return None | |
| return lora_A, lora_B, int(actual_rank) | |
| except Exception: | |
| # Fall back to legacy behavior | |
| pass | |
| # --------------------------------------------------------------------- | |
| # Legacy path: teacher and student already share shape, or projection is off. | |
| # --------------------------------------------------------------------- | |
| deltas = [] | |
| for teacher_key in teacher_keys: | |
| teacher_tensor = teacher_tensors[teacher_key].to(self.device) | |
| # Cross-architecture projection (legacy): only used when caller didn't | |
| # trigger the safer subspace path above. | |
| if self.config.use_svd_projection and self.is_cross_arch: | |
| teacher_tensor = svd_projection( | |
| teacher_tensor, target_shape, self.svd_engine, self.config, self.device, ref=student_on_device | |
| ) | |
| else: | |
| teacher_tensor = self._project_tensor(teacher_tensor, target_shape) | |
| delta = teacher_tensor.to(self.device) - student_on_device | |
| deltas.append(delta) | |
| # Apply interpolation or TIES between floor/ceil deltas | |
| mix_w = float(interp_w) | |
| if len(deltas) == 2 and getattr(self.config, "interp_mode", "schedule") == "lsq": | |
| mix_w = lsq_mix_weight_from_deltas(deltas[0], deltas[1], mix_w) | |
| if len(deltas) == 1: | |
| final_delta = deltas[0] | |
| if self.config.use_dare: | |
| final_delta = apply_dare(final_delta, self.config.dare_drop_rate) | |
| if self.config.use_ties and getattr(self.config, "ties_trim_single", True): | |
| final_delta = apply_ties_single(final_delta, self.config.ties_density) | |
| else: | |
| if self.config.use_ties: | |
| weighted_deltas = [deltas[0] * (1 - mix_w), deltas[1] * mix_w] | |
| if self.config.use_dare: | |
| weighted_deltas = [apply_dare(d, self.config.dare_drop_rate) for d in weighted_deltas] | |
| final_delta = apply_ties(weighted_deltas, self.config.ties_density) | |
| else: | |
| final_delta = deltas[0] * (1 - mix_w) + deltas[1] * mix_w | |
| if self.config.use_dare: | |
| final_delta = apply_dare(final_delta, self.config.dare_drop_rate) | |
| # Final delta ratio clipping (after interpolation/TIES/DARE) | |
| if getattr(self.config, "max_delta_ratio", 0.0) and self.config.max_delta_ratio > 0: | |
| denom = torch.norm(student_on_device.float()) + 1e-8 | |
| ratio = torch.norm(final_delta.float()) / denom | |
| if ratio > self.config.max_delta_ratio: | |
| scale_t = (self.config.max_delta_ratio / (ratio + 1e-12)).clamp(max=1.0) | |
| scale = float(scale_t.item()) | |
| final_delta = final_delta * scale | |
| print( | |
| f" [delta-clip] scaled by {scale:.4f} (ratio {float(ratio.item()):.4f} -> {self.config.max_delta_ratio:.4f})" | |
| ) | |
| # Extract LoRA from final delta | |
| is_attn = any(proj in student_key for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]) | |
| is_mlp = any(proj in student_key for proj in ["gate_proj", "up_proj", "down_proj"]) | |
| rank = self.config.rank_default | |
| if is_attn and self.config.rank_attn is not None: | |
| rank = self.config.rank_attn | |
| elif is_mlp and self.config.rank_mlp is not None: | |
| rank = self.config.rank_mlp | |
| # LoRA extraction (PiSSA). Optionally apply calibration-aware diagonal whitening. | |
| module_name = student_key[:-len(".weight")] if student_key.endswith(".weight") else student_key | |
| if getattr(self.config, "calibration_mode", "none") != "none" and self.calib_stats: | |
| lora_A, lora_B, actual_rank = extract_lora_pissa_calibrated( | |
| final_delta, rank, self.svd_engine, self.config, module_name, self.calib_stats, self.student_arch | |
| ) | |
| else: | |
| lora_A, lora_B, actual_rank = extract_lora_pissa(final_delta, rank, self.svd_engine, self.config) | |
| if lora_A is None: | |
| return None | |
| return lora_A, lora_B, actual_rank | |
| def _project_tensor(self, src: torch.Tensor, | |
| target_shape: Tuple[int, ...]) -> torch.Tensor: | |
| """Simple projection (fallback).""" | |
| if src.shape == target_shape: | |
| return src | |
| if src.dim() == 1: | |
| out = torch.zeros(target_shape, device=src.device, dtype=src.dtype) | |
| copy_len = min(src.numel(), target_shape[0]) | |
| out[:copy_len] = src[:copy_len] | |
| return out | |
| if src.dim() == 2: | |
| out = torch.zeros(target_shape, device=src.device, dtype=src.dtype) | |
| copy_rows = min(src.shape[0], target_shape[0]) | |
| copy_cols = min(src.shape[1], target_shape[1]) | |
| out[:copy_rows, :copy_cols] = src[:copy_rows, :copy_cols] | |
| return out | |
| return torch.zeros(target_shape, device=src.device, dtype=src.dtype) | |
| def run(self, | |
| student_keys: List[str], | |
| teacher_weight_map: Dict[str, str], | |
| student_weight_map: Dict[str, str], | |
| teacher_layers_map: Dict[str, List[int]], | |
| student_layers_map: Dict[str, List[int]], | |
| teacher_arch: ArchitectureInfo, | |
| student_arch: ArchitectureInfo, | |
| layer_name_mapping: Dict[str, str]) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: | |
| """Run accelerated distillation.""" | |
| s_layers = max(len(v) for v in student_layers_map.values()) if student_layers_map else 32 | |
| t_layers = max(len(v) for v in teacher_layers_map.values()) if teacher_layers_map else 32 | |
| # FIX #4: Skip embeddings for cross-arch with different vocab/hidden | |
| skip_embed_warning_shown = False | |
| processable_keys = [] | |
| for k in student_keys: | |
| if not k.endswith(".weight"): | |
| continue | |
| if any(k.endswith(s) for s in [".bias"]): | |
| continue | |
| if any(tok in k.lower() for tok in NORM_TOKENS): | |
| continue | |
| # Special handling for embeddings in cross-arch | |
| if any(tok in k for tok in EMBED_TOKENS): | |
| if not self.config.include_embed_lm_head: | |
| continue | |
| if self.is_cross_arch: | |
| if not skip_embed_warning_shown and self.config.verbose: | |
| print(f"\nโ ๏ธ Skipping embeddings for cross-architecture distillation") | |
| print(f" (Teacher hidden={teacher_arch.hidden_size}, Student hidden={student_arch.hidden_size})") | |
| skip_embed_warning_shown = True | |
| continue | |
| if self.config.include_pattern and not re.search(self.config.include_pattern, k): | |
| continue | |
| if self.config.exclude_pattern and re.search(self.config.exclude_pattern, k): | |
| continue | |
| processable_keys.append(k) | |
| keys_by_shard = defaultdict(list) | |
| for k in processable_keys: | |
| if k in student_weight_map: | |
| keys_by_shard[student_weight_map[k]].append(k) | |
| pbar = tqdm(total=len(processable_keys), | |
| desc="Distilling", | |
| dynamic_ncols=True, | |
| disable=not self.config.verbose) | |
| for shard_file, shard_keys in sorted(keys_by_shard.items()): | |
| pbar.set_description(f"Processing {os.path.basename(shard_file)}") | |
| for i in range(0, len(shard_keys), self.config.micro_batch_size): | |
| micro_keys = shard_keys[i:i + self.config.micro_batch_size] | |
| next_keys = shard_keys[i + self.config.micro_batch_size:i + 2*self.config.micro_batch_size] | |
| student_tensors = self.load_tensors_with_prefetch( | |
| micro_keys, self.config.student_path, student_weight_map, next_keys | |
| ) | |
| for sk in micro_keys: | |
| pbar.update(1) | |
| st = student_tensors.get(sk) | |
| if st is None: | |
| continue | |
| seg = split_key(sk) | |
| if seg is None: | |
| continue | |
| s_prefix, token_name, s_idx, rest = seg | |
| t_floor, interp_w = teacher_idx_from_student_idx( | |
| s_idx, s_layers, t_layers, | |
| self.config.map_schedule, self.config.sigmoid_k | |
| ) | |
| t_ceil = min(t_floor + 1, t_layers - 1) | |
| mapped_rest = rest | |
| for s_name, t_name in layer_name_mapping.items(): | |
| mapped_rest = mapped_rest.replace(s_name, t_name) | |
| t_prefix = s_prefix | |
| key_floor = f"{t_prefix}.{t_floor}.{mapped_rest}" | |
| key_ceil = f"{t_prefix}.{t_ceil}.{mapped_rest}" | |
| teacher_keys_to_load = [] | |
| for tk in [key_floor, key_ceil]: | |
| if tk in teacher_weight_map: | |
| teacher_keys_to_load.append(tk) | |
| teacher_tensors: Dict[str, torch.Tensor] = {} | |
| if teacher_keys_to_load: | |
| teacher_tensors = self.load_tensors_with_prefetch( | |
| teacher_keys_to_load, self.config.teacher_path, teacher_weight_map | |
| ) | |
| else: | |
| teacher_tensors = self._maybe_synthesize_teacher_from_moe( | |
| s_prefix, int(t_floor), int(t_ceil), mapped_rest, teacher_weight_map | |
| ) | |
| if not teacher_tensors: | |
| continue | |
| result = self.process_layer(sk, st, teacher_tensors, interp_w) | |
| if result is not None: | |
| lora_A, lora_B, actual_rank = result | |
| peft_key = to_peft_key(sk) | |
| self.lora_weights[f"{peft_key}.lora_A.default.weight"] = lora_A.detach().cpu().contiguous() | |
| self.lora_weights[f"{peft_key}.lora_B.default.weight"] = lora_B.detach().cpu().contiguous() | |
| self.rank_info[peft_key] = actual_rank | |
| self.stats["processed"] += 1 | |
| del teacher_tensors | |
| del student_tensors | |
| if i % (self.config.micro_batch_size * 4) == 0: | |
| self.memory_manager.clear_cache() | |
| pbar.close() | |
| self.memory_manager.clear_cache() | |
| self.offloader.cleanup() | |
| return self.lora_weights, self.rank_info | |
| # ============================================================================= | |
| # MAIN | |
| # ============================================================================= | |
| def build_argparser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser( | |
| description="Universal SVD-LoRA Distillation v4.2.0 (Cross-Arch Innovations) โ Plug-and-Play Truncated SVD (Low-VRAM)", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| v3.7 highlights: | |
| - Truncated / randomized SVD by default in 'auto' mode (avoids full SVD on large matrices). | |
| - Efficient SVD-based projection for cross-architecture distillation with bounded projection rank. | |
| - Keeps shard-wise streaming + offload (low VRAM) behavior from v3.2. | |
| Examples: | |
| # Cross-architecture distillation (e.g., 72B -> 36B) | |
| python %(prog)s --teacher /path/to/teacher --student /path/to/student --output /path/to/lora \ | |
| --adaptive-rank --min-rank 256 --max-rank 2048 --energy-threshold 0.95 \ | |
| --svd-mode auto --projection-rank 256 --svd-projection \ | |
| --ties --ties-density 0.3 --no-dare \ | |
| --include "self_attn|mlp" --offload-strategy cpu --mixed-precision | |
| """ | |
| ) | |
| # Required | |
| p.add_argument("--teacher", required=True, help="Teacher model folder (HF safetensors)") | |
| p.add_argument("--student", required=True, help="Student model folder (HF safetensors)") | |
| p.add_argument("--output", required=True, help="Output directory for PEFT adapter") | |
| # Rank controls | |
| p.add_argument("--rank", type=int, default=64, help="Default LoRA rank (if not using rank-attn/mlp)") | |
| p.add_argument("--rank-attn", type=int, default=None, help="Override LoRA rank for attention projections") | |
| p.add_argument("--rank-mlp", type=int, default=None, help="Override LoRA rank for MLP projections") | |
| p.add_argument("--lora-alpha", type=int, default=None, help="LoRA alpha; default=r (scale=1)") | |
| p.add_argument("--alpha-mode", type=str, default="auto", choices=["auto","rank","fixed"], help="LoRA alpha policy: auto=rank if --lora-alpha not set; rank=alpha_pattern=r (scale~1); fixed=global alpha only") | |
| p.add_argument("--adaptive-rank", action="store_true", default=True, help="Enable per-layer adaptive rank") | |
| p.add_argument("--no-adaptive-rank", action="store_false", dest="adaptive_rank") | |
| p.add_argument("--energy-threshold", type=float, default=0.95, help="Energy threshold for adaptive rank (0-1)") | |
| p.add_argument("--min-rank", type=int, default=8, help="Minimum rank when adaptive rank is enabled") | |
| p.add_argument("--max-rank", type=int, default=256, help="Maximum rank when adaptive rank is enabled") | |
| # Layer mapping | |
| p.add_argument("--map-schedule", choices=["linear", "sigmoid"], default="linear", | |
| help="Teacher->student layer mapping schedule") | |
| p.add_argument("--sigmoid-k", type=float, default=0.15, help="Sigmoid schedule slope") | |
| p.add_argument("--interp-mode", choices=["schedule", "lsq"], default="schedule", | |
| help="Teacher floor/ceil mix: schedule=use mapping schedule weight; lsq=per-matrix least-squares mix weight in student space (clamped 0..1).") | |
| # Merge / regularization | |
| # DARE (off by default in v3.7) | |
| p.add_argument("--dare", action="store_true", default=False, help="Enable DARE (dropout on delta)") | |
| p.add_argument("--no-dare", action="store_false", dest="dare") | |
| p.add_argument("--dare-drop-rate", type=float, default=0.3, help="DARE dropout probability") | |
| p.add_argument("--auto-adjust-dare", action="store_true", default=True, | |
| help="Auto-adjust DARE for cross-architecture (when DARE enabled)") | |
| p.add_argument("--no-auto-adjust-dare", action="store_false", dest="auto_adjust_dare") | |
| # TIES (on by default) | |
| p.add_argument("--ties", action="store_true", default=True, help="Enable TIES merging when combining deltas") | |
| p.add_argument("--no-ties", action="store_false", dest="ties") | |
| p.add_argument("--ties-density", type=float, default=0.3, help="Keep ratio for TIES trimming (0-1)") | |
| p.add_argument("--ties-trim-single", action="store_true", default=True, | |
| help="Apply the TIES Trim step to single-delta cases (default: on; matches v3.7.15).") | |
| p.add_argument("--no-ties-trim-single", action="store_false", dest="ties_trim_single", | |
| help="Disable TIES Trim for single-delta cases.") | |
| # Stability guard | |
| p.add_argument("--max-delta-ratio", type=float, default=0.35, | |
| help="If >0, scales delta when ||delta||/||student|| exceeds this value") | |
| # SVD backend controls | |
| svd_group = p.add_argument_group("SVD Options") | |
| svd_group.add_argument("--svd-mode", choices=["auto", "full", "randomized", "krylov", "adaptive", "drf", "aurora"], default="auto", | |
| help="SVD backend: auto|full|randomized|krylov|adaptive|drf|aurora") | |
| svd_group.add_argument("--svd-rand-iter", type=int, default=2, | |
| help="Low-rank iterations: randomized=subspace iters (Alg. 4.4), krylov=depth, adaptive=per-block iters") | |
| svd_group.add_argument("--svd-rand-oversamples", type=int, default=8, | |
| help="Low-rank oversamples (k = rank + oversamples) for randomized/krylov") | |
| svd_group.add_argument("--svd-auto-min-dim", type=int, default=1024, | |
| help="In auto mode, prefer low-rank SVD when min(m,n) >= this") | |
| svd_group.add_argument("--svd-auto-full-rank-ratio", type=float, default=0.6, | |
| help="In auto mode, use full SVD when rank >= ratio*min(m,n)") | |
| svd_group.add_argument("--svd-auto-lowrank", choices=["randomized", "krylov", "drf", "aurora"], default="randomized", | |
| help="In auto mode (when low-rank is selected), choose randomized, krylov, drf, or aurora") | |
| svd_group.add_argument("--svd-adaptive-block-size", type=int, default=32, | |
| help="(adaptive) columns added per adaptive iteration (Alg. 4.2)") | |
| svd_group.add_argument("--svd-adaptive-n-test", type=int, default=8, | |
| help="(adaptive) test vectors for residual estimate (Alg. 4.2)") | |
| svd_group.add_argument("--svd-drf-steps", type=int, default=1, | |
| help="(drf) Number of dual-residual folding refinement rounds (>=0).") | |
| svd_group.add_argument("--svd-drf-theta", type=float, default=0.5, | |
| help="(drf) Residual scaling exponent in [0,1]. 0=no scaling; 1=full 1/sigma scaling.") | |
| svd_group.add_argument("--svd-drf-resid-eps", type=float, default=1e-8, | |
| help="(drf) Numerical floor for sigma in residual scaling.") | |
| svd_group.add_argument("--svd-aurora-steps", type=int, default=1, | |
| help="(aurora) Number of AURORA refinement rounds (>=0).") | |
| svd_group.add_argument("--svd-aurora-order", type=int, choices=[1, 2], default=2, | |
| help="(aurora) Residual order: 1=first-order residual only; 2=adds second-order residual response via A/A^T.") | |
| svd_group.add_argument("--svd-aurora-theta1", type=float, default=0.5, | |
| help="(aurora) Level-1 residual scaling exponent in [0,1]. 0=no scaling; 1=full 1/sigma scaling.") | |
| svd_group.add_argument("--svd-aurora-theta2", type=float, default=0.25, | |
| help="(aurora) Level-2 residual scaling exponent in [0,1].") | |
| svd_group.add_argument("--svd-aurora-resid-eps", type=float, default=1e-8, | |
| help="(aurora) Numerical floor for sigma in residual scaling.") | |
| svd_group.add_argument("--svd-aurora-level2-keep", type=float, default=1.0, | |
| help="(aurora) Fraction (0..1] of components kept for the level-2 term; <1 keeps only the most residual-dominant components.") | |
| # Cross-architecture projection controls | |
| cross_group = p.add_argument_group("Cross-Architecture Projection") | |
| cross_group.add_argument("--svd-projection", action="store_true", default=True, | |
| help="Enable SVD-based teacher->student tensor projection") | |
| cross_group.add_argument("--no-svd-projection", action="store_false", dest="svd_projection") | |
| cross_group.add_argument("--projection-rank", type=int, default=256, | |
| help="Max rank used for SVD projection (bounded for cost control)") | |
| cross_group.add_argument("--projection-min-rank", type=int, default=8, | |
| help="Min rank for projection when projection-adaptive-rank is enabled") | |
| cross_group.add_argument("--projection-adaptive-rank", action="store_true", default=True, | |
| help="Enable energy-based adaptive rank for projection SVD") | |
| cross_group.add_argument("--no-projection-adaptive-rank", action="store_false", dest="projection_adaptive_rank") | |
| cross_group.add_argument("--projection-energy-threshold", type=float, default=0.99, | |
| help="Energy threshold for projection adaptive rank (0-1)") | |
| cross_group.add_argument("--projection-rand-iter", type=int, default=2, | |
| help="Randomized SVD power iterations for projection") | |
| cross_group.add_argument("--projection-rand-oversamples", type=int, default=8, | |
| help="Randomized SVD oversamples for projection") | |
| cross_group.add_argument("--subspace-delta", type=str, default="off", | |
| choices=["off", "auto", "diag", "full"], | |
| help="Cross-arch: subspace-delta mode. off=disable (use full-delta), diag=diagonal core only, full=full core, auto=diag if off-diagonal energy is high") | |
| cross_group.add_argument("--subspace-offdiag-threshold", type=float, default=0.55, | |
| help="subspace-delta auto: if off-diagonal Frobenius fraction > threshold, fall back to diag") | |
| cross_group.add_argument("--subspace-denom", type=str, default="core", choices=["core", "full"], | |
| help="subspace-delta delta-clip denominator: core=||diag(S_student)||, full=||W_student||") | |
| cross_group.add_argument("--no-subspace-norm-match", action="store_true", | |
| help="Disable subspace core norm matching (scale teacher core to student core norm)") | |
| # Performance / memory | |
| perf = p.add_argument_group("Performance / Memory") | |
| perf.add_argument("--num-gpus", type=int, default=1) | |
| perf.add_argument("--gpu-memory-fraction", type=float, default=0.95) | |
| perf.add_argument("--no-cuda-streams", action="store_true", help="Disable CUDA streams") | |
| perf.add_argument("--cuda-streams", type=int, default=3) | |
| perf.add_argument("--svd-batch-size", type=int, default=1) | |
| perf.add_argument("--micro-batch", type=int, default=8, help="Tensors per IO micro-batch") | |
| perf.add_argument("--prefetch-layers", type=int, default=2) | |
| perf.add_argument("--offload-strategy", choices=["cpu", "disk", "none"], default="cpu") | |
| perf.add_argument("--offload-dir", type=str, default=None) | |
| perf.add_argument("--pinned-memory", action="store_true", default=True) | |
| perf.add_argument("--no-pinned-memory", action="store_false", dest="pinned_memory") | |
| perf.add_argument("--max-cpu-memory", type=float, default=96.0) | |
| perf.add_argument("--chunk-threshold", type=int, default=512, help="MB: use randomized SVD above this size") | |
| perf.add_argument("--chunk-size", type=int, default=128, help="MB: chunk size (reserved for future use)") | |
| # Precision | |
| p.add_argument("--mixed-precision", action="store_true", default=True) | |
| p.add_argument("--no-mixed-precision", action="store_false", dest="mixed_precision") | |
| # Module selection | |
| p.add_argument("--include", type=str, default="self_attn|mlp", | |
| help="Regex include filter for tensor keys") | |
| p.add_argument("--exclude", type=str, default="", | |
| help="Regex exclude filter for tensor keys") | |
| p.add_argument("--include-embed", action="store_true", default=False, | |
| help="Include embedding and lm_head weights (only safe when shapes match)") | |
| # Calibration-aware SVD (optional) | |
| calib = p.add_argument_group("Calibration-aware SVD (optional)") | |
| calib.add_argument("--calib-mode", choices=["none", "rms", "cov"], default="none", | |
| help="Enable calibration-aware whitening before SVD. none=disable, rms=diagonal RMS whitening, cov=full-covariance (SVD-LLM-style) whitening.") | |
| calib.add_argument("--calib-data", type=str, default=None, | |
| help="Calibration prompts file: .txt (1 prompt/line) or .jsonl ({\"text\":...}). If omitted, a small built-in prompt set is used.") | |
| calib.add_argument("--calib-format", choices=["auto", "txt", "jsonl", "json", "alpaca"], default="auto", | |
| help="Calibration file format (auto uses extension).") | |
| calib.add_argument("--calib-alpaca-template", choices=["classic", "plain"], default="classic", | |
| help="Alpaca prompt template for calibration (only for calib-format=alpaca/auto).") | |
| calib.add_argument("--calib-alpaca-include-output", action="store_true", | |
| help="Include Alpaca output field in calibration forward (alpaca only).") | |
| calib.add_argument("--calib-max-samples", type=int, default=128, | |
| help="Max number of calibration prompts to use.") | |
| calib.add_argument("--calib-max-length", type=int, default=256, | |
| help="Max token length per prompt (tokenizer truncation).") | |
| calib.add_argument("--calib-padding", type=str, default="longest", choices=["longest","max_length"], | |
| help="Padding strategy for calibration tokenization (longest minimizes padding).") | |
| calib.add_argument("--calib-no-attention-mask", action="store_true", | |
| help="Disable attention_mask-based padding exclusion in calibration stats.") | |
| calib.add_argument("--calib-batch-size", type=int, default=1, | |
| help="Batch size for calibration forward passes.") | |
| calib.add_argument("--calib-device", choices=["auto", "cuda", "cpu"], default="auto", | |
| help="Device placement for calibration model forward (auto=cuda if available).") | |
| calib.add_argument("--calib-dtype", choices=["auto", "bf16", "fp16", "fp32"], default="auto", | |
| help="dtype for calibration forward (auto: bf16 if supported on cuda else fp16; cpu->fp32).") | |
| calib.add_argument("--calib-load", type=str, default=None, | |
| help="Load precomputed calibration stats (.safetensors). Skips forward calibration.") | |
| calib.add_argument("--calib-save", type=str, default=None, | |
| help="Save calibration stats (.safetensors) for reuse.") | |
| calib.add_argument("--calib-eps", type=float, default=1e-6, | |
| help="Epsilon for whitening (clamp min).") | |
| calib.add_argument("--calib-cov-chunk-layers", type=int, default=4, | |
| help="(cov mode) Number of layers per calibration pass (controls memory).") | |
| calib.add_argument("--calib-cov-max-dim", type=int, default=8192, | |
| help="(cov mode) Max in_features dimension to collect full covariance (skip larger dims).") | |
| calib.add_argument("--calib-cov-groups", type=str, default="qkv,mlp,o", | |
| help="(cov mode) Comma-separated groups to collect: qkv,o,mlp,down.") | |
| calib.add_argument("--calib-cov-store-dtype", choices=["fp16", "bf16", "fp32"], default="fp16", | |
| help="(cov mode) dtype for storing Cholesky factors in calib-save.") | |
| calib.add_argument("--calib-no-in", action="store_true", | |
| help="Disable input RMS collection (not recommended).") | |
| calib.add_argument("--calib-no-out", action="store_true", | |
| help="Disable output RMS collection (reduces overhead; input-only whitening).") | |
| # MoE (kept for compatibility) | |
| p.add_argument("--moe-method", type=str, default="none") | |
| p.add_argument("--max-experts", type=int, default=2) | |
| # Misc | |
| p.add_argument("--seed", type=int, default=1234) | |
| p.add_argument("--quiet", action="store_true") | |
| return p | |
| def main(): | |
| args = build_argparser().parse_args() | |
| config = DistillConfig( | |
| teacher_path=args.teacher, | |
| student_path=args.student, | |
| output_path=args.output, | |
| rank_default=args.rank, | |
| rank_attn=args.rank_attn, | |
| rank_mlp=args.rank_mlp, | |
| lora_alpha=args.lora_alpha, | |
| alpha_mode=args.alpha_mode, | |
| use_adaptive_rank=args.adaptive_rank, | |
| energy_threshold=args.energy_threshold, | |
| min_rank=args.min_rank, | |
| max_rank=args.max_rank, | |
| use_dare=args.dare, | |
| dare_drop_rate=args.dare_drop_rate, | |
| use_ties=args.ties, | |
| ties_density=args.ties_density, | |
| ties_trim_single=args.ties_trim_single, | |
| max_delta_ratio=args.max_delta_ratio, | |
| # Calibration-aware SVD | |
| calibration_mode=args.calib_mode, | |
| calib_data=args.calib_data, | |
| calib_format=args.calib_format, | |
| calib_alpaca_template=args.calib_alpaca_template, | |
| calib_alpaca_include_output=bool(args.calib_alpaca_include_output), | |
| calib_max_samples=args.calib_max_samples, | |
| calib_max_length=args.calib_max_length, | |
| calib_padding=args.calib_padding, | |
| calib_use_attention_mask=(not args.calib_no_attention_mask), | |
| calib_batch_size=args.calib_batch_size, | |
| calib_device=args.calib_device, | |
| calib_dtype=args.calib_dtype, | |
| calib_load=args.calib_load, | |
| calib_save=args.calib_save, | |
| calib_eps=args.calib_eps, | |
| calib_cov_chunk_layers=args.calib_cov_chunk_layers, | |
| calib_cov_max_dim=args.calib_cov_max_dim, | |
| calib_cov_groups=args.calib_cov_groups, | |
| calib_cov_store_dtype=args.calib_cov_store_dtype, | |
| calib_collect_in=(not args.calib_no_in), | |
| calib_collect_out=(not args.calib_no_out), | |
| moe_merge_method=args.moe_method, | |
| max_experts_to_blend=args.max_experts, | |
| map_schedule=args.map_schedule, | |
| sigmoid_k=args.sigmoid_k, | |
| interp_mode=args.interp_mode, | |
| # SVD backend | |
| svd_mode=args.svd_mode, | |
| svd_randomized_iter=args.svd_rand_iter, | |
| svd_randomized_oversamples=args.svd_rand_oversamples, | |
| svd_auto_min_dim=args.svd_auto_min_dim, | |
| svd_auto_full_rank_ratio=args.svd_auto_full_rank_ratio, | |
| svd_auto_lowrank=args.svd_auto_lowrank, | |
| svd_adaptive_block_size=args.svd_adaptive_block_size, | |
| svd_adaptive_n_test=args.svd_adaptive_n_test, | |
| svd_drf_steps=args.svd_drf_steps, | |
| svd_drf_theta=args.svd_drf_theta, | |
| svd_drf_resid_eps=args.svd_drf_resid_eps, | |
| svd_aurora_steps=args.svd_aurora_steps, | |
| svd_aurora_order=args.svd_aurora_order, | |
| svd_aurora_theta1=args.svd_aurora_theta1, | |
| svd_aurora_theta2=args.svd_aurora_theta2, | |
| svd_aurora_resid_eps=args.svd_aurora_resid_eps, | |
| svd_aurora_level2_keep=args.svd_aurora_level2_keep, | |
| # Cross-arch projection | |
| use_svd_projection=args.svd_projection, | |
| projection_rank=args.projection_rank, | |
| projection_min_rank=args.projection_min_rank, | |
| projection_adaptive_rank=args.projection_adaptive_rank, | |
| projection_energy_threshold=args.projection_energy_threshold, | |
| projection_randomized_iter=args.projection_rand_iter, | |
| projection_randomized_oversamples=args.projection_rand_oversamples, | |
| subspace_delta=args.subspace_delta, | |
| subspace_offdiag_threshold=args.subspace_offdiag_threshold, | |
| subspace_denom=args.subspace_denom, | |
| subspace_norm_match=(not args.no_subspace_norm_match), | |
| num_gpus=args.num_gpus, | |
| gpu_memory_fraction=args.gpu_memory_fraction, | |
| use_cuda_streams=not args.no_cuda_streams, | |
| num_cuda_streams=args.cuda_streams, | |
| use_mixed_precision=args.mixed_precision, | |
| svd_batch_size=args.svd_batch_size, | |
| offload_strategy=args.offload_strategy, | |
| offload_dir=args.offload_dir, | |
| prefetch_layers=args.prefetch_layers, | |
| use_pinned_memory=args.pinned_memory, | |
| max_cpu_memory_gb=args.max_cpu_memory, | |
| chunk_threshold_mb=args.chunk_threshold, | |
| chunk_size_mb=args.chunk_size, | |
| micro_batch_size=args.micro_batch, | |
| include_pattern=args.include, | |
| exclude_pattern=args.exclude, | |
| include_embed_lm_head=args.include_embed, | |
| auto_adjust_dare=args.auto_adjust_dare, | |
| seed=args.seed, | |
| verbose=not args.quiet, | |
| ) | |
| seed_all(config.seed) | |
| print("=" * 80) | |
| print("Universal SVD-LoRA Distillation v4.2.0 (Cross-Arch Innovations)") | |
| print("Plug-and-Play Truncated SVD (Low-VRAM) + Low-VRAM Offload") | |
| print("=" * 80) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| print(f"\n๐ฅ๏ธ GPU: {torch.cuda.get_device_name(0)}") | |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
| try: | |
| torch.backends.cuda.preferred_linalg_library("cusolver") | |
| except Exception: | |
| pass | |
| else: | |
| device = torch.device("cpu") | |
| print("\nโ ๏ธ No GPU available, using CPU") | |
| print("\n๐ Loading model indices...") | |
| teacher_weight_map = read_index_map(config.teacher_path) | |
| student_weight_map = read_index_map(config.student_path) | |
| teacher_keys = sorted(teacher_weight_map.keys()) | |
| student_keys = sorted(student_weight_map.keys()) | |
| print(f" Teacher: {len(teacher_keys)} tensors") | |
| print(f" Student: {len(student_keys)} tensors") | |
| print("\n๐ Detecting architectures...") | |
| if HAS_TRANSFORMERS: | |
| teacher_arch = detect_architecture_from_config(config.teacher_path) | |
| student_arch = detect_architecture_from_config(config.student_path) | |
| else: | |
| teacher_arch = detect_architecture_from_weights(teacher_keys) | |
| student_arch = detect_architecture_from_weights(student_keys) | |
| teacher_arch_w = detect_architecture_from_weights(teacher_keys) | |
| if teacher_arch_w.mlp_type == MLPType.MOE: | |
| teacher_arch.mlp_type = MLPType.MOE | |
| teacher_arch.num_experts = teacher_arch_w.num_experts | |
| print(f"\n๐ Teacher: {teacher_arch.attention_type.value}, {teacher_arch.mlp_type.value}") | |
| print(f" Layers: {teacher_arch.num_hidden_layers}, Hidden: {teacher_arch.hidden_size}") | |
| if teacher_arch.num_experts > 0: | |
| print(f" Experts: {teacher_arch.num_experts}") | |
| print(f"\n๐ Student: {student_arch.attention_type.value}, {student_arch.mlp_type.value}") | |
| print(f" Layers: {student_arch.num_hidden_layers}, Hidden: {student_arch.hidden_size}") | |
| teacher_layers_map = scan_layers(teacher_keys) | |
| student_layers_map = scan_layers(student_keys) | |
| layer_mapping = create_layer_mapping(teacher_arch, student_arch) | |
| if layer_mapping and config.verbose: | |
| print(f"\n๐ Layer Name Mapping: {layer_mapping}") | |
| print(f"\n๐ Starting GPU-accelerated distillation...") | |
| print(f" CUDA Streams: {config.num_cuda_streams if config.use_cuda_streams else 'Disabled'}") | |
| print(f" Mixed Precision: {config.use_mixed_precision}") | |
| print(f" Offload Strategy: {config.offload_strategy}") | |
| print(f" Prefetch Layers: {config.prefetch_layers}") | |
| print(f" DARE: {config.use_dare} (drop_rate={config.dare_drop_rate})") | |
| print(f" TIES: {config.use_ties} (density={config.ties_density}, trim_single={getattr(config, 'ties_trim_single', True)})") | |
| print(f" Adaptive Rank: {config.use_adaptive_rank} (threshold={config.energy_threshold})") | |
| print(f" Rank Range: {config.min_rank} - {config.max_rank}") | |
| print(f" SVD Projection: {config.use_svd_projection}") | |
| # Calibration-aware SVD (optional) | |
| calib_stats: Dict[str, CalibStatsEntry] = {} | |
| if getattr(config, "calibration_mode", "none") != "none": | |
| print(f"\n๐งช Calibration-aware SVD: {config.calibration_mode}") | |
| if getattr(config, "calib_load", None): | |
| calib_stats = load_calibration_stats(config.calib_load) | |
| print(f" Loaded calibration stats: {len(calib_stats)} modules") | |
| print(f" From: {config.calib_load}") | |
| else: | |
| calib_stats = collect_calibration_stats(config.student_path, student_arch, config) | |
| print(f" Collected calibration stats: {len(calib_stats)} modules") | |
| if getattr(config, "calib_save", None): | |
| save_calibration_stats(calib_stats, config.calib_save) | |
| print(f" Saved calibration stats to: {config.calib_save}") | |
| start_time = time.time() | |
| distiller = AcceleratedDistiller(config, device, teacher_arch, student_arch) | |
| distiller.calib_stats = calib_stats if calib_stats else None | |
| lora_weights, rank_info = distiller.run( | |
| student_keys=student_keys, | |
| teacher_weight_map=teacher_weight_map, | |
| student_weight_map=student_weight_map, | |
| teacher_layers_map=teacher_layers_map, | |
| student_layers_map=student_layers_map, | |
| teacher_arch=teacher_arch, | |
| student_arch=student_arch, | |
| layer_name_mapping=layer_mapping | |
| ) | |
| elapsed = time.time() - start_time | |
| if not lora_weights: | |
| print("โ CRITICAL: No LoRA weights were generated!") | |
| return 1 | |
| save_peft_adapter(lora_weights, config.output_path, config, student_arch, rank_info) | |
| print(f"\nโ Generated {len(lora_weights) // 2} LoRA pairs in {elapsed:.1f}s") | |
| print(f" Throughput: {len(lora_weights) // 2 / elapsed:.1f} layers/sec") | |
| if rank_info: | |
| ranks = list(rank_info.values()) | |
| print(f" Rank range: {min(ranks)}-{max(ranks)} (avg: {sum(ranks)/len(ranks):.1f})") | |
| print("\n๐ Distillation Complete!") | |
| print("=" * 80) | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |