File size: 4,896 Bytes
4cd8837
 
 
 
4aaae80
4cd8837
 
 
 
4aaae80
 
 
 
4cd8837
 
 
 
4aaae80
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import hashlib
import time
from dataclasses import dataclass


@dataclass(frozen=True)
class DhtContact:
    node_key: bytes  # 32-byte SHA-256 of node_id
    endpoint: str  # "host:port"
    node_id: str  # human-readable node identifier
    last_seen: float  # monotonic timestamp


@dataclass(frozen=True)
class DhtValue:
    key: bytes  # lookup key (arbitrary bytes)
    payload: dict  # stored data
    expires_at: int  # Unix epoch seconds


def _xor_distance(a: bytes, b: bytes) -> int:
    """XOR metric over equal-length byte strings, returned as integer."""
    # Pad to same length if needed
    la, lb = len(a), len(b)
    if la < lb:
        a = a.ljust(lb, b"\x00")
    elif lb < la:
        b = b.ljust(la, b"\x00")
    result = 0
    for x, y in zip(a, b, strict=False):
        result = (result << 8) | (x ^ y)
    return result


def _bucket_index(own_key: bytes, target_key: bytes) -> int:
    """Return the Kademlia bucket index [0, 255] for the target key."""
    dist = _xor_distance(own_key, target_key)
    if dist == 0:
        return 0
    # Most-significant bit position of the XOR distance
    bit_length = dist.bit_length()
    return bit_length - 1  # 0-based, max 255 for 32-byte keys


class RoutingTable:
    """256 buckets of K=8 contacts each."""

    def __init__(self, own_key: bytes, k: int = 8) -> None:
        self._own_key = own_key
        self._k = k
        self._buckets: list[list[DhtContact]] = [[] for _ in range(256)]

    def add_contact(self, contact: DhtContact) -> None:
        if contact.node_key == self._own_key:
            return
        idx = _bucket_index(self._own_key, contact.node_key)
        bucket = self._buckets[idx]

        # Replace existing entry for the same node_key
        for i, existing in enumerate(bucket):
            if existing.node_key == contact.node_key:
                bucket[i] = contact
                return

        if len(bucket) < self._k:
            bucket.append(contact)
        else:
            # Replace the oldest (least recently seen) contact
            oldest_idx = min(range(len(bucket)), key=lambda i: bucket[i].last_seen)
            bucket[oldest_idx] = contact

    def find_closest(self, key: bytes, k: int = 8) -> list[DhtContact]:
        """Return up to k contacts closest (by XOR) to key."""
        all_contacts = self.all_contacts()
        all_contacts.sort(key=lambda c: _xor_distance(c.node_key, key))
        return all_contacts[:k]

    def size(self) -> int:
        return sum(len(b) for b in self._buckets)

    def all_contacts(self) -> list[DhtContact]:
        contacts: list[DhtContact] = []
        for bucket in self._buckets:
            contacts.extend(bucket)
        return contacts


class KademliaNode:
    """Local Kademlia DHT node: routing table + local value store."""

    def __init__(self, node_id: str, k: int = 8, alpha: int = 3) -> None:
        self._node_id = node_id
        self._k = k
        self._alpha = alpha

        # Deterministic 32-byte key from node_id
        self.node_key: bytes = hashlib.sha256(node_id.encode()).digest()

        self.routing_table = RoutingTable(own_key=self.node_key, k=k)
        self.local_store: dict[bytes, DhtValue] = {}

    # ── Value store ───────────────────────────────────────────────────────────

    def store(self, key: bytes, value: dict, ttl: int = 3600) -> None:
        expires_at = int(time.time()) + ttl
        self.local_store[key] = DhtValue(key=key, payload=value, expires_at=expires_at)

    def find_value(self, key: bytes) -> DhtValue | None:
        entry = self.local_store.get(key)
        if entry is None:
            return None
        if int(time.time()) > entry.expires_at:
            del self.local_store[key]
            return None
        return entry

    # ── Routing ───────────────────────────────────────────────────────────────

    def find_closest(self, key: bytes, k: int = 8) -> list[DhtContact]:
        return self.routing_table.find_closest(key, k)

    def update_contact(self, contact: DhtContact) -> None:
        self.routing_table.add_contact(contact)

    # ── Maintenance ───────────────────────────────────────────────────────────

    def expire_stale(self) -> int:
        """Remove expired values. Returns count of removed entries."""
        now = int(time.time())
        stale = [k for k, v in self.local_store.items() if now > v.expires_at]
        for k in stale:
            del self.local_store[k]
        return len(stale)