File size: 14,769 Bytes
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
 
 
 
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
 
 
 
 
 
4cd8837
4aaae80
3f78ea8
 
4aaae80
 
 
 
4cd8837
 
3f78ea8
4cd8837
 
 
4aaae80
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
 
 
 
 
 
4cd8837
4aaae80
 
 
 
 
 
 
4cd8837
 
 
 
 
 
4aaae80
 
 
 
4cd8837
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
from __future__ import annotations

import sqlite3
import threading
import time
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class Thread:
    thread_id: str
    name: str
    members: list[str]
    created_at: float
    archived: bool
    e2e_enabled: bool


@dataclass(frozen=True)
class ThreadMessage:
    event_id: str
    thread_id: str
    sender: str
    content: str
    sent_at: float
    delivered_to: frozenset[str]


class ThreadViewStore:
    """Materialised view of thread state from chat.thread.* events.

    Uses SQLite when available, falls back to in-memory dicts.
    """

    def __init__(self, db_path: str | Path | None = None) -> None:
        self._lock = threading.Lock()
        self._db: sqlite3.Connection | None = None

        # In-memory fallback structures
        self._threads: dict[str, dict] = {}  # thread_id -> thread data
        self._members: dict[str, set[str]] = {}  # thread_id -> set of member_ids
        self._messages: dict[str, dict] = {}  # event_id -> message data
        self._msg_by_thread: dict[str, list[str]] = {}  # thread_id -> [event_id, ...]
        # read receipts: thread_id -> {member_id -> last_read_ts}
        self._read_receipts: dict[str, dict[str, float]] = {}

        if db_path:
            try:
                self._db = sqlite3.connect(str(db_path), check_same_thread=False)
                self._db.execute("PRAGMA journal_mode=WAL")
                self._init_schema()
            except Exception:
                self._db = None

    def _init_schema(self) -> None:
        assert self._db is not None
        self._db.executescript("""
            CREATE TABLE IF NOT EXISTS threads (
                thread_id TEXT PRIMARY KEY,
                name TEXT NOT NULL,
                created_at REAL NOT NULL,
                archived INTEGER NOT NULL DEFAULT 0,
                e2e_enabled INTEGER NOT NULL DEFAULT 0
            );
            CREATE TABLE IF NOT EXISTS thread_members (
                thread_id TEXT NOT NULL,
                member_id TEXT NOT NULL,
                PRIMARY KEY (thread_id, member_id)
            );
            CREATE TABLE IF NOT EXISTS thread_messages (
                event_id TEXT PRIMARY KEY,
                thread_id TEXT NOT NULL,
                sender TEXT NOT NULL,
                content TEXT NOT NULL,
                sent_at REAL NOT NULL
            );
            CREATE TABLE IF NOT EXISTS delivered_to (
                event_id TEXT NOT NULL,
                member_id TEXT NOT NULL,
                PRIMARY KEY (event_id, member_id)
            );
            CREATE TABLE IF NOT EXISTS read_receipts (
                thread_id TEXT NOT NULL,
                member_id TEXT NOT NULL,
                last_read_ts REAL NOT NULL,
                PRIMARY KEY (thread_id, member_id)
            );
        """)
        self._db.commit()

    # ── Apply event ───────────────────────────────────────────────────────────

    def apply(self, event: dict) -> None:
        etype = event.get("event_type", "")
        payload = event.get("payload", {})
        author = event.get("author", "")
        event_id = event.get("event_id", "")

        with self._lock:
            if etype == "chat.thread.created":
                self._apply_thread_created(event_id, payload, author)
            elif etype == "chat.thread.message.sent":
                self._apply_message_sent(event_id, payload, author)
            elif etype == "chat.thread.member.added":
                self._apply_member_added(payload)
            elif etype == "chat.thread.member.removed":
                self._apply_member_removed(payload)
            elif etype == "chat.thread.archived":
                self._apply_archived(payload)

    def _apply_thread_created(self, event_id: str, payload: dict, author: str) -> None:
        thread_id = payload.get("thread_id", event_id)
        members: list[str] = list(payload.get("members", []))
        if author and author not in members:
            members.append(author)
        name = payload.get("name", "")
        created_at = payload.get("created_at", time.time())
        e2e_enabled = bool(payload.get("e2e_enabled", False))

        if self._db:
            self._db.execute(
                "INSERT OR IGNORE INTO threads (thread_id, name, created_at, archived, e2e_enabled) VALUES (?,?,?,0,?)",
                (thread_id, name, created_at, int(e2e_enabled)),
            )
            for m in members:
                self._db.execute(
                    "INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)",
                    (thread_id, m),
                )
            self._db.commit()
        else:
            if thread_id not in self._threads:
                self._threads[thread_id] = {
                    "thread_id": thread_id,
                    "name": name,
                    "created_at": created_at,
                    "archived": False,
                    "e2e_enabled": e2e_enabled,
                }
                self._members[thread_id] = set(members)
                self._msg_by_thread[thread_id] = []

    def _apply_message_sent(self, event_id: str, payload: dict, author: str) -> None:
        thread_id = payload.get("thread_id", "")
        sender = payload.get("sender", author)
        content = payload.get("content", "")
        sent_at = payload.get("sent_at", time.time())

        if self._db:
            self._db.execute(
                "INSERT OR IGNORE INTO thread_messages (event_id, thread_id, sender, content, sent_at) VALUES (?,?,?,?,?)",
                (event_id, thread_id, sender, content, sent_at),
            )
            self._db.commit()
        else:
            if event_id not in self._messages:
                self._messages[event_id] = {
                    "event_id": event_id,
                    "thread_id": thread_id,
                    "sender": sender,
                    "content": content,
                    "sent_at": sent_at,
                    "delivered_to": set(),
                }
                self._msg_by_thread.setdefault(thread_id, []).append(event_id)

    def _apply_member_added(self, payload: dict) -> None:
        thread_id = payload.get("thread_id", "")
        member_id = payload.get("member_id", "")
        if not thread_id or not member_id:
            return

        if self._db:
            self._db.execute(
                "INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)",
                (thread_id, member_id),
            )
            self._db.commit()
        else:
            self._members.setdefault(thread_id, set()).add(member_id)

    def _apply_member_removed(self, payload: dict) -> None:
        thread_id = payload.get("thread_id", "")
        member_id = payload.get("member_id", "")
        if not thread_id or not member_id:
            return

        if self._db:
            self._db.execute(
                "DELETE FROM thread_members WHERE thread_id=? AND member_id=?",
                (thread_id, member_id),
            )
            self._db.commit()
        else:
            self._members.get(thread_id, set()).discard(member_id)

    def _apply_archived(self, payload: dict) -> None:
        thread_id = payload.get("thread_id", "")
        if not thread_id:
            return

        if self._db:
            self._db.execute("UPDATE threads SET archived=1 WHERE thread_id=?", (thread_id,))
            self._db.commit()
        else:
            if thread_id in self._threads:
                t = self._threads[thread_id]
                t["archived"] = True

    # ── Queries ───────────────────────────────────────────────────────────────

    def get_thread(self, thread_id: str) -> Thread | None:
        with self._lock:
            if self._db:
                row = self._db.execute(
                    "SELECT thread_id, name, created_at, archived, e2e_enabled FROM threads WHERE thread_id=?",
                    (thread_id,),
                ).fetchone()
                if not row:
                    return None
                members_rows = self._db.execute(
                    "SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,)
                ).fetchall()
                members = [r[0] for r in members_rows]
                return Thread(
                    thread_id=row[0],
                    name=row[1],
                    members=members,
                    created_at=row[2],
                    archived=bool(row[3]),
                    e2e_enabled=bool(row[4]),
                )
            t = self._threads.get(thread_id)
            if not t:
                return None
            return Thread(
                thread_id=t["thread_id"],
                name=t["name"],
                members=list(self._members.get(thread_id, set())),
                created_at=t["created_at"],
                archived=t["archived"],
                e2e_enabled=t["e2e_enabled"],
            )

    def list_threads(self, member_id: str) -> list[Thread]:
        with self._lock:
            if self._db:
                rows = self._db.execute(
                    """SELECT t.thread_id, t.name, t.created_at, t.archived, t.e2e_enabled
                       FROM threads t
                       JOIN thread_members tm ON t.thread_id=tm.thread_id
                       WHERE tm.member_id=?
                       ORDER BY t.created_at DESC""",
                    (member_id,),
                ).fetchall()
                result = []
                for row in rows:
                    thread_id = row[0]
                    members_rows = self._db.execute(
                        "SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,)
                    ).fetchall()
                    members = [r[0] for r in members_rows]
                    result.append(
                        Thread(
                            thread_id=thread_id,
                            name=row[1],
                            members=members,
                            created_at=row[2],
                            archived=bool(row[3]),
                            e2e_enabled=bool(row[4]),
                        )
                    )
                return result
            results = []
            for tid, member_set in self._members.items():
                if member_id in member_set:
                    t = self._threads.get(tid)
                    if t:
                        results.append(
                            Thread(
                                thread_id=t["thread_id"],
                                name=t["name"],
                                members=list(member_set),
                                created_at=t["created_at"],
                                archived=t["archived"],
                                e2e_enabled=t["e2e_enabled"],
                            )
                        )
            results.sort(key=lambda x: x.created_at, reverse=True)
            return results

    def get_messages(
        self,
        thread_id: str,
        since: float | None = None,
        limit: int = 50,
    ) -> list[ThreadMessage]:
        with self._lock:
            if self._db:
                if since is not None:
                    rows = self._db.execute(
                        "SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages "
                        "WHERE thread_id=? AND sent_at>? ORDER BY sent_at ASC LIMIT ?",
                        (thread_id, since, limit),
                    ).fetchall()
                else:
                    rows = self._db.execute(
                        "SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages "
                        "WHERE thread_id=? ORDER BY sent_at ASC LIMIT ?",
                        (thread_id, limit),
                    ).fetchall()
                messages = []
                for row in rows:
                    eid = row[0]
                    delivered_rows = self._db.execute(
                        "SELECT member_id FROM delivered_to WHERE event_id=?", (eid,)
                    ).fetchall()
                    delivered = frozenset(r[0] for r in delivered_rows)
                    messages.append(
                        ThreadMessage(
                            event_id=eid,
                            thread_id=row[1],
                            sender=row[2],
                            content=row[3],
                            sent_at=row[4],
                            delivered_to=delivered,
                        )
                    )
                return messages
            eids = self._msg_by_thread.get(thread_id, [])
            msgs = []
            for eid in eids:
                m = self._messages.get(eid)
                if m and (since is None or m["sent_at"] > since):
                    msgs.append(
                        ThreadMessage(
                            event_id=m["event_id"],
                            thread_id=m["thread_id"],
                            sender=m["sender"],
                            content=m["content"],
                            sent_at=m["sent_at"],
                            delivered_to=frozenset(m["delivered_to"]),
                        )
                    )
            msgs.sort(key=lambda x: x.sent_at)
            return msgs[:limit]

    def unread_count(self, thread_id: str, member_id: str) -> int:
        with self._lock:
            if self._db:
                row = self._db.execute(
                    "SELECT last_read_ts FROM read_receipts WHERE thread_id=? AND member_id=?",
                    (thread_id, member_id),
                ).fetchone()
                last_read = row[0] if row else 0.0
                count = self._db.execute(
                    "SELECT COUNT(*) FROM thread_messages WHERE thread_id=? AND sent_at>? AND sender!=?",
                    (thread_id, last_read, member_id),
                ).fetchone()[0]
                return int(count)
            last_read = self._read_receipts.get(thread_id, {}).get(member_id, 0.0)
            eids = self._msg_by_thread.get(thread_id, [])
            count = 0
            for eid in eids:
                m = self._messages.get(eid)
                if m and m["sent_at"] > last_read and m["sender"] != member_id:
                    count += 1
            return count