GitHub Actions
feat: federated RAG + comprehensive security fixes
6c38d43
Raw
History Blame Contribute Delete
6.25 kB
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from hearthnet.services.image.backends.base import ImageDescription
if TYPE_CHECKING:
pass
_TASK_MAP = {
"caption": "<CAPTION>",
"detailed_caption": "<DETAILED_CAPTION>",
"ocr": "<OCR>",
"object_detection": "<OD>",
}
# Allowlist of approved model IDs to prevent RCE via trust_remote_code
_APPROVED_MODELS = {
"microsoft/Florence-2-large",
"microsoft/Florence-2-base",
}
class Florence2Backend:
"""Vision backend using Microsoft Florence-2."""
name = "florence2"
def __init__(
self,
model: str = "microsoft/Florence-2-large",
device: str = "auto",
) -> None:
if model not in _APPROVED_MODELS:
raise ValueError(
f"Model '{model}' not in approved list. "
f"Approved models: {', '.join(sorted(_APPROVED_MODELS))}"
)
self._model_id = model
self._device = device
self._processor = None
self._model = None
self._loaded = False
self._load_error: str | None = None
def _load(self) -> bool:
if self._loaded:
return True
if self._load_error:
return False
try:
import torch # type: ignore[import-untyped]
from transformers import ( # type: ignore[import-untyped]
AutoModelForCausalLM,
AutoProcessor,
)
device = self._device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self._processor = AutoProcessor.from_pretrained( # nosec B615 - revision pinned to main
self._model_id, trust_remote_code=True, revision="main"
)
self._model = AutoModelForCausalLM.from_pretrained( # nosec B615 - revision pinned to main
self._model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True,
revision="main",
).to(device)
self._device = device
self._loaded = True
return True
except ImportError as exc:
self._load_error = f"transformers/torch not installed: {exc}"
return False
except Exception as exc:
self._load_error = str(exc)
return False
def _run_task(self, image, task_prompt: str) -> str:
"""Run a single Florence-2 task prompt and return raw text result."""
import torch # type: ignore[import-untyped]
inputs = self._processor(text=task_prompt, images=image, return_tensors="pt").to(
self._device
)
with torch.no_grad():
generated_ids = self._model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False,
)
generated_text = self._processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed = self._processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height),
)
# parsed is typically {task_prompt: <result>}
raw = parsed.get(task_prompt, "")
if isinstance(raw, dict):
return str(raw)
return str(raw)
async def describe(
self,
image_bytes: bytes,
mode: str = "caption",
) -> ImageDescription:
t0 = time.monotonic()
if not self._load():
return ImageDescription(
caption=f"[florence2 unavailable: {self._load_error}]",
tags=[],
objects=[],
ocr_text=None,
backend=self.name,
ms=0,
)
try:
import io
from PIL import Image as PILImage # type: ignore[import-untyped]
pil_image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB")
task_key = _TASK_MAP.get(mode, "<CAPTION>")
caption = ""
tags: list[str] = []
objects: list[str] = []
ocr_text: str | None = None
if mode == "ocr":
raw = self._run_task(pil_image, "<OCR>")
ocr_text = raw
caption = raw[:200] if raw else ""
elif mode == "object_detection":
raw = self._run_task(pil_image, "<OD>")
# raw is a string repr of dict like {'<OD>': {'bboxes': [...], 'labels': [...]}}
# Try to extract labels
cap_text = self._run_task(pil_image, "<CAPTION>")
caption = cap_text
try:
import ast
parsed = ast.literal_eval(raw)
if isinstance(parsed, dict):
inner = next(iter(parsed.values()), {})
objects = inner.get("labels", []) if isinstance(inner, dict) else []
except Exception:
objects = []
else:
raw = self._run_task(pil_image, task_key)
caption = raw
elapsed_ms = int((time.monotonic() - t0) * 1000)
return ImageDescription(
caption=caption,
tags=tags,
objects=objects,
ocr_text=ocr_text,
backend=self.name,
ms=elapsed_ms,
)
except Exception as exc:
return ImageDescription(
caption=f"[florence2 error: {exc}]",
tags=[],
objects=[],
ocr_text=None,
backend=self.name,
ms=int((time.monotonic() - t0) * 1000),
)
def health(self) -> dict:
available = self._load_error is None
return {
"backend": self.name,
"model": self._model_id,
"loaded": self._loaded,
"available": available,
"error": self._load_error,
}