from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, Tuple import torch from model import PTDConfig, PTDQwen2ForCausalLM try: from huggingface_hub import snapshot_download except Exception: snapshot_download = None def _resolve_repo_path(path_or_repo: str) -> Path: p = Path(path_or_repo) if p.exists(): return p if snapshot_download is None: raise RuntimeError( "huggingface_hub is required to load from repo id. " "Install it or pass a local package path." ) local = snapshot_download( repo_id=path_or_repo, allow_patterns=[ "*.json", "*.md", "*.pt", "model.py", "__init__.py", "hf_ptd_loader.py", "requirements.txt", ".gitattributes", ], ) return Path(local) def _resolve_dtype(dtype: str) -> torch.dtype: dt = dtype.lower() if dt == "bfloat16": return torch.bfloat16 if dt == "float16": return torch.float16 if dt == "float32": return torch.float32 raise ValueError(f"Unsupported dtype: {dtype}") def load_ptd_model( path_or_repo: str, *, device: str | None = None, dtype: str = "bfloat16", keep_rate: float | None = None, ) -> Tuple[PTDQwen2ForCausalLM, Dict[str, Any]]: pkg_dir = _resolve_repo_path(path_or_repo) cfg_path = pkg_dir / "ptd_package_config.json" if not cfg_path.exists(): raise FileNotFoundError(f"Missing package config: {cfg_path}") with cfg_path.open("r", encoding="utf-8") as f: package_cfg = json.load(f) ptd_cfg = PTDConfig(**package_cfg["ptd_config"]) torch_dtype = _resolve_dtype(dtype) resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu") model = PTDQwen2ForCausalLM.from_pretrained( package_cfg["base_model"], ptd_config=ptd_cfg, torch_dtype=torch_dtype, ) package_type = package_cfg["package_type"] if package_type == "router_only": router_state = torch.load(pkg_dir / "router_state.pt", map_location="cpu", weights_only=True) model.routers.load_state_dict(router_state, strict=True) elif package_type == "full_state": model_state = torch.load(pkg_dir / "ptd_model_state.pt", map_location="cpu", weights_only=True) model.load_state_dict(model_state, strict=True) else: raise ValueError(f"Unsupported package type: {package_type}") target_keep = keep_rate if keep_rate is not None else package_cfg.get("recommended_keep_rate") if target_keep is not None: model.set_keep_rate(float(target_keep)) model = model.to(device=resolved_device, dtype=torch_dtype) model.eval() return model, package_cfg