Spaces:
Running on Zero
Running on Zero
File size: 4,931 Bytes
4cd8837 4aaae80 4cd8837 4aaae80 4cd8837 6c38d43 4cd8837 8ee8138 4cd8837 4aaae80 4cd8837 4aaae80 4cd8837 6c38d43 4cd8837 8ee8138 4cd8837 4aaae80 4cd8837 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """TrOCR backend via Hugging Face Transformers (optional dependency)."""
from __future__ import annotations
import asyncio
import io
import time
from typing import Any
class TrocrBackend:
name = "trocr"
def __init__(
self,
model: str = "microsoft/trocr-large-handwritten",
device: str = "auto",
) -> None:
self._model_name = model
self._device = device
self._processor: Any = None
self._model: Any = None
self._loaded = False
@property
def supported_languages(self) -> list[str]:
# TrOCR is primarily English/handwriting; can be fine-tuned for others
return ["eng", "deu"]
def _resolve_device(self) -> str:
if self._device != "auto":
return self._device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
def _load_model_sync(self) -> None:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel # type: ignore[import]
device = self._resolve_device()
self._processor = TrOCRProcessor.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned
self._model = VisionEncoderDecoderModel.from_pretrained(self._model_name, revision="main") # nosec B615 - revision pinned
self._model.to(device)
self._loaded = True
async def _ensure_loaded(self) -> None:
if not self._loaded:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._load_model_sync)
def health(self) -> dict:
try:
import transformers # noqa: F401
except ImportError:
return {
"backend": self.name,
"status": "unavailable",
"reason": "transformers not installed",
}
return {"backend": self.name, "status": "ok", "model": self._model_name}
def _run_trocr_sync(self, image_bytes: bytes) -> tuple[str, float]:
import torch
from PIL import Image # type: ignore[import]
device = self._resolve_device()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
pixel_values = self._processor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = self._model.generate(pixel_values, max_new_tokens=512)
text = self._processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text, 1.0
async def ocr_image(
self,
image_bytes: bytes,
languages: list[str] | None = None,
) -> Any:
from hearthnet.services.ocr.backends.base import OcrBlock, OcrPageResult, OcrResult
await self._ensure_loaded()
t0 = time.monotonic()
loop = asyncio.get_running_loop()
text, confidence = await loop.run_in_executor(None, self._run_trocr_sync, image_bytes)
ms = int((time.monotonic() - t0) * 1000)
block = OcrBlock(text=text, confidence=confidence, bbox=None, language=None)
page = OcrPageResult(
page=1, blocks=[block], full_text=text, confidence_avg=confidence, ms=ms
)
return OcrResult(pages=[page], detected_languages=[], backend=self.name, ms=ms)
async def ocr_pdf(
self,
pdf_bytes: bytes,
pages: list[int] | None = None,
languages: list[str] | None = None,
) -> Any:
from hearthnet.services.ocr.backends.base import OcrResult
try:
from pdf2image import convert_from_bytes # type: ignore[import]
except ImportError:
from hearthnet.services.ocr.backends.base import OcrPageResult
return OcrResult(
pages=[OcrPageResult(page=1, blocks=[], full_text="", confidence_avg=0.0, ms=0)],
detected_languages=[],
backend=self.name,
ms=0,
)
t0 = time.monotonic()
images = convert_from_bytes(pdf_bytes, dpi=200)
page_results = []
for idx, img in enumerate(images, start=1):
if pages and idx not in pages:
continue
buf = io.BytesIO()
img.save(buf, format="PNG")
partial = await self.ocr_image(buf.getvalue(), languages)
from hearthnet.services.ocr.backends.base import OcrPageResult
old = partial.pages[0]
page_results.append(
OcrPageResult(
page=idx,
blocks=old.blocks,
full_text=old.full_text,
confidence_avg=old.confidence_avg,
ms=old.ms,
)
)
ms = int((time.monotonic() - t0) * 1000)
return OcrResult(pages=page_results, detected_languages=[], backend=self.name, ms=ms)
|