legal-eye / tau_rag /scripts /check_bm25_direct.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
#!/usr/bin/env python3
"""
BM25 quality probe — DIRECT (bypasses API response stripping).
Loads the same pipeline the API uses, runs each query through the
retriever ONLY (no generation), and reports rank of the right doc.
"""
import os
import re
import sys
from pathlib import Path
# Make tau_rag importable
HERE = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(HERE))
# Force the same preset the API uses
os.environ.setdefault("TAU_RAG_PRESET", "hebrew_legal_local")
from tau_rag.pipeline import get_pipeline
from tau_rag.core.types import Query
TESTS = [
("מה אומר סעיף 39 לחוק החוזים?", r"(KL.?39|סעיף\s*39|תום\s*לב)"),
("מה אומר סעיף 12 לחוק החוזים?", r"(KL.?12|סעיף\s*12|תום\s*לב\s*במשא)"),
("מה אומר סעיף 30 לחוק החוזים?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת\s*הציבור)"),
("מה זה גמירת דעת?", r"גמירת\s*דעת"),
("מה זה מסויימות?", r"מסוימות|מסויימות"),
("הסבר על קיום בתום לב", r"(KL.?39|סעיף\s*39|תום\s*לב|קיום)"),
("מה התרופות במכר פגום?", r"(תרופות|אכיפה|מכר|פגם)"),
("מתי קונה רשאי לבטל מכר?", r"(ביטול|מכר|הפרה\s*יסודית|ארכה)"),
("מה ההבדל בין שכירות למכירה?", r"(שכירות|מכר|חוק\s*המכר|חוק\s*השכירות)"),
("שנשר25512551", r"(NOMATCH_GIBBERISH)"),
("הסבר את חובת תום הלב בחוזים", r"(KL.?39|סעיף\s*39|תום\s*לב)"),
("איך מתבטל חוזה פסול?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת)"),
]
def get_retrieved(pipe, query, top_k=30):
"""Try multiple paths to get raw retrieval results from the pipeline."""
q = Query(text=query, k=top_k)
# Path 1: direct call to retriever
if hasattr(pipe, "retrievers"):
try:
results = pipe.retrievers.retrieve(q)
return results[:top_k]
except Exception as e:
print(f" [retrievers.retrieve failed: {e}]")
# Path 2: pipe.run() and use response.retrieved
try:
resp = pipe.run(q)
return getattr(resp, "retrieved", [])[:top_k]
except Exception as e:
print(f" [pipe.run failed: {e}]")
return []
def get_text(retrieved_item):
"""Extract text field from a Retrieved/Chunk object."""
chunk = getattr(retrieved_item, "chunk", None) or \
getattr(retrieved_item, "document", None) or retrieved_item
text = getattr(chunk, "text", None) or \
getattr(chunk, "content", "") or ""
return str(text)
def main():
print("=" * 78)
print(" BM25 QUALITY PROBE — DIRECT (loads same pipeline as API)")
print("=" * 78)
print(f"Loading pipeline (TAU_RAG_PRESET={os.environ.get('TAU_RAG_PRESET')})...")
pipe = get_pipeline()
print(f"Pipeline ready. Generator: {type(getattr(pipe, 'generator', None)).__name__}")
print(f"Retrievers: {type(getattr(pipe, 'retrievers', None)).__name__}")
n_docs = len(getattr(pipe, "_indexed_docs", []) or [])
print(f"Indexed docs: {n_docs}")
print()
if n_docs == 0:
print("❌ Index is empty — no docs to retrieve from.")
print(" The pipeline started but the corpus wasn't loaded.")
print(" Check TAU_RAG_CORPUS_JSONL or pipeline hydration.")
return
rows = []
for q, anchor in TESTS:
retrieved = get_retrieved(pipe, q, top_k=30)
rank = 0
for i, r in enumerate(retrieved, 1):
text = get_text(r)
if re.search(anchor, text):
rank = i
break
rows.append({"q": q, "rank": rank, "n": len(retrieved)})
marker = "✓" if rank > 0 else "✗"
print(f"{marker} rank={rank or '—':>3} /{len(retrieved):>2} {q}")
if rank == 0 and retrieved:
# Show first retrieved doc's start so we can debug
t = get_text(retrieved[0])[:100]
print(f" [top-1: {t}...]")
n = len(rows)
n_top1 = sum(1 for r in rows if 1 <= r["rank"] <= 1)
n_top3 = sum(1 for r in rows if 1 <= r["rank"] <= 3)
n_top10 = sum(1 for r in rows if 1 <= r["rank"] <= 10)
n_top30 = sum(1 for r in rows if r["rank"] >= 1)
n_miss = sum(1 for r in rows if r["rank"] == 0)
print()
print("=" * 78)
print(" AGGREGATE")
print("=" * 78)
print(f" total queries: {n}")
print(f" hit @ top-1: {n_top1}/{n} ({100*n_top1/n:.0f}%)")
print(f" hit @ top-3: {n_top3}/{n} ({100*n_top3/n:.0f}%)")
print(f" hit @ top-10: {n_top10}/{n} ({100*n_top10/n:.0f}%)")
print(f" hit @ top-30: {n_top30}/{n} ({100*n_top30/n:.0f}%)")
print(f" missed entirely: {n_miss}/{n}")
print()
print(" VERDICT for reranker viability:")
if n_top10 >= n - 2:
print(" ✅ STRONG — reranker on top-30 will likely succeed.")
elif n_top30 >= n * 0.7:
print(" ⚠️ MIXED — top-30 mostly covers but BM25 needs tuning.")
else:
print(" ❌ WEAK — reranker can't fix retrieval gap.")
if __name__ == "__main__":
main()