"""Weighted-sum fusion. score_final(d) = Σ_L w_L · normalize(score_L(d)) Each retriever's scores are min-max normalized into [0,1] before weighting so different magnitudes (BM25 vs cosine) become comparable. """ from __future__ import annotations from typing import Dict, List, Optional from ..core.types import Retrieved def _minmax(vals: List[float]) -> List[float]: if not vals: return vals lo, hi = min(vals), max(vals) span = hi - lo if span == 0: return [1.0 for _ in vals] return [(v - lo) / span for v in vals] class WeightedFuser: """Normalize-then-weight fusion.""" name = "weighted" def __init__(self, weights: Optional[Dict[str, float]] = None) -> None: self.weights = weights or {} def _weight_for(self, retriever: str) -> float: return float(self.weights.get(retriever, 1.0)) def fuse( self, per_retriever_results: List[List[Retrieved]], top_n: int = 20, ) -> List[Retrieved]: scores: Dict[str, float] = {} best: Dict[str, Retrieved] = {} for retrieved_list in per_retriever_results: if not retrieved_list: continue retriever_name = retrieved_list[0].retriever w = self._weight_for(retriever_name) normed = _minmax([r.score for r in retrieved_list]) for r, ns in zip(retrieved_list, normed): cid = r.chunk.chunk_id scores[cid] = scores.get(cid, 0.0) + w * ns cur = best.get(cid) if cur is None or r.score > cur.score: best[cid] = r fused: List[Retrieved] = [] for cid, s in sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:top_n]: base = best[cid] merged_extra = dict(base.extra or {}) merged_extra.setdefault("origin_retriever", base.retriever) merged_extra.setdefault("origin_score", base.score) fused.append(Retrieved( chunk=base.chunk, score=s, retriever="weighted", rank=len(fused) + 1, extra=merged_extra, )) return fused