Spaces:
Running on Zero
Running on Zero
GitHub Actions
fix: 0 test failures; FileService; real RagService; emergency probe; chat return
4aaae80 | """M28 — Federated Learning / LoRA aggregation (experimental, Phase 3). | |
| FedAvg on LoRA adapter weight deltas. Each node trains locally; | |
| only adapter deltas (not raw data or full weights) are shared. | |
| Gated by config.research.federated_learning = True. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from typing import NewType | |
| RoundID = NewType("RoundID", str) | |
| class RoundManifest: | |
| """Describes a federated learning round.""" | |
| round_id: RoundID | |
| base_model_id: str | |
| coordinator_node_id: str | |
| community_id: str | |
| lora_rank: int = 16 | |
| lora_alpha: float = 32.0 | |
| learning_rate: float = 2e-4 | |
| min_participants: int = 2 | |
| max_participants: int = 20 | |
| round_timeout_seconds: int = 3600 | |
| created_at: float = field(default_factory=time.time) | |
| coordinator_sig: bytes = b"" | |
| class ParticipantSubmission: | |
| round_id: RoundID | |
| participant_node_id: str | |
| delta_bytes: bytes # serialised LoRA state dict subset (safetensors format) | |
| num_samples: int | |
| submitted_at: float = field(default_factory=time.time) | |
| participant_sig: bytes = b"" | |
| class FedLearnCoordinator: | |
| """Orchestrates a federated learning round. | |
| Experimental. Requires peft and torch. | |
| Only active when config.research.federated_learning = True. | |
| """ | |
| def __init__(self, keypair=None, bus=None) -> None: | |
| self._keypair = keypair | |
| self._bus = bus | |
| self._rounds: dict[RoundID, RoundManifest] = {} | |
| self._submissions: dict[RoundID, list[ParticipantSubmission]] = {} | |
| def create_round( | |
| self, | |
| base_model_id: str, | |
| community_id: str, | |
| **kwargs, | |
| ) -> RoundManifest: | |
| """Create a new federated learning round manifest.""" | |
| round_id = RoundID(str(uuid.uuid4())) | |
| manifest = RoundManifest( | |
| round_id=round_id, | |
| base_model_id=base_model_id, | |
| coordinator_node_id=getattr(self._keypair, "node_id_short", "unknown"), | |
| community_id=community_id, | |
| **kwargs, | |
| ) | |
| self._rounds[round_id] = manifest | |
| self._submissions[round_id] = [] | |
| return manifest | |
| def submit(self, submission: ParticipantSubmission) -> bool: | |
| """Accept a participant's LoRA delta submission.""" | |
| if submission.round_id not in self._rounds: | |
| return False | |
| self._submissions[submission.round_id].append(submission) | |
| return True | |
| def aggregate(self, round_id: RoundID) -> bytes | None: | |
| """FedAvg: weighted average of submitted LoRA deltas. | |
| Returns aggregated delta bytes or None if not enough participants. | |
| Raises NotImplementedError — actual aggregation requires peft+torch. | |
| """ | |
| subs = self._submissions.get(round_id, []) | |
| manifest = self._rounds.get(round_id) | |
| if manifest is None or len(subs) < manifest.min_participants: | |
| return None | |
| raise NotImplementedError( | |
| "FedLearnCoordinator.aggregate() requires peft and torch. " | |
| "This is an experimental Phase 3 feature (M28)." | |
| ) | |
| def round_status(self, round_id: RoundID) -> dict: | |
| manifest = self._rounds.get(round_id) | |
| if manifest is None: | |
| return {"error": "not_found"} | |
| subs = self._submissions.get(round_id, []) | |
| return { | |
| "round_id": round_id, | |
| "base_model_id": manifest.base_model_id, | |
| "participants": len(subs), | |
| "min_required": manifest.min_participants, | |
| "ready_to_aggregate": len(subs) >= manifest.min_participants, | |
| } | |