"""Relay client — joins a relay hub, polls its mailbox, and does RPC over it. This is the NAT-bound counterpart to :mod:`hearthnet.transport.relay_hub`. A local node that cannot accept inbound connections uses a :class:`RelayClient` to: 1. **join** the hub and register the other members' capabilities locally (so the bus can route ``llm.chat`` / ``rag.query`` / ``chat.deliver`` to them); 2. run a background **poll loop** that drains its mailbox and: * dispatches inbound ``request`` envelopes to the local bus, then ships the ``response`` back through the hub; * resolves pending outbound calls when their ``response`` arrives; * applies ``roster`` gossip so newly-joined peers become routable (all-to-all); 3. send outbound calls via :meth:`call_remote`, correlating request/response by id. :class:`RelayStrategy` adapts a :class:`RelayClient` to the :class:`~hearthnet.bus.transport.DeliveryStrategy` protocol so :class:`~hearthnet.bus.transport.CompositeTransport` can use it as a fallback. """ from __future__ import annotations import asyncio import contextlib import logging import time import uuid from typing import Any from hearthnet.bus import BusError from hearthnet.bus.capability import RouteRequest from hearthnet.bus.transport import NOT_HANDLED from hearthnet.discovery.peers import PeerRecord, PeerRegistry from hearthnet.types import Endpoint _log = logging.getLogger(__name__) # How long an outbound relayed call waits for its response before failing. RELAY_CALL_TIMEOUT_SECONDS = 30.0 def _parse_version(raw: Any) -> tuple[int, int]: parts = str(raw or "1.0").split(".") if len(parts) < 2: parts.append("0") return (int(parts[0]), int(parts[1])) class RelayClient: """Connects a local node to a relay hub for all-to-all messaging over NAT.""" def __init__( self, relay_url: str, *, node_id: str, display_name: str, community_id: str, bus: Any, peers: PeerRegistry, token: str | None = None, poll_timeout: float = 25.0, ) -> None: self._base = relay_url.rstrip("/") self._node_id = node_id self._display_name = display_name self._community_id = community_id self._bus = bus self._peers = peers self._token = token self._poll_timeout = poll_timeout self._client: Any = None self._members: set[str] = set() self._pending: dict[str, asyncio.Future] = {} self._poll_task: asyncio.Task | None = None self._running = False # node_id of the hub's own in-process node, learned from the join # response. That node is directly reachable at the relay base URL. self._hub_node_id: str | None = None @property def members(self) -> set[str]: return set(self._members) # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ async def join(self) -> dict[str, Any]: """Join the hub, register the returned roster, and start polling.""" import httpx if self._client is None: self._client = httpx.AsyncClient(timeout=60.0) caps = sorted({e.descriptor.name for e in self._bus.registry.all_local()}) payload = { "node_id": self._node_id, "display_name": self._display_name, "community_id": self._community_id, "capabilities": caps, } if self._token: payload["token"] = self._token resp = await self._client.post(f"{self._base}/relay/v1/join", json=payload) resp.raise_for_status() data = resp.json() if data.get("error"): raise BusError(str(data["error"]), str(data.get("message", ""))) # The hub's own in-process node is directly reachable over HTTP at the # relay base URL. Record it so _apply_roster can give it a direct-HTTP # endpoint (bypasses the mailbox poll loop, robust across event loops). self._hub_node_id = data.get("hub_node_id") self._apply_roster(data.get("roster", [])) # The hub node may be absent from the roster (or present only as a stale # entry from a previous deployment). Register it authoritatively via its # manifest with a direct-HTTP endpoint so chat/RPC to the hub always # works regardless of roster state or the caller's event loop. await self._register_hub_direct() self._running = True if self._poll_task is None or self._poll_task.done(): self._poll_task = asyncio.create_task(self._poll_loop(), name="relay-poll") return data async def close(self) -> None: self._running = False if self._poll_task is not None: self._poll_task.cancel() with contextlib.suppress(asyncio.CancelledError, Exception): await self._poll_task self._poll_task = None for fut in self._pending.values(): if not fut.done(): fut.cancel() self._pending.clear() if self._client is not None: with contextlib.suppress(Exception): await self._client.aclose() self._client = None # ------------------------------------------------------------------ # Outbound RPC (used by RelayStrategy) # ------------------------------------------------------------------ async def call_remote(self, node_id: str, req: RouteRequest) -> Any: """Deliver *req* to *node_id* via the hub and await its response. Returns :data:`NOT_HANDLED` if *node_id* is not a known relay member, so the composite transport can try other strategies. """ if node_id not in self._members or self._client is None: return NOT_HANDLED correlation_id = uuid.uuid4().hex loop = asyncio.get_event_loop() fut: asyncio.Future = loop.create_future() self._pending[correlation_id] = fut envelope = { "kind": "request", "from": self._node_id, "correlation_id": correlation_id, "capability": req.capability, "version": f"{req.version_req[0]}.{req.version_req[1]}", "body": {"params": req.body.get("params", {}), "input": req.body.get("input", {})}, } try: sent = await self._send(node_id, envelope) except Exception as exc: self._pending.pop(correlation_id, None) raise BusError("partition", f"relay send failed: {exc}") from exc if sent.get("error"): self._pending.pop(correlation_id, None) raise BusError(str(sent["error"]), str(sent.get("message", ""))) try: return await asyncio.wait_for(fut, timeout=RELAY_CALL_TIMEOUT_SECONDS) except TimeoutError as exc: self._pending.pop(correlation_id, None) raise BusError("timeout", f"relay call to {node_id} timed out") from exc # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ async def _send(self, to: str, envelope: dict[str, Any]) -> dict[str, Any]: resp = await self._client.post( f"{self._base}/relay/v1/send", json={"to": to, "envelope": envelope} ) resp.raise_for_status() return resp.json() async def _poll_loop(self) -> None: while self._running: try: resp = await self._client.get( f"{self._base}/relay/v1/poll", params={"node_id": self._node_id, "timeout": self._poll_timeout}, ) resp.raise_for_status() data = resp.json() except asyncio.CancelledError: raise except Exception as exc: _log.debug("relay poll error: %s", exc) await asyncio.sleep(2.0) continue if data.get("error") == "not_joined": with contextlib.suppress(Exception): await self.join() continue for envelope in data.get("envelopes", []): await self._handle_envelope(envelope) async def _handle_envelope(self, envelope: dict[str, Any]) -> None: kind = envelope.get("kind") if kind == "request": await self._serve_request(envelope) elif kind == "response": self._resolve_response(envelope) elif kind == "roster": self._apply_roster(envelope.get("members", [])) async def _serve_request(self, envelope: dict[str, Any]) -> None: from_node = envelope.get("from", "") correlation_id = envelope.get("correlation_id", "") req = RouteRequest( capability=envelope.get("capability", ""), version_req=_parse_version(envelope.get("version", "1.0")), body=envelope.get("body", {}), caller=from_node, trace_id=correlation_id or uuid.uuid4().hex, deadline_ms=int((time.monotonic() + RELAY_CALL_TIMEOUT_SECONDS) * 1000), ) response: dict[str, Any] = { "kind": "response", "from": self._node_id, "correlation_id": correlation_id, } try: response["result"] = await self._bus.handle_call(req, local_only=True) except BusError as exc: response["error"] = exc.code response["message"] = str(exc) except Exception as exc: # report any handler failure back to the caller response["error"] = "internal_error" response["message"] = str(exc) if from_node: with contextlib.suppress(Exception): await self._send(from_node, response) def _resolve_response(self, envelope: dict[str, Any]) -> None: correlation_id = envelope.get("correlation_id", "") fut = self._pending.pop(correlation_id, None) if fut is None or fut.done(): return if envelope.get("error"): fut.set_exception( BusError(str(envelope["error"]), str(envelope.get("message", ""))) ) else: fut.set_result(envelope.get("result", {})) def _apply_roster(self, members: list[dict[str, Any]]) -> None: for member in members: node_id = member.get("node_id") if not node_id or node_id == self._node_id: continue self._members.add(node_id) # The hub's own node is directly reachable over HTTP at the relay # base URL — give it a direct http/https endpoint so the composite # transport's direct-HTTP path serves it via /bus/v1/call. This is # robust across event loops (no mailbox poll-loop future needed). # All other peers are NAT-bound: mark them with a "relay" endpoint so # the direct-HTTP path skips them and the relay strategy delivers. if node_id == self._hub_node_id: endpoint = self._direct_http_endpoint() else: endpoint = Endpoint(transport="relay", host=self._base, port=0) record = PeerRecord( node_id_full=node_id, display_name=member.get("display_name", node_id[:20]), community_id=member.get("community_id", self._community_id), endpoints=[endpoint], source="relay", ) self._peers.upsert(record) manifest = { "node_id": node_id, "capabilities": [{"name": name} for name in member.get("capabilities", [])], } with contextlib.suppress(Exception): self._bus.registry.update_from_peer_manifest(record, manifest) def _direct_http_endpoint(self) -> Endpoint: """Build a direct http/https Endpoint from the relay base URL.""" from urllib.parse import urlparse parsed = urlparse(self._base) scheme = parsed.scheme or "https" host = parsed.hostname or self._base port = parsed.port or (443 if scheme == "https" else 80) return Endpoint(transport=scheme, host=host, port=port) async def _register_hub_direct(self) -> None: """Register the hub's own node with a direct-HTTP endpoint. The hub (e.g. the HF Space) serves its bus at ``{base}/bus/v1/call`` and is always directly reachable at the relay base URL. We fetch its manifest to learn its authoritative node id + capabilities, then register every capability with a direct http/https endpoint. This guarantees chat/RPC to the hub works even when the hub is missing from the roster or only present as a stale entry, and regardless of which event loop issues the call. """ if self._base.split("://", 1)[0] not in ("http", "https"): return try: resp = await self._client.get(f"{self._base}/manifest") resp.raise_for_status() manifest = resp.json() except Exception: return hub_id = manifest.get("node_id") if not hub_id or hub_id == self._node_id: return # Prefer the manifest's node id as the hub identity. self._hub_node_id = hub_id self._members.add(hub_id) record = PeerRecord( node_id_full=hub_id, display_name=manifest.get("display_name", hub_id[:20]), community_id=manifest.get("community_id", self._community_id), endpoints=[self._direct_http_endpoint()], source="relay", ) self._peers.upsert(record) with contextlib.suppress(Exception): self._bus.registry.update_from_peer_manifest(record, manifest) class RelayStrategy: """Adapts a :class:`RelayClient` to the bus ``DeliveryStrategy`` protocol.""" name = "relay" def __init__(self, client: RelayClient) -> None: self._client = client async def try_deliver(self, node_id: str, req: RouteRequest) -> Any: return await self._client.call_remote(node_id, req)