"""Retriever index persistence — save/load BM25/Graph/Gematria/Hilbert state to disk so server restart doesn't rebuild from scratch. Without this: 50k corpus → 60-120s startup (chunking + indexing 5+ retrievers). With this: 60s on first run, then <5s on every subsequent restart (loads indexes from pickle files). Storage layout: tau_rag/runtime/retriever_state/ manifest.json — corpus fingerprint + file index bm25.pkl — pickled state dict graph.pkl gematria.pkl hilbert.npz — numpy ndarray (more efficient) hilbert_index.bin — HNSW binary Cache invalidation: each save records the corpus fingerprint (count + hash of doc_ids). On load, if fingerprint differs → skip load, force rebuild. """ from __future__ import annotations import gzip import json import os import pickle import time from pathlib import Path from typing import Any, Dict, List, Optional # gzip magic bytes — used to auto-detect compressed files on load _GZIP_MAGIC = b"\x1f\x8b" # Compress level: 6 = balanced (fast enough, ~2-4× shrink on Hebrew text). # Set TAU_RAG_PICKLE_GZIP=0 to disable (debugging, profiling). _USE_GZIP = os.environ.get("TAU_RAG_PICKLE_GZIP", "1") != "0" _GZIP_LEVEL = int(os.environ.get("TAU_RAG_PICKLE_GZIP_LEVEL", "6")) def _dump_pickle(obj: Any, path: Path) -> None: """Pickle + optionally gzip-compress + write atomically.""" if _USE_GZIP: with gzip.open(str(path), "wb", compresslevel=_GZIP_LEVEL) as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) else: with path.open("wb") as f: pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) def _load_pickle(path: Path) -> Any: """Auto-detects gzip vs raw pickle by magic bytes — backward compat with uncompressed pickles from before this change.""" with path.open("rb") as f: head = f.read(2) if head == _GZIP_MAGIC: with gzip.open(str(path), "rb") as f: return pickle.load(f) with path.open("rb") as f: return pickle.load(f) _DEFAULT_DIR = (Path(__file__).resolve().parent.parent / "runtime" / "retriever_state") def _ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def save_retriever_state(retriever, path: Path) -> bool: """Persist a single retriever's state to a pickle file. Each retriever exposes `state_dict() -> dict` (or we extract attrs manually). Returns True on success, False on any error. Writes to `path.tmp` then atomically renames over `path`, so a crash mid-write never leaves a corrupt file. If the destination is immutable / unwritable, falls back to a sibling `.new` file so the user can manually swap them in. """ try: state = None if hasattr(retriever, "state_dict") and callable(retriever.state_dict): state = retriever.state_dict() else: state = _extract_state_generic(retriever) if state is None: return False tmp = path.with_suffix(path.suffix + ".tmp") _dump_pickle(state, tmp) try: tmp.replace(path) # atomic rename except (PermissionError, OSError) as e: # Destination locked (immutable flag, etc) — leave .new sibling new_sibling = path.with_suffix(path.suffix + ".new") try: tmp.replace(new_sibling) print(f"[persistence] couldn't overwrite {path.name} ({e}); " f"wrote {new_sibling.name} instead — you can `mv` it manually") except Exception: tmp.unlink(missing_ok=True) raise return True except Exception as e: print(f"[persistence] save failed for {retriever.__class__.__name__}: {e}") return False def _resolve_load_path(path: Path) -> Optional[Path]: """Pick the best file to load from. If a sibling `.pkl.new` exists AND is newer than `.pkl`, prefer the `.new` file. This handles the case where the original is immutable (couldn't be overwritten in save) — the `.new` was written next to it and is the actual fresh state. Returns the resolved path or None if neither exists. """ new_sibling = path.with_suffix(path.suffix + ".new") if new_sibling.exists(): if not path.exists(): return new_sibling try: if new_sibling.stat().st_mtime > path.stat().st_mtime: return new_sibling except OSError: return new_sibling return path if path.exists() else None def load_retriever_state(retriever, path: Path) -> bool: """Restore a retriever's state from a pickle file. Auto-prefers `.new` if it exists and is newer (handles the case where the original was immutable and the save fell back to a sibling). Detects old-format pickles (saved by `_extract_state_generic` before each retriever had its own state_dict()) and rejects them — those would silently leave the retriever empty. The expected new-format state_dict has a "_class" key plus retriever-specific fields like "chunks" / "hists" / "vecs_f16" without leading underscores. """ resolved = _resolve_load_path(path) if resolved is None: return False try: # _load_pickle auto-detects gzip vs raw pickle (backward compat) state = _load_pickle(resolved) # Old-format detection: the generic extractor pickled the full # __dict__ with leading-underscore keys. New format has plain keys. if isinstance(state, dict) and "_chunks" in state: print(f"[persistence] old-format pickle at {resolved.name} — " f"ignoring, will rebuild") return False if hasattr(retriever, "load_state_dict") and callable(retriever.load_state_dict): retriever.load_state_dict(state) else: _restore_state_generic(retriever, state) # Log when we used the .new fallback so it's visible in startup logs if resolved.suffix == ".new" or resolved.name.endswith(".pkl.new"): print(f"[persistence] loaded {resolved.name} (sibling override)") return True except Exception as e: print(f"[persistence] load failed for {retriever.__class__.__name__}: {e}") return False # Generic fallback — pickles all "_*" attrs of the retriever object def _extract_state_generic(retriever) -> Dict[str, Any]: state: Dict[str, Any] = {"_class": retriever.__class__.__name__} for k, v in retriever.__dict__.items(): # Skip non-picklable things (functions, locks, executors) if k.startswith("_executor"): continue if callable(v): continue try: pickle.dumps(v) # quick pickle test state[k] = v except Exception: pass return state def _restore_state_generic(retriever, state: Dict[str, Any]) -> None: for k, v in state.items(): if k == "_class": continue try: setattr(retriever, k, v) except Exception: pass # ============================================================================ # Top-level: persist/load all retrievers in a MultiRetriever at once # ============================================================================ def fingerprint_corpus(pipe) -> str: """Stable fingerprint of the indexed corpus.""" import hashlib docs = (getattr(pipe, "_indexed_docs", None) or getattr(pipe, "_docs", None) or []) n = len(docs) if n == 0: return "empty" ids = [str(getattr(d, "id", "") or "") for d in docs] h = hashlib.md5("|".join(sorted(ids)).encode("utf-8")).hexdigest()[:12] return f"n={n};h={h}" def save_all_retrievers(pipe, base_dir: Optional[Path] = None) -> Dict[str, Any]: """Save every registered retriever's state. Returns a manifest dict.""" base_dir = base_dir or _DEFAULT_DIR _ensure_dir(base_dir) fp = fingerprint_corpus(pipe) multi = getattr(pipe, "retrievers", None) if multi is None or not hasattr(multi, "_retrievers"): return {"error": "no MultiRetriever"} manifest = { "fingerprint": fp, "saved_at": int(time.time()), "retrievers": {}, } for name, r in multi._retrievers.items(): # The retriever may be wrapped (ReferenceAware) — unwrap to the inner inner = getattr(r, "_inner", r) cls = inner.__class__.__name__ path = base_dir / f"{name}.pkl" ok = save_retriever_state(inner, path) manifest["retrievers"][name] = { "class": cls, "saved": ok, "size_bytes": path.stat().st_size if ok and path.exists() else 0, } # Persist the indexed_docs list separately so we can skip the chunker # on next startup (the heavy part) docs_path = base_dir / "indexed_docs.pkl" try: docs = (getattr(pipe, "_indexed_docs", None) or getattr(pipe, "_docs", None) or []) _dump_pickle(docs, docs_path) manifest["indexed_docs"] = { "saved": True, "n_docs": len(docs), "size_bytes": docs_path.stat().st_size, } except Exception as e: manifest["indexed_docs"] = {"saved": False, "error": str(e)} # Write manifest (base_dir / "manifest.json").write_text( json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") return manifest def load_all_retrievers(pipe, base_dir: Optional[Path] = None, expected_fingerprint: Optional[str] = None ) -> Optional[Dict[str, Any]]: """Try to load all retriever states. Returns manifest if successful; None if missing, fingerprint mismatch, or any other failure. """ base_dir = base_dir or _DEFAULT_DIR manifest_path = base_dir / "manifest.json" if not manifest_path.exists(): return None try: manifest = json.loads(manifest_path.read_text(encoding="utf-8")) except Exception: return None if expected_fingerprint and manifest.get("fingerprint") != expected_fingerprint: return None # corpus changed → caller should rebuild multi = getattr(pipe, "retrievers", None) if multi is None or not hasattr(multi, "_retrievers"): return None n_total = len(multi._retrievers) n_loaded = 0 failures: List[str] = [] for name, r in multi._retrievers.items(): inner = getattr(r, "_inner", r) path = base_dir / f"{name}.pkl" if load_retriever_state(inner, path): n_loaded += 1 else: failures.append(name) # If any retriever failed to load, abort — caller will rebuild from scratch. # A partial load would leave some retrievers empty and produce broken # results without any obvious symptom (search just returns nothing). if n_loaded < n_total: print(f"[persistence] aborting load: only {n_loaded}/{n_total} " f"retrievers loaded (failed: {failures}). Will rebuild.") return None # Restore indexed_docs (required — without it the pipeline can't return # canonical docs for /v1/query). Same .new-sibling preference as above. docs_path = base_dir / "indexed_docs.pkl" docs_resolved = _resolve_load_path(docs_path) if docs_resolved is None: print(f"[persistence] indexed_docs missing — will rebuild") return None try: docs = _load_pickle(docs_resolved) pipe._indexed_docs = docs pipe._indexed_ids = set(getattr(d, "id", None) for d in docs if getattr(d, "id", None)) if docs_resolved.name.endswith(".pkl.new"): print(f"[persistence] loaded {docs_resolved.name} (sibling override)") except Exception as e: print(f"[persistence] indexed_docs load failed: {e}") return None return {**manifest, "n_loaded": n_loaded} # ============================================================================ # Debounced background save — coalesces rapid add_documents calls # ============================================================================ import threading as _threading import time as _time _DEBOUNCE_LOCK = _threading.Lock() _DEBOUNCE_TIMER: Optional[_threading.Timer] = None _LAST_SAVE_AT: float = 0.0 _LAST_SAVE_RESULT: Optional[Dict[str, Any]] = None def schedule_save(pipe, delay_seconds: float = 60.0, min_docs: int = 100) -> bool: """Schedule a background save N seconds from now, coalescing any pending save. Subsequent calls reset the timer — so a burst of add_documents() calls within `delay_seconds` only saves once at the end of the burst. Returns True if a save was scheduled, False if skipped (corpus too small, or persistence module unavailable). """ global _DEBOUNCE_TIMER try: n_docs = len(getattr(pipe, "_indexed_docs", []) or []) except Exception: return False if n_docs < min_docs: return False with _DEBOUNCE_LOCK: if _DEBOUNCE_TIMER is not None: try: _DEBOUNCE_TIMER.cancel() except Exception: pass _DEBOUNCE_TIMER = None def _run(): global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT try: t0 = _time.time() manifest = save_all_retrievers(pipe) dt = _time.time() - t0 n_saved = sum(1 for r in manifest.get("retrievers", {}).values() if r.get("saved")) print(f"[persistence] debounced save: {n_saved} retrievers " f"in {dt:.1f}s") _LAST_SAVE_AT = _time.time() _LAST_SAVE_RESULT = manifest except Exception as e: print(f"[persistence] debounced save failed: {e}") finally: _DEBOUNCE_TIMER = None _DEBOUNCE_TIMER = _threading.Timer(delay_seconds, _run) _DEBOUNCE_TIMER.daemon = True _DEBOUNCE_TIMER.start() return True def cancel_pending_save() -> bool: """Cancel any pending debounced save. Returns True if one was cancelled.""" global _DEBOUNCE_TIMER with _DEBOUNCE_LOCK: if _DEBOUNCE_TIMER is not None: try: _DEBOUNCE_TIMER.cancel() _DEBOUNCE_TIMER = None return True except Exception: pass return False def flush_pending_save(pipe, timeout_s: float = 120.0) -> Dict[str, Any]: """Synchronously run the save NOW. Cancels any pending debounced timer first. Used by graceful shutdown handlers — guarantees the on-disk state matches the in-memory state when the server exits. Returns the manifest (or an error dict). Bounded by timeout_s to avoid hanging an exit indefinitely. """ global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT cancel_pending_save() result: Dict[str, Any] = {} completed = _threading.Event() def _run(): try: t0 = _time.time() manifest = save_all_retrievers(pipe) dt = _time.time() - t0 n_saved = sum(1 for r in manifest.get("retrievers", {}).values() if r.get("saved")) print(f"[persistence] flush_pending_save: {n_saved} retrievers " f"saved in {dt:.1f}s") global _LAST_SAVE_AT, _LAST_SAVE_RESULT _LAST_SAVE_AT = _time.time() _LAST_SAVE_RESULT = manifest result.update(manifest) except Exception as e: print(f"[persistence] flush_pending_save failed: {e}") result["error"] = str(e) finally: completed.set() th = _threading.Thread(target=_run, daemon=True, name="persistence-flush") th.start() if not completed.wait(timeout_s): result["error"] = f"flush timed out after {timeout_s}s" print(f"[persistence] flush timed out — partial save likely on disk") return result def debouncer_status() -> Dict[str, Any]: """Inspect the debouncer for /v1/admin/persistence/status.""" pending = (_DEBOUNCE_TIMER is not None and getattr(_DEBOUNCE_TIMER, "finished", None) is not None and not _DEBOUNCE_TIMER.finished.is_set()) return { "pending_save": bool(pending), "last_save_at": int(_LAST_SAVE_AT) if _LAST_SAVE_AT else 0, "last_save_age_s": (int(_time.time() - _LAST_SAVE_AT) if _LAST_SAVE_AT else None), } __all__ = [ "save_all_retrievers", "load_all_retrievers", "save_retriever_state", "load_retriever_state", "fingerprint_corpus", "schedule_save", "cancel_pending_save", "debouncer_status", "flush_pending_save", ]