legal-eye / tau_rag /scripts /check_bm25_quality.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
Raw
History Blame Contribute Delete
4.64 kB
#!/usr/bin/env python3
"""
BM25 retrieval quality probe — does the retriever find the RIGHT doc
in top-K for each diagnostic query? This determines whether a reranker
can possibly succeed (reranker is only as good as its candidate pool).
For each test query: ask the API for top-30 docs, then check whether
the doc containing the expected anchor term shows up in top-1 / top-3 /
top-10 / top-30. The anchor is the ground-truth section/concept that
the teacher answer cites.
"""
import json
import re
import urllib.request
from pathlib import Path
API = "http://127.0.0.1:8000/v1/query"
# (query, anchor_regex) — anchor must appear in the retrieved doc text
# for that doc to count as "the right one". Section regexes look for
# explicit "[KL§N]" markers + "סעיף N" + the section's hallmark term.
TESTS = [
("מה אומר סעיף 39 לחוק החוזים?", r"(KL.?39|סעיף\s*39|תום\s*לב)"),
("מה אומר סעיף 12 לחוק החוזים?", r"(KL.?12|סעיף\s*12|תום\s*לב\s*במשא)"),
("מה אומר סעיף 30 לחוק החוזים?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת\s*הציבור)"),
("מה זה גמירת דעת?", r"גמירת\s*דעת"),
("מה זה מסויימות?", r"מסוימות|מסויימות"),
("הסבר על קיום בתום לב", r"(KL.?39|סעיף\s*39|תום\s*לב|קיום)"),
("מה התרופות במכר פגום?", r"(תרופות|אכיפה|מכר|פגם)"),
("מתי קונה רשאי לבטל מכר?", r"(ביטול|מכר|הפרה\s*יסודית|ארכה)"),
("מה ההבדל בין שכירות למכירה?", r"(שכירות|מכר|חוק\s*המכר|חוק\s*השכירות)"),
("שנשר25512551", r"(NOMATCH_GIBBERISH)"), # no anchor — can't pass
("הסבר את חובת תום הלב בחוזים", r"(KL.?39|סעיף\s*39|תום\s*לב)"),
("איך מתבטל חוזה פסול?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת)"),
]
def call(query, top_k=30):
body = json.dumps({"query": query, "top_k": top_k},
ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
API, data=body,
headers={"Content-Type": "application/json; charset=utf-8"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
def hit_rank(docs, anchor_re):
"""Return 1-indexed rank of first doc whose text matches anchor_re,
or 0 if none of `docs` matches."""
for i, d in enumerate(docs, 1):
text = d.get("text") or ""
if re.search(anchor_re, text):
return i
return 0
def main():
print("=" * 78)
print(" BM25 QUALITY PROBE — does retrieval surface the right doc?")
print("=" * 78)
rows = []
for q, anchor in TESTS:
try:
r = call(q, top_k=30)
except Exception as e:
print(f"\n❌ {q}\n API error: {e}")
continue
docs = r.get("docs") or []
rank = hit_rank(docs, anchor)
rows.append({"q": q, "rank": rank, "n_returned": len(docs)})
marker = "✓" if rank > 0 else "✗"
print(f"{marker} rank={rank or '—':>3} /{len(docs):>2} {q}")
n = len(rows)
n_top1 = sum(1 for r in rows if 1 <= r["rank"] <= 1)
n_top3 = sum(1 for r in rows if 1 <= r["rank"] <= 3)
n_top10 = sum(1 for r in rows if 1 <= r["rank"] <= 10)
n_top30 = sum(1 for r in rows if r["rank"] >= 1)
n_miss = sum(1 for r in rows if r["rank"] == 0)
print()
print("=" * 78)
print(" AGGREGATE")
print("=" * 78)
print(f" total queries: {n}")
print(f" hit @ top-1: {n_top1}/{n} ({100*n_top1/n:.0f}%)")
print(f" hit @ top-3: {n_top3}/{n} ({100*n_top3/n:.0f}%)")
print(f" hit @ top-10: {n_top10}/{n} ({100*n_top10/n:.0f}%)")
print(f" hit @ top-30: {n_top30}/{n} ({100*n_top30/n:.0f}%)")
print(f" missed entirely: {n_miss}/{n}")
print()
print(" VERDICT for reranker viability:")
if n_top10 >= n - 2: # tolerate 2 misses (e.g. gibberish + 1 hard)
print(" ✅ STRONG — reranker on top-30 will likely succeed.")
elif n_top30 >= n * 0.7:
print(" ⚠️ MIXED — top-30 mostly covers but BM25 needs tuning")
print(" OR add dense retriever before reranker.")
else:
print(" ❌ WEAK — reranker can't fix retrieval gap. Fix BM25 first.")
if __name__ == "__main__":
main()