Spaces:
Running on Zero
Running on Zero
File size: 3,268 Bytes
4cd8837 4aaae80 4cd8837 4aaae80 4cd8837 8ee8138 d6ca3a2 4cd8837 4aaae80 4cd8837 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | from __future__ import annotations
import asyncio
import time
from hearthnet.services.rerank.backends.base import (
RerankedDoc,
RerankRequest,
RerankResponse,
)
class BgeRerankerBackend:
"""Cross-encoder reranker using BAAI/bge-reranker models."""
name = "bge_reranker"
def __init__(
self,
model_id: str = "BAAI/bge-reranker-v2-m3",
device: str = "auto",
max_batch: int = 32,
) -> None:
self._model_id = model_id
self._device = device
self._max_batch = max_batch
self._encoder = None
self._loaded = False
self._load_error: str | None = None
def _load(self) -> bool:
if self._loaded:
return True
if self._load_error:
return False
try:
import torch # type: ignore[import-untyped]
from sentence_transformers import CrossEncoder # type: ignore[import-untyped]
device = self._device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self._encoder = CrossEncoder(self._model_id, device=device)
self._device = device
self._loaded = True
return True
except ImportError as exc:
self._load_error = f"sentence_transformers not installed: {exc}"
return False
except Exception as exc:
self._load_error = str(exc)
return False
async def rerank(self, request: RerankRequest) -> RerankResponse:
if not self._load():
return RerankResponse(
ranked=[RerankedDoc(id=d.id, score=0.0) for d in request.docs],
meta={"error": self._load_error, "backend": self.name},
)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._sync_rerank, request)
def _sync_rerank(self, request: RerankRequest) -> RerankResponse:
t0 = time.monotonic()
pairs = [[request.query, doc.text] for doc in request.docs]
scores: list[float] = []
# Process in batches
for i in range(0, len(pairs), self._max_batch):
batch = pairs[i : i + self._max_batch]
batch_scores = self._encoder.predict(batch) # type: ignore[union-attr]
scores.extend(float(s) for s in batch_scores)
ranked = sorted(
[
RerankedDoc(id=doc.id, score=score)
for doc, score in zip(request.docs, scores, strict=False)
],
key=lambda x: x.score,
reverse=True,
)
if request.top_k is not None:
ranked = ranked[: request.top_k]
return RerankResponse(
ranked=ranked,
meta={
"backend": self.name,
"model": self._model_id,
"ms": int((time.monotonic() - t0) * 1000),
"doc_count": len(request.docs),
},
)
def health(self) -> dict:
return {
"backend": self.name,
"model": self._model_id,
"loaded": self._loaded,
"available": self._load_error is None,
"error": self._load_error,
}
|