#!/usr/bin/env python3 """ Benchmark the API against 10 representative queries from train_pairs.parquet. Each query has labeled relevant docs (label=2 = highly relevant). We measure: - top-1 hit rate: did we return a label-2 doc as the first answer? - top-3 hit rate: any label-2 in our top-3? - has-relevant: any label≥1 in our top-3? This validates whether the case-law rulings we just loaded actually help on real-world legal queries. """ import json import urllib.request from collections import defaultdict from pathlib import Path API = "http://127.0.0.1:8000/v1/query" # 10 representative queries — mix of labor/consumer/contract/tort SAMPLE = [ "אחריות יצרן לטלפון סלולרי", "אחריות מעביד לנזקי עובד", "ביטול עסקה באינטרנט תוך 14 יום", "הטעיה בפרסום לפי חוק הגנת הצרכן", "הפרת חובה חקוקה", "טעות יסודית בחוזה", "חוזה אחיד ותנאים מקפחים", "פיטורין בהריון לפי חוק עבודת נשים", "רשלנות רפואית בלידה", "תרמית במשא ומתן", ] PARQUET = "/Users/avrahambarzel/Library/Mobile Documents/com~apple~CloudDocs/LawDBHeb/train_pairs.parquet" def load_relevance(): """Returns {query → {doc_id: label}}.""" import pyarrow.parquet as pq t = pq.read_table(PARQUET, columns=['query', 'doc_id', 'label']) rel = defaultdict(dict) for q, did, lbl in zip(t['query'].to_pylist(), t['doc_id'].to_pylist(), t['label'].to_pylist()): rel[q][str(did)] = lbl return rel def call_api(query, top_k=10): body = json.dumps({"query": query, "top_k": top_k}, ensure_ascii=False).encode("utf-8") req = urllib.request.Request( API, data=body, headers={"Content-Type": "application/json; charset=utf-8"}, ) with urllib.request.urlopen(req, timeout=60) as resp: return json.loads(resp.read().decode("utf-8")) def get_doc_ids_from_response(r): """Extract case_doc_ids from returned docs (in our format 'caselaw/{doc_id}/chunk_N').""" out = [] for d in r.get("docs", []): rid = d.get("id") or "" # Our caselaw IDs look like: "caselaw//chunk_N" if rid.startswith("caselaw/"): parts = rid.split("/") if len(parts) >= 2: out.append(parts[1]) else: # Statute IDs (e.g., "heb_law/contracts_general/section_2") # don't appear in train_pairs labels — skip out.append(None) return out def main(): print("Loading train_pairs.parquet...") rel = load_relevance() print(f"Found relevance for {len(rel)} queries") print() print("=" * 78) print(" BENCHMARK 10/55 — does case-law help?") print("=" * 78) n = len(SAMPLE) n_top1 = n_top3_high = n_top3_any = 0 for q in SAMPLE: try: r = call_api(q, top_k=10) except Exception as e: print(f"\n❌ {q}\n {e}") continue ans = r.get("answer", "")[:120] doc_ids = get_doc_ids_from_response(r) labels = [rel.get(q, {}).get(d, None) for d in doc_ids] top1_hit = labels[0] == 2 if labels else False top3_high = any(l == 2 for l in labels[:3] if l is not None) top3_any = any(l is not None and l >= 1 for l in labels[:3]) n_top1 += int(top1_hit) n_top3_high += int(top3_high) n_top3_any += int(top3_any) marker = "✓" if top3_high else ("~" if top3_any else "✗") print(f"\n{marker} {q}") print(f" top-3 labels: {labels[:3]} (relevant_in_corpus: {len(rel.get(q, {})):,})") print(f" answer head: {ans}") print() print("=" * 78) print(" AGGREGATE") print("=" * 78) print(f" top-1 highly relevant: {n_top1}/{n} ({100*n_top1/n:.0f}%)") print(f" top-3 has highly rel: {n_top3_high}/{n} ({100*n_top3_high/n:.0f}%)") print(f" top-3 has any relevant: {n_top3_any}/{n} ({100*n_top3_any/n:.0f}%)") if __name__ == "__main__": main()