from __future__ import annotations import time from typing import TYPE_CHECKING from hearthnet.services.image.backends.base import ImageDescription if TYPE_CHECKING: pass _TASK_MAP = { "caption": "", "detailed_caption": "", "ocr": "", "object_detection": "", } # Allowlist of approved model IDs to prevent RCE via trust_remote_code _APPROVED_MODELS = { "microsoft/Florence-2-large", "microsoft/Florence-2-base", } class Florence2Backend: """Vision backend using Microsoft Florence-2.""" name = "florence2" def __init__( self, model: str = "microsoft/Florence-2-large", device: str = "auto", ) -> None: if model not in _APPROVED_MODELS: raise ValueError( f"Model '{model}' not in approved list. " f"Approved models: {', '.join(sorted(_APPROVED_MODELS))}" ) self._model_id = model self._device = device self._processor = None self._model = None self._loaded = False self._load_error: str | None = None def _load(self) -> bool: if self._loaded: return True if self._load_error: return False try: import torch # type: ignore[import-untyped] from transformers import ( # type: ignore[import-untyped] AutoModelForCausalLM, AutoProcessor, ) device = self._device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" self._processor = AutoProcessor.from_pretrained( # nosec B615 - revision pinned to main self._model_id, trust_remote_code=True, revision="main" ) self._model = AutoModelForCausalLM.from_pretrained( # nosec B615 - revision pinned to main self._model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, trust_remote_code=True, revision="main", ).to(device) self._device = device self._loaded = True return True except ImportError as exc: self._load_error = f"transformers/torch not installed: {exc}" return False except Exception as exc: self._load_error = str(exc) return False def _run_task(self, image, task_prompt: str) -> str: """Run a single Florence-2 task prompt and return raw text result.""" import torch # type: ignore[import-untyped] inputs = self._processor(text=task_prompt, images=image, return_tensors="pt").to( self._device ) with torch.no_grad(): generated_ids = self._model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3, do_sample=False, ) generated_text = self._processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed = self._processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height), ) # parsed is typically {task_prompt: } raw = parsed.get(task_prompt, "") if isinstance(raw, dict): return str(raw) return str(raw) async def describe( self, image_bytes: bytes, mode: str = "caption", ) -> ImageDescription: t0 = time.monotonic() if not self._load(): return ImageDescription( caption=f"[florence2 unavailable: {self._load_error}]", tags=[], objects=[], ocr_text=None, backend=self.name, ms=0, ) try: import io from PIL import Image as PILImage # type: ignore[import-untyped] pil_image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB") task_key = _TASK_MAP.get(mode, "") caption = "" tags: list[str] = [] objects: list[str] = [] ocr_text: str | None = None if mode == "ocr": raw = self._run_task(pil_image, "") ocr_text = raw caption = raw[:200] if raw else "" elif mode == "object_detection": raw = self._run_task(pil_image, "") # raw is a string repr of dict like {'': {'bboxes': [...], 'labels': [...]}} # Try to extract labels cap_text = self._run_task(pil_image, "") caption = cap_text try: import ast parsed = ast.literal_eval(raw) if isinstance(parsed, dict): inner = next(iter(parsed.values()), {}) objects = inner.get("labels", []) if isinstance(inner, dict) else [] except Exception: objects = [] else: raw = self._run_task(pil_image, task_key) caption = raw elapsed_ms = int((time.monotonic() - t0) * 1000) return ImageDescription( caption=caption, tags=tags, objects=objects, ocr_text=ocr_text, backend=self.name, ms=elapsed_ms, ) except Exception as exc: return ImageDescription( caption=f"[florence2 error: {exc}]", tags=[], objects=[], ocr_text=None, backend=self.name, ms=int((time.monotonic() - t0) * 1000), ) def health(self) -> dict: available = self._load_error is None return { "backend": self.name, "model": self._model_id, "loaded": self._loaded, "available": available, "error": self._load_error, }