from __future__ import annotations from pathlib import Path from typing import Any, Optional import torch from transformers import PreTrainedModel from .configuration_ptd_qwen2 import PTDQwen2Config from .model import PTDConfig, PTDQwen2ForCausalLM as _CorePTDModel try: from huggingface_hub import snapshot_download except Exception: snapshot_download = None def _resolve_torch_dtype(torch_dtype: Any) -> torch.dtype: if isinstance(torch_dtype, torch.dtype): return torch_dtype if torch_dtype is None: return torch.bfloat16 if torch.cuda.is_available() else torch.float32 s = str(torch_dtype).lower() if "bfloat16" in s: return torch.bfloat16 if "float16" in s: return torch.float16 if "float32" in s: return torch.float32 raise ValueError(f"Unsupported torch_dtype: {torch_dtype}") class PTDQwen2ForCausalLM(PreTrainedModel): config_class = PTDQwen2Config main_input_name = "input_ids" _supports_cache_class = False def __init__(self, config: PTDQwen2Config) -> None: super().__init__(config) self.ptd_model: Optional[_CorePTDModel] = None def _supports_default_dynamic_cache(self) -> bool: return False def _supports_static_cache(self) -> bool: return False def _supports_quantized_cache(self) -> bool: return False @staticmethod def _resolve_repo_dir(path_or_repo: str, local_files_only: bool = False) -> Path: p = Path(path_or_repo) if p.exists(): return p if snapshot_download is None: raise RuntimeError( "huggingface_hub is required to load from repo id. " "Install huggingface_hub or pass a local path." ) local = snapshot_download( repo_id=path_or_repo, local_files_only=local_files_only, allow_patterns=[ "*.json", "*.md", "*.pt", "*.py", ".gitattributes", ], ) return Path(local) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *args: Any, **kwargs: Any): keep_rate = kwargs.pop("keep_rate", None) device = kwargs.pop("device", None) device_map = kwargs.pop("device_map", None) local_files_only = bool(kwargs.pop("local_files_only", False)) torch_dtype = _resolve_torch_dtype(kwargs.pop("torch_dtype", None)) repo_dir = cls._resolve_repo_dir(pretrained_model_name_or_path, local_files_only=local_files_only) config = PTDQwen2Config.from_pretrained(str(repo_dir), trust_remote_code=True) model = cls(config) ptd_cfg = PTDConfig(**config.ptd_config) core = _CorePTDModel.from_pretrained( config.base_model, ptd_config=ptd_cfg, torch_dtype=torch_dtype, ) if config.package_type == "full_state": state = torch.load(repo_dir / "ptd_model_state.pt", map_location="cpu", weights_only=True) core.load_state_dict(state, strict=True) elif config.package_type == "router_only": r_state = torch.load(repo_dir / "router_state.pt", map_location="cpu", weights_only=True) core.routers.load_state_dict(r_state, strict=True) else: raise ValueError(f"Unsupported package_type: {config.package_type}") target_keep = config.recommended_keep_rate if keep_rate is None else float(keep_rate) core.set_keep_rate(target_keep) if device is None: if isinstance(device_map, str) and device_map == "cpu": device = "cpu" else: device = "cuda" if torch.cuda.is_available() else "cpu" core = core.to(device=device, dtype=torch_dtype) core.eval() model.ptd_model = core return model def _ensure_loaded(self) -> _CorePTDModel: if self.ptd_model is None: raise RuntimeError("Model is not initialized. Use from_pretrained().") return self.ptd_model @property def device(self) -> torch.device: return self._ensure_loaded().device def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Any, ): core = self._ensure_loaded() kwargs.setdefault("ptd_use_sparse_cache", True) return core( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, cache_position=cache_position, **kwargs, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Any = None, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any, ): if past_key_values is not None: input_ids = input_ids[:, -1:] if attention_mask is not None: attention_mask = attention_mask[:, -1:] model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} if attention_mask is not None: model_inputs["attention_mask"] = attention_mask model_inputs.update(kwargs) return model_inputs