legal-eye / tau_rag /scripts /serve_with_corpus.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
raw
history blame contribute delete
6.03 kB
"""All-in-one launcher: load JSONL corpus → start uvicorn in SAME process.
This is the simplest end-to-end path:
1. Init the tau-rag pipeline (preset from $TAU_RAG_PRESET)
2. Stream the JSONL corpus and add_documents() in batches
3. Boot uvicorn — the pipeline is already populated, so /v1/query
can immediately retrieve real chunks
Usage:
cd tau_rag
PYTHONPATH=.. python3 -m scripts.serve_with_corpus \\
~/tau_corpus.jsonl
# With custom port / host:
PYTHONPATH=.. python3 -m scripts.serve_with_corpus \\
~/tau_corpus.jsonl --host 0.0.0.0 --port 8080
# Limit doc count for quick testing:
PYTHONPATH=.. python3 -m scripts.serve_with_corpus \\
~/tau_corpus.jsonl --max-docs 10000
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import List
def main(argv=None) -> int:
p = argparse.ArgumentParser(
description="Load corpus + serve in one process")
p.add_argument("jsonl", help="JSONL corpus file")
p.add_argument("--host", default="127.0.0.1")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--batch-size", type=int, default=1000)
p.add_argument("--max-docs", type=int, default=None,
help="cap docs for testing")
p.add_argument("--preset",
default=os.environ.get(
"TAU_RAG_PRESET", "no_llm"),
help="pipeline preset")
p.add_argument("--cors-origins",
default=os.environ.get(
"TAU_RAG_CORS_ORIGINS", "*"),
help="CORS allow_origins (default: *)")
p.add_argument("--no-auth", action="store_true",
help="disable auth requirement (dev only)")
args = p.parse_args(argv)
src = Path(args.jsonl).expanduser().resolve()
if not src.exists():
print(f"✗ JSONL not found: {src}", file=sys.stderr)
return 2
# Config env BEFORE importing tau_rag
os.environ["TAU_RAG_PRESET"] = args.preset
os.environ["TAU_RAG_CORS_ORIGINS"] = args.cors_origins
if args.no_auth:
os.environ["TAU_RAG_AUTH_REQUIRED"] = "false"
# ============================================ Phase 1: Load
try:
from tau_rag.pipeline import get_pipeline
from tau_rag.core.types import Document
except ImportError as e:
print(f"✗ Cannot import tau_rag: {e}", file=sys.stderr)
return 3
print(f"╭─────────────────────────────────────────────╮")
print(f"│ tau-rag · serve-with-corpus │")
print(f"╰─────────────────────────────────────────────╯")
print(f" preset: {args.preset}")
print(f" corpus: {src}")
print(f" size: {src.stat().st_size/1024/1024:.1f} MB")
print()
print(f"→ Initializing pipeline...")
pipe = get_pipeline()
print(f"✓ Pipeline: {type(pipe).__name__}")
if not hasattr(pipe, "add_documents"):
print(f"✗ Pipeline has no add_documents() — preset "
f"'{args.preset}' may be too minimal. Try 'hebrew_legal'.",
file=sys.stderr)
return 4
print(f"→ Loading {src.name}...")
t0 = time.time()
last_print = t0
n_loaded = 0
n_skipped = 0
batch: List = []
def flush():
nonlocal n_loaded
if not batch:
return
try:
pipe.add_documents(batch)
n_loaded += len(batch)
except Exception as e:
print(f"\n ⚠ batch failed: "
f"{type(e).__name__}: {e}", file=sys.stderr)
batch.clear()
with open(src, encoding="utf-8") as f:
for line_n, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
d = json.loads(line)
except Exception:
n_skipped += 1
continue
text = d.get("text", "")
if not text:
n_skipped += 1
continue
doc = Document(
id=str(d.get("id", f"line-{line_n}")),
text=text,
metadata=d.get("metadata", {}) or {},
)
batch.append(doc)
if len(batch) >= args.batch_size:
flush()
now = time.time()
if now - last_print >= 1.0:
elapsed = now - t0
rate = n_loaded / elapsed if elapsed > 0 else 0
print(
f"\r loaded {n_loaded:>9,} docs | "
f"{rate:6.0f}/s",
end="", flush=True)
last_print = now
if args.max_docs and n_loaded >= args.max_docs:
print(f"\n hit --max-docs={args.max_docs}, stopping load")
break
flush()
print()
elapsed = time.time() - t0
print(f"✓ Loaded {n_loaded:,} docs in {elapsed:.0f}s "
f"({n_loaded/elapsed:.0f}/s)")
if n_skipped:
print(f" ({n_skipped} lines skipped)")
# ============================================ Phase 2: Serve
print()
print(f"→ Starting uvicorn on http://{args.host}:{args.port}")
print(f" Open chat UI: http://{args.host}:{args.port}/")
print(f" Open admin: http://{args.host}:{args.port}/admin")
print(f" Stop: Ctrl+C")
print()
try:
import uvicorn
from tau_rag.api.fastapi_app import app
except ImportError as e:
print(f"✗ uvicorn not installed: {e}", file=sys.stderr)
return 5
uvicorn.run(app, host=args.host, port=args.port,
log_level="info")
return 0
if __name__ == "__main__":
sys.exit(main())