Spaces:
Running on Zero
A newer version of the Gradio SDK is available: 6.19.0
M26 β Distributed Inference
Spec version: v3.0 β experimental
Depends on: X08 Tensor Transport, X06 WebSocket, M04 LLM, M16 Tokens, X03 Observability
Depended on by: Optional experimental.distributed_llm.chat backend in M04
1. Responsibility
Run a single LLM forward pass across multiple machines on the same LAN (or, in extremely careful setups, a single federation). Take a 7B model that doesn't fit on any one anchor's GPU and split it: anchor A holds layers 0β7, anchor B holds 8β15, anchor C holds 16β23, etc. A request orchestrator chains them and streams tokens back to the user.
This is a research module. It exists for two reasons:
- Resilience. When a community's biggest GPU breaks, the next-biggest fleet of GPUs can still serve mid-sized models cooperatively.
- Reach. A community of three households, each with consumer hardware, can collectively run a model none of them could run alone.
It is explicitly not for serving production user-facing LLM traffic at scale. The latency is worse than local inference (typically 2β4Γ per token), the orchestration is fragile (one shard offline = retry the whole pipeline), and the GPU memory savings come at significant complexity cost. Communities should default to local inference; this module exists for the cases where local isn't enough.
2. Non-goals (loud and clear)
- Large models. 70B-class models are out of scope. The math says you'd need ten 24 GB GPUs to host one, which is the wrong problem for a neighbourhood mesh to solve.
- Cross-WAN sharding. Inference across the public internet is uneconomical (latency, bandwidth). Limit to same LAN or same-VPN federation.
- Heterogeneous shards across model versions. All shards in a pipeline must serve the exact same model and weights checksum. No partial-model recovery.
- Replacing local inference. When
policy.research.enable = false, this module is inert.
3. File layout
hearthnet/distributed_inference/
βββ __init__.py
βββ shard.py # Shard, ShardDescriptor, ShardServer
βββ pipeline.py # Pipeline, PipelineOrchestrator
βββ routing.py # Picks a set of shards that cover [0..N] layers
βββ health.py # Heartbeats, failover detection
βββ backends/
βββ base.py
βββ petals_like.py # uses bigscience/petals client/server primitives
βββ small_model_layered.py # custom impl for small models (β€ 3B)
4. Public API
4.1 Dataclasses
@dataclass(frozen=True)
class ShardDescriptor:
shard_id: ShardID # "<model_id>:<lo>-<hi>"
model_id: str # HF model id
weights_sha256: str # full model weights hash; shards must match
layer_range: tuple[int,int] # inclusive
vram_required_mb: int
max_concurrent_streams: int
host: NodeID
endpoint: Endpoint # ws://...
advertised_at: datetime
@dataclass
class Pipeline:
pipeline_id: str
model_id: str
weights_sha256: str
total_layers: int
ordered_shards: list[ShardDescriptor]
established_at: datetime
@dataclass
class ShardHealth:
shard_id: ShardID
online: bool
last_seen: datetime
p95_latency_ms: float
queue_depth: int
4.2 ShardServer
class ShardServer:
"""Hosts one contiguous shard. Loaded on demand; lazy-evictable under memory pressure."""
def __init__(self, descriptor: ShardDescriptor, model_loader: ModelLoader, settings: ShardSettings): ...
async def start(self) -> None:
# Load weights for the layer range; register `experimental.distributed_llm.shard.serve` on the bus
...
async def forward(self, activations_in: TensorChunkStream) -> TensorChunkStream:
"""The hot path. Receives activations, runs layers, emits activations."""
...
async def health(self) -> ShardHealth: ...
async def evict(self) -> None:
"""Free VRAM; triggered by host memory manager."""
...
4.3 PipelineOrchestrator
class PipelineOrchestrator:
"""
Chooses shards to cover the model's layers, opens streams to each, and
pumps activations through them in order. Handles failover.
"""
def __init__(
self,
bus: CapabilityBus,
router: ShardRouter,
health: ShardHealthTracker,
observability: Observability,
): ...
async def chat(
self,
request: LlmChatRequest,
params: DistributedChatParams,
) -> AsyncIterator[StreamFrame]:
# 1. Resolve a Pipeline covering all layers of the target model
# 2. Open WS streams to each shard via X08 tensor transport
# 3. For each token step:
# embedding β shard 0 β shard 1 β ... β shard N β token sample
# 4. Yield `token_delta` frames; emit `shard_status` and `shard_failover` diagnostics
# 5. On any shard failure, attempt re-routing once; if that fails and
# `params.fallback_to_local`, fall back to local inference and emit a
# `pipeline_aborted` frame
...
4.4 ShardRouter
class ShardRouter:
"""
Given a model_id and an `experimental.shard.advertised` event stream,
pick a covering set of shards minimising:
- total network hops
- max per-shard queue depth
- chance of overlap with the caller's own GPU (avoid self-as-shard)
"""
def __init__(self, store: ShardStore, settings: RoutingSettings): ...
async def pick(self, model_id: str, weights_sha256: str) -> Pipeline: ...
async def repick(self, pipeline: Pipeline, exclude: set[ShardID]) -> Pipeline: ...
5. Behaviour
5.1 Shard advertisement and discovery
A node hosting a shard emits experimental.shard.advertised events into the community event log. The event carries ShardDescriptor fields plus a timestamp. Advertisements expire after DISTRIBUTED_SHARD_HEALTH_TIMEOUT_S * 4 (default 120s); shard hosts must re-advertise via heartbeat.
When a node opts out (policy.research.enable=false), it does not emit advertisements. Existing advertisements expire normally.
The shard store is a local read model built from these events, indexed by (model_id, weights_sha256, layer_range).
5.2 Pipeline construction
ShardRouter.pick:
- Filter advertisements to those matching
model_idandweights_sha256. - Greedy cover: starting from layer 0, pick the shard with the lowest queue depth that includes the next uncovered layer; advance the cursor; repeat. Returns failure if any layer is uncoverable.
- Prefer shards on the same LAN if possible (LAN advertisements have a lower "hop weight" metric attached by Discovery).
- Avoid sharding to self as the first shard β embedding + sampling should stay on the orchestrator.
Constructed pipelines are not persisted; they're per-call.
5.3 Forward pass
Per token:
[orchestrator] embedding β [shard 0] layers 0..7 β [shard 1] layers 8..15 β ... β [orchestrator] sample
Activations flow as fp16 tensors by default (configurable to fp32 for debugging). Each hop is a WebSocket binary frame stream (see X08). The orchestrator interleaves token-N and token-N+1: as soon as shard 0 finishes token N, the orchestrator pushes token N+1's embedding into shard 0 while shard 1 is still processing N. This pipeline parallelism approaches the latency of the longest-latency shard at steady state.
5.4 Failure handling
If a shard's stream errors or stalls past DISTRIBUTED_SHARD_HEALTH_TIMEOUT_S:
- The orchestrator emits a
shard_statusframe withstatus:"degraded". - Calls
router.repick(pipeline, exclude={failed_shard_id}). - If repick succeeds, opens a fresh stream to the replacement and emits
shard_failoverframe. In-flight tokens are restarted (no mid-token recovery). - If repick fails and
params.fallback_to_local, the orchestrator silently restarts the call as a local-onlyllm.chat@2.0against any local model that matches. - Else: emit
pipeline_abortedframe and returnshard_unavailable.
DISTRIBUTED_FALLBACK_TO_LOCAL_AFTER_FAILURES (default 2): if failover happens that many times in one call, give up and fall back to local.
5.5 Streaming and backpressure
Tensor-chunk streams use a window of TENSOR_FLOW_CONTROL_WINDOW chunks (default 16). Each chunk is at most TENSOR_CHUNK_BYTES (1 MB). If the downstream shard's send queue fills, the orchestrator pauses upstream until ACKs drain. See X08 Β§4.
5.6 Concurrency
A shard's max_concurrent_streams is honoured strictly. If the orchestrator's call would exceed it, the orchestrator picks a different shard (via router.repick) rather than queuing.
A shard's GPU memory budget is enforced by the shard host's own resource manager; a shard exceeding its budget gets evicted and re-advertises with vram_required_mb updated next time it loads.
5.7 Models supported
Phase 3 launches with two backend choices:
| Backend | Models | Notes |
|---|---|---|
small_model_layered |
Qwen2.5-{1.5B,3B,7B}, Llama-3.2-{1B,3B}, MiniCPM-3 | Custom HearthNet impl; PyTorch model surgery to expose per-layer forward |
petals_like |
(vendored from BigScience Petals) | Optional; only if user installs hearthnet[petals] extra |
The small_model_layered backend handles models up to roughly 7B parameters cleanly; beyond that the activation transport becomes the bottleneck.
5.8 Security boundary
A shard host receives activation tensors which leak training data residue. Treat activations as sensitive: do not log them, do not persist, do not retain past forward pass. Per-call signed authentication; the caller's identity is recorded in metrics but not in logs of tensor contents.
A malicious shard could degrade outputs subtly. Detection is hard in general; the orchestrator does basic sanity checks (norm bounds, NaN/Inf detection) but cannot detect adversarial corruption. Communities should only enable distributed inference among members they trust as much as they trust the LLM service operator.
5.9 Privacy threat surface
A shard sees the activations of every request routed through it. With effort, a shard host can reconstruct approximate input text (especially the prompt) from activations of intermediate layers. This is a real concern, not a theoretical one.
Mitigations (none perfect):
- Restrict participation to members at trust level
trustedor higher. - Mix activations with a small amount of noise at the orchestrator (research; not yet implemented).
- Use this module only for queries the requester would already trust the community with.
5.10 Observability
Per call, emit:
distributed_inference.pipeline_construct_msdistributed_inference.first_token_msdistributed_inference.tokens_per_seconddistributed_inference.shard_latency_ms{shard_id}histogramsdistributed_inference.failovers_totaldistributed_inference.fallback_to_local_total
6. Errors
| Code | Cause |
|---|---|
experimental_disabled |
policy.research.enable=false |
shard_unavailable |
No shard covers a required layer range, or all candidates are at max concurrency |
pipeline_stalled |
No progress within timeout |
weights_mismatch |
A shard's advertised weights_sha256 differs from requested |
bad_request |
Unknown model, malformed pipeline params |
7. Configuration
[research.distributed_inference]
enabled = false
backend = "small_model_layered"
max_shards_per_request = 16
shard_health_timeout_seconds = 30
fallback_to_local = true
activation_dtype = "fp16" # "fp16" | "fp32"
allow_self_as_shard = false
max_concurrent_pipelines = 4
[research.distributed_inference.host]
serve_shards = false
shard_eviction_idle_seconds = 600
shard_max_vram_mb = 20000
8. Tests
8.1 Unit
- ShardRouter cover algorithm: 16-layer model + 3 advertised shards (0-7, 4-11, 8-15) β picks {0-7, 8-15}, ignores overlap shard
- Sanity bounds on activations: NaN injection triggers
pipeline_stalled(via failed health check on subsequent chunk) - Pipeline construction with weights mismatch β
weights_mismatch
8.2 Integration (LAN)
- Two-node setup, 1.5B model split as 0-7 / 8-15; happy-path tokens/sec measured; baseline single-machine inference also measured; ratio reported (expect 0.4β0.6Γ local)
- Shard host kill mid-stream; failover to a third node; total call still succeeds; latency penalty bounded
- Concurrent two-pipeline test on three nodes; no deadlock; per-call latency degrades < 2Γ
8.3 Adversarial
- Malicious shard returns garbage activations: orchestrator's NaN/Inf detector catches the call; metric
distributed_inference.shard_corruption_detected_totalincrements; pipeline aborts - Slowloris shard (returns one chunk per second):
pipeline_stalledafter timeout; failover succeeds
8.4 Performance budget
- 3B model, 2-shard pipeline, RTX 5090 + RTX 4090: β₯ 8 tokens/sec sustained
- First-token latency β€ 800ms
- Construction-to-first-byte β€ 500ms
- Tensor-chunk overhead per hop β€ 25ms p95
9. Cross-references
- Capability spec: CAPABILITY_CONTRACT_v3 Β§4.1β4.3
- Tensor transport: X08
- Base LLM service: M04
- Trust levels: M01
10. Open research questions
- Activation privacy. Can we add fast-to-compute noise that preserves inference accuracy but defeats activation-inversion attacks? Cite the Geiping et al. inversion paper as the threat baseline.
- Mid-token recovery. Currently a shard failure restarts the in-flight token. Could we use micro-checkpointing (every K tokens) to recover without a restart? Latency cost?
- Heterogeneous shards. Could a 4090 host the early layers (heavier compute per layer) and a 3060 the later ones, while remaining balanced? Probably yes β automated load assignment is the research question.
- Async pipeline. Currently the orchestrator interleaves at the token level. Could it interleave at the layer level (one shard processes token N+2 while another processes N+1) for higher throughput? In theory yes; coordination protocol unclear.
- Mixed local + distributed. When the orchestrator could host some layers itself (it has a GPU), should it? When? Currently
allow_self_as_shard=false. A heuristic that considers compute headroom would be richer. - Adversarial detection. Beyond NaN/Inf, can we cheaply detect activation tampering by comparing to a small "shadow inference" on a tiny model? Cost vs. benefit unclear.
- Pricing / incentive. A shard host pays in GPU time. A community-internal token-economy is explicitly out of scope (00-OVERVIEW Β§8). But a reputational signal β "this anchor served 4000 shard-tokens this week" β could be helpful. Should it be a metric?
- Backend strategy.
petals_likevssmall_model_layered: which delivers better quality / latency / robustness for our target models? An honest A/B is the answer.