| |
| """ |
| Full benchmark on all 55 unique queries from train_pairs.parquet. |
| |
| For each query: |
| • POST to /v1/query (top_k=10) |
| • Extract doc_ids from response (handle multiple ID formats) |
| • Cross-reference with train_pairs labels (0=irrelevant, 1=relevant, 2=highly relevant) |
| • Also note when retrieval came from Tier 1 (statute/Kol-Zchut) — these |
| docs aren't in train_pairs labels but may still be topically correct. |
| |
| Outputs per-query verdicts and aggregate accuracy. |
| """ |
| import json |
| import re |
| import sys |
| import time |
| import urllib.request |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| API = "http://127.0.0.1:8000/v1/query" |
| PARQUET = "/Users/avrahambarzel/Library/Mobile Documents/com~apple~CloudDocs/LawDBHeb/train_pairs.parquet" |
|
|
|
|
| def load_relevance(): |
| """Returns: |
| rel: {query → {doc_id: label}} |
| n_high_per_q: {query → count of label-2 docs} |
| """ |
| import pyarrow.parquet as pq |
| t = pq.read_table(PARQUET, columns=['query', 'doc_id', 'label']) |
| rel = defaultdict(dict) |
| for q, did, lbl in zip(t['query'].to_pylist(), |
| t['doc_id'].to_pylist(), |
| t['label'].to_pylist()): |
| rel[q][str(did)] = int(lbl) |
| n_high = {q: sum(1 for l in d.values() if l == 2) |
| for q, d in rel.items()} |
| return dict(rel), n_high |
|
|
|
|
| def call_api(query, top_k=10): |
| body = json.dumps({"query": query, "top_k": top_k}, |
| ensure_ascii=False).encode("utf-8") |
| req = urllib.request.Request( |
| API, data=body, |
| headers={"Content-Type": "application/json; charset=utf-8"}, |
| ) |
| with urllib.request.urlopen(req, timeout=60) as resp: |
| return json.loads(resp.read().decode("utf-8")) |
|
|
|
|
| def extract_doc_ids(response): |
| """Pull case_doc_ids from each retrieved doc — handle several ID |
| schemes used by different retrievers.""" |
| out = [] |
| for d in response.get("docs", []): |
| |
| |
| meta = d.get("metadata", {}) or {} |
| case_id = meta.get("parquet_doc_id") or meta.get("case_doc_id") |
| if case_id: |
| out.append(("parquet", str(case_id))) |
| continue |
| rid = d.get("id") or d.get("doc_id") or "" |
| if rid.startswith("caselaw/"): |
| parts = rid.split("/") |
| if len(parts) >= 2: |
| out.append(("caselaw", parts[1])) |
| continue |
| |
| if rid.startswith("heb_law/"): |
| out.append(("statute", rid.split("/")[-1])) |
| elif rid.startswith("kolzchut/"): |
| out.append(("kolzchut", rid.split("/")[-1])) |
| else: |
| out.append(("?", rid)) |
| return out |
|
|
|
|
| def main(): |
| print("Loading train_pairs.parquet …") |
| rel, n_high_per_q = load_relevance() |
| print(f" {len(rel)} unique queries with relevance labels") |
| print() |
|
|
| |
| n = len(rel) |
| queries = sorted(rel.keys()) |
|
|
| n_top1_high = 0 |
| n_top3_high = 0 |
| n_top10_high = 0 |
| n_top3_relevant = 0 |
| n_topical_only = 0 |
| |
| n_no_results = 0 |
| n_errors = 0 |
|
|
| by_class = Counter() |
|
|
| t_start = time.time() |
| for i, q in enumerate(queries, 1): |
| try: |
| r = call_api(q, top_k=10) |
| except Exception as e: |
| print(f"[{i:>2}] ❌ ERROR: {q[:50]} ({e})") |
| n_errors += 1 |
| continue |
| ids = extract_doc_ids(r) |
| if not ids: |
| n_no_results += 1 |
| print(f"[{i:>2}] ✗ empty: {q[:50]}") |
| continue |
| |
| ql = rel.get(q, {}) |
| labels_top3 = [] |
| for src, did in ids[:3]: |
| by_class[src] += 1 |
| labels_top3.append(ql.get(did)) |
| |
| if labels_top3 and labels_top3[0] == 2: |
| n_top1_high += 1 |
| if any(l == 2 for l in labels_top3): |
| n_top3_high += 1 |
| if any(l == 2 for l in [ql.get(d) for s, d in ids[:10]]): |
| n_top10_high += 1 |
| if any(l is not None and l >= 1 for l in labels_top3): |
| n_top3_relevant += 1 |
| if all(l is None for l in labels_top3): |
| n_topical_only += 1 |
|
|
| |
| marker = "✓" if any(l == 2 for l in labels_top3) else ( |
| "~" if any(l is not None and l >= 1 for l in labels_top3) else "✗") |
| sources = ",".join(s for s, _ in ids[:3]) |
| print(f"[{i:>2}] {marker} {q[:55]:<55} src=[{sources}]") |
|
|
| elapsed = time.time() - t_start |
| n_eval = n - n_errors |
|
|
| print() |
| print("=" * 78) |
| print(" AGGREGATE") |
| print("=" * 78) |
| print(f" total queries: {n}") |
| print(f" errors: {n_errors}") |
| print(f" empty results: {n_no_results}") |
| print(f" total time: {elapsed:.1f}s ({elapsed/max(1,n):.2f}s/q)") |
| print() |
| print(" RELEVANCE @ top-1:") |
| print(f" highly relevant (lbl=2): {n_top1_high:>2}/{n} ({100*n_top1_high/n:.0f}%)") |
| print() |
| print(" RELEVANCE @ top-3:") |
| print(f" highly relevant in top-3: {n_top3_high:>2}/{n} ({100*n_top3_high/n:.0f}%)") |
| print(f" any relevant in top-3: {n_top3_relevant:>2}/{n} ({100*n_top3_relevant/n:.0f}%)") |
| print() |
| print(" RELEVANCE @ top-10:") |
| print(f" highly relevant in top-10:{n_top10_high:>2}/{n} ({100*n_top10_high/n:.0f}%)") |
| print() |
| print(" TOPICAL coverage:") |
| print(f" all top-3 from statute/kolzchut (not in labels): {n_topical_only:>2}/{n}") |
| print(f" (could be correct, but unverifiable from train_pairs)") |
| print() |
| print(" SOURCE MIX in top-3:") |
| for src, c in by_class.most_common(): |
| print(f" {src:>10}: {c}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|