GitHub Actions
Quality improvements: Unicode chars, Token class, imports, type hints, formatting
3f78ea8
Raw
History Blame
14.8 kB
from __future__ import annotations
import sqlite3
import threading
import time
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class Thread:
thread_id: str
name: str
members: list[str]
created_at: float
archived: bool
e2e_enabled: bool
@dataclass(frozen=True)
class ThreadMessage:
event_id: str
thread_id: str
sender: str
content: str
sent_at: float
delivered_to: frozenset[str]
class ThreadViewStore:
"""Materialised view of thread state from chat.thread.* events.
Uses SQLite when available, falls back to in-memory dicts.
"""
def __init__(self, db_path: str | Path | None = None) -> None:
self._lock = threading.Lock()
self._db: sqlite3.Connection | None = None
# In-memory fallback structures
self._threads: dict[str, dict] = {} # thread_id -> thread data
self._members: dict[str, set[str]] = {} # thread_id -> set of member_ids
self._messages: dict[str, dict] = {} # event_id -> message data
self._msg_by_thread: dict[str, list[str]] = {} # thread_id -> [event_id, ...]
# read receipts: thread_id -> {member_id -> last_read_ts}
self._read_receipts: dict[str, dict[str, float]] = {}
if db_path:
try:
self._db = sqlite3.connect(str(db_path), check_same_thread=False)
self._db.execute("PRAGMA journal_mode=WAL")
self._init_schema()
except Exception:
self._db = None
def _init_schema(self) -> None:
assert self._db is not None
self._db.executescript("""
CREATE TABLE IF NOT EXISTS threads (
thread_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
created_at REAL NOT NULL,
archived INTEGER NOT NULL DEFAULT 0,
e2e_enabled INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS thread_members (
thread_id TEXT NOT NULL,
member_id TEXT NOT NULL,
PRIMARY KEY (thread_id, member_id)
);
CREATE TABLE IF NOT EXISTS thread_messages (
event_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL,
sender TEXT NOT NULL,
content TEXT NOT NULL,
sent_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS delivered_to (
event_id TEXT NOT NULL,
member_id TEXT NOT NULL,
PRIMARY KEY (event_id, member_id)
);
CREATE TABLE IF NOT EXISTS read_receipts (
thread_id TEXT NOT NULL,
member_id TEXT NOT NULL,
last_read_ts REAL NOT NULL,
PRIMARY KEY (thread_id, member_id)
);
""")
self._db.commit()
# ── Apply event ───────────────────────────────────────────────────────────
def apply(self, event: dict) -> None:
etype = event.get("event_type", "")
payload = event.get("payload", {})
author = event.get("author", "")
event_id = event.get("event_id", "")
with self._lock:
if etype == "chat.thread.created":
self._apply_thread_created(event_id, payload, author)
elif etype == "chat.thread.message.sent":
self._apply_message_sent(event_id, payload, author)
elif etype == "chat.thread.member.added":
self._apply_member_added(payload)
elif etype == "chat.thread.member.removed":
self._apply_member_removed(payload)
elif etype == "chat.thread.archived":
self._apply_archived(payload)
def _apply_thread_created(self, event_id: str, payload: dict, author: str) -> None:
thread_id = payload.get("thread_id", event_id)
members: list[str] = list(payload.get("members", []))
if author and author not in members:
members.append(author)
name = payload.get("name", "")
created_at = payload.get("created_at", time.time())
e2e_enabled = bool(payload.get("e2e_enabled", False))
if self._db:
self._db.execute(
"INSERT OR IGNORE INTO threads (thread_id, name, created_at, archived, e2e_enabled) VALUES (?,?,?,0,?)",
(thread_id, name, created_at, int(e2e_enabled)),
)
for m in members:
self._db.execute(
"INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)",
(thread_id, m),
)
self._db.commit()
else:
if thread_id not in self._threads:
self._threads[thread_id] = {
"thread_id": thread_id,
"name": name,
"created_at": created_at,
"archived": False,
"e2e_enabled": e2e_enabled,
}
self._members[thread_id] = set(members)
self._msg_by_thread[thread_id] = []
def _apply_message_sent(self, event_id: str, payload: dict, author: str) -> None:
thread_id = payload.get("thread_id", "")
sender = payload.get("sender", author)
content = payload.get("content", "")
sent_at = payload.get("sent_at", time.time())
if self._db:
self._db.execute(
"INSERT OR IGNORE INTO thread_messages (event_id, thread_id, sender, content, sent_at) VALUES (?,?,?,?,?)",
(event_id, thread_id, sender, content, sent_at),
)
self._db.commit()
else:
if event_id not in self._messages:
self._messages[event_id] = {
"event_id": event_id,
"thread_id": thread_id,
"sender": sender,
"content": content,
"sent_at": sent_at,
"delivered_to": set(),
}
self._msg_by_thread.setdefault(thread_id, []).append(event_id)
def _apply_member_added(self, payload: dict) -> None:
thread_id = payload.get("thread_id", "")
member_id = payload.get("member_id", "")
if not thread_id or not member_id:
return
if self._db:
self._db.execute(
"INSERT OR IGNORE INTO thread_members (thread_id, member_id) VALUES (?,?)",
(thread_id, member_id),
)
self._db.commit()
else:
self._members.setdefault(thread_id, set()).add(member_id)
def _apply_member_removed(self, payload: dict) -> None:
thread_id = payload.get("thread_id", "")
member_id = payload.get("member_id", "")
if not thread_id or not member_id:
return
if self._db:
self._db.execute(
"DELETE FROM thread_members WHERE thread_id=? AND member_id=?",
(thread_id, member_id),
)
self._db.commit()
else:
self._members.get(thread_id, set()).discard(member_id)
def _apply_archived(self, payload: dict) -> None:
thread_id = payload.get("thread_id", "")
if not thread_id:
return
if self._db:
self._db.execute("UPDATE threads SET archived=1 WHERE thread_id=?", (thread_id,))
self._db.commit()
else:
if thread_id in self._threads:
t = self._threads[thread_id]
t["archived"] = True
# ── Queries ───────────────────────────────────────────────────────────────
def get_thread(self, thread_id: str) -> Thread | None:
with self._lock:
if self._db:
row = self._db.execute(
"SELECT thread_id, name, created_at, archived, e2e_enabled FROM threads WHERE thread_id=?",
(thread_id,),
).fetchone()
if not row:
return None
members_rows = self._db.execute(
"SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,)
).fetchall()
members = [r[0] for r in members_rows]
return Thread(
thread_id=row[0],
name=row[1],
members=members,
created_at=row[2],
archived=bool(row[3]),
e2e_enabled=bool(row[4]),
)
t = self._threads.get(thread_id)
if not t:
return None
return Thread(
thread_id=t["thread_id"],
name=t["name"],
members=list(self._members.get(thread_id, set())),
created_at=t["created_at"],
archived=t["archived"],
e2e_enabled=t["e2e_enabled"],
)
def list_threads(self, member_id: str) -> list[Thread]:
with self._lock:
if self._db:
rows = self._db.execute(
"""SELECT t.thread_id, t.name, t.created_at, t.archived, t.e2e_enabled
FROM threads t
JOIN thread_members tm ON t.thread_id=tm.thread_id
WHERE tm.member_id=?
ORDER BY t.created_at DESC""",
(member_id,),
).fetchall()
result = []
for row in rows:
thread_id = row[0]
members_rows = self._db.execute(
"SELECT member_id FROM thread_members WHERE thread_id=?", (thread_id,)
).fetchall()
members = [r[0] for r in members_rows]
result.append(
Thread(
thread_id=thread_id,
name=row[1],
members=members,
created_at=row[2],
archived=bool(row[3]),
e2e_enabled=bool(row[4]),
)
)
return result
results = []
for tid, member_set in self._members.items():
if member_id in member_set:
t = self._threads.get(tid)
if t:
results.append(
Thread(
thread_id=t["thread_id"],
name=t["name"],
members=list(member_set),
created_at=t["created_at"],
archived=t["archived"],
e2e_enabled=t["e2e_enabled"],
)
)
results.sort(key=lambda x: x.created_at, reverse=True)
return results
def get_messages(
self,
thread_id: str,
since: float | None = None,
limit: int = 50,
) -> list[ThreadMessage]:
with self._lock:
if self._db:
if since is not None:
rows = self._db.execute(
"SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages "
"WHERE thread_id=? AND sent_at>? ORDER BY sent_at ASC LIMIT ?",
(thread_id, since, limit),
).fetchall()
else:
rows = self._db.execute(
"SELECT event_id, thread_id, sender, content, sent_at FROM thread_messages "
"WHERE thread_id=? ORDER BY sent_at ASC LIMIT ?",
(thread_id, limit),
).fetchall()
messages = []
for row in rows:
eid = row[0]
delivered_rows = self._db.execute(
"SELECT member_id FROM delivered_to WHERE event_id=?", (eid,)
).fetchall()
delivered = frozenset(r[0] for r in delivered_rows)
messages.append(
ThreadMessage(
event_id=eid,
thread_id=row[1],
sender=row[2],
content=row[3],
sent_at=row[4],
delivered_to=delivered,
)
)
return messages
eids = self._msg_by_thread.get(thread_id, [])
msgs = []
for eid in eids:
m = self._messages.get(eid)
if m and (since is None or m["sent_at"] > since):
msgs.append(
ThreadMessage(
event_id=m["event_id"],
thread_id=m["thread_id"],
sender=m["sender"],
content=m["content"],
sent_at=m["sent_at"],
delivered_to=frozenset(m["delivered_to"]),
)
)
msgs.sort(key=lambda x: x.sent_at)
return msgs[:limit]
def unread_count(self, thread_id: str, member_id: str) -> int:
with self._lock:
if self._db:
row = self._db.execute(
"SELECT last_read_ts FROM read_receipts WHERE thread_id=? AND member_id=?",
(thread_id, member_id),
).fetchone()
last_read = row[0] if row else 0.0
count = self._db.execute(
"SELECT COUNT(*) FROM thread_messages WHERE thread_id=? AND sent_at>? AND sender!=?",
(thread_id, last_read, member_id),
).fetchone()[0]
return int(count)
last_read = self._read_receipts.get(thread_id, {}).get(member_id, 0.0)
eids = self._msg_by_thread.get(thread_id, [])
count = 0
for eid in eids:
m = self._messages.get(eid)
if m and m["sent_at"] > last_read and m["sender"] != member_id:
count += 1
return count