| """Retriever index persistence — save/load BM25/Graph/Gematria/Hilbert |
| state to disk so server restart doesn't rebuild from scratch. |
| |
| Without this: 50k corpus → 60-120s startup (chunking + indexing 5+ retrievers). |
| With this: 60s on first run, then <5s on every subsequent restart (loads |
| indexes from pickle files). |
| |
| Storage layout: tau_rag/runtime/retriever_state/ |
| manifest.json — corpus fingerprint + file index |
| bm25.pkl — pickled state dict |
| graph.pkl |
| gematria.pkl |
| hilbert.npz — numpy ndarray (more efficient) |
| hilbert_index.bin — HNSW binary |
| |
| Cache invalidation: each save records the corpus fingerprint (count + |
| hash of doc_ids). On load, if fingerprint differs → skip load, force rebuild. |
| """ |
| from __future__ import annotations |
|
|
| import gzip |
| import json |
| import os |
| import pickle |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| |
| _GZIP_MAGIC = b"\x1f\x8b" |
|
|
| |
| |
| _USE_GZIP = os.environ.get("TAU_RAG_PICKLE_GZIP", "1") != "0" |
| _GZIP_LEVEL = int(os.environ.get("TAU_RAG_PICKLE_GZIP_LEVEL", "6")) |
|
|
|
|
| def _dump_pickle(obj: Any, path: Path) -> None: |
| """Pickle + optionally gzip-compress + write atomically.""" |
| if _USE_GZIP: |
| with gzip.open(str(path), "wb", |
| compresslevel=_GZIP_LEVEL) as f: |
| pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) |
| else: |
| with path.open("wb") as f: |
| pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
| def _load_pickle(path: Path) -> Any: |
| """Auto-detects gzip vs raw pickle by magic bytes — backward compat |
| with uncompressed pickles from before this change.""" |
| with path.open("rb") as f: |
| head = f.read(2) |
| if head == _GZIP_MAGIC: |
| with gzip.open(str(path), "rb") as f: |
| return pickle.load(f) |
| with path.open("rb") as f: |
| return pickle.load(f) |
|
|
|
|
| _DEFAULT_DIR = (Path(__file__).resolve().parent.parent |
| / "runtime" / "retriever_state") |
|
|
|
|
| def _ensure_dir(path: Path) -> None: |
| path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def save_retriever_state(retriever, path: Path) -> bool: |
| """Persist a single retriever's state to a pickle file. |
| |
| Each retriever exposes `state_dict() -> dict` (or we extract attrs |
| manually). Returns True on success, False on any error. |
| |
| Writes to `path.tmp` then atomically renames over `path`, so a crash |
| mid-write never leaves a corrupt file. If the destination is |
| immutable / unwritable, falls back to a sibling `.new` file so the |
| user can manually swap them in. |
| """ |
| try: |
| state = None |
| if hasattr(retriever, "state_dict") and callable(retriever.state_dict): |
| state = retriever.state_dict() |
| else: |
| state = _extract_state_generic(retriever) |
| if state is None: |
| return False |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| _dump_pickle(state, tmp) |
| try: |
| tmp.replace(path) |
| except (PermissionError, OSError) as e: |
| |
| new_sibling = path.with_suffix(path.suffix + ".new") |
| try: |
| tmp.replace(new_sibling) |
| print(f"[persistence] couldn't overwrite {path.name} ({e}); " |
| f"wrote {new_sibling.name} instead — you can `mv` it manually") |
| except Exception: |
| tmp.unlink(missing_ok=True) |
| raise |
| return True |
| except Exception as e: |
| print(f"[persistence] save failed for {retriever.__class__.__name__}: {e}") |
| return False |
|
|
|
|
| def _resolve_load_path(path: Path) -> Optional[Path]: |
| """Pick the best file to load from. |
| |
| If a sibling `<name>.pkl.new` exists AND is newer than `<name>.pkl`, |
| prefer the `.new` file. This handles the case where the original |
| is immutable (couldn't be overwritten in save) — the `.new` was |
| written next to it and is the actual fresh state. |
| |
| Returns the resolved path or None if neither exists. |
| """ |
| new_sibling = path.with_suffix(path.suffix + ".new") |
| if new_sibling.exists(): |
| if not path.exists(): |
| return new_sibling |
| try: |
| if new_sibling.stat().st_mtime > path.stat().st_mtime: |
| return new_sibling |
| except OSError: |
| return new_sibling |
| return path if path.exists() else None |
|
|
|
|
| def load_retriever_state(retriever, path: Path) -> bool: |
| """Restore a retriever's state from a pickle file. |
| |
| Auto-prefers `<path>.new` if it exists and is newer (handles the |
| case where the original was immutable and the save fell back to a |
| sibling). |
| |
| Detects old-format pickles (saved by `_extract_state_generic` before |
| each retriever had its own state_dict()) and rejects them — those |
| would silently leave the retriever empty. The expected new-format |
| state_dict has a "_class" key plus retriever-specific fields like |
| "chunks" / "hists" / "vecs_f16" without leading underscores. |
| """ |
| resolved = _resolve_load_path(path) |
| if resolved is None: |
| return False |
| try: |
| |
| state = _load_pickle(resolved) |
| |
| |
| if isinstance(state, dict) and "_chunks" in state: |
| print(f"[persistence] old-format pickle at {resolved.name} — " |
| f"ignoring, will rebuild") |
| return False |
| if hasattr(retriever, "load_state_dict") and callable(retriever.load_state_dict): |
| retriever.load_state_dict(state) |
| else: |
| _restore_state_generic(retriever, state) |
| |
| if resolved.suffix == ".new" or resolved.name.endswith(".pkl.new"): |
| print(f"[persistence] loaded {resolved.name} (sibling override)") |
| return True |
| except Exception as e: |
| print(f"[persistence] load failed for {retriever.__class__.__name__}: {e}") |
| return False |
|
|
|
|
| |
| def _extract_state_generic(retriever) -> Dict[str, Any]: |
| state: Dict[str, Any] = {"_class": retriever.__class__.__name__} |
| for k, v in retriever.__dict__.items(): |
| |
| if k.startswith("_executor"): |
| continue |
| if callable(v): |
| continue |
| try: |
| pickle.dumps(v) |
| state[k] = v |
| except Exception: |
| pass |
| return state |
|
|
|
|
| def _restore_state_generic(retriever, state: Dict[str, Any]) -> None: |
| for k, v in state.items(): |
| if k == "_class": |
| continue |
| try: |
| setattr(retriever, k, v) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def fingerprint_corpus(pipe) -> str: |
| """Stable fingerprint of the indexed corpus.""" |
| import hashlib |
| docs = (getattr(pipe, "_indexed_docs", None) |
| or getattr(pipe, "_docs", None) or []) |
| n = len(docs) |
| if n == 0: |
| return "empty" |
| ids = [str(getattr(d, "id", "") or "") for d in docs] |
| h = hashlib.md5("|".join(sorted(ids)).encode("utf-8")).hexdigest()[:12] |
| return f"n={n};h={h}" |
|
|
|
|
| def save_all_retrievers(pipe, base_dir: Optional[Path] = None) -> Dict[str, Any]: |
| """Save every registered retriever's state. Returns a manifest dict.""" |
| base_dir = base_dir or _DEFAULT_DIR |
| _ensure_dir(base_dir) |
| fp = fingerprint_corpus(pipe) |
| multi = getattr(pipe, "retrievers", None) |
| if multi is None or not hasattr(multi, "_retrievers"): |
| return {"error": "no MultiRetriever"} |
| manifest = { |
| "fingerprint": fp, |
| "saved_at": int(time.time()), |
| "retrievers": {}, |
| } |
| for name, r in multi._retrievers.items(): |
| |
| inner = getattr(r, "_inner", r) |
| cls = inner.__class__.__name__ |
| path = base_dir / f"{name}.pkl" |
| ok = save_retriever_state(inner, path) |
| manifest["retrievers"][name] = { |
| "class": cls, |
| "saved": ok, |
| "size_bytes": path.stat().st_size if ok and path.exists() else 0, |
| } |
| |
| |
| docs_path = base_dir / "indexed_docs.pkl" |
| try: |
| docs = (getattr(pipe, "_indexed_docs", None) |
| or getattr(pipe, "_docs", None) or []) |
| _dump_pickle(docs, docs_path) |
| manifest["indexed_docs"] = { |
| "saved": True, "n_docs": len(docs), |
| "size_bytes": docs_path.stat().st_size, |
| } |
| except Exception as e: |
| manifest["indexed_docs"] = {"saved": False, "error": str(e)} |
| |
| (base_dir / "manifest.json").write_text( |
| json.dumps(manifest, ensure_ascii=False, indent=2), |
| encoding="utf-8") |
| return manifest |
|
|
|
|
| def load_all_retrievers(pipe, base_dir: Optional[Path] = None, |
| expected_fingerprint: Optional[str] = None |
| ) -> Optional[Dict[str, Any]]: |
| """Try to load all retriever states. Returns manifest if successful; |
| None if missing, fingerprint mismatch, or any other failure. |
| """ |
| base_dir = base_dir or _DEFAULT_DIR |
| manifest_path = base_dir / "manifest.json" |
| if not manifest_path.exists(): |
| return None |
| try: |
| manifest = json.loads(manifest_path.read_text(encoding="utf-8")) |
| except Exception: |
| return None |
| if expected_fingerprint and manifest.get("fingerprint") != expected_fingerprint: |
| return None |
|
|
| multi = getattr(pipe, "retrievers", None) |
| if multi is None or not hasattr(multi, "_retrievers"): |
| return None |
| n_total = len(multi._retrievers) |
| n_loaded = 0 |
| failures: List[str] = [] |
| for name, r in multi._retrievers.items(): |
| inner = getattr(r, "_inner", r) |
| path = base_dir / f"{name}.pkl" |
| if load_retriever_state(inner, path): |
| n_loaded += 1 |
| else: |
| failures.append(name) |
| |
| |
| |
| if n_loaded < n_total: |
| print(f"[persistence] aborting load: only {n_loaded}/{n_total} " |
| f"retrievers loaded (failed: {failures}). Will rebuild.") |
| return None |
| |
| |
| docs_path = base_dir / "indexed_docs.pkl" |
| docs_resolved = _resolve_load_path(docs_path) |
| if docs_resolved is None: |
| print(f"[persistence] indexed_docs missing — will rebuild") |
| return None |
| try: |
| docs = _load_pickle(docs_resolved) |
| pipe._indexed_docs = docs |
| pipe._indexed_ids = set(getattr(d, "id", None) for d in docs |
| if getattr(d, "id", None)) |
| if docs_resolved.name.endswith(".pkl.new"): |
| print(f"[persistence] loaded {docs_resolved.name} (sibling override)") |
| except Exception as e: |
| print(f"[persistence] indexed_docs load failed: {e}") |
| return None |
| return {**manifest, "n_loaded": n_loaded} |
|
|
|
|
| |
| |
| |
| import threading as _threading |
| import time as _time |
|
|
| _DEBOUNCE_LOCK = _threading.Lock() |
| _DEBOUNCE_TIMER: Optional[_threading.Timer] = None |
| _LAST_SAVE_AT: float = 0.0 |
| _LAST_SAVE_RESULT: Optional[Dict[str, Any]] = None |
|
|
|
|
| def schedule_save(pipe, delay_seconds: float = 60.0, |
| min_docs: int = 100) -> bool: |
| """Schedule a background save N seconds from now, coalescing any |
| pending save. Subsequent calls reset the timer — so a burst of |
| add_documents() calls within `delay_seconds` only saves once at |
| the end of the burst. |
| |
| Returns True if a save was scheduled, False if skipped (corpus too |
| small, or persistence module unavailable). |
| """ |
| global _DEBOUNCE_TIMER |
| try: |
| n_docs = len(getattr(pipe, "_indexed_docs", []) or []) |
| except Exception: |
| return False |
| if n_docs < min_docs: |
| return False |
| with _DEBOUNCE_LOCK: |
| if _DEBOUNCE_TIMER is not None: |
| try: |
| _DEBOUNCE_TIMER.cancel() |
| except Exception: |
| pass |
| _DEBOUNCE_TIMER = None |
|
|
| def _run(): |
| global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT |
| try: |
| t0 = _time.time() |
| manifest = save_all_retrievers(pipe) |
| dt = _time.time() - t0 |
| n_saved = sum(1 for r in manifest.get("retrievers", {}).values() |
| if r.get("saved")) |
| print(f"[persistence] debounced save: {n_saved} retrievers " |
| f"in {dt:.1f}s") |
| _LAST_SAVE_AT = _time.time() |
| _LAST_SAVE_RESULT = manifest |
| except Exception as e: |
| print(f"[persistence] debounced save failed: {e}") |
| finally: |
| _DEBOUNCE_TIMER = None |
|
|
| _DEBOUNCE_TIMER = _threading.Timer(delay_seconds, _run) |
| _DEBOUNCE_TIMER.daemon = True |
| _DEBOUNCE_TIMER.start() |
| return True |
|
|
|
|
| def cancel_pending_save() -> bool: |
| """Cancel any pending debounced save. Returns True if one was cancelled.""" |
| global _DEBOUNCE_TIMER |
| with _DEBOUNCE_LOCK: |
| if _DEBOUNCE_TIMER is not None: |
| try: |
| _DEBOUNCE_TIMER.cancel() |
| _DEBOUNCE_TIMER = None |
| return True |
| except Exception: |
| pass |
| return False |
|
|
|
|
| def flush_pending_save(pipe, timeout_s: float = 120.0) -> Dict[str, Any]: |
| """Synchronously run the save NOW. Cancels any pending debounced |
| timer first. Used by graceful shutdown handlers — guarantees the |
| on-disk state matches the in-memory state when the server exits. |
| |
| Returns the manifest (or an error dict). Bounded by timeout_s to |
| avoid hanging an exit indefinitely. |
| """ |
| global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT |
| cancel_pending_save() |
|
|
| result: Dict[str, Any] = {} |
| completed = _threading.Event() |
|
|
| def _run(): |
| try: |
| t0 = _time.time() |
| manifest = save_all_retrievers(pipe) |
| dt = _time.time() - t0 |
| n_saved = sum(1 for r in manifest.get("retrievers", {}).values() |
| if r.get("saved")) |
| print(f"[persistence] flush_pending_save: {n_saved} retrievers " |
| f"saved in {dt:.1f}s") |
| global _LAST_SAVE_AT, _LAST_SAVE_RESULT |
| _LAST_SAVE_AT = _time.time() |
| _LAST_SAVE_RESULT = manifest |
| result.update(manifest) |
| except Exception as e: |
| print(f"[persistence] flush_pending_save failed: {e}") |
| result["error"] = str(e) |
| finally: |
| completed.set() |
|
|
| th = _threading.Thread(target=_run, daemon=True, |
| name="persistence-flush") |
| th.start() |
| if not completed.wait(timeout_s): |
| result["error"] = f"flush timed out after {timeout_s}s" |
| print(f"[persistence] flush timed out — partial save likely on disk") |
| return result |
|
|
|
|
| def debouncer_status() -> Dict[str, Any]: |
| """Inspect the debouncer for /v1/admin/persistence/status.""" |
| pending = (_DEBOUNCE_TIMER is not None |
| and getattr(_DEBOUNCE_TIMER, "finished", None) is not None |
| and not _DEBOUNCE_TIMER.finished.is_set()) |
| return { |
| "pending_save": bool(pending), |
| "last_save_at": int(_LAST_SAVE_AT) if _LAST_SAVE_AT else 0, |
| "last_save_age_s": (int(_time.time() - _LAST_SAVE_AT) |
| if _LAST_SAVE_AT else None), |
| } |
|
|
|
|
| __all__ = [ |
| "save_all_retrievers", "load_all_retrievers", |
| "save_retriever_state", "load_retriever_state", |
| "fingerprint_corpus", |
| "schedule_save", "cancel_pending_save", "debouncer_status", |
| "flush_pending_save", |
| ] |
|
|