Spaces:
Running on Zero
Running on Zero
File size: 15,775 Bytes
70650b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | # M26 β Distributed Inference
**Spec version:** v3.0 β *experimental*
**Depends on:** [X08 Tensor Transport](../cross-cutting/X08-tensor-transport.md), [X06 WebSocket](../../phase-2/cross-cutting/X06-websocket.md), [M04 LLM](../../modules/M04-llm.md), [M16 Tokens](../../phase-2/modules/M16-tokens.md), [X03 Observability](../../cross-cutting/X03-observability.md)
**Depended on by:** Optional `experimental.distributed_llm.chat` backend in M04
---
## 1. Responsibility
Run a single LLM forward pass across multiple machines on the same LAN (or, in extremely careful setups, a single federation). Take a 7B model that doesn't fit on any one anchor's GPU and split it: anchor A holds layers 0β7, anchor B holds 8β15, anchor C holds 16β23, etc. A request orchestrator chains them and streams tokens back to the user.
This is a **research module**. It exists for two reasons:
1. **Resilience.** When a community's biggest GPU breaks, the next-biggest fleet of GPUs can still serve mid-sized models cooperatively.
2. **Reach.** A community of three households, each with consumer hardware, can collectively run a model none of them could run alone.
It is explicitly **not** for serving production user-facing LLM traffic at scale. The latency is worse than local inference (typically 2β4Γ per token), the orchestration is fragile (one shard offline = retry the whole pipeline), and the GPU memory savings come at significant complexity cost. Communities should default to local inference; this module exists for the cases where local isn't enough.
---
## 2. Non-goals (loud and clear)
- **Large models.** 70B-class models are out of scope. The math says you'd need ten 24 GB GPUs to host one, which is the wrong problem for a neighbourhood mesh to solve.
- **Cross-WAN sharding.** Inference across the public internet is uneconomical (latency, bandwidth). Limit to same LAN or same-VPN federation.
- **Heterogeneous shards across model versions.** All shards in a pipeline must serve the **exact same model and weights checksum**. No partial-model recovery.
- **Replacing local inference.** When `policy.research.enable = false`, this module is inert.
---
## 3. File layout
```
hearthnet/distributed_inference/
βββ __init__.py
βββ shard.py # Shard, ShardDescriptor, ShardServer
βββ pipeline.py # Pipeline, PipelineOrchestrator
βββ routing.py # Picks a set of shards that cover [0..N] layers
βββ health.py # Heartbeats, failover detection
βββ backends/
βββ base.py
βββ petals_like.py # uses bigscience/petals client/server primitives
βββ small_model_layered.py # custom impl for small models (β€ 3B)
```
---
## 4. Public API
### 4.1 Dataclasses
```python
@dataclass(frozen=True)
class ShardDescriptor:
shard_id: ShardID # "<model_id>:<lo>-<hi>"
model_id: str # HF model id
weights_sha256: str # full model weights hash; shards must match
layer_range: tuple[int,int] # inclusive
vram_required_mb: int
max_concurrent_streams: int
host: NodeID
endpoint: Endpoint # ws://...
advertised_at: datetime
@dataclass
class Pipeline:
pipeline_id: str
model_id: str
weights_sha256: str
total_layers: int
ordered_shards: list[ShardDescriptor]
established_at: datetime
@dataclass
class ShardHealth:
shard_id: ShardID
online: bool
last_seen: datetime
p95_latency_ms: float
queue_depth: int
```
### 4.2 `ShardServer`
```python
class ShardServer:
"""Hosts one contiguous shard. Loaded on demand; lazy-evictable under memory pressure."""
def __init__(self, descriptor: ShardDescriptor, model_loader: ModelLoader, settings: ShardSettings): ...
async def start(self) -> None:
# Load weights for the layer range; register `experimental.distributed_llm.shard.serve` on the bus
...
async def forward(self, activations_in: TensorChunkStream) -> TensorChunkStream:
"""The hot path. Receives activations, runs layers, emits activations."""
...
async def health(self) -> ShardHealth: ...
async def evict(self) -> None:
"""Free VRAM; triggered by host memory manager."""
...
```
### 4.3 `PipelineOrchestrator`
```python
class PipelineOrchestrator:
"""
Chooses shards to cover the model's layers, opens streams to each, and
pumps activations through them in order. Handles failover.
"""
def __init__(
self,
bus: CapabilityBus,
router: ShardRouter,
health: ShardHealthTracker,
observability: Observability,
): ...
async def chat(
self,
request: LlmChatRequest,
params: DistributedChatParams,
) -> AsyncIterator[StreamFrame]:
# 1. Resolve a Pipeline covering all layers of the target model
# 2. Open WS streams to each shard via X08 tensor transport
# 3. For each token step:
# embedding β shard 0 β shard 1 β ... β shard N β token sample
# 4. Yield `token_delta` frames; emit `shard_status` and `shard_failover` diagnostics
# 5. On any shard failure, attempt re-routing once; if that fails and
# `params.fallback_to_local`, fall back to local inference and emit a
# `pipeline_aborted` frame
...
```
### 4.4 `ShardRouter`
```python
class ShardRouter:
"""
Given a model_id and an `experimental.shard.advertised` event stream,
pick a covering set of shards minimising:
- total network hops
- max per-shard queue depth
- chance of overlap with the caller's own GPU (avoid self-as-shard)
"""
def __init__(self, store: ShardStore, settings: RoutingSettings): ...
async def pick(self, model_id: str, weights_sha256: str) -> Pipeline: ...
async def repick(self, pipeline: Pipeline, exclude: set[ShardID]) -> Pipeline: ...
```
---
## 5. Behaviour
### 5.1 Shard advertisement and discovery
A node hosting a shard emits `experimental.shard.advertised` events into the community event log. The event carries `ShardDescriptor` fields plus a timestamp. Advertisements expire after `DISTRIBUTED_SHARD_HEALTH_TIMEOUT_S * 4` (default 120s); shard hosts must re-advertise via heartbeat.
When a node opts out (`policy.research.enable=false`), it does not emit advertisements. Existing advertisements expire normally.
The shard store is a local read model built from these events, indexed by `(model_id, weights_sha256, layer_range)`.
### 5.2 Pipeline construction
`ShardRouter.pick`:
1. Filter advertisements to those matching `model_id` and `weights_sha256`.
2. Greedy cover: starting from layer 0, pick the shard with the lowest queue depth that includes the next uncovered layer; advance the cursor; repeat. Returns failure if any layer is uncoverable.
3. Prefer shards on the same LAN if possible (LAN advertisements have a lower "hop weight" metric attached by Discovery).
4. Avoid sharding to **self** as the first shard β embedding + sampling should stay on the orchestrator.
Constructed pipelines are not persisted; they're per-call.
### 5.3 Forward pass
Per token:
```
[orchestrator] embedding β [shard 0] layers 0..7 β [shard 1] layers 8..15 β ... β [orchestrator] sample
```
Activations flow as **fp16 tensors** by default (configurable to fp32 for debugging). Each hop is a WebSocket binary frame stream (see [X08](../cross-cutting/X08-tensor-transport.md)). The orchestrator interleaves token-N and token-N+1: as soon as shard 0 finishes token N, the orchestrator pushes token N+1's embedding into shard 0 while shard 1 is still processing N. This pipeline parallelism approaches the latency of the longest-latency shard at steady state.
### 5.4 Failure handling
If a shard's stream errors or stalls past `DISTRIBUTED_SHARD_HEALTH_TIMEOUT_S`:
1. The orchestrator emits a `shard_status` frame with `status:"degraded"`.
2. Calls `router.repick(pipeline, exclude={failed_shard_id})`.
3. If repick succeeds, opens a fresh stream to the replacement and emits `shard_failover` frame. **In-flight tokens are restarted** (no mid-token recovery).
4. If repick fails and `params.fallback_to_local`, the orchestrator silently restarts the call as a local-only `llm.chat@2.0` against any local model that matches.
5. Else: emit `pipeline_aborted` frame and return `shard_unavailable`.
`DISTRIBUTED_FALLBACK_TO_LOCAL_AFTER_FAILURES` (default 2): if failover happens that many times in one call, give up and fall back to local.
### 5.5 Streaming and backpressure
Tensor-chunk streams use a window of `TENSOR_FLOW_CONTROL_WINDOW` chunks (default 16). Each chunk is at most `TENSOR_CHUNK_BYTES` (1 MB). If the downstream shard's send queue fills, the orchestrator pauses upstream until ACKs drain. See [X08 Β§4](../cross-cutting/X08-tensor-transport.md).
### 5.6 Concurrency
A shard's `max_concurrent_streams` is honoured strictly. If the orchestrator's call would exceed it, the orchestrator picks a different shard (via `router.repick`) rather than queuing.
A shard's GPU memory budget is enforced by the shard host's own resource manager; a shard exceeding its budget gets evicted and re-advertises with `vram_required_mb` updated next time it loads.
### 5.7 Models supported
Phase 3 launches with two backend choices:
| Backend | Models | Notes |
|---------|--------|-------|
| `small_model_layered` | Qwen2.5-{1.5B,3B,7B}, Llama-3.2-{1B,3B}, MiniCPM-3 | Custom HearthNet impl; PyTorch model surgery to expose per-layer forward |
| `petals_like` | (vendored from BigScience Petals) | Optional; only if user installs `hearthnet[petals]` extra |
The `small_model_layered` backend handles models up to roughly 7B parameters cleanly; beyond that the activation transport becomes the bottleneck.
### 5.8 Security boundary
A shard host receives activation tensors which **leak training data residue**. Treat activations as sensitive: do not log them, do not persist, do not retain past forward pass. Per-call signed authentication; the caller's identity is recorded in metrics but not in logs of tensor contents.
A malicious shard could degrade outputs subtly. Detection is hard in general; the orchestrator does **basic sanity checks** (norm bounds, NaN/Inf detection) but cannot detect adversarial corruption. Communities should only enable distributed inference among members they trust as much as they trust the LLM service operator.
### 5.9 Privacy threat surface
A shard sees the activations of every request routed through it. With effort, a shard host can reconstruct approximate input text (especially the prompt) from activations of intermediate layers. This is **a real concern, not a theoretical one**.
Mitigations (none perfect):
- Restrict participation to members at trust level `trusted` or higher.
- Mix activations with a small amount of noise at the orchestrator (research; not yet implemented).
- Use this module only for queries the requester would already trust the community with.
### 5.10 Observability
Per call, emit:
- `distributed_inference.pipeline_construct_ms`
- `distributed_inference.first_token_ms`
- `distributed_inference.tokens_per_second`
- `distributed_inference.shard_latency_ms{shard_id}` histograms
- `distributed_inference.failovers_total`
- `distributed_inference.fallback_to_local_total`
---
## 6. Errors
| Code | Cause |
|------|-------|
| `experimental_disabled` | `policy.research.enable=false` |
| `shard_unavailable` | No shard covers a required layer range, or all candidates are at max concurrency |
| `pipeline_stalled` | No progress within timeout |
| `weights_mismatch` | A shard's advertised `weights_sha256` differs from requested |
| `bad_request` | Unknown model, malformed pipeline params |
---
## 7. Configuration
```toml
[research.distributed_inference]
enabled = false
backend = "small_model_layered"
max_shards_per_request = 16
shard_health_timeout_seconds = 30
fallback_to_local = true
activation_dtype = "fp16" # "fp16" | "fp32"
allow_self_as_shard = false
max_concurrent_pipelines = 4
[research.distributed_inference.host]
serve_shards = false
shard_eviction_idle_seconds = 600
shard_max_vram_mb = 20000
```
---
## 8. Tests
### 8.1 Unit
- ShardRouter cover algorithm: 16-layer model + 3 advertised shards (0-7, 4-11, 8-15) β picks {0-7, 8-15}, ignores overlap shard
- Sanity bounds on activations: NaN injection triggers `pipeline_stalled` (via failed health check on subsequent chunk)
- Pipeline construction with weights mismatch β `weights_mismatch`
### 8.2 Integration (LAN)
- Two-node setup, 1.5B model split as 0-7 / 8-15; happy-path tokens/sec measured; baseline single-machine inference also measured; ratio reported (expect 0.4β0.6Γ local)
- Shard host kill mid-stream; failover to a third node; total call still succeeds; latency penalty bounded
- Concurrent two-pipeline test on three nodes; no deadlock; per-call latency degrades < 2Γ
### 8.3 Adversarial
- Malicious shard returns garbage activations: orchestrator's NaN/Inf detector catches the call; metric `distributed_inference.shard_corruption_detected_total` increments; pipeline aborts
- Slowloris shard (returns one chunk per second): `pipeline_stalled` after timeout; failover succeeds
### 8.4 Performance budget
- 3B model, 2-shard pipeline, RTX 5090 + RTX 4090: β₯ 8 tokens/sec sustained
- First-token latency β€ 800ms
- Construction-to-first-byte β€ 500ms
- Tensor-chunk overhead per hop β€ 25ms p95
---
## 9. Cross-references
- Capability spec: [CAPABILITY_CONTRACT_v3 Β§4.1β4.3](../CAPABILITY_CONTRACT_v3.md)
- Tensor transport: [X08](../cross-cutting/X08-tensor-transport.md)
- Base LLM service: [M04](../../modules/M04-llm.md)
- Trust levels: [M01](../../modules/M01-identity.md)
---
## 10. Open research questions
1. **Activation privacy.** Can we add fast-to-compute noise that preserves inference accuracy but defeats activation-inversion attacks? Cite the Geiping et al. inversion paper as the threat baseline.
2. **Mid-token recovery.** Currently a shard failure restarts the in-flight token. Could we use micro-checkpointing (every K tokens) to recover without a restart? Latency cost?
3. **Heterogeneous shards.** Could a 4090 host the early layers (heavier compute per layer) and a 3060 the later ones, while remaining balanced? Probably yes β automated load assignment is the research question.
4. **Async pipeline.** Currently the orchestrator interleaves at the token level. Could it interleave at the layer level (one shard processes token N+2 while another processes N+1) for higher throughput? In theory yes; coordination protocol unclear.
5. **Mixed local + distributed.** When the orchestrator could host some layers itself (it has a GPU), should it? When? Currently `allow_self_as_shard=false`. A heuristic that considers compute headroom would be richer.
6. **Adversarial detection.** Beyond NaN/Inf, can we cheaply detect activation tampering by comparing to a small "shadow inference" on a tiny model? Cost vs. benefit unclear.
7. **Pricing / incentive.** A shard host pays in GPU time. A community-internal token-economy is explicitly out of scope (00-OVERVIEW Β§8). But a *reputational* signal β "this anchor served 4000 shard-tokens this week" β could be helpful. Should it be a metric?
8. **Backend strategy.** `petals_like` vs `small_model_layered`: which delivers better quality / latency / robustness for our target models? An honest A/B is the answer.
|