"""Persistent SQLite-backed cache for cross-startup state. Without this: every server restart rebuilds caches from scratch: • overruled detection cache (~30s on 5k docs) • judge stats (~10s) • outcome stats (~10s) • citation network (~1s) • per-doc structurer (lazy) With this: caches survive restart, keyed by `corpus_fingerprint` (count + first/last doc_id + doc count). When fingerprint changes, the cache auto-invalidates and rebuilds. Single SQLite file at tau_rag/runtime/cache_store.db. Usage: from tau_rag.storage import get_cache_store cache = get_cache_store() cache.set('overruled', fingerprint, payload_dict) payload = cache.get('overruled', fingerprint) # None if stale or missing """ from __future__ import annotations import gzip import json import sqlite3 import threading import time from pathlib import Path from typing import Any, Dict, Optional _DEFAULT_PATH = (Path(__file__).resolve().parent.parent / "runtime" / "cache_store.db") _singleton: Optional["PersistentCache"] = None _singleton_lock = threading.Lock() def get_cache_store() -> "PersistentCache": """Process-wide singleton.""" global _singleton if _singleton is None: with _singleton_lock: if _singleton is None: _singleton = PersistentCache(_DEFAULT_PATH) return _singleton class PersistentCache: """Thin SQLite KV store keyed by (cache_name, fingerprint). Schema: CREATE TABLE caches ( cache_name TEXT, fingerprint TEXT, payload_gz BLOB NOT NULL, ts INTEGER NOT NULL, PRIMARY KEY (cache_name) ); Stores ONE row per cache_name. Putting a new value overwrites the old — we don't keep old fingerprints around (they're stale anyway). """ def __init__(self, path): self.path = Path(path) self.path.parent.mkdir(parents=True, exist_ok=True) self._tls = threading.local() self._init_schema() def _conn(self) -> sqlite3.Connection: c = getattr(self._tls, "conn", None) if c is None: c = sqlite3.connect(str(self.path), timeout=30.0, isolation_level=None) c.execute("PRAGMA journal_mode=WAL") c.execute("PRAGMA synchronous=NORMAL") self._tls.conn = c return c def _init_schema(self) -> None: c = self._conn() c.execute(""" CREATE TABLE IF NOT EXISTS caches ( cache_name TEXT PRIMARY KEY, fingerprint TEXT NOT NULL, payload_gz BLOB NOT NULL, ts INTEGER NOT NULL ) """) def get(self, cache_name: str, fingerprint: str) -> Optional[Any]: """Return the cached payload if it matches the fingerprint, else None.""" if not cache_name or not fingerprint: return None try: row = self._conn().execute( "SELECT fingerprint, payload_gz FROM caches " "WHERE cache_name = ?", (cache_name,), ).fetchone() if not row: return None stored_fp, blob = row if stored_fp != fingerprint: # Fingerprint mismatch — corpus changed, payload stale return None try: return json.loads(gzip.decompress(blob).decode("utf-8")) except Exception: return None except Exception: return None def set(self, cache_name: str, fingerprint: str, payload: Any) -> None: if not cache_name or not fingerprint: return try: blob = gzip.compress( json.dumps(payload, ensure_ascii=False, default=str) .encode("utf-8"), compresslevel=6, ) self._conn().execute( "INSERT OR REPLACE INTO caches " "(cache_name, fingerprint, payload_gz, ts) VALUES (?, ?, ?, ?)", (cache_name, fingerprint, blob, int(time.time())), ) except Exception as e: # Persistent cache is best-effort — never crash the caller print(f"[cache] save failed for {cache_name}: {e}") def invalidate(self, cache_name: str) -> None: try: self._conn().execute( "DELETE FROM caches WHERE cache_name = ?", (cache_name,)) except Exception: pass def clear_all(self) -> None: try: self._conn().execute("DELETE FROM caches") except Exception: pass def stats(self) -> Dict[str, Any]: try: c = self._conn() rows = c.execute( "SELECT cache_name, fingerprint, LENGTH(payload_gz), ts " "FROM caches ORDER BY cache_name").fetchall() return { "n_caches": len(rows), "entries": [ {"name": r[0], "fingerprint": r[1][:40], "size_bytes": r[2], "ts": r[3]} for r in rows ], "db_bytes": self.path.stat().st_size if self.path.exists() else 0, } except Exception as e: return {"error": str(e)} def fingerprint_corpus(pipe) -> str: """Compute a corpus fingerprint that changes when docs are added/removed. Cheap: just (count, first_id, last_id, hash of all ids).""" import hashlib docs = (getattr(pipe, "_indexed_docs", None) or getattr(pipe, "_docs", None) or []) n = len(docs) if n == 0: return "empty" ids = [str(getattr(d, "id", "") or "") for d in docs] h = hashlib.md5("|".join(sorted(ids)).encode("utf-8")).hexdigest()[:12] return f"n={n};h={h}" __all__ = [ "PersistentCache", "get_cache_store", "fingerprint_corpus", ]