| import os |
|
|
| import torch |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoProcessor |
|
|
| from clip_head import CreativeScorer |
|
|
| _model: CreativeScorer | None = None |
| _processor: AutoProcessor | None = None |
|
|
|
|
| def get_model() -> tuple[CreativeScorer, AutoProcessor]: |
| global _model, _processor |
| if _model is None: |
| try: |
| hf_repo = os.environ["HF_MODEL_REPO"] |
| hf_token = os.getenv("HF_TOKEN") |
|
|
| print(f"[model_loader] Loading from repo: {hf_repo}", flush=True) |
|
|
| _processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224") |
| print("[model_loader] Processor loaded", flush=True) |
|
|
| _model = CreativeScorer() |
| print("[model_loader] Model instantiated", flush=True) |
|
|
| try: |
| weights_path = hf_hub_download( |
| repo_id=hf_repo, |
| filename="model.safetensors", |
| token=hf_token, |
| ) |
| from safetensors.torch import load_file |
| state_dict = load_file(weights_path, device="cpu") |
| print("[model_loader] Weights loaded from safetensors", flush=True) |
| except Exception as e: |
| print(f"[model_loader] safetensors failed: {e}, trying .bin", flush=True) |
| weights_path = hf_hub_download( |
| repo_id=hf_repo, |
| filename="pytorch_model.bin", |
| token=hf_token, |
| ) |
| state_dict = torch.load(weights_path, map_location="cpu") |
|
|
| _model.load_state_dict(state_dict) |
| _model.eval() |
| print("[model_loader] Model ready", flush=True) |
|
|
| except Exception as e: |
| import traceback |
| print(f"[model_loader] FATAL: {e}", flush=True) |
| print(traceback.format_exc(), flush=True) |
| raise |
|
|
| return _model, _processor |
|
|