legal-eye / tau_rag /scripts /benchmark_55_full.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
Raw
History Blame Contribute Delete
6.14 kB
#!/usr/bin/env python3
"""
Full benchmark on all 55 unique queries from train_pairs.parquet.
For each query:
• POST to /v1/query (top_k=10)
• Extract doc_ids from response (handle multiple ID formats)
• Cross-reference with train_pairs labels (0=irrelevant, 1=relevant, 2=highly relevant)
• Also note when retrieval came from Tier 1 (statute/Kol-Zchut) — these
docs aren't in train_pairs labels but may still be topically correct.
Outputs per-query verdicts and aggregate accuracy.
"""
import json
import re
import sys
import time
import urllib.request
from collections import Counter, defaultdict
from pathlib import Path
API = "http://127.0.0.1:8000/v1/query"
PARQUET = "/Users/avrahambarzel/Library/Mobile Documents/com~apple~CloudDocs/LawDBHeb/train_pairs.parquet"
def load_relevance():
"""Returns:
rel: {query → {doc_id: label}}
n_high_per_q: {query → count of label-2 docs}
"""
import pyarrow.parquet as pq
t = pq.read_table(PARQUET, columns=['query', 'doc_id', 'label'])
rel = defaultdict(dict)
for q, did, lbl in zip(t['query'].to_pylist(),
t['doc_id'].to_pylist(),
t['label'].to_pylist()):
rel[q][str(did)] = int(lbl)
n_high = {q: sum(1 for l in d.values() if l == 2)
for q, d in rel.items()}
return dict(rel), n_high
def call_api(query, top_k=10):
body = json.dumps({"query": query, "top_k": top_k},
ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
API, data=body,
headers={"Content-Type": "application/json; charset=utf-8"},
)
with urllib.request.urlopen(req, timeout=60) as resp:
return json.loads(resp.read().decode("utf-8"))
def extract_doc_ids(response):
"""Pull case_doc_ids from each retrieved doc — handle several ID
schemes used by different retrievers."""
out = []
for d in response.get("docs", []):
# Tier 2 (parquet) — id format is "caselaw/<doc_id>/<chunk_id>"
# OR metadata has parquet_doc_id directly
meta = d.get("metadata", {}) or {}
case_id = meta.get("parquet_doc_id") or meta.get("case_doc_id")
if case_id:
out.append(("parquet", str(case_id)))
continue
rid = d.get("id") or d.get("doc_id") or ""
if rid.startswith("caselaw/"):
parts = rid.split("/")
if len(parts) >= 2:
out.append(("caselaw", parts[1]))
continue
# Tier 1: statute or Kol-Zchut — not in train_pairs labels
if rid.startswith("heb_law/"):
out.append(("statute", rid.split("/")[-1]))
elif rid.startswith("kolzchut/"):
out.append(("kolzchut", rid.split("/")[-1]))
else:
out.append(("?", rid))
return out
def main():
print("Loading train_pairs.parquet …")
rel, n_high_per_q = load_relevance()
print(f" {len(rel)} unique queries with relevance labels")
print()
# Verdict buckets
n = len(rel)
queries = sorted(rel.keys())
n_top1_high = 0 # top-1 is a label-2 doc
n_top3_high = 0 # any label-2 in top-3
n_top10_high = 0
n_top3_relevant = 0 # any label≥1 in top-3
n_topical_only = 0 # all top-3 are statute/kolzchut (not labeled)
# — could still be right but we can't verify
n_no_results = 0
n_errors = 0
by_class = Counter() # source mix in top-3 across all queries
t_start = time.time()
for i, q in enumerate(queries, 1):
try:
r = call_api(q, top_k=10)
except Exception as e:
print(f"[{i:>2}] ❌ ERROR: {q[:50]} ({e})")
n_errors += 1
continue
ids = extract_doc_ids(r)
if not ids:
n_no_results += 1
print(f"[{i:>2}] ✗ empty: {q[:50]}")
continue
# Lookup labels
ql = rel.get(q, {})
labels_top3 = []
for src, did in ids[:3]:
by_class[src] += 1
labels_top3.append(ql.get(did))
# Verdict
if labels_top3 and labels_top3[0] == 2:
n_top1_high += 1
if any(l == 2 for l in labels_top3):
n_top3_high += 1
if any(l == 2 for l in [ql.get(d) for s, d in ids[:10]]):
n_top10_high += 1
if any(l is not None and l >= 1 for l in labels_top3):
n_top3_relevant += 1
if all(l is None for l in labels_top3):
n_topical_only += 1
# One-line progress
marker = "✓" if any(l == 2 for l in labels_top3) else (
"~" if any(l is not None and l >= 1 for l in labels_top3) else "✗")
sources = ",".join(s for s, _ in ids[:3])
print(f"[{i:>2}] {marker} {q[:55]:<55} src=[{sources}]")
elapsed = time.time() - t_start
n_eval = n - n_errors
print()
print("=" * 78)
print(" AGGREGATE")
print("=" * 78)
print(f" total queries: {n}")
print(f" errors: {n_errors}")
print(f" empty results: {n_no_results}")
print(f" total time: {elapsed:.1f}s ({elapsed/max(1,n):.2f}s/q)")
print()
print(" RELEVANCE @ top-1:")
print(f" highly relevant (lbl=2): {n_top1_high:>2}/{n} ({100*n_top1_high/n:.0f}%)")
print()
print(" RELEVANCE @ top-3:")
print(f" highly relevant in top-3: {n_top3_high:>2}/{n} ({100*n_top3_high/n:.0f}%)")
print(f" any relevant in top-3: {n_top3_relevant:>2}/{n} ({100*n_top3_relevant/n:.0f}%)")
print()
print(" RELEVANCE @ top-10:")
print(f" highly relevant in top-10:{n_top10_high:>2}/{n} ({100*n_top10_high/n:.0f}%)")
print()
print(" TOPICAL coverage:")
print(f" all top-3 from statute/kolzchut (not in labels): {n_topical_only:>2}/{n}")
print(f" (could be correct, but unverifiable from train_pairs)")
print()
print(" SOURCE MIX in top-3:")
for src, c in by_class.most_common():
print(f" {src:>10}: {c}")
if __name__ == "__main__":
main()