File size: 11,335 Bytes
4aaae80
8514223
 
 
 
 
 
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
31c93b1
 
4aaae80
31c93b1
8514223
31c93b1
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
31c93b1
6affa2f
3f78ea8
6affa2f
31c93b1
 
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
6a14ec5
 
 
 
 
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""M01 - Node identity: Ed25519 key management.

Spec: docs/M01-identity.md §3.1
Impl-ref: impl_ref.md §5

Keys stored in keys_dir (default ~/.hearthnet/keys/).
Sign/verify via PyNaCl Ed25519. canonical_json() for deterministic signing.
"""

from __future__ import annotations

import base64
import json
import os
import stat
from dataclasses import dataclass
from pathlib import Path
from typing import Any

try:
    import nacl.exceptions
    import nacl.signing

    _NACL_AVAILABLE = True
except ImportError:  # pragma: no cover
    _NACL_AVAILABLE = False

# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------

NodeID = str  # "ed25519:XXXX-XXXX-XXXX-XXXX" (short) or "ed25519:<b64url>" (full)
Signature = str  # "ed25519:<b64url>"


class IdentityError(Exception):
    """Raised for all identity-layer failures."""

    def __init__(self, code: str, reason: str = "") -> None:
        super().__init__(reason or code)
        self.code = code
        self.reason = reason


@dataclass(frozen=True)
class KeyPair:
    signing_key: Any  # nacl.signing.SigningKey
    verify_key: Any  # nacl.signing.VerifyKey
    node_id_short: str
    node_id_full: str


# ---------------------------------------------------------------------------
# ID helpers
# ---------------------------------------------------------------------------


def short_node_id(verify_key_bytes: bytes) -> str:
    """First 8 bytes base32, grouped in 4-char segments: 'ed25519:XXXX-XXXX-XXXX-XXXX'."""
    raw = base64.b32encode(verify_key_bytes[:8]).decode("ascii")
    grouped = "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
    return f"ed25519:{grouped}"


def full_node_id(verify_key_bytes: bytes) -> str:
    """All 32 bytes base64url no-pad: 'ed25519:<b64>'."""
    b64 = base64.urlsafe_b64encode(verify_key_bytes).rstrip(b"=").decode("ascii")
    return f"ed25519:{b64}"


def parse_node_id(node_id: str) -> bytes:
    """Decode a full node_id to 32 bytes. Short form raises ValueError."""
    import re

    if not node_id.startswith("ed25519:"):
        raise ValueError(f"node_id must start with 'ed25519:': {node_id!r}")
    payload = node_id[len("ed25519:") :]
    # Short form is b32-with-dashes: groups of [A-Z2-7=]{1,4} separated by '-'
    # e.g. "SQ2J-OH7E-LCMU-Y===" — always shorter than 30 chars and matches this pattern.
    # Full form is 43-char base64url (no '=' padding).
    if re.fullmatch(r"[A-Z2-7=]{1,4}(-[A-Z2-7=]{1,4}){1,}", payload):
        raise ValueError("Short node IDs cannot be decoded to raw bytes; use full form.")
    # Add padding back for base64url decoding
    padded = payload + "=" * (4 - len(payload) % 4 if len(payload) % 4 != 0 else 0)
    raw = base64.urlsafe_b64decode(padded)
    if len(raw) != 32:
        raise ValueError(f"Expected 32 bytes, got {len(raw)}")
    return raw


# ---------------------------------------------------------------------------
# Canonical JSON
# ---------------------------------------------------------------------------


def canonical_json(obj: Any) -> bytes:
    """Canonical JSON: sorted keys, no whitespace, numbers stripped of trailing zeros, UTF-8."""
    serialised = json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
    # Strip trailing zeros from numbers: 1.0 -> 1, 1.10 -> 1.1
    # We post-process the JSON string carefully without breaking string contents.
    result = _strip_trailing_zeros(serialised)
    return result.encode("utf-8")


def _strip_trailing_zeros(s: str) -> str:
    """Remove trailing zeros from JSON numbers without touching string values."""
    import re

    # Match JSON numbers (integers, floats, exponent forms) that appear outside strings
    # We parse character-by-character to skip string literals.
    out: list[str] = []
    i = 0
    n = len(s)
    while i < n:
        c = s[i]
        if c == '"':
            # Scan to end of string, respecting escapes
            out.append(c)
            i += 1
            while i < n:
                ch = s[i]
                out.append(ch)
                if ch == "\\":
                    i += 1
                    if i < n:
                        out.append(s[i])
                elif ch == '"':
                    i += 1
                    break
                i += 1
        else:
            # Look for a number token
            m = re.match(r"-?(?:0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?", s[i:])
            if m and (m.group(1) or m.group(2)):
                num_str = m.group(0)
                # Parse and reformat
                try:
                    val = float(num_str)
                    # If it represents an integer value, emit as integer
                    if val == int(val) and "e" not in num_str.lower():
                        out.append(str(int(val)))
                    else:
                        # Strip trailing zeros from decimal part
                        formatted = f"{val:g}"
                        out.append(formatted)
                except (ValueError, OverflowError):
                    out.append(num_str)
                i += len(num_str)
            else:
                out.append(c)
                i += 1
    return "".join(out)


# ---------------------------------------------------------------------------
# Signing / Verification
# ---------------------------------------------------------------------------


def sign_payload(payload: dict, kp: KeyPair) -> dict:
    """Return a copy of payload with 'signature' field added (signs over payload without signature)."""
    if not _NACL_AVAILABLE:
        raise IdentityError("keys_invalid", reason="PyNaCl not installed")
    unsigned = {k: v for k, v in payload.items() if k != "signature"}
    raw = canonical_json(unsigned)
    try:
        signed = kp.signing_key.sign(raw)
        sig_bytes = signed.signature
    except Exception as exc:
        raise IdentityError("sign_failed", reason=str(exc)) from exc
    sig_b64 = base64.urlsafe_b64encode(sig_bytes).rstrip(b"=").decode("ascii")
    result = dict(unsigned)
    result["signature"] = f"ed25519:{sig_b64}"
    return result


def verify_payload(payload: dict, vk: Any) -> bool:  # vk: nacl.signing.VerifyKey
    """Verify the 'signature' field of payload against vk. Returns True or raises IdentityError."""
    if not _NACL_AVAILABLE:
        raise IdentityError("keys_invalid", reason="PyNaCl not installed")
    raw_sig = payload.get("signature", "")
    if not raw_sig.startswith("ed25519:"):
        raise IdentityError("verify_failed", reason="signature field missing or malformed")
    sig_b64 = raw_sig[len("ed25519:") :]
    padding = 4 - len(sig_b64) % 4
    if padding != 4:
        sig_b64 += "=" * padding
    try:
        sig_bytes = base64.urlsafe_b64decode(sig_b64)
    except Exception as exc:
        raise IdentityError("verify_failed", reason=f"bad signature encoding: {exc}") from exc
    unsigned = {k: v for k, v in payload.items() if k != "signature"}
    raw = canonical_json(unsigned)
    try:
        vk.verify(raw, sig_bytes)
    except nacl.exceptions.BadSignatureError as exc:
        raise IdentityError("verify_failed", reason="signature verification failed") from exc
    except Exception as exc:
        raise IdentityError("verify_failed", reason=str(exc)) from exc
    return True


def verify_payload_with_node_id(payload: dict, expected_node_id_full: str) -> bool:
    """Verify payload signature using the public key encoded in expected_node_id_full."""
    if not _NACL_AVAILABLE:
        raise IdentityError("keys_invalid", reason="PyNaCl not installed")
    try:
        vk_bytes = parse_node_id(expected_node_id_full)
    except ValueError as exc:
        raise IdentityError("bad_node_id", reason=str(exc)) from exc
    try:
        vk = nacl.signing.VerifyKey(vk_bytes)
    except Exception as exc:
        raise IdentityError("keys_invalid", reason=str(exc)) from exc
    return verify_payload(payload, vk)


# ---------------------------------------------------------------------------
# Key I/O
# ---------------------------------------------------------------------------


def generate() -> KeyPair:
    """Generate a fresh Ed25519 keypair using os.urandom."""
    if not _NACL_AVAILABLE:
        raise IdentityError("keys_invalid", reason="PyNaCl not installed")
    seed = os.urandom(32)
    sk = nacl.signing.SigningKey(seed)
    vk = sk.verify_key
    vk_bytes = bytes(vk)
    return KeyPair(
        signing_key=sk,
        verify_key=vk,
        node_id_short=short_node_id(vk_bytes),
        node_id_full=full_node_id(vk_bytes),
    )


def save(kp: KeyPair, keys_dir: Path) -> None:
    """Save signing key (chmod 0600) and verify key to keys_dir."""
    keys_dir.mkdir(parents=True, exist_ok=True)
    priv_path = keys_dir / "device.ed25519"
    pub_path = keys_dir / "device.pub"
    # Write private key (raw 32-byte seed, base64url encoded)
    sk_bytes = bytes(kp.signing_key)
    priv_path.write_bytes(base64.urlsafe_b64encode(sk_bytes).rstrip(b"=") + b"\n")
    # Restrict permissions on POSIX
    from contextlib import suppress

    with suppress(AttributeError):
        os.chmod(priv_path, stat.S_IRUSR | stat.S_IWUSR)  # 0600
    # Write public key
    vk_bytes = bytes(kp.verify_key)
    pub_path.write_bytes(base64.urlsafe_b64encode(vk_bytes).rstrip(b"=") + b"\n")


def load(keys_dir: Path) -> KeyPair:
    """Load KeyPair from device.ed25519 + device.pub in keys_dir."""
    if not _NACL_AVAILABLE:
        raise IdentityError("keys_invalid", reason="PyNaCl not installed")
    priv_path = keys_dir / "device.ed25519"
    pub_path = keys_dir / "device.pub"
    if not priv_path.exists() or not pub_path.exists():
        raise IdentityError("keys_missing", reason=f"Key files not found in {keys_dir}")
    # Check permissions on POSIX only. POSIX mode bits are not meaningful on
    # Windows (NTFS files commonly report 0o666 regardless of ACLs), so the
    # check would raise false positives there. stat.S_IMODE does not raise on
    # Windows, so an explicit os.name guard is required.
    if os.name == "posix":
        mode = oct(stat.S_IMODE(priv_path.stat().st_mode))
        if not mode.endswith("600") and not mode.endswith("400"):
            raise IdentityError(
                "keys_permissions",
                reason=f"Private key {priv_path} has unsafe permissions {mode}",
            )
    try:
        sk_b64 = priv_path.read_text().strip()
        padding = 4 - len(sk_b64) % 4
        if padding != 4:
            sk_b64 += "=" * padding
        sk_bytes = base64.urlsafe_b64decode(sk_b64)
        sk = nacl.signing.SigningKey(sk_bytes)
    except IdentityError:
        raise
    except Exception as exc:
        raise IdentityError("keys_invalid", reason=str(exc)) from exc
    vk = sk.verify_key
    vk_bytes = bytes(vk)
    return KeyPair(
        signing_key=sk,
        verify_key=vk,
        node_id_short=short_node_id(vk_bytes),
        node_id_full=full_node_id(vk_bytes),
    )


def load_or_generate(keys_dir: Path) -> KeyPair:
    """Load keys if present, otherwise generate and persist."""
    priv_path = keys_dir / "device.ed25519"
    if priv_path.exists():
        return load(keys_dir)
    kp = generate()
    save(kp, keys_dir)
    return kp