GitHub Actions
Quality improvements: Unicode chars, Token class, imports, type hints, formatting
3f78ea8
Raw
History Blame
11.3 kB
"""M01 - Node identity: Ed25519 key management.
Spec: docs/M01-identity.md §3.1
Impl-ref: impl_ref.md §5
Keys stored in keys_dir (default ~/.hearthnet/keys/).
Sign/verify via PyNaCl Ed25519. canonical_json() for deterministic signing.
"""
from __future__ import annotations
import base64
import json
import os
import stat
from dataclasses import dataclass
from pathlib import Path
from typing import Any
try:
import nacl.exceptions
import nacl.signing
_NACL_AVAILABLE = True
except ImportError: # pragma: no cover
_NACL_AVAILABLE = False
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
NodeID = str # "ed25519:XXXX-XXXX-XXXX-XXXX" (short) or "ed25519:<b64url>" (full)
Signature = str # "ed25519:<b64url>"
class IdentityError(Exception):
"""Raised for all identity-layer failures."""
def __init__(self, code: str, reason: str = "") -> None:
super().__init__(reason or code)
self.code = code
self.reason = reason
@dataclass(frozen=True)
class KeyPair:
signing_key: Any # nacl.signing.SigningKey
verify_key: Any # nacl.signing.VerifyKey
node_id_short: str
node_id_full: str
# ---------------------------------------------------------------------------
# ID helpers
# ---------------------------------------------------------------------------
def short_node_id(verify_key_bytes: bytes) -> str:
"""First 8 bytes base32, grouped in 4-char segments: 'ed25519:XXXX-XXXX-XXXX-XXXX'."""
raw = base64.b32encode(verify_key_bytes[:8]).decode("ascii")
grouped = "-".join(raw[i : i + 4] for i in range(0, len(raw), 4))
return f"ed25519:{grouped}"
def full_node_id(verify_key_bytes: bytes) -> str:
"""All 32 bytes base64url no-pad: 'ed25519:<b64>'."""
b64 = base64.urlsafe_b64encode(verify_key_bytes).rstrip(b"=").decode("ascii")
return f"ed25519:{b64}"
def parse_node_id(node_id: str) -> bytes:
"""Decode a full node_id to 32 bytes. Short form raises ValueError."""
import re
if not node_id.startswith("ed25519:"):
raise ValueError(f"node_id must start with 'ed25519:': {node_id!r}")
payload = node_id[len("ed25519:") :]
# Short form is b32-with-dashes: groups of [A-Z2-7=]{1,4} separated by '-'
# e.g. "SQ2J-OH7E-LCMU-Y===" — always shorter than 30 chars and matches this pattern.
# Full form is 43-char base64url (no '=' padding).
if re.fullmatch(r"[A-Z2-7=]{1,4}(-[A-Z2-7=]{1,4}){1,}", payload):
raise ValueError("Short node IDs cannot be decoded to raw bytes; use full form.")
# Add padding back for base64url decoding
padded = payload + "=" * (4 - len(payload) % 4 if len(payload) % 4 != 0 else 0)
raw = base64.urlsafe_b64decode(padded)
if len(raw) != 32:
raise ValueError(f"Expected 32 bytes, got {len(raw)}")
return raw
# ---------------------------------------------------------------------------
# Canonical JSON
# ---------------------------------------------------------------------------
def canonical_json(obj: Any) -> bytes:
"""Canonical JSON: sorted keys, no whitespace, numbers stripped of trailing zeros, UTF-8."""
serialised = json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
# Strip trailing zeros from numbers: 1.0 -> 1, 1.10 -> 1.1
# We post-process the JSON string carefully without breaking string contents.
result = _strip_trailing_zeros(serialised)
return result.encode("utf-8")
def _strip_trailing_zeros(s: str) -> str:
"""Remove trailing zeros from JSON numbers without touching string values."""
import re
# Match JSON numbers (integers, floats, exponent forms) that appear outside strings
# We parse character-by-character to skip string literals.
out: list[str] = []
i = 0
n = len(s)
while i < n:
c = s[i]
if c == '"':
# Scan to end of string, respecting escapes
out.append(c)
i += 1
while i < n:
ch = s[i]
out.append(ch)
if ch == "\\":
i += 1
if i < n:
out.append(s[i])
elif ch == '"':
i += 1
break
i += 1
else:
# Look for a number token
m = re.match(r"-?(?:0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?", s[i:])
if m and (m.group(1) or m.group(2)):
num_str = m.group(0)
# Parse and reformat
try:
val = float(num_str)
# If it represents an integer value, emit as integer
if val == int(val) and "e" not in num_str.lower():
out.append(str(int(val)))
else:
# Strip trailing zeros from decimal part
formatted = f"{val:g}"
out.append(formatted)
except (ValueError, OverflowError):
out.append(num_str)
i += len(num_str)
else:
out.append(c)
i += 1
return "".join(out)
# ---------------------------------------------------------------------------
# Signing / Verification
# ---------------------------------------------------------------------------
def sign_payload(payload: dict, kp: KeyPair) -> dict:
"""Return a copy of payload with 'signature' field added (signs over payload without signature)."""
if not _NACL_AVAILABLE:
raise IdentityError("keys_invalid", reason="PyNaCl not installed")
unsigned = {k: v for k, v in payload.items() if k != "signature"}
raw = canonical_json(unsigned)
try:
signed = kp.signing_key.sign(raw)
sig_bytes = signed.signature
except Exception as exc:
raise IdentityError("sign_failed", reason=str(exc)) from exc
sig_b64 = base64.urlsafe_b64encode(sig_bytes).rstrip(b"=").decode("ascii")
result = dict(unsigned)
result["signature"] = f"ed25519:{sig_b64}"
return result
def verify_payload(payload: dict, vk: Any) -> bool: # vk: nacl.signing.VerifyKey
"""Verify the 'signature' field of payload against vk. Returns True or raises IdentityError."""
if not _NACL_AVAILABLE:
raise IdentityError("keys_invalid", reason="PyNaCl not installed")
raw_sig = payload.get("signature", "")
if not raw_sig.startswith("ed25519:"):
raise IdentityError("verify_failed", reason="signature field missing or malformed")
sig_b64 = raw_sig[len("ed25519:") :]
padding = 4 - len(sig_b64) % 4
if padding != 4:
sig_b64 += "=" * padding
try:
sig_bytes = base64.urlsafe_b64decode(sig_b64)
except Exception as exc:
raise IdentityError("verify_failed", reason=f"bad signature encoding: {exc}") from exc
unsigned = {k: v for k, v in payload.items() if k != "signature"}
raw = canonical_json(unsigned)
try:
vk.verify(raw, sig_bytes)
except nacl.exceptions.BadSignatureError as exc:
raise IdentityError("verify_failed", reason="signature verification failed") from exc
except Exception as exc:
raise IdentityError("verify_failed", reason=str(exc)) from exc
return True
def verify_payload_with_node_id(payload: dict, expected_node_id_full: str) -> bool:
"""Verify payload signature using the public key encoded in expected_node_id_full."""
if not _NACL_AVAILABLE:
raise IdentityError("keys_invalid", reason="PyNaCl not installed")
try:
vk_bytes = parse_node_id(expected_node_id_full)
except ValueError as exc:
raise IdentityError("bad_node_id", reason=str(exc)) from exc
try:
vk = nacl.signing.VerifyKey(vk_bytes)
except Exception as exc:
raise IdentityError("keys_invalid", reason=str(exc)) from exc
return verify_payload(payload, vk)
# ---------------------------------------------------------------------------
# Key I/O
# ---------------------------------------------------------------------------
def generate() -> KeyPair:
"""Generate a fresh Ed25519 keypair using os.urandom."""
if not _NACL_AVAILABLE:
raise IdentityError("keys_invalid", reason="PyNaCl not installed")
seed = os.urandom(32)
sk = nacl.signing.SigningKey(seed)
vk = sk.verify_key
vk_bytes = bytes(vk)
return KeyPair(
signing_key=sk,
verify_key=vk,
node_id_short=short_node_id(vk_bytes),
node_id_full=full_node_id(vk_bytes),
)
def save(kp: KeyPair, keys_dir: Path) -> None:
"""Save signing key (chmod 0600) and verify key to keys_dir."""
keys_dir.mkdir(parents=True, exist_ok=True)
priv_path = keys_dir / "device.ed25519"
pub_path = keys_dir / "device.pub"
# Write private key (raw 32-byte seed, base64url encoded)
sk_bytes = bytes(kp.signing_key)
priv_path.write_bytes(base64.urlsafe_b64encode(sk_bytes).rstrip(b"=") + b"\n")
# Restrict permissions on POSIX
from contextlib import suppress
with suppress(AttributeError):
os.chmod(priv_path, stat.S_IRUSR | stat.S_IWUSR) # 0600
# Write public key
vk_bytes = bytes(kp.verify_key)
pub_path.write_bytes(base64.urlsafe_b64encode(vk_bytes).rstrip(b"=") + b"\n")
def load(keys_dir: Path) -> KeyPair:
"""Load KeyPair from device.ed25519 + device.pub in keys_dir."""
if not _NACL_AVAILABLE:
raise IdentityError("keys_invalid", reason="PyNaCl not installed")
priv_path = keys_dir / "device.ed25519"
pub_path = keys_dir / "device.pub"
if not priv_path.exists() or not pub_path.exists():
raise IdentityError("keys_missing", reason=f"Key files not found in {keys_dir}")
# Check permissions on POSIX only. POSIX mode bits are not meaningful on
# Windows (NTFS files commonly report 0o666 regardless of ACLs), so the
# check would raise false positives there. stat.S_IMODE does not raise on
# Windows, so an explicit os.name guard is required.
if os.name == "posix":
mode = oct(stat.S_IMODE(priv_path.stat().st_mode))
if not mode.endswith("600") and not mode.endswith("400"):
raise IdentityError(
"keys_permissions",
reason=f"Private key {priv_path} has unsafe permissions {mode}",
)
try:
sk_b64 = priv_path.read_text().strip()
padding = 4 - len(sk_b64) % 4
if padding != 4:
sk_b64 += "=" * padding
sk_bytes = base64.urlsafe_b64decode(sk_b64)
sk = nacl.signing.SigningKey(sk_bytes)
except IdentityError:
raise
except Exception as exc:
raise IdentityError("keys_invalid", reason=str(exc)) from exc
vk = sk.verify_key
vk_bytes = bytes(vk)
return KeyPair(
signing_key=sk,
verify_key=vk,
node_id_short=short_node_id(vk_bytes),
node_id_full=full_node_id(vk_bytes),
)
def load_or_generate(keys_dir: Path) -> KeyPair:
"""Load keys if present, otherwise generate and persist."""
priv_path = keys_dir / "device.ed25519"
if priv_path.exists():
return load(keys_dir)
kp = generate()
save(kp, keys_dir)
return kp