Spaces:
Running on Zero
Running on Zero
GitHub Actions
Quality improvements: Unicode chars, Token class, imports, type hints, formatting
3f78ea8 | """M05 Federated RAG β multi-node scatter-gather query with reranking. | |
| Strategy mix: | |
| A β single-best routing (handled by plain rag.query) | |
| B β scatter-gather: fan out to all peers, merge results | |
| C β local-first: return immediately when local confidence is high | |
| E β MoE routing: use moe.route to prioritise which peers to query first | |
| Capability: rag.federated_query v1.0 | |
| Spec: docs/M05-rag.md Β§9 (distributed query path) | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import logging | |
| from typing import Any | |
| from hearthnet.bus.capability import CapabilityDescriptor, RouteRequest | |
| _log = logging.getLogger(__name__) | |
| _DEFAULT_CONFIDENCE = 0.5 # local-first threshold (C) | |
| _DEFAULT_FANOUT_TIMEOUT = 4.0 # seconds per remote call (B) | |
| _DEFAULT_K = 5 | |
| class FederatedRagService: | |
| """Registers rag.federated_query on the capability bus. | |
| Constructor args: | |
| bus β CapabilityBus (required; used for scatter-gather calls) | |
| corpus β corpus name filter; None = any corpus | |
| confidence_threshold β local score threshold for early return (C strategy) | |
| fanout_timeout β per-peer timeout in seconds (B strategy) | |
| """ | |
| name = "rag.federated" | |
| version = "1.0" | |
| def __init__( | |
| self, | |
| bus: Any, | |
| *, | |
| corpus: str | None = None, | |
| confidence_threshold: float = _DEFAULT_CONFIDENCE, | |
| fanout_timeout: float = _DEFAULT_FANOUT_TIMEOUT, | |
| ) -> None: | |
| self._bus = bus | |
| self._corpus = corpus | |
| self._confidence = confidence_threshold | |
| self._fanout_timeout = fanout_timeout | |
| def capabilities(self) -> list[tuple]: | |
| params: dict[str, Any] = {} | |
| if self._corpus: | |
| params["corpus"] = self._corpus | |
| return [ | |
| ( | |
| CapabilityDescriptor( | |
| name="rag.federated_query", | |
| version=(1, 0), | |
| params=params, | |
| max_concurrent=4, | |
| idempotent=True, | |
| ), | |
| self.handle_federated_query, | |
| self._corpus_matches, | |
| ), | |
| ] | |
| def _corpus_matches(self, offered: dict, requested: dict) -> bool: | |
| return ( | |
| not requested.get("corpus") | |
| or not offered.get("corpus") | |
| or requested.get("corpus") == offered.get("corpus") | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Main handler | |
| # ------------------------------------------------------------------ | |
| async def handle_federated_query(self, req: RouteRequest) -> dict[str, Any]: | |
| """Federated query: local-first β scatter-gather β merge β rerank.""" | |
| inp = req.body.get("input", {}) | |
| query: str = inp.get("query", "") | |
| k: int = int(inp.get("k", _DEFAULT_K)) | |
| corpus: str | None = inp.get("corpus", self._corpus) | |
| threshold: float = float(inp.get("confidence_threshold", self._confidence)) | |
| if not query: | |
| return {"output": {"chunks": []}, "meta": {"corpus": corpus, "federated": False}} | |
| # ββ Strategy C: local-first ββββββββββββββββββββββββββββββββββββββββ | |
| local_chunks, local_node_id, best_local_score = await self._query_local(query, k, corpus) | |
| if best_local_score >= threshold and local_chunks: | |
| _log.debug("federated_query: local-first short-circuit score=%.3f", best_local_score) | |
| _add_source(local_chunks, local_node_id) | |
| return { | |
| "output": {"chunks": local_chunks[:k]}, | |
| "meta": { | |
| "corpus": corpus, | |
| "federated": False, | |
| "peers_asked": 0, | |
| "reranked": False, | |
| }, | |
| } | |
| # ββ Strategy E: MoE β prioritise peers by topic ββββββββββββββββββββ | |
| peer_priority: list[str] | None = await self._moe_peer_priority(query, corpus) | |
| # ββ Strategy B: scatter-gather βββββββββββββββββββββββββββββββββββββ | |
| query_body = { | |
| "input": {"query": query, "k": k * 2, "corpus": corpus}, | |
| "params": {"corpus": corpus} if corpus else {}, | |
| } | |
| all_results = await self._bus.call_all( | |
| "rag.query", | |
| (1, 0), | |
| query_body, | |
| include_local=False, # we already queried local above | |
| timeout_seconds=self._fanout_timeout, | |
| max_providers=6, | |
| ) | |
| peers_asked = len(all_results) | |
| # Reorder by MoE priority if we got one | |
| if peer_priority: | |
| def _priority_key(item: tuple[str, dict]) -> int: | |
| try: | |
| return peer_priority.index(item[0]) | |
| except ValueError: | |
| return len(peer_priority) | |
| all_results.sort(key=_priority_key) | |
| # ββ Merge local + remote βββββββββββββββββββββββββββββββββββββββββββ | |
| merged: list[dict[str, Any]] = [] | |
| _add_source(local_chunks, local_node_id) | |
| merged.extend(local_chunks) | |
| for node_id, result in all_results: | |
| chunks = result.get("output", {}).get("chunks", []) | |
| _add_source(chunks, node_id) | |
| merged.extend(chunks) | |
| # ββ Deduplicate by doc_cid / text fingerprint βββββββββββββββββββββ | |
| merged = _dedupe(merged) | |
| # ββ Rerank via M24 rerank.text ββββββββββββββββββββββββββββββββββββ | |
| reranked = False | |
| if len(merged) > k: | |
| try: | |
| rerank_body = { | |
| "input": { | |
| "query": query, | |
| "docs": [{"id": str(i), "text": c["text"]} for i, c in enumerate(merged)], | |
| "top_k": k, | |
| } | |
| } | |
| rerank_result = await self._bus.call("rerank.text", (1, 0), rerank_body) | |
| ranked = rerank_result.get("output", {}).get("ranked", []) | |
| if ranked: | |
| idx_score = {int(r["id"]): r["score"] for r in ranked} | |
| for i, chunk in enumerate(merged): | |
| chunk["score"] = idx_score.get(i, chunk.get("score", 0.0)) | |
| merged.sort(key=lambda c: c.get("score", 0.0), reverse=True) | |
| reranked = True | |
| except Exception as exc: | |
| _log.debug("rerank.text unavailable, falling back to score sort: %s", exc) | |
| merged.sort(key=lambda c: c.get("score", 0.0), reverse=True) | |
| # Re-number ranks | |
| for i, chunk in enumerate(merged[:k]): | |
| chunk["rank"] = i + 1 | |
| return { | |
| "output": {"chunks": merged[:k]}, | |
| "meta": { | |
| "corpus": corpus, | |
| "federated": True, | |
| "peers_asked": peers_asked, | |
| "reranked": reranked, | |
| }, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| async def _query_local( | |
| self, query: str, k: int, corpus: str | None | |
| ) -> tuple[list[dict], str, float]: | |
| """Query the local rag.query and return (chunks, node_id, best_score).""" | |
| body: dict[str, Any] = { | |
| "input": {"query": query, "k": k, "corpus": corpus}, | |
| "params": {"corpus": corpus} if corpus else {}, | |
| } | |
| try: | |
| result = await self._bus.call("rag.query", (1, 0), body) | |
| chunks = result.get("output", {}).get("chunks", []) | |
| best = max((c.get("score", 0.0) for c in chunks), default=0.0) | |
| return chunks, self._bus.node_id_full, best | |
| except Exception as exc: | |
| _log.debug("local rag.query failed: %s", exc) | |
| return [], self._bus.node_id_full, 0.0 | |
| async def _moe_peer_priority(self, query: str, corpus: str | None) -> list[str] | None: | |
| """Ask moe.route to rank which expert peers to prefer. Returns node_ids or None.""" | |
| tags = [corpus] if corpus else [] | |
| try: | |
| result = await self._bus.call( | |
| "moe.route", | |
| (1, 0), | |
| {"input": {"query": query, "top_k": 4, "tags": tags}}, | |
| ) | |
| candidates = result.get("output", {}).get("candidates", []) | |
| return [c["expert_id"] for c in candidates if "expert_id" in c] | |
| except Exception: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Utilities | |
| # --------------------------------------------------------------------------- | |
| def _add_source(chunks: list[dict], node_id: str) -> None: | |
| """Attach source_node provenance to each chunk in-place.""" | |
| for chunk in chunks: | |
| chunk.setdefault("source_node", node_id) | |
| def _dedupe(chunks: list[dict]) -> list[dict]: | |
| """Remove duplicate chunks (same doc_cid or same text fingerprint).""" | |
| seen: set[str] = set() | |
| out: list[dict] = [] | |
| for chunk in chunks: | |
| meta = chunk.get("metadata") or {} | |
| doc_cid = meta.get("doc_cid") or meta.get("source") | |
| if doc_cid: | |
| key = doc_cid | |
| else: | |
| text = chunk.get("text", "") | |
| key = hashlib.sha256(text.encode()).hexdigest() | |
| if key not in seen: | |
| seen.add(key) | |
| out.append(chunk) | |
| return out | |