Spaces:
Running on Zero
Running on Zero
File size: 3,709 Bytes
31c93b1 4aaae80 31c93b1 d6ca3a2 31c93b1 4aaae80 31c93b1 d6ca3a2 31c93b1 38cba90 78cc96f 38cba90 31c93b1 4aaae80 31c93b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """SSE writer/reader helpers."""
from __future__ import annotations
import asyncio
import contextlib
import json
from collections.abc import AsyncIterator
def encode_sse_frame(data: dict, event: str | None = None) -> str:
"""Encode a dict as an SSE frame string."""
lines = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, separators=(',', ':'))}")
lines.append("")
lines.append("")
return "\n".join(lines)
async def parse_sse_stream(lines: AsyncIterator[str]) -> AsyncIterator[dict]:
"""Parse SSE stream lines into dicts."""
async for line in lines:
if line.startswith("data: "):
with contextlib.suppress(json.JSONDecodeError):
yield json.loads(line[6:])
# ---------------------------------------------------------------------------
# Frame — typed SSE frame (X01 §3.2)
# ---------------------------------------------------------------------------
class Frame:
"""A single SSE frame with optional event tag and raw data.
Spec: X01-transport §3.2 — wire format is ``data: <json>\\n\\n``
with optional ``event: <tag>\\n`` prefix.
"""
__slots__ = ("data", "event", "raw")
def __init__(self, data: dict, event: str | None = None) -> None:
self.data = data
self.event = event
self.raw = encode_sse_frame(data, event)
def __repr__(self) -> str:
return f"Frame(event={self.event!r}, data={self.data!r})"
# ---------------------------------------------------------------------------
# SseReader — parse an HTTP SSE response stream (X01 §3.2)
# ---------------------------------------------------------------------------
class SseReader:
"""Parse a streaming HTTP response into Frame objects.
Typical usage with httpx::
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, ...) as resp:
reader = SseReader(resp.aiter_lines())
async for frame in reader:
handle(frame)
"""
def __init__(self, lines: AsyncIterator[str]) -> None:
self._lines = lines
async def __aiter__(self) -> AsyncIterator[Frame]:
event_tag: str | None = None
async for line in self._lines:
if line.startswith("event:"):
event_tag = line[6:].strip()
elif line.startswith("data:"):
raw = line[5:].strip()
try:
data = json.loads(raw)
except json.JSONDecodeError:
data = {"raw": raw}
yield Frame(data, event_tag)
event_tag = None
elif not line.strip():
event_tag = None # blank separator
class SseWriter:
"""Async generator that yields SSE-formatted strings."""
def __init__(self):
self._queue: asyncio.Queue | None = None
self._done = False
async def start(self) -> None:
self._queue = asyncio.Queue()
async def send(self, data: dict, event: str | None = None) -> None:
if self._queue is not None:
await self._queue.put(encode_sse_frame(data, event))
def close(self) -> None:
self._done = True
async def __aiter__(self):
while not self._done:
try:
frame = await asyncio.wait_for(self._queue.get(), timeout=0.5)
yield frame
except TimeoutError:
if self._done:
break
yield ": keepalive\n\n"
except Exception:
if self._done:
break
yield ": keepalive\n\n"
|