File size: 3,268 Bytes
4cd8837
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ee8138
d6ca3a2
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import asyncio
import time

from hearthnet.services.rerank.backends.base import (
    RerankedDoc,
    RerankRequest,
    RerankResponse,
)


class BgeRerankerBackend:
    """Cross-encoder reranker using BAAI/bge-reranker models."""

    name = "bge_reranker"

    def __init__(
        self,
        model_id: str = "BAAI/bge-reranker-v2-m3",
        device: str = "auto",
        max_batch: int = 32,
    ) -> None:
        self._model_id = model_id
        self._device = device
        self._max_batch = max_batch
        self._encoder = None
        self._loaded = False
        self._load_error: str | None = None

    def _load(self) -> bool:
        if self._loaded:
            return True
        if self._load_error:
            return False
        try:
            import torch  # type: ignore[import-untyped]
            from sentence_transformers import CrossEncoder  # type: ignore[import-untyped]

            device = self._device
            if device == "auto":
                device = "cuda" if torch.cuda.is_available() else "cpu"

            self._encoder = CrossEncoder(self._model_id, device=device)
            self._device = device
            self._loaded = True
            return True
        except ImportError as exc:
            self._load_error = f"sentence_transformers not installed: {exc}"
            return False
        except Exception as exc:
            self._load_error = str(exc)
            return False

    async def rerank(self, request: RerankRequest) -> RerankResponse:
        if not self._load():
            return RerankResponse(
                ranked=[RerankedDoc(id=d.id, score=0.0) for d in request.docs],
                meta={"error": self._load_error, "backend": self.name},
            )

        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(None, self._sync_rerank, request)

    def _sync_rerank(self, request: RerankRequest) -> RerankResponse:
        t0 = time.monotonic()
        pairs = [[request.query, doc.text] for doc in request.docs]
        scores: list[float] = []

        # Process in batches
        for i in range(0, len(pairs), self._max_batch):
            batch = pairs[i : i + self._max_batch]
            batch_scores = self._encoder.predict(batch)  # type: ignore[union-attr]
            scores.extend(float(s) for s in batch_scores)

        ranked = sorted(
            [
                RerankedDoc(id=doc.id, score=score)
                for doc, score in zip(request.docs, scores, strict=False)
            ],
            key=lambda x: x.score,
            reverse=True,
        )
        if request.top_k is not None:
            ranked = ranked[: request.top_k]

        return RerankResponse(
            ranked=ranked,
            meta={
                "backend": self.name,
                "model": self._model_id,
                "ms": int((time.monotonic() - t0) * 1000),
                "doc_count": len(request.docs),
            },
        )

    def health(self) -> dict:
        return {
            "backend": self.name,
            "model": self._model_id,
            "loaded": self._loaded,
            "available": self._load_error is None,
            "error": self._load_error,
        }