File size: 8,786 Bytes
4cd8837
4aaae80
4cd8837
 
 
d6ca3a2
4cd8837
 
 
 
4aaae80
 
 
4cd8837
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
3f78ea8
 
 
 
4cd8837
4aaae80
 
 
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ca3a2
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ca3a2
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
d6ca3a2
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""WebSocket upgrade for bidirectional streaming (X06)."""

from __future__ import annotations

import asyncio
import contextlib
import json
import logging
import time
import uuid
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any

logger = logging.getLogger(__name__)

# Optional websockets import (client-side only)
try:
    import websockets  # type: ignore[import]

    HAS_WEBSOCKETS = True
except ImportError:
    websockets = None  # type: ignore[assignment]
    HAS_WEBSOCKETS = False

# Optional FastAPI/Starlette WebSocket import (server-side)
WebSocket: Any
WebSocketDisconnect: Any
WebSocketState: Any

try:
    from starlette.websockets import (  # type: ignore[import]
        WebSocket,
        WebSocketDisconnect,
        WebSocketState,
    )

    HAS_STARLETTE_WS = True
except ImportError:
    WebSocket = None  # type: ignore[assignment]
    WebSocketDisconnect = None  # type: ignore[assignment]
    WebSocketState = None  # type: ignore[assignment]
    HAS_STARLETTE_WS = False


# ── Dataclasses ───────────────────────────────────────────────────────────────


@dataclass(frozen=True)
class WsClientFrame:
    """A parsed frame received from a WebSocket client."""

    type: str  # "ack" | "tool_result" | "cancel"
    data: dict


# ── Server side ───────────────────────────────────────────────────────────────


class WebSocketSession:
    """Wraps a Starlette/FastAPI WebSocket from the server's perspective."""

    def __init__(self, ws: Any, keypair: Any = None) -> None:
        if ws is None:
            raise ValueError("ws must be a non-None WebSocket object")
        self._ws = ws
        self._keypair = keypair
        self.session_id: str = str(uuid.uuid4())
        self.connected_at: float = time.time()
        self._seq: int = 0

    async def send_event(
        self,
        event: str,
        data: dict,
        seq: int | None = None,
    ) -> None:
        """Send a JSON frame to the client."""
        if seq is None:
            self._seq += 1
            seq = self._seq
        frame = json.dumps({"event": event, "data": data, "seq": seq})
        try:
            await self._ws.send_text(frame)
        except Exception as exc:
            logger.debug("WebSocketSession.send_event error: %s", exc)
            raise

    async def receive_frame(self) -> WsClientFrame | None:
        """Receive and parse one inbound JSON frame. Returns None on disconnect."""
        try:
            raw = await self._ws.receive_text()
        except Exception:
            return None
        try:
            obj = json.loads(raw)
        except json.JSONDecodeError:
            logger.warning("WebSocketSession: malformed JSON from client")
            return None
        frame_type = obj.get("type", "")
        # Strip type key, rest is data
        data = {k: v for k, v in obj.items() if k != "type"}
        return WsClientFrame(type=frame_type, data=data)

    async def send_ack(self, up_to: int) -> None:
        """Send a server-to-client ACK frame."""
        frame = json.dumps({"event": "ack", "data": {"up_to": up_to}, "seq": self._seq})
        try:
            await self._ws.send_text(frame)
        except Exception as exc:
            logger.debug("WebSocketSession.send_ack error: %s", exc)

    async def close(self, code: int = 1000) -> None:
        """Close the WebSocket with the given close code."""
        with contextlib.suppress(Exception):
            await self._ws.close(code=code)


# ── Client side ───────────────────────────────────────────────────────────────


class WebSocketClient:
    """Client-side WebSocket wrapper. Requires the `websockets` library."""

    def __init__(self, base_url: str, keypair: Any = None) -> None:
        if not HAS_WEBSOCKETS:
            raise ImportError("Install websockets: pip install websockets")
        # Convert http(s) to ws(s)
        self._base_url = (
            base_url.rstrip("/").replace("https://", "wss://").replace("http://", "ws://")
        )
        self._keypair = keypair
        self._conn: Any = None  # websockets.WebSocketClientProtocol

    async def connect(self, path: str) -> None:
        """Establish a WebSocket connection to *path* on the server."""
        if not HAS_WEBSOCKETS:
            raise ImportError("Install websockets: pip install websockets")
        url = f"{self._base_url}/{path.lstrip('/')}"
        self._conn = await websockets.connect(url)  # type: ignore[union-attr]

    async def stream(self, event_iterator: Any) -> AsyncIterator[dict]:
        """
        Send frames from *event_iterator* to the server and yield parsed
        server frames until the connection closes.
        """
        if self._conn is None:
            raise RuntimeError("Not connected. Call connect() first.")

        async def _sender() -> None:
            try:
                async for item in event_iterator:
                    await self._conn.send(json.dumps(item))
            except Exception as exc:
                logger.debug("WebSocketClient._sender error: %s", exc)

        sender_task = asyncio.create_task(_sender())
        try:
            async for raw in self._conn:
                try:
                    yield json.loads(raw)
                except json.JSONDecodeError:
                    logger.warning("WebSocketClient: malformed JSON from server")
        finally:
            sender_task.cancel()
            with contextlib.suppress(asyncio.CancelledError):
                await sender_task

    async def send_tool_result(self, tool_call_id: str, body: dict) -> None:
        """Send a tool result frame mid-stream."""
        if self._conn is None:
            raise RuntimeError("Not connected.")
        frame = json.dumps({"type": "tool_result", "tool_call_id": tool_call_id, "body": body})
        await self._conn.send(frame)

    async def cancel(self) -> None:
        """Send a cancel frame to the server."""
        if self._conn is None:
            return
        with contextlib.suppress(Exception):
            await self._conn.send(json.dumps({"type": "cancel"}))

    async def close(self) -> None:
        """Close the WebSocket connection gracefully."""
        if self._conn is not None:
            try:
                await self._conn.close()
            except Exception:
                pass
            finally:
                self._conn = None


# ── PubSub fanout ─────────────────────────────────────────────────────────────


class WebsocketPubSub:
    """
    In-process publish/subscribe for WebSocket sessions.

    subscribe/unsubscribe are synchronous; publish is async and fan-outs to all
    sessions registered for the topic.
    """

    def __init__(self) -> None:
        self._subscriptions: dict[str, set[WebSocketSession]] = {}
        self._lock = asyncio.Lock()

    def subscribe(self, topic: str, ws_session: WebSocketSession) -> None:
        """Register *ws_session* to receive messages on *topic*."""
        if topic not in self._subscriptions:
            self._subscriptions[topic] = set()
        self._subscriptions[topic].add(ws_session)

    def unsubscribe(self, topic: str, ws_session: WebSocketSession) -> None:
        """Remove *ws_session* from *topic*."""
        if topic in self._subscriptions:
            self._subscriptions[topic].discard(ws_session)
            if not self._subscriptions[topic]:
                del self._subscriptions[topic]

    async def publish(self, topic: str, event: str, data: dict) -> int:
        """
        Fan-out *event*/*data* to all sessions subscribed to *topic*.
        Returns the number of sessions that received the message.
        """
        async with self._lock:
            sessions = list(self._subscriptions.get(topic, []))

        dead: list[WebSocketSession] = []
        delivered = 0
        for session in sessions:
            try:
                await session.send_event(event, data)
                delivered += 1
            except Exception:
                dead.append(session)

        # Clean up disconnected sessions
        if dead:
            async with self._lock:
                for session in dead:
                    self.unsubscribe(topic, session)

        return delivered