Spaces:
Running on Zero
Running on Zero
GitHub Actions
fix: 0 test failures; FileService; real RagService; emergency probe; chat return
4aaae80 | """WebSocket upgrade for bidirectional streaming (X06).""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| from collections.abc import AsyncIterator | |
| from dataclasses import dataclass | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| # Optional websockets import (client-side only) | |
| try: | |
| import websockets # type: ignore[import] | |
| HAS_WEBSOCKETS = True | |
| except ImportError: | |
| websockets = None # type: ignore[assignment] | |
| HAS_WEBSOCKETS = False | |
| # Optional FastAPI/Starlette WebSocket import (server-side) | |
| try: | |
| from starlette.websockets import ( # type: ignore[import] | |
| WebSocket, | |
| WebSocketDisconnect, | |
| WebSocketState, | |
| ) | |
| HAS_STARLETTE_WS = True | |
| except ImportError: | |
| WebSocket = None # type: ignore[assignment] | |
| WebSocketDisconnect = None # type: ignore[assignment] | |
| WebSocketState = None # type: ignore[assignment] | |
| HAS_STARLETTE_WS = False | |
| # ββ Dataclasses βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WsClientFrame: | |
| """A parsed frame received from a WebSocket client.""" | |
| type: str # "ack" | "tool_result" | "cancel" | |
| data: dict | |
| # ββ Server side βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WebSocketSession: | |
| """Wraps a Starlette/FastAPI WebSocket from the server's perspective.""" | |
| def __init__(self, ws: Any, keypair: Any = None) -> None: | |
| if ws is None: | |
| raise ValueError("ws must be a non-None WebSocket object") | |
| self._ws = ws | |
| self._keypair = keypair | |
| self.session_id: str = str(uuid.uuid4()) | |
| self.connected_at: float = time.time() | |
| self._seq: int = 0 | |
| async def send_event( | |
| self, | |
| event: str, | |
| data: dict, | |
| seq: int | None = None, | |
| ) -> None: | |
| """Send a JSON frame to the client.""" | |
| if seq is None: | |
| self._seq += 1 | |
| seq = self._seq | |
| frame = json.dumps({"event": event, "data": data, "seq": seq}) | |
| try: | |
| await self._ws.send_text(frame) | |
| except Exception as exc: | |
| logger.debug("WebSocketSession.send_event error: %s", exc) | |
| raise | |
| async def receive_frame(self) -> WsClientFrame | None: | |
| """Receive and parse one inbound JSON frame. Returns None on disconnect.""" | |
| try: | |
| raw = await self._ws.receive_text() | |
| except Exception: | |
| return None | |
| try: | |
| obj = json.loads(raw) | |
| except json.JSONDecodeError: | |
| logger.warning("WebSocketSession: malformed JSON from client") | |
| return None | |
| frame_type = obj.get("type", "") | |
| # Strip type key, rest is data | |
| data = {k: v for k, v in obj.items() if k != "type"} | |
| return WsClientFrame(type=frame_type, data=data) | |
| async def send_ack(self, up_to: int) -> None: | |
| """Send a server-to-client ACK frame.""" | |
| frame = json.dumps({"event": "ack", "data": {"up_to": up_to}, "seq": self._seq}) | |
| try: | |
| await self._ws.send_text(frame) | |
| except Exception as exc: | |
| logger.debug("WebSocketSession.send_ack error: %s", exc) | |
| async def close(self, code: int = 1000) -> None: | |
| """Close the WebSocket with the given close code.""" | |
| try: | |
| await self._ws.close(code=code) | |
| except Exception: | |
| pass | |
| # ββ Client side βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WebSocketClient: | |
| """Client-side WebSocket wrapper. Requires the `websockets` library.""" | |
| def __init__(self, base_url: str, keypair: Any = None) -> None: | |
| if not HAS_WEBSOCKETS: | |
| raise ImportError("Install websockets: pip install websockets") | |
| # Convert http(s) to ws(s) | |
| self._base_url = ( | |
| base_url.rstrip("/").replace("https://", "wss://").replace("http://", "ws://") | |
| ) | |
| self._keypair = keypair | |
| self._conn: Any = None # websockets.WebSocketClientProtocol | |
| async def connect(self, path: str) -> None: | |
| """Establish a WebSocket connection to *path* on the server.""" | |
| if not HAS_WEBSOCKETS: | |
| raise ImportError("Install websockets: pip install websockets") | |
| url = f"{self._base_url}/{path.lstrip('/')}" | |
| self._conn = await websockets.connect(url) # type: ignore[union-attr] | |
| async def stream(self, event_iterator: Any) -> AsyncIterator[dict]: | |
| """ | |
| Send frames from *event_iterator* to the server and yield parsed | |
| server frames until the connection closes. | |
| """ | |
| if self._conn is None: | |
| raise RuntimeError("Not connected. Call connect() first.") | |
| async def _sender() -> None: | |
| try: | |
| async for item in event_iterator: | |
| await self._conn.send(json.dumps(item)) | |
| except Exception as exc: | |
| logger.debug("WebSocketClient._sender error: %s", exc) | |
| sender_task = asyncio.create_task(_sender()) | |
| try: | |
| async for raw in self._conn: | |
| try: | |
| yield json.loads(raw) | |
| except json.JSONDecodeError: | |
| logger.warning("WebSocketClient: malformed JSON from server") | |
| finally: | |
| sender_task.cancel() | |
| try: | |
| await sender_task | |
| except asyncio.CancelledError: | |
| pass | |
| async def send_tool_result(self, tool_call_id: str, body: dict) -> None: | |
| """Send a tool result frame mid-stream.""" | |
| if self._conn is None: | |
| raise RuntimeError("Not connected.") | |
| frame = json.dumps({"type": "tool_result", "tool_call_id": tool_call_id, "body": body}) | |
| await self._conn.send(frame) | |
| async def cancel(self) -> None: | |
| """Send a cancel frame to the server.""" | |
| if self._conn is None: | |
| return | |
| try: | |
| await self._conn.send(json.dumps({"type": "cancel"})) | |
| except Exception: | |
| pass | |
| async def close(self) -> None: | |
| """Close the WebSocket connection gracefully.""" | |
| if self._conn is not None: | |
| try: | |
| await self._conn.close() | |
| except Exception: | |
| pass | |
| finally: | |
| self._conn = None | |
| # ββ PubSub fanout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class WebsocketPubSub: | |
| """ | |
| In-process publish/subscribe for WebSocket sessions. | |
| subscribe/unsubscribe are synchronous; publish is async and fan-outs to all | |
| sessions registered for the topic. | |
| """ | |
| def __init__(self) -> None: | |
| self._subscriptions: dict[str, set[WebSocketSession]] = {} | |
| self._lock = asyncio.Lock() | |
| def subscribe(self, topic: str, ws_session: WebSocketSession) -> None: | |
| """Register *ws_session* to receive messages on *topic*.""" | |
| if topic not in self._subscriptions: | |
| self._subscriptions[topic] = set() | |
| self._subscriptions[topic].add(ws_session) | |
| def unsubscribe(self, topic: str, ws_session: WebSocketSession) -> None: | |
| """Remove *ws_session* from *topic*.""" | |
| if topic in self._subscriptions: | |
| self._subscriptions[topic].discard(ws_session) | |
| if not self._subscriptions[topic]: | |
| del self._subscriptions[topic] | |
| async def publish(self, topic: str, event: str, data: dict) -> int: | |
| """ | |
| Fan-out *event*/*data* to all sessions subscribed to *topic*. | |
| Returns the number of sessions that received the message. | |
| """ | |
| async with self._lock: | |
| sessions = list(self._subscriptions.get(topic, [])) | |
| dead: list[WebSocketSession] = [] | |
| delivered = 0 | |
| for session in sessions: | |
| try: | |
| await session.send_event(event, data) | |
| delivered += 1 | |
| except Exception: | |
| dead.append(session) | |
| # Clean up disconnected sessions | |
| if dead: | |
| async with self._lock: | |
| for session in dead: | |
| self.unsubscribe(topic, session) | |
| return delivered | |