Spaces:
Running on Zero
Running on Zero
GitHub Actions
Quality improvements: Unicode chars, Token class, imports, type hints, formatting
3f78ea8 | from __future__ import annotations | |
| import sqlite3 | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| class Thread: | |
| thread_id: str | |
| name: str | |
| members: list[str] | |
| created_at: float | |
| archived: bool | |
| e2e_enabled: bool | |
| class ThreadMessage: | |
| event_id: str | |
| thread_id: str | |
| sender: str | |
| content: str | |
| sent_at: float | |
| delivered_to: frozenset[str] | |
| class ThreadViewStore: | |
| """Materialised view of thread state from chat.thread.* events. | |
| Uses SQLite when available, falls back to in-memory dicts. | |
| """ | |
| def __init__(self, db_path: str | Path | None = None) -> None: | |
| self._lock = threading.Lock() | |
| self._db: sqlite3.Connection | None = None | |
| # In-memory fallback structures | |
| self._threads: dict[str, dict] = {} # thread_id -> thread data | |
| self._members: dict[str, set[str]] = {} # thread_id -> set of member_ids | |
| self._messages: dict[str, dict] = {} # event_id -> message data | |
| self._msg_by_thread: dict[str, list[str]] = {} # thread_id -> [event_id, ...] | |
| # read receipts: thread_id -> {member_id -> last_read_ts} | |
| self._read_receipts: dict[str, dict[str, float]] = {} | |
| if db_path: | |
| try: | |
| self._db = sqlite3.connect(str(db_path), check_same_thread=False) | |
| self._db.execute("PRAGMA journal_mode=WAL") | |
| self._init_schema() | |
| except Exception: | |
| self._db = None | |
| def _init_schema(self) -> None: | |
| assert self._db is not None | |
| self._db.executescript(""" | |
| CREATE TABLE IF NOT EXISTS threads ( | |
| thread_id TEXT PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| created_at REAL NOT NULL, | |
| archived INTEGER NOT NULL DEFAULT 0, | |
| e2e_enabled INTEGER NOT NULL DEFAULT 0 | |
| ); | |
| CREATE TABLE IF NOT EXISTS thread_members ( | |
| thread_id TEXT NOT NULL, | |
| member_id TEXT NOT NULL, | |
| PRIMARY KEY (thread_id, member_id) | |
| ); | |
| CREATE TABLE IF NOT EXISTS thread_messages ( | |
| event_id TEXT PRIMARY KEY, | |
| thread_id TEXT NOT NULL, | |
| sender TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| sent_at REAL NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS delivered_to ( | |
| event_id TEXT NOT NULL, | |
| member_id TEXT NOT NULL, | |
| PRIMARY KEY (event_id, member_id) | |
| ); | |
| CREATE TABLE IF NOT EXISTS read_receipts ( | |
| thread_id TEXT NOT NULL, | |
| member_id TEXT NOT NULL, | |
| last_read_ts REAL NOT NULL, | |
| PRIMARY KEY (thread_id, member_id) | |
| ); | |
| """) | |
| self._db.commit() | |
| # ── Apply event ─────────────────────────────────────────────────────────── | |
| def apply(self, event: dict) -> None: | |
| etype = event.get("event_type", "") | |
| payload = event.get("payload", {}) | |
| author = event.get("author", "") | |
| event_id = event.get("event_id", "") | |
| with self._lock: | |
| if etype == "chat.thread.created": | |
| self._apply_thread_created(event_id, payload, author) | |
| elif etype == "chat.thread.message.sent": | |
| self._apply_message_sent(event_id, payload, author) | |
| elif etype == "chat.thread.member.added": | |
| self._apply_member_added(payload) | |
| elif etype == "chat.thread.member.removed": | |
| self._apply_member_removed(payload) | |
| elif etype == "chat.thread.archived": | |
| self._apply_archived(payload) | |
| def _apply_thread_created(self, event_id: str, payload: dict, author: str) -> None: | |
| thread_id = payload.get("thread_id", event_id) | |
| members: list[str] = list(payload.get("members", [])) | |
| if author and author not in members: | |
| members.append(author) | |
| name = payload.get("name", "") | |
| created_at = payload.get("created_at", time.time()) | |
| e2e_enabled = bool(payload.get("e2e_enabled", False)) | |
| if self._db: | |
| self._db.execute( | |
| "INSERT OR IGNORE INTO threads (thread_id, name, created_at, archived, e2e_enabled) VALUES (?,?,?,0,?)", | |
| (thread_id, name, created_at, int(e2e_enabled)), | |
| ) | |
| for m in members: | |
| self._db.execute( | |
| "INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)", | |
| (thread_id, m), | |
| ) | |
| self._db.commit() | |
| else: | |
| if thread_id not in self._threads: | |
| self._threads[thread_id] = { | |
| "thread_id": thread_id, | |
| "name": name, | |
| "created_at": created_at, | |
| "archived": False, | |
| "e2e_enabled": e2e_enabled, | |
| } | |
| self._members[thread_id] = set(members) | |
| self._msg_by_thread[thread_id] = [] | |
| def _apply_message_sent(self, event_id: str, payload: dict, author: str) -> None: | |
| thread_id = payload.get("thread_id", "") | |
| sender = payload.get("sender", author) | |
| content = payload.get("content", "") | |
| sent_at = payload.get("sent_at", time.time()) | |
| if self._db: | |
| self._db.execute( | |
| "INSERT OR IGNORE INTO thread_messages (event_id, thread_id, sender, content, sent_at) VALUES (?,?,?,?,?)", | |
| (event_id, thread_id, sender, content, sent_at), | |
| ) | |
| self._db.commit() | |
| else: | |
| if event_id not in self._messages: | |
| self._messages[event_id] = { | |
| "event_id": event_id, | |
| "thread_id": thread_id, | |
| "sender": sender, | |
| "content": content, | |
| "sent_at": sent_at, | |
| "delivered_to": set(), | |
| } | |
| self._msg_by_thread.setdefault(thread_id, []).append(event_id) | |
| def _apply_member_added(self, payload: dict) -> None: | |
| thread_id = payload.get("thread_id", "") | |
| member_id = payload.get("member_id", "") | |
| if not thread_id or not member_id: | |
| return | |
| if self._db: | |
| self._db.execute( | |
| "INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)", | |
| (thread_id, member_id), | |
| ) | |
| self._db.commit() | |
| else: | |
| self._members.setdefault(thread_id, set()).add(member_id) | |
| def _apply_member_removed(self, payload: dict) -> None: | |
| thread_id = payload.get("thread_id", "") | |
| member_id = payload.get("member_id", "") | |
| if not thread_id or not member_id: | |
| return | |
| if self._db: | |
| self._db.execute( | |
| "DELETE FROM thread_members WHERE thread_id=? AND member_id=?", | |
| (thread_id, member_id), | |
| ) | |
| self._db.commit() | |
| else: | |
| self._members.get(thread_id, set()).discard(member_id) | |
| def _apply_archived(self, payload: dict) -> None: | |
| thread_id = payload.get("thread_id", "") | |
| if not thread_id: | |
| return | |
| if self._db: | |
| self._db.execute("UPDATE threads SET archived=1 WHERE thread_id=?", (thread_id,)) | |
| self._db.commit() | |
| else: | |
| if thread_id in self._threads: | |
| t = self._threads[thread_id] | |
| t["archived"] = True | |
| # ── Queries ─────────────────────────────────────────────────────────────── | |
| def get_thread(self, thread_id: str) -> Thread | None: | |
| with self._lock: | |
| if self._db: | |
| row = self._db.execute( | |
| "SELECT thread_id, name, created_at, archived, e2e_enabled FROM threads WHERE thread_id=?", | |
| (thread_id,), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| members_rows = self._db.execute( | |
| "SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,) | |
| ).fetchall() | |
| members = [r[0] for r in members_rows] | |
| return Thread( | |
| thread_id=row[0], | |
| name=row[1], | |
| members=members, | |
| created_at=row[2], | |
| archived=bool(row[3]), | |
| e2e_enabled=bool(row[4]), | |
| ) | |
| t = self._threads.get(thread_id) | |
| if not t: | |
| return None | |
| return Thread( | |
| thread_id=t["thread_id"], | |
| name=t["name"], | |
| members=list(self._members.get(thread_id, set())), | |
| created_at=t["created_at"], | |
| archived=t["archived"], | |
| e2e_enabled=t["e2e_enabled"], | |
| ) | |
| def list_threads(self, member_id: str) -> list[Thread]: | |
| with self._lock: | |
| if self._db: | |
| rows = self._db.execute( | |
| """SELECT t.thread_id, t.name, t.created_at, t.archived, t.e2e_enabled | |
| FROM threads t | |
| JOIN thread_members tm ON t.thread_id=tm.thread_id | |
| WHERE tm.member_id=? | |
| ORDER BY t.created_at DESC""", | |
| (member_id,), | |
| ).fetchall() | |
| result = [] | |
| for row in rows: | |
| thread_id = row[0] | |
| members_rows = self._db.execute( | |
| "SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,) | |
| ).fetchall() | |
| members = [r[0] for r in members_rows] | |
| result.append( | |
| Thread( | |
| thread_id=thread_id, | |
| name=row[1], | |
| members=members, | |
| created_at=row[2], | |
| archived=bool(row[3]), | |
| e2e_enabled=bool(row[4]), | |
| ) | |
| ) | |
| return result | |
| results = [] | |
| for tid, member_set in self._members.items(): | |
| if member_id in member_set: | |
| t = self._threads.get(tid) | |
| if t: | |
| results.append( | |
| Thread( | |
| thread_id=t["thread_id"], | |
| name=t["name"], | |
| members=list(member_set), | |
| created_at=t["created_at"], | |
| archived=t["archived"], | |
| e2e_enabled=t["e2e_enabled"], | |
| ) | |
| ) | |
| results.sort(key=lambda x: x.created_at, reverse=True) | |
| return results | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| since: float | None = None, | |
| limit: int = 50, | |
| ) -> list[ThreadMessage]: | |
| with self._lock: | |
| if self._db: | |
| if since is not None: | |
| rows = self._db.execute( | |
| "SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages " | |
| "WHERE thread_id=? AND sent_at>? ORDER BY sent_at ASC LIMIT ?", | |
| (thread_id, since, limit), | |
| ).fetchall() | |
| else: | |
| rows = self._db.execute( | |
| "SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages " | |
| "WHERE thread_id=? ORDER BY sent_at ASC LIMIT ?", | |
| (thread_id, limit), | |
| ).fetchall() | |
| messages = [] | |
| for row in rows: | |
| eid = row[0] | |
| delivered_rows = self._db.execute( | |
| "SELECT member_id FROM delivered_to WHERE event_id=?", (eid,) | |
| ).fetchall() | |
| delivered = frozenset(r[0] for r in delivered_rows) | |
| messages.append( | |
| ThreadMessage( | |
| event_id=eid, | |
| thread_id=row[1], | |
| sender=row[2], | |
| content=row[3], | |
| sent_at=row[4], | |
| delivered_to=delivered, | |
| ) | |
| ) | |
| return messages | |
| eids = self._msg_by_thread.get(thread_id, []) | |
| msgs = [] | |
| for eid in eids: | |
| m = self._messages.get(eid) | |
| if m and (since is None or m["sent_at"] > since): | |
| msgs.append( | |
| ThreadMessage( | |
| event_id=m["event_id"], | |
| thread_id=m["thread_id"], | |
| sender=m["sender"], | |
| content=m["content"], | |
| sent_at=m["sent_at"], | |
| delivered_to=frozenset(m["delivered_to"]), | |
| ) | |
| ) | |
| msgs.sort(key=lambda x: x.sent_at) | |
| return msgs[:limit] | |
| def unread_count(self, thread_id: str, member_id: str) -> int: | |
| with self._lock: | |
| if self._db: | |
| row = self._db.execute( | |
| "SELECT last_read_ts FROM read_receipts WHERE thread_id=? AND member_id=?", | |
| (thread_id, member_id), | |
| ).fetchone() | |
| last_read = row[0] if row else 0.0 | |
| count = self._db.execute( | |
| "SELECT COUNT(*) FROM thread_messages WHERE thread_id=? AND sent_at>? AND sender!=?", | |
| (thread_id, last_read, member_id), | |
| ).fetchone()[0] | |
| return int(count) | |
| last_read = self._read_receipts.get(thread_id, {}).get(member_id, 0.0) | |
| eids = self._msg_by_thread.get(thread_id, []) | |
| count = 0 | |
| for eid in eids: | |
| m = self._messages.get(eid) | |
| if m and m["sent_at"] > last_read and m["sender"] != member_id: | |
| count += 1 | |
| return count | |