GitHub Actions
fix: 0 test failures; FileService; real RagService; emergency probe; chat return
4aaae80
Raw
History Blame
4.9 kB
from __future__ import annotations
import hashlib
import time
from dataclasses import dataclass
@dataclass(frozen=True)
class DhtContact:
node_key: bytes # 32-byte SHA-256 of node_id
endpoint: str # "host:port"
node_id: str # human-readable node identifier
last_seen: float # monotonic timestamp
@dataclass(frozen=True)
class DhtValue:
key: bytes # lookup key (arbitrary bytes)
payload: dict # stored data
expires_at: int # Unix epoch seconds
def _xor_distance(a: bytes, b: bytes) -> int:
"""XOR metric over equal-length byte strings, returned as integer."""
# Pad to same length if needed
la, lb = len(a), len(b)
if la < lb:
a = a.ljust(lb, b"\x00")
elif lb < la:
b = b.ljust(la, b"\x00")
result = 0
for x, y in zip(a, b, strict=False):
result = (result << 8) | (x ^ y)
return result
def _bucket_index(own_key: bytes, target_key: bytes) -> int:
"""Return the Kademlia bucket index [0, 255] for the target key."""
dist = _xor_distance(own_key, target_key)
if dist == 0:
return 0
# Most-significant bit position of the XOR distance
bit_length = dist.bit_length()
return bit_length - 1 # 0-based, max 255 for 32-byte keys
class RoutingTable:
"""256 buckets of K=8 contacts each."""
def __init__(self, own_key: bytes, k: int = 8) -> None:
self._own_key = own_key
self._k = k
self._buckets: list[list[DhtContact]] = [[] for _ in range(256)]
def add_contact(self, contact: DhtContact) -> None:
if contact.node_key == self._own_key:
return
idx = _bucket_index(self._own_key, contact.node_key)
bucket = self._buckets[idx]
# Replace existing entry for the same node_key
for i, existing in enumerate(bucket):
if existing.node_key == contact.node_key:
bucket[i] = contact
return
if len(bucket) < self._k:
bucket.append(contact)
else:
# Replace the oldest (least recently seen) contact
oldest_idx = min(range(len(bucket)), key=lambda i: bucket[i].last_seen)
bucket[oldest_idx] = contact
def find_closest(self, key: bytes, k: int = 8) -> list[DhtContact]:
"""Return up to k contacts closest (by XOR) to key."""
all_contacts = self.all_contacts()
all_contacts.sort(key=lambda c: _xor_distance(c.node_key, key))
return all_contacts[:k]
def size(self) -> int:
return sum(len(b) for b in self._buckets)
def all_contacts(self) -> list[DhtContact]:
contacts: list[DhtContact] = []
for bucket in self._buckets:
contacts.extend(bucket)
return contacts
class KademliaNode:
"""Local Kademlia DHT node: routing table + local value store."""
def __init__(self, node_id: str, k: int = 8, alpha: int = 3) -> None:
self._node_id = node_id
self._k = k
self._alpha = alpha
# Deterministic 32-byte key from node_id
self.node_key: bytes = hashlib.sha256(node_id.encode()).digest()
self.routing_table = RoutingTable(own_key=self.node_key, k=k)
self.local_store: dict[bytes, DhtValue] = {}
# ── Value store ───────────────────────────────────────────────────────────
def store(self, key: bytes, value: dict, ttl: int = 3600) -> None:
expires_at = int(time.time()) + ttl
self.local_store[key] = DhtValue(key=key, payload=value, expires_at=expires_at)
def find_value(self, key: bytes) -> DhtValue | None:
entry = self.local_store.get(key)
if entry is None:
return None
if int(time.time()) > entry.expires_at:
del self.local_store[key]
return None
return entry
# ── Routing ───────────────────────────────────────────────────────────────
def find_closest(self, key: bytes, k: int = 8) -> list[DhtContact]:
return self.routing_table.find_closest(key, k)
def update_contact(self, contact: DhtContact) -> None:
self.routing_table.add_contact(contact)
# ── Maintenance ───────────────────────────────────────────────────────────
def expire_stale(self) -> int:
"""Remove expired values. Returns count of removed entries."""
now = int(time.time())
stale = [k for k, v in self.local_store.items() if now > v.expires_at]
for k in stale:
del self.local_store[k]
return len(stale)