legal-eye / tau_rag /storage /retriever_persistence.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
"""Retriever index persistence — save/load BM25/Graph/Gematria/Hilbert
state to disk so server restart doesn't rebuild from scratch.
Without this: 50k corpus → 60-120s startup (chunking + indexing 5+ retrievers).
With this: 60s on first run, then <5s on every subsequent restart (loads
indexes from pickle files).
Storage layout: tau_rag/runtime/retriever_state/
manifest.json — corpus fingerprint + file index
bm25.pkl — pickled state dict
graph.pkl
gematria.pkl
hilbert.npz — numpy ndarray (more efficient)
hilbert_index.bin — HNSW binary
Cache invalidation: each save records the corpus fingerprint (count +
hash of doc_ids). On load, if fingerprint differs → skip load, force rebuild.
"""
from __future__ import annotations
import gzip
import json
import os
import pickle
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
# gzip magic bytes — used to auto-detect compressed files on load
_GZIP_MAGIC = b"\x1f\x8b"
# Compress level: 6 = balanced (fast enough, ~2-4× shrink on Hebrew text).
# Set TAU_RAG_PICKLE_GZIP=0 to disable (debugging, profiling).
_USE_GZIP = os.environ.get("TAU_RAG_PICKLE_GZIP", "1") != "0"
_GZIP_LEVEL = int(os.environ.get("TAU_RAG_PICKLE_GZIP_LEVEL", "6"))
def _dump_pickle(obj: Any, path: Path) -> None:
"""Pickle + optionally gzip-compress + write atomically."""
if _USE_GZIP:
with gzip.open(str(path), "wb",
compresslevel=_GZIP_LEVEL) as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
with path.open("wb") as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
def _load_pickle(path: Path) -> Any:
"""Auto-detects gzip vs raw pickle by magic bytes — backward compat
with uncompressed pickles from before this change."""
with path.open("rb") as f:
head = f.read(2)
if head == _GZIP_MAGIC:
with gzip.open(str(path), "rb") as f:
return pickle.load(f)
with path.open("rb") as f:
return pickle.load(f)
_DEFAULT_DIR = (Path(__file__).resolve().parent.parent
/ "runtime" / "retriever_state")
def _ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def save_retriever_state(retriever, path: Path) -> bool:
"""Persist a single retriever's state to a pickle file.
Each retriever exposes `state_dict() -> dict` (or we extract attrs
manually). Returns True on success, False on any error.
Writes to `path.tmp` then atomically renames over `path`, so a crash
mid-write never leaves a corrupt file. If the destination is
immutable / unwritable, falls back to a sibling `.new` file so the
user can manually swap them in.
"""
try:
state = None
if hasattr(retriever, "state_dict") and callable(retriever.state_dict):
state = retriever.state_dict()
else:
state = _extract_state_generic(retriever)
if state is None:
return False
tmp = path.with_suffix(path.suffix + ".tmp")
_dump_pickle(state, tmp)
try:
tmp.replace(path) # atomic rename
except (PermissionError, OSError) as e:
# Destination locked (immutable flag, etc) — leave .new sibling
new_sibling = path.with_suffix(path.suffix + ".new")
try:
tmp.replace(new_sibling)
print(f"[persistence] couldn't overwrite {path.name} ({e}); "
f"wrote {new_sibling.name} instead — you can `mv` it manually")
except Exception:
tmp.unlink(missing_ok=True)
raise
return True
except Exception as e:
print(f"[persistence] save failed for {retriever.__class__.__name__}: {e}")
return False
def _resolve_load_path(path: Path) -> Optional[Path]:
"""Pick the best file to load from.
If a sibling `<name>.pkl.new` exists AND is newer than `<name>.pkl`,
prefer the `.new` file. This handles the case where the original
is immutable (couldn't be overwritten in save) — the `.new` was
written next to it and is the actual fresh state.
Returns the resolved path or None if neither exists.
"""
new_sibling = path.with_suffix(path.suffix + ".new")
if new_sibling.exists():
if not path.exists():
return new_sibling
try:
if new_sibling.stat().st_mtime > path.stat().st_mtime:
return new_sibling
except OSError:
return new_sibling
return path if path.exists() else None
def load_retriever_state(retriever, path: Path) -> bool:
"""Restore a retriever's state from a pickle file.
Auto-prefers `<path>.new` if it exists and is newer (handles the
case where the original was immutable and the save fell back to a
sibling).
Detects old-format pickles (saved by `_extract_state_generic` before
each retriever had its own state_dict()) and rejects them — those
would silently leave the retriever empty. The expected new-format
state_dict has a "_class" key plus retriever-specific fields like
"chunks" / "hists" / "vecs_f16" without leading underscores.
"""
resolved = _resolve_load_path(path)
if resolved is None:
return False
try:
# _load_pickle auto-detects gzip vs raw pickle (backward compat)
state = _load_pickle(resolved)
# Old-format detection: the generic extractor pickled the full
# __dict__ with leading-underscore keys. New format has plain keys.
if isinstance(state, dict) and "_chunks" in state:
print(f"[persistence] old-format pickle at {resolved.name} — "
f"ignoring, will rebuild")
return False
if hasattr(retriever, "load_state_dict") and callable(retriever.load_state_dict):
retriever.load_state_dict(state)
else:
_restore_state_generic(retriever, state)
# Log when we used the .new fallback so it's visible in startup logs
if resolved.suffix == ".new" or resolved.name.endswith(".pkl.new"):
print(f"[persistence] loaded {resolved.name} (sibling override)")
return True
except Exception as e:
print(f"[persistence] load failed for {retriever.__class__.__name__}: {e}")
return False
# Generic fallback — pickles all "_*" attrs of the retriever object
def _extract_state_generic(retriever) -> Dict[str, Any]:
state: Dict[str, Any] = {"_class": retriever.__class__.__name__}
for k, v in retriever.__dict__.items():
# Skip non-picklable things (functions, locks, executors)
if k.startswith("_executor"):
continue
if callable(v):
continue
try:
pickle.dumps(v) # quick pickle test
state[k] = v
except Exception:
pass
return state
def _restore_state_generic(retriever, state: Dict[str, Any]) -> None:
for k, v in state.items():
if k == "_class":
continue
try:
setattr(retriever, k, v)
except Exception:
pass
# ============================================================================
# Top-level: persist/load all retrievers in a MultiRetriever at once
# ============================================================================
def fingerprint_corpus(pipe) -> str:
"""Stable fingerprint of the indexed corpus."""
import hashlib
docs = (getattr(pipe, "_indexed_docs", None)
or getattr(pipe, "_docs", None) or [])
n = len(docs)
if n == 0:
return "empty"
ids = [str(getattr(d, "id", "") or "") for d in docs]
h = hashlib.md5("|".join(sorted(ids)).encode("utf-8")).hexdigest()[:12]
return f"n={n};h={h}"
def save_all_retrievers(pipe, base_dir: Optional[Path] = None) -> Dict[str, Any]:
"""Save every registered retriever's state. Returns a manifest dict."""
base_dir = base_dir or _DEFAULT_DIR
_ensure_dir(base_dir)
fp = fingerprint_corpus(pipe)
multi = getattr(pipe, "retrievers", None)
if multi is None or not hasattr(multi, "_retrievers"):
return {"error": "no MultiRetriever"}
manifest = {
"fingerprint": fp,
"saved_at": int(time.time()),
"retrievers": {},
}
for name, r in multi._retrievers.items():
# The retriever may be wrapped (ReferenceAware) — unwrap to the inner
inner = getattr(r, "_inner", r)
cls = inner.__class__.__name__
path = base_dir / f"{name}.pkl"
ok = save_retriever_state(inner, path)
manifest["retrievers"][name] = {
"class": cls,
"saved": ok,
"size_bytes": path.stat().st_size if ok and path.exists() else 0,
}
# Persist the indexed_docs list separately so we can skip the chunker
# on next startup (the heavy part)
docs_path = base_dir / "indexed_docs.pkl"
try:
docs = (getattr(pipe, "_indexed_docs", None)
or getattr(pipe, "_docs", None) or [])
_dump_pickle(docs, docs_path)
manifest["indexed_docs"] = {
"saved": True, "n_docs": len(docs),
"size_bytes": docs_path.stat().st_size,
}
except Exception as e:
manifest["indexed_docs"] = {"saved": False, "error": str(e)}
# Write manifest
(base_dir / "manifest.json").write_text(
json.dumps(manifest, ensure_ascii=False, indent=2),
encoding="utf-8")
return manifest
def load_all_retrievers(pipe, base_dir: Optional[Path] = None,
expected_fingerprint: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Try to load all retriever states. Returns manifest if successful;
None if missing, fingerprint mismatch, or any other failure.
"""
base_dir = base_dir or _DEFAULT_DIR
manifest_path = base_dir / "manifest.json"
if not manifest_path.exists():
return None
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except Exception:
return None
if expected_fingerprint and manifest.get("fingerprint") != expected_fingerprint:
return None # corpus changed → caller should rebuild
multi = getattr(pipe, "retrievers", None)
if multi is None or not hasattr(multi, "_retrievers"):
return None
n_total = len(multi._retrievers)
n_loaded = 0
failures: List[str] = []
for name, r in multi._retrievers.items():
inner = getattr(r, "_inner", r)
path = base_dir / f"{name}.pkl"
if load_retriever_state(inner, path):
n_loaded += 1
else:
failures.append(name)
# If any retriever failed to load, abort — caller will rebuild from scratch.
# A partial load would leave some retrievers empty and produce broken
# results without any obvious symptom (search just returns nothing).
if n_loaded < n_total:
print(f"[persistence] aborting load: only {n_loaded}/{n_total} "
f"retrievers loaded (failed: {failures}). Will rebuild.")
return None
# Restore indexed_docs (required — without it the pipeline can't return
# canonical docs for /v1/query). Same .new-sibling preference as above.
docs_path = base_dir / "indexed_docs.pkl"
docs_resolved = _resolve_load_path(docs_path)
if docs_resolved is None:
print(f"[persistence] indexed_docs missing — will rebuild")
return None
try:
docs = _load_pickle(docs_resolved)
pipe._indexed_docs = docs
pipe._indexed_ids = set(getattr(d, "id", None) for d in docs
if getattr(d, "id", None))
if docs_resolved.name.endswith(".pkl.new"):
print(f"[persistence] loaded {docs_resolved.name} (sibling override)")
except Exception as e:
print(f"[persistence] indexed_docs load failed: {e}")
return None
return {**manifest, "n_loaded": n_loaded}
# ============================================================================
# Debounced background save — coalesces rapid add_documents calls
# ============================================================================
import threading as _threading
import time as _time
_DEBOUNCE_LOCK = _threading.Lock()
_DEBOUNCE_TIMER: Optional[_threading.Timer] = None
_LAST_SAVE_AT: float = 0.0
_LAST_SAVE_RESULT: Optional[Dict[str, Any]] = None
def schedule_save(pipe, delay_seconds: float = 60.0,
min_docs: int = 100) -> bool:
"""Schedule a background save N seconds from now, coalescing any
pending save. Subsequent calls reset the timer — so a burst of
add_documents() calls within `delay_seconds` only saves once at
the end of the burst.
Returns True if a save was scheduled, False if skipped (corpus too
small, or persistence module unavailable).
"""
global _DEBOUNCE_TIMER
try:
n_docs = len(getattr(pipe, "_indexed_docs", []) or [])
except Exception:
return False
if n_docs < min_docs:
return False
with _DEBOUNCE_LOCK:
if _DEBOUNCE_TIMER is not None:
try:
_DEBOUNCE_TIMER.cancel()
except Exception:
pass
_DEBOUNCE_TIMER = None
def _run():
global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT
try:
t0 = _time.time()
manifest = save_all_retrievers(pipe)
dt = _time.time() - t0
n_saved = sum(1 for r in manifest.get("retrievers", {}).values()
if r.get("saved"))
print(f"[persistence] debounced save: {n_saved} retrievers "
f"in {dt:.1f}s")
_LAST_SAVE_AT = _time.time()
_LAST_SAVE_RESULT = manifest
except Exception as e:
print(f"[persistence] debounced save failed: {e}")
finally:
_DEBOUNCE_TIMER = None
_DEBOUNCE_TIMER = _threading.Timer(delay_seconds, _run)
_DEBOUNCE_TIMER.daemon = True
_DEBOUNCE_TIMER.start()
return True
def cancel_pending_save() -> bool:
"""Cancel any pending debounced save. Returns True if one was cancelled."""
global _DEBOUNCE_TIMER
with _DEBOUNCE_LOCK:
if _DEBOUNCE_TIMER is not None:
try:
_DEBOUNCE_TIMER.cancel()
_DEBOUNCE_TIMER = None
return True
except Exception:
pass
return False
def flush_pending_save(pipe, timeout_s: float = 120.0) -> Dict[str, Any]:
"""Synchronously run the save NOW. Cancels any pending debounced
timer first. Used by graceful shutdown handlers — guarantees the
on-disk state matches the in-memory state when the server exits.
Returns the manifest (or an error dict). Bounded by timeout_s to
avoid hanging an exit indefinitely.
"""
global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT
cancel_pending_save()
result: Dict[str, Any] = {}
completed = _threading.Event()
def _run():
try:
t0 = _time.time()
manifest = save_all_retrievers(pipe)
dt = _time.time() - t0
n_saved = sum(1 for r in manifest.get("retrievers", {}).values()
if r.get("saved"))
print(f"[persistence] flush_pending_save: {n_saved} retrievers "
f"saved in {dt:.1f}s")
global _LAST_SAVE_AT, _LAST_SAVE_RESULT
_LAST_SAVE_AT = _time.time()
_LAST_SAVE_RESULT = manifest
result.update(manifest)
except Exception as e:
print(f"[persistence] flush_pending_save failed: {e}")
result["error"] = str(e)
finally:
completed.set()
th = _threading.Thread(target=_run, daemon=True,
name="persistence-flush")
th.start()
if not completed.wait(timeout_s):
result["error"] = f"flush timed out after {timeout_s}s"
print(f"[persistence] flush timed out — partial save likely on disk")
return result
def debouncer_status() -> Dict[str, Any]:
"""Inspect the debouncer for /v1/admin/persistence/status."""
pending = (_DEBOUNCE_TIMER is not None
and getattr(_DEBOUNCE_TIMER, "finished", None) is not None
and not _DEBOUNCE_TIMER.finished.is_set())
return {
"pending_save": bool(pending),
"last_save_at": int(_LAST_SAVE_AT) if _LAST_SAVE_AT else 0,
"last_save_age_s": (int(_time.time() - _LAST_SAVE_AT)
if _LAST_SAVE_AT else None),
}
__all__ = [
"save_all_retrievers", "load_all_retrievers",
"save_retriever_state", "load_retriever_state",
"fingerprint_corpus",
"schedule_save", "cancel_pending_save", "debouncer_status",
"flush_pending_save",
]