GitHub Actions
feat: federated RAG + comprehensive security fixes
6c38d43
Raw
History Blame
4.93 kB
"""TrOCR backend via Hugging Face Transformers (optional dependency)."""
from __future__ import annotations
import asyncio
import io
import time
from typing import Any
class TrocrBackend:
name = "trocr"
def __init__(
self,
model: str = "microsoft/trocr-large-handwritten",
device: str = "auto",
) -> None:
self._model_name = model
self._device = device
self._processor: Any = None
self._model: Any = None
self._loaded = False
@property
def supported_languages(self) -> list[str]:
# TrOCR is primarily English/handwriting; can be fine-tuned for others
return ["eng", "deu"]
def _resolve_device(self) -> str:
if self._device != "auto":
return self._device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def _load_model_sync(self) -> None:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel # type: ignore[import]
device = self._resolve_device()
self._processor = TrOCRProcessor.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned
self._model = VisionEncoderDecoderModel.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned
self._model.to(device)
self._loaded = True
async def _ensure_loaded(self) -> None:
if not self._loaded:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._load_model_sync)
def health(self) -> dict:
try:
import transformers # noqa: F401
except ImportError:
return {
"backend": self.name,
"status": "unavailable",
"reason": "transformers not installed",
}
return {"backend": self.name, "status": "ok", "model": self._model_name}
def _run_trocr_sync(self, image_bytes: bytes) -> tuple[str, float]:
import torch
from PIL import Image # type: ignore[import]
device = self._resolve_device()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
pixel_values = self._processor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = self._model.generate(pixel_values, max_new_tokens=512)
text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text, 1.0
async def ocr_image(
self,
image_bytes: bytes,
languages: list[str] | None = None,
) -> Any:
from hearthnet.services.ocr.backends.base import OcrBlock, OcrPageResult, OcrResult
await self._ensure_loaded()
t0 = time.monotonic()
loop = asyncio.get_running_loop()
text, confidence = await loop.run_in_executor(None, self._run_trocr_sync, image_bytes)
ms = int((time.monotonic() - t0) * 1000)
block = OcrBlock(text=text, confidence=confidence, bbox=None, language=None)
page = OcrPageResult(
page=1, blocks=[block], full_text=text, confidence_avg=confidence, ms=ms
)
return OcrResult(pages=[page], detected_languages=[], backend=self.name, ms=ms)
async def ocr_pdf(
self,
pdf_bytes: bytes,
pages: list[int] | None = None,
languages: list[str] | None = None,
) -> Any:
from hearthnet.services.ocr.backends.base import OcrResult
try:
from pdf2image import convert_from_bytes # type: ignore[import]
except ImportError:
from hearthnet.services.ocr.backends.base import OcrPageResult
return OcrResult(
pages=[OcrPageResult(page=1, blocks=[], full_text="", confidence_avg=0.0, ms=0)],
detected_languages=[],
backend=self.name,
ms=0,
)
t0 = time.monotonic()
images = convert_from_bytes(pdf_bytes, dpi=200)
page_results = []
for idx, img in enumerate(images, start=1):
if pages and idx not in pages:
continue
buf = io.BytesIO()
img.save(buf, format="PNG")
partial = await self.ocr_image(buf.getvalue(), languages)
from hearthnet.services.ocr.backends.base import OcrPageResult
old = partial.pages[0]
page_results.append(
OcrPageResult(
page=idx,
blocks=old.blocks,
full_text=old.full_text,
confidence_avg=old.confidence_avg,
ms=old.ms,
)
)
ms = int((time.monotonic() - t0) * 1000)
return OcrResult(pages=page_results, detected_languages=[], backend=self.name, ms=ms)