Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from hearthnet.bus.capability import ( | |
| CapabilityDescriptor, | |
| CapabilityEntry, | |
| Handler, | |
| ParamsPredicate, | |
| RouteRequest, | |
| ) | |
| from hearthnet.bus.health import HealthTracker | |
| from hearthnet.bus.registry import Diff, RegistryEvent, Registry | |
| from hearthnet.bus.router import BusConfig, Router | |
| from hearthnet.types import CapabilityName, HearthNetError, Version | |
| class BusError(HearthNetError): | |
| pass | |
| class InMemoryTransport: | |
| def __init__(self) -> None: | |
| self._buses: dict[str, CapabilityBus] = {} | |
| def register(self, bus: CapabilityBus) -> None: | |
| self._buses[bus.node_id_full] = bus | |
| async def call(self, node_id: str, req: RouteRequest) -> dict[str, Any]: | |
| try: | |
| bus = self._buses[node_id] | |
| except KeyError as exc: | |
| raise BusError("partition", f"node {node_id} is not reachable") from exc | |
| inbound = RouteRequest( | |
| capability=req.capability, | |
| version_req=req.version_req, | |
| body=req.body, | |
| caller=req.caller, | |
| trace_id=req.trace_id, | |
| session_id=req.session_id, | |
| deadline_ms=req.deadline_ms, | |
| stream=req.stream, | |
| ) | |
| return await bus.handle_call(inbound, local_only=True) | |
| class CallTraceEvent: | |
| trace_id: str | |
| capability: CapabilityName | |
| from_node: str | |
| to_node: str | |
| result: str | |
| ms: float | |
| class TopologySnapshot: | |
| our_node_id: str | |
| peers: list[dict[str, Any]] | |
| capabilities_local: list[dict[str, Any]] | |
| capabilities_remote: list[dict[str, Any]] | |
| in_flight_total: int | |
| traces: list[CallTraceEvent] | |
| class CapabilityBus: | |
| def __init__( | |
| self, | |
| node_id_full: str, | |
| community_id: str, | |
| transport: InMemoryTransport | None = None, | |
| config: BusConfig | None = None, | |
| ) -> None: | |
| self.node_id_full = node_id_full | |
| self.community_id = community_id | |
| self.registry = Registry(our_node_id=node_id_full) | |
| self.health = HealthTracker() | |
| self.router = Router(self.registry, config) | |
| self.transport = transport or InMemoryTransport() | |
| self.transport.register(self) | |
| self._traces: list[CallTraceEvent] = [] | |
| self._offline_stash: list[tuple[CapabilityDescriptor, Handler, ParamsPredicate | None]] = [] | |
| def register_capability( | |
| self, | |
| descriptor: CapabilityDescriptor, | |
| handler: Handler, | |
| params_compatible: ParamsPredicate | None = None, | |
| ) -> None: | |
| self.registry.register_local(descriptor, handler, params_compatible) | |
| def register_service(self, service: Any) -> None: | |
| for item in service.capabilities(): | |
| descriptor, handler, *rest = item | |
| predicate = rest[0] if rest else None | |
| self.register_capability(descriptor, handler, predicate) | |
| async def call( | |
| self, | |
| capability: CapabilityName, | |
| version_req: Version, | |
| body: dict[str, Any], | |
| *, | |
| session_id: str | None = None, | |
| ) -> dict[str, Any]: | |
| req = RouteRequest( | |
| capability=capability, | |
| version_req=version_req, | |
| body=body, | |
| caller=self.node_id_full, | |
| trace_id=uuid.uuid4().hex, | |
| session_id=session_id, | |
| deadline_ms=int((time.monotonic() + 10) * 1000), | |
| ) | |
| return await self.handle_call(req) | |
| async def handle_call(self, req: RouteRequest, *, local_only: bool = False) -> dict[str, Any]: | |
| entry = self.router.route_sticky(req) if req.session_id else self.router.route(req) | |
| if entry is None: | |
| raise BusError("not_found", f"no provider for {req.capability}@{req.version_req}") | |
| started = time.monotonic() | |
| entry.in_flight += 1 | |
| try: | |
| if entry.is_local: | |
| if entry.handler is None: | |
| raise BusError("not_implemented", entry.descriptor.name) | |
| result = await entry.handler(req) | |
| elif local_only: | |
| raise BusError("not_found", f"remote entry cannot satisfy inbound {req.capability}") | |
| else: | |
| result = await self.transport.call(entry.node_id, req) | |
| elapsed = (time.monotonic() - started) * 1000 | |
| self.health.record(entry, success=True, latency_ms=elapsed) | |
| self._traces.append( | |
| CallTraceEvent( | |
| req.trace_id, req.capability, req.caller, entry.node_id, "ok", elapsed | |
| ) | |
| ) | |
| return result | |
| except HearthNetError as exc: | |
| elapsed = (time.monotonic() - started) * 1000 | |
| self.health.record(entry, success=False, latency_ms=elapsed) | |
| self._traces.append( | |
| CallTraceEvent( | |
| req.trace_id, req.capability, req.caller, entry.node_id, exc.code, elapsed | |
| ) | |
| ) | |
| raise | |
| finally: | |
| entry.in_flight -= 1 | |
| def deregister_internet_capabilities(self) -> int: | |
| removed = 0 | |
| for entry in list(self.registry.all_local()): | |
| if entry.descriptor.params.get("requires_internet"): | |
| removed_entry = self.registry.deregister_local( | |
| entry.descriptor.name, entry.descriptor.version | |
| ) | |
| if removed_entry and removed_entry.handler: | |
| self._offline_stash.append( | |
| ( | |
| removed_entry.descriptor, | |
| removed_entry.handler, | |
| removed_entry.params_compatible, | |
| ) | |
| ) | |
| removed += 1 | |
| return removed | |
| def restore_internet_capabilities(self) -> int: | |
| restored = 0 | |
| while self._offline_stash: | |
| descriptor, handler, predicate = self._offline_stash.pop(0) | |
| self.register_capability(descriptor, handler, predicate) | |
| restored += 1 | |
| return restored | |
| def topology_snapshot(self, peers: list[dict[str, Any]] | None = None) -> TopologySnapshot: | |
| return TopologySnapshot( | |
| our_node_id=self.node_id_full, | |
| peers=peers or [], | |
| capabilities_local=[_entry_view(entry) for entry in self.registry.all_local()], | |
| capabilities_remote=[_entry_view(entry) for entry in self.registry.all_remote()], | |
| in_flight_total=sum(entry.in_flight for entry in self.registry.all()), | |
| traces=list(self._traces[-50:]), | |
| ) | |
| def _entry_view(entry: CapabilityEntry) -> dict[str, Any]: | |
| return { | |
| "node_id": entry.node_id, | |
| "name": entry.descriptor.name, | |
| "version": entry.descriptor.version_str, | |
| "local": entry.is_local, | |
| "params": dict(entry.descriptor.params), | |
| "success_rate": entry.success_rate, | |
| "quarantined": entry.quarantined_until > time.monotonic(), | |
| } | |