Spaces:
Running on Zero
Running on Zero
GitHub Actions
fix: 0 test failures; FileService; real RagService; emergency probe; chat return
4aaae80 | """Shard descriptors and server for distributed inference (M26 β experimental).""" | |
| from __future__ import annotations | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| ShardID = str | |
| class ShardDescriptor: | |
| """Describes one contiguous layer range hosted by a node.""" | |
| shard_id: ShardID # "<model_id>:<lo>-<hi>" | |
| model_id: str | |
| layer_lo: int | |
| layer_hi: int # inclusive | |
| node_id: str | |
| endpoint: str | |
| dtype: str = "float16" # "float16" | "bfloat16" | "int8" | |
| advertised_at: float = field(default_factory=time.time) | |
| def layer_count(self) -> int: | |
| return self.layer_hi - self.layer_lo + 1 | |
| class ShardServer: | |
| """Hosts one contiguous shard. | |
| Loaded lazily on first use; evictable under memory pressure. | |
| This is an experimental module β only active when | |
| config.research.distributed_inference = True. | |
| """ | |
| def __init__(self, descriptor: ShardDescriptor) -> None: | |
| self._desc = descriptor | |
| self._model: Any = None | |
| self._loaded = False | |
| def descriptor(self) -> ShardDescriptor: | |
| return self._desc | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| def load(self) -> None: | |
| """Load the shard weights. Raises ImportError if torch unavailable.""" | |
| try: | |
| import torch # noqa: F401 | |
| except ImportError as exc: | |
| raise ImportError( | |
| "PyTorch is required for distributed inference. Install: pip install torch" | |
| ) from exc | |
| # Actual weight loading would go here; placeholder for the research prototype. | |
| self._loaded = True | |
| def evict(self) -> None: | |
| """Free shard memory.""" | |
| self._model = None | |
| self._loaded = False | |
| async def forward(self, hidden_states: bytes, dtype: str = "float16") -> bytes: | |
| """Run one forward pass through this shard. | |
| hidden_states: raw tensor bytes (X08 tensor-transport format) | |
| Returns: raw output tensor bytes | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| # Placeholder β real implementation uses torch to slice model and forward. | |
| raise NotImplementedError( | |
| "ShardServer.forward() is not yet implemented for this shard. " | |
| "This is an experimental Phase 3 feature." | |
| ) | |
| def health(self) -> dict: | |
| return { | |
| "shard_id": self._desc.shard_id, | |
| "loaded": self._loaded, | |
| "layers": f"{self._desc.layer_lo}-{self._desc.layer_hi}", | |
| "status": "loaded" if self._loaded else "unloaded", | |
| } | |