from __future__ import annotations import time import uuid from dataclasses import dataclass from typing import Any from hearthnet.bus.capability import ( CapabilityDescriptor, CapabilityEntry, Handler, ParamsPredicate, RouteRequest, ) from hearthnet.bus.health import HealthTracker from hearthnet.bus.registry import Diff, RegistryEvent, Registry from hearthnet.bus.router import BusConfig, Router from hearthnet.types import CapabilityName, HearthNetError, Version class BusError(HearthNetError): pass class InMemoryTransport: def __init__(self) -> None: self._buses: dict[str, CapabilityBus] = {} def register(self, bus: CapabilityBus) -> None: self._buses[bus.node_id_full] = bus async def call(self, node_id: str, req: RouteRequest) -> dict[str, Any]: try: bus = self._buses[node_id] except KeyError as exc: raise BusError("partition", f"node {node_id} is not reachable") from exc inbound = RouteRequest( capability=req.capability, version_req=req.version_req, body=req.body, caller=req.caller, trace_id=req.trace_id, session_id=req.session_id, deadline_ms=req.deadline_ms, stream=req.stream, ) return await bus.handle_call(inbound, local_only=True) @dataclass(frozen=True) class CallTraceEvent: trace_id: str capability: CapabilityName from_node: str to_node: str result: str ms: float @dataclass(frozen=True) class TopologySnapshot: our_node_id: str peers: list[dict[str, Any]] capabilities_local: list[dict[str, Any]] capabilities_remote: list[dict[str, Any]] in_flight_total: int traces: list[CallTraceEvent] class CapabilityBus: def __init__( self, node_id_full: str, community_id: str, transport: InMemoryTransport | None = None, config: BusConfig | None = None, ) -> None: self.node_id_full = node_id_full self.community_id = community_id self.registry = Registry(our_node_id=node_id_full) self.health = HealthTracker() self.router = Router(self.registry, config) self.transport = transport or InMemoryTransport() self.transport.register(self) self._traces: list[CallTraceEvent] = [] self._offline_stash: list[tuple[CapabilityDescriptor, Handler, ParamsPredicate | None]] = [] def register_capability( self, descriptor: CapabilityDescriptor, handler: Handler, params_compatible: ParamsPredicate | None = None, ) -> None: self.registry.register_local(descriptor, handler, params_compatible) def register_service(self, service: Any) -> None: for item in service.capabilities(): descriptor, handler, *rest = item predicate = rest[0] if rest else None self.register_capability(descriptor, handler, predicate) async def call( self, capability: CapabilityName, version_req: Version, body: dict[str, Any], *, session_id: str | None = None, ) -> dict[str, Any]: req = RouteRequest( capability=capability, version_req=version_req, body=body, caller=self.node_id_full, trace_id=uuid.uuid4().hex, session_id=session_id, deadline_ms=int((time.monotonic() + 10) * 1000), ) return await self.handle_call(req) async def handle_call(self, req: RouteRequest, *, local_only: bool = False) -> dict[str, Any]: entry = self.router.route_sticky(req) if req.session_id else self.router.route(req) if entry is None: raise BusError("not_found", f"no provider for {req.capability}@{req.version_req}") started = time.monotonic() entry.in_flight += 1 try: if entry.is_local: if entry.handler is None: raise BusError("not_implemented", entry.descriptor.name) result = await entry.handler(req) elif local_only: raise BusError("not_found", f"remote entry cannot satisfy inbound {req.capability}") else: result = await self.transport.call(entry.node_id, req) elapsed = (time.monotonic() - started) * 1000 self.health.record(entry, success=True, latency_ms=elapsed) self._traces.append( CallTraceEvent( req.trace_id, req.capability, req.caller, entry.node_id, "ok", elapsed ) ) return result except HearthNetError as exc: elapsed = (time.monotonic() - started) * 1000 self.health.record(entry, success=False, latency_ms=elapsed) self._traces.append( CallTraceEvent( req.trace_id, req.capability, req.caller, entry.node_id, exc.code, elapsed ) ) raise finally: entry.in_flight -= 1 def deregister_internet_capabilities(self) -> int: removed = 0 for entry in list(self.registry.all_local()): if entry.descriptor.params.get("requires_internet"): removed_entry = self.registry.deregister_local( entry.descriptor.name, entry.descriptor.version ) if removed_entry and removed_entry.handler: self._offline_stash.append( ( removed_entry.descriptor, removed_entry.handler, removed_entry.params_compatible, ) ) removed += 1 return removed def restore_internet_capabilities(self) -> int: restored = 0 while self._offline_stash: descriptor, handler, predicate = self._offline_stash.pop(0) self.register_capability(descriptor, handler, predicate) restored += 1 return restored def topology_snapshot(self, peers: list[dict[str, Any]] | None = None) -> TopologySnapshot: return TopologySnapshot( our_node_id=self.node_id_full, peers=peers or [], capabilities_local=[_entry_view(entry) for entry in self.registry.all_local()], capabilities_remote=[_entry_view(entry) for entry in self.registry.all_remote()], in_flight_total=sum(entry.in_flight for entry in self.registry.all()), traces=list(self._traces[-50:]), ) def _entry_view(entry: CapabilityEntry) -> dict[str, Any]: return { "node_id": entry.node_id, "name": entry.descriptor.name, "version": entry.descriptor.version_str, "local": entry.is_local, "params": dict(entry.descriptor.params), "success_rate": entry.success_rate, "quarantined": entry.quarantined_until > time.monotonic(), }