#!/usr/bin/env python3 """ Full benchmark on all 55 unique queries from train_pairs.parquet. For each query: • POST to /v1/query (top_k=10) • Extract doc_ids from response (handle multiple ID formats) • Cross-reference with train_pairs labels (0=irrelevant, 1=relevant, 2=highly relevant) • Also note when retrieval came from Tier 1 (statute/Kol-Zchut) — these docs aren't in train_pairs labels but may still be topically correct. Outputs per-query verdicts and aggregate accuracy. """ import json import re import sys import time import urllib.request from collections import Counter, defaultdict from pathlib import Path API = "http://127.0.0.1:8000/v1/query" PARQUET = "/Users/avrahambarzel/Library/Mobile Documents/com~apple~CloudDocs/LawDBHeb/train_pairs.parquet" def load_relevance(): """Returns: rel: {query → {doc_id: label}} n_high_per_q: {query → count of label-2 docs} """ import pyarrow.parquet as pq t = pq.read_table(PARQUET, columns=['query', 'doc_id', 'label']) rel = defaultdict(dict) for q, did, lbl in zip(t['query'].to_pylist(), t['doc_id'].to_pylist(), t['label'].to_pylist()): rel[q][str(did)] = int(lbl) n_high = {q: sum(1 for l in d.values() if l == 2) for q, d in rel.items()} return dict(rel), n_high def call_api(query, top_k=10): body = json.dumps({"query": query, "top_k": top_k}, ensure_ascii=False).encode("utf-8") req = urllib.request.Request( API, data=body, headers={"Content-Type": "application/json; charset=utf-8"}, ) with urllib.request.urlopen(req, timeout=60) as resp: return json.loads(resp.read().decode("utf-8")) def extract_doc_ids(response): """Pull case_doc_ids from each retrieved doc — handle several ID schemes used by different retrievers.""" out = [] for d in response.get("docs", []): # Tier 2 (parquet) — id format is "caselaw//" # OR metadata has parquet_doc_id directly meta = d.get("metadata", {}) or {} case_id = meta.get("parquet_doc_id") or meta.get("case_doc_id") if case_id: out.append(("parquet", str(case_id))) continue rid = d.get("id") or d.get("doc_id") or "" if rid.startswith("caselaw/"): parts = rid.split("/") if len(parts) >= 2: out.append(("caselaw", parts[1])) continue # Tier 1: statute or Kol-Zchut — not in train_pairs labels if rid.startswith("heb_law/"): out.append(("statute", rid.split("/")[-1])) elif rid.startswith("kolzchut/"): out.append(("kolzchut", rid.split("/")[-1])) else: out.append(("?", rid)) return out def main(): print("Loading train_pairs.parquet …") rel, n_high_per_q = load_relevance() print(f" {len(rel)} unique queries with relevance labels") print() # Verdict buckets n = len(rel) queries = sorted(rel.keys()) n_top1_high = 0 # top-1 is a label-2 doc n_top3_high = 0 # any label-2 in top-3 n_top10_high = 0 n_top3_relevant = 0 # any label≥1 in top-3 n_topical_only = 0 # all top-3 are statute/kolzchut (not labeled) # — could still be right but we can't verify n_no_results = 0 n_errors = 0 by_class = Counter() # source mix in top-3 across all queries t_start = time.time() for i, q in enumerate(queries, 1): try: r = call_api(q, top_k=10) except Exception as e: print(f"[{i:>2}] ❌ ERROR: {q[:50]} ({e})") n_errors += 1 continue ids = extract_doc_ids(r) if not ids: n_no_results += 1 print(f"[{i:>2}] ✗ empty: {q[:50]}") continue # Lookup labels ql = rel.get(q, {}) labels_top3 = [] for src, did in ids[:3]: by_class[src] += 1 labels_top3.append(ql.get(did)) # Verdict if labels_top3 and labels_top3[0] == 2: n_top1_high += 1 if any(l == 2 for l in labels_top3): n_top3_high += 1 if any(l == 2 for l in [ql.get(d) for s, d in ids[:10]]): n_top10_high += 1 if any(l is not None and l >= 1 for l in labels_top3): n_top3_relevant += 1 if all(l is None for l in labels_top3): n_topical_only += 1 # One-line progress marker = "✓" if any(l == 2 for l in labels_top3) else ( "~" if any(l is not None and l >= 1 for l in labels_top3) else "✗") sources = ",".join(s for s, _ in ids[:3]) print(f"[{i:>2}] {marker} {q[:55]:<55} src=[{sources}]") elapsed = time.time() - t_start n_eval = n - n_errors print() print("=" * 78) print(" AGGREGATE") print("=" * 78) print(f" total queries: {n}") print(f" errors: {n_errors}") print(f" empty results: {n_no_results}") print(f" total time: {elapsed:.1f}s ({elapsed/max(1,n):.2f}s/q)") print() print(" RELEVANCE @ top-1:") print(f" highly relevant (lbl=2): {n_top1_high:>2}/{n} ({100*n_top1_high/n:.0f}%)") print() print(" RELEVANCE @ top-3:") print(f" highly relevant in top-3: {n_top3_high:>2}/{n} ({100*n_top3_high/n:.0f}%)") print(f" any relevant in top-3: {n_top3_relevant:>2}/{n} ({100*n_top3_relevant/n:.0f}%)") print() print(" RELEVANCE @ top-10:") print(f" highly relevant in top-10:{n_top10_high:>2}/{n} ({100*n_top10_high/n:.0f}%)") print() print(" TOPICAL coverage:") print(f" all top-3 from statute/kolzchut (not in labels): {n_topical_only:>2}/{n}") print(f" (could be correct, but unverifiable from train_pairs)") print() print(" SOURCE MIX in top-3:") for src, c in by_class.most_common(): print(f" {src:>10}: {c}") if __name__ == "__main__": main()