creative-intelligence-scorer / model_loader.py
PranavCR01
feat: swap backbone from CLIP to SigLIP 2 (google/siglip2-base-patch16-224)
233452b
Raw
History Blame
1.94 kB
import os
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor
from clip_head import CreativeScorer
_model: CreativeScorer | None = None
_processor: AutoProcessor | None = None
def get_model() -> tuple[CreativeScorer, AutoProcessor]:
global _model, _processor
if _model is None:
try:
hf_repo = os.environ["HF_MODEL_REPO"]
hf_token = os.getenv("HF_TOKEN")
print(f"[model_loader] Loading from repo: {hf_repo}", flush=True)
_processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
print("[model_loader] Processor loaded", flush=True)
_model = CreativeScorer()
print("[model_loader] Model instantiated", flush=True)
try:
weights_path = hf_hub_download(
repo_id=hf_repo,
filename="model.safetensors",
token=hf_token,
)
from safetensors.torch import load_file
state_dict = load_file(weights_path, device="cpu")
print("[model_loader] Weights loaded from safetensors", flush=True)
except Exception as e:
print(f"[model_loader] safetensors failed: {e}, trying .bin", flush=True)
weights_path = hf_hub_download(
repo_id=hf_repo,
filename="pytorch_model.bin",
token=hf_token,
)
state_dict = torch.load(weights_path, map_location="cpu")
_model.load_state_dict(state_dict)
_model.eval()
print("[model_loader] Model ready", flush=True)
except Exception as e:
import traceback
print(f"[model_loader] FATAL: {e}", flush=True)
print(traceback.format_exc(), flush=True)
raise
return _model, _processor