File size: 17,136 Bytes
3be54c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 | """Retriever index persistence — save/load BM25/Graph/Gematria/Hilbert
state to disk so server restart doesn't rebuild from scratch.
Without this: 50k corpus → 60-120s startup (chunking + indexing 5+ retrievers).
With this: 60s on first run, then <5s on every subsequent restart (loads
indexes from pickle files).
Storage layout: tau_rag/runtime/retriever_state/
manifest.json — corpus fingerprint + file index
bm25.pkl — pickled state dict
graph.pkl
gematria.pkl
hilbert.npz — numpy ndarray (more efficient)
hilbert_index.bin — HNSW binary
Cache invalidation: each save records the corpus fingerprint (count +
hash of doc_ids). On load, if fingerprint differs → skip load, force rebuild.
"""
from __future__ import annotations
import gzip
import json
import os
import pickle
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
# gzip magic bytes — used to auto-detect compressed files on load
_GZIP_MAGIC = b"\x1f\x8b"
# Compress level: 6 = balanced (fast enough, ~2-4× shrink on Hebrew text).
# Set TAU_RAG_PICKLE_GZIP=0 to disable (debugging, profiling).
_USE_GZIP = os.environ.get("TAU_RAG_PICKLE_GZIP", "1") != "0"
_GZIP_LEVEL = int(os.environ.get("TAU_RAG_PICKLE_GZIP_LEVEL", "6"))
def _dump_pickle(obj: Any, path: Path) -> None:
"""Pickle + optionally gzip-compress + write atomically."""
if _USE_GZIP:
with gzip.open(str(path), "wb",
compresslevel=_GZIP_LEVEL) as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
with path.open("wb") as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
def _load_pickle(path: Path) -> Any:
"""Auto-detects gzip vs raw pickle by magic bytes — backward compat
with uncompressed pickles from before this change."""
with path.open("rb") as f:
head = f.read(2)
if head == _GZIP_MAGIC:
with gzip.open(str(path), "rb") as f:
return pickle.load(f)
with path.open("rb") as f:
return pickle.load(f)
_DEFAULT_DIR = (Path(__file__).resolve().parent.parent
/ "runtime" / "retriever_state")
def _ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def save_retriever_state(retriever, path: Path) -> bool:
"""Persist a single retriever's state to a pickle file.
Each retriever exposes `state_dict() -> dict` (or we extract attrs
manually). Returns True on success, False on any error.
Writes to `path.tmp` then atomically renames over `path`, so a crash
mid-write never leaves a corrupt file. If the destination is
immutable / unwritable, falls back to a sibling `.new` file so the
user can manually swap them in.
"""
try:
state = None
if hasattr(retriever, "state_dict") and callable(retriever.state_dict):
state = retriever.state_dict()
else:
state = _extract_state_generic(retriever)
if state is None:
return False
tmp = path.with_suffix(path.suffix + ".tmp")
_dump_pickle(state, tmp)
try:
tmp.replace(path) # atomic rename
except (PermissionError, OSError) as e:
# Destination locked (immutable flag, etc) — leave .new sibling
new_sibling = path.with_suffix(path.suffix + ".new")
try:
tmp.replace(new_sibling)
print(f"[persistence] couldn't overwrite {path.name} ({e}); "
f"wrote {new_sibling.name} instead — you can `mv` it manually")
except Exception:
tmp.unlink(missing_ok=True)
raise
return True
except Exception as e:
print(f"[persistence] save failed for {retriever.__class__.__name__}: {e}")
return False
def _resolve_load_path(path: Path) -> Optional[Path]:
"""Pick the best file to load from.
If a sibling `<name>.pkl.new` exists AND is newer than `<name>.pkl`,
prefer the `.new` file. This handles the case where the original
is immutable (couldn't be overwritten in save) — the `.new` was
written next to it and is the actual fresh state.
Returns the resolved path or None if neither exists.
"""
new_sibling = path.with_suffix(path.suffix + ".new")
if new_sibling.exists():
if not path.exists():
return new_sibling
try:
if new_sibling.stat().st_mtime > path.stat().st_mtime:
return new_sibling
except OSError:
return new_sibling
return path if path.exists() else None
def load_retriever_state(retriever, path: Path) -> bool:
"""Restore a retriever's state from a pickle file.
Auto-prefers `<path>.new` if it exists and is newer (handles the
case where the original was immutable and the save fell back to a
sibling).
Detects old-format pickles (saved by `_extract_state_generic` before
each retriever had its own state_dict()) and rejects them — those
would silently leave the retriever empty. The expected new-format
state_dict has a "_class" key plus retriever-specific fields like
"chunks" / "hists" / "vecs_f16" without leading underscores.
"""
resolved = _resolve_load_path(path)
if resolved is None:
return False
try:
# _load_pickle auto-detects gzip vs raw pickle (backward compat)
state = _load_pickle(resolved)
# Old-format detection: the generic extractor pickled the full
# __dict__ with leading-underscore keys. New format has plain keys.
if isinstance(state, dict) and "_chunks" in state:
print(f"[persistence] old-format pickle at {resolved.name} — "
f"ignoring, will rebuild")
return False
if hasattr(retriever, "load_state_dict") and callable(retriever.load_state_dict):
retriever.load_state_dict(state)
else:
_restore_state_generic(retriever, state)
# Log when we used the .new fallback so it's visible in startup logs
if resolved.suffix == ".new" or resolved.name.endswith(".pkl.new"):
print(f"[persistence] loaded {resolved.name} (sibling override)")
return True
except Exception as e:
print(f"[persistence] load failed for {retriever.__class__.__name__}: {e}")
return False
# Generic fallback — pickles all "_*" attrs of the retriever object
def _extract_state_generic(retriever) -> Dict[str, Any]:
state: Dict[str, Any] = {"_class": retriever.__class__.__name__}
for k, v in retriever.__dict__.items():
# Skip non-picklable things (functions, locks, executors)
if k.startswith("_executor"):
continue
if callable(v):
continue
try:
pickle.dumps(v) # quick pickle test
state[k] = v
except Exception:
pass
return state
def _restore_state_generic(retriever, state: Dict[str, Any]) -> None:
for k, v in state.items():
if k == "_class":
continue
try:
setattr(retriever, k, v)
except Exception:
pass
# ============================================================================
# Top-level: persist/load all retrievers in a MultiRetriever at once
# ============================================================================
def fingerprint_corpus(pipe) -> str:
"""Stable fingerprint of the indexed corpus."""
import hashlib
docs = (getattr(pipe, "_indexed_docs", None)
or getattr(pipe, "_docs", None) or [])
n = len(docs)
if n == 0:
return "empty"
ids = [str(getattr(d, "id", "") or "") for d in docs]
h = hashlib.md5("|".join(sorted(ids)).encode("utf-8")).hexdigest()[:12]
return f"n={n};h={h}"
def save_all_retrievers(pipe, base_dir: Optional[Path] = None) -> Dict[str, Any]:
"""Save every registered retriever's state. Returns a manifest dict."""
base_dir = base_dir or _DEFAULT_DIR
_ensure_dir(base_dir)
fp = fingerprint_corpus(pipe)
multi = getattr(pipe, "retrievers", None)
if multi is None or not hasattr(multi, "_retrievers"):
return {"error": "no MultiRetriever"}
manifest = {
"fingerprint": fp,
"saved_at": int(time.time()),
"retrievers": {},
}
for name, r in multi._retrievers.items():
# The retriever may be wrapped (ReferenceAware) — unwrap to the inner
inner = getattr(r, "_inner", r)
cls = inner.__class__.__name__
path = base_dir / f"{name}.pkl"
ok = save_retriever_state(inner, path)
manifest["retrievers"][name] = {
"class": cls,
"saved": ok,
"size_bytes": path.stat().st_size if ok and path.exists() else 0,
}
# Persist the indexed_docs list separately so we can skip the chunker
# on next startup (the heavy part)
docs_path = base_dir / "indexed_docs.pkl"
try:
docs = (getattr(pipe, "_indexed_docs", None)
or getattr(pipe, "_docs", None) or [])
_dump_pickle(docs, docs_path)
manifest["indexed_docs"] = {
"saved": True, "n_docs": len(docs),
"size_bytes": docs_path.stat().st_size,
}
except Exception as e:
manifest["indexed_docs"] = {"saved": False, "error": str(e)}
# Write manifest
(base_dir / "manifest.json").write_text(
json.dumps(manifest, ensure_ascii=False, indent=2),
encoding="utf-8")
return manifest
def load_all_retrievers(pipe, base_dir: Optional[Path] = None,
expected_fingerprint: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Try to load all retriever states. Returns manifest if successful;
None if missing, fingerprint mismatch, or any other failure.
"""
base_dir = base_dir or _DEFAULT_DIR
manifest_path = base_dir / "manifest.json"
if not manifest_path.exists():
return None
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except Exception:
return None
if expected_fingerprint and manifest.get("fingerprint") != expected_fingerprint:
return None # corpus changed → caller should rebuild
multi = getattr(pipe, "retrievers", None)
if multi is None or not hasattr(multi, "_retrievers"):
return None
n_total = len(multi._retrievers)
n_loaded = 0
failures: List[str] = []
for name, r in multi._retrievers.items():
inner = getattr(r, "_inner", r)
path = base_dir / f"{name}.pkl"
if load_retriever_state(inner, path):
n_loaded += 1
else:
failures.append(name)
# If any retriever failed to load, abort — caller will rebuild from scratch.
# A partial load would leave some retrievers empty and produce broken
# results without any obvious symptom (search just returns nothing).
if n_loaded < n_total:
print(f"[persistence] aborting load: only {n_loaded}/{n_total} "
f"retrievers loaded (failed: {failures}). Will rebuild.")
return None
# Restore indexed_docs (required — without it the pipeline can't return
# canonical docs for /v1/query). Same .new-sibling preference as above.
docs_path = base_dir / "indexed_docs.pkl"
docs_resolved = _resolve_load_path(docs_path)
if docs_resolved is None:
print(f"[persistence] indexed_docs missing — will rebuild")
return None
try:
docs = _load_pickle(docs_resolved)
pipe._indexed_docs = docs
pipe._indexed_ids = set(getattr(d, "id", None) for d in docs
if getattr(d, "id", None))
if docs_resolved.name.endswith(".pkl.new"):
print(f"[persistence] loaded {docs_resolved.name} (sibling override)")
except Exception as e:
print(f"[persistence] indexed_docs load failed: {e}")
return None
return {**manifest, "n_loaded": n_loaded}
# ============================================================================
# Debounced background save — coalesces rapid add_documents calls
# ============================================================================
import threading as _threading
import time as _time
_DEBOUNCE_LOCK = _threading.Lock()
_DEBOUNCE_TIMER: Optional[_threading.Timer] = None
_LAST_SAVE_AT: float = 0.0
_LAST_SAVE_RESULT: Optional[Dict[str, Any]] = None
def schedule_save(pipe, delay_seconds: float = 60.0,
min_docs: int = 100) -> bool:
"""Schedule a background save N seconds from now, coalescing any
pending save. Subsequent calls reset the timer — so a burst of
add_documents() calls within `delay_seconds` only saves once at
the end of the burst.
Returns True if a save was scheduled, False if skipped (corpus too
small, or persistence module unavailable).
"""
global _DEBOUNCE_TIMER
try:
n_docs = len(getattr(pipe, "_indexed_docs", []) or [])
except Exception:
return False
if n_docs < min_docs:
return False
with _DEBOUNCE_LOCK:
if _DEBOUNCE_TIMER is not None:
try:
_DEBOUNCE_TIMER.cancel()
except Exception:
pass
_DEBOUNCE_TIMER = None
def _run():
global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT
try:
t0 = _time.time()
manifest = save_all_retrievers(pipe)
dt = _time.time() - t0
n_saved = sum(1 for r in manifest.get("retrievers", {}).values()
if r.get("saved"))
print(f"[persistence] debounced save: {n_saved} retrievers "
f"in {dt:.1f}s")
_LAST_SAVE_AT = _time.time()
_LAST_SAVE_RESULT = manifest
except Exception as e:
print(f"[persistence] debounced save failed: {e}")
finally:
_DEBOUNCE_TIMER = None
_DEBOUNCE_TIMER = _threading.Timer(delay_seconds, _run)
_DEBOUNCE_TIMER.daemon = True
_DEBOUNCE_TIMER.start()
return True
def cancel_pending_save() -> bool:
"""Cancel any pending debounced save. Returns True if one was cancelled."""
global _DEBOUNCE_TIMER
with _DEBOUNCE_LOCK:
if _DEBOUNCE_TIMER is not None:
try:
_DEBOUNCE_TIMER.cancel()
_DEBOUNCE_TIMER = None
return True
except Exception:
pass
return False
def flush_pending_save(pipe, timeout_s: float = 120.0) -> Dict[str, Any]:
"""Synchronously run the save NOW. Cancels any pending debounced
timer first. Used by graceful shutdown handlers — guarantees the
on-disk state matches the in-memory state when the server exits.
Returns the manifest (or an error dict). Bounded by timeout_s to
avoid hanging an exit indefinitely.
"""
global _DEBOUNCE_TIMER, _LAST_SAVE_AT, _LAST_SAVE_RESULT
cancel_pending_save()
result: Dict[str, Any] = {}
completed = _threading.Event()
def _run():
try:
t0 = _time.time()
manifest = save_all_retrievers(pipe)
dt = _time.time() - t0
n_saved = sum(1 for r in manifest.get("retrievers", {}).values()
if r.get("saved"))
print(f"[persistence] flush_pending_save: {n_saved} retrievers "
f"saved in {dt:.1f}s")
global _LAST_SAVE_AT, _LAST_SAVE_RESULT
_LAST_SAVE_AT = _time.time()
_LAST_SAVE_RESULT = manifest
result.update(manifest)
except Exception as e:
print(f"[persistence] flush_pending_save failed: {e}")
result["error"] = str(e)
finally:
completed.set()
th = _threading.Thread(target=_run, daemon=True,
name="persistence-flush")
th.start()
if not completed.wait(timeout_s):
result["error"] = f"flush timed out after {timeout_s}s"
print(f"[persistence] flush timed out — partial save likely on disk")
return result
def debouncer_status() -> Dict[str, Any]:
"""Inspect the debouncer for /v1/admin/persistence/status."""
pending = (_DEBOUNCE_TIMER is not None
and getattr(_DEBOUNCE_TIMER, "finished", None) is not None
and not _DEBOUNCE_TIMER.finished.is_set())
return {
"pending_save": bool(pending),
"last_save_at": int(_LAST_SAVE_AT) if _LAST_SAVE_AT else 0,
"last_save_age_s": (int(_time.time() - _LAST_SAVE_AT)
if _LAST_SAVE_AT else None),
}
__all__ = [
"save_all_retrievers", "load_all_retrievers",
"save_retriever_state", "load_retriever_state",
"fingerprint_corpus",
"schedule_save", "cancel_pending_save", "debouncer_status",
"flush_pending_save",
]
|