File size: 6,246 Bytes
4cd8837
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
21c2afa
 
 
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
21c2afa
 
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
4cd8837
 
 
 
 
6c38d43
4aaae80
4cd8837
6c38d43
4cd8837
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from __future__ import annotations

import time
from typing import TYPE_CHECKING

from hearthnet.services.image.backends.base import ImageDescription

if TYPE_CHECKING:
    pass

_TASK_MAP = {
    "caption": "<CAPTION>",
    "detailed_caption": "<DETAILED_CAPTION>",
    "ocr": "<OCR>",
    "object_detection": "<OD>",
}

# Allowlist of approved model IDs to prevent RCE via trust_remote_code
_APPROVED_MODELS = {
    "microsoft/Florence-2-large",
    "microsoft/Florence-2-base",
}


class Florence2Backend:
    """Vision backend using Microsoft Florence-2."""

    name = "florence2"

    def __init__(
        self,
        model: str = "microsoft/Florence-2-large",
        device: str = "auto",
    ) -> None:
        if model not in _APPROVED_MODELS:
            raise ValueError(
                f"Model '{model}' not in approved list. "
                f"Approved models: {', '.join(sorted(_APPROVED_MODELS))}"
            )
        self._model_id = model
        self._device = device
        self._processor = None
        self._model = None
        self._loaded = False
        self._load_error: str | None = None

    def _load(self) -> bool:
        if self._loaded:
            return True
        if self._load_error:
            return False
        try:
            import torch  # type: ignore[import-untyped]
            from transformers import (  # type: ignore[import-untyped]
                AutoModelForCausalLM,
                AutoProcessor,
            )

            device = self._device
            if device == "auto":
                device = "cuda" if torch.cuda.is_available() else "cpu"

            self._processor = AutoProcessor.from_pretrained(  # nosec B615 - revision pinned to main
                self._model_id, trust_remote_code=True, revision="main"
            )
            self._model = AutoModelForCausalLM.from_pretrained(  # nosec B615 - revision pinned to main
                self._model_id,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                trust_remote_code=True,
                revision="main",
            ).to(device)
            self._device = device
            self._loaded = True
            return True
        except ImportError as exc:
            self._load_error = f"transformers/torch not installed: {exc}"
            return False
        except Exception as exc:
            self._load_error = str(exc)
            return False

    def _run_task(self, image, task_prompt: str) -> str:
        """Run a single Florence-2 task prompt and return raw text result."""
        import torch  # type: ignore[import-untyped]

        inputs = self._processor(text=task_prompt, images=image, return_tensors="pt").to(
            self._device
        )
        with torch.no_grad():
            generated_ids = self._model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                num_beams=3,
                do_sample=False,
            )
        generated_text = self._processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed = self._processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(image.width, image.height),
        )
        # parsed is typically {task_prompt: <result>}
        raw = parsed.get(task_prompt, "")
        if isinstance(raw, dict):
            return str(raw)
        return str(raw)

    async def describe(
        self,
        image_bytes: bytes,
        mode: str = "caption",
    ) -> ImageDescription:
        t0 = time.monotonic()

        if not self._load():
            return ImageDescription(
                caption=f"[florence2 unavailable: {self._load_error}]",
                tags=[],
                objects=[],
                ocr_text=None,
                backend=self.name,
                ms=0,
            )

        try:
            import io

            from PIL import Image as PILImage  # type: ignore[import-untyped]

            pil_image = PILImage.open(io.BytesIO(image_bytes)).convert("RGB")

            task_key = _TASK_MAP.get(mode, "<CAPTION>")

            caption = ""
            tags: list[str] = []
            objects: list[str] = []
            ocr_text: str | None = None

            if mode == "ocr":
                raw = self._run_task(pil_image, "<OCR>")
                ocr_text = raw
                caption = raw[:200] if raw else ""
            elif mode == "object_detection":
                raw = self._run_task(pil_image, "<OD>")
                # raw is a string repr of dict like {'<OD>': {'bboxes': [...], 'labels': [...]}}
                # Try to extract labels
                cap_text = self._run_task(pil_image, "<CAPTION>")
                caption = cap_text
                try:
                    import ast

                    parsed = ast.literal_eval(raw)
                    if isinstance(parsed, dict):
                        inner = next(iter(parsed.values()), {})
                        objects = inner.get("labels", []) if isinstance(inner, dict) else []
                except Exception:
                    objects = []
            else:
                raw = self._run_task(pil_image, task_key)
                caption = raw

            elapsed_ms = int((time.monotonic() - t0) * 1000)
            return ImageDescription(
                caption=caption,
                tags=tags,
                objects=objects,
                ocr_text=ocr_text,
                backend=self.name,
                ms=elapsed_ms,
            )
        except Exception as exc:
            return ImageDescription(
                caption=f"[florence2 error: {exc}]",
                tags=[],
                objects=[],
                ocr_text=None,
                backend=self.name,
                ms=int((time.monotonic() - t0) * 1000),
            )

    def health(self) -> dict:
        available = self._load_error is None
        return {
            "backend": self.name,
            "model": self._model_id,
            "loaded": self._loaded,
            "available": available,
            "error": self._load_error,
        }