Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import asyncio | |
| import time | |
| from hearthnet.services.rerank.backends.base import ( | |
| RerankedDoc, | |
| RerankRequest, | |
| RerankResponse, | |
| ) | |
| class BgeRerankerBackend: | |
| """Cross-encoder reranker using BAAI/bge-reranker models.""" | |
| name = "bge_reranker" | |
| def __init__( | |
| self, | |
| model_id: str = "BAAI/bge-reranker-v2-m3", | |
| device: str = "auto", | |
| max_batch: int = 32, | |
| ) -> None: | |
| self._model_id = model_id | |
| self._device = device | |
| self._max_batch = max_batch | |
| self._encoder = None | |
| self._loaded = False | |
| self._load_error: str | None = None | |
| def _load(self) -> bool: | |
| if self._loaded: | |
| return True | |
| if self._load_error: | |
| return False | |
| try: | |
| import torch # type: ignore[import-untyped] | |
| from sentence_transformers import CrossEncoder # type: ignore[import-untyped] | |
| device = self._device | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._encoder = CrossEncoder(self._model_id, device=device) | |
| self._device = device | |
| self._loaded = True | |
| return True | |
| except ImportError as exc: | |
| self._load_error = f"sentence_transformers not installed: {exc}" | |
| return False | |
| except Exception as exc: | |
| self._load_error = str(exc) | |
| return False | |
| async def rerank(self, request: RerankRequest) -> RerankResponse: | |
| if not self._load(): | |
| return RerankResponse( | |
| ranked=[RerankedDoc(id=d.id, score=0.0) for d in request.docs], | |
| meta={"error": self._load_error, "backend": self.name}, | |
| ) | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(None, self._sync_rerank, request) | |
| def _sync_rerank(self, request: RerankRequest) -> RerankResponse: | |
| t0 = time.monotonic() | |
| pairs = [[request.query, doc.text] for doc in request.docs] | |
| scores: list[float] = [] | |
| # Process in batches | |
| for i in range(0, len(pairs), self._max_batch): | |
| batch = pairs[i : i + self._max_batch] | |
| batch_scores = self._encoder.predict(batch) # type: ignore[union-attr] | |
| scores.extend(float(s) for s in batch_scores) | |
| ranked = sorted( | |
| [ | |
| RerankedDoc(id=doc.id, score=score) | |
| for doc, score in zip(request.docs, scores, strict=False) | |
| ], | |
| key=lambda x: x.score, | |
| reverse=True, | |
| ) | |
| if request.top_k is not None: | |
| ranked = ranked[: request.top_k] | |
| return RerankResponse( | |
| ranked=ranked, | |
| meta={ | |
| "backend": self.name, | |
| "model": self._model_id, | |
| "ms": int((time.monotonic() - t0) * 1000), | |
| "doc_count": len(request.docs), | |
| }, | |
| ) | |
| def health(self) -> dict: | |
| return { | |
| "backend": self.name, | |
| "model": self._model_id, | |
| "loaded": self._loaded, | |
| "available": self._load_error is None, | |
| "error": self._load_error, | |
| } | |