File size: 5,169 Bytes
9fb722e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38cba90
 
 
 
 
 
3f78ea8
38cba90
 
 
 
9fb722e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f78ea8
7573216
 
 
 
 
90a59b3
 
 
 
 
 
 
 
 
 
 
 
7573216
 
 
9fb722e
 
 
 
 
 
7573216
9fb722e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from hearthnet.bus.capability import CapabilityDescriptor, CapabilityEntry, Handler, ParamsPredicate
from hearthnet.discovery.peers import PeerRecord
from hearthnet.types import CapabilityName, Version


@dataclass(frozen=True)
class Diff:
    added: list[CapabilityEntry]
    removed: list[CapabilityEntry]
    updated: list[CapabilityEntry]


@dataclass(frozen=True)
class RegistryEvent:
    """Emitted by Registry when capabilities change (M03 §3.3).

    kind in {"added", "removed", "updated"}
    """

    kind: str
    entry: CapabilityEntry


class Registry:
    def __init__(self, our_node_id: str) -> None:
        self.our_node_id = our_node_id
        self._entries: dict[tuple[str, CapabilityName, Version], CapabilityEntry] = {}

    def register_local(
        self,
        descriptor: CapabilityDescriptor,
        handler: Handler,
        params_compatible: ParamsPredicate | None = None,
    ) -> None:
        self._entries[(self.our_node_id, descriptor.name, descriptor.version)] = CapabilityEntry(
            node_id=self.our_node_id,
            descriptor=descriptor,
            is_local=True,
            handler=handler,
            params_compatible=params_compatible or (lambda offered, requested: True),
        )

    def deregister_local(self, name: CapabilityName, version: Version) -> CapabilityEntry | None:
        return self._entries.pop((self.our_node_id, name, version), None)

    def add_remote(self, peer: PeerRecord, descriptor: CapabilityDescriptor) -> CapabilityEntry:
        endpoint = peer.endpoints[0] if peer.endpoints else None

        # Use a general params-compatibility check for remote entries so that
        # corpus/model/lang routing works across the mesh without needing to
        # transfer Python callables over the wire.
        def _remote_params_compatible(offered: dict, requested: dict) -> bool:
            for key, value in requested.items():
                if value is None:
                    continue
                if key == "model":
                    # A capability may advertise a catalogue of models it serves
                    # ("models") in addition to its primary ("model").
                    catalogue = offered.get("models")
                    if catalogue and value in catalogue:
                        continue
                    if offered.get("model") == value:
                        continue
                    return False
                if key in offered and offered[key] != value:
                    return False
            return True

        entry = CapabilityEntry(
            node_id=peer.node_id_full,
            descriptor=descriptor,
            is_local=False,
            endpoint=endpoint,
            last_seen=peer.last_seen,
            params_compatible=_remote_params_compatible,
        )
        self._entries[(peer.node_id_full, descriptor.name, descriptor.version)] = entry
        return entry

    def update_from_peer_manifest(self, peer: PeerRecord, manifest: dict[str, Any]) -> Diff:
        before = [
            entry
            for entry in self.all()
            if entry.node_id == peer.node_id_full and not entry.is_local
        ]
        for entry in before:
            self._entries.pop(
                (entry.node_id, entry.descriptor.name, entry.descriptor.version), None
            )
        added: list[CapabilityEntry] = []
        for raw in manifest.get("capabilities", []):
            descriptor = CapabilityDescriptor(
                name=raw["name"],
                version=_parse_version(raw.get("version", "1.0")),
                stability=raw.get("stability", "stable"),
                params=dict(raw.get("params", {})),
                max_concurrent=int(raw.get("max_concurrent", 1)),
            )
            added.append(self.add_remote(peer, descriptor))
        return Diff(added=added, removed=before, updated=[])

    def remove_peer(self, node_id: str) -> int:
        keys = [
            key
            for key, entry in self._entries.items()
            if entry.node_id == node_id and not entry.is_local
        ]
        for key in keys:
            self._entries.pop(key, None)
        return len(keys)

    def find(self, name: CapabilityName, version_req: Version) -> list[CapabilityEntry]:
        return [
            entry
            for entry in self._entries.values()
            if entry.descriptor.name == name and _compatible(entry.descriptor.version, version_req)
        ]

    def all_local(self) -> list[CapabilityEntry]:
        return [entry for entry in self._entries.values() if entry.is_local]

    def all_remote(self) -> list[CapabilityEntry]:
        return [entry for entry in self._entries.values() if not entry.is_local]

    def all(self) -> list[CapabilityEntry]:
        return list(self._entries.values())


def _compatible(offered: Version, requested: Version) -> bool:
    return offered[0] == requested[0] and offered[1] >= requested[1]


def _parse_version(raw: str) -> Version:
    major, minor = raw.split(".", 1)
    return int(major), int(minor)