| """Load a JSONL corpus DIRECTLY into tau-rag's pipeline. |
| |
| No HTTP, no 8GB upload — reads the JSONL locally and calls |
| pipeline.add_documents() in batches. Run this in the SAME |
| process as the server is NOT ideal; instead, run it before |
| the server starts, or load into a persistent store. |
| |
| For the pilot: this populates the in-memory pipeline so that |
| `/v1/query` returns real results. |
| |
| Usage: |
| python -m scripts.load_corpus ~/tau_pilot.jsonl \\ |
| --preset hebrew_legal_prod |
| |
| The pipeline state persists to runtime/snapshots/ (if enabled). |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
|
|
| def main(argv=None): |
| p = argparse.ArgumentParser( |
| description="Load JSONL corpus directly into tau-rag pipeline") |
| p.add_argument("input", help="JSONL file") |
| p.add_argument("--preset", default="no_llm", |
| help="tau-rag preset (default: no_llm)") |
| p.add_argument("--batch-size", type=int, default=500, |
| help="docs per add_documents call") |
| p.add_argument("--max", type=int, default=None, |
| help="stop after N docs") |
| p.add_argument("--snapshot", |
| default=os.path.expanduser( |
| "~/tau_snapshot"), |
| help="save pipeline state here for later load") |
| args = p.parse_args(argv) |
|
|
| in_path = Path(args.input).expanduser().resolve() |
| if not in_path.exists(): |
| print(f"✗ Not found: {in_path}", file=sys.stderr) |
| return 2 |
|
|
| |
| |
| here = Path(__file__).resolve().parent.parent |
| pkg_parent = here.parent |
| sys.path.insert(0, str(pkg_parent)) |
|
|
| os.environ.setdefault("TAU_RAG_PRESET", args.preset) |
| try: |
| from tau_rag.core.types import Document |
| from tau_rag.pipeline import get_pipeline |
| except ImportError as e: |
| print(f"✗ Cannot import tau_rag: {e}", file=sys.stderr) |
| print(f" Run from tau_rag parent dir, " |
| f"or set PYTHONPATH.", file=sys.stderr) |
| return 3 |
|
|
| print(f"→ Preset: {args.preset}") |
| print(f"→ Loading {in_path.stat().st_size/1024/1024:.1f}MB " |
| f"from {in_path}") |
|
|
| pipe = get_pipeline() |
| if not hasattr(pipe, "add_documents"): |
| print(f"✗ Pipeline has no add_documents(); " |
| f"preset {args.preset} doesn't support ingestion.", |
| file=sys.stderr) |
| return 4 |
|
|
| batch: list = [] |
| n_added = 0 |
| bad = 0 |
| t0 = time.time() |
|
|
| def flush(): |
| nonlocal n_added |
| if not batch: |
| return |
| try: |
| pipe.add_documents(batch) |
| n_added += len(batch) |
| except Exception as e: |
| print(f"\n batch error: {e}", file=sys.stderr) |
| batch.clear() |
|
|
| with open(in_path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| d = json.loads(line) |
| except Exception: |
| bad += 1 |
| continue |
| text = d.get("text") or "" |
| if not text: |
| continue |
| batch.append(Document( |
| id=str(d.get("id", f"doc-{n_added + len(batch)}")), |
| text=text, |
| metadata=d.get("metadata", {}) or {}, |
| )) |
| if len(batch) >= args.batch_size: |
| flush() |
| elapsed = time.time() - t0 |
| rate = n_added / elapsed if elapsed > 0 else 0 |
| print(f"\r loaded {n_added:,} docs | " |
| f"{rate:.0f}/s", |
| end="", flush=True) |
| if args.max and n_added + len(batch) >= args.max: |
| break |
| flush() |
| print() |
| print(f"✓ Loaded {n_added:,} docs into pipeline " |
| f"({bad} bad lines skipped)") |
|
|
| |
| if args.snapshot and hasattr(pipe, "snapshot"): |
| snap = Path(args.snapshot).expanduser().resolve() |
| snap.parent.mkdir(parents=True, exist_ok=True) |
| try: |
| pipe.snapshot(str(snap)) |
| print(f"✓ Saved snapshot → {snap}") |
| except Exception as e: |
| print(f"⚠ Snapshot failed: {e}", file=sys.stderr) |
|
|
| |
| try: |
| test_q = "חוזה" |
| res = pipe.query(test_q, k=3) |
| docs = res.get("docs", []) or [] |
| print(f"\n🔎 Smoke query '{test_q}': " |
| f"got {len(docs)} results") |
| for d in docs[:3]: |
| print(f" [{d.id}] {(d.text or '')[:80]}...") |
| except Exception as e: |
| print(f"⚠ Smoke query failed: {e}", file=sys.stderr) |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|