legal-eye / tau_rag /scripts /load_corpus.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
Raw
History Blame Contribute Delete
4.78 kB
"""Load a JSONL corpus DIRECTLY into tau-rag's pipeline.
No HTTP, no 8GB upload — reads the JSONL locally and calls
pipeline.add_documents() in batches. Run this in the SAME
process as the server is NOT ideal; instead, run it before
the server starts, or load into a persistent store.
For the pilot: this populates the in-memory pipeline so that
`/v1/query` returns real results.
Usage:
python -m scripts.load_corpus ~/tau_pilot.jsonl \\
--preset hebrew_legal_prod
The pipeline state persists to runtime/snapshots/ (if enabled).
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
def main(argv=None):
p = argparse.ArgumentParser(
description="Load JSONL corpus directly into tau-rag pipeline")
p.add_argument("input", help="JSONL file")
p.add_argument("--preset", default="no_llm",
help="tau-rag preset (default: no_llm)")
p.add_argument("--batch-size", type=int, default=500,
help="docs per add_documents call")
p.add_argument("--max", type=int, default=None,
help="stop after N docs")
p.add_argument("--snapshot",
default=os.path.expanduser(
"~/tau_snapshot"),
help="save pipeline state here for later load")
args = p.parse_args(argv)
in_path = Path(args.input).expanduser().resolve()
if not in_path.exists():
print(f"✗ Not found: {in_path}", file=sys.stderr)
return 2
# Bootstrap tau-rag
# Locate the package root
here = Path(__file__).resolve().parent.parent
pkg_parent = here.parent
sys.path.insert(0, str(pkg_parent))
os.environ.setdefault("TAU_RAG_PRESET", args.preset)
try:
from tau_rag.core.types import Document
from tau_rag.pipeline import get_pipeline
except ImportError as e:
print(f"✗ Cannot import tau_rag: {e}", file=sys.stderr)
print(f" Run from tau_rag parent dir, "
f"or set PYTHONPATH.", file=sys.stderr)
return 3
print(f"→ Preset: {args.preset}")
print(f"→ Loading {in_path.stat().st_size/1024/1024:.1f}MB "
f"from {in_path}")
pipe = get_pipeline()
if not hasattr(pipe, "add_documents"):
print(f"✗ Pipeline has no add_documents(); "
f"preset {args.preset} doesn't support ingestion.",
file=sys.stderr)
return 4
batch: list = []
n_added = 0
bad = 0
t0 = time.time()
def flush():
nonlocal n_added
if not batch:
return
try:
pipe.add_documents(batch)
n_added += len(batch)
except Exception as e:
print(f"\n batch error: {e}", file=sys.stderr)
batch.clear()
with open(in_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
d = json.loads(line)
except Exception:
bad += 1
continue
text = d.get("text") or ""
if not text:
continue
batch.append(Document(
id=str(d.get("id", f"doc-{n_added + len(batch)}")),
text=text,
metadata=d.get("metadata", {}) or {},
))
if len(batch) >= args.batch_size:
flush()
elapsed = time.time() - t0
rate = n_added / elapsed if elapsed > 0 else 0
print(f"\r loaded {n_added:,} docs | "
f"{rate:.0f}/s",
end="", flush=True)
if args.max and n_added + len(batch) >= args.max:
break
flush()
print()
print(f"✓ Loaded {n_added:,} docs into pipeline "
f"({bad} bad lines skipped)")
# Snapshot
if args.snapshot and hasattr(pipe, "snapshot"):
snap = Path(args.snapshot).expanduser().resolve()
snap.parent.mkdir(parents=True, exist_ok=True)
try:
pipe.snapshot(str(snap))
print(f"✓ Saved snapshot → {snap}")
except Exception as e:
print(f"⚠ Snapshot failed: {e}", file=sys.stderr)
# Quick smoke query
try:
test_q = "חוזה"
res = pipe.query(test_q, k=3)
docs = res.get("docs", []) or []
print(f"\n🔎 Smoke query '{test_q}': "
f"got {len(docs)} results")
for d in docs[:3]:
print(f" [{d.id}] {(d.text or '')[:80]}...")
except Exception as e:
print(f"⚠ Smoke query failed: {e}", file=sys.stderr)
return 0
if __name__ == "__main__":
sys.exit(main())