Spaces:
Running on Zero
Running on Zero
GitHub Actions
Quality improvements: Unicode chars, Token class, imports, type hints, formatting
3f78ea8 | """Tests for Phase 1 + Phase 2 distributed RAG. | |
| Phase 1: FederatedRagService (rag.federated_query) — local-first + scatter-gather + rerank. | |
| Phase 2: CorpusReplicator — event-driven BLAKE3 blob replication. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import functools | |
| from typing import Any | |
| from unittest.mock import AsyncMock, MagicMock | |
| import pytest | |
| def run(coro): | |
| return asyncio.get_event_loop().run_until_complete(coro) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_rag_result(chunks: list[dict], corpus: str = "test") -> dict: | |
| return {"output": {"chunks": chunks}, "meta": {"corpus": corpus}} | |
| def _chunk(text: str, score: float = 0.8, rank: int = 1, doc_cid: str | None = None) -> dict: | |
| return { | |
| "rank": rank, | |
| "score": score, | |
| "text": text, | |
| "metadata": {"doc_cid": doc_cid or f"cid:{text[:8]}"}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Phase 1 – FederatedRagService unit tests | |
| # --------------------------------------------------------------------------- | |
| class TestFederatedRagService: | |
| """rag.federated_query: local-first, scatter-gather, merge, rerank.""" | |
| def _make_bus( | |
| self, | |
| local_chunks: list[dict] | None = None, | |
| remote_chunks: list[dict] | None = None, | |
| rerank_available: bool = False, | |
| ) -> MagicMock: | |
| bus = MagicMock() | |
| bus.node_id_full = "ed25519:local-node" | |
| local_result = _make_rag_result(local_chunks or []) | |
| remote_result = _make_rag_result(remote_chunks or []) | |
| async def _call(cap, ver, body, **kw): | |
| if cap == "rag.query": | |
| return local_result | |
| if cap == "rerank.text" and rerank_available: | |
| docs = body["input"]["docs"] | |
| # Reverse order to simulate rerank changing order | |
| ranked = [{"id": d["id"], "score": 0.9 - int(d["id"]) * 0.1} for d in docs] | |
| return {"output": {"ranked": ranked}} | |
| if cap == "moe.route": | |
| return {"output": {"candidates": []}} | |
| raise Exception(f"not_found: {cap}") | |
| bus.call = AsyncMock(side_effect=_call) | |
| async def _call_all(cap, ver, body, **kw): | |
| if cap == "rag.query": | |
| return [("ed25519:peer-1", remote_result)] | |
| return [] | |
| bus.call_all = AsyncMock(side_effect=_call_all) | |
| return bus | |
| def test_local_first_shortcircuit(self): | |
| """If local score >= threshold, returns without fan-out (C strategy).""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| chunks = [_chunk("local knowledge", score=0.9, rank=1)] | |
| bus = self._make_bus(local_chunks=chunks) | |
| svc = FederatedRagService(bus, corpus="test", confidence_threshold=0.5) | |
| req = MagicMock() | |
| req.body = {"input": {"query": "local knowledge", "k": 3, "corpus": "test"}} | |
| result = run(svc.handle_federated_query(req)) | |
| assert result["meta"]["federated"] is False | |
| assert result["meta"]["peers_asked"] == 0 | |
| assert len(result["output"]["chunks"]) >= 1 | |
| # Should NOT have called call_all | |
| bus.call_all.assert_not_called() | |
| def test_scatter_gather_on_low_local_score(self): | |
| """When local score < threshold, fans out to peers (B strategy).""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| local_chunks = [_chunk("weak local", score=0.2, rank=1)] | |
| remote_chunks = [_chunk("strong remote", score=0.95, rank=1)] | |
| bus = self._make_bus(local_chunks=local_chunks, remote_chunks=remote_chunks) | |
| svc = FederatedRagService(bus, corpus="test", confidence_threshold=0.5) | |
| req = MagicMock() | |
| req.body = {"input": {"query": "remote knowledge", "k": 5, "corpus": "test"}} | |
| result = run(svc.handle_federated_query(req)) | |
| assert result["meta"]["federated"] is True | |
| assert result["meta"]["peers_asked"] == 1 | |
| texts = [c["text"] for c in result["output"]["chunks"]] | |
| assert "strong remote" in texts | |
| def test_provenance_attached_to_chunks(self): | |
| """Each chunk must carry source_node identifying which node answered.""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| local_chunks = [_chunk("local doc", score=0.2)] | |
| remote_chunks = [_chunk("remote doc", score=0.95)] | |
| bus = self._make_bus(local_chunks=local_chunks, remote_chunks=remote_chunks) | |
| svc = FederatedRagService(bus, corpus="test", confidence_threshold=0.5) | |
| req = MagicMock() | |
| req.body = {"input": {"query": "docs", "k": 5, "corpus": "test"}} | |
| result = run(svc.handle_federated_query(req)) | |
| for chunk in result["output"]["chunks"]: | |
| assert "source_node" in chunk, f"chunk missing source_node: {chunk}" | |
| def test_deduplication_by_doc_cid(self): | |
| """Same doc_cid from local + remote appears only once in merged output.""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| shared_cid = "cid:shared-doc" | |
| local_chunks = [_chunk("same text", score=0.2, doc_cid=shared_cid)] | |
| remote_chunks = [_chunk("same text", score=0.8, doc_cid=shared_cid)] | |
| bus = self._make_bus(local_chunks=local_chunks, remote_chunks=remote_chunks) | |
| svc = FederatedRagService(bus, corpus="test", confidence_threshold=0.1) | |
| req = MagicMock() | |
| req.body = {"input": {"query": "same", "k": 5, "corpus": "test"}} | |
| result = run(svc.handle_federated_query(req)) | |
| cids = [c.get("metadata", {}).get("doc_cid") for c in result["output"]["chunks"]] | |
| assert cids.count(shared_cid) == 1, f"duplicate doc_cid in output: {cids}" | |
| def test_graceful_degradation_no_peers(self): | |
| """When no peers available, still returns local results (C-strategy fallback).""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| local_chunks = [_chunk("only local", score=0.1)] | |
| bus = self._make_bus(local_chunks=local_chunks) | |
| bus.call_all = AsyncMock(return_value=[]) # no peers | |
| svc = FederatedRagService(bus, corpus="test", confidence_threshold=0.5) | |
| req = MagicMock() | |
| req.body = {"input": {"query": "anything", "k": 5, "corpus": "test"}} | |
| result = run(svc.handle_federated_query(req)) | |
| assert result["output"]["chunks"][0]["text"] == "only local" | |
| def test_empty_query_returns_empty(self): | |
| """Empty query returns empty chunks without errors.""" | |
| from hearthnet.services.rag.federated import FederatedRagService | |
| bus = self._make_bus() | |
| svc = FederatedRagService(bus, corpus="test") | |
| req = MagicMock() | |
| req.body = {"input": {"query": "", "k": 5}} | |
| result = run(svc.handle_federated_query(req)) | |
| assert result["output"]["chunks"] == [] | |
| # --------------------------------------------------------------------------- | |
| # Phase 1 – bus.call_all() primitive unit tests | |
| # --------------------------------------------------------------------------- | |
| class TestCallAll: | |
| """CapabilityBus.call_all() scatter-gather primitive.""" | |
| def _make_bus_with_two_nodes(self): | |
| """Two in-process nodes sharing an InMemoryTransport.""" | |
| from hearthnet.bus import CapabilityBus, InMemoryTransport | |
| from hearthnet.bus.capability import CapabilityDescriptor | |
| transport = InMemoryTransport() | |
| alpha = CapabilityBus("ed25519:alpha", "test-community", transport) | |
| beta = CapabilityBus("ed25519:beta", "test-community", transport) | |
| async def alpha_handler(req): | |
| return {"output": {"node": "alpha", "echo": req.body.get("input", {}).get("q")}} | |
| async def beta_handler(req): | |
| return {"output": {"node": "beta", "echo": req.body.get("input", {}).get("q")}} | |
| alpha.register_capability( | |
| CapabilityDescriptor(name="test.echo", version=(1, 0)), | |
| alpha_handler, | |
| ) | |
| beta.register_capability( | |
| CapabilityDescriptor(name="test.echo", version=(1, 0)), | |
| beta_handler, | |
| ) | |
| # Cross-register so each bus knows about the other | |
| from hearthnet.bus.capability import CapabilityEntry | |
| from hearthnet.bus.router import BusConfig | |
| for bus, remote_bus, _remote_handler in [ | |
| (alpha, beta, beta_handler), | |
| (beta, alpha, alpha_handler), | |
| ]: | |
| from hearthnet.bus.capability import CapabilityDescriptor as CD, CapabilityEntry as CE | |
| entry = CE( | |
| descriptor=CD(name="test.echo", version=(1, 0)), | |
| handler=None, | |
| params_compatible=lambda o, r: True, | |
| node_id=remote_bus.node_id_full, | |
| is_local=False, | |
| ) | |
| bus.registry._entries[f"test.echo@1.0:{remote_bus.node_id_full}"] = entry | |
| return alpha, beta | |
| def test_call_all_reaches_all_providers(self): | |
| """call_all fans out to all matching providers and returns all results.""" | |
| from hearthnet.bus import CapabilityBus, InMemoryTransport | |
| from hearthnet.bus.capability import CapabilityDescriptor | |
| transport = InMemoryTransport() | |
| alpha = CapabilityBus("ed25519:alpha", "comm", transport) | |
| beta = CapabilityBus("ed25519:beta", "comm", transport) | |
| async def echo(req): | |
| return {"output": {"from": alpha.node_id_full if req else "?"}} | |
| alpha.register_capability(CapabilityDescriptor(name="ping", version=(1, 0)), echo) | |
| results = run(alpha.call_all("ping", (1, 0), {})) | |
| assert len(results) == 1 | |
| assert results[0][0] == "ed25519:alpha" | |
| def test_call_all_tolerates_partial_failure(self): | |
| """call_all returns successful results even when some providers fail.""" | |
| from hearthnet.bus import CapabilityBus, InMemoryTransport | |
| from hearthnet.bus.capability import CapabilityDescriptor | |
| transport = InMemoryTransport() | |
| bus = CapabilityBus("ed25519:node", "comm", transport) | |
| async def ok_handler(req): | |
| return {"output": "ok"} | |
| async def fail_handler(req): | |
| raise RuntimeError("deliberate failure") | |
| bus.register_capability(CapabilityDescriptor(name="test.cap", version=(1, 0)), ok_handler) | |
| results = run(bus.call_all("test.cap", (1, 0), {}, timeout_seconds=2.0)) | |
| # At minimum the ok_handler result is present | |
| assert any(r[1].get("output") == "ok" for r in results) | |
| def test_call_all_empty_when_no_providers(self): | |
| """call_all returns empty list when nobody offers the capability.""" | |
| from hearthnet.bus import CapabilityBus, InMemoryTransport | |
| transport = InMemoryTransport() | |
| bus = CapabilityBus("ed25519:node", "comm", transport) | |
| results = run(bus.call_all("nonexistent.cap", (1, 0), {})) | |
| assert results == [] | |
| # --------------------------------------------------------------------------- | |
| # Phase 2 – CorpusReplicator unit tests | |
| # --------------------------------------------------------------------------- | |
| class TestCorpusReplicator: | |
| """CorpusReplicator: event-driven BLAKE3 blob replication.""" | |
| def _make_event( | |
| self, | |
| author: str, | |
| corpus: str = "test", | |
| doc_cid: str = "cid:doc1", | |
| blob_cid: str = "blob:abc", | |
| title: str = "Test Doc", | |
| ): | |
| evt = MagicMock() | |
| evt.author = author | |
| evt.payload = { | |
| "corpus": corpus, | |
| "doc_cid": doc_cid, | |
| "blob_cid": blob_cid, | |
| "title": title, | |
| } | |
| return evt | |
| def test_skips_own_events(self): | |
| """Replicator ignores events authored by local node.""" | |
| from hearthnet.services.rag.replication import CorpusReplicator | |
| bus = MagicMock() | |
| bus.call = AsyncMock(return_value={"output": {}}) | |
| event_log = MagicMock() | |
| transfer = MagicMock() | |
| peers = MagicMock() | |
| peers.all.return_value = [] | |
| replicator = CorpusReplicator( | |
| bus=bus, | |
| event_log=event_log, | |
| transfer=transfer, | |
| peers=peers, | |
| local_node_id="ed25519:local", | |
| ) | |
| evt = self._make_event(author="ed25519:local") | |
| run(replicator._handle_event(evt)) | |
| # Must NOT call bus.call (ingest) or transfer.fetch | |
| bus.call.assert_not_called() | |
| transfer.fetch.assert_not_called() | |
| def test_skips_known_docs(self): | |
| """If corpus_store_fn reports has_doc, replicator skips fetching.""" | |
| from hearthnet.services.rag.replication import CorpusReplicator | |
| bus = MagicMock() | |
| bus.call = AsyncMock() | |
| event_log = MagicMock() | |
| transfer = MagicMock() | |
| peers = MagicMock() | |
| peers.all.return_value = [] | |
| store_mock = MagicMock() | |
| store_mock.has_doc.return_value = True | |
| replicator = CorpusReplicator( | |
| bus=bus, | |
| event_log=event_log, | |
| transfer=transfer, | |
| peers=peers, | |
| local_node_id="ed25519:local", | |
| corpus_store_fn=lambda corpus: store_mock, | |
| ) | |
| evt = self._make_event(author="ed25519:peer") | |
| run(replicator._handle_event(evt)) | |
| store_mock.has_doc.assert_called_once_with("cid:doc1") | |
| bus.call.assert_not_called() | |
| def test_skips_when_no_blob_cid(self): | |
| """Without blob_cid in event payload, replicator cannot fetch.""" | |
| from hearthnet.services.rag.replication import CorpusReplicator | |
| bus = MagicMock() | |
| bus.call = AsyncMock() | |
| event_log = MagicMock() | |
| transfer = MagicMock() | |
| peers = MagicMock() | |
| peers.all.return_value = [] | |
| replicator = CorpusReplicator( | |
| bus=bus, | |
| event_log=event_log, | |
| transfer=transfer, | |
| peers=peers, | |
| local_node_id="ed25519:local", | |
| ) | |
| evt = MagicMock() | |
| evt.author = "ed25519:peer" | |
| evt.payload = {"corpus": "test", "doc_cid": "cid:doc1"} # no blob_cid | |
| run(replicator._handle_event(evt)) | |
| transfer.fetch.assert_not_called() | |
| bus.call.assert_not_called() | |
| def test_fetches_and_ingests_new_doc(self): | |
| """Replicator fetches blob and calls rag.ingest for unknown peer doc.""" | |
| from hearthnet.services.rag.replication import CorpusReplicator | |
| raw_text = b"replicated document text" | |
| bus = MagicMock() | |
| bus.call = AsyncMock(return_value={"output": {"chunks_indexed": 2, "was_duplicate": False}}) | |
| manifest = MagicMock() | |
| manifest.cid = "blob:abc" | |
| transfer = MagicMock() | |
| transfer.fetch = AsyncMock(return_value=manifest) | |
| transfer.store = MagicMock() | |
| transfer.store.get = MagicMock(return_value=raw_text) | |
| peers = MagicMock() | |
| peer_rec = MagicMock() | |
| peer_rec.node_id = "ed25519:peer" | |
| ep = MagicMock() | |
| ep.transport = "http" | |
| ep.host = "192.168.1.2" | |
| ep.port = 7080 | |
| peer_rec.endpoints = [ep] | |
| peers.all.return_value = [peer_rec] | |
| store_mock = MagicMock() | |
| store_mock.has_doc.return_value = False | |
| replicator = CorpusReplicator( | |
| bus=bus, | |
| event_log=MagicMock(), | |
| transfer=transfer, | |
| peers=peers, | |
| local_node_id="ed25519:local", | |
| corpus_store_fn=lambda corpus: store_mock, | |
| ) | |
| evt = self._make_event( | |
| author="ed25519:peer", | |
| corpus="test", | |
| doc_cid="cid:doc1", | |
| blob_cid="blob:abc", | |
| title="Remote Doc", | |
| ) | |
| run(replicator._handle_event(evt)) | |
| transfer.fetch.assert_called_once_with("blob:abc", ["http://192.168.1.2:7080"]) | |
| bus.call.assert_called_once() | |
| call_args = bus.call.call_args | |
| assert call_args[0][0] == "rag.ingest" | |
| ingest_input = call_args[0][2]["input"] | |
| assert ingest_input["text"] == raw_text.decode("utf-8") | |
| assert ingest_input["doc_cid"] == "cid:doc1" | |
| assert ingest_input["corpus"] == "test" | |
| # --------------------------------------------------------------------------- | |
| # Phase 1 – Integration: two-node mesh, federated query returns from both | |
| # --------------------------------------------------------------------------- | |
| class TestFederatedIntegration: | |
| """Two in-memory nodes, each with different docs. Federated query merges both.""" | |
| def test_two_nodes_federated_returns_from_both(self): | |
| """Alice and Bob each have unique docs. Alice's federated query returns both.""" | |
| from hearthnet.node import HearthNode, InMemoryNetwork | |
| net = InMemoryNetwork() | |
| alice = net.add_node("alice", "Alice", "ed25519:alice") | |
| bob = net.add_node("bob", "Bob", "ed25519:bob") | |
| alice.install_demo_services(corpus="shared") | |
| bob.install_demo_services(corpus="shared") | |
| net.mesh_discover() | |
| # Ingest unique docs into each node | |
| run( | |
| alice.bus.call( | |
| "rag.ingest", | |
| (1, 0), | |
| { | |
| "params": {"corpus": "shared"}, | |
| "input": { | |
| "doc_cid": "alice-doc", | |
| "title": "Alice Doc", | |
| "text": "alice unique knowledge about planets", | |
| }, | |
| }, | |
| ) | |
| ) | |
| run( | |
| bob.bus.call( | |
| "rag.ingest", | |
| (1, 0), | |
| { | |
| "params": {"corpus": "shared"}, | |
| "input": { | |
| "doc_cid": "bob-doc", | |
| "title": "Bob Doc", | |
| "text": "bob unique knowledge about stars", | |
| }, | |
| }, | |
| ) | |
| ) | |
| # Single-node query on Alice only returns Alice's doc | |
| local_result = run( | |
| alice.bus.call( | |
| "rag.query", | |
| (1, 0), | |
| {"params": {"corpus": "shared"}, "input": {"query": "stars knowledge", "k": 5}}, | |
| ) | |
| ) | |
| local_texts = [c["text"] for c in local_result["output"]["chunks"]] | |
| # Alice may or may not have Bob's doc locally (no replication yet) | |
| # Federated query should be able to reach Bob | |
| federated_result = run( | |
| alice.bus.call( | |
| "rag.federated_query", | |
| (1, 0), | |
| { | |
| "params": {"corpus": "shared"}, | |
| "input": { | |
| "query": "knowledge", | |
| "k": 5, | |
| "corpus": "shared", | |
| "confidence_threshold": 0.0, | |
| }, | |
| }, # force fan-out | |
| ) | |
| ) | |
| all_texts = [c["text"] for c in federated_result["output"]["chunks"]] | |
| # Should have results (at minimum from local) | |
| assert len(all_texts) > 0 | |
| # Federated metadata present | |
| assert "federated" in federated_result["meta"] | |
| def test_rag_service_emits_blob_cid_on_ingest(self): | |
| """RagService stores blob and returns blob info when blob_store is provided.""" | |
| import tempfile | |
| from pathlib import Path | |
| from hearthnet.blobs.store import BlobStore | |
| from hearthnet.services.rag.service import RagService | |
| from hearthnet.bus.capability import RouteRequest | |
| with tempfile.TemporaryDirectory() as tmp: | |
| blob_store = BlobStore(Path(tmp) / "blobs") | |
| svc = RagService(corpus="test", blob_store=blob_store) | |
| req = MagicMock(spec=RouteRequest) | |
| req.body = { | |
| "input": { | |
| "text": "hello world document content", | |
| "title": "Test", | |
| "doc_cid": None, | |
| } | |
| } | |
| result = run(svc.handle_ingest(req)) | |
| assert result["output"]["chunks_indexed"] >= 1 | |
| assert result["output"]["was_duplicate"] is False | |
| doc_cid = result["output"]["doc_cid"] | |
| assert doc_cid is not None | |