Spaces:
Running on Zero
Running on Zero
| """TrOCR backend via Hugging Face Transformers (optional dependency).""" | |
| from __future__ import annotations | |
| import asyncio | |
| import io | |
| import time | |
| from typing import Any | |
| class TrocrBackend: | |
| name = "trocr" | |
| def __init__( | |
| self, | |
| model: str = "microsoft/trocr-large-handwritten", | |
| device: str = "auto", | |
| ) -> None: | |
| self._model_name = model | |
| self._device = device | |
| self._processor: Any = None | |
| self._model: Any = None | |
| self._loaded = False | |
| def supported_languages(self) -> list[str]: | |
| # TrOCR is primarily English/handwriting; can be fine-tuned for others | |
| return ["eng", "deu"] | |
| def _resolve_device(self) -> str: | |
| if self._device != "auto": | |
| return self._device | |
| try: | |
| import torch | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| except ImportError: | |
| return "cpu" | |
| def _load_model_sync(self) -> None: | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel # type: ignore[import] | |
| device = self._resolve_device() | |
| self._processor = TrOCRProcessor.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned | |
| self._model = VisionEncoderDecoderModel.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned | |
| self._model.to(device) | |
| self._loaded = True | |
| async def _ensure_loaded(self) -> None: | |
| if not self._loaded: | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor(None, self._load_model_sync) | |
| def health(self) -> dict: | |
| try: | |
| import transformers # noqa: F401 | |
| except ImportError: | |
| return { | |
| "backend": self.name, | |
| "status": "unavailable", | |
| "reason": "transformers not installed", | |
| } | |
| return {"backend": self.name, "status": "ok", "model": self._model_name} | |
| def _run_trocr_sync(self, image_bytes: bytes) -> tuple[str, float]: | |
| import torch | |
| from PIL import Image # type: ignore[import] | |
| device = self._resolve_device() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| pixel_values = self._processor(images=image, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| generated_ids = self._model.generate(pixel_values, max_new_tokens=512) | |
| text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return text, 1.0 | |
| async def ocr_image( | |
| self, | |
| image_bytes: bytes, | |
| languages: list[str] | None = None, | |
| ) -> Any: | |
| from hearthnet.services.ocr.backends.base import OcrBlock, OcrPageResult, OcrResult | |
| await self._ensure_loaded() | |
| t0 = time.monotonic() | |
| loop = asyncio.get_running_loop() | |
| text, confidence = await loop.run_in_executor(None, self._run_trocr_sync, image_bytes) | |
| ms = int((time.monotonic() - t0) * 1000) | |
| block = OcrBlock(text=text, confidence=confidence, bbox=None, language=None) | |
| page = OcrPageResult( | |
| page=1, blocks=[block], full_text=text, confidence_avg=confidence, ms=ms | |
| ) | |
| return OcrResult(pages=[page], detected_languages=[], backend=self.name, ms=ms) | |
| async def ocr_pdf( | |
| self, | |
| pdf_bytes: bytes, | |
| pages: list[int] | None = None, | |
| languages: list[str] | None = None, | |
| ) -> Any: | |
| from hearthnet.services.ocr.backends.base import OcrResult | |
| try: | |
| from pdf2image import convert_from_bytes # type: ignore[import] | |
| except ImportError: | |
| from hearthnet.services.ocr.backends.base import OcrPageResult | |
| return OcrResult( | |
| pages=[OcrPageResult(page=1, blocks=[], full_text="", confidence_avg=0.0, ms=0)], | |
| detected_languages=[], | |
| backend=self.name, | |
| ms=0, | |
| ) | |
| t0 = time.monotonic() | |
| images = convert_from_bytes(pdf_bytes, dpi=200) | |
| page_results = [] | |
| for idx, img in enumerate(images, start=1): | |
| if pages and idx not in pages: | |
| continue | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| partial = await self.ocr_image(buf.getvalue(), languages) | |
| from hearthnet.services.ocr.backends.base import OcrPageResult | |
| old = partial.pages[0] | |
| page_results.append( | |
| OcrPageResult( | |
| page=idx, | |
| blocks=old.blocks, | |
| full_text=old.full_text, | |
| confidence_avg=old.confidence_avg, | |
| ms=old.ms, | |
| ) | |
| ) | |
| ms = int((time.monotonic() - t0) * 1000) | |
| return OcrResult(pages=page_results, detected_languages=[], backend=self.name, ms=ms) | |