File size: 3,709 Bytes
31c93b1
4aaae80
31c93b1
 
 
d6ca3a2
31c93b1
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ca3a2
31c93b1
 
 
38cba90
 
 
 
 
 
 
 
 
 
 
 
78cc96f
38cba90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
31c93b1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""SSE writer/reader helpers."""

from __future__ import annotations

import asyncio
import contextlib
import json
from collections.abc import AsyncIterator


def encode_sse_frame(data: dict, event: str | None = None) -> str:
    """Encode a dict as an SSE frame string."""
    lines = []
    if event:
        lines.append(f"event: {event}")
    lines.append(f"data: {json.dumps(data, separators=(',', ':'))}")
    lines.append("")
    lines.append("")
    return "\n".join(lines)


async def parse_sse_stream(lines: AsyncIterator[str]) -> AsyncIterator[dict]:
    """Parse SSE stream lines into dicts."""
    async for line in lines:
        if line.startswith("data: "):
            with contextlib.suppress(json.JSONDecodeError):
                yield json.loads(line[6:])


# ---------------------------------------------------------------------------
# Frame — typed SSE frame (X01 §3.2)
# ---------------------------------------------------------------------------


class Frame:
    """A single SSE frame with optional event tag and raw data.

    Spec: X01-transport §3.2 — wire format is ``data: <json>\\n\\n``
    with optional ``event: <tag>\\n`` prefix.
    """

    __slots__ = ("data", "event", "raw")

    def __init__(self, data: dict, event: str | None = None) -> None:
        self.data = data
        self.event = event
        self.raw = encode_sse_frame(data, event)

    def __repr__(self) -> str:
        return f"Frame(event={self.event!r}, data={self.data!r})"


# ---------------------------------------------------------------------------
# SseReader — parse an HTTP SSE response stream (X01 §3.2)
# ---------------------------------------------------------------------------


class SseReader:
    """Parse a streaming HTTP response into Frame objects.

    Typical usage with httpx::

        async with httpx.AsyncClient() as client:
            async with client.stream("POST", url, ...) as resp:
                reader = SseReader(resp.aiter_lines())
                async for frame in reader:
                    handle(frame)
    """

    def __init__(self, lines: AsyncIterator[str]) -> None:
        self._lines = lines

    async def __aiter__(self) -> AsyncIterator[Frame]:
        event_tag: str | None = None
        async for line in self._lines:
            if line.startswith("event:"):
                event_tag = line[6:].strip()
            elif line.startswith("data:"):
                raw = line[5:].strip()
                try:
                    data = json.loads(raw)
                except json.JSONDecodeError:
                    data = {"raw": raw}
                yield Frame(data, event_tag)
                event_tag = None
            elif not line.strip():
                event_tag = None  # blank separator


class SseWriter:
    """Async generator that yields SSE-formatted strings."""

    def __init__(self):
        self._queue: asyncio.Queue | None = None
        self._done = False

    async def start(self) -> None:
        self._queue = asyncio.Queue()

    async def send(self, data: dict, event: str | None = None) -> None:
        if self._queue is not None:
            await self._queue.put(encode_sse_frame(data, event))

    def close(self) -> None:
        self._done = True

    async def __aiter__(self):
        while not self._done:
            try:
                frame = await asyncio.wait_for(self._queue.get(), timeout=0.5)
                yield frame
            except TimeoutError:
                if self._done:
                    break
                yield ": keepalive\n\n"
            except Exception:
                if self._done:
                    break
                yield ": keepalive\n\n"