"""FastAPI app. Run: uvicorn tau_rag.api.fastapi_app:app --reload """ from __future__ import annotations from typing import Any, Dict, List, Optional try: from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from pydantic import BaseModel except Exception as e: # pragma: no cover raise RuntimeError( "FastAPI not installed. `pip install tau-rag[api]`." ) from e from ..core.config import Config from ..core.types import Document, Query, Strategy from ..pipeline import Pipeline from .errors import ( ErrorCode, Limits, build_error_body, validate_query_text, validate_doc_list, validate_k, ) app = FastAPI(title="TAU-RAG", version="2.0.0") # ---------------------------------------------------- CORS + security from fastapi.middleware.cors import CORSMiddleware # noqa: E402 from .security import cors_config_from_env, apply_security_headers # noqa: E402 _cors_cfg = cors_config_from_env() if _cors_cfg["allow_origins"]: app.add_middleware(CORSMiddleware, **_cors_cfg) @app.middleware("http") async def _security_headers_middleware(request, call_next): response = await call_next(request) apply_security_headers(response.headers) return response # ---------------------------------------------------- global error handlers def _rid_from(request) -> Optional[str]: return getattr(getattr(request, "state", None), "request_id", None) @app.exception_handler(HTTPException) async def _http_exc_handler(request, exc: HTTPException): # Map FastAPI's status codes to our canonical codes code_map = { 400: ErrorCode.VALIDATION_ERROR, 401: ErrorCode.UNAUTHORIZED, 403: ErrorCode.ADMIN_REQUIRED, 404: ErrorCode.NOT_FOUND, 413: ErrorCode.PAYLOAD_TOO_LARGE, 422: ErrorCode.VALIDATION_ERROR, 429: ErrorCode.RATE_LIMITED, } code = code_map.get(exc.status_code, ErrorCode.INTERNAL_ERROR) detail = exc.detail message = detail if isinstance(detail, str) else "request failed" details = detail if isinstance(detail, dict) else None body = build_error_body(code, message, _rid_from(request), details) headers = dict(exc.headers or {}) rid = _rid_from(request) if rid: headers["X-Request-ID"] = rid return JSONResponse(status_code=exc.status_code, content=body, headers=headers) @app.exception_handler(RequestValidationError) async def _validation_exc_handler(request, exc: RequestValidationError): body = build_error_body( ErrorCode.VALIDATION_ERROR, "request failed validation", _rid_from(request), details={"errors": exc.errors()}, ) headers = {"X-Request-ID": _rid_from(request)} if _rid_from(request) else {} return JSONResponse(status_code=422, content=body, headers=headers) @app.exception_handler(Exception) async def _unhandled_exc_handler(request, exc: Exception): body = build_error_body( ErrorCode.INTERNAL_ERROR, f"{type(exc).__name__}: {exc}"[:300], _rid_from(request), ) headers = {"X-Request-ID": _rid_from(request)} if _rid_from(request) else {} return JSONResponse(status_code=500, content=body, headers=headers) # ----------------------------------------------------------------- middleware from fastapi import Request from fastapi.responses import JSONResponse from ..middleware import get_cache, get_limiter from ..middleware.auth import get_auth from ..middleware.ratelimit import RateLimitExceeded from ..middleware.observability import ( get_obs, generate_request_id, RequestLog, _hash_prefix, ) from ..middleware.maintenance import get_maintenance from ..middleware.pii_redaction import get_pii_redactor from ..middleware.slow_queries import get_slow_tracker, SlowRecord from ..middleware.quota import get_quota_tracker from ..middleware.idempotency import get_idempotency_store from ..middleware.request_timeout import get_timeout_guard import asyncio as _asyncio import time as _time @app.middleware("http") async def auth_and_ratelimit_middleware(request: Request, call_next): t0 = _time.time() path = request.url.path protected = (path.startswith("/v1/generate") or path.startswith("/v1/chat") or path.startswith("/v1/documents")) admin_only = path.startswith("/v1/admin/") api_key = request.headers.get("x-api-key") # X-Request-ID — honor client-supplied or generate our own request_id = request.headers.get("x-request-id") or generate_request_id() # Stash on request.state so handlers can correlate if they want request.state.request_id = request_id # v1.99 — optional body capture for replay. Opt-in via env var # TAU_RAG_OBS_CAPTURE_BODY=1. Only captures bodies for the replay- # able endpoints (search/generate/chat) to keep log size bounded # and avoid picking up admin request bodies with keys. captured_body: Optional[str] = None _REPLAY_CAPTURE_PATHS = ( "/v1/search", "/v1/generate", "/v1/chat", ) if (_os.environ.get("TAU_RAG_OBS_CAPTURE_BODY") == "1" and request.method == "POST" and any(path.startswith(p) for p in _REPLAY_CAPTURE_PATHS)): try: raw = await request.body() # Truncate to 4KB — real queries are under 1KB, legal # texts rarely reach 2KB. Very-long payloads get flagged # but replay won't work on them (acceptable). if raw is not None: captured_body = raw[:4096].decode("utf-8", errors="replace") # v2.8 — PII redaction. When TAU_RAG_PII_REDACT=1 (or # admin flipped via endpoint), scrub Israeli IDs, phone # numbers, emails, and CC-like digit runs from the # captured text BEFORE it hits the observability log, # JSONL file, stdout, or SSE tail. No-op when disabled. captured_body = get_pii_redactor().redact(captured_body) # Put the already-read body back on the request so the # downstream handler still sees it. async def _receive() -> Dict[str, Any]: return {"type": "http.request", "body": raw, "more_body": False} request._receive = _receive # type: ignore[attr-defined] except Exception: captured_body = None def _log(status: int, error: Optional[str] = None) -> None: extra: Dict[str, Any] = {} if error: extra["error"] = error if captured_body is not None: extra["body"] = captured_body latency_ms = (_time.time() - t0) * 1000.0 get_obs().record(RequestLog( ts=_time.time(), request_id=request_id, method=request.method, path=path, status=status, latency_ms=latency_ms, key_hash_prefix=_hash_prefix(api_key), client_ip=(request.client.host if request.client else None), user_agent=request.headers.get("user-agent"), event_type="request", extra=extra, )) # v2.9 — also feed the slow-query tracker. No-op if threshold=0 # or if the request was fast enough. Kept off the observability # log hot-path: cheap dict append in the tracker. try: get_slow_tracker().maybe_record(SlowRecord( ts=_time.time(), request_id=request_id, method=request.method, path=path, status=status, latency_ms=latency_ms, error=error, )) except Exception: pass # 1. Auth check (only if TAU_RAG_REQUIRE_AUTH is set OR admin path) auth = get_auth() # v2.7 — maintenance / drain mode. Admin traffic always flows (so # operators can turn it off again); everyone else gets 503 + # Retry-After. Check happens AFTER auth object is available so we # can ask ``is_admin(key)`` but BEFORE rate limiting — otherwise a # drained pod would count rejected requests against the limiter, # polluting stats. # v2.11 — k8s probes (/livez, /readyz) must always reach the # handler so the probe reflects true readiness. Drain is ONE of # several reasons a pod might be unready; the probe itself (via # the readiness registry's ``not_draining`` check) signals it. # Blocking the probe at middleware level would mask other # unreadiness signals during drain. _PROBE_PATHS = ("/livez", "/readyz") maint = get_maintenance() if (maint.is_enabled() and not admin_only and not auth.is_admin(api_key) and path not in _PROBE_PATHS): snap = maint.snapshot() _log(503, error="maintenance") return JSONResponse( status_code=503, headers={ "Retry-After": str(int(snap["retry_after"])), "X-Request-ID": request_id, }, content=build_error_body( ErrorCode.INTERNAL_ERROR, "service is in maintenance mode", request_id=request_id, details={ "reason": snap["reason"], "retry_after": int(snap["retry_after"]), "maintenance_since_sec": round( snap["duration_sec"], 2), }, ), ) if admin_only: if not auth.is_admin(api_key): _log(401, error="admin_required") return JSONResponse( status_code=401, headers={"X-Request-ID": request_id}, content=build_error_body( ErrorCode.ADMIN_REQUIRED, "admin scope required", request_id=request_id, details={"hint": "pass X-API-Key with admin scope"}, ), ) elif protected and auth.required: scope = "write" if request.method in ("POST", "PUT", "DELETE") else "read" if not auth.validate(api_key, scope=scope): _log(401, error="unauthorized") return JSONResponse( status_code=401, headers={"X-Request-ID": request_id}, content=build_error_body( ErrorCode.UNAUTHORIZED, "missing or invalid X-API-Key", request_id=request_id, details={"required_scope": scope}, ), ) # 2. v2.12 — Per-API-key daily quota. Runs BEFORE rate limit so # the per-second limiter doesn't deduct tokens for requests that # will be rejected anyway. Skipped for: # - unauthenticated paths (no key to meter) # - whitelisted clients (same as rate limiter whitelist) # - admin-only paths (admin already auth'd; no quota) # Only applies when the key actually has a quota configured — # unquotaed keys are unlimited (same as pre-v2.12 behavior). if api_key and not admin_only: client_ip = request.client.host if request.client else None if (api_key not in get_limiter().whitelist and client_ip not in get_limiter().whitelist): qc = get_quota_tracker().check_and_increment(api_key) if not qc.ok: _log(429, error="quota_exceeded") return JSONResponse( status_code=429, headers={ "Retry-After": str(qc.reset_in_sec), "X-Request-ID": request_id, "X-Quota-Limit": str(qc.limit), "X-Quota-Used": str(qc.used), }, content=build_error_body( ErrorCode.RATE_LIMITED, f"daily quota exceeded: {qc.used}/{qc.limit}", request_id=request_id, details={ "quota": "daily", "limit": qc.limit, "used": qc.used, "reset_in_sec": qc.reset_in_sec, "key_prefix": qc.key_prefix, }, ), ) # 3. Rate limit (skip admin — already auth'd) if protected: try: client_key = ( api_key or (request.client.host if request.client else "unknown") ) # v1.73 — pass path so the limiter can apply per-endpoint overrides get_limiter().acquire(client_key, path=path) except RateLimitExceeded as e: _log(429, error="rate_limited") return JSONResponse( status_code=429, headers={"Retry-After": f"{e.retry_after:.1f}", "X-Request-ID": request_id}, content=build_error_body( ErrorCode.RATE_LIMITED, f"rate limit exceeded for {e.key!r}", request_id=request_id, details={"retry_after": round(e.retry_after, 3), "key": e.key}, ), ) # v2.13 — Idempotency-Key check BEFORE dispatch. Scoped to # (api_key_prefix, idempotency_key) so two clients using the same # key don't collide. POST only, whitelisted paths only. _IDEMPOTENT_PATHS = ("/v1/generate", "/v1/chat", "/v1/search") idem_key = request.headers.get("idempotency-key") idem_eligible = ( idem_key and request.method == "POST" and any(path.startswith(p) for p in _IDEMPOTENT_PATHS) ) if idem_eligible: idem_scope = _hash_prefix(api_key) or ( request.client.host if request.client else "anon") cached = get_idempotency_store().get(idem_scope, idem_key) if cached is not None: _log(cached.status, error="idempotent_hit") headers = { "X-Request-ID": request_id, "X-Idempotency-Hit": "1", "X-Idempotency-Key": idem_key, **cached.headers_extra, } return JSONResponse( status_code=cached.status, content=cached.body, headers=headers, ) # 3. Dispatch + log # v2.4 — set request_id on the tracer's thread-local so all pipeline # spans (v1.27) created during handler execution get auto-tagged. try: from ..observability.tracing import get_tracer _t = get_tracer() _t.set_request_id(request_id) except Exception: _t = None # v2.14 — optional wall-clock timeout enforcement. 0 (default) = # disabled, preserving pre-v2.14 behavior. _guard = get_timeout_guard() _guard.record_request() try: if _guard.is_enabled(): try: response = await _asyncio.wait_for( call_next(request), timeout=_guard.timeout_ms / 1000.0, ) except _asyncio.TimeoutError: _guard.record_timeout() _log(504, error="request_timeout") return JSONResponse( status_code=504, headers={"X-Request-ID": request_id}, content=build_error_body( ErrorCode.INTERNAL_ERROR, f"request exceeded {_guard.timeout_ms:.0f}ms timeout", request_id=request_id, details={"timeout_ms": _guard.timeout_ms, "path": path}, ), ) else: response = await call_next(request) except Exception as e: _log(500, error=f"{type(e).__name__}: {e}") raise finally: if _t is not None: try: _t.set_request_id(None) except Exception: pass response.headers["X-Request-ID"] = request_id # v2.13 — cache successful responses for idempotency replay. # Consumes body_iterator and reconstructs a response so downstream # still sees the full body. if idem_eligible and 200 <= response.status_code < 300: try: body_bytes = b"" async for chunk in response.body_iterator: # type: ignore[attr-defined] body_bytes += chunk import json as _json try: body_json = _json.loads(body_bytes.decode("utf-8")) except Exception: body_json = None if body_json is not None: get_idempotency_store().set( idem_scope, idem_key, response.status_code, body_json, ) from fastapi.responses import Response as _Resp response = _Resp( content=body_bytes, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, ) except Exception: pass _log(response.status_code) return response # ----------------------------------------------------------------- startup def _pipeline_from_env() -> Pipeline: """Pick config via env var TAU_RAG_PRESET so the same container image can run different flavors (no_llm, hebrew_dense, mock).""" import os preset = os.environ.get("TAU_RAG_PRESET", "no_llm") cfg = { "mock": Config.mock, "default": Config.default, "hebrew_legal": Config.hebrew_legal, "no_llm": Config.no_llm, "hebrew_dense": Config.hebrew_dense, }.get(preset, Config.mock)() return Pipeline.from_config(cfg) # One shared pipeline. Swap the config for production. _pipeline: Pipeline = _pipeline_from_env() # Auto-restore from snapshot if TAU_RAG_SNAPSHOT_PATH is set import os as _os _snap_path = _os.environ.get("TAU_RAG_SNAPSHOT_PATH") if _snap_path and _os.path.exists(_snap_path): try: _restore_summary = _pipeline.load_snapshot(_snap_path, replace=False) print(f"[tau-rag] restored from snapshot: {_restore_summary}") except Exception as _e: print(f"[tau-rag] snapshot restore failed: {_e}") # Periodic auto-snapshot — fires every N seconds as a crash-proofing measure. from ..snapshot import AutoSnapshotter, set_autosnapshotter, get_autosnapshotter # noqa: E402 _snap_interval = _os.environ.get("TAU_RAG_SNAPSHOT_INTERVAL") if _snap_path and _snap_interval: try: _iv = float(_snap_interval) if _iv > 0: _auto = AutoSnapshotter( _pipeline, _snap_path, interval_sec=_iv, on_save=lambda s: get_obs().audit( "snapshot.auto_periodic", **s), ) _auto.start() set_autosnapshotter(_auto) print(f"[tau-rag] periodic auto-snapshot every {_iv}s → {_snap_path}") except Exception as _e: print(f"[tau-rag] periodic snapshot setup failed: {_e}") # Periodic metrics history sampler (v1.78) — optional, enabled by env var. try: _metrics_iv_raw = _os.environ.get("TAU_RAG_METRICS_HISTORY_INTERVAL_SEC") if _metrics_iv_raw: _metrics_iv = float(_metrics_iv_raw) if _metrics_iv > 0: from ..middleware import ( MetricsHistory, MetricsHistorySampler, get_metrics_history, set_metrics_sampler, ) _mcap = int(_os.environ.get( "TAU_RAG_METRICS_HISTORY_CAPACITY", "720")) # Replace the default history with one sized from env. _h = MetricsHistory(max_samples=max(10, _mcap)) from ..middleware import set_metrics_history set_metrics_history(_h) _sampler = MetricsHistorySampler( _h, interval_s=_metrics_iv, sample_on_start=True, ) _sampler.start() set_metrics_sampler(_sampler) print(f"[tau-rag] metrics history sampler every " f"{_metrics_iv}s cap={_mcap}") except Exception as _e: print(f"[tau-rag] metrics history sampler setup failed: {_e}") # Background analytics retention scheduler (v1.93) — optional, enabled # by TAU_RAG_ANALYTICS_TTL_DAYS. try: _ttl_raw = _os.environ.get("TAU_RAG_ANALYTICS_TTL_DAYS") if _ttl_raw: _ttl_days = float(_ttl_raw) if _ttl_days > 0: _prune_iv = float(_os.environ.get( "TAU_RAG_ANALYTICS_PRUNE_INTERVAL_SEC", "3600")) from ..middleware import ( AnalyticsRetentionScheduler, set_retention_scheduler, ) def _analytics_prune_cb(summary): try: if summary.get("total_removed", 0) > 0: get_obs().audit( "analytics.prune.auto", ttl_seconds=summary.get("ttl_seconds"), total_removed=summary.get("total_removed"), ) elif summary.get("error"): get_obs().audit( "analytics.prune.auto.error", error=summary["error"], ) except Exception: pass _retention = AnalyticsRetentionScheduler( ttl_seconds=_ttl_days * 86400.0, interval_s=_prune_iv, on_prune=_analytics_prune_cb, ) _retention.start() set_retention_scheduler(_retention) print(f"[tau-rag] analytics retention scheduler: " f"ttl={_ttl_days}d interval={_prune_iv}s") except Exception as _e: print(f"[tau-rag] retention scheduler setup failed: {_e}") # Background alert evaluator (v1.81) — optional, enabled by env var. try: _alert_iv_raw = _os.environ.get("TAU_RAG_ALERT_EVAL_INTERVAL_SEC") if _alert_iv_raw: _alert_iv = float(_alert_iv_raw) if _alert_iv > 0: from ..middleware import ( AlertScheduler, get_alert_store, get_metrics_history, set_alert_scheduler, ) def _alert_fire_cb(verdict): try: get_obs().audit( "alert.fired", rule=verdict["rule"], reason=verdict["reason"], latest_value=verdict["latest_value"], n_samples=verdict["n_samples"], ) except Exception: pass _asched = AlertScheduler( get_alert_store(), get_metrics_history(), interval_s=_alert_iv, on_fire=_alert_fire_cb, evaluate_on_start=True, ) _asched.start() set_alert_scheduler(_asched) print(f"[tau-rag] alert scheduler every {_alert_iv}s") except Exception as _e: print(f"[tau-rag] alert scheduler setup failed: {_e}") # Auto-warmup on startup if env requests it (v1.56). if _os.environ.get("TAU_RAG_WARMUP") == "1": try: _fn = getattr(_pipeline, "warmup", None) if callable(_fn): _fn() _pipeline._warmed = True # type: ignore[attr-defined] print("[tau-rag] auto-warmup complete") except Exception as _e: print(f"[tau-rag] auto-warmup failed: {_e}") # Auto-snapshot on shutdown — pairs with auto-restore on startup above. @app.on_event("shutdown") def _save_snapshot_on_shutdown() -> None: # Stop the periodic thread first so we don't race with the final save auto = get_autosnapshotter() if auto: auto.stop() set_autosnapshotter(None) # v1.78 — also stop the metrics sampler cleanly try: from ..middleware import get_metrics_sampler, set_metrics_sampler msamp = get_metrics_sampler() if msamp: msamp.stop() set_metrics_sampler(None) except Exception: pass # v1.81 — stop the alert scheduler cleanly try: from ..middleware import get_alert_scheduler, set_alert_scheduler asched = get_alert_scheduler() if asched: asched.stop() set_alert_scheduler(None) except Exception: pass # v1.93 — stop the analytics retention scheduler cleanly try: from ..middleware import ( get_retention_scheduler, set_retention_scheduler, ) rs = get_retention_scheduler() if rs: rs.stop() set_retention_scheduler(None) except Exception: pass path = _os.environ.get("TAU_RAG_SNAPSHOT_PATH") if not path: return try: summary = _pipeline.save_snapshot(path) print(f"[tau-rag] shutdown snapshot saved: {summary}") get_obs().audit("snapshot.auto_save_on_shutdown", **summary) except Exception as _e: print(f"[tau-rag] shutdown snapshot failed: {_e}") # ------------------------------------------------------------------ schemas class SearchRequest(BaseModel): query: str k: int = 10 rerank_k: int = 5 strategy: str = "hybrid" lang: str = "he" filters: Dict[str, Any] = {} class DocumentRequest(BaseModel): id: str text: str metadata: Dict[str, Any] = {} class DocumentsRequest(BaseModel): documents: List[DocumentRequest] # ------------------------------------------------------------------ routes _PLAYGROUND_HTML = """ TAU-RAG

🔎 TAU-RAG — Hebrew legal RAG

Pipeline alive at /v1/*. Swagger UI: /docs · ReDoc: /redoc

📄 Documents
❓ Single query
💬 Chat

Add documents


Ask a question


Chat with session memory

session
""" @app.get("/", response_class=None) def root(): from fastapi.responses import HTMLResponse return HTMLResponse(_PLAYGROUND_HTML) @app.get("/favicon.ico") def favicon(): # Silent 204 to stop the noisy browser 404 from fastapi.responses import Response return Response(status_code=204) @app.get("/health") def health(): return {"ok": True, "version": "2.0.0"} @app.get("/v1/version") def version_manifest(): """Build + runtime version info. Unauthenticated so anyone (including deploy scripts, monitoring, and teammates debugging) can check what's actually running. Does not expose secrets, just structural metadata.""" import platform as _plat import sys as _sys import subprocess as _sp # Pipeline structure retr_multi = getattr(_pipeline, "retrievers", None) retriever_members = ( sorted(getattr(retr_multi, "retrievers", {}).keys()) if retr_multi is not None else [] ) cfg = _pipeline.config preset = _os.environ.get("TAU_RAG_PRESET", "unknown") # Build info — keep it safe to serialize try: import fastapi as _fastapi fastapi_v = getattr(_fastapi, "__version__", "unknown") except Exception: fastapi_v = "unknown" # Git metadata — optional; silent fallback if not in a git checkout git_info: Dict[str, Any] = {} try: commit = _sp.check_output( ["git", "rev-parse", "HEAD"], stderr=_sp.DEVNULL, timeout=1).decode().strip() branch = _sp.check_output( ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=_sp.DEVNULL, timeout=1).decode().strip() dirty = _sp.check_output( ["git", "status", "--porcelain"], stderr=_sp.DEVNULL, timeout=1).decode().strip() git_info = { "commit": commit, "commit_short": commit[:8], "branch": branch, "dirty": bool(dirty), } except Exception: git_info = {"available": False} # Enabled feature flags (from env) — helps debug "is it off/on in prod?" features = { "auth_required": _os.environ.get("TAU_RAG_REQUIRE_AUTH") == "1", "auto_warmup": _os.environ.get("TAU_RAG_WARMUP") == "1", "snapshot_path": bool(_os.environ.get("TAU_RAG_SNAPSHOT_PATH")), "snapshot_interval": float(_os.environ.get("TAU_RAG_SNAPSHOT_INTERVAL") or 0) or None, "synonyms_path": bool(_os.environ.get("TAU_RAG_SYNONYMS_PATH")), "hsts": _os.environ.get("TAU_RAG_HSTS") == "1", "csp": bool(_os.environ.get("TAU_RAG_CSP")), "cors_origins": int(bool(_os.environ.get("TAU_RAG_CORS_ORIGINS"))), "log_stdout": _os.environ.get("TAU_RAG_LOG_STDOUT") == "1", "log_file": bool(_os.environ.get("TAU_RAG_LOG_PATH")), "audit_webhook": bool(_os.environ.get("TAU_RAG_AUDIT_WEBHOOK_URL")), "endpoint_rate_limits": bool( _os.environ.get("TAU_RAG_ENDPOINT_RATE_LIMITS")), "audit_export": True, "log_stream": True, "key_rotation": True, "snapshot_diff": True, "metrics_history": bool( _os.environ.get("TAU_RAG_METRICS_HISTORY_INTERVAL_SEC")), "webhook_circuit_breaker": True, "alert_rules": True, "alert_scheduler": bool( _os.environ.get("TAU_RAG_ALERT_EVAL_INTERVAL_SEC")), "doc_stats": True, "retriever_attribution": True, "cocitation_graph": True, "content_health": True, "content_health_ui": True, "eval_latency_gate": True, "content_health_history": True, "query_fingerprints": True, "preset_promote_candidates": True, "preset_auto_promote": True, "analytics_retention": True, "analytics_retention_scheduler": bool( _os.environ.get("TAU_RAG_ANALYTICS_TTL_DAYS")), "doc_freshness": True, "doc_update_priorities": True, "query_doc_affinity": True, "analytics_dump_restore": True, "query_analytics_ui": True, "query_replay": True, "replay_body_capture": ( _os.environ.get("TAU_RAG_OBS_CAPTURE_BODY") == "1"), "v2_stable_api": True, "about_endpoint": True, "semantic_cache": ( _os.environ.get("TAU_RAG_SEMANTIC_CACHE") == "1"), "graph_cocitation_boost": ( float(_os.environ.get("TAU_RAG_GRAPH_COCITATION_BOOST") or 0.0) if _os.environ.get("TAU_RAG_GRAPH_COCITATION_BOOST") else False), "query_doc_boost": ( float(_os.environ.get("TAU_RAG_QUERY_DOC_BOOST") or 0.0) if _os.environ.get("TAU_RAG_QUERY_DOC_BOOST") else False), "request_spans": True, "span_timeline_ui": True, "limiter_backend_protocol": True, "maintenance_mode": True, "pii_redaction": True, "pii_redaction_enabled": ( _os.environ.get("TAU_RAG_PII_REDACT") == "1"), "slow_query_detection": True, "readiness_registry": True, "daily_quota": True, "idempotency_key": True, "request_timeout": True, "log_rotation": True, "span_exporter_protocol": True, "span_exporter_type": type( __import__("tau_rag.observability.span_exporters", fromlist=["get_span_exporter"]) .get_span_exporter()).__name__, } return { "version": "2.0.0", "preset": preset, "pipeline": { "retriever_members": retriever_members, "generator_provider": getattr(cfg.generation, "provider", "unknown"), "fusion_method": getattr(cfg.fusion, "method", "unknown"), "rerank_method": (getattr(cfg.rerank, "method", None) if getattr(cfg, "rerank", None) else None), "verifier": type(_pipeline.verifier).__name__, "chunker": getattr(_pipeline, "_chunker_last", "fixed"), }, "build": { "python": _sys.version.split()[0], "platform": _plat.platform(), "fastapi": fastapi_v, }, "git": git_info, "features": features, } # ------------------------------------------------------ ops-ready endpoints from .metrics import render_prometheus, check_readiness # noqa: E402 @app.get("/livez", response_class=None) def livez(): """Liveness probe — 200 if the process can answer.""" from fastapi.responses import PlainTextResponse return PlainTextResponse("ok", status_code=200) @app.get("/readyz") def readyz(require_warmed: bool = False): """Readiness probe — 503 + detail if pipeline isn't ready. Pass ``?require_warmed=1`` to also fail until ``POST /v1/admin/warmup`` has been invoked (useful for deployment gating). v2.11 — also consults the pluggable readiness registry so plugins and operator-registered checks (Redis, S3, etc.) participate. A failing critical check in the registry flips /readyz to 503. """ ok, detail = check_readiness(_pipeline, require_warmed=bool(require_warmed)) # v2.11 — also check pluggable registry (ties v2.7 drain mode and any # operator-registered checks) from ..middleware.readiness import get_readiness_registry reg_result = get_readiness_registry().evaluate() if not ok or not reg_result["ready"]: # Merge detail from both sources body = { "detail": detail if not ok else None, "checks": reg_result["checks"], "n_passed": reg_result["n_passed"], "n_failed": reg_result["n_failed"], } raise HTTPException(status_code=503, detail=body) return {"ok": True, "detail": detail, "checks": reg_result["checks"]} @app.get("/v1/admin/readiness") def admin_readiness(): """Full readiness report (v2.11). Always returns 200, unlike /readyz — useful for dashboards that want to show health without tripping k8s routing. Contains every registered check's current state, plus an overall ``ready`` bool.""" from ..middleware.readiness import get_readiness_registry return get_readiness_registry().evaluate() @app.post("/v1/admin/warmup") def admin_warmup(request: Request): """Pre-load heavy components (embedders, tokenizers, adapters). Sets the ``pipeline._warmed`` flag so ``/readyz?require_warmed=1`` starts passing.""" import time as _t t0 = _t.time() try: fn = getattr(_pipeline, "warmup", None) if callable(fn): fn() _pipeline._warmed = True # type: ignore[attr-defined] elapsed = round((_t.time() - t0) * 1000.0, 2) get_obs().audit( "pipeline.warmup", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), elapsed_ms=elapsed, ) return {"warmed": True, "elapsed_ms": elapsed} except Exception as e: raise HTTPException(status_code=500, detail=f"warmup failed: {type(e).__name__}: {e}"[:200]) @app.get("/metrics", response_class=None) def metrics(): """Prometheus exposition format — scrape me every 15s.""" from fastapi.responses import PlainTextResponse auth = get_auth() keys = auth.list_keys() active = sum(1 for k in keys if not k.get("revoked")) revoked = sum(1 for k in keys if k.get("revoked")) body = render_prometheus( obs_stats=get_obs().stats(), cache_stats=get_cache().stats(), limiter_stats=get_limiter().stats(), auth_keys=active, auth_keys_revoked=revoked, version="2.0.0", ) return PlainTextResponse(body, media_type="text/plain; version=0.0.4") @app.post("/v1/documents") def add_documents(req: DocumentsRequest): try: validate_doc_list(req.documents) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) docs = [Document(id=d.id, text=d.text, metadata=d.metadata) for d in req.documents] n = _pipeline.add_documents(docs) return {"added_chunks": n, "documents": len(docs)} # ---- document lifecycle endpoints ---------------------------------------- class DocumentBody(BaseModel): id: str text: str metadata: Dict[str, Any] = {} @app.get("/v1/documents") def list_documents( request: Request, q: Optional[str] = None, limit: int = 50, offset: int = 0, preview_chars: int = 200, ): """List or search indexed documents. Query params: * ``q`` — substring (case-insensitive) over text + id * ``limit`` / ``offset`` — pagination (default 50 / 0) * ``preview_chars`` — first N chars returned per doc (default 200) * ``metadata.=`` — filter on flat metadata keys (repeatable) Backward-compat: with no query params, returns the v1.38 summary shape ``{documents: [...], count: N}``. """ # Flat metadata filter: any query param starting with 'metadata.' meta_filter: Dict[str, str] = {} for k, v in request.query_params.multi_items(): if k.startswith("metadata."): meta_filter[k[len("metadata."):]] = v if limit < 1 or limit > 10_000: limit = 50 if offset < 0: offset = 0 result = _pipeline.search_documents( q=q, metadata=meta_filter or None, limit=limit, offset=offset, preview_chars=preview_chars, ) # Back-compat: keep `count` key that v1.38 clients expect result["count"] = result["matched"] return result @app.get("/v1/documents/export", response_class=None) def export_documents(request: Request): """Export the full indexed corpus as JSONL (one ``{id,text,metadata}`` per line). Supports the same filters as ``GET /v1/documents``: * ``?q=`` — substring search * ``?metadata.=`` — flat metadata filter (repeatable) Returns ``application/x-ndjson`` with a download filename so browsers save it as ``tau-rag-documents.jsonl``. """ from fastapi.responses import PlainTextResponse import json as _json q = request.query_params.get("q") meta_filter: Dict[str, str] = {} for k, v in request.query_params.multi_items(): if k.startswith("metadata."): meta_filter[k[len("metadata."):]] = v # Iterate through the full doc-log, applying the same filters as # search_documents() but without the limit cap — export is all-or-none. _pipeline._ensure_doc_log() qn = (q or "").strip().lower() lines: List[str] = [] for d in _pipeline._indexed_docs: if qn: hay = (d.text or "").lower() if qn not in hay and qn not in (d.id or "").lower(): continue if meta_filter: ok = True for mk, mv in meta_filter.items(): if str((d.metadata or {}).get(mk)) != str(mv): ok = False break if not ok: continue lines.append(_json.dumps({ "id": d.id, "text": d.text, "metadata": d.metadata or {}, }, ensure_ascii=False)) body = "\n".join(lines) + ("\n" if lines else "") return PlainTextResponse( body, media_type="application/x-ndjson", headers={ "Content-Disposition": 'attachment; filename="tau-rag-documents.jsonl"', "X-Document-Count": str(len(lines)), }, ) @app.get("/v1/documents/stats") def index_stats(): """Corpus-level stats: doc count, text-length distribution, metadata value histogram, metadata coverage, and retriever set. Safe on large corpora (no full-text scan).""" return _pipeline.index_stats() @app.get("/v1/documents/duplicates") def admin_duplicates(): """Scan the index for documents sharing normalized content (collapsed whitespace, case-folded, sha256'd). Returns ``{groups: [{hash, members}], n_groups, n_duplicate_docs, total_docs}`` — only groups with ≥2 members. Declared *before* ``/v1/documents/{doc_id}`` so FastAPI matches ``/duplicates`` as a fixed path, not a doc id.""" groups = _pipeline.find_duplicates() pretty = [ {"hash": h, "members": members} for h, members in sorted(groups.items(), key=lambda kv: -len(kv[1])) ] n_dup_docs = sum(len(g["members"]) for g in pretty) total = len(_pipeline.list_documents()) return { "n_groups": len(pretty), "n_duplicate_docs": n_dup_docs, "total_docs": total, "groups": pretty, } # ---- Per-document citation stats (v1.82) -------------------------------- # Placed BEFORE /v1/documents/{doc_id} so FastAPI matches these fixed # paths first — same ordering trick as /duplicates above. @app.get("/v1/admin/documents/stats/summary") def admin_docs_stats_summary(): """Global rollup: n_docs tracked, total retrieved/cited, global cite_rate, persistence path (v1.82).""" from ..middleware import get_doc_stats return get_doc_stats().summary() @app.get("/v1/admin/documents/stats/top-cited") def admin_docs_stats_top_cited(n: int = 10): """Top ``n`` documents by citation count (v1.82). Each row carries ``{doc_id, n_retrieved, n_cited, cite_rate, first_seen_at, ...}``.""" from ..middleware import get_doc_stats return {"top": get_doc_stats().top_cited(n=int(n))} @app.get("/v1/admin/documents/stats/unused") def admin_docs_stats_unused( min_retrieved: int = 1, max_cite_rate: float = 0.0, ): """Docs that were retrieved ``min_retrieved``+ times but whose cite_rate stayed at or below ``max_cite_rate`` (default 0 → never cited). Useful for finding retrieval false-positives (v1.82).""" from ..middleware import get_doc_stats return { "unused": get_doc_stats().unused( min_retrieved=int(min_retrieved), max_cite_rate=float(max_cite_rate), ), } @app.post("/v1/admin/documents/stats/reset") def admin_docs_stats_reset(request: Request): """Wipe all per-doc counters (v1.82). audit event emitted.""" from ..middleware import get_doc_stats before = get_doc_stats().summary() get_doc_stats().clear() get_obs().audit( "doc.stats.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_docs=before["n_docs"], prev_total_cited=before["total_cited"], ) return {"reset": True, "before": before} # ---- Per-retriever attribution (v1.83) ---------------------------------- @app.get("/v1/admin/retrievers/stats") def admin_retriever_stats(): """All retrievers ranked by n_cited_contributions (v1.83).""" from ..middleware import get_retriever_attribution store = get_retriever_attribution() return { "summary": store.summary(), "stats": store.all_stats(), } @app.get("/v1/admin/retrievers/stats/ranking") def admin_retriever_ranking(): """Retrievers ordered by cite_rate × log(1 + n_contributions) — smooths precision by sample size so rare-but-perfect retrievers don't outrank workhorses (v1.83).""" from ..middleware import get_retriever_attribution return {"ranking": get_retriever_attribution().ranking()} @app.post("/v1/admin/retrievers/stats/reset") def admin_retriever_stats_reset(request: Request): """Wipe per-retriever counters + audit (v1.83).""" from ..middleware import get_retriever_attribution store = get_retriever_attribution() before = store.summary() store.clear() get_obs().audit( "retriever.stats.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_retrievers=before["n_retrievers"], prev_total_cited=before["total_cited"], ) return {"reset": True, "before": before} # ---- Co-citation graph (v1.84) ------------------------------------------ @app.get("/v1/admin/documents/cocitation/summary") def admin_cocitation_summary(): """Rollup of the co-citation graph: n_events (responses with ≥2 cites), n_pairs, n_docs, total_count (v1.84).""" from ..middleware import get_cocitation return get_cocitation().summary() @app.get("/v1/admin/documents/cocitation/top") def admin_cocitation_top(n: int = 20): """Top ``n`` most-common co-citation pairs (v1.84).""" from ..middleware import get_cocitation return {"top": get_cocitation().top_pairs(n=int(n))} @app.post("/v1/admin/documents/cocitation/reset") def admin_cocitation_reset(request: Request): """Wipe the co-citation graph + audit (v1.84).""" from ..middleware import get_cocitation store = get_cocitation() before = store.summary() store.clear() get_obs().audit( "cocitation.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_pairs=before["n_pairs"], prev_total_count=before["total_count"], ) return {"reset": True, "before": before} @app.get("/v1/admin/content/health") def admin_content_health( top_n: int = 5, unused_min_retrieved: int = 3, ): """Consolidated corpus health report (v1.85) — merges v1.82 doc stats, v1.83 retriever attribution, v1.84 co-citation into a single answer. Cross-cutting insights: ``dead_docs`` (in corpus but never retrieved), ``isolated_docs`` (in corpus but never co-cited). Query params: * ``top_n`` — rows to include in top-cited / top-pairs / ranking sections (default 5). * ``unused_min_retrieved`` — passed to doc_stats.unused() to filter retrieval false-positives. """ from ..middleware import ( get_doc_stats, get_retriever_attribution, get_cocitation, ) doc_store = get_doc_stats() ra_store = get_retriever_attribution() cc_store = get_cocitation() doc_summary = doc_store.summary() ra_summary = ra_store.summary() cc_summary = cc_store.summary() # Corpus set — every indexed doc id all_docs = {d["id"] for d in _pipeline.list_documents()} touched_docs = { d["doc_id"] for d in doc_store.top_cited(n=10 ** 9) } # all tracked docs, regardless of count # Dead docs: indexed but counters never fired (n_retrieved == 0) dead = sorted(all_docs - touched_docs) # Isolated docs: tracked, but never co-cited with anyone # (i.e. they never appeared in a response with ≥2 sources). partnered = set() for pair in cc_store.top_pairs(n=10 ** 9): partnered.add(pair["a"]) partnered.add(pair["b"]) isolated = sorted(touched_docs - partnered) # Derived corpus-level health score. # * coverage = fraction of indexed docs that were ever retrieved # * cite_rate = global doc-level cite rate # * connectivity = fraction of touched docs that are non-isolated n_all = max(1, len(all_docs)) n_touched = len(touched_docs) coverage = n_touched / n_all cite_rate = doc_summary.get("global_cite_rate", 0.0) connectivity = ( (n_touched - len(isolated)) / max(1, n_touched) if n_touched else 0.0 ) # Equal-weight geometric mean — any dimension collapsing to 0 # drags the whole score down. Helps operators see a single knob. score = (coverage * cite_rate * connectivity) ** (1 / 3) if ( coverage > 0 and cite_rate > 0 and connectivity > 0 ) else 0.0 return { "score": round(score, 4), "coverage": round(coverage, 4), "cite_rate": round(cite_rate, 4), "connectivity": round(connectivity, 4), "corpus": { "n_indexed": len(all_docs), "n_touched": n_touched, "n_dead": len(dead), "n_isolated": len(isolated), }, "top_cited": doc_store.top_cited(n=int(top_n)), "top_noisy": doc_store.unused( min_retrieved=int(unused_min_retrieved), max_cite_rate=0.0, ), "retrievers": { "summary": ra_summary, "ranking": ra_store.ranking()[:int(top_n)], }, "cocitation": { "summary": cc_summary, "top_pairs": cc_store.top_pairs(n=int(top_n)), }, "dead_docs": dead, "isolated_docs": isolated, } # ---- Query fingerprint analytics (v1.89) -------------------------------- @app.get("/v1/admin/queries/stats/summary") def admin_query_stats_summary(): """Rollup: n_unique_queries, n_events (total), avg_sources per query (v1.89).""" from ..middleware import get_query_stats return get_query_stats().summary() @app.get("/v1/admin/queries/stats/top") def admin_query_stats_top(n: int = 10): """Top ``n`` queries by frequency (v1.89).""" from ..middleware import get_query_stats return {"top": get_query_stats().top(n=int(n))} @app.get("/v1/admin/queries/stats/recent") def admin_query_stats_recent( since: Optional[float] = None, n: int = 10, ): """Queries whose last hit was ≥ ``since`` (Unix ts), newest first (v1.89). Omit ``since`` to get the most-recently-seen N regardless of age.""" from ..middleware import get_query_stats return {"recent": get_query_stats().recent(since=since, n=int(n))} @app.get("/v1/admin/queries/promote-candidates") def admin_query_promote_candidates( min_count: int = 3, min_sources: float = 0.0, max_avg_latency_ms: Optional[float] = None, n: int = 20, ): """Return query fingerprints that are strong candidates for promotion to saved presets (v1.90). Heuristic: a query is a good preset candidate when it's been asked often enough to justify the saved-search slot, returns a useful number of sources on average, doesn't already have a preset, and (optionally) isn't slower than some threshold. Query params: * ``min_count`` — minimum observed occurrences (default 3). * ``min_sources`` — minimum ``avg_sources`` per response. Filters out popular queries that find nothing useful (0 = no filter). * ``max_avg_latency_ms`` — optional ceiling on average latency. Omit to not filter. * ``n`` — cap on rows returned. A candidate row is a ``QueryStats.to_dict()`` plus a derived ``suggested_preset_name`` that ops can accept as-is. """ from ..middleware import get_query_stats from ..middleware.query_stats import _canonicalize from ..presets import get_preset_store query_store = get_query_stats() preset_store = get_preset_store() # Pre-index the presets by canonical query text so O(P) setup turns # the per-candidate check into O(1). existing_canonical: set = set() for p in preset_store.list_all(): existing_canonical.add(_canonicalize(p.get("query", ""))) candidates: List[Dict[str, Any]] = [] for row in query_store.top(n=10 ** 9): if row["count"] < int(min_count): continue if row["avg_sources"] < float(min_sources): continue if (max_avg_latency_ms is not None and row["avg_latency_ms"] > float(max_avg_latency_ms)): continue canonical = _canonicalize(row["sample"]) if canonical in existing_canonical: continue # Derive a snake_case preset name from the sample (short). suggested = _suggest_preset_name(row["sample"]) candidates.append({ **row, "suggested_preset_name": suggested, "already_preset": False, }) if len(candidates) >= int(n): break return { "candidates": candidates, "n_candidates": len(candidates), "min_count": int(min_count), "min_sources": float(min_sources), "n_existing_presets": len(existing_canonical), } class PresetPromoteRequest(BaseModel): # Explicit list of fingerprints to promote. Mutually exclusive with # auto-mode filters below — if both are set, names take precedence. fingerprints: Optional[List[str]] = None # Auto-mode: pick candidates via filters (same as v1.90 endpoint) min_count: int = 3 min_sources: float = 0.0 max_avg_latency_ms: Optional[float] = None limit: int = 20 # Common preset knobs — applied to every created preset k: int = 10 rerank_k: int = 5 strategy: str = "hybrid" lang: str = "he" # Naming name_prefix: str = "" # optional prefix (e.g. "promoted-") # Safety dry_run: bool = False # preview without creating @app.post("/v1/admin/queries/promote") def admin_queries_promote(req: PresetPromoteRequest, request: Request): """Auto-promote query-stats candidates to saved presets (v1.91). Two modes: * **Explicit**: pass ``fingerprints=[...]`` to promote specific queries by their v1.89 fingerprints. * **Filtered**: pass the same filter params as ``/v1/admin/queries/promote-candidates`` (v1.90) and we'll promote the top ``limit`` that match. In both modes we skip queries whose canonical text already has a preset, and we deduplicate against name collisions (adding a ``-2``, ``-3`` suffix). ``dry_run=True`` returns the planned actions without touching the preset store. Returns:: { created: [{name, query, fingerprint}, ...], skipped: [{fingerprint, reason}, ...], dry_run: bool, n_created: int, n_skipped: int, } One ``preset.auto_promoted`` audit event is emitted per created preset so the change flows through the webhook (v1.71). """ from ..middleware import get_query_stats from ..middleware.query_stats import _canonicalize from ..presets import get_preset_store, QueryPreset query_store = get_query_stats() preset_store = get_preset_store() # Pre-index existing presets by canonical query AND by name so we # can skip duplicates in both dimensions. existing_canonical: set = set() existing_names: set = set() for p in preset_store.list_all(): existing_canonical.add(_canonicalize(p.get("query", ""))) existing_names.add(p.get("name", "")) # ---- pick candidates if req.fingerprints is not None: # Explicit mode rows = [] for fp in req.fingerprints: s = query_store.get(fp) if s is None: rows.append({"fingerprint": fp, "_missing": True}) else: rows.append(s.to_dict()) else: # Filter mode — mirror the v1.90 logic rows = [] for row in query_store.top(n=10 ** 9): if row["count"] < int(req.min_count): continue if row["avg_sources"] < float(req.min_sources): continue if (req.max_avg_latency_ms is not None and row["avg_latency_ms"] > float(req.max_avg_latency_ms)): continue canonical = _canonicalize(row["sample"]) if canonical in existing_canonical: continue rows.append(row) if len(rows) >= int(req.limit): break # ---- plan created: List[Dict[str, Any]] = [] skipped: List[Dict[str, Any]] = [] used_names = set(existing_names) for row in rows: fp = row.get("fingerprint", "") if row.get("_missing"): skipped.append({"fingerprint": fp, "reason": "fingerprint not found in query_stats"}) continue sample = row.get("sample", "") canonical = _canonicalize(sample) if not canonical: skipped.append({"fingerprint": fp, "reason": "empty query after canonicalization"}) continue if canonical in existing_canonical: skipped.append({"fingerprint": fp, "reason": "already a preset (same canonical)"}) continue base_name = _suggest_preset_name(sample) if req.name_prefix: base_name = f"{req.name_prefix}{base_name}" # Dedupe against already-used names name = base_name suffix = 2 while name in used_names: name = f"{base_name}-{suffix}" suffix += 1 if not req.dry_run: try: preset_store.put(QueryPreset( name=name, query=sample, k=int(req.k), rerank_k=int(req.rerank_k), strategy=req.strategy, lang=req.lang, notes=f"auto-promoted from traffic (fp={fp}, " f"count={row.get('count')})", )) except Exception as e: skipped.append({"fingerprint": fp, "reason": f"put failed: " f"{type(e).__name__}: {e}"}) continue get_obs().audit( "preset.auto_promoted", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), name=name, fingerprint=fp, count=row.get("count"), avg_sources=row.get("avg_sources"), ) used_names.add(name) existing_canonical.add(canonical) created.append({ "name": name, "query": sample, "fingerprint": fp, "count": row.get("count"), }) return { "created": created, "skipped": skipped, "n_created": len(created), "n_skipped": len(skipped), "dry_run": bool(req.dry_run), } def _suggest_preset_name(sample: str) -> str: """Turn a raw user query into a safe preset id: lowercase, ASCII where possible, hyphens for whitespace, strip punctuation, cap length. Keeps non-ASCII runs (Hebrew letters) as-is when they can't be transliterated, so the result is still recognizable.""" import re s = (sample or "").strip().lower() # Drop typical punctuation s = re.sub(r"[\?\!\.\,\:\;\(\)\[\]\{\}\"'`]", "", s) s = re.sub(r"\s+", "-", s) # Clip to 48 chars — leaves room for a namespace prefix if len(s) > 48: s = s[:48].rstrip("-") return s or "preset" @app.post("/v1/admin/queries/stats/reset") def admin_query_stats_reset(request: Request): """Wipe the query fingerprint store + audit (v1.89).""" from ..middleware import get_query_stats store = get_query_stats() before = store.summary() store.clear() get_obs().audit( "query.stats.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_unique=before["n_unique_queries"], prev_n_events=before["n_events"], ) return {"reset": True, "before": before} class AnalyticsPruneRequest(BaseModel): older_than_days: Optional[float] = None older_than_seconds: Optional[float] = None # Which stores to prune. Omit a flag to skip that store. doc_stats: bool = True retriever_attribution: bool = True cocitation: bool = True query_stats: bool = True @app.get("/v1/admin/analytics/prune/scheduler") def admin_retention_scheduler_status(): """Report the background retention scheduler state (v1.93).""" from ..middleware import get_retention_scheduler sched = get_retention_scheduler() if sched is None: return {"enabled": False, "is_running": False} return {"enabled": True, **sched.status()} @app.post("/v1/admin/analytics/prune") def admin_analytics_prune( req: AnalyticsPruneRequest, request: Request, ): """Prune stale entries across all analytics stores (v1.92). Removes rows whose last-activity timestamp is older than the TTL. Pass ``older_than_days`` OR ``older_than_seconds``; one is required. Per-store flags let ops target a subset (e.g. prune only ``query_stats`` while keeping doc history). Returns per-store ``{n_removed, n_remaining_after}`` + audit event. """ from ..middleware import ( get_doc_stats, get_retriever_attribution, get_cocitation, get_query_stats, ) # Resolve TTL seconds if req.older_than_seconds is not None: ttl_s = float(req.older_than_seconds) elif req.older_than_days is not None: ttl_s = float(req.older_than_days) * 86400.0 else: raise HTTPException( status_code=400, detail={"error": "either older_than_days or " "older_than_seconds is required"}, ) if ttl_s <= 0: raise HTTPException( status_code=400, detail={"error": "TTL must be positive"}, ) results: Dict[str, Dict[str, int]] = {} total_removed = 0 if req.doc_stats: store = get_doc_stats() n = store.prune(ttl_s) results["doc_stats"] = { "n_removed": n, "n_remaining_after": store.summary()["n_docs"], } total_removed += n if req.retriever_attribution: store = get_retriever_attribution() n = store.prune(ttl_s) results["retriever_attribution"] = { "n_removed": n, "n_remaining_after": store.summary()["n_retrievers"], } total_removed += n if req.cocitation: store = get_cocitation() n = store.prune(ttl_s) results["cocitation"] = { "n_removed": n, "n_remaining_after": store.summary()["n_pairs"], } total_removed += n if req.query_stats: store = get_query_stats() n = store.prune(ttl_s) results["query_stats"] = { "n_removed": n, "n_remaining_after": store.summary()["n_unique_queries"], } total_removed += n get_obs().audit( "analytics.prune", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), older_than_seconds=ttl_s, total_removed=total_removed, ) return { "ttl_seconds": ttl_s, "total_removed": total_removed, "per_store": results, } @app.get("/v1/about") def about(): """Public architectural overview (v2.0). Non-admin — no key required. Useful for clients, docs generators, CI checks. Covers the layers of tau-rag and points at the relevant primitives. """ return { "name": "tau-rag", "version": "2.0.0", "tagline": "Unified Hebrew-legal RAG with structure-preserving " "verification + TAU-Ω signals", "layers": { "retrieval": { "retrievers": ["bm25", "gematria", "hilbert", "graph"], "fusion": "rank-based or weighted", "rerank": "optional cross-encoder or score-based", "chunker": "fixed | sentence | legal_hebrew", }, "observability_stack": { "push": "webhook + breaker (v1.71/79)", "batch": "/v1/admin/audit/export (v1.74)", "pull_stream": "/v1/admin/logs/stream SSE (v1.75)", "history": "metrics + content health (v1.78/88)", "alerts": "rules + scheduler (v1.80/81)", }, "content_analytics": { "doc_stats": "v1.82", "retriever_attribution": "v1.83", "cocitation": "v1.84", "query_stats": "v1.89", "doc_freshness": "v1.94", "query_doc_affinity": "v1.96", }, "analytics_cross_cuts": { "content_health": "v1.85/86", "update_priorities": "v1.95", "query_analytics_ui": "v1.98", "dump_restore": "v1.97", }, "debugging": { "request_ids": "X-Request-ID on every response", "replay": "v1.99 re-execute by request_id", }, }, "patterns": { "side_channel_stores": "singleton+inject pattern; pipeline hook silent-fail; " "admin CRUD; persistence opt-in", "daemons": "AutoSnapshotter / MetricsHistorySampler / " "AlertScheduler / AnalyticsRetentionScheduler — " "start/stop/is_running/status + Event.wait + silent-fail", "quiet_on_zero": "schedulers emit audits only on state change", "html_dashboards": "inline CSS, zero JS, zero CDN, escape-safe, " "meta-refresh for wall screens", }, "stability": { "api_stability": "v2.0 marks /v1/* as stable — additive " "changes only; breaking changes → /v2/*", "deprecation_policy": "6-month notice; features.* flags " "track active capabilities", }, "counts": { "endpoints": "80+", "tests": "1096+", "side_channels": 6, "daemons": 4, "html_dashboards": 4, }, } @app.get("/v1/admin/requests/{request_id}/spans/ui", response_class=None) def admin_request_spans_ui(request_id: str, refresh: int = 0): """HTML timeline view of a request's spans (v2.5). Renders v2.4 span data as a gantt-style bar chart for quick operator inspection. Same design language as v1.86 / v1.98 dashboards.""" from fastapi.responses import HTMLResponse from .span_timeline_ui import render_span_timeline # Reuse the JSON endpoint's data gathering by calling its function data = admin_request_spans(request_id) html = render_span_timeline( request_id=data["request_id"], n_spans=data["n_spans"], total_ms=data["total_ms"], spans=data["spans"], refresh_sec=int(refresh or 0), ) return HTMLResponse(html) @app.get("/v1/admin/requests/{request_id}/spans") def admin_request_spans(request_id: str): """Return in-memory trace spans for a specific request_id (v2.4). Pipeline stages (understand, retrieve, fuse, rerank, generate, verify, ...) each open a span; middleware auto-tags every span with the current request_id. This endpoint pulls them back by that id. Returns:: { "request_id": str, "n_spans": int, "total_ms": float (root span duration), "spans": [{name, trace_id, span_id, parent_id, duration_ms, attrs}, ...], } Useful for: * seeing where time went inside a slow request * correlating a user complaint with what actually executed * diagnosing retriever-specific failures per query """ from ..observability.tracing import get_tracer spans = get_tracer().spans_for_request_id(request_id) if not spans: raise HTTPException( status_code=404, detail={"error": "no spans found for request_id", "hint": "spans are in-memory — oldest get evicted " "past the 5000-span cap"}, ) out = [] root_total_ms = 0.0 for s in spans: dur_ms = (s.end_ts - s.start_ts) * 1000.0 if s.end_ts else 0.0 if s.parent_id is None and dur_ms > root_total_ms: root_total_ms = dur_ms out.append({ "name": s.name, "trace_id": s.trace_id, "span_id": s.span_id, "parent_id": s.parent_id, "duration_ms": round(dur_ms, 2), "attrs": s.attrs, }) return { "request_id": request_id, "n_spans": len(out), "total_ms": round(root_total_ms, 2), "spans": out, } @app.post("/v1/admin/replay/{request_id}") def admin_replay(request_id: str, request: Request): """Re-execute a previously-logged request against the current pipeline (v1.99). Requires body capture to have been on at the time of the original request (``TAU_RAG_OBS_CAPTURE_BODY=1``). Returns:: { "request_id": original id, "replay_request_id": new id, "path": /v1/search | /v1/generate | /v1/chat, "original_body": captured body (truncated to 4KB), "query": parsed query text, "replay": { "sources": [doc_id, ...], "answer": str|None, "passed": bool|None, "omega": float|None, "timing_ms": dict, }, "note": optional human-readable comparison hint. } Useful for: * regression debug — "did our new chunker break this query?" * eval gold augmentation — turn a real user query into a gold case. * postmortem analysis — replay after a bad deploy to prove harm. """ import json as _json # Find the original row in the obs ring buffer row = None for entry in reversed(get_obs().tail(n=10 ** 9, event_type="request")): if entry.get("request_id") == request_id: row = entry break if row is None: raise HTTPException( status_code=404, detail={"error": "request_id not found in obs log", "hint": "ensure the request was recorded — " "obs log is a ring buffer; older rows " "may have been evicted"}, ) body_txt = (row.get("extra") or {}).get("body") if not body_txt: raise HTTPException( status_code=400, detail={"error": "no captured body on this request", "hint": "set TAU_RAG_OBS_CAPTURE_BODY=1 and re-run " "the original request to enable replay"}, ) try: payload = _json.loads(body_txt) except Exception as e: raise HTTPException( status_code=400, detail={"error": "captured body is not valid JSON", "detail": f"{type(e).__name__}: {e}"}) path = row.get("path") or "" # Support the 3 replayable endpoints from ..core.types import Query, Strategy q_text = payload.get("query") if not q_text: raise HTTPException(status_code=400, detail={"error": "captured body has no 'query' field"}) strategy_name = (payload.get("strategy") or "hybrid").lower() try: strategy = Strategy(strategy_name) except Exception: strategy = Strategy.HYBRID q = Query( text=q_text, lang=payload.get("lang") or "he", filters=payload.get("filters") or {}, strategy=strategy, k=int(payload.get("k") or 10), rerank_k=int(payload.get("rerank_k") or 5), ) # Generate a fresh replay_request_id and mark this as a replay in # obs so the new run is traceable. replay_id = generate_request_id() resp = _pipeline.run(q) # Extract the interesting bits omega = None try: omega = float(resp.signals.omega) if resp.signals else None except Exception: pass verif = getattr(resp, "verification", None) get_obs().audit( "replay.executed", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), original_request_id=request_id, replay_request_id=replay_id, path=path, ) return { "request_id": request_id, "replay_request_id": replay_id, "path": path, "original_body": body_txt, "query": q_text, "replay": { "sources": list(resp.sources or []), "answer": resp.answer, "omega": omega, "passed": (bool(getattr(verif, "passed", False)) if verif else None), "timing_ms": dict(resp.timing_ms or {}), }, "note": ("compare 'replay.sources' to whatever the original " "response had — differences reveal pipeline drift " "since the original run"), } @app.get("/v1/admin/queries/ui", response_class=None) def admin_query_analytics_ui( top_n: int = 10, matrix_queries: int = 6, matrix_docs: int = 6, min_count: int = 3, refresh: int = 0, ): """HTML dashboard for query analytics (v1.98). Merges v1.89 (query_stats), v1.90 (promote candidates), and v1.96 (query × doc affinity) into one visual page. Same design language as the v1.86 content-health UI. Query params: * ``top_n`` — rows to show in 'top queries' (default 10). * ``matrix_queries`` — rows in the affinity heatmap (default 6). * ``matrix_docs`` — cols in the affinity heatmap (default 6). * ``min_count`` — promote-candidate threshold (default 3). * ``refresh`` — auto-refresh seconds (0 = off). """ from fastapi.responses import HTMLResponse from ..middleware import ( get_query_stats, get_query_doc_affinity, ) from ..middleware.query_stats import _canonicalize from ..presets import get_preset_store from .query_analytics_ui import render_query_analytics_ui q_store = get_query_stats() qda = get_query_doc_affinity() preset_store = get_preset_store() summary = q_store.summary() top_qs = q_store.top(n=int(top_n)) # Promote candidates — reuse same logic as v1.90 endpoint existing_canonical = set() for p in preset_store.list_all(): existing_canonical.add(_canonicalize(p.get("query", ""))) promote = [] for row in q_store.top(n=10 ** 9): if row["count"] < int(min_count): continue if row.get("avg_sources", 0.0) < 1.0: continue canon = _canonicalize(row["sample"]) if canon in existing_canonical: continue promote.append({ **row, "suggested_preset_name": _suggest_preset_name(row["sample"]), }) if len(promote) >= 10: break # Build affinity matrix grid matrix_q_rows = q_store.top(n=int(matrix_queries)) # Pick the top docs across the shown queries doc_votes: Dict[str, int] = {} for q in matrix_q_rows: for r in qda.top_docs_for_query(q["fingerprint"], n=10 ** 9): doc_votes[r["doc_id"]] = doc_votes.get(r["doc_id"], 0) + r["count"] top_doc_ids = [d for d, _ in sorted(doc_votes.items(), key=lambda kv: -kv[1])] top_doc_ids = top_doc_ids[:int(matrix_docs)] # Pre-compute the (fp, doc_id) → count map for the rendered subset matrix_pairs: Dict[tuple, int] = {} for q in matrix_q_rows: for r in qda.top_docs_for_query(q["fingerprint"], n=10 ** 9): if r["doc_id"] in top_doc_ids: matrix_pairs[(q["fingerprint"], r["doc_id"])] = r["count"] html = render_query_analytics_ui( summary=summary, top_queries=top_qs, promote_candidates=promote, matrix_queries=matrix_q_rows, matrix_docs=top_doc_ids, matrix_pairs=matrix_pairs, refresh_sec=int(refresh or 0), ) return HTMLResponse(html) @app.get("/v1/admin/content/health/ui", response_class=None) def admin_content_health_ui( top_n: int = 5, unused_min_retrieved: int = 3, refresh: int = 0, ): """HTML dashboard for the corpus health report (v1.86). Same data as ``/v1/admin/content/health`` (v1.85) but rendered as a self- contained page. ``?refresh=N`` opts into an HTML meta-refresh every N seconds — handy to leave open on a wall screen.""" from fastapi.responses import HTMLResponse from .content_health_ui import render_content_health_ui health = admin_content_health( top_n=top_n, unused_min_retrieved=unused_min_retrieved, ) html = render_content_health_ui( health, refresh_sec=int(refresh or 0), ) return HTMLResponse(html) # ---- Doc freshness tracking (v1.94) ------------------------------------- @app.get("/v1/admin/documents/freshness/summary") def admin_doc_freshness_summary(): """Rollup of doc freshness — n_docs, oldest/newest added_at, median age, total modifications (v1.94).""" from ..middleware import get_doc_freshness return get_doc_freshness().summary() @app.get("/v1/admin/documents/freshness/stale") def admin_doc_freshness_stale(older_than_days: float = 90.0): """Docs whose last activity (modified or added) is older than ``older_than_days`` (v1.94). Oldest-first ordering so content audit can start at the top.""" from ..middleware import get_doc_freshness return { "older_than_days": float(older_than_days), "stale": get_doc_freshness().stale( older_than_days=float(older_than_days)), } @app.post("/v1/admin/documents/freshness/reset") def admin_doc_freshness_reset(request: Request): """Wipe the freshness side-channel store + audit. Use after a large corpus reload when old timestamps are meaningless (v1.94).""" from ..middleware import get_doc_freshness store = get_doc_freshness() before = store.summary() store.clear() get_obs().audit( "doc.freshness.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_docs=before["n_docs"], ) return {"reset": True, "before": before} # ---- Query × doc affinity (v1.96) --------------------------------------- # ---- Unified analytics dump/restore (v1.97) ----------------------------- _ANALYTICS_DUMP_VERSION = 1 @app.get("/v1/admin/analytics/dump") def admin_analytics_dump(): """Single-call snapshot of all 6 side-channel analytics stores (v1.97). Returns JSON with one key per store — versioned payload for migration and offline analysis. Stores included: * v1.82 doc_stats * v1.83 retriever_attribution * v1.84 cocitation * v1.89 query_stats * v1.94 doc_freshness * v1.96 query_doc_affinity The ``version`` field signals the dump format — kept as a single integer so restore can refuse incompatible schemas. """ from ..middleware import ( get_doc_stats, get_retriever_attribution, get_cocitation, get_query_stats, get_doc_freshness, get_query_doc_affinity, ) # doc_stats: every tracked doc + raw counters doc_store = get_doc_stats() doc_rows = [] # Walk via top_cited(n=10**9) — it returns all rows for row in doc_store.top_cited(n=10 ** 9): doc_rows.append({ "doc_id": row["doc_id"], "n_retrieved": row.get("n_retrieved", 0), "n_cited": row.get("n_cited", 0), "first_seen_at": row.get("first_seen_at"), "last_retrieved_at": row.get("last_retrieved_at"), "last_cited_at": row.get("last_cited_at"), }) # retriever_attribution ra = get_retriever_attribution() ra_rows = [ {k: v for k, v in row.items() if k != "cite_rate"} for row in ra.all_stats() ] # cocitation cc = get_cocitation() cc_pairs = cc.top_pairs(n=10 ** 9) # query_stats qs = get_query_stats() qs_rows = [] for row in qs.top(n=10 ** 9): qs_rows.append({ "fingerprint": row["fingerprint"], "sample": row["sample"], "count": row["count"], "first_seen_at": row["first_seen_at"], "last_seen_at": row["last_seen_at"], "sum_sources": row["sum_sources"], "sum_latency_ms": row["sum_latency_ms"], }) # doc_freshness fs = get_doc_freshness() fs_rows = [] # No top() here, walk via data dict — use all stale(older_than_days=0) # which gives every row sorted oldest-first; take as-is. import time as _t now = _t.time() for row in fs.stale(older_than_days=0, now=now): fs_rows.append({ "doc_id": row["doc_id"], "added_at": row["added_at"], "last_modified_at": row.get("last_modified_at"), "n_modifications": row.get("n_modifications", 0), }) # query_doc_affinity qda = get_query_doc_affinity() qda_pairs: List[Dict[str, Any]] = [] qda_summary = qda.summary() # Use the inverted index: for each query fingerprint, enumerate # its docs. Cheap — no O(N*M) scan. for fp in list(qda._by_query.keys()): # noqa: SLF001 for row in qda.top_docs_for_query(fp, n=10 ** 9): qda_pairs.append({ "fingerprint": fp, "doc_id": row["doc_id"], "count": row["count"], "last_seen": row.get("last_seen"), }) return { "version": _ANALYTICS_DUMP_VERSION, "exported_at": now, "doc_stats": { "rows": doc_rows, "n_rows": len(doc_rows), }, "retriever_attribution": { "rows": ra_rows, "n_rows": len(ra_rows), }, "cocitation": { "pairs": cc_pairs, "n_pairs": len(cc_pairs), "n_events": cc.summary().get("n_events", 0), }, "query_stats": { "rows": qs_rows, "n_rows": len(qs_rows), }, "doc_freshness": { "rows": fs_rows, "n_rows": len(fs_rows), }, "query_doc_affinity": { "pairs": qda_pairs, "n_pairs": len(qda_pairs), "n_events": qda_summary.get("n_events", 0), }, } class AnalyticsRestoreRequest(BaseModel): dump: Dict[str, Any] replace: bool = True # default: wipe before restore @app.post("/v1/admin/analytics/restore") def admin_analytics_restore(req: AnalyticsRestoreRequest, request: Request): """Rebuild the 6 analytics stores from a v1.97 dump. ``replace=True`` (default) wipes each store before loading — gives exact-match state after restore. ``replace=False`` merges on top of existing data (fingerprints / doc_ids collide → counters SUM). Useful for aggregating traffic across prod nodes. Refuses to restore from dumps whose ``version`` doesn't match the current ``_ANALYTICS_DUMP_VERSION`` — schema compatibility gate. """ from ..middleware import ( get_doc_stats, set_doc_stats, DocumentStatsStore, get_retriever_attribution, set_retriever_attribution, RetrieverAttributionStore, get_cocitation, set_cocitation, CoCitationStore, get_query_stats, set_query_stats, QueryStatsStore, get_doc_freshness, set_doc_freshness, DocFreshnessStore, get_query_doc_affinity, set_query_doc_affinity, QueryDocAffinityStore, ) dump = req.dump or {} ver = dump.get("version") if ver != _ANALYTICS_DUMP_VERSION: raise HTTPException( status_code=400, detail={"error": "version mismatch", "expected": _ANALYTICS_DUMP_VERSION, "got": ver}, ) totals: Dict[str, int] = {} # doc_stats if req.replace: set_doc_stats(DocumentStatsStore()) doc_store = get_doc_stats() doc_rows = (dump.get("doc_stats") or {}).get("rows") or [] for row in doc_rows: did = row.get("doc_id") if not did: continue # Directly seed inner state (avoids driving up counters via # record() N times when N can be huge on real dumps) from ..middleware.doc_stats import DocumentStats as _DS doc_store._data[did] = _DS( # noqa: SLF001 doc_id=did, n_retrieved=int(row.get("n_retrieved", 0)), n_cited=int(row.get("n_cited", 0)), first_seen_at=row.get("first_seen_at"), last_retrieved_at=row.get("last_retrieved_at"), last_cited_at=row.get("last_cited_at"), ) totals["doc_stats"] = len(doc_rows) # retriever_attribution if req.replace: set_retriever_attribution(RetrieverAttributionStore()) ra = get_retriever_attribution() from ..middleware.retriever_attribution import RetrieverStats as _RS for row in (dump.get("retriever_attribution") or {}).get("rows") or []: nm = row.get("name") if not nm: continue ra._data[nm] = _RS( # noqa: SLF001 name=nm, n_contributed=int(row.get("n_contributed", 0)), n_doc_contributions=int(row.get("n_doc_contributions", 0)), n_cited_contributions=int(row.get("n_cited_contributions", 0)), first_seen_at=row.get("first_seen_at"), last_seen_at=row.get("last_seen_at"), ) totals["retriever_attribution"] = len( (dump.get("retriever_attribution") or {}).get("rows") or []) # cocitation — replay via record() (preserves partner index) if req.replace: set_cocitation(CoCitationStore()) cc = get_cocitation() n_cc = 0 for pair in (dump.get("cocitation") or {}).get("pairs") or []: count = int(pair.get("count", 0)) a = pair.get("a"); b = pair.get("b") if not a or not b or count <= 0: continue # Bump the pair count ``count`` times via direct state access — # replay would create a new n_events per iteration which skews # the counter. from ..middleware.cocitation import _pair_key as _pk k = _pk(a, b) cc._pairs[k] = int(cc._pairs.get(k, 0)) + count # noqa: SLF001 cc._partners[a].add(b) # noqa: SLF001 cc._partners[b].add(a) # noqa: SLF001 ls = pair.get("last_seen") if ls is not None: cc._last_seen[k] = float(ls) # noqa: SLF001 n_cc += 1 n_events_cc = int((dump.get("cocitation") or {}).get("n_events", 0)) if n_events_cc: cc._n_events = cc._n_events + n_events_cc # noqa: SLF001 totals["cocitation_pairs"] = n_cc # query_stats if req.replace: set_query_stats(QueryStatsStore()) qs = get_query_stats() from ..middleware.query_stats import QueryStats as _QS for row in (dump.get("query_stats") or {}).get("rows") or []: fp = row.get("fingerprint") if not fp: continue existing = qs._data.get(fp) # noqa: SLF001 if existing and not req.replace: # Merge: add counts; keep earliest first_seen_at; latest # last_seen_at; sum sources/latency. existing.count += int(row.get("count", 0)) existing.sum_sources += int(row.get("sum_sources", 0)) existing.sum_latency_ms += float( row.get("sum_latency_ms", 0.0)) if (row.get("first_seen_at") is not None and (existing.first_seen_at is None or row["first_seen_at"] < existing.first_seen_at)): existing.first_seen_at = row["first_seen_at"] if (row.get("last_seen_at") is not None and (existing.last_seen_at is None or row["last_seen_at"] > existing.last_seen_at)): existing.last_seen_at = row["last_seen_at"] else: qs._data[fp] = _QS( # noqa: SLF001 fingerprint=fp, sample=row.get("sample", ""), count=int(row.get("count", 0)), first_seen_at=row.get("first_seen_at"), last_seen_at=row.get("last_seen_at"), sum_sources=int(row.get("sum_sources", 0)), sum_latency_ms=float(row.get("sum_latency_ms", 0.0)), ) totals["query_stats"] = len( (dump.get("query_stats") or {}).get("rows") or []) # doc_freshness if req.replace: set_doc_freshness(DocFreshnessStore()) fs = get_doc_freshness() from ..middleware.doc_freshness import DocFreshness as _DF for row in (dump.get("doc_freshness") or {}).get("rows") or []: did = row.get("doc_id") if not did: continue fs._data[did] = _DF( # noqa: SLF001 doc_id=did, added_at=float(row.get("added_at") or 0.0), last_modified_at=row.get("last_modified_at"), n_modifications=int(row.get("n_modifications", 0)), ) totals["doc_freshness"] = len( (dump.get("doc_freshness") or {}).get("rows") or []) # query_doc_affinity if req.replace: set_query_doc_affinity(QueryDocAffinityStore()) qda = get_query_doc_affinity() n_qda = 0 for pair in (dump.get("query_doc_affinity") or {}).get("pairs") or []: fp = pair.get("fingerprint"); did = pair.get("doc_id") count = int(pair.get("count", 0)) if not fp or not did or count <= 0: continue k = (fp, did) qda._pairs[k] = int(qda._pairs.get(k, 0)) + count # noqa: SLF001 qda._by_query[fp].add(did) # noqa: SLF001 qda._by_doc[did].add(fp) # noqa: SLF001 ls = pair.get("last_seen") if ls is not None: qda._last_seen[k] = float(ls) # noqa: SLF001 n_qda += 1 n_events_qda = int( (dump.get("query_doc_affinity") or {}).get("n_events", 0)) if n_events_qda: qda._n_events = qda._n_events + n_events_qda # noqa: SLF001 totals["query_doc_affinity_pairs"] = n_qda get_obs().audit( "analytics.restore", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), version=ver, replace=bool(req.replace), totals=totals, ) return { "restored": True, "replace": bool(req.replace), "version": ver, "totals": totals, } @app.get("/v1/admin/queries/affinity/summary") def admin_query_doc_affinity_summary(): """Rollup of query × doc affinity matrix: n_events, n_pairs, n_queries, n_docs, total_count (v1.96).""" from ..middleware import get_query_doc_affinity return get_query_doc_affinity().summary() @app.get("/v1/admin/queries/{fingerprint}/top-docs") def admin_query_top_docs(fingerprint: str, n: int = 10): """Which docs does this query most often cite? (v1.96). ``fingerprint`` is the v1.89 canonical fingerprint. For a text query, run it through ``_fingerprint(canonical)`` first — or use v1.89 lookup endpoints.""" from ..middleware import get_query_doc_affinity, get_query_stats store = get_query_doc_affinity() rows = store.top_docs_for_query(fingerprint, n=int(n)) # Bonus: include the sample text from v1.89 if known qs_row = get_query_stats().get(fingerprint) return { "fingerprint": fingerprint, "sample": qs_row.sample if qs_row else None, "top_docs": rows, } @app.get("/v1/admin/documents/{doc_id}/top-queries") def admin_doc_top_queries(doc_id: str, n: int = 10): """Which queries lead to this doc being cited? (v1.96). Returns fingerprints + counts, with each fingerprint's sample text attached if still known to v1.89's query_stats store.""" from ..middleware import get_query_doc_affinity, get_query_stats store = get_query_doc_affinity() rows = store.top_queries_for_doc(doc_id, n=int(n)) # Enrich with query text samples from v1.89 qs = get_query_stats() for row in rows: r = qs.get(row["fingerprint"]) row["sample"] = r.sample if r else None return {"doc_id": doc_id, "top_queries": rows} @app.post("/v1/admin/queries/affinity/reset") def admin_query_doc_affinity_reset(request: Request): """Wipe the query × doc affinity matrix + audit (v1.96).""" from ..middleware import get_query_doc_affinity store = get_query_doc_affinity() before = store.summary() store.clear() get_obs().audit( "query_doc_affinity.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_n_pairs=before["n_pairs"], prev_total_count=before["total_count"], ) return {"reset": True, "before": before} @app.get("/v1/admin/documents/update-priorities") def admin_doc_update_priorities( n: int = 20, min_cited: int = 1, older_than_days: float = 0.0, alpha: float = 1.0, ): """Rank docs by "needs update" priority (v1.95) — cross-cut of v1.82 doc_stats × v1.94 doc_freshness. Priority score = ``n_cited * (age_days ** alpha)``. High traffic plus stale content → high score. Docs get flagged from both sides: low cite = not worth the review; recently-modified = doesn't need review yet. ``alpha`` tunes how heavily age dominates (α>1 = age matters more; α<1 = traffic matters more). Filters: * ``min_cited`` — minimum citation count (default 1 — skip cold docs entirely, they're content-audit's problem, not update's). * ``older_than_days`` — minimum age in days (default 0 — let the caller decide what ''stale'' means). * ``n`` — cap on rows returned. * ``alpha`` — exponent on age_days in the score. 1.0 is linear (balanced); raise to prioritize stale content harder. """ from ..middleware import get_doc_stats, get_doc_freshness import time as _t docs = get_doc_stats() fresh = get_doc_freshness() now = _t.time() # Walk the smaller side (whichever has fewer entries) and join. # We go through doc_stats (usually <= corpus size) because stale # docs with zero traffic are noise here — we want things BOTH # sides know about. rows: List[Dict[str, Any]] = [] for doc_row in docs.top_cited(n=10 ** 9): did = doc_row["doc_id"] if doc_row["n_cited"] < int(min_cited): continue f = fresh.get(did) if f is None: continue ref_ts = f.last_modified_at or f.added_at age_days = max(0.0, (now - ref_ts) / 86400.0) if age_days < float(older_than_days): continue try: aged = age_days ** float(alpha) except (OverflowError, ValueError): aged = age_days score = float(doc_row["n_cited"]) * aged rows.append({ "doc_id": did, "n_cited": doc_row["n_cited"], "n_retrieved": doc_row["n_retrieved"], "cite_rate": doc_row.get("cite_rate", 0.0), "added_at": f.added_at, "last_modified_at": f.last_modified_at, "age_days": round(age_days, 2), "n_modifications": f.n_modifications, "priority_score": round(score, 2), }) rows.sort(key=lambda r: r["priority_score"], reverse=True) rows = rows[:max(0, int(n))] return { "n_candidates": len(rows), "n": int(n), "min_cited": int(min_cited), "older_than_days": float(older_than_days), "alpha": float(alpha), "candidates": rows, } @app.get("/v1/documents/{doc_id}/freshness") def get_document_freshness(doc_id: str): """Per-doc freshness: added_at, last_modified_at, n_modifications, age_s, age_days (v1.94).""" from ..middleware import get_doc_freshness row = get_doc_freshness().get(doc_id) if row is None: raise HTTPException( status_code=404, detail={"error": "no freshness record for doc", "doc_id": doc_id}, ) return row.to_dict() @app.get("/v1/documents/{doc_id}/related") def get_document_related(doc_id: str, n: int = 10): """Docs most commonly co-cited with ``doc_id`` in actual traffic (v1.84). Empirical 'related' — purely behavioural.""" from ..middleware import get_cocitation related = get_cocitation().related(doc_id, n=int(n)) return {"doc_id": doc_id, "related": related} @app.get("/v1/documents/{doc_id}/stats") def get_document_stats(doc_id: str): """Per-document citation + retrieval counters (v1.82).""" from ..middleware import get_doc_stats row = get_doc_stats().get(doc_id) if row is None: raise HTTPException( status_code=404, detail={"error": "no stats for doc", "doc_id": doc_id}, ) return row.to_dict() @app.get("/v1/documents/{doc_id}/chunks") def get_document_chunks(doc_id: str, chunker: Optional[str] = None): """Return the chunks the retrievers actually index for this doc. Re-runs the configured chunker on-demand; pass ``?chunker=sentence`` to preview alternative chunkings without changing the index.""" d = _pipeline.get_document(doc_id) if d is None: raise HTTPException(status_code=404, detail="document not found") chunks = _pipeline.get_chunks(doc_id, chunker=chunker) return { "doc_id": doc_id, "n_chunks": len(chunks), "chunker": chunker or getattr(_pipeline, "_chunker_last", "fixed"), "chunks": chunks, } @app.get("/v1/documents/{doc_id}") def get_document(doc_id: str): d = _pipeline.get_document(doc_id) if d is None: raise HTTPException(status_code=404, detail="document not found") return {"id": d.id, "text": d.text, "metadata": d.metadata or {}} @app.put("/v1/documents/{doc_id}") def replace_document(doc_id: str, body: DocumentBody): if body.id != doc_id: raise HTTPException(status_code=422, detail={"path_id_mismatch": {"url": doc_id, "body": body.id}}) try: # Reuse the doc-size validator validate_doc_list([body]) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) doc = Document(id=body.id, text=body.text, metadata=body.metadata) ok = _pipeline.replace_document(doc) if not ok: raise HTTPException(status_code=404, detail="document not found") return {"replaced": True, "id": doc_id} @app.delete("/v1/documents/{doc_id}") def delete_document(doc_id: str): ok = _pipeline.delete_document(doc_id) if not ok: raise HTTPException(status_code=404, detail="document not found") return {"deleted": True, "id": doc_id} @app.delete("/v1/documents") def clear_documents(request: Request): n = _pipeline.clear_documents() get_obs().audit( "documents.clear", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), removed=n, ) return {"cleared": True, "removed": n} # ---- bulk ingest (JSONL + CSV streaming) ------------------------------ import csv as _csv # noqa: E402 import io as _io # noqa: E402 import json # noqa: E402 def _parse_jsonl(text: str): """Yield (row_num, doc_dict_or_error) tuples for a JSONL payload. Blank lines and `#` comment lines are skipped silently.""" for i, line in enumerate(text.splitlines(), start=1): line = line.strip() if not line or line.startswith("#"): continue try: obj = json.loads(line) if not isinstance(obj, dict): raise ValueError("row is not a JSON object") yield i, obj except Exception as e: yield i, {"__error__": str(e)} def _parse_csv(text: str): """Yield (row_num, doc_dict_or_error) tuples for a CSV payload. Expects columns: id, text, and optional metadata columns merged into a single metadata dict.""" reader = _csv.DictReader(_io.StringIO(text)) if reader.fieldnames is None or "id" not in reader.fieldnames \ or "text" not in reader.fieldnames: raise HTTPException( status_code=400, detail={"csv_missing_columns": "required: 'id' and 'text'"}, ) meta_cols = [c for c in reader.fieldnames if c not in ("id", "text")] for i, row in enumerate(reader, start=2): # row 1 is header try: metadata = {c: row[c] for c in meta_cols if row.get(c) not in (None, "")} yield i, {"id": row["id"], "text": row["text"], "metadata": metadata} except Exception as e: yield i, {"__error__": str(e)} @app.post("/v1/documents/bulk") async def bulk_ingest_documents(request: Request): """Bulk ingest — JSONL (one ``{"id","text","metadata"}`` per line) or CSV (columns: id, text, [any other] → metadata). Partial success semantics: each row parsed+validated independently, successes indexed into the pipeline, failures reported with row numbers. Content-Type: * ``application/x-ndjson`` or ``application/jsonl`` → JSONL * ``text/csv`` → CSV * anything else → JSONL (default) """ ct = (request.headers.get("content-type") or "").split(";", 1)[0].strip().lower() raw = (await request.body()).decode("utf-8", errors="replace") if ct in ("text/csv",): iterator = _parse_csv(raw) else: iterator = _parse_jsonl(raw) # Enforce per-batch size limit from v1.35 from .errors import Limits accepted: List[Document] = [] errors: List[Dict[str, Any]] = [] row_n = 0 for row_num, obj in iterator: row_n += 1 if "__error__" in obj: errors.append({"row": row_num, "error": obj["__error__"]}) continue text = obj.get("text") if not isinstance(text, str) or not text: errors.append({"row": row_num, "error": "missing or empty 'text'"}) continue if len(text) > Limits.max_doc_text_len: errors.append({"row": row_num, "error": f"text exceeds {Limits.max_doc_text_len} chars"}) continue if len(accepted) >= Limits.max_docs_per_batch: errors.append({"row": row_num, "error": f"batch cap reached (max {Limits.max_docs_per_batch})"}) continue accepted.append(Document( id=obj.get("id") or f"row-{row_num}", text=text, metadata=obj.get("metadata") or {}, )) chunks = _pipeline.add_documents(accepted) if accepted else 0 get_obs().audit( "documents.bulk_ingest", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), content_type=ct, rows_total=row_n, accepted=len(accepted), errors=len(errors), ) return { "accepted": [d.id for d in accepted], "errors": errors, "added_chunks": chunks, "rows_total": row_n, } @app.post("/v1/search") def search(req: SearchRequest): try: validate_query_text(req.query) validate_k(req.k) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) try: strategy = Strategy(req.strategy) except ValueError: raise HTTPException(status_code=400, detail={"bad_strategy": req.strategy}) q = Query( text=req.query, lang=req.lang, filters=req.filters, strategy=strategy, k=req.k, rerank_k=req.rerank_k, ) per = _pipeline.retrievers.search_per_retriever(q, req.k) return { "per_retriever": { name: [{"doc": r.chunk.doc_id, "chunk": r.chunk.chunk_id, "score": r.score, "rank": r.rank, "text": r.chunk.text[:300]} for r in lst] for name, lst in per.items() } } # ---- batch query (v1.54) ------------------------------------------------- class BatchQueryItem(BaseModel): query: str k: int = 10 rerank_k: int = 5 strategy: str = "hybrid" lang: str = "he" filters: Dict[str, Any] = {} class BatchQueryRequest(BaseModel): queries: List[BatchQueryItem] @app.post("/v1/batch") def batch_query(req: BatchQueryRequest, request: Request): """Run many queries in a single HTTP call — useful for eval runners, benchmarks, and bulk re-indexing workflows. Cap: ``Limits.max_docs_per_batch`` items per call (reusing the doc-limit env knob). Each item is validated independently; per-item errors are returned alongside successful responses. Cache + rate-limit apply to the overall request, not per-item (so a single admin call can sweep many queries without tripping the limiter). """ from .errors import Limits if not req.queries: return {"n": 0, "results": [], "errors": [], "total_ms": 0} if len(req.queries) > Limits.max_docs_per_batch: raise HTTPException( status_code=413, detail=f"too many queries — max {Limits.max_docs_per_batch} per batch", ) import time as _t t0 = _t.time() results: List[Dict[str, Any]] = [] errors: List[Dict[str, Any]] = [] for i, item in enumerate(req.queries, start=1): try: validate_query_text(item.query) validate_k(item.k) strategy = Strategy(item.strategy) q = Query(text=item.query, lang=item.lang, filters=item.filters, strategy=strategy, k=item.k, rerank_k=item.rerank_k) resp = _pipeline.run(q) try: omega = float(resp.signals.omega) if resp.signals else None except Exception: omega = None verif = getattr(resp, "verification", None) results.append({ "index": i, "query": item.query, "answer": resp.answer or "", "sources": list(resp.sources or []), "omega": omega, "passed": bool(getattr(verif, "passed", False)) if verif else None, }) except HTTPException as e: errors.append({"index": i, "error": str(e.detail), "status": e.status_code}) except Exception as e: errors.append({"index": i, "error": f"{type(e).__name__}: {e}"[:240], "status": 500}) total_ms = (_t.time() - t0) * 1000.0 get_obs().audit( "batch.query", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), n=len(req.queries), errors=len(errors), total_ms=round(total_ms, 2), ) return { "n": len(req.queries), "results": results, "errors": errors, "total_ms": round(total_ms, 2), "avg_ms": round(total_ms / max(1, len(req.queries)), 2), } @app.post("/v1/generate/stream", response_class=None) def generate_stream(req: SearchRequest, request: Request): """Server-Sent Events version of /v1/generate. Emits events in order: event: retrieved data: {"doc_ids": [...], "count": N} event: answer data: {"chunk": "word "} (repeated) event: done data: {"answer","sources","signals","verification", "passed","omega"} event: error data: {"code","message"} (on failure) Flow: runs the full pipeline synchronously (it's ~ms for extractive) and streams the staged output. Each SSE event is . Clients: * browser: EventSource('/v1/generate/stream' ... POST) * curl --no-buffer -N -H 'Content-Type: application/json' \ -d '{"query":"..."}' http://localhost:8000/v1/generate/stream * SDK: for ev in client.stream_query("..."): ... """ try: validate_query_text(req.query) validate_k(req.k) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) try: strategy = Strategy(req.strategy) except ValueError: raise HTTPException(status_code=400, detail={"bad_strategy": req.strategy}) from fastapi.responses import StreamingResponse import json as _json def _sse(event: str, data: Any) -> str: return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n" def _event_gen(): try: q = Query(text=req.query, lang=req.lang, filters=req.filters, strategy=strategy, k=req.k, rerank_k=req.rerank_k) resp = _pipeline.run(q) # Stage 1: retrieval results retrieved = [] seen = set() for c in getattr(resp, "retrieved", []) or []: did = getattr(getattr(c, "chunk", None), "doc_id", None) if did and did not in seen: retrieved.append(did) seen.add(did) yield _sse("retrieved", {"doc_ids": retrieved, "count": len(retrieved)}) # Stage 2: answer streamed word-by-word answer = resp.answer or "" words = answer.split(" ") for w in words: if not w: continue yield _sse("answer", {"chunk": w + " "}) # Stage 3: final envelope try: omega = float(resp.signals.omega) if resp.signals else None except Exception: omega = None verif = getattr(resp, "verification", None) yield _sse("done", { "answer": answer, "sources": list(resp.sources or []), "omega": omega, "passed": bool(getattr(verif, "passed", False)) if verif else None, "verification": (verif.to_dict() if hasattr(verif, "to_dict") else getattr(verif, "__dict__", None)), }) except Exception as e: yield _sse("error", { "code": "internal_error", "message": f"{type(e).__name__}: {e}"[:240], }) return StreamingResponse( _event_gen(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # nginx — flush immediately }, ) @app.post("/v1/generate") def generate(req: SearchRequest): try: validate_query_text(req.query) validate_k(req.k) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) try: strategy = Strategy(req.strategy) except ValueError: raise HTTPException(status_code=400, detail={"bad_strategy": req.strategy}) # Cache hit short-circuit cache = get_cache() cache_key = cache.make_key( f"{req.query}|{req.strategy}|{req.k}|{req.rerank_k}", req.lang, req.filters, ) cached = cache.get(cache_key) if cached is not None: cached = dict(cached) cached["_cache"] = "hit" return cached q = Query( text=req.query, lang=req.lang, filters=req.filters, strategy=strategy, k=req.k, rerank_k=req.rerank_k, ) out = _pipeline.run(q).to_dict() cache.put(cache_key, out) out = dict(out); out["_cache"] = "miss" return out # ---- saved query presets (v1.66) ---------------------------------------- class QueryPresetBody(BaseModel): query: str k: int = 10 rerank_k: int = 5 strategy: str = "hybrid" lang: str = "he" notes: str = "" @app.get("/v1/queries") def list_query_presets(): """List all saved query presets. Unauthenticated (queries are public).""" from ..presets import get_preset_store presets = get_preset_store().list_all() return {"count": len(presets), "presets": presets} @app.get("/v1/queries/{name}") def get_query_preset(name: str): from ..presets import get_preset_store p = get_preset_store().get(name) if p is None: raise HTTPException(status_code=404, detail={"preset_not_found": name}) return p.to_dict() @app.put("/v1/queries/{name}") def save_query_preset(name: str, body: QueryPresetBody, request: Request): from ..presets import QueryPreset, get_preset_store try: validate_query_text(body.query) validate_k(body.k) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) preset = QueryPreset( name=name, query=body.query, k=body.k, rerank_k=body.rerank_k, strategy=body.strategy, lang=body.lang, notes=body.notes, ) get_preset_store().put(preset) get_obs().audit( "query_preset.put", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), name=name, ) return preset.to_dict() @app.delete("/v1/queries/{name}") def delete_query_preset(name: str, request: Request): from ..presets import get_preset_store ok = get_preset_store().remove(name) if not ok: raise HTTPException(status_code=404, detail={"preset_not_found": name}) get_obs().audit( "query_preset.remove", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), name=name, ) return {"removed": True, "name": name} @app.post("/v1/queries/{name}/run") def run_query_preset(name: str): """Execute a saved preset — equivalent to ``POST /v1/generate`` with the stored parameters. Returns the full /generate response.""" from ..presets import get_preset_store p = get_preset_store().get(name) if p is None: raise HTTPException(status_code=404, detail={"preset_not_found": name}) req = SearchRequest(query=p.query, k=p.k, rerank_k=p.rerank_k, strategy=p.strategy, lang=p.lang, filters={}) return generate(req) @app.post("/v1/generate/timings") def generate_timings(req: SearchRequest): """Run the query but return only the per-stage latency breakdown + Ω — no answer text, no candidate list. For ops/profiling workflows that need to know WHERE time is spent, not WHAT was returned. Response shape: { "query": "...", "timings_ms": {understand, retrieve, fuse, rerank, generate, verify, signals, total}, "omega": 0.67, "n_sources": 2, "cache": "miss" } """ try: validate_query_text(req.query) validate_k(req.k) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) try: strategy = Strategy(req.strategy) except ValueError: raise HTTPException(status_code=400, detail={"bad_strategy": req.strategy}) q = Query( text=req.query, lang=req.lang, filters=req.filters, strategy=strategy, k=req.k, rerank_k=req.rerank_k, ) resp = _pipeline.run(q) try: omega = float(resp.signals.omega) if resp.signals else None except Exception: omega = None return { "query": req.query, "timings_ms": dict(resp.timing_ms or {}), "omega": omega, "n_sources": len(resp.sources or []), "passed": (bool(resp.verification.passed) if resp.verification else None), } @app.get("/v1/admin/stats") def admin_stats(): from ..middleware import get_webhook_dispatcher return { "cache": get_cache().stats(), "rate_limiter": get_limiter().stats(), "observability": get_obs().stats(), "cached_queries": len(_pipeline.cache), "webhook": get_webhook_dispatcher().stats(), } @app.get("/v1/admin/webhook") def admin_webhook_stats(): """Report the audit webhook dispatcher state (v1.71) including the circuit breaker status (v1.79).""" from ..middleware import get_webhook_dispatcher return get_webhook_dispatcher().stats() # ---- alert rules (v1.80) ------------------------------------------------- class AlertRuleRequest(BaseModel): name: str metric: str op: str threshold: float window_s: float = 300.0 cooldown_s: float = 600.0 enabled: bool = True description: str = "" @app.get("/v1/admin/alerts") def admin_alerts_list(): """List all configured alert rules (v1.80).""" from ..middleware import get_alert_store return {"rules": [r.to_dict() for r in get_alert_store().list_all()]} @app.put("/v1/admin/alerts/{name}") def admin_alerts_put(name: str, req: AlertRuleRequest, request: Request): """Create or update an alert rule. ``name`` in the path wins over the body — consistent with PUT semantics.""" from ..middleware import get_alert_store, AlertRule try: rule = AlertRule( name=name, metric=req.metric, op=req.op, threshold=float(req.threshold), window_s=float(req.window_s), cooldown_s=float(req.cooldown_s), enabled=bool(req.enabled), description=req.description, ) rule = get_alert_store().put(rule) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) get_obs().audit( "alert.rule.put", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), name=name, metric=req.metric, op=req.op, threshold=float(req.threshold), ) return rule.to_dict() @app.delete("/v1/admin/alerts/{name}") def admin_alerts_delete(name: str, request: Request): from ..middleware import get_alert_store ok = get_alert_store().delete(name) if not ok: raise HTTPException(status_code=404, detail={"error": "alert rule not found", "name": name}) get_obs().audit( "alert.rule.delete", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), name=name, ) return {"deleted": True, "name": name} @app.get("/v1/admin/alerts/scheduler") def admin_alerts_scheduler_status(): """Report the background alert-evaluator state (v1.81).""" from ..middleware import get_alert_scheduler sched = get_alert_scheduler() if sched is None: return {"enabled": False, "is_running": False} return {"enabled": True, **sched.status()} @app.post("/v1/admin/alerts/evaluate") def admin_alerts_evaluate(request: Request): """Run every enabled rule once against the current metrics history. Returns a list of verdicts; rules that fire also emit an ``alert.fired`` audit event (which flows through the webhook dispatcher per v1.71).""" from ..middleware import ( get_alert_store, get_metrics_history, evaluate_all, ) actor = request.headers.get("x-api-key") rid = getattr(request.state, "request_id", None) def _on_fire(verdict): get_obs().audit( "alert.fired", actor_key=actor, request_id=rid, rule=verdict["rule"], reason=verdict["reason"], latest_value=verdict["latest_value"], n_samples=verdict["n_samples"], ) verdicts = evaluate_all( get_alert_store(), get_metrics_history(), on_fire=_on_fire, ) return { "verdicts": verdicts, "n_fired": sum(1 for v in verdicts if v["fired"]), "n_suppressed": sum(1 for v in verdicts if v["suppressed"]), "n_evaluated": len(verdicts), } @app.post("/v1/admin/webhook/breaker/reset") def admin_webhook_breaker_reset(request: Request): """Manually force the audit-webhook circuit breaker back to CLOSED and clear its failure counters (v1.79). Useful after fixing a downstream outage without waiting for the cooldown probe.""" from ..middleware import get_webhook_dispatcher disp = get_webhook_dispatcher() before = disp.breaker.stats() disp.breaker.reset() after = disp.breaker.stats() get_obs().audit( "webhook.breaker.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), prev_state=before["state"], prev_fail_count=before["fail_count"], ) return {"reset": True, "before": before, "after": after} @app.get("/v1/admin/metrics/history") def admin_metrics_history( since: Optional[float] = None, until: Optional[float] = None, metric: Optional[str] = None, limit: int = 1000, ): """Return sampled time-series metrics (v1.78). Query params: * ``since`` / ``until`` — Unix timestamps bounding the window. * ``metric`` — dotted path (e.g. ``obs.p95_ms``, ``cache.hit_rate``, ``limiter.denied``) to project each sample to ``{ts, value}``. Omit for full samples. * ``limit`` — hard cap on rows returned (default 1000). The sampler is off by default; enable via ``TAU_RAG_METRICS_HISTORY_INTERVAL_SEC=10`` at server start, or call ``MetricsHistorySampler(h, interval_s=10.0).start()`` directly. """ from ..middleware import get_metrics_history, get_metrics_sampler h = get_metrics_history() rows = h.history(since=since, until=until, metric=metric) # Enforce upper bound — take the newest ``limit`` rows. if limit and len(rows) > int(limit): rows = rows[-int(limit):] sampler = get_metrics_sampler() return { "samples": rows, "count": len(rows), "capacity": h.capacity(), "metric": metric, "sampler": (sampler.status() if sampler else {"is_running": False, "interval_s": None}), } @app.post("/v1/admin/metrics/history/sample") def admin_metrics_history_sample_now(request: Request): """Force one immediate sample (useful for tests and 'capture before change' workflows). Returns the sample that was just captured.""" from ..middleware import get_metrics_history row = get_metrics_history().sample() get_obs().audit( "metrics.sample", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ts=row["ts"], ) return row @app.post("/v1/admin/cache/clear") def admin_cache_clear(request: Request): get_cache().clear() _pipeline.cache.clear() get_obs().audit( "cache.clear", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ) return {"cleared": True} @app.get("/v1/admin/logs") def admin_logs(n: int = 50, event_type: Optional[str] = None): """Tail the in-memory request/audit log. `event_type=audit` to see just admin actions, `event_type=request` for HTTP traffic.""" return {"logs": get_obs().tail(n=n, event_type=event_type)} @app.get("/v1/admin/audit/export", response_class=None) def admin_audit_export( since: Optional[float] = None, until: Optional[float] = None, event_type: Optional[str] = None, format: str = "jsonl", limit: int = 10_000, ): """Export the observability ring buffer as JSONL (default) or JSON. Query params: * ``since`` — Unix timestamp, inclusive lower bound on entry.ts * ``until`` — Unix timestamp, exclusive upper bound * ``event_type`` — ``audit`` | ``request`` (omit = both) * ``format`` — ``jsonl`` (default, streamable) or ``json`` * ``limit`` — hard cap on rows returned (default 10,000) Returns a file attachment with ``Content-Disposition``; intended to be piped to compliance storage, fed to a SIEM, or spot-checked with jq. """ from fastapi.responses import PlainTextResponse import json as _json rows = get_obs().tail(n=max(1, int(limit)), event_type=event_type) # Apply time filters (tail already returns newest-last) if since is not None: rows = [r for r in rows if float(r.get("ts") or 0) >= float(since)] if until is not None: rows = [r for r in rows if float(r.get("ts") or 0) < float(until)] if format == "json": body = _json.dumps(rows, ensure_ascii=False, indent=2) media = "application/json" suffix = "json" else: # default jsonl body = "\n".join(_json.dumps(r, ensure_ascii=False) for r in rows) if rows: body += "\n" media = "application/x-ndjson" suffix = "jsonl" filename = f"tau-rag-audit.{suffix}" return PlainTextResponse( body, media_type=media, headers={ "Content-Disposition": f'attachment; filename="{filename}"', "X-Entry-Count": str(len(rows)), }, ) @app.get("/v1/admin/logs/stream", response_class=None) def admin_logs_stream( event_type: Optional[str] = None, heartbeat_s: float = 15.0, replay_last: int = 0, max_events: int = 0, max_heartbeats: int = 0, ): """Live SSE tail of the observability log (v1.75). Query params: * ``event_type`` — ``audit`` | ``request`` (omit = both) * ``heartbeat_s`` — emit a ``:heartbeat`` comment every N seconds so proxies/load balancers don't idle-kill the connection (default 15s). * ``replay_last`` — on connect, emit the last N buffered entries before live-tailing (0 = live only). * ``max_events`` — stop after this many ``log`` events (0 = unbounded). Useful for bounded tails and deterministic testing. * ``max_heartbeats`` — stop after this many heartbeat ticks (0 = unbounded). Useful for "give me whatever's there within N seconds, then close". Event names: * ``log`` — new entry (data is the same dict as ``/v1/admin/logs``) * (heartbeat is a ``:`` SSE comment line, per the SSE spec) Client disconnect frees the subscriber queue; drop-oldest on a slow reader so the request path is never back-pressured. """ from fastapi.responses import StreamingResponse import json as _json import queue as _q obs = get_obs() sub = obs.subscribe(maxsize=256) def _sse(event: str, data: Any) -> str: return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n" def _gen(): emitted_logs = 0 heartbeats = 0 try: # Replay tail first (filtered by event_type if set) if replay_last and replay_last > 0: for row in obs.tail(n=int(replay_last), event_type=event_type): yield _sse("log", row) emitted_logs += 1 if max_events and emitted_logs >= max_events: return # Live loop — block up to heartbeat_s for next entry; if # nothing arrived emit a comment keep-alive. hb = max(0.05, float(heartbeat_s)) while True: try: row = sub.get(timeout=hb) except _q.Empty: heartbeats += 1 yield ": heartbeat\n\n" if max_heartbeats and heartbeats >= max_heartbeats: return continue if event_type and row.get("event_type") != event_type: continue yield _sse("log", row) emitted_logs += 1 if max_events and emitted_logs >= max_events: return except GeneratorExit: pass finally: obs.unsubscribe(sub) return StreamingResponse( _gen(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, ) @app.get("/v1/admin/ui", response_class=None) def admin_ui(refresh: int = 0): """Unified read-only admin dashboard (HTML).""" from fastapi.responses import HTMLResponse from .admin_ui import render_admin_ui from .metrics import check_readiness _, ready_detail = check_readiness(_pipeline) html = render_admin_ui( cache_stats = get_cache().stats(), limiter_stats = get_limiter().stats(), obs_stats = get_obs().stats(), recent_requests = get_obs().tail(n=20, event_type="request"), recent_audits = get_obs().tail(n=20, event_type="audit"), keys = get_auth().list_keys(), documents = _pipeline.list_documents(), readiness = ready_detail, refresh_sec = int(refresh or 0), ) return HTMLResponse(html) # ---- snapshot / restore -------------------------------------------------- class SnapshotSaveRequest(BaseModel): path: Optional[str] = None rotate: int = 0 # v1.67: keep last N rotated generations class SnapshotLoadRequest(BaseModel): path: Optional[str] = None replace: bool = False generation: int = 0 # v1.67: load a specific rotated generation def _default_snapshot_path() -> str: return _os.environ.get("TAU_RAG_SNAPSHOT_PATH") or "runtime/snapshot.jsonl" # ---- eval harness endpoint ---------------------------------------------- class EvalRequest(BaseModel): cases: List[Dict[str, Any]] k: int = 5 thresholds: Optional[Dict[str, float]] = None @app.post("/v1/admin/eval") def admin_eval(req: EvalRequest, request: Request): """Run the pipeline against a gold set inline; return aggregate metrics. Body: ``{"cases": [{id, query, expected_doc_ids, expected_claims?, lang?}, ...], "k": 5, "thresholds": {"recall@5": 0.7, ...}}`` """ from ..eval import GoldCase, run_eval cases = [GoldCase.from_dict(c) for c in req.cases] report = run_eval(_pipeline, cases, k=req.k) failures = report.fail_below(req.thresholds or {}) get_obs().audit( "eval.run", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), n_cases=report.n_cases, aggregate=report.aggregate, failures=failures, ) return { "n_cases": report.n_cases, "aggregate": report.aggregate, "latency_ms": report.latency_ms, "omega": report.omega, "per_case": [c.to_dict() for c in report.per_case], "failures": failures, "passed": len(failures) == 0, } @app.post("/v1/admin/snapshot/save") def admin_snapshot_save(req: SnapshotSaveRequest, request: Request): path = req.path or _default_snapshot_path() summary = _pipeline.save_snapshot(path, rotate=int(req.rotate or 0)) get_obs().audit( "snapshot.save", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), **summary, ) return summary @app.get("/v1/admin/snapshot/history") def admin_snapshot_history(path: Optional[str] = None, max_gens: int = 10): """List rotated snapshot generations on disk. ``path`` defaults to the env-configured snapshot path.""" from ..snapshot import list_snapshot_history base = path or _default_snapshot_path() return { "base_path": base, "generations": list_snapshot_history(base, max_gens=max_gens), } class SnapshotDiffRequest(BaseModel): a: str # path to snapshot A (the "before") b: str # path to snapshot B (the "after") include_details: bool = False # expand modified[] to {id, lens, hashes} @app.post("/v1/admin/snapshot/diff") def admin_snapshot_diff(req: SnapshotDiffRequest, request: Request): """Compare two snapshots (v1.77). Returns added / removed / modified doc IDs + per-snapshot metadata + a ``same_fingerprint`` boolean. Use cases: * CI: "what's new in this PR?" (diff main vs branch) * QA: "did we accidentally delete docs?" (diff yesterday vs today) * compliance: "what changed between quarters?" """ from ..snapshot import diff_snapshots import os as _os for label, p in (("a", req.a), ("b", req.b)): if not _os.path.exists(p): raise HTTPException( status_code=404, detail={"error": "snapshot not found", "which": label, "path": p}, ) result = diff_snapshots(req.a, req.b, include_details=bool(req.include_details)) get_obs().audit( "snapshot.diff", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), a=req.a, b=req.b, n_added=len(result["added"]), n_removed=len(result["removed"]), n_modified=len(result["modified"]), ) return result # ---- runtime config (v1.51) --------------------------------------------- class ConfigUpdateRequest(BaseModel): # Whitelist of tunable fields. Add to `_TUNABLE` below before exposing. updates: Dict[str, Any] # (path-in-Config, validator_fn, human_description) _TUNABLE: Dict[str, Any] = { "verify.min_omega": ( lambda v: isinstance(v, (int, float)) and 0.0 <= float(v) <= 1.0, "Minimum Ω signal for response.passed; 0.55 default", ), "verify.min_citation_coverage": ( lambda v: isinstance(v, (int, float)) and 0.0 <= float(v) <= 1.0, "Fraction of answer claims that must be cited; 0.8 default", ), } def _config_to_dict(cfg) -> Dict[str, Any]: """Recursively convert the nested Config dataclass to a plain dict.""" from dataclasses import is_dataclass, asdict if is_dataclass(cfg): return asdict(cfg) return dict(cfg) def _apply_config_update(cfg, key_path: str, value: Any) -> None: """Set ``cfg.a.b.c = value`` for a dotted ``key_path``.""" parts = key_path.split(".") obj = cfg for part in parts[:-1]: obj = getattr(obj, part) setattr(obj, parts[-1], value) @app.get("/v1/admin/config") def admin_get_config(): """Return the live effective configuration + list of tunable keys.""" return { "config": _config_to_dict(_pipeline.config), "tunable": { k: {"description": desc} for k, (_, desc) in _TUNABLE.items() }, } @app.post("/v1/admin/config") def admin_update_config(req: ConfigUpdateRequest, request: Request): """Update whitelisted config values at runtime. Clears the query cache so subsequent requests use the new thresholds.""" applied: Dict[str, Any] = {} rejected: List[Dict[str, Any]] = [] for key, new_val in req.updates.items(): entry = _TUNABLE.get(key) if entry is None: rejected.append({"key": key, "reason": "not in whitelist"}) continue validator, _desc = entry if not validator(new_val): rejected.append({"key": key, "reason": "validation failed", "value": new_val}) continue try: _apply_config_update(_pipeline.config, key, new_val) applied[key] = new_val except Exception as e: rejected.append({"key": key, "reason": f"{type(e).__name__}: {e}"}) # Clear query cache so future calls use the new threshold if applied: _pipeline.cache.clear() get_cache().clear() get_obs().audit( "config.update", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), applied=applied, rejected=rejected, ) return {"applied": applied, "rejected": rejected, "cache_cleared": bool(applied)} # ---- Hebrew synonyms CRUD (v1.52) -------------------------------------- class SynonymBody(BaseModel): canonical: str variants: List[str] def _maybe_autosave_synonyms() -> None: """If TAU_RAG_SYNONYMS_PATH is set, persist current dict to disk.""" path = _os.environ.get("TAU_RAG_SYNONYMS_PATH") if not path: return try: from ..core.hebrew_synonyms import save_synonyms_jsonl save_synonyms_jsonl(path) except Exception as _e: print(f"[tau-rag] synonym autosave failed: {_e}") @app.get("/v1/admin/synonyms") def admin_list_synonyms(q: Optional[str] = None): """List synonym entries. ``?q=`` filters by substring in canonical or variants.""" from ..core.hebrew_synonyms import list_synonyms all_syn = list_synonyms() if q: qn = q.strip() all_syn = { k: v for k, v in all_syn.items() if qn in k or any(qn in x for x in v) } return {"count": len(all_syn), "synonyms": all_syn} class SynonymBulkRequest(BaseModel): entries: List[Dict[str, Any]] # [{canonical, variants}, ...] replace: bool = False @app.post("/v1/admin/synonyms/bulk") def admin_bulk_synonyms(req: SynonymBulkRequest, request: Request): """Bulk add/replace synonyms. Each row: {canonical, variants: [...]}.""" from ..core.hebrew_synonyms import ( add_synonym, clear_synonyms as _clear, ) if req.replace: _clear() added = 0 errors: List[Dict[str, Any]] = [] for i, e in enumerate(req.entries, start=1): try: add_synonym(e["canonical"], list(e.get("variants") or [])) added += 1 except Exception as ex: errors.append({"row": i, "error": str(ex)}) _pipeline.cache.clear() get_cache().clear() _maybe_autosave_synonyms() get_obs().audit( "synonyms.bulk", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), added=added, errors=len(errors), replaced=req.replace, ) return {"added": added, "errors": errors, "replaced": req.replace} @app.get("/v1/admin/synonyms/export", response_class=None) def admin_export_synonyms(): """Export the current synonyms as JSONL — one ``{canonical, variants}`` per line.""" from fastapi.responses import PlainTextResponse from ..core.hebrew_synonyms import list_synonyms import json as _json lines = [ _json.dumps({"canonical": k, "variants": v}, ensure_ascii=False) for k, v in list_synonyms().items() ] body = "\n".join(lines) + ("\n" if lines else "") return PlainTextResponse(body, media_type="application/x-ndjson", headers={ "Content-Disposition": 'attachment; filename="synonyms.jsonl"', }) @app.post("/v1/admin/synonyms") def admin_add_synonym(body: SynonymBody, request: Request): """Add a synonym entry or extend an existing one's variant list.""" from ..core.hebrew_synonyms import add_synonym try: result = add_synonym(body.canonical, body.variants) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) # Purge caches — old answers may have used stale expansion _pipeline.cache.clear() get_cache().clear() _maybe_autosave_synonyms() get_obs().audit( "synonyms.add", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), canonical=body.canonical, n_variants=len(result["variants"]), ) return result @app.delete("/v1/admin/synonyms/{canonical}") def admin_delete_synonym(canonical: str, request: Request): from ..core.hebrew_synonyms import remove_synonym ok = remove_synonym(canonical) if not ok: raise HTTPException(status_code=404, detail={"canonical": canonical}) _pipeline.cache.clear() get_cache().clear() _maybe_autosave_synonyms() get_obs().audit( "synonyms.remove", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), canonical=canonical, ) return {"removed": True, "canonical": canonical} @app.get("/v1/admin/snapshot/status") def admin_snapshot_status(): """Report the periodic auto-snapshotter state (if configured).""" auto = get_autosnapshotter() if auto is None: return { "enabled": False, "hint": ("set TAU_RAG_SNAPSHOT_PATH and " "TAU_RAG_SNAPSHOT_INTERVAL= to enable"), } return {"enabled": True, **auto.status()} @app.post("/v1/admin/snapshot/load") def admin_snapshot_load(req: SnapshotLoadRequest, request: Request): base_path = req.path or _default_snapshot_path() # Resolve which generation to load: 0 = current, N>0 = rotated backup from ..snapshot import _gen_path from pathlib import Path as _P resolved = str(_gen_path(_P(base_path), int(req.generation or 0))) if not _os.path.exists(resolved): raise HTTPException( status_code=404, detail={"snapshot_not_found": resolved, "generation": int(req.generation or 0)}, ) summary = _pipeline.load_snapshot(resolved, replace=req.replace) summary["generation"] = int(req.generation or 0) summary["path_loaded"] = resolved get_obs().audit( "snapshot.load", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), path=resolved, **{k: v for k, v in summary.items() if k not in ("warnings", "generation", "path_loaded")}, ) return summary # ---- API key management (admin-only endpoints) --------------------------- class APIKeyCreateRequest(BaseModel): label: str scopes: List[str] = ["read", "write"] @app.post("/v1/admin/keys") def admin_create_key(req: APIKeyCreateRequest, request: Request): raw = get_auth().create(label=req.label, scopes=req.scopes) get_obs().audit( "key.create", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), label=req.label, scopes=req.scopes, ) return { "api_key": raw, "label": req.label, "scopes": req.scopes, "warning": "save this key now — it cannot be retrieved later", } @app.get("/v1/admin/keys") def admin_list_keys(): return {"keys": get_auth().list_keys()} @app.delete("/v1/admin/keys/{hash_prefix}") def admin_revoke_key(hash_prefix: str, request: Request): ok = get_auth().revoke(hash_prefix) if not ok: raise HTTPException(status_code=404, detail="key not found") get_obs().audit( "key.revoke", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), target_prefix=hash_prefix, ) return {"revoked": True, "hash_prefix": hash_prefix} class APIKeyRotateRequest(BaseModel): grace_seconds: float = 300.0 @app.post("/v1/admin/keys/{hash_prefix}/rotate") def admin_rotate_key( hash_prefix: str, req: APIKeyRotateRequest, request: Request, ): """Rotate an API key with a grace period (v1.76). Generates a new key with the same label+scopes. The old key remains valid until ``grace_seconds`` elapses, after which it stops working. Clients should rotate their config during the window. """ result = get_auth().rotate(hash_prefix, grace_seconds=req.grace_seconds) if result is None: raise HTTPException( status_code=404, detail={"error": "key not found or already revoked", "hash_prefix": hash_prefix}, ) get_obs().audit( "key.rotate", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), old_prefix=result["old_prefix"], new_prefix=result["new_prefix"], grace_seconds=result["grace_seconds"], ) return { **result, "warning": "save the new key now — it cannot be retrieved later", } # ---- v2.7 Maintenance / drain mode ------------------------------------- class MaintenanceOnRequest(BaseModel): reason: str = "" retry_after: int = 30 @app.post("/v1/admin/maintenance/on") def admin_maintenance_on(req: MaintenanceOnRequest, request: Request): """Turn on maintenance / drain mode (v2.7). Non-admin requests get 503 + ``Retry-After`` until turned off. Admin callers (this endpoint included) always flow through. """ from ..middleware.maintenance import get_maintenance m = get_maintenance() m.enable(reason=req.reason, retry_after=req.retry_after) get_obs().audit( "maintenance.on", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), reason=req.reason, retry_after=req.retry_after, ) return {"ok": True, **m.snapshot()} @app.post("/v1/admin/maintenance/off") def admin_maintenance_off(request: Request): """Clear maintenance / drain mode (v2.7).""" from ..middleware.maintenance import get_maintenance m = get_maintenance() snap_before = m.snapshot() m.disable() get_obs().audit( "maintenance.off", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), duration_sec=round(snap_before["duration_sec"], 2), ) return {"ok": True, **m.snapshot()} @app.get("/v1/admin/maintenance") def admin_maintenance_status(): """Current maintenance / drain state (v2.7).""" from ..middleware.maintenance import get_maintenance return get_maintenance().snapshot() # ---- v2.8 PII redaction ------------------------------------------------- class PIIToggleRequest(BaseModel): enabled: bool = True @app.post("/v1/admin/pii_redaction/toggle") def admin_pii_toggle(req: PIIToggleRequest, request: Request): """Enable or disable PII redaction at runtime (v2.8). Default is driven by ``TAU_RAG_PII_REDACT`` env at startup; this endpoint lets ops toggle without a restart (useful when you realize bodies were going to the log unredacted). """ from ..middleware.pii_redaction import get_pii_redactor r = get_pii_redactor() r.set_enabled(req.enabled) get_obs().audit( "pii_redaction.toggle", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), enabled=req.enabled, ) return {"ok": True, **r.stats()} @app.get("/v1/admin/pii_redaction/stats") def admin_pii_stats(): """PII redactor counters: how many IDs/phones/emails/CCs have been scrubbed since startup (or last reset).""" from ..middleware.pii_redaction import get_pii_redactor return get_pii_redactor().stats() @app.post("/v1/admin/pii_redaction/reset") def admin_pii_reset(request: Request): """Zero the per-kind counters. Does not change enabled state.""" from ..middleware.pii_redaction import get_pii_redactor r = get_pii_redactor() r.reset() get_obs().audit( "pii_redaction.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ) return {"ok": True, **r.stats()} # ---- v2.9 Slow-query detection ----------------------------------------- class SlowThresholdRequest(BaseModel): ms: float @app.get("/v1/admin/slow_queries") def admin_slow_queries(n: int = 20): """Top-N slowest requests + per-path aggregates + summary stats. ``n`` bounds the ``top`` list length (default 20, max 100). """ from ..middleware.slow_queries import get_slow_tracker n = min(100, max(1, int(n))) t = get_slow_tracker() return { "stats": t.stats(), "top": t.top_n(n), "by_path": t.by_path(), } @app.post("/v1/admin/slow_queries/threshold") def admin_slow_threshold(req: SlowThresholdRequest, request: Request): """Set the slow-query threshold in ms at runtime. 0 disables.""" from ..middleware.slow_queries import get_slow_tracker t = get_slow_tracker() old = t.threshold_ms t.set_threshold(req.ms) get_obs().audit( "slow_queries.threshold", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), old_ms=old, new_ms=t.threshold_ms, ) return {"ok": True, **t.stats()} @app.post("/v1/admin/slow_queries/reset") def admin_slow_reset(request: Request): """Clear the ring buffer and per-path aggregates.""" from ..middleware.slow_queries import get_slow_tracker t = get_slow_tracker() t.reset() get_obs().audit( "slow_queries.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ) return {"ok": True, **t.stats()} # ---- v2.12 Daily quota ------------------------------------------------- class QuotaSetRequest(BaseModel): key_prefix: str limit: int @app.get("/v1/admin/quotas") def admin_quotas_get(): """Dump all quota state: limits, usage, day cursor, reset timer.""" from ..middleware.quota import get_quota_tracker return get_quota_tracker().stats() @app.post("/v1/admin/quotas") def admin_quota_set(req: QuotaSetRequest, request: Request): """Set a daily quota for a key (by hash prefix). ``limit=0`` = unlimited (also equivalent to deleting the quota). Usage counters are NOT reset — operator can raise the cap mid-day without wiping the meter. """ from ..middleware.quota import get_quota_tracker t = get_quota_tracker() t.set_quota(req.key_prefix, req.limit) get_obs().audit( "quota.set", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), key_prefix=req.key_prefix, limit=req.limit, ) return {"ok": True, "key_prefix": req.key_prefix, "limit": req.limit} @app.delete("/v1/admin/quotas/{key_prefix}") def admin_quota_clear(key_prefix: str, request: Request): """Remove quota enforcement for a key entirely (back to unlimited and zero-out its usage counter).""" from ..middleware.quota import get_quota_tracker t = get_quota_tracker() removed = t.clear_quota(key_prefix) if not removed: raise HTTPException( status_code=404, detail={"error": "no quota set for this key_prefix", "key_prefix": key_prefix}, ) get_obs().audit( "quota.clear", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), key_prefix=key_prefix, ) return {"ok": True, "key_prefix": key_prefix, "removed": True} @app.post("/v1/admin/quotas/reset") def admin_quota_reset_all(request: Request): """Wipe all quota state — limits AND usage. Testing / incident recovery. Audit-logged.""" from ..middleware.quota import get_quota_tracker t = get_quota_tracker() t.reset() get_obs().audit( "quota.reset_all", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ) return {"ok": True, **t.stats()} # ---- v2.13 Idempotency ------------------------------------------------- @app.get("/v1/admin/idempotency/stats") def admin_idempotency_stats(): from ..middleware.idempotency import get_idempotency_store return get_idempotency_store().stats() @app.post("/v1/admin/idempotency/reset") def admin_idempotency_reset(request: Request): from ..middleware.idempotency import get_idempotency_store get_idempotency_store().reset() get_obs().audit( "idempotency.reset", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ) return {"ok": True, **get_idempotency_store().stats()} class IdemTTLRequest(BaseModel): ttl_sec: float @app.post("/v1/admin/idempotency/ttl") def admin_idempotency_ttl(req: IdemTTLRequest, request: Request): from ..middleware.idempotency import get_idempotency_store s = get_idempotency_store() s.set_ttl(req.ttl_sec) get_obs().audit( "idempotency.ttl", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), ttl_sec=req.ttl_sec, ) return {"ok": True, **s.stats()} # ---- v2.14 Request timeout --------------------------------------------- class TimeoutRequest(BaseModel): timeout_ms: float @app.get("/v1/admin/request_timeout") def admin_request_timeout_stats(): from ..middleware.request_timeout import get_timeout_guard return get_timeout_guard().stats() @app.post("/v1/admin/request_timeout") def admin_request_timeout_set(req: TimeoutRequest, request: Request): """Set wall-clock request timeout in ms. 0 disables.""" from ..middleware.request_timeout import get_timeout_guard g = get_timeout_guard() old = g.timeout_ms g.set_timeout_ms(req.timeout_ms) get_obs().audit( "request_timeout.set", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), old_ms=old, new_ms=g.timeout_ms, ) return {"ok": True, **g.stats()} @app.post("/v1/admin/request_timeout/reset") def admin_request_timeout_reset(request: Request): from ..middleware.request_timeout import get_timeout_guard g = get_timeout_guard() g.reset() return {"ok": True, **g.stats()} @app.get("/v1/signals/latest") def latest_signals(): if not _pipeline.cache: return {"empty": True} last = list(_pipeline.cache.values())[-1] return last.signals.to_dict() # ---- Chat / multi-turn --------------------------------------------------- class ChatRequest(BaseModel): query: str session_id: str lang: str = "he" @app.post("/v1/chat/stream", response_class=None) def chat_stream(req: ChatRequest, request: Request): """Server-Sent Events version of /v1/chat — same conversational context as /v1/chat (follow-up expansion, session history) but streaming. Event order: event: followup data: {"is_followup": bool, "expanded_query": "..."} event: retrieved data: {"doc_ids": [...], "count": N} event: answer data: {"chunk": "word "} (repeated) event: done data: {session_id, n_turns, answer, sources, omega, passed, verification} event: error data: {"code","message"} (on failure) """ try: validate_query_text(req.query) except OverflowError as e: raise HTTPException(status_code=413, detail=str(e)) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) from fastapi.responses import StreamingResponse import json as _json def _sse(event: str, data: Any) -> str: return f"event: {event}\ndata: {_json.dumps(data, ensure_ascii=False)}\n\n" def _event_gen(): try: # Detect follow-up + expand (same logic as run_conversation) from ..memory import expand_followup, get_store, is_followup store = get_store() session = store.get_or_create(req.session_id) followup = bool(is_followup(req.query, lang=req.lang)) and bool(session.turns) expanded = expand_followup(req.query, session) if followup else req.query yield _sse("followup", { "is_followup": followup, "expanded_query": expanded if followup else None, }) # Run the pipeline (extractive — ~ms) resp = _pipeline.run_conversation(req.query, req.session_id, lang=req.lang) # Stage 1: retrieval results retrieved = [] seen = set() for c in getattr(resp, "retrieved", []) or []: did = getattr(getattr(c, "chunk", None), "doc_id", None) if did and did not in seen: retrieved.append(did) seen.add(did) yield _sse("retrieved", {"doc_ids": retrieved, "count": len(retrieved)}) # Stage 2: answer streamed word-by-word answer = resp.answer or "" for w in answer.split(" "): if not w: continue yield _sse("answer", {"chunk": w + " "}) # Stage 3: final envelope — includes session_id + n_turns for client try: omega = float(resp.signals.omega) if resp.signals else None except Exception: omega = None verif = getattr(resp, "verification", None) # Re-read session — run_conversation added a new turn session = store.get_or_create(req.session_id) yield _sse("done", { "session_id": req.session_id, "n_turns": len(session.turns), "answer": answer, "sources": list(resp.sources or []), "omega": omega, "passed": bool(getattr(verif, "passed", False)) if verif else None, "verification": (verif.to_dict() if hasattr(verif, "to_dict") else getattr(verif, "__dict__", None)), }) except Exception as e: yield _sse("error", { "code": "internal_error", "message": f"{type(e).__name__}: {e}"[:240], }) return StreamingResponse( _event_gen(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) @app.post("/v1/chat") def chat(req: ChatRequest): """Conversational query — carries prior-turn context for follow-ups.""" resp = _pipeline.run_conversation( query_text=req.query, session_id=req.session_id, lang=req.lang, ) return resp.to_dict() @app.get("/v1/sessions/{session_id}") def session_info(session_id: str): from ..memory import get_store s = get_store().get(session_id) if s is None: raise HTTPException(status_code=404, detail="session not found") return s.to_dict() @app.delete("/v1/sessions/{session_id}") def session_drop(session_id: str): from ..memory import get_store dropped = get_store().drop(session_id) return {"dropped": dropped} @app.get("/v1/sessions") def sessions_list( details: bool = False, min_turns: int = 0, limit: int = 500, ): """List session ids. ``?details=1`` returns the richer per-session view (created_at, last_activity_ts, idle_sec, ttl_remaining_sec, n_turns, last_query, last_sources) sorted by most-recent-activity-first. ``?min_turns=N`` filters; ``?limit=`` caps.""" from ..memory import get_store store = get_store() if not details: return {"sessions": store.list_ids()} rows = [r for r in store.summaries() if r["n_turns"] >= min_turns] if limit and limit > 0: rows = rows[:limit] return {"count": len(rows), "sessions": rows} @app.post("/v1/sessions/gc") def admin_sessions_gc(request: Request): """Force immediate TTL + capacity eviction. Returns ``{before, expired, overflow_evicted, after}``.""" from ..memory import get_store summary = get_store().gc_now() get_obs().audit( "sessions.gc", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), **summary, ) return summary @app.delete("/v1/sessions") def admin_sessions_drop_all(request: Request): """Drop every session. Returns ``{dropped: N}``. Audit-logged.""" from ..memory import get_store n = get_store().drop_all() get_obs().audit( "sessions.drop_all", actor_key=request.headers.get("x-api-key"), request_id=getattr(request.state, "request_id", None), dropped=n, ) return {"dropped": n} def get_pipeline() -> Pipeline: return _pipeline def set_pipeline(new: Pipeline) -> None: """Hook for tests / production to swap in a real pipeline.""" global _pipeline _pipeline = new