"""Cross-encoder reranker. Uses `sentence-transformers` CrossEncoder if available. Falls back to a simple length-normalized overlap scorer so the pipeline never breaks. """ from __future__ import annotations import re from typing import List, Optional from ..core.types import Query, Retrieved _TOKEN_RE = re.compile(r"\w+", re.UNICODE) def _overlap_score(q: str, d: str) -> float: qt = set(t.lower() for t in _TOKEN_RE.findall(q)) dt = set(t.lower() for t in _TOKEN_RE.findall(d)) if not qt or not dt: return 0.0 return len(qt & dt) / len(qt | dt) class CrossEncoderReranker: """Pairwise reranker with graceful fallback.""" name = "cross_encoder" def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> None: self._model_name = model_name self._model = None self._loaded = False def _lazy_load(self) -> None: if self._loaded: return try: from sentence_transformers import CrossEncoder # type: ignore self._model = CrossEncoder(self._model_name) except Exception: self._model = None self._loaded = True def rerank( self, query: Query, candidates: List[Retrieved], k: int, ) -> List[Retrieved]: if not candidates: return [] self._lazy_load() if self._model is not None: pairs = [(query.text, r.chunk.text) for r in candidates] try: scores = self._model.predict(pairs).tolist() except Exception: scores = [_overlap_score(q, d) for q, d in pairs] else: scores = [_overlap_score(query.text, r.chunk.text) for r in candidates] reordered = sorted( zip(candidates, scores), key=lambda t: t[1], reverse=True )[:k] out: List[Retrieved] = [] for rank, (r, s) in enumerate(reordered, start=1): r2 = Retrieved( chunk=r.chunk, score=float(s), retriever="reranked", rank=rank, extra={**r.extra, "prev_retriever": r.retriever, "prev_score": r.score}, ) out.append(r2) return out