"""FastAPI app.
Run: uvicorn tau_rag.api.fastapi_app:app --reload
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
try:
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
except Exception as e: # pragma: no cover
raise RuntimeError(
"FastAPI not installed. `pip install tau-rag[api]`."
) from e
from ..core.config import Config
from ..core.types import Document, Query, Strategy
from ..pipeline import Pipeline
from .errors import (
ErrorCode, Limits, build_error_body,
validate_query_text, validate_doc_list, validate_k,
)
app = FastAPI(title="TAU-RAG", version="2.0.0")
# ---------------------------------------------------- CORS + security
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
from .security import cors_config_from_env, apply_security_headers # noqa: E402
_cors_cfg = cors_config_from_env()
if _cors_cfg["allow_origins"]:
app.add_middleware(CORSMiddleware, **_cors_cfg)
@app.middleware("http")
async def _security_headers_middleware(request, call_next):
response = await call_next(request)
apply_security_headers(response.headers)
return response
# ---------------------------------------------------- global error handlers
def _rid_from(request) -> Optional[str]:
return getattr(getattr(request, "state", None), "request_id", None)
@app.exception_handler(HTTPException)
async def _http_exc_handler(request, exc: HTTPException):
# Map FastAPI's status codes to our canonical codes
code_map = {
400: ErrorCode.VALIDATION_ERROR,
401: ErrorCode.UNAUTHORIZED,
403: ErrorCode.ADMIN_REQUIRED,
404: ErrorCode.NOT_FOUND,
413: ErrorCode.PAYLOAD_TOO_LARGE,
422: ErrorCode.VALIDATION_ERROR,
429: ErrorCode.RATE_LIMITED,
}
code = code_map.get(exc.status_code, ErrorCode.INTERNAL_ERROR)
detail = exc.detail
message = detail if isinstance(detail, str) else "request failed"
details = detail if isinstance(detail, dict) else None
body = build_error_body(code, message, _rid_from(request), details)
headers = dict(exc.headers or {})
rid = _rid_from(request)
if rid:
headers["X-Request-ID"] = rid
return JSONResponse(status_code=exc.status_code, content=body, headers=headers)
@app.exception_handler(RequestValidationError)
async def _validation_exc_handler(request, exc: RequestValidationError):
body = build_error_body(
ErrorCode.VALIDATION_ERROR,
"request failed validation",
_rid_from(request),
details={"errors": exc.errors()},
)
headers = {"X-Request-ID": _rid_from(request)} if _rid_from(request) else {}
return JSONResponse(status_code=422, content=body, headers=headers)
@app.exception_handler(Exception)
async def _unhandled_exc_handler(request, exc: Exception):
body = build_error_body(
ErrorCode.INTERNAL_ERROR,
f"{type(exc).__name__}: {exc}"[:300],
_rid_from(request),
)
headers = {"X-Request-ID": _rid_from(request)} if _rid_from(request) else {}
return JSONResponse(status_code=500, content=body, headers=headers)
# ----------------------------------------------------------------- middleware
from fastapi import Request
from fastapi.responses import JSONResponse
from ..middleware import get_cache, get_limiter
from ..middleware.auth import get_auth
from ..middleware.ratelimit import RateLimitExceeded
from ..middleware.observability import (
get_obs, generate_request_id, RequestLog, _hash_prefix,
)
from ..middleware.maintenance import get_maintenance
from ..middleware.pii_redaction import get_pii_redactor
from ..middleware.slow_queries import get_slow_tracker, SlowRecord
from ..middleware.quota import get_quota_tracker
from ..middleware.idempotency import get_idempotency_store
from ..middleware.request_timeout import get_timeout_guard
import asyncio as _asyncio
import time as _time
@app.middleware("http")
async def auth_and_ratelimit_middleware(request: Request, call_next):
t0 = _time.time()
path = request.url.path
protected = (path.startswith("/v1/generate") or path.startswith("/v1/chat")
or path.startswith("/v1/documents"))
admin_only = path.startswith("/v1/admin/")
api_key = request.headers.get("x-api-key")
# X-Request-ID — honor client-supplied or generate our own
request_id = request.headers.get("x-request-id") or generate_request_id()
# Stash on request.state so handlers can correlate if they want
request.state.request_id = request_id
# v1.99 — optional body capture for replay. Opt-in via env var
# TAU_RAG_OBS_CAPTURE_BODY=1. Only captures bodies for the replay-
# able endpoints (search/generate/chat) to keep log size bounded
# and avoid picking up admin request bodies with keys.
captured_body: Optional[str] = None
_REPLAY_CAPTURE_PATHS = (
"/v1/search", "/v1/generate", "/v1/chat",
)
if (_os.environ.get("TAU_RAG_OBS_CAPTURE_BODY") == "1"
and request.method == "POST"
and any(path.startswith(p) for p in _REPLAY_CAPTURE_PATHS)):
try:
raw = await request.body()
# Truncate to 4KB — real queries are under 1KB, legal
# texts rarely reach 2KB. Very-long payloads get flagged
# but replay won't work on them (acceptable).
if raw is not None:
captured_body = raw[:4096].decode("utf-8", errors="replace")
# v2.8 — PII redaction. When TAU_RAG_PII_REDACT=1 (or
# admin flipped via endpoint), scrub Israeli IDs, phone
# numbers, emails, and CC-like digit runs from the
# captured text BEFORE it hits the observability log,
# JSONL file, stdout, or SSE tail. No-op when disabled.
captured_body = get_pii_redactor().redact(captured_body)
# Put the already-read body back on the request so the
# downstream handler still sees it.
async def _receive() -> Dict[str, Any]:
return {"type": "http.request", "body": raw,
"more_body": False}
request._receive = _receive # type: ignore[attr-defined]
except Exception:
captured_body = None
def _log(status: int, error: Optional[str] = None) -> None:
extra: Dict[str, Any] = {}
if error:
extra["error"] = error
if captured_body is not None:
extra["body"] = captured_body
latency_ms = (_time.time() - t0) * 1000.0
get_obs().record(RequestLog(
ts=_time.time(),
request_id=request_id,
method=request.method,
path=path,
status=status,
latency_ms=latency_ms,
key_hash_prefix=_hash_prefix(api_key),
client_ip=(request.client.host if request.client else None),
user_agent=request.headers.get("user-agent"),
event_type="request",
extra=extra,
))
# v2.9 — also feed the slow-query tracker. No-op if threshold=0
# or if the request was fast enough. Kept off the observability
# log hot-path: cheap dict append in the tracker.
try:
get_slow_tracker().maybe_record(SlowRecord(
ts=_time.time(),
request_id=request_id,
method=request.method,
path=path,
status=status,
latency_ms=latency_ms,
error=error,
))
except Exception:
pass
# 1. Auth check (only if TAU_RAG_REQUIRE_AUTH is set OR admin path)
auth = get_auth()
# v2.7 — maintenance / drain mode. Admin traffic always flows (so
# operators can turn it off again); everyone else gets 503 +
# Retry-After. Check happens AFTER auth object is available so we
# can ask ``is_admin(key)`` but BEFORE rate limiting — otherwise a
# drained pod would count rejected requests against the limiter,
# polluting stats.
# v2.11 — k8s probes (/livez, /readyz) must always reach the
# handler so the probe reflects true readiness. Drain is ONE of
# several reasons a pod might be unready; the probe itself (via
# the readiness registry's ``not_draining`` check) signals it.
# Blocking the probe at middleware level would mask other
# unreadiness signals during drain.
_PROBE_PATHS = ("/livez", "/readyz")
maint = get_maintenance()
if (maint.is_enabled() and not admin_only
and not auth.is_admin(api_key)
and path not in _PROBE_PATHS):
snap = maint.snapshot()
_log(503, error="maintenance")
return JSONResponse(
status_code=503,
headers={
"Retry-After": str(int(snap["retry_after"])),
"X-Request-ID": request_id,
},
content=build_error_body(
ErrorCode.INTERNAL_ERROR,
"service is in maintenance mode",
request_id=request_id,
details={
"reason": snap["reason"],
"retry_after": int(snap["retry_after"]),
"maintenance_since_sec": round(
snap["duration_sec"], 2),
},
),
)
if admin_only:
if not auth.is_admin(api_key):
_log(401, error="admin_required")
return JSONResponse(
status_code=401,
headers={"X-Request-ID": request_id},
content=build_error_body(
ErrorCode.ADMIN_REQUIRED,
"admin scope required",
request_id=request_id,
details={"hint": "pass X-API-Key with admin scope"},
),
)
elif protected and auth.required:
scope = "write" if request.method in ("POST", "PUT", "DELETE") else "read"
if not auth.validate(api_key, scope=scope):
_log(401, error="unauthorized")
return JSONResponse(
status_code=401,
headers={"X-Request-ID": request_id},
content=build_error_body(
ErrorCode.UNAUTHORIZED,
"missing or invalid X-API-Key",
request_id=request_id,
details={"required_scope": scope},
),
)
# 2. v2.12 — Per-API-key daily quota. Runs BEFORE rate limit so
# the per-second limiter doesn't deduct tokens for requests that
# will be rejected anyway. Skipped for:
# - unauthenticated paths (no key to meter)
# - whitelisted clients (same as rate limiter whitelist)
# - admin-only paths (admin already auth'd; no quota)
# Only applies when the key actually has a quota configured —
# unquotaed keys are unlimited (same as pre-v2.12 behavior).
if api_key and not admin_only:
client_ip = request.client.host if request.client else None
if (api_key not in get_limiter().whitelist
and client_ip not in get_limiter().whitelist):
qc = get_quota_tracker().check_and_increment(api_key)
if not qc.ok:
_log(429, error="quota_exceeded")
return JSONResponse(
status_code=429,
headers={
"Retry-After": str(qc.reset_in_sec),
"X-Request-ID": request_id,
"X-Quota-Limit": str(qc.limit),
"X-Quota-Used": str(qc.used),
},
content=build_error_body(
ErrorCode.RATE_LIMITED,
f"daily quota exceeded: {qc.used}/{qc.limit}",
request_id=request_id,
details={
"quota": "daily",
"limit": qc.limit,
"used": qc.used,
"reset_in_sec": qc.reset_in_sec,
"key_prefix": qc.key_prefix,
},
),
)
# 3. Rate limit (skip admin — already auth'd)
if protected:
try:
client_key = (
api_key
or (request.client.host if request.client else "unknown")
)
# v1.73 — pass path so the limiter can apply per-endpoint overrides
get_limiter().acquire(client_key, path=path)
except RateLimitExceeded as e:
_log(429, error="rate_limited")
return JSONResponse(
status_code=429,
headers={"Retry-After": f"{e.retry_after:.1f}",
"X-Request-ID": request_id},
content=build_error_body(
ErrorCode.RATE_LIMITED,
f"rate limit exceeded for {e.key!r}",
request_id=request_id,
details={"retry_after": round(e.retry_after, 3),
"key": e.key},
),
)
# v2.13 — Idempotency-Key check BEFORE dispatch. Scoped to
# (api_key_prefix, idempotency_key) so two clients using the same
# key don't collide. POST only, whitelisted paths only.
_IDEMPOTENT_PATHS = ("/v1/generate", "/v1/chat", "/v1/search")
idem_key = request.headers.get("idempotency-key")
idem_eligible = (
idem_key
and request.method == "POST"
and any(path.startswith(p) for p in _IDEMPOTENT_PATHS)
)
if idem_eligible:
idem_scope = _hash_prefix(api_key) or (
request.client.host if request.client else "anon")
cached = get_idempotency_store().get(idem_scope, idem_key)
if cached is not None:
_log(cached.status, error="idempotent_hit")
headers = {
"X-Request-ID": request_id,
"X-Idempotency-Hit": "1",
"X-Idempotency-Key": idem_key,
**cached.headers_extra,
}
return JSONResponse(
status_code=cached.status,
content=cached.body,
headers=headers,
)
# 3. Dispatch + log
# v2.4 — set request_id on the tracer's thread-local so all pipeline
# spans (v1.27) created during handler execution get auto-tagged.
try:
from ..observability.tracing import get_tracer
_t = get_tracer()
_t.set_request_id(request_id)
except Exception:
_t = None
# v2.14 — optional wall-clock timeout enforcement. 0 (default) =
# disabled, preserving pre-v2.14 behavior.
_guard = get_timeout_guard()
_guard.record_request()
try:
if _guard.is_enabled():
try:
response = await _asyncio.wait_for(
call_next(request),
timeout=_guard.timeout_ms / 1000.0,
)
except _asyncio.TimeoutError:
_guard.record_timeout()
_log(504, error="request_timeout")
return JSONResponse(
status_code=504,
headers={"X-Request-ID": request_id},
content=build_error_body(
ErrorCode.INTERNAL_ERROR,
f"request exceeded {_guard.timeout_ms:.0f}ms timeout",
request_id=request_id,
details={"timeout_ms": _guard.timeout_ms,
"path": path},
),
)
else:
response = await call_next(request)
except Exception as e:
_log(500, error=f"{type(e).__name__}: {e}")
raise
finally:
if _t is not None:
try:
_t.set_request_id(None)
except Exception:
pass
response.headers["X-Request-ID"] = request_id
# v2.13 — cache successful responses for idempotency replay.
# Consumes body_iterator and reconstructs a response so downstream
# still sees the full body.
if idem_eligible and 200 <= response.status_code < 300:
try:
body_bytes = b""
async for chunk in response.body_iterator: # type: ignore[attr-defined]
body_bytes += chunk
import json as _json
try:
body_json = _json.loads(body_bytes.decode("utf-8"))
except Exception:
body_json = None
if body_json is not None:
get_idempotency_store().set(
idem_scope, idem_key,
response.status_code, body_json,
)
from fastapi.responses import Response as _Resp
response = _Resp(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
except Exception:
pass
_log(response.status_code)
return response
# ----------------------------------------------------------------- startup
def _pipeline_from_env() -> Pipeline:
"""Pick config via env var TAU_RAG_PRESET so the same container image
can run different flavors (no_llm, hebrew_dense, mock)."""
import os
preset = os.environ.get("TAU_RAG_PRESET", "no_llm")
cfg = {
"mock": Config.mock,
"default": Config.default,
"hebrew_legal": Config.hebrew_legal,
"no_llm": Config.no_llm,
"hebrew_dense": Config.hebrew_dense,
}.get(preset, Config.mock)()
return Pipeline.from_config(cfg)
# One shared pipeline. Swap the config for production.
_pipeline: Pipeline = _pipeline_from_env()
# Auto-restore from snapshot if TAU_RAG_SNAPSHOT_PATH is set
import os as _os
_snap_path = _os.environ.get("TAU_RAG_SNAPSHOT_PATH")
if _snap_path and _os.path.exists(_snap_path):
try:
_restore_summary = _pipeline.load_snapshot(_snap_path, replace=False)
print(f"[tau-rag] restored from snapshot: {_restore_summary}")
except Exception as _e:
print(f"[tau-rag] snapshot restore failed: {_e}")
# Periodic auto-snapshot — fires every N seconds as a crash-proofing measure.
from ..snapshot import AutoSnapshotter, set_autosnapshotter, get_autosnapshotter # noqa: E402
_snap_interval = _os.environ.get("TAU_RAG_SNAPSHOT_INTERVAL")
if _snap_path and _snap_interval:
try:
_iv = float(_snap_interval)
if _iv > 0:
_auto = AutoSnapshotter(
_pipeline, _snap_path, interval_sec=_iv,
on_save=lambda s: get_obs().audit(
"snapshot.auto_periodic", **s),
)
_auto.start()
set_autosnapshotter(_auto)
print(f"[tau-rag] periodic auto-snapshot every {_iv}s → {_snap_path}")
except Exception as _e:
print(f"[tau-rag] periodic snapshot setup failed: {_e}")
# Periodic metrics history sampler (v1.78) — optional, enabled by env var.
try:
_metrics_iv_raw = _os.environ.get("TAU_RAG_METRICS_HISTORY_INTERVAL_SEC")
if _metrics_iv_raw:
_metrics_iv = float(_metrics_iv_raw)
if _metrics_iv > 0:
from ..middleware import (
MetricsHistory, MetricsHistorySampler,
get_metrics_history, set_metrics_sampler,
)
_mcap = int(_os.environ.get(
"TAU_RAG_METRICS_HISTORY_CAPACITY", "720"))
# Replace the default history with one sized from env.
_h = MetricsHistory(max_samples=max(10, _mcap))
from ..middleware import set_metrics_history
set_metrics_history(_h)
_sampler = MetricsHistorySampler(
_h, interval_s=_metrics_iv, sample_on_start=True,
)
_sampler.start()
set_metrics_sampler(_sampler)
print(f"[tau-rag] metrics history sampler every "
f"{_metrics_iv}s cap={_mcap}")
except Exception as _e:
print(f"[tau-rag] metrics history sampler setup failed: {_e}")
# Background analytics retention scheduler (v1.93) — optional, enabled
# by TAU_RAG_ANALYTICS_TTL_DAYS.
try:
_ttl_raw = _os.environ.get("TAU_RAG_ANALYTICS_TTL_DAYS")
if _ttl_raw:
_ttl_days = float(_ttl_raw)
if _ttl_days > 0:
_prune_iv = float(_os.environ.get(
"TAU_RAG_ANALYTICS_PRUNE_INTERVAL_SEC", "3600"))
from ..middleware import (
AnalyticsRetentionScheduler,
set_retention_scheduler,
)
def _analytics_prune_cb(summary):
try:
if summary.get("total_removed", 0) > 0:
get_obs().audit(
"analytics.prune.auto",
ttl_seconds=summary.get("ttl_seconds"),
total_removed=summary.get("total_removed"),
)
elif summary.get("error"):
get_obs().audit(
"analytics.prune.auto.error",
error=summary["error"],
)
except Exception:
pass
_retention = AnalyticsRetentionScheduler(
ttl_seconds=_ttl_days * 86400.0,
interval_s=_prune_iv,
on_prune=_analytics_prune_cb,
)
_retention.start()
set_retention_scheduler(_retention)
print(f"[tau-rag] analytics retention scheduler: "
f"ttl={_ttl_days}d interval={_prune_iv}s")
except Exception as _e:
print(f"[tau-rag] retention scheduler setup failed: {_e}")
# Background alert evaluator (v1.81) — optional, enabled by env var.
try:
_alert_iv_raw = _os.environ.get("TAU_RAG_ALERT_EVAL_INTERVAL_SEC")
if _alert_iv_raw:
_alert_iv = float(_alert_iv_raw)
if _alert_iv > 0:
from ..middleware import (
AlertScheduler, get_alert_store, get_metrics_history,
set_alert_scheduler,
)
def _alert_fire_cb(verdict):
try:
get_obs().audit(
"alert.fired",
rule=verdict["rule"],
reason=verdict["reason"],
latest_value=verdict["latest_value"],
n_samples=verdict["n_samples"],
)
except Exception:
pass
_asched = AlertScheduler(
get_alert_store(), get_metrics_history(),
interval_s=_alert_iv, on_fire=_alert_fire_cb,
evaluate_on_start=True,
)
_asched.start()
set_alert_scheduler(_asched)
print(f"[tau-rag] alert scheduler every {_alert_iv}s")
except Exception as _e:
print(f"[tau-rag] alert scheduler setup failed: {_e}")
# Auto-warmup on startup if env requests it (v1.56).
if _os.environ.get("TAU_RAG_WARMUP") == "1":
try:
_fn = getattr(_pipeline, "warmup", None)
if callable(_fn):
_fn()
_pipeline._warmed = True # type: ignore[attr-defined]
print("[tau-rag] auto-warmup complete")
except Exception as _e:
print(f"[tau-rag] auto-warmup failed: {_e}")
# Auto-snapshot on shutdown — pairs with auto-restore on startup above.
@app.on_event("shutdown")
def _save_snapshot_on_shutdown() -> None:
# Stop the periodic thread first so we don't race with the final save
auto = get_autosnapshotter()
if auto:
auto.stop()
set_autosnapshotter(None)
# v1.78 — also stop the metrics sampler cleanly
try:
from ..middleware import get_metrics_sampler, set_metrics_sampler
msamp = get_metrics_sampler()
if msamp:
msamp.stop()
set_metrics_sampler(None)
except Exception:
pass
# v1.81 — stop the alert scheduler cleanly
try:
from ..middleware import get_alert_scheduler, set_alert_scheduler
asched = get_alert_scheduler()
if asched:
asched.stop()
set_alert_scheduler(None)
except Exception:
pass
# v1.93 — stop the analytics retention scheduler cleanly
try:
from ..middleware import (
get_retention_scheduler, set_retention_scheduler,
)
rs = get_retention_scheduler()
if rs:
rs.stop()
set_retention_scheduler(None)
except Exception:
pass
path = _os.environ.get("TAU_RAG_SNAPSHOT_PATH")
if not path:
return
try:
summary = _pipeline.save_snapshot(path)
print(f"[tau-rag] shutdown snapshot saved: {summary}")
get_obs().audit("snapshot.auto_save_on_shutdown", **summary)
except Exception as _e:
print(f"[tau-rag] shutdown snapshot failed: {_e}")
# ------------------------------------------------------------------ schemas
class SearchRequest(BaseModel):
query: str
k: int = 10
rerank_k: int = 5
strategy: str = "hybrid"
lang: str = "he"
filters: Dict[str, Any] = {}
class DocumentRequest(BaseModel):
id: str
text: str
metadata: Dict[str, Any] = {}
class DocumentsRequest(BaseModel):
documents: List[DocumentRequest]
# ------------------------------------------------------------------ routes
_PLAYGROUND_HTML = """
TAU-RAG
🔎 TAU-RAG — Hebrew legal RAG
Pipeline alive at /v1/*. Swagger UI:
/docs · ReDoc: /redoc
📄 Documents
❓ Single query
💬 Chat
"""
@app.get("/", response_class=None)
def root():
from fastapi.responses import HTMLResponse
return HTMLResponse(_PLAYGROUND_HTML)
@app.get("/favicon.ico")
def favicon():
# Silent 204 to stop the noisy browser 404
from fastapi.responses import Response
return Response(status_code=204)
@app.get("/health")
def health():
return {"ok": True, "version": "2.0.0"}
@app.get("/v1/version")
def version_manifest():
"""Build + runtime version info. Unauthenticated so anyone (including
deploy scripts, monitoring, and teammates debugging) can check what's
actually running. Does not expose secrets, just structural metadata."""
import platform as _plat
import sys as _sys
import subprocess as _sp
# Pipeline structure
retr_multi = getattr(_pipeline, "retrievers", None)
retriever_members = (
sorted(getattr(retr_multi, "retrievers", {}).keys())
if retr_multi is not None else []
)
cfg = _pipeline.config
preset = _os.environ.get("TAU_RAG_PRESET", "unknown")
# Build info — keep it safe to serialize
try:
import fastapi as _fastapi
fastapi_v = getattr(_fastapi, "__version__", "unknown")
except Exception:
fastapi_v = "unknown"
# Git metadata — optional; silent fallback if not in a git checkout
git_info: Dict[str, Any] = {}
try:
commit = _sp.check_output(
["git", "rev-parse", "HEAD"],
stderr=_sp.DEVNULL, timeout=1).decode().strip()
branch = _sp.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
stderr=_sp.DEVNULL, timeout=1).decode().strip()
dirty = _sp.check_output(
["git", "status", "--porcelain"],
stderr=_sp.DEVNULL, timeout=1).decode().strip()
git_info = {
"commit": commit,
"commit_short": commit[:8],
"branch": branch,
"dirty": bool(dirty),
}
except Exception:
git_info = {"available": False}
# Enabled feature flags (from env) — helps debug "is it off/on in prod?"
features = {
"auth_required": _os.environ.get("TAU_RAG_REQUIRE_AUTH") == "1",
"auto_warmup": _os.environ.get("TAU_RAG_WARMUP") == "1",
"snapshot_path": bool(_os.environ.get("TAU_RAG_SNAPSHOT_PATH")),
"snapshot_interval": float(_os.environ.get("TAU_RAG_SNAPSHOT_INTERVAL") or 0) or None,
"synonyms_path": bool(_os.environ.get("TAU_RAG_SYNONYMS_PATH")),
"hsts": _os.environ.get("TAU_RAG_HSTS") == "1",
"csp": bool(_os.environ.get("TAU_RAG_CSP")),
"cors_origins": int(bool(_os.environ.get("TAU_RAG_CORS_ORIGINS"))),
"log_stdout": _os.environ.get("TAU_RAG_LOG_STDOUT") == "1",
"log_file": bool(_os.environ.get("TAU_RAG_LOG_PATH")),
"audit_webhook": bool(_os.environ.get("TAU_RAG_AUDIT_WEBHOOK_URL")),
"endpoint_rate_limits": bool(
_os.environ.get("TAU_RAG_ENDPOINT_RATE_LIMITS")),
"audit_export": True,
"log_stream": True,
"key_rotation": True,
"snapshot_diff": True,
"metrics_history": bool(
_os.environ.get("TAU_RAG_METRICS_HISTORY_INTERVAL_SEC")),
"webhook_circuit_breaker": True,
"alert_rules": True,
"alert_scheduler": bool(
_os.environ.get("TAU_RAG_ALERT_EVAL_INTERVAL_SEC")),
"doc_stats": True,
"retriever_attribution": True,
"cocitation_graph": True,
"content_health": True,
"content_health_ui": True,
"eval_latency_gate": True,
"content_health_history": True,
"query_fingerprints": True,
"preset_promote_candidates": True,
"preset_auto_promote": True,
"analytics_retention": True,
"analytics_retention_scheduler": bool(
_os.environ.get("TAU_RAG_ANALYTICS_TTL_DAYS")),
"doc_freshness": True,
"doc_update_priorities": True,
"query_doc_affinity": True,
"analytics_dump_restore": True,
"query_analytics_ui": True,
"query_replay": True,
"replay_body_capture": (
_os.environ.get("TAU_RAG_OBS_CAPTURE_BODY") == "1"),
"v2_stable_api": True,
"about_endpoint": True,
"semantic_cache": (
_os.environ.get("TAU_RAG_SEMANTIC_CACHE") == "1"),
"graph_cocitation_boost": (
float(_os.environ.get("TAU_RAG_GRAPH_COCITATION_BOOST") or 0.0)
if _os.environ.get("TAU_RAG_GRAPH_COCITATION_BOOST") else False),
"query_doc_boost": (
float(_os.environ.get("TAU_RAG_QUERY_DOC_BOOST") or 0.0)
if _os.environ.get("TAU_RAG_QUERY_DOC_BOOST") else False),
"request_spans": True,
"span_timeline_ui": True,
"limiter_backend_protocol": True,
"maintenance_mode": True,
"pii_redaction": True,
"pii_redaction_enabled": (
_os.environ.get("TAU_RAG_PII_REDACT") == "1"),
"slow_query_detection": True,
"readiness_registry": True,
"daily_quota": True,
"idempotency_key": True,
"request_timeout": True,
"log_rotation": True,
"span_exporter_protocol": True,
"span_exporter_type": type(
__import__("tau_rag.observability.span_exporters",
fromlist=["get_span_exporter"])
.get_span_exporter()).__name__,
}
return {
"version": "2.0.0",
"preset": preset,
"pipeline": {
"retriever_members": retriever_members,
"generator_provider": getattr(cfg.generation, "provider", "unknown"),
"fusion_method": getattr(cfg.fusion, "method", "unknown"),
"rerank_method": (getattr(cfg.rerank, "method", None)
if getattr(cfg, "rerank", None) else None),
"verifier": type(_pipeline.verifier).__name__,
"chunker": getattr(_pipeline, "_chunker_last", "fixed"),
},
"build": {
"python": _sys.version.split()[0],
"platform": _plat.platform(),
"fastapi": fastapi_v,
},
"git": git_info,
"features": features,
}
# ------------------------------------------------------ ops-ready endpoints
from .metrics import render_prometheus, check_readiness # noqa: E402
@app.get("/livez", response_class=None)
def livez():
"""Liveness probe — 200 if the process can answer."""
from fastapi.responses import PlainTextResponse
return PlainTextResponse("ok", status_code=200)
@app.get("/readyz")
def readyz(require_warmed: bool = False):
"""Readiness probe — 503 + detail if pipeline isn't ready.
Pass ``?require_warmed=1`` to also fail until ``POST /v1/admin/warmup``
has been invoked (useful for deployment gating).
v2.11 — also consults the pluggable readiness registry so plugins
and operator-registered checks (Redis, S3, etc.) participate. A
failing critical check in the registry flips /readyz to 503.
"""
ok, detail = check_readiness(_pipeline, require_warmed=bool(require_warmed))
# v2.11 — also check pluggable registry (ties v2.7 drain mode and any
# operator-registered checks)
from ..middleware.readiness import get_readiness_registry
reg_result = get_readiness_registry().evaluate()
if not ok or not reg_result["ready"]:
# Merge detail from both sources
body = {
"detail": detail if not ok else None,
"checks": reg_result["checks"],
"n_passed": reg_result["n_passed"],
"n_failed": reg_result["n_failed"],
}
raise HTTPException(status_code=503, detail=body)
return {"ok": True, "detail": detail,
"checks": reg_result["checks"]}
@app.get("/v1/admin/readiness")
def admin_readiness():
"""Full readiness report (v2.11). Always returns 200, unlike
/readyz — useful for dashboards that want to show health without
tripping k8s routing. Contains every registered check's current
state, plus an overall ``ready`` bool."""
from ..middleware.readiness import get_readiness_registry
return get_readiness_registry().evaluate()
@app.post("/v1/admin/warmup")
def admin_warmup(request: Request):
"""Pre-load heavy components (embedders, tokenizers, adapters). Sets the
``pipeline._warmed`` flag so ``/readyz?require_warmed=1`` starts passing."""
import time as _t
t0 = _t.time()
try:
fn = getattr(_pipeline, "warmup", None)
if callable(fn):
fn()
_pipeline._warmed = True # type: ignore[attr-defined]
elapsed = round((_t.time() - t0) * 1000.0, 2)
get_obs().audit(
"pipeline.warmup",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
elapsed_ms=elapsed,
)
return {"warmed": True, "elapsed_ms": elapsed}
except Exception as e:
raise HTTPException(status_code=500,
detail=f"warmup failed: {type(e).__name__}: {e}"[:200])
@app.get("/metrics", response_class=None)
def metrics():
"""Prometheus exposition format — scrape me every 15s."""
from fastapi.responses import PlainTextResponse
auth = get_auth()
keys = auth.list_keys()
active = sum(1 for k in keys if not k.get("revoked"))
revoked = sum(1 for k in keys if k.get("revoked"))
body = render_prometheus(
obs_stats=get_obs().stats(),
cache_stats=get_cache().stats(),
limiter_stats=get_limiter().stats(),
auth_keys=active,
auth_keys_revoked=revoked,
version="2.0.0",
)
return PlainTextResponse(body, media_type="text/plain; version=0.0.4")
@app.post("/v1/documents")
def add_documents(req: DocumentsRequest):
try:
validate_doc_list(req.documents)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
docs = [Document(id=d.id, text=d.text, metadata=d.metadata) for d in req.documents]
n = _pipeline.add_documents(docs)
return {"added_chunks": n, "documents": len(docs)}
# ---- document lifecycle endpoints ----------------------------------------
class DocumentBody(BaseModel):
id: str
text: str
metadata: Dict[str, Any] = {}
@app.get("/v1/documents")
def list_documents(
request: Request,
q: Optional[str] = None,
limit: int = 50,
offset: int = 0,
preview_chars: int = 200,
):
"""List or search indexed documents.
Query params:
* ``q`` — substring (case-insensitive) over text + id
* ``limit`` / ``offset`` — pagination (default 50 / 0)
* ``preview_chars`` — first N chars returned per doc (default 200)
* ``metadata.=`` — filter on flat metadata keys (repeatable)
Backward-compat: with no query params, returns the v1.38 summary shape
``{documents: [...], count: N}``.
"""
# Flat metadata filter: any query param starting with 'metadata.'
meta_filter: Dict[str, str] = {}
for k, v in request.query_params.multi_items():
if k.startswith("metadata."):
meta_filter[k[len("metadata."):]] = v
if limit < 1 or limit > 10_000:
limit = 50
if offset < 0:
offset = 0
result = _pipeline.search_documents(
q=q, metadata=meta_filter or None,
limit=limit, offset=offset, preview_chars=preview_chars,
)
# Back-compat: keep `count` key that v1.38 clients expect
result["count"] = result["matched"]
return result
@app.get("/v1/documents/export", response_class=None)
def export_documents(request: Request):
"""Export the full indexed corpus as JSONL (one ``{id,text,metadata}`` per line).
Supports the same filters as ``GET /v1/documents``:
* ``?q=`` — substring search
* ``?metadata.=`` — flat metadata filter (repeatable)
Returns ``application/x-ndjson`` with a download filename so browsers save
it as ``tau-rag-documents.jsonl``.
"""
from fastapi.responses import PlainTextResponse
import json as _json
q = request.query_params.get("q")
meta_filter: Dict[str, str] = {}
for k, v in request.query_params.multi_items():
if k.startswith("metadata."):
meta_filter[k[len("metadata."):]] = v
# Iterate through the full doc-log, applying the same filters as
# search_documents() but without the limit cap — export is all-or-none.
_pipeline._ensure_doc_log()
qn = (q or "").strip().lower()
lines: List[str] = []
for d in _pipeline._indexed_docs:
if qn:
hay = (d.text or "").lower()
if qn not in hay and qn not in (d.id or "").lower():
continue
if meta_filter:
ok = True
for mk, mv in meta_filter.items():
if str((d.metadata or {}).get(mk)) != str(mv):
ok = False
break
if not ok:
continue
lines.append(_json.dumps({
"id": d.id,
"text": d.text,
"metadata": d.metadata or {},
}, ensure_ascii=False))
body = "\n".join(lines) + ("\n" if lines else "")
return PlainTextResponse(
body,
media_type="application/x-ndjson",
headers={
"Content-Disposition":
'attachment; filename="tau-rag-documents.jsonl"',
"X-Document-Count": str(len(lines)),
},
)
@app.get("/v1/documents/stats")
def index_stats():
"""Corpus-level stats: doc count, text-length distribution, metadata
value histogram, metadata coverage, and retriever set. Safe on large
corpora (no full-text scan)."""
return _pipeline.index_stats()
@app.get("/v1/documents/duplicates")
def admin_duplicates():
"""Scan the index for documents sharing normalized content (collapsed
whitespace, case-folded, sha256'd). Returns ``{groups: [{hash, members}],
n_groups, n_duplicate_docs, total_docs}`` — only groups with ≥2 members.
Declared *before* ``/v1/documents/{doc_id}`` so FastAPI matches
``/duplicates`` as a fixed path, not a doc id."""
groups = _pipeline.find_duplicates()
pretty = [
{"hash": h, "members": members}
for h, members in sorted(groups.items(),
key=lambda kv: -len(kv[1]))
]
n_dup_docs = sum(len(g["members"]) for g in pretty)
total = len(_pipeline.list_documents())
return {
"n_groups": len(pretty),
"n_duplicate_docs": n_dup_docs,
"total_docs": total,
"groups": pretty,
}
# ---- Per-document citation stats (v1.82) --------------------------------
# Placed BEFORE /v1/documents/{doc_id} so FastAPI matches these fixed
# paths first — same ordering trick as /duplicates above.
@app.get("/v1/admin/documents/stats/summary")
def admin_docs_stats_summary():
"""Global rollup: n_docs tracked, total retrieved/cited, global
cite_rate, persistence path (v1.82)."""
from ..middleware import get_doc_stats
return get_doc_stats().summary()
@app.get("/v1/admin/documents/stats/top-cited")
def admin_docs_stats_top_cited(n: int = 10):
"""Top ``n`` documents by citation count (v1.82). Each row carries
``{doc_id, n_retrieved, n_cited, cite_rate, first_seen_at, ...}``."""
from ..middleware import get_doc_stats
return {"top": get_doc_stats().top_cited(n=int(n))}
@app.get("/v1/admin/documents/stats/unused")
def admin_docs_stats_unused(
min_retrieved: int = 1,
max_cite_rate: float = 0.0,
):
"""Docs that were retrieved ``min_retrieved``+ times but whose
cite_rate stayed at or below ``max_cite_rate`` (default 0 → never
cited). Useful for finding retrieval false-positives (v1.82)."""
from ..middleware import get_doc_stats
return {
"unused": get_doc_stats().unused(
min_retrieved=int(min_retrieved),
max_cite_rate=float(max_cite_rate),
),
}
@app.post("/v1/admin/documents/stats/reset")
def admin_docs_stats_reset(request: Request):
"""Wipe all per-doc counters (v1.82). audit event emitted."""
from ..middleware import get_doc_stats
before = get_doc_stats().summary()
get_doc_stats().clear()
get_obs().audit(
"doc.stats.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_docs=before["n_docs"],
prev_total_cited=before["total_cited"],
)
return {"reset": True, "before": before}
# ---- Per-retriever attribution (v1.83) ----------------------------------
@app.get("/v1/admin/retrievers/stats")
def admin_retriever_stats():
"""All retrievers ranked by n_cited_contributions (v1.83)."""
from ..middleware import get_retriever_attribution
store = get_retriever_attribution()
return {
"summary": store.summary(),
"stats": store.all_stats(),
}
@app.get("/v1/admin/retrievers/stats/ranking")
def admin_retriever_ranking():
"""Retrievers ordered by cite_rate × log(1 + n_contributions) —
smooths precision by sample size so rare-but-perfect retrievers
don't outrank workhorses (v1.83)."""
from ..middleware import get_retriever_attribution
return {"ranking": get_retriever_attribution().ranking()}
@app.post("/v1/admin/retrievers/stats/reset")
def admin_retriever_stats_reset(request: Request):
"""Wipe per-retriever counters + audit (v1.83)."""
from ..middleware import get_retriever_attribution
store = get_retriever_attribution()
before = store.summary()
store.clear()
get_obs().audit(
"retriever.stats.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_retrievers=before["n_retrievers"],
prev_total_cited=before["total_cited"],
)
return {"reset": True, "before": before}
# ---- Co-citation graph (v1.84) ------------------------------------------
@app.get("/v1/admin/documents/cocitation/summary")
def admin_cocitation_summary():
"""Rollup of the co-citation graph: n_events (responses with ≥2
cites), n_pairs, n_docs, total_count (v1.84)."""
from ..middleware import get_cocitation
return get_cocitation().summary()
@app.get("/v1/admin/documents/cocitation/top")
def admin_cocitation_top(n: int = 20):
"""Top ``n`` most-common co-citation pairs (v1.84)."""
from ..middleware import get_cocitation
return {"top": get_cocitation().top_pairs(n=int(n))}
@app.post("/v1/admin/documents/cocitation/reset")
def admin_cocitation_reset(request: Request):
"""Wipe the co-citation graph + audit (v1.84)."""
from ..middleware import get_cocitation
store = get_cocitation()
before = store.summary()
store.clear()
get_obs().audit(
"cocitation.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_pairs=before["n_pairs"],
prev_total_count=before["total_count"],
)
return {"reset": True, "before": before}
@app.get("/v1/admin/content/health")
def admin_content_health(
top_n: int = 5,
unused_min_retrieved: int = 3,
):
"""Consolidated corpus health report (v1.85) — merges v1.82 doc
stats, v1.83 retriever attribution, v1.84 co-citation into a single
answer. Cross-cutting insights: ``dead_docs`` (in corpus but never
retrieved), ``isolated_docs`` (in corpus but never co-cited).
Query params:
* ``top_n`` — rows to include in top-cited / top-pairs / ranking
sections (default 5).
* ``unused_min_retrieved`` — passed to doc_stats.unused() to
filter retrieval false-positives.
"""
from ..middleware import (
get_doc_stats, get_retriever_attribution, get_cocitation,
)
doc_store = get_doc_stats()
ra_store = get_retriever_attribution()
cc_store = get_cocitation()
doc_summary = doc_store.summary()
ra_summary = ra_store.summary()
cc_summary = cc_store.summary()
# Corpus set — every indexed doc id
all_docs = {d["id"] for d in _pipeline.list_documents()}
touched_docs = {
d["doc_id"] for d in doc_store.top_cited(n=10 ** 9)
} # all tracked docs, regardless of count
# Dead docs: indexed but counters never fired (n_retrieved == 0)
dead = sorted(all_docs - touched_docs)
# Isolated docs: tracked, but never co-cited with anyone
# (i.e. they never appeared in a response with ≥2 sources).
partnered = set()
for pair in cc_store.top_pairs(n=10 ** 9):
partnered.add(pair["a"])
partnered.add(pair["b"])
isolated = sorted(touched_docs - partnered)
# Derived corpus-level health score.
# * coverage = fraction of indexed docs that were ever retrieved
# * cite_rate = global doc-level cite rate
# * connectivity = fraction of touched docs that are non-isolated
n_all = max(1, len(all_docs))
n_touched = len(touched_docs)
coverage = n_touched / n_all
cite_rate = doc_summary.get("global_cite_rate", 0.0)
connectivity = (
(n_touched - len(isolated)) / max(1, n_touched)
if n_touched else 0.0
)
# Equal-weight geometric mean — any dimension collapsing to 0
# drags the whole score down. Helps operators see a single knob.
score = (coverage * cite_rate * connectivity) ** (1 / 3) if (
coverage > 0 and cite_rate > 0 and connectivity > 0
) else 0.0
return {
"score": round(score, 4),
"coverage": round(coverage, 4),
"cite_rate": round(cite_rate, 4),
"connectivity": round(connectivity, 4),
"corpus": {
"n_indexed": len(all_docs),
"n_touched": n_touched,
"n_dead": len(dead),
"n_isolated": len(isolated),
},
"top_cited": doc_store.top_cited(n=int(top_n)),
"top_noisy": doc_store.unused(
min_retrieved=int(unused_min_retrieved),
max_cite_rate=0.0,
),
"retrievers": {
"summary": ra_summary,
"ranking": ra_store.ranking()[:int(top_n)],
},
"cocitation": {
"summary": cc_summary,
"top_pairs": cc_store.top_pairs(n=int(top_n)),
},
"dead_docs": dead,
"isolated_docs": isolated,
}
# ---- Query fingerprint analytics (v1.89) --------------------------------
@app.get("/v1/admin/queries/stats/summary")
def admin_query_stats_summary():
"""Rollup: n_unique_queries, n_events (total), avg_sources per
query (v1.89)."""
from ..middleware import get_query_stats
return get_query_stats().summary()
@app.get("/v1/admin/queries/stats/top")
def admin_query_stats_top(n: int = 10):
"""Top ``n`` queries by frequency (v1.89)."""
from ..middleware import get_query_stats
return {"top": get_query_stats().top(n=int(n))}
@app.get("/v1/admin/queries/stats/recent")
def admin_query_stats_recent(
since: Optional[float] = None,
n: int = 10,
):
"""Queries whose last hit was ≥ ``since`` (Unix ts), newest first
(v1.89). Omit ``since`` to get the most-recently-seen N regardless
of age."""
from ..middleware import get_query_stats
return {"recent": get_query_stats().recent(since=since, n=int(n))}
@app.get("/v1/admin/queries/promote-candidates")
def admin_query_promote_candidates(
min_count: int = 3,
min_sources: float = 0.0,
max_avg_latency_ms: Optional[float] = None,
n: int = 20,
):
"""Return query fingerprints that are strong candidates for
promotion to saved presets (v1.90).
Heuristic: a query is a good preset candidate when it's been
asked often enough to justify the saved-search slot, returns a
useful number of sources on average, doesn't already have a
preset, and (optionally) isn't slower than some threshold.
Query params:
* ``min_count`` — minimum observed occurrences
(default 3).
* ``min_sources`` — minimum ``avg_sources`` per response.
Filters out popular queries that
find nothing useful (0 = no filter).
* ``max_avg_latency_ms`` — optional ceiling on average latency.
Omit to not filter.
* ``n`` — cap on rows returned.
A candidate row is a ``QueryStats.to_dict()`` plus a derived
``suggested_preset_name`` that ops can accept as-is.
"""
from ..middleware import get_query_stats
from ..middleware.query_stats import _canonicalize
from ..presets import get_preset_store
query_store = get_query_stats()
preset_store = get_preset_store()
# Pre-index the presets by canonical query text so O(P) setup turns
# the per-candidate check into O(1).
existing_canonical: set = set()
for p in preset_store.list_all():
existing_canonical.add(_canonicalize(p.get("query", "")))
candidates: List[Dict[str, Any]] = []
for row in query_store.top(n=10 ** 9):
if row["count"] < int(min_count):
continue
if row["avg_sources"] < float(min_sources):
continue
if (max_avg_latency_ms is not None
and row["avg_latency_ms"] > float(max_avg_latency_ms)):
continue
canonical = _canonicalize(row["sample"])
if canonical in existing_canonical:
continue
# Derive a snake_case preset name from the sample (short).
suggested = _suggest_preset_name(row["sample"])
candidates.append({
**row,
"suggested_preset_name": suggested,
"already_preset": False,
})
if len(candidates) >= int(n):
break
return {
"candidates": candidates,
"n_candidates": len(candidates),
"min_count": int(min_count),
"min_sources": float(min_sources),
"n_existing_presets": len(existing_canonical),
}
class PresetPromoteRequest(BaseModel):
# Explicit list of fingerprints to promote. Mutually exclusive with
# auto-mode filters below — if both are set, names take precedence.
fingerprints: Optional[List[str]] = None
# Auto-mode: pick candidates via filters (same as v1.90 endpoint)
min_count: int = 3
min_sources: float = 0.0
max_avg_latency_ms: Optional[float] = None
limit: int = 20
# Common preset knobs — applied to every created preset
k: int = 10
rerank_k: int = 5
strategy: str = "hybrid"
lang: str = "he"
# Naming
name_prefix: str = "" # optional prefix (e.g. "promoted-")
# Safety
dry_run: bool = False # preview without creating
@app.post("/v1/admin/queries/promote")
def admin_queries_promote(req: PresetPromoteRequest, request: Request):
"""Auto-promote query-stats candidates to saved presets (v1.91).
Two modes:
* **Explicit**: pass ``fingerprints=[...]`` to promote specific
queries by their v1.89 fingerprints.
* **Filtered**: pass the same filter params as
``/v1/admin/queries/promote-candidates`` (v1.90) and we'll
promote the top ``limit`` that match.
In both modes we skip queries whose canonical text already has a
preset, and we deduplicate against name collisions (adding a ``-2``,
``-3`` suffix). ``dry_run=True`` returns the planned actions
without touching the preset store.
Returns::
{
created: [{name, query, fingerprint}, ...],
skipped: [{fingerprint, reason}, ...],
dry_run: bool,
n_created: int,
n_skipped: int,
}
One ``preset.auto_promoted`` audit event is emitted per created
preset so the change flows through the webhook (v1.71).
"""
from ..middleware import get_query_stats
from ..middleware.query_stats import _canonicalize
from ..presets import get_preset_store, QueryPreset
query_store = get_query_stats()
preset_store = get_preset_store()
# Pre-index existing presets by canonical query AND by name so we
# can skip duplicates in both dimensions.
existing_canonical: set = set()
existing_names: set = set()
for p in preset_store.list_all():
existing_canonical.add(_canonicalize(p.get("query", "")))
existing_names.add(p.get("name", ""))
# ---- pick candidates
if req.fingerprints is not None:
# Explicit mode
rows = []
for fp in req.fingerprints:
s = query_store.get(fp)
if s is None:
rows.append({"fingerprint": fp, "_missing": True})
else:
rows.append(s.to_dict())
else:
# Filter mode — mirror the v1.90 logic
rows = []
for row in query_store.top(n=10 ** 9):
if row["count"] < int(req.min_count):
continue
if row["avg_sources"] < float(req.min_sources):
continue
if (req.max_avg_latency_ms is not None
and row["avg_latency_ms"] > float(req.max_avg_latency_ms)):
continue
canonical = _canonicalize(row["sample"])
if canonical in existing_canonical:
continue
rows.append(row)
if len(rows) >= int(req.limit):
break
# ---- plan
created: List[Dict[str, Any]] = []
skipped: List[Dict[str, Any]] = []
used_names = set(existing_names)
for row in rows:
fp = row.get("fingerprint", "")
if row.get("_missing"):
skipped.append({"fingerprint": fp,
"reason": "fingerprint not found in query_stats"})
continue
sample = row.get("sample", "")
canonical = _canonicalize(sample)
if not canonical:
skipped.append({"fingerprint": fp,
"reason": "empty query after canonicalization"})
continue
if canonical in existing_canonical:
skipped.append({"fingerprint": fp,
"reason": "already a preset (same canonical)"})
continue
base_name = _suggest_preset_name(sample)
if req.name_prefix:
base_name = f"{req.name_prefix}{base_name}"
# Dedupe against already-used names
name = base_name
suffix = 2
while name in used_names:
name = f"{base_name}-{suffix}"
suffix += 1
if not req.dry_run:
try:
preset_store.put(QueryPreset(
name=name, query=sample,
k=int(req.k), rerank_k=int(req.rerank_k),
strategy=req.strategy, lang=req.lang,
notes=f"auto-promoted from traffic (fp={fp}, "
f"count={row.get('count')})",
))
except Exception as e:
skipped.append({"fingerprint": fp,
"reason": f"put failed: "
f"{type(e).__name__}: {e}"})
continue
get_obs().audit(
"preset.auto_promoted",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
name=name, fingerprint=fp,
count=row.get("count"),
avg_sources=row.get("avg_sources"),
)
used_names.add(name)
existing_canonical.add(canonical)
created.append({
"name": name,
"query": sample,
"fingerprint": fp,
"count": row.get("count"),
})
return {
"created": created,
"skipped": skipped,
"n_created": len(created),
"n_skipped": len(skipped),
"dry_run": bool(req.dry_run),
}
def _suggest_preset_name(sample: str) -> str:
"""Turn a raw user query into a safe preset id: lowercase, ASCII
where possible, hyphens for whitespace, strip punctuation, cap
length. Keeps non-ASCII runs (Hebrew letters) as-is when they
can't be transliterated, so the result is still recognizable."""
import re
s = (sample or "").strip().lower()
# Drop typical punctuation
s = re.sub(r"[\?\!\.\,\:\;\(\)\[\]\{\}\"'`]", "", s)
s = re.sub(r"\s+", "-", s)
# Clip to 48 chars — leaves room for a namespace prefix
if len(s) > 48:
s = s[:48].rstrip("-")
return s or "preset"
@app.post("/v1/admin/queries/stats/reset")
def admin_query_stats_reset(request: Request):
"""Wipe the query fingerprint store + audit (v1.89)."""
from ..middleware import get_query_stats
store = get_query_stats()
before = store.summary()
store.clear()
get_obs().audit(
"query.stats.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_unique=before["n_unique_queries"],
prev_n_events=before["n_events"],
)
return {"reset": True, "before": before}
class AnalyticsPruneRequest(BaseModel):
older_than_days: Optional[float] = None
older_than_seconds: Optional[float] = None
# Which stores to prune. Omit a flag to skip that store.
doc_stats: bool = True
retriever_attribution: bool = True
cocitation: bool = True
query_stats: bool = True
@app.get("/v1/admin/analytics/prune/scheduler")
def admin_retention_scheduler_status():
"""Report the background retention scheduler state (v1.93)."""
from ..middleware import get_retention_scheduler
sched = get_retention_scheduler()
if sched is None:
return {"enabled": False, "is_running": False}
return {"enabled": True, **sched.status()}
@app.post("/v1/admin/analytics/prune")
def admin_analytics_prune(
req: AnalyticsPruneRequest,
request: Request,
):
"""Prune stale entries across all analytics stores (v1.92).
Removes rows whose last-activity timestamp is older than the TTL.
Pass ``older_than_days`` OR ``older_than_seconds``; one is required.
Per-store flags let ops target a subset (e.g. prune only
``query_stats`` while keeping doc history).
Returns per-store ``{n_removed, n_remaining_after}`` + audit event.
"""
from ..middleware import (
get_doc_stats, get_retriever_attribution,
get_cocitation, get_query_stats,
)
# Resolve TTL seconds
if req.older_than_seconds is not None:
ttl_s = float(req.older_than_seconds)
elif req.older_than_days is not None:
ttl_s = float(req.older_than_days) * 86400.0
else:
raise HTTPException(
status_code=400,
detail={"error": "either older_than_days or "
"older_than_seconds is required"},
)
if ttl_s <= 0:
raise HTTPException(
status_code=400,
detail={"error": "TTL must be positive"},
)
results: Dict[str, Dict[str, int]] = {}
total_removed = 0
if req.doc_stats:
store = get_doc_stats()
n = store.prune(ttl_s)
results["doc_stats"] = {
"n_removed": n,
"n_remaining_after": store.summary()["n_docs"],
}
total_removed += n
if req.retriever_attribution:
store = get_retriever_attribution()
n = store.prune(ttl_s)
results["retriever_attribution"] = {
"n_removed": n,
"n_remaining_after": store.summary()["n_retrievers"],
}
total_removed += n
if req.cocitation:
store = get_cocitation()
n = store.prune(ttl_s)
results["cocitation"] = {
"n_removed": n,
"n_remaining_after": store.summary()["n_pairs"],
}
total_removed += n
if req.query_stats:
store = get_query_stats()
n = store.prune(ttl_s)
results["query_stats"] = {
"n_removed": n,
"n_remaining_after": store.summary()["n_unique_queries"],
}
total_removed += n
get_obs().audit(
"analytics.prune",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
older_than_seconds=ttl_s,
total_removed=total_removed,
)
return {
"ttl_seconds": ttl_s,
"total_removed": total_removed,
"per_store": results,
}
@app.get("/v1/about")
def about():
"""Public architectural overview (v2.0). Non-admin — no key
required. Useful for clients, docs generators, CI checks.
Covers the layers of tau-rag and points at the relevant primitives.
"""
return {
"name": "tau-rag",
"version": "2.0.0",
"tagline": "Unified Hebrew-legal RAG with structure-preserving "
"verification + TAU-Ω signals",
"layers": {
"retrieval": {
"retrievers": ["bm25", "gematria", "hilbert", "graph"],
"fusion": "rank-based or weighted",
"rerank": "optional cross-encoder or score-based",
"chunker": "fixed | sentence | legal_hebrew",
},
"observability_stack": {
"push": "webhook + breaker (v1.71/79)",
"batch": "/v1/admin/audit/export (v1.74)",
"pull_stream": "/v1/admin/logs/stream SSE (v1.75)",
"history": "metrics + content health (v1.78/88)",
"alerts": "rules + scheduler (v1.80/81)",
},
"content_analytics": {
"doc_stats": "v1.82",
"retriever_attribution": "v1.83",
"cocitation": "v1.84",
"query_stats": "v1.89",
"doc_freshness": "v1.94",
"query_doc_affinity": "v1.96",
},
"analytics_cross_cuts": {
"content_health": "v1.85/86",
"update_priorities": "v1.95",
"query_analytics_ui": "v1.98",
"dump_restore": "v1.97",
},
"debugging": {
"request_ids": "X-Request-ID on every response",
"replay": "v1.99 re-execute by request_id",
},
},
"patterns": {
"side_channel_stores":
"singleton+inject pattern; pipeline hook silent-fail; "
"admin CRUD; persistence opt-in",
"daemons":
"AutoSnapshotter / MetricsHistorySampler / "
"AlertScheduler / AnalyticsRetentionScheduler — "
"start/stop/is_running/status + Event.wait + silent-fail",
"quiet_on_zero":
"schedulers emit audits only on state change",
"html_dashboards":
"inline CSS, zero JS, zero CDN, escape-safe, "
"meta-refresh for wall screens",
},
"stability": {
"api_stability": "v2.0 marks /v1/* as stable — additive "
"changes only; breaking changes → /v2/*",
"deprecation_policy": "6-month notice; features.* flags "
"track active capabilities",
},
"counts": {
"endpoints": "80+",
"tests": "1096+",
"side_channels": 6,
"daemons": 4,
"html_dashboards": 4,
},
}
@app.get("/v1/admin/requests/{request_id}/spans/ui", response_class=None)
def admin_request_spans_ui(request_id: str, refresh: int = 0):
"""HTML timeline view of a request's spans (v2.5). Renders v2.4
span data as a gantt-style bar chart for quick operator inspection.
Same design language as v1.86 / v1.98 dashboards."""
from fastapi.responses import HTMLResponse
from .span_timeline_ui import render_span_timeline
# Reuse the JSON endpoint's data gathering by calling its function
data = admin_request_spans(request_id)
html = render_span_timeline(
request_id=data["request_id"],
n_spans=data["n_spans"],
total_ms=data["total_ms"],
spans=data["spans"],
refresh_sec=int(refresh or 0),
)
return HTMLResponse(html)
@app.get("/v1/admin/requests/{request_id}/spans")
def admin_request_spans(request_id: str):
"""Return in-memory trace spans for a specific request_id (v2.4).
Pipeline stages (understand, retrieve, fuse, rerank, generate,
verify, ...) each open a span; middleware auto-tags every span
with the current request_id. This endpoint pulls them back by
that id.
Returns::
{
"request_id": str,
"n_spans": int,
"total_ms": float (root span duration),
"spans": [{name, trace_id, span_id, parent_id,
duration_ms, attrs}, ...],
}
Useful for:
* seeing where time went inside a slow request
* correlating a user complaint with what actually executed
* diagnosing retriever-specific failures per query
"""
from ..observability.tracing import get_tracer
spans = get_tracer().spans_for_request_id(request_id)
if not spans:
raise HTTPException(
status_code=404,
detail={"error": "no spans found for request_id",
"hint": "spans are in-memory — oldest get evicted "
"past the 5000-span cap"},
)
out = []
root_total_ms = 0.0
for s in spans:
dur_ms = (s.end_ts - s.start_ts) * 1000.0 if s.end_ts else 0.0
if s.parent_id is None and dur_ms > root_total_ms:
root_total_ms = dur_ms
out.append({
"name": s.name,
"trace_id": s.trace_id,
"span_id": s.span_id,
"parent_id": s.parent_id,
"duration_ms": round(dur_ms, 2),
"attrs": s.attrs,
})
return {
"request_id": request_id,
"n_spans": len(out),
"total_ms": round(root_total_ms, 2),
"spans": out,
}
@app.post("/v1/admin/replay/{request_id}")
def admin_replay(request_id: str, request: Request):
"""Re-execute a previously-logged request against the current
pipeline (v1.99). Requires body capture to have been on at the
time of the original request (``TAU_RAG_OBS_CAPTURE_BODY=1``).
Returns::
{
"request_id": original id,
"replay_request_id": new id,
"path": /v1/search | /v1/generate | /v1/chat,
"original_body": captured body (truncated to 4KB),
"query": parsed query text,
"replay": {
"sources": [doc_id, ...],
"answer": str|None,
"passed": bool|None,
"omega": float|None,
"timing_ms": dict,
},
"note": optional human-readable comparison hint.
}
Useful for:
* regression debug — "did our new chunker break this query?"
* eval gold augmentation — turn a real user query into a gold case.
* postmortem analysis — replay after a bad deploy to prove harm.
"""
import json as _json
# Find the original row in the obs ring buffer
row = None
for entry in reversed(get_obs().tail(n=10 ** 9, event_type="request")):
if entry.get("request_id") == request_id:
row = entry
break
if row is None:
raise HTTPException(
status_code=404,
detail={"error": "request_id not found in obs log",
"hint": "ensure the request was recorded — "
"obs log is a ring buffer; older rows "
"may have been evicted"},
)
body_txt = (row.get("extra") or {}).get("body")
if not body_txt:
raise HTTPException(
status_code=400,
detail={"error": "no captured body on this request",
"hint": "set TAU_RAG_OBS_CAPTURE_BODY=1 and re-run "
"the original request to enable replay"},
)
try:
payload = _json.loads(body_txt)
except Exception as e:
raise HTTPException(
status_code=400,
detail={"error": "captured body is not valid JSON",
"detail": f"{type(e).__name__}: {e}"})
path = row.get("path") or ""
# Support the 3 replayable endpoints
from ..core.types import Query, Strategy
q_text = payload.get("query")
if not q_text:
raise HTTPException(status_code=400,
detail={"error": "captured body has no 'query' field"})
strategy_name = (payload.get("strategy") or "hybrid").lower()
try:
strategy = Strategy(strategy_name)
except Exception:
strategy = Strategy.HYBRID
q = Query(
text=q_text,
lang=payload.get("lang") or "he",
filters=payload.get("filters") or {},
strategy=strategy,
k=int(payload.get("k") or 10),
rerank_k=int(payload.get("rerank_k") or 5),
)
# Generate a fresh replay_request_id and mark this as a replay in
# obs so the new run is traceable.
replay_id = generate_request_id()
resp = _pipeline.run(q)
# Extract the interesting bits
omega = None
try:
omega = float(resp.signals.omega) if resp.signals else None
except Exception:
pass
verif = getattr(resp, "verification", None)
get_obs().audit(
"replay.executed",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
original_request_id=request_id,
replay_request_id=replay_id,
path=path,
)
return {
"request_id": request_id,
"replay_request_id": replay_id,
"path": path,
"original_body": body_txt,
"query": q_text,
"replay": {
"sources": list(resp.sources or []),
"answer": resp.answer,
"omega": omega,
"passed": (bool(getattr(verif, "passed", False))
if verif else None),
"timing_ms": dict(resp.timing_ms or {}),
},
"note": ("compare 'replay.sources' to whatever the original "
"response had — differences reveal pipeline drift "
"since the original run"),
}
@app.get("/v1/admin/queries/ui", response_class=None)
def admin_query_analytics_ui(
top_n: int = 10,
matrix_queries: int = 6,
matrix_docs: int = 6,
min_count: int = 3,
refresh: int = 0,
):
"""HTML dashboard for query analytics (v1.98).
Merges v1.89 (query_stats), v1.90 (promote candidates), and
v1.96 (query × doc affinity) into one visual page. Same design
language as the v1.86 content-health UI.
Query params:
* ``top_n`` — rows to show in 'top queries' (default 10).
* ``matrix_queries`` — rows in the affinity heatmap (default 6).
* ``matrix_docs`` — cols in the affinity heatmap (default 6).
* ``min_count`` — promote-candidate threshold (default 3).
* ``refresh`` — auto-refresh seconds (0 = off).
"""
from fastapi.responses import HTMLResponse
from ..middleware import (
get_query_stats, get_query_doc_affinity,
)
from ..middleware.query_stats import _canonicalize
from ..presets import get_preset_store
from .query_analytics_ui import render_query_analytics_ui
q_store = get_query_stats()
qda = get_query_doc_affinity()
preset_store = get_preset_store()
summary = q_store.summary()
top_qs = q_store.top(n=int(top_n))
# Promote candidates — reuse same logic as v1.90 endpoint
existing_canonical = set()
for p in preset_store.list_all():
existing_canonical.add(_canonicalize(p.get("query", "")))
promote = []
for row in q_store.top(n=10 ** 9):
if row["count"] < int(min_count):
continue
if row.get("avg_sources", 0.0) < 1.0:
continue
canon = _canonicalize(row["sample"])
if canon in existing_canonical:
continue
promote.append({
**row,
"suggested_preset_name": _suggest_preset_name(row["sample"]),
})
if len(promote) >= 10:
break
# Build affinity matrix grid
matrix_q_rows = q_store.top(n=int(matrix_queries))
# Pick the top docs across the shown queries
doc_votes: Dict[str, int] = {}
for q in matrix_q_rows:
for r in qda.top_docs_for_query(q["fingerprint"], n=10 ** 9):
doc_votes[r["doc_id"]] = doc_votes.get(r["doc_id"], 0) + r["count"]
top_doc_ids = [d for d, _ in sorted(doc_votes.items(),
key=lambda kv: -kv[1])]
top_doc_ids = top_doc_ids[:int(matrix_docs)]
# Pre-compute the (fp, doc_id) → count map for the rendered subset
matrix_pairs: Dict[tuple, int] = {}
for q in matrix_q_rows:
for r in qda.top_docs_for_query(q["fingerprint"], n=10 ** 9):
if r["doc_id"] in top_doc_ids:
matrix_pairs[(q["fingerprint"], r["doc_id"])] = r["count"]
html = render_query_analytics_ui(
summary=summary,
top_queries=top_qs,
promote_candidates=promote,
matrix_queries=matrix_q_rows,
matrix_docs=top_doc_ids,
matrix_pairs=matrix_pairs,
refresh_sec=int(refresh or 0),
)
return HTMLResponse(html)
@app.get("/v1/admin/content/health/ui", response_class=None)
def admin_content_health_ui(
top_n: int = 5,
unused_min_retrieved: int = 3,
refresh: int = 0,
):
"""HTML dashboard for the corpus health report (v1.86). Same data
as ``/v1/admin/content/health`` (v1.85) but rendered as a self-
contained page. ``?refresh=N`` opts into an HTML meta-refresh every
N seconds — handy to leave open on a wall screen."""
from fastapi.responses import HTMLResponse
from .content_health_ui import render_content_health_ui
health = admin_content_health(
top_n=top_n,
unused_min_retrieved=unused_min_retrieved,
)
html = render_content_health_ui(
health, refresh_sec=int(refresh or 0),
)
return HTMLResponse(html)
# ---- Doc freshness tracking (v1.94) -------------------------------------
@app.get("/v1/admin/documents/freshness/summary")
def admin_doc_freshness_summary():
"""Rollup of doc freshness — n_docs, oldest/newest added_at,
median age, total modifications (v1.94)."""
from ..middleware import get_doc_freshness
return get_doc_freshness().summary()
@app.get("/v1/admin/documents/freshness/stale")
def admin_doc_freshness_stale(older_than_days: float = 90.0):
"""Docs whose last activity (modified or added) is older than
``older_than_days`` (v1.94). Oldest-first ordering so content
audit can start at the top."""
from ..middleware import get_doc_freshness
return {
"older_than_days": float(older_than_days),
"stale": get_doc_freshness().stale(
older_than_days=float(older_than_days)),
}
@app.post("/v1/admin/documents/freshness/reset")
def admin_doc_freshness_reset(request: Request):
"""Wipe the freshness side-channel store + audit. Use after a
large corpus reload when old timestamps are meaningless (v1.94)."""
from ..middleware import get_doc_freshness
store = get_doc_freshness()
before = store.summary()
store.clear()
get_obs().audit(
"doc.freshness.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_docs=before["n_docs"],
)
return {"reset": True, "before": before}
# ---- Query × doc affinity (v1.96) ---------------------------------------
# ---- Unified analytics dump/restore (v1.97) -----------------------------
_ANALYTICS_DUMP_VERSION = 1
@app.get("/v1/admin/analytics/dump")
def admin_analytics_dump():
"""Single-call snapshot of all 6 side-channel analytics stores
(v1.97). Returns JSON with one key per store — versioned payload
for migration and offline analysis.
Stores included:
* v1.82 doc_stats
* v1.83 retriever_attribution
* v1.84 cocitation
* v1.89 query_stats
* v1.94 doc_freshness
* v1.96 query_doc_affinity
The ``version`` field signals the dump format — kept as a single
integer so restore can refuse incompatible schemas.
"""
from ..middleware import (
get_doc_stats, get_retriever_attribution,
get_cocitation, get_query_stats, get_doc_freshness,
get_query_doc_affinity,
)
# doc_stats: every tracked doc + raw counters
doc_store = get_doc_stats()
doc_rows = []
# Walk via top_cited(n=10**9) — it returns all rows
for row in doc_store.top_cited(n=10 ** 9):
doc_rows.append({
"doc_id": row["doc_id"],
"n_retrieved": row.get("n_retrieved", 0),
"n_cited": row.get("n_cited", 0),
"first_seen_at": row.get("first_seen_at"),
"last_retrieved_at": row.get("last_retrieved_at"),
"last_cited_at": row.get("last_cited_at"),
})
# retriever_attribution
ra = get_retriever_attribution()
ra_rows = [
{k: v for k, v in row.items() if k != "cite_rate"}
for row in ra.all_stats()
]
# cocitation
cc = get_cocitation()
cc_pairs = cc.top_pairs(n=10 ** 9)
# query_stats
qs = get_query_stats()
qs_rows = []
for row in qs.top(n=10 ** 9):
qs_rows.append({
"fingerprint": row["fingerprint"],
"sample": row["sample"],
"count": row["count"],
"first_seen_at": row["first_seen_at"],
"last_seen_at": row["last_seen_at"],
"sum_sources": row["sum_sources"],
"sum_latency_ms": row["sum_latency_ms"],
})
# doc_freshness
fs = get_doc_freshness()
fs_rows = []
# No top() here, walk via data dict — use all stale(older_than_days=0)
# which gives every row sorted oldest-first; take as-is.
import time as _t
now = _t.time()
for row in fs.stale(older_than_days=0, now=now):
fs_rows.append({
"doc_id": row["doc_id"],
"added_at": row["added_at"],
"last_modified_at": row.get("last_modified_at"),
"n_modifications": row.get("n_modifications", 0),
})
# query_doc_affinity
qda = get_query_doc_affinity()
qda_pairs: List[Dict[str, Any]] = []
qda_summary = qda.summary()
# Use the inverted index: for each query fingerprint, enumerate
# its docs. Cheap — no O(N*M) scan.
for fp in list(qda._by_query.keys()): # noqa: SLF001
for row in qda.top_docs_for_query(fp, n=10 ** 9):
qda_pairs.append({
"fingerprint": fp,
"doc_id": row["doc_id"],
"count": row["count"],
"last_seen": row.get("last_seen"),
})
return {
"version": _ANALYTICS_DUMP_VERSION,
"exported_at": now,
"doc_stats": {
"rows": doc_rows,
"n_rows": len(doc_rows),
},
"retriever_attribution": {
"rows": ra_rows,
"n_rows": len(ra_rows),
},
"cocitation": {
"pairs": cc_pairs,
"n_pairs": len(cc_pairs),
"n_events": cc.summary().get("n_events", 0),
},
"query_stats": {
"rows": qs_rows,
"n_rows": len(qs_rows),
},
"doc_freshness": {
"rows": fs_rows,
"n_rows": len(fs_rows),
},
"query_doc_affinity": {
"pairs": qda_pairs,
"n_pairs": len(qda_pairs),
"n_events": qda_summary.get("n_events", 0),
},
}
class AnalyticsRestoreRequest(BaseModel):
dump: Dict[str, Any]
replace: bool = True # default: wipe before restore
@app.post("/v1/admin/analytics/restore")
def admin_analytics_restore(req: AnalyticsRestoreRequest, request: Request):
"""Rebuild the 6 analytics stores from a v1.97 dump.
``replace=True`` (default) wipes each store before loading —
gives exact-match state after restore. ``replace=False`` merges
on top of existing data (fingerprints / doc_ids collide → counters
SUM). Useful for aggregating traffic across prod nodes.
Refuses to restore from dumps whose ``version`` doesn't match the
current ``_ANALYTICS_DUMP_VERSION`` — schema compatibility gate.
"""
from ..middleware import (
get_doc_stats, set_doc_stats, DocumentStatsStore,
get_retriever_attribution, set_retriever_attribution,
RetrieverAttributionStore,
get_cocitation, set_cocitation, CoCitationStore,
get_query_stats, set_query_stats, QueryStatsStore,
get_doc_freshness, set_doc_freshness, DocFreshnessStore,
get_query_doc_affinity, set_query_doc_affinity,
QueryDocAffinityStore,
)
dump = req.dump or {}
ver = dump.get("version")
if ver != _ANALYTICS_DUMP_VERSION:
raise HTTPException(
status_code=400,
detail={"error": "version mismatch",
"expected": _ANALYTICS_DUMP_VERSION,
"got": ver},
)
totals: Dict[str, int] = {}
# doc_stats
if req.replace:
set_doc_stats(DocumentStatsStore())
doc_store = get_doc_stats()
doc_rows = (dump.get("doc_stats") or {}).get("rows") or []
for row in doc_rows:
did = row.get("doc_id")
if not did:
continue
# Directly seed inner state (avoids driving up counters via
# record() N times when N can be huge on real dumps)
from ..middleware.doc_stats import DocumentStats as _DS
doc_store._data[did] = _DS( # noqa: SLF001
doc_id=did,
n_retrieved=int(row.get("n_retrieved", 0)),
n_cited=int(row.get("n_cited", 0)),
first_seen_at=row.get("first_seen_at"),
last_retrieved_at=row.get("last_retrieved_at"),
last_cited_at=row.get("last_cited_at"),
)
totals["doc_stats"] = len(doc_rows)
# retriever_attribution
if req.replace:
set_retriever_attribution(RetrieverAttributionStore())
ra = get_retriever_attribution()
from ..middleware.retriever_attribution import RetrieverStats as _RS
for row in (dump.get("retriever_attribution") or {}).get("rows") or []:
nm = row.get("name")
if not nm:
continue
ra._data[nm] = _RS( # noqa: SLF001
name=nm,
n_contributed=int(row.get("n_contributed", 0)),
n_doc_contributions=int(row.get("n_doc_contributions", 0)),
n_cited_contributions=int(row.get("n_cited_contributions", 0)),
first_seen_at=row.get("first_seen_at"),
last_seen_at=row.get("last_seen_at"),
)
totals["retriever_attribution"] = len(
(dump.get("retriever_attribution") or {}).get("rows") or [])
# cocitation — replay via record() (preserves partner index)
if req.replace:
set_cocitation(CoCitationStore())
cc = get_cocitation()
n_cc = 0
for pair in (dump.get("cocitation") or {}).get("pairs") or []:
count = int(pair.get("count", 0))
a = pair.get("a"); b = pair.get("b")
if not a or not b or count <= 0:
continue
# Bump the pair count ``count`` times via direct state access —
# replay would create a new n_events per iteration which skews
# the counter.
from ..middleware.cocitation import _pair_key as _pk
k = _pk(a, b)
cc._pairs[k] = int(cc._pairs.get(k, 0)) + count # noqa: SLF001
cc._partners[a].add(b) # noqa: SLF001
cc._partners[b].add(a) # noqa: SLF001
ls = pair.get("last_seen")
if ls is not None:
cc._last_seen[k] = float(ls) # noqa: SLF001
n_cc += 1
n_events_cc = int((dump.get("cocitation") or {}).get("n_events", 0))
if n_events_cc:
cc._n_events = cc._n_events + n_events_cc # noqa: SLF001
totals["cocitation_pairs"] = n_cc
# query_stats
if req.replace:
set_query_stats(QueryStatsStore())
qs = get_query_stats()
from ..middleware.query_stats import QueryStats as _QS
for row in (dump.get("query_stats") or {}).get("rows") or []:
fp = row.get("fingerprint")
if not fp:
continue
existing = qs._data.get(fp) # noqa: SLF001
if existing and not req.replace:
# Merge: add counts; keep earliest first_seen_at; latest
# last_seen_at; sum sources/latency.
existing.count += int(row.get("count", 0))
existing.sum_sources += int(row.get("sum_sources", 0))
existing.sum_latency_ms += float(
row.get("sum_latency_ms", 0.0))
if (row.get("first_seen_at") is not None and
(existing.first_seen_at is None or
row["first_seen_at"] < existing.first_seen_at)):
existing.first_seen_at = row["first_seen_at"]
if (row.get("last_seen_at") is not None and
(existing.last_seen_at is None or
row["last_seen_at"] > existing.last_seen_at)):
existing.last_seen_at = row["last_seen_at"]
else:
qs._data[fp] = _QS( # noqa: SLF001
fingerprint=fp,
sample=row.get("sample", ""),
count=int(row.get("count", 0)),
first_seen_at=row.get("first_seen_at"),
last_seen_at=row.get("last_seen_at"),
sum_sources=int(row.get("sum_sources", 0)),
sum_latency_ms=float(row.get("sum_latency_ms", 0.0)),
)
totals["query_stats"] = len(
(dump.get("query_stats") or {}).get("rows") or [])
# doc_freshness
if req.replace:
set_doc_freshness(DocFreshnessStore())
fs = get_doc_freshness()
from ..middleware.doc_freshness import DocFreshness as _DF
for row in (dump.get("doc_freshness") or {}).get("rows") or []:
did = row.get("doc_id")
if not did:
continue
fs._data[did] = _DF( # noqa: SLF001
doc_id=did,
added_at=float(row.get("added_at") or 0.0),
last_modified_at=row.get("last_modified_at"),
n_modifications=int(row.get("n_modifications", 0)),
)
totals["doc_freshness"] = len(
(dump.get("doc_freshness") or {}).get("rows") or [])
# query_doc_affinity
if req.replace:
set_query_doc_affinity(QueryDocAffinityStore())
qda = get_query_doc_affinity()
n_qda = 0
for pair in (dump.get("query_doc_affinity") or {}).get("pairs") or []:
fp = pair.get("fingerprint"); did = pair.get("doc_id")
count = int(pair.get("count", 0))
if not fp or not did or count <= 0:
continue
k = (fp, did)
qda._pairs[k] = int(qda._pairs.get(k, 0)) + count # noqa: SLF001
qda._by_query[fp].add(did) # noqa: SLF001
qda._by_doc[did].add(fp) # noqa: SLF001
ls = pair.get("last_seen")
if ls is not None:
qda._last_seen[k] = float(ls) # noqa: SLF001
n_qda += 1
n_events_qda = int(
(dump.get("query_doc_affinity") or {}).get("n_events", 0))
if n_events_qda:
qda._n_events = qda._n_events + n_events_qda # noqa: SLF001
totals["query_doc_affinity_pairs"] = n_qda
get_obs().audit(
"analytics.restore",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
version=ver, replace=bool(req.replace),
totals=totals,
)
return {
"restored": True,
"replace": bool(req.replace),
"version": ver,
"totals": totals,
}
@app.get("/v1/admin/queries/affinity/summary")
def admin_query_doc_affinity_summary():
"""Rollup of query × doc affinity matrix: n_events, n_pairs,
n_queries, n_docs, total_count (v1.96)."""
from ..middleware import get_query_doc_affinity
return get_query_doc_affinity().summary()
@app.get("/v1/admin/queries/{fingerprint}/top-docs")
def admin_query_top_docs(fingerprint: str, n: int = 10):
"""Which docs does this query most often cite? (v1.96).
``fingerprint`` is the v1.89 canonical fingerprint. For a text
query, run it through ``_fingerprint(canonical)`` first — or use
v1.89 lookup endpoints."""
from ..middleware import get_query_doc_affinity, get_query_stats
store = get_query_doc_affinity()
rows = store.top_docs_for_query(fingerprint, n=int(n))
# Bonus: include the sample text from v1.89 if known
qs_row = get_query_stats().get(fingerprint)
return {
"fingerprint": fingerprint,
"sample": qs_row.sample if qs_row else None,
"top_docs": rows,
}
@app.get("/v1/admin/documents/{doc_id}/top-queries")
def admin_doc_top_queries(doc_id: str, n: int = 10):
"""Which queries lead to this doc being cited? (v1.96).
Returns fingerprints + counts, with each fingerprint's sample
text attached if still known to v1.89's query_stats store."""
from ..middleware import get_query_doc_affinity, get_query_stats
store = get_query_doc_affinity()
rows = store.top_queries_for_doc(doc_id, n=int(n))
# Enrich with query text samples from v1.89
qs = get_query_stats()
for row in rows:
r = qs.get(row["fingerprint"])
row["sample"] = r.sample if r else None
return {"doc_id": doc_id, "top_queries": rows}
@app.post("/v1/admin/queries/affinity/reset")
def admin_query_doc_affinity_reset(request: Request):
"""Wipe the query × doc affinity matrix + audit (v1.96)."""
from ..middleware import get_query_doc_affinity
store = get_query_doc_affinity()
before = store.summary()
store.clear()
get_obs().audit(
"query_doc_affinity.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_n_pairs=before["n_pairs"],
prev_total_count=before["total_count"],
)
return {"reset": True, "before": before}
@app.get("/v1/admin/documents/update-priorities")
def admin_doc_update_priorities(
n: int = 20,
min_cited: int = 1,
older_than_days: float = 0.0,
alpha: float = 1.0,
):
"""Rank docs by "needs update" priority (v1.95) — cross-cut of
v1.82 doc_stats × v1.94 doc_freshness.
Priority score = ``n_cited * (age_days ** alpha)``. High traffic
plus stale content → high score. Docs get flagged from both
sides: low cite = not worth the review; recently-modified =
doesn't need review yet. ``alpha`` tunes how heavily age
dominates (α>1 = age matters more; α<1 = traffic matters more).
Filters:
* ``min_cited`` — minimum citation count (default 1 — skip
cold docs entirely, they're content-audit's
problem, not update's).
* ``older_than_days`` — minimum age in days (default 0 — let
the caller decide what ''stale'' means).
* ``n`` — cap on rows returned.
* ``alpha`` — exponent on age_days in the score. 1.0 is linear
(balanced); raise to prioritize stale content
harder.
"""
from ..middleware import get_doc_stats, get_doc_freshness
import time as _t
docs = get_doc_stats()
fresh = get_doc_freshness()
now = _t.time()
# Walk the smaller side (whichever has fewer entries) and join.
# We go through doc_stats (usually <= corpus size) because stale
# docs with zero traffic are noise here — we want things BOTH
# sides know about.
rows: List[Dict[str, Any]] = []
for doc_row in docs.top_cited(n=10 ** 9):
did = doc_row["doc_id"]
if doc_row["n_cited"] < int(min_cited):
continue
f = fresh.get(did)
if f is None:
continue
ref_ts = f.last_modified_at or f.added_at
age_days = max(0.0, (now - ref_ts) / 86400.0)
if age_days < float(older_than_days):
continue
try:
aged = age_days ** float(alpha)
except (OverflowError, ValueError):
aged = age_days
score = float(doc_row["n_cited"]) * aged
rows.append({
"doc_id": did,
"n_cited": doc_row["n_cited"],
"n_retrieved": doc_row["n_retrieved"],
"cite_rate": doc_row.get("cite_rate", 0.0),
"added_at": f.added_at,
"last_modified_at": f.last_modified_at,
"age_days": round(age_days, 2),
"n_modifications": f.n_modifications,
"priority_score": round(score, 2),
})
rows.sort(key=lambda r: r["priority_score"], reverse=True)
rows = rows[:max(0, int(n))]
return {
"n_candidates": len(rows),
"n": int(n),
"min_cited": int(min_cited),
"older_than_days": float(older_than_days),
"alpha": float(alpha),
"candidates": rows,
}
@app.get("/v1/documents/{doc_id}/freshness")
def get_document_freshness(doc_id: str):
"""Per-doc freshness: added_at, last_modified_at, n_modifications,
age_s, age_days (v1.94)."""
from ..middleware import get_doc_freshness
row = get_doc_freshness().get(doc_id)
if row is None:
raise HTTPException(
status_code=404,
detail={"error": "no freshness record for doc",
"doc_id": doc_id},
)
return row.to_dict()
@app.get("/v1/documents/{doc_id}/related")
def get_document_related(doc_id: str, n: int = 10):
"""Docs most commonly co-cited with ``doc_id`` in actual traffic
(v1.84). Empirical 'related' — purely behavioural."""
from ..middleware import get_cocitation
related = get_cocitation().related(doc_id, n=int(n))
return {"doc_id": doc_id, "related": related}
@app.get("/v1/documents/{doc_id}/stats")
def get_document_stats(doc_id: str):
"""Per-document citation + retrieval counters (v1.82)."""
from ..middleware import get_doc_stats
row = get_doc_stats().get(doc_id)
if row is None:
raise HTTPException(
status_code=404,
detail={"error": "no stats for doc", "doc_id": doc_id},
)
return row.to_dict()
@app.get("/v1/documents/{doc_id}/chunks")
def get_document_chunks(doc_id: str, chunker: Optional[str] = None):
"""Return the chunks the retrievers actually index for this doc. Re-runs
the configured chunker on-demand; pass ``?chunker=sentence`` to preview
alternative chunkings without changing the index."""
d = _pipeline.get_document(doc_id)
if d is None:
raise HTTPException(status_code=404, detail="document not found")
chunks = _pipeline.get_chunks(doc_id, chunker=chunker)
return {
"doc_id": doc_id,
"n_chunks": len(chunks),
"chunker": chunker or getattr(_pipeline, "_chunker_last", "fixed"),
"chunks": chunks,
}
@app.get("/v1/documents/{doc_id}")
def get_document(doc_id: str):
d = _pipeline.get_document(doc_id)
if d is None:
raise HTTPException(status_code=404, detail="document not found")
return {"id": d.id, "text": d.text, "metadata": d.metadata or {}}
@app.put("/v1/documents/{doc_id}")
def replace_document(doc_id: str, body: DocumentBody):
if body.id != doc_id:
raise HTTPException(status_code=422,
detail={"path_id_mismatch":
{"url": doc_id, "body": body.id}})
try:
# Reuse the doc-size validator
validate_doc_list([body])
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
doc = Document(id=body.id, text=body.text, metadata=body.metadata)
ok = _pipeline.replace_document(doc)
if not ok:
raise HTTPException(status_code=404, detail="document not found")
return {"replaced": True, "id": doc_id}
@app.delete("/v1/documents/{doc_id}")
def delete_document(doc_id: str):
ok = _pipeline.delete_document(doc_id)
if not ok:
raise HTTPException(status_code=404, detail="document not found")
return {"deleted": True, "id": doc_id}
@app.delete("/v1/documents")
def clear_documents(request: Request):
n = _pipeline.clear_documents()
get_obs().audit(
"documents.clear",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
removed=n,
)
return {"cleared": True, "removed": n}
# ---- bulk ingest (JSONL + CSV streaming) ------------------------------
import csv as _csv # noqa: E402
import io as _io # noqa: E402
import json # noqa: E402
def _parse_jsonl(text: str):
"""Yield (row_num, doc_dict_or_error) tuples for a JSONL payload.
Blank lines and `#` comment lines are skipped silently."""
for i, line in enumerate(text.splitlines(), start=1):
line = line.strip()
if not line or line.startswith("#"):
continue
try:
obj = json.loads(line)
if not isinstance(obj, dict):
raise ValueError("row is not a JSON object")
yield i, obj
except Exception as e:
yield i, {"__error__": str(e)}
def _parse_csv(text: str):
"""Yield (row_num, doc_dict_or_error) tuples for a CSV payload.
Expects columns: id, text, and optional metadata columns merged into
a single metadata dict."""
reader = _csv.DictReader(_io.StringIO(text))
if reader.fieldnames is None or "id" not in reader.fieldnames \
or "text" not in reader.fieldnames:
raise HTTPException(
status_code=400,
detail={"csv_missing_columns": "required: 'id' and 'text'"},
)
meta_cols = [c for c in reader.fieldnames if c not in ("id", "text")]
for i, row in enumerate(reader, start=2): # row 1 is header
try:
metadata = {c: row[c] for c in meta_cols if row.get(c) not in (None, "")}
yield i, {"id": row["id"], "text": row["text"],
"metadata": metadata}
except Exception as e:
yield i, {"__error__": str(e)}
@app.post("/v1/documents/bulk")
async def bulk_ingest_documents(request: Request):
"""Bulk ingest — JSONL (one ``{"id","text","metadata"}`` per line) or
CSV (columns: id, text, [any other] → metadata). Partial success
semantics: each row parsed+validated independently, successes indexed
into the pipeline, failures reported with row numbers.
Content-Type:
* ``application/x-ndjson`` or ``application/jsonl`` → JSONL
* ``text/csv`` → CSV
* anything else → JSONL (default)
"""
ct = (request.headers.get("content-type") or "").split(";", 1)[0].strip().lower()
raw = (await request.body()).decode("utf-8", errors="replace")
if ct in ("text/csv",):
iterator = _parse_csv(raw)
else:
iterator = _parse_jsonl(raw)
# Enforce per-batch size limit from v1.35
from .errors import Limits
accepted: List[Document] = []
errors: List[Dict[str, Any]] = []
row_n = 0
for row_num, obj in iterator:
row_n += 1
if "__error__" in obj:
errors.append({"row": row_num, "error": obj["__error__"]})
continue
text = obj.get("text")
if not isinstance(text, str) or not text:
errors.append({"row": row_num,
"error": "missing or empty 'text'"})
continue
if len(text) > Limits.max_doc_text_len:
errors.append({"row": row_num,
"error": f"text exceeds {Limits.max_doc_text_len} chars"})
continue
if len(accepted) >= Limits.max_docs_per_batch:
errors.append({"row": row_num,
"error": f"batch cap reached (max {Limits.max_docs_per_batch})"})
continue
accepted.append(Document(
id=obj.get("id") or f"row-{row_num}",
text=text,
metadata=obj.get("metadata") or {},
))
chunks = _pipeline.add_documents(accepted) if accepted else 0
get_obs().audit(
"documents.bulk_ingest",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
content_type=ct, rows_total=row_n,
accepted=len(accepted), errors=len(errors),
)
return {
"accepted": [d.id for d in accepted],
"errors": errors,
"added_chunks": chunks,
"rows_total": row_n,
}
@app.post("/v1/search")
def search(req: SearchRequest):
try:
validate_query_text(req.query)
validate_k(req.k)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
strategy = Strategy(req.strategy)
except ValueError:
raise HTTPException(status_code=400,
detail={"bad_strategy": req.strategy})
q = Query(
text=req.query, lang=req.lang, filters=req.filters,
strategy=strategy, k=req.k, rerank_k=req.rerank_k,
)
per = _pipeline.retrievers.search_per_retriever(q, req.k)
return {
"per_retriever": {
name: [{"doc": r.chunk.doc_id, "chunk": r.chunk.chunk_id,
"score": r.score, "rank": r.rank, "text": r.chunk.text[:300]}
for r in lst]
for name, lst in per.items()
}
}
# ---- batch query (v1.54) -------------------------------------------------
class BatchQueryItem(BaseModel):
query: str
k: int = 10
rerank_k: int = 5
strategy: str = "hybrid"
lang: str = "he"
filters: Dict[str, Any] = {}
class BatchQueryRequest(BaseModel):
queries: List[BatchQueryItem]
@app.post("/v1/batch")
def batch_query(req: BatchQueryRequest, request: Request):
"""Run many queries in a single HTTP call — useful for eval runners,
benchmarks, and bulk re-indexing workflows.
Cap: ``Limits.max_docs_per_batch`` items per call (reusing the doc-limit
env knob). Each item is validated independently; per-item errors are
returned alongside successful responses. Cache + rate-limit apply to
the overall request, not per-item (so a single admin call can sweep
many queries without tripping the limiter).
"""
from .errors import Limits
if not req.queries:
return {"n": 0, "results": [], "errors": [], "total_ms": 0}
if len(req.queries) > Limits.max_docs_per_batch:
raise HTTPException(
status_code=413,
detail=f"too many queries — max {Limits.max_docs_per_batch} per batch",
)
import time as _t
t0 = _t.time()
results: List[Dict[str, Any]] = []
errors: List[Dict[str, Any]] = []
for i, item in enumerate(req.queries, start=1):
try:
validate_query_text(item.query)
validate_k(item.k)
strategy = Strategy(item.strategy)
q = Query(text=item.query, lang=item.lang, filters=item.filters,
strategy=strategy, k=item.k, rerank_k=item.rerank_k)
resp = _pipeline.run(q)
try:
omega = float(resp.signals.omega) if resp.signals else None
except Exception:
omega = None
verif = getattr(resp, "verification", None)
results.append({
"index": i,
"query": item.query,
"answer": resp.answer or "",
"sources": list(resp.sources or []),
"omega": omega,
"passed": bool(getattr(verif, "passed", False)) if verif else None,
})
except HTTPException as e:
errors.append({"index": i, "error": str(e.detail),
"status": e.status_code})
except Exception as e:
errors.append({"index": i,
"error": f"{type(e).__name__}: {e}"[:240],
"status": 500})
total_ms = (_t.time() - t0) * 1000.0
get_obs().audit(
"batch.query",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
n=len(req.queries), errors=len(errors),
total_ms=round(total_ms, 2),
)
return {
"n": len(req.queries),
"results": results,
"errors": errors,
"total_ms": round(total_ms, 2),
"avg_ms": round(total_ms / max(1, len(req.queries)), 2),
}
@app.post("/v1/generate/stream", response_class=None)
def generate_stream(req: SearchRequest, request: Request):
"""Server-Sent Events version of /v1/generate.
Emits events in order:
event: retrieved data: {"doc_ids": [...], "count": N}
event: answer data: {"chunk": "word "} (repeated)
event: done data: {"answer","sources","signals","verification",
"passed","omega"}
event: error data: {"code","message"} (on failure)
Flow: runs the full pipeline synchronously (it's ~ms for extractive)
and streams the staged output. Each SSE event is .
Clients:
* browser: EventSource('/v1/generate/stream' ... POST)
* curl --no-buffer -N -H 'Content-Type: application/json' \
-d '{"query":"..."}' http://localhost:8000/v1/generate/stream
* SDK: for ev in client.stream_query("..."): ...
"""
try:
validate_query_text(req.query)
validate_k(req.k)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
strategy = Strategy(req.strategy)
except ValueError:
raise HTTPException(status_code=400,
detail={"bad_strategy": req.strategy})
from fastapi.responses import StreamingResponse
import json as _json
def _sse(event: str, data: Any) -> str:
return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n"
def _event_gen():
try:
q = Query(text=req.query, lang=req.lang, filters=req.filters,
strategy=strategy, k=req.k, rerank_k=req.rerank_k)
resp = _pipeline.run(q)
# Stage 1: retrieval results
retrieved = []
seen = set()
for c in getattr(resp, "retrieved", []) or []:
did = getattr(getattr(c, "chunk", None), "doc_id", None)
if did and did not in seen:
retrieved.append(did)
seen.add(did)
yield _sse("retrieved",
{"doc_ids": retrieved, "count": len(retrieved)})
# Stage 2: answer streamed word-by-word
answer = resp.answer or ""
words = answer.split(" ")
for w in words:
if not w:
continue
yield _sse("answer", {"chunk": w + " "})
# Stage 3: final envelope
try:
omega = float(resp.signals.omega) if resp.signals else None
except Exception:
omega = None
verif = getattr(resp, "verification", None)
yield _sse("done", {
"answer": answer,
"sources": list(resp.sources or []),
"omega": omega,
"passed": bool(getattr(verif, "passed", False)) if verif else None,
"verification": (verif.to_dict() if hasattr(verif, "to_dict")
else getattr(verif, "__dict__", None)),
})
except Exception as e:
yield _sse("error", {
"code": "internal_error",
"message": f"{type(e).__name__}: {e}"[:240],
})
return StreamingResponse(
_event_gen(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # nginx — flush immediately
},
)
@app.post("/v1/generate")
def generate(req: SearchRequest):
try:
validate_query_text(req.query)
validate_k(req.k)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
strategy = Strategy(req.strategy)
except ValueError:
raise HTTPException(status_code=400,
detail={"bad_strategy": req.strategy})
# Cache hit short-circuit
cache = get_cache()
cache_key = cache.make_key(
f"{req.query}|{req.strategy}|{req.k}|{req.rerank_k}",
req.lang, req.filters,
)
cached = cache.get(cache_key)
if cached is not None:
cached = dict(cached)
cached["_cache"] = "hit"
return cached
q = Query(
text=req.query, lang=req.lang, filters=req.filters,
strategy=strategy, k=req.k, rerank_k=req.rerank_k,
)
out = _pipeline.run(q).to_dict()
cache.put(cache_key, out)
out = dict(out); out["_cache"] = "miss"
return out
# ---- saved query presets (v1.66) ----------------------------------------
class QueryPresetBody(BaseModel):
query: str
k: int = 10
rerank_k: int = 5
strategy: str = "hybrid"
lang: str = "he"
notes: str = ""
@app.get("/v1/queries")
def list_query_presets():
"""List all saved query presets. Unauthenticated (queries are public)."""
from ..presets import get_preset_store
presets = get_preset_store().list_all()
return {"count": len(presets), "presets": presets}
@app.get("/v1/queries/{name}")
def get_query_preset(name: str):
from ..presets import get_preset_store
p = get_preset_store().get(name)
if p is None:
raise HTTPException(status_code=404,
detail={"preset_not_found": name})
return p.to_dict()
@app.put("/v1/queries/{name}")
def save_query_preset(name: str, body: QueryPresetBody, request: Request):
from ..presets import QueryPreset, get_preset_store
try:
validate_query_text(body.query)
validate_k(body.k)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
preset = QueryPreset(
name=name, query=body.query, k=body.k, rerank_k=body.rerank_k,
strategy=body.strategy, lang=body.lang, notes=body.notes,
)
get_preset_store().put(preset)
get_obs().audit(
"query_preset.put",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
name=name,
)
return preset.to_dict()
@app.delete("/v1/queries/{name}")
def delete_query_preset(name: str, request: Request):
from ..presets import get_preset_store
ok = get_preset_store().remove(name)
if not ok:
raise HTTPException(status_code=404,
detail={"preset_not_found": name})
get_obs().audit(
"query_preset.remove",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
name=name,
)
return {"removed": True, "name": name}
@app.post("/v1/queries/{name}/run")
def run_query_preset(name: str):
"""Execute a saved preset — equivalent to ``POST /v1/generate`` with
the stored parameters. Returns the full /generate response."""
from ..presets import get_preset_store
p = get_preset_store().get(name)
if p is None:
raise HTTPException(status_code=404,
detail={"preset_not_found": name})
req = SearchRequest(query=p.query, k=p.k, rerank_k=p.rerank_k,
strategy=p.strategy, lang=p.lang, filters={})
return generate(req)
@app.post("/v1/generate/timings")
def generate_timings(req: SearchRequest):
"""Run the query but return only the per-stage latency breakdown + Ω —
no answer text, no candidate list. For ops/profiling workflows that need
to know WHERE time is spent, not WHAT was returned.
Response shape:
{
"query": "...",
"timings_ms": {understand, retrieve, fuse, rerank, generate,
verify, signals, total},
"omega": 0.67,
"n_sources": 2,
"cache": "miss"
}
"""
try:
validate_query_text(req.query)
validate_k(req.k)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
strategy = Strategy(req.strategy)
except ValueError:
raise HTTPException(status_code=400,
detail={"bad_strategy": req.strategy})
q = Query(
text=req.query, lang=req.lang, filters=req.filters,
strategy=strategy, k=req.k, rerank_k=req.rerank_k,
)
resp = _pipeline.run(q)
try:
omega = float(resp.signals.omega) if resp.signals else None
except Exception:
omega = None
return {
"query": req.query,
"timings_ms": dict(resp.timing_ms or {}),
"omega": omega,
"n_sources": len(resp.sources or []),
"passed": (bool(resp.verification.passed)
if resp.verification else None),
}
@app.get("/v1/admin/stats")
def admin_stats():
from ..middleware import get_webhook_dispatcher
return {
"cache": get_cache().stats(),
"rate_limiter": get_limiter().stats(),
"observability": get_obs().stats(),
"cached_queries": len(_pipeline.cache),
"webhook": get_webhook_dispatcher().stats(),
}
@app.get("/v1/admin/webhook")
def admin_webhook_stats():
"""Report the audit webhook dispatcher state (v1.71) including the
circuit breaker status (v1.79)."""
from ..middleware import get_webhook_dispatcher
return get_webhook_dispatcher().stats()
# ---- alert rules (v1.80) -------------------------------------------------
class AlertRuleRequest(BaseModel):
name: str
metric: str
op: str
threshold: float
window_s: float = 300.0
cooldown_s: float = 600.0
enabled: bool = True
description: str = ""
@app.get("/v1/admin/alerts")
def admin_alerts_list():
"""List all configured alert rules (v1.80)."""
from ..middleware import get_alert_store
return {"rules": [r.to_dict() for r in get_alert_store().list_all()]}
@app.put("/v1/admin/alerts/{name}")
def admin_alerts_put(name: str, req: AlertRuleRequest, request: Request):
"""Create or update an alert rule. ``name`` in the path wins over
the body — consistent with PUT semantics."""
from ..middleware import get_alert_store, AlertRule
try:
rule = AlertRule(
name=name,
metric=req.metric, op=req.op,
threshold=float(req.threshold),
window_s=float(req.window_s),
cooldown_s=float(req.cooldown_s),
enabled=bool(req.enabled),
description=req.description,
)
rule = get_alert_store().put(rule)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
get_obs().audit(
"alert.rule.put",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
name=name, metric=req.metric, op=req.op,
threshold=float(req.threshold),
)
return rule.to_dict()
@app.delete("/v1/admin/alerts/{name}")
def admin_alerts_delete(name: str, request: Request):
from ..middleware import get_alert_store
ok = get_alert_store().delete(name)
if not ok:
raise HTTPException(status_code=404,
detail={"error": "alert rule not found",
"name": name})
get_obs().audit(
"alert.rule.delete",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
name=name,
)
return {"deleted": True, "name": name}
@app.get("/v1/admin/alerts/scheduler")
def admin_alerts_scheduler_status():
"""Report the background alert-evaluator state (v1.81)."""
from ..middleware import get_alert_scheduler
sched = get_alert_scheduler()
if sched is None:
return {"enabled": False, "is_running": False}
return {"enabled": True, **sched.status()}
@app.post("/v1/admin/alerts/evaluate")
def admin_alerts_evaluate(request: Request):
"""Run every enabled rule once against the current metrics history.
Returns a list of verdicts; rules that fire also emit an
``alert.fired`` audit event (which flows through the webhook
dispatcher per v1.71)."""
from ..middleware import (
get_alert_store, get_metrics_history, evaluate_all,
)
actor = request.headers.get("x-api-key")
rid = getattr(request.state, "request_id", None)
def _on_fire(verdict):
get_obs().audit(
"alert.fired",
actor_key=actor, request_id=rid,
rule=verdict["rule"],
reason=verdict["reason"],
latest_value=verdict["latest_value"],
n_samples=verdict["n_samples"],
)
verdicts = evaluate_all(
get_alert_store(), get_metrics_history(), on_fire=_on_fire,
)
return {
"verdicts": verdicts,
"n_fired": sum(1 for v in verdicts if v["fired"]),
"n_suppressed": sum(1 for v in verdicts if v["suppressed"]),
"n_evaluated": len(verdicts),
}
@app.post("/v1/admin/webhook/breaker/reset")
def admin_webhook_breaker_reset(request: Request):
"""Manually force the audit-webhook circuit breaker back to CLOSED
and clear its failure counters (v1.79). Useful after fixing a
downstream outage without waiting for the cooldown probe."""
from ..middleware import get_webhook_dispatcher
disp = get_webhook_dispatcher()
before = disp.breaker.stats()
disp.breaker.reset()
after = disp.breaker.stats()
get_obs().audit(
"webhook.breaker.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
prev_state=before["state"],
prev_fail_count=before["fail_count"],
)
return {"reset": True, "before": before, "after": after}
@app.get("/v1/admin/metrics/history")
def admin_metrics_history(
since: Optional[float] = None,
until: Optional[float] = None,
metric: Optional[str] = None,
limit: int = 1000,
):
"""Return sampled time-series metrics (v1.78).
Query params:
* ``since`` / ``until`` — Unix timestamps bounding the window.
* ``metric`` — dotted path (e.g. ``obs.p95_ms``, ``cache.hit_rate``,
``limiter.denied``) to project each sample to
``{ts, value}``. Omit for full samples.
* ``limit`` — hard cap on rows returned (default 1000).
The sampler is off by default; enable via
``TAU_RAG_METRICS_HISTORY_INTERVAL_SEC=10`` at server start, or
call ``MetricsHistorySampler(h, interval_s=10.0).start()`` directly.
"""
from ..middleware import get_metrics_history, get_metrics_sampler
h = get_metrics_history()
rows = h.history(since=since, until=until, metric=metric)
# Enforce upper bound — take the newest ``limit`` rows.
if limit and len(rows) > int(limit):
rows = rows[-int(limit):]
sampler = get_metrics_sampler()
return {
"samples": rows,
"count": len(rows),
"capacity": h.capacity(),
"metric": metric,
"sampler": (sampler.status() if sampler else
{"is_running": False, "interval_s": None}),
}
@app.post("/v1/admin/metrics/history/sample")
def admin_metrics_history_sample_now(request: Request):
"""Force one immediate sample (useful for tests and 'capture before
change' workflows). Returns the sample that was just captured."""
from ..middleware import get_metrics_history
row = get_metrics_history().sample()
get_obs().audit(
"metrics.sample",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
ts=row["ts"],
)
return row
@app.post("/v1/admin/cache/clear")
def admin_cache_clear(request: Request):
get_cache().clear()
_pipeline.cache.clear()
get_obs().audit(
"cache.clear",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
)
return {"cleared": True}
@app.get("/v1/admin/logs")
def admin_logs(n: int = 50, event_type: Optional[str] = None):
"""Tail the in-memory request/audit log. `event_type=audit` to see
just admin actions, `event_type=request` for HTTP traffic."""
return {"logs": get_obs().tail(n=n, event_type=event_type)}
@app.get("/v1/admin/audit/export", response_class=None)
def admin_audit_export(
since: Optional[float] = None,
until: Optional[float] = None,
event_type: Optional[str] = None,
format: str = "jsonl",
limit: int = 10_000,
):
"""Export the observability ring buffer as JSONL (default) or JSON.
Query params:
* ``since`` — Unix timestamp, inclusive lower bound on entry.ts
* ``until`` — Unix timestamp, exclusive upper bound
* ``event_type`` — ``audit`` | ``request`` (omit = both)
* ``format`` — ``jsonl`` (default, streamable) or ``json``
* ``limit`` — hard cap on rows returned (default 10,000)
Returns a file attachment with ``Content-Disposition``; intended to be
piped to compliance storage, fed to a SIEM, or spot-checked with jq.
"""
from fastapi.responses import PlainTextResponse
import json as _json
rows = get_obs().tail(n=max(1, int(limit)),
event_type=event_type)
# Apply time filters (tail already returns newest-last)
if since is not None:
rows = [r for r in rows if float(r.get("ts") or 0) >= float(since)]
if until is not None:
rows = [r for r in rows if float(r.get("ts") or 0) < float(until)]
if format == "json":
body = _json.dumps(rows, ensure_ascii=False, indent=2)
media = "application/json"
suffix = "json"
else: # default jsonl
body = "\n".join(_json.dumps(r, ensure_ascii=False) for r in rows)
if rows:
body += "\n"
media = "application/x-ndjson"
suffix = "jsonl"
filename = f"tau-rag-audit.{suffix}"
return PlainTextResponse(
body,
media_type=media,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"X-Entry-Count": str(len(rows)),
},
)
@app.get("/v1/admin/logs/stream", response_class=None)
def admin_logs_stream(
event_type: Optional[str] = None,
heartbeat_s: float = 15.0,
replay_last: int = 0,
max_events: int = 0,
max_heartbeats: int = 0,
):
"""Live SSE tail of the observability log (v1.75).
Query params:
* ``event_type`` — ``audit`` | ``request`` (omit = both)
* ``heartbeat_s`` — emit a ``:heartbeat`` comment every N seconds
so proxies/load balancers don't idle-kill the
connection (default 15s).
* ``replay_last`` — on connect, emit the last N buffered entries
before live-tailing (0 = live only).
* ``max_events`` — stop after this many ``log`` events (0 =
unbounded). Useful for bounded tails and
deterministic testing.
* ``max_heartbeats`` — stop after this many heartbeat ticks (0 =
unbounded). Useful for "give me whatever's
there within N seconds, then close".
Event names:
* ``log`` — new entry (data is the same dict as ``/v1/admin/logs``)
* (heartbeat is a ``:`` SSE comment line, per the SSE spec)
Client disconnect frees the subscriber queue; drop-oldest on a slow
reader so the request path is never back-pressured.
"""
from fastapi.responses import StreamingResponse
import json as _json
import queue as _q
obs = get_obs()
sub = obs.subscribe(maxsize=256)
def _sse(event: str, data: Any) -> str:
return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n"
def _gen():
emitted_logs = 0
heartbeats = 0
try:
# Replay tail first (filtered by event_type if set)
if replay_last and replay_last > 0:
for row in obs.tail(n=int(replay_last), event_type=event_type):
yield _sse("log", row)
emitted_logs += 1
if max_events and emitted_logs >= max_events:
return
# Live loop — block up to heartbeat_s for next entry; if
# nothing arrived emit a comment keep-alive.
hb = max(0.05, float(heartbeat_s))
while True:
try:
row = sub.get(timeout=hb)
except _q.Empty:
heartbeats += 1
yield ": heartbeat\n\n"
if max_heartbeats and heartbeats >= max_heartbeats:
return
continue
if event_type and row.get("event_type") != event_type:
continue
yield _sse("log", row)
emitted_logs += 1
if max_events and emitted_logs >= max_events:
return
except GeneratorExit:
pass
finally:
obs.unsubscribe(sub)
return StreamingResponse(
_gen(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
@app.get("/v1/admin/ui", response_class=None)
def admin_ui(refresh: int = 0):
"""Unified read-only admin dashboard (HTML)."""
from fastapi.responses import HTMLResponse
from .admin_ui import render_admin_ui
from .metrics import check_readiness
_, ready_detail = check_readiness(_pipeline)
html = render_admin_ui(
cache_stats = get_cache().stats(),
limiter_stats = get_limiter().stats(),
obs_stats = get_obs().stats(),
recent_requests = get_obs().tail(n=20, event_type="request"),
recent_audits = get_obs().tail(n=20, event_type="audit"),
keys = get_auth().list_keys(),
documents = _pipeline.list_documents(),
readiness = ready_detail,
refresh_sec = int(refresh or 0),
)
return HTMLResponse(html)
# ---- snapshot / restore --------------------------------------------------
class SnapshotSaveRequest(BaseModel):
path: Optional[str] = None
rotate: int = 0 # v1.67: keep last N rotated generations
class SnapshotLoadRequest(BaseModel):
path: Optional[str] = None
replace: bool = False
generation: int = 0 # v1.67: load a specific rotated generation
def _default_snapshot_path() -> str:
return _os.environ.get("TAU_RAG_SNAPSHOT_PATH") or "runtime/snapshot.jsonl"
# ---- eval harness endpoint ----------------------------------------------
class EvalRequest(BaseModel):
cases: List[Dict[str, Any]]
k: int = 5
thresholds: Optional[Dict[str, float]] = None
@app.post("/v1/admin/eval")
def admin_eval(req: EvalRequest, request: Request):
"""Run the pipeline against a gold set inline; return aggregate metrics.
Body: ``{"cases": [{id, query, expected_doc_ids, expected_claims?, lang?}, ...],
"k": 5, "thresholds": {"recall@5": 0.7, ...}}``
"""
from ..eval import GoldCase, run_eval
cases = [GoldCase.from_dict(c) for c in req.cases]
report = run_eval(_pipeline, cases, k=req.k)
failures = report.fail_below(req.thresholds or {})
get_obs().audit(
"eval.run",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
n_cases=report.n_cases,
aggregate=report.aggregate,
failures=failures,
)
return {
"n_cases": report.n_cases,
"aggregate": report.aggregate,
"latency_ms": report.latency_ms,
"omega": report.omega,
"per_case": [c.to_dict() for c in report.per_case],
"failures": failures,
"passed": len(failures) == 0,
}
@app.post("/v1/admin/snapshot/save")
def admin_snapshot_save(req: SnapshotSaveRequest, request: Request):
path = req.path or _default_snapshot_path()
summary = _pipeline.save_snapshot(path, rotate=int(req.rotate or 0))
get_obs().audit(
"snapshot.save",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
**summary,
)
return summary
@app.get("/v1/admin/snapshot/history")
def admin_snapshot_history(path: Optional[str] = None, max_gens: int = 10):
"""List rotated snapshot generations on disk. ``path`` defaults to the
env-configured snapshot path."""
from ..snapshot import list_snapshot_history
base = path or _default_snapshot_path()
return {
"base_path": base,
"generations": list_snapshot_history(base, max_gens=max_gens),
}
class SnapshotDiffRequest(BaseModel):
a: str # path to snapshot A (the "before")
b: str # path to snapshot B (the "after")
include_details: bool = False # expand modified[] to {id, lens, hashes}
@app.post("/v1/admin/snapshot/diff")
def admin_snapshot_diff(req: SnapshotDiffRequest, request: Request):
"""Compare two snapshots (v1.77). Returns added / removed / modified
doc IDs + per-snapshot metadata + a ``same_fingerprint`` boolean.
Use cases:
* CI: "what's new in this PR?" (diff main vs branch)
* QA: "did we accidentally delete docs?" (diff yesterday vs today)
* compliance: "what changed between quarters?"
"""
from ..snapshot import diff_snapshots
import os as _os
for label, p in (("a", req.a), ("b", req.b)):
if not _os.path.exists(p):
raise HTTPException(
status_code=404,
detail={"error": "snapshot not found",
"which": label, "path": p},
)
result = diff_snapshots(req.a, req.b,
include_details=bool(req.include_details))
get_obs().audit(
"snapshot.diff",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
a=req.a, b=req.b,
n_added=len(result["added"]),
n_removed=len(result["removed"]),
n_modified=len(result["modified"]),
)
return result
# ---- runtime config (v1.51) ---------------------------------------------
class ConfigUpdateRequest(BaseModel):
# Whitelist of tunable fields. Add to `_TUNABLE` below before exposing.
updates: Dict[str, Any]
# (path-in-Config, validator_fn, human_description)
_TUNABLE: Dict[str, Any] = {
"verify.min_omega": (
lambda v: isinstance(v, (int, float)) and 0.0 <= float(v) <= 1.0,
"Minimum Ω signal for response.passed; 0.55 default",
),
"verify.min_citation_coverage": (
lambda v: isinstance(v, (int, float)) and 0.0 <= float(v) <= 1.0,
"Fraction of answer claims that must be cited; 0.8 default",
),
}
def _config_to_dict(cfg) -> Dict[str, Any]:
"""Recursively convert the nested Config dataclass to a plain dict."""
from dataclasses import is_dataclass, asdict
if is_dataclass(cfg):
return asdict(cfg)
return dict(cfg)
def _apply_config_update(cfg, key_path: str, value: Any) -> None:
"""Set ``cfg.a.b.c = value`` for a dotted ``key_path``."""
parts = key_path.split(".")
obj = cfg
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)
@app.get("/v1/admin/config")
def admin_get_config():
"""Return the live effective configuration + list of tunable keys."""
return {
"config": _config_to_dict(_pipeline.config),
"tunable": {
k: {"description": desc}
for k, (_, desc) in _TUNABLE.items()
},
}
@app.post("/v1/admin/config")
def admin_update_config(req: ConfigUpdateRequest, request: Request):
"""Update whitelisted config values at runtime. Clears the query cache
so subsequent requests use the new thresholds."""
applied: Dict[str, Any] = {}
rejected: List[Dict[str, Any]] = []
for key, new_val in req.updates.items():
entry = _TUNABLE.get(key)
if entry is None:
rejected.append({"key": key, "reason": "not in whitelist"})
continue
validator, _desc = entry
if not validator(new_val):
rejected.append({"key": key, "reason": "validation failed",
"value": new_val})
continue
try:
_apply_config_update(_pipeline.config, key, new_val)
applied[key] = new_val
except Exception as e:
rejected.append({"key": key,
"reason": f"{type(e).__name__}: {e}"})
# Clear query cache so future calls use the new threshold
if applied:
_pipeline.cache.clear()
get_cache().clear()
get_obs().audit(
"config.update",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
applied=applied, rejected=rejected,
)
return {"applied": applied, "rejected": rejected,
"cache_cleared": bool(applied)}
# ---- Hebrew synonyms CRUD (v1.52) --------------------------------------
class SynonymBody(BaseModel):
canonical: str
variants: List[str]
def _maybe_autosave_synonyms() -> None:
"""If TAU_RAG_SYNONYMS_PATH is set, persist current dict to disk."""
path = _os.environ.get("TAU_RAG_SYNONYMS_PATH")
if not path:
return
try:
from ..core.hebrew_synonyms import save_synonyms_jsonl
save_synonyms_jsonl(path)
except Exception as _e:
print(f"[tau-rag] synonym autosave failed: {_e}")
@app.get("/v1/admin/synonyms")
def admin_list_synonyms(q: Optional[str] = None):
"""List synonym entries. ``?q=`` filters by substring in canonical or variants."""
from ..core.hebrew_synonyms import list_synonyms
all_syn = list_synonyms()
if q:
qn = q.strip()
all_syn = {
k: v for k, v in all_syn.items()
if qn in k or any(qn in x for x in v)
}
return {"count": len(all_syn), "synonyms": all_syn}
class SynonymBulkRequest(BaseModel):
entries: List[Dict[str, Any]] # [{canonical, variants}, ...]
replace: bool = False
@app.post("/v1/admin/synonyms/bulk")
def admin_bulk_synonyms(req: SynonymBulkRequest, request: Request):
"""Bulk add/replace synonyms. Each row: {canonical, variants: [...]}."""
from ..core.hebrew_synonyms import (
add_synonym, clear_synonyms as _clear,
)
if req.replace:
_clear()
added = 0
errors: List[Dict[str, Any]] = []
for i, e in enumerate(req.entries, start=1):
try:
add_synonym(e["canonical"], list(e.get("variants") or []))
added += 1
except Exception as ex:
errors.append({"row": i, "error": str(ex)})
_pipeline.cache.clear()
get_cache().clear()
_maybe_autosave_synonyms()
get_obs().audit(
"synonyms.bulk",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
added=added, errors=len(errors), replaced=req.replace,
)
return {"added": added, "errors": errors, "replaced": req.replace}
@app.get("/v1/admin/synonyms/export", response_class=None)
def admin_export_synonyms():
"""Export the current synonyms as JSONL — one ``{canonical, variants}`` per line."""
from fastapi.responses import PlainTextResponse
from ..core.hebrew_synonyms import list_synonyms
import json as _json
lines = [
_json.dumps({"canonical": k, "variants": v}, ensure_ascii=False)
for k, v in list_synonyms().items()
]
body = "\n".join(lines) + ("\n" if lines else "")
return PlainTextResponse(body, media_type="application/x-ndjson",
headers={
"Content-Disposition":
'attachment; filename="synonyms.jsonl"',
})
@app.post("/v1/admin/synonyms")
def admin_add_synonym(body: SynonymBody, request: Request):
"""Add a synonym entry or extend an existing one's variant list."""
from ..core.hebrew_synonyms import add_synonym
try:
result = add_synonym(body.canonical, body.variants)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
# Purge caches — old answers may have used stale expansion
_pipeline.cache.clear()
get_cache().clear()
_maybe_autosave_synonyms()
get_obs().audit(
"synonyms.add",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
canonical=body.canonical, n_variants=len(result["variants"]),
)
return result
@app.delete("/v1/admin/synonyms/{canonical}")
def admin_delete_synonym(canonical: str, request: Request):
from ..core.hebrew_synonyms import remove_synonym
ok = remove_synonym(canonical)
if not ok:
raise HTTPException(status_code=404,
detail={"canonical": canonical})
_pipeline.cache.clear()
get_cache().clear()
_maybe_autosave_synonyms()
get_obs().audit(
"synonyms.remove",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
canonical=canonical,
)
return {"removed": True, "canonical": canonical}
@app.get("/v1/admin/snapshot/status")
def admin_snapshot_status():
"""Report the periodic auto-snapshotter state (if configured)."""
auto = get_autosnapshotter()
if auto is None:
return {
"enabled": False,
"hint": ("set TAU_RAG_SNAPSHOT_PATH and "
"TAU_RAG_SNAPSHOT_INTERVAL= to enable"),
}
return {"enabled": True, **auto.status()}
@app.post("/v1/admin/snapshot/load")
def admin_snapshot_load(req: SnapshotLoadRequest, request: Request):
base_path = req.path or _default_snapshot_path()
# Resolve which generation to load: 0 = current, N>0 = rotated backup
from ..snapshot import _gen_path
from pathlib import Path as _P
resolved = str(_gen_path(_P(base_path), int(req.generation or 0)))
if not _os.path.exists(resolved):
raise HTTPException(
status_code=404,
detail={"snapshot_not_found": resolved,
"generation": int(req.generation or 0)},
)
summary = _pipeline.load_snapshot(resolved, replace=req.replace)
summary["generation"] = int(req.generation or 0)
summary["path_loaded"] = resolved
get_obs().audit(
"snapshot.load",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
path=resolved,
**{k: v for k, v in summary.items()
if k not in ("warnings", "generation", "path_loaded")},
)
return summary
# ---- API key management (admin-only endpoints) ---------------------------
class APIKeyCreateRequest(BaseModel):
label: str
scopes: List[str] = ["read", "write"]
@app.post("/v1/admin/keys")
def admin_create_key(req: APIKeyCreateRequest, request: Request):
raw = get_auth().create(label=req.label, scopes=req.scopes)
get_obs().audit(
"key.create",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
label=req.label, scopes=req.scopes,
)
return {
"api_key": raw,
"label": req.label,
"scopes": req.scopes,
"warning": "save this key now — it cannot be retrieved later",
}
@app.get("/v1/admin/keys")
def admin_list_keys():
return {"keys": get_auth().list_keys()}
@app.delete("/v1/admin/keys/{hash_prefix}")
def admin_revoke_key(hash_prefix: str, request: Request):
ok = get_auth().revoke(hash_prefix)
if not ok:
raise HTTPException(status_code=404, detail="key not found")
get_obs().audit(
"key.revoke",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
target_prefix=hash_prefix,
)
return {"revoked": True, "hash_prefix": hash_prefix}
class APIKeyRotateRequest(BaseModel):
grace_seconds: float = 300.0
@app.post("/v1/admin/keys/{hash_prefix}/rotate")
def admin_rotate_key(
hash_prefix: str,
req: APIKeyRotateRequest,
request: Request,
):
"""Rotate an API key with a grace period (v1.76).
Generates a new key with the same label+scopes. The old key remains
valid until ``grace_seconds`` elapses, after which it stops working.
Clients should rotate their config during the window.
"""
result = get_auth().rotate(hash_prefix, grace_seconds=req.grace_seconds)
if result is None:
raise HTTPException(
status_code=404,
detail={"error": "key not found or already revoked",
"hash_prefix": hash_prefix},
)
get_obs().audit(
"key.rotate",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
old_prefix=result["old_prefix"],
new_prefix=result["new_prefix"],
grace_seconds=result["grace_seconds"],
)
return {
**result,
"warning": "save the new key now — it cannot be retrieved later",
}
# ---- v2.7 Maintenance / drain mode -------------------------------------
class MaintenanceOnRequest(BaseModel):
reason: str = ""
retry_after: int = 30
@app.post("/v1/admin/maintenance/on")
def admin_maintenance_on(req: MaintenanceOnRequest, request: Request):
"""Turn on maintenance / drain mode (v2.7).
Non-admin requests get 503 + ``Retry-After`` until turned off.
Admin callers (this endpoint included) always flow through.
"""
from ..middleware.maintenance import get_maintenance
m = get_maintenance()
m.enable(reason=req.reason, retry_after=req.retry_after)
get_obs().audit(
"maintenance.on",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
reason=req.reason,
retry_after=req.retry_after,
)
return {"ok": True, **m.snapshot()}
@app.post("/v1/admin/maintenance/off")
def admin_maintenance_off(request: Request):
"""Clear maintenance / drain mode (v2.7)."""
from ..middleware.maintenance import get_maintenance
m = get_maintenance()
snap_before = m.snapshot()
m.disable()
get_obs().audit(
"maintenance.off",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
duration_sec=round(snap_before["duration_sec"], 2),
)
return {"ok": True, **m.snapshot()}
@app.get("/v1/admin/maintenance")
def admin_maintenance_status():
"""Current maintenance / drain state (v2.7)."""
from ..middleware.maintenance import get_maintenance
return get_maintenance().snapshot()
# ---- v2.8 PII redaction -------------------------------------------------
class PIIToggleRequest(BaseModel):
enabled: bool = True
@app.post("/v1/admin/pii_redaction/toggle")
def admin_pii_toggle(req: PIIToggleRequest, request: Request):
"""Enable or disable PII redaction at runtime (v2.8).
Default is driven by ``TAU_RAG_PII_REDACT`` env at startup; this
endpoint lets ops toggle without a restart (useful when you realize
bodies were going to the log unredacted).
"""
from ..middleware.pii_redaction import get_pii_redactor
r = get_pii_redactor()
r.set_enabled(req.enabled)
get_obs().audit(
"pii_redaction.toggle",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
enabled=req.enabled,
)
return {"ok": True, **r.stats()}
@app.get("/v1/admin/pii_redaction/stats")
def admin_pii_stats():
"""PII redactor counters: how many IDs/phones/emails/CCs have been
scrubbed since startup (or last reset)."""
from ..middleware.pii_redaction import get_pii_redactor
return get_pii_redactor().stats()
@app.post("/v1/admin/pii_redaction/reset")
def admin_pii_reset(request: Request):
"""Zero the per-kind counters. Does not change enabled state."""
from ..middleware.pii_redaction import get_pii_redactor
r = get_pii_redactor()
r.reset()
get_obs().audit(
"pii_redaction.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
)
return {"ok": True, **r.stats()}
# ---- v2.9 Slow-query detection -----------------------------------------
class SlowThresholdRequest(BaseModel):
ms: float
@app.get("/v1/admin/slow_queries")
def admin_slow_queries(n: int = 20):
"""Top-N slowest requests + per-path aggregates + summary stats.
``n`` bounds the ``top`` list length (default 20, max 100).
"""
from ..middleware.slow_queries import get_slow_tracker
n = min(100, max(1, int(n)))
t = get_slow_tracker()
return {
"stats": t.stats(),
"top": t.top_n(n),
"by_path": t.by_path(),
}
@app.post("/v1/admin/slow_queries/threshold")
def admin_slow_threshold(req: SlowThresholdRequest, request: Request):
"""Set the slow-query threshold in ms at runtime. 0 disables."""
from ..middleware.slow_queries import get_slow_tracker
t = get_slow_tracker()
old = t.threshold_ms
t.set_threshold(req.ms)
get_obs().audit(
"slow_queries.threshold",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
old_ms=old,
new_ms=t.threshold_ms,
)
return {"ok": True, **t.stats()}
@app.post("/v1/admin/slow_queries/reset")
def admin_slow_reset(request: Request):
"""Clear the ring buffer and per-path aggregates."""
from ..middleware.slow_queries import get_slow_tracker
t = get_slow_tracker()
t.reset()
get_obs().audit(
"slow_queries.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
)
return {"ok": True, **t.stats()}
# ---- v2.12 Daily quota -------------------------------------------------
class QuotaSetRequest(BaseModel):
key_prefix: str
limit: int
@app.get("/v1/admin/quotas")
def admin_quotas_get():
"""Dump all quota state: limits, usage, day cursor, reset timer."""
from ..middleware.quota import get_quota_tracker
return get_quota_tracker().stats()
@app.post("/v1/admin/quotas")
def admin_quota_set(req: QuotaSetRequest, request: Request):
"""Set a daily quota for a key (by hash prefix).
``limit=0`` = unlimited (also equivalent to deleting the quota).
Usage counters are NOT reset — operator can raise the cap
mid-day without wiping the meter.
"""
from ..middleware.quota import get_quota_tracker
t = get_quota_tracker()
t.set_quota(req.key_prefix, req.limit)
get_obs().audit(
"quota.set",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
key_prefix=req.key_prefix,
limit=req.limit,
)
return {"ok": True, "key_prefix": req.key_prefix,
"limit": req.limit}
@app.delete("/v1/admin/quotas/{key_prefix}")
def admin_quota_clear(key_prefix: str, request: Request):
"""Remove quota enforcement for a key entirely (back to unlimited
and zero-out its usage counter)."""
from ..middleware.quota import get_quota_tracker
t = get_quota_tracker()
removed = t.clear_quota(key_prefix)
if not removed:
raise HTTPException(
status_code=404,
detail={"error": "no quota set for this key_prefix",
"key_prefix": key_prefix},
)
get_obs().audit(
"quota.clear",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
key_prefix=key_prefix,
)
return {"ok": True, "key_prefix": key_prefix, "removed": True}
@app.post("/v1/admin/quotas/reset")
def admin_quota_reset_all(request: Request):
"""Wipe all quota state — limits AND usage. Testing / incident
recovery. Audit-logged."""
from ..middleware.quota import get_quota_tracker
t = get_quota_tracker()
t.reset()
get_obs().audit(
"quota.reset_all",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
)
return {"ok": True, **t.stats()}
# ---- v2.13 Idempotency -------------------------------------------------
@app.get("/v1/admin/idempotency/stats")
def admin_idempotency_stats():
from ..middleware.idempotency import get_idempotency_store
return get_idempotency_store().stats()
@app.post("/v1/admin/idempotency/reset")
def admin_idempotency_reset(request: Request):
from ..middleware.idempotency import get_idempotency_store
get_idempotency_store().reset()
get_obs().audit(
"idempotency.reset",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
)
return {"ok": True, **get_idempotency_store().stats()}
class IdemTTLRequest(BaseModel):
ttl_sec: float
@app.post("/v1/admin/idempotency/ttl")
def admin_idempotency_ttl(req: IdemTTLRequest, request: Request):
from ..middleware.idempotency import get_idempotency_store
s = get_idempotency_store()
s.set_ttl(req.ttl_sec)
get_obs().audit(
"idempotency.ttl",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
ttl_sec=req.ttl_sec,
)
return {"ok": True, **s.stats()}
# ---- v2.14 Request timeout ---------------------------------------------
class TimeoutRequest(BaseModel):
timeout_ms: float
@app.get("/v1/admin/request_timeout")
def admin_request_timeout_stats():
from ..middleware.request_timeout import get_timeout_guard
return get_timeout_guard().stats()
@app.post("/v1/admin/request_timeout")
def admin_request_timeout_set(req: TimeoutRequest, request: Request):
"""Set wall-clock request timeout in ms. 0 disables."""
from ..middleware.request_timeout import get_timeout_guard
g = get_timeout_guard()
old = g.timeout_ms
g.set_timeout_ms(req.timeout_ms)
get_obs().audit(
"request_timeout.set",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
old_ms=old, new_ms=g.timeout_ms,
)
return {"ok": True, **g.stats()}
@app.post("/v1/admin/request_timeout/reset")
def admin_request_timeout_reset(request: Request):
from ..middleware.request_timeout import get_timeout_guard
g = get_timeout_guard()
g.reset()
return {"ok": True, **g.stats()}
@app.get("/v1/signals/latest")
def latest_signals():
if not _pipeline.cache:
return {"empty": True}
last = list(_pipeline.cache.values())[-1]
return last.signals.to_dict()
# ---- Chat / multi-turn ---------------------------------------------------
class ChatRequest(BaseModel):
query: str
session_id: str
lang: str = "he"
@app.post("/v1/chat/stream", response_class=None)
def chat_stream(req: ChatRequest, request: Request):
"""Server-Sent Events version of /v1/chat — same conversational context
as /v1/chat (follow-up expansion, session history) but streaming.
Event order:
event: followup data: {"is_followup": bool, "expanded_query": "..."}
event: retrieved data: {"doc_ids": [...], "count": N}
event: answer data: {"chunk": "word "} (repeated)
event: done data: {session_id, n_turns, answer, sources,
omega, passed, verification}
event: error data: {"code","message"} (on failure)
"""
try:
validate_query_text(req.query)
except OverflowError as e:
raise HTTPException(status_code=413, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
from fastapi.responses import StreamingResponse
import json as _json
def _sse(event: str, data: Any) -> str:
return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n"
def _event_gen():
try:
# Detect follow-up + expand (same logic as run_conversation)
from ..memory import expand_followup, get_store, is_followup
store = get_store()
session = store.get_or_create(req.session_id)
followup = bool(is_followup(req.query, lang=req.lang)) and bool(session.turns)
expanded = expand_followup(req.query, session) if followup else req.query
yield _sse("followup", {
"is_followup": followup,
"expanded_query": expanded if followup else None,
})
# Run the pipeline (extractive — ~ms)
resp = _pipeline.run_conversation(req.query, req.session_id,
lang=req.lang)
# Stage 1: retrieval results
retrieved = []
seen = set()
for c in getattr(resp, "retrieved", []) or []:
did = getattr(getattr(c, "chunk", None), "doc_id", None)
if did and did not in seen:
retrieved.append(did)
seen.add(did)
yield _sse("retrieved",
{"doc_ids": retrieved, "count": len(retrieved)})
# Stage 2: answer streamed word-by-word
answer = resp.answer or ""
for w in answer.split(" "):
if not w:
continue
yield _sse("answer", {"chunk": w + " "})
# Stage 3: final envelope — includes session_id + n_turns for client
try:
omega = float(resp.signals.omega) if resp.signals else None
except Exception:
omega = None
verif = getattr(resp, "verification", None)
# Re-read session — run_conversation added a new turn
session = store.get_or_create(req.session_id)
yield _sse("done", {
"session_id": req.session_id,
"n_turns": len(session.turns),
"answer": answer,
"sources": list(resp.sources or []),
"omega": omega,
"passed": bool(getattr(verif, "passed", False)) if verif else None,
"verification": (verif.to_dict() if hasattr(verif, "to_dict")
else getattr(verif, "__dict__", None)),
})
except Exception as e:
yield _sse("error", {
"code": "internal_error",
"message": f"{type(e).__name__}: {e}"[:240],
})
return StreamingResponse(
_event_gen(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/v1/chat")
def chat(req: ChatRequest):
"""Conversational query — carries prior-turn context for follow-ups."""
resp = _pipeline.run_conversation(
query_text=req.query,
session_id=req.session_id,
lang=req.lang,
)
return resp.to_dict()
@app.get("/v1/sessions/{session_id}")
def session_info(session_id: str):
from ..memory import get_store
s = get_store().get(session_id)
if s is None:
raise HTTPException(status_code=404, detail="session not found")
return s.to_dict()
@app.delete("/v1/sessions/{session_id}")
def session_drop(session_id: str):
from ..memory import get_store
dropped = get_store().drop(session_id)
return {"dropped": dropped}
@app.get("/v1/sessions")
def sessions_list(
details: bool = False,
min_turns: int = 0,
limit: int = 500,
):
"""List session ids. ``?details=1`` returns the richer per-session view
(created_at, last_activity_ts, idle_sec, ttl_remaining_sec, n_turns,
last_query, last_sources) sorted by most-recent-activity-first.
``?min_turns=N`` filters; ``?limit=`` caps."""
from ..memory import get_store
store = get_store()
if not details:
return {"sessions": store.list_ids()}
rows = [r for r in store.summaries() if r["n_turns"] >= min_turns]
if limit and limit > 0:
rows = rows[:limit]
return {"count": len(rows), "sessions": rows}
@app.post("/v1/sessions/gc")
def admin_sessions_gc(request: Request):
"""Force immediate TTL + capacity eviction. Returns
``{before, expired, overflow_evicted, after}``."""
from ..memory import get_store
summary = get_store().gc_now()
get_obs().audit(
"sessions.gc",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
**summary,
)
return summary
@app.delete("/v1/sessions")
def admin_sessions_drop_all(request: Request):
"""Drop every session. Returns ``{dropped: N}``. Audit-logged."""
from ..memory import get_store
n = get_store().drop_all()
get_obs().audit(
"sessions.drop_all",
actor_key=request.headers.get("x-api-key"),
request_id=getattr(request.state, "request_id", None),
dropped=n,
)
return {"dropped": n}
def get_pipeline() -> Pipeline:
return _pipeline
def set_pipeline(new: Pipeline) -> None:
"""Hook for tests / production to swap in a real pipeline."""
global _pipeline
_pipeline = new