#!/usr/bin/env python3 """ BM25 quality probe — DIRECT (bypasses API response stripping). Loads the same pipeline the API uses, runs each query through the retriever ONLY (no generation), and reports rank of the right doc. """ import os import re import sys from pathlib import Path # Make tau_rag importable HERE = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(HERE)) # Force the same preset the API uses os.environ.setdefault("TAU_RAG_PRESET", "hebrew_legal_local") from tau_rag.pipeline import get_pipeline from tau_rag.core.types import Query TESTS = [ ("מה אומר סעיף 39 לחוק החוזים?", r"(KL.?39|סעיף\s*39|תום\s*לב)"), ("מה אומר סעיף 12 לחוק החוזים?", r"(KL.?12|סעיף\s*12|תום\s*לב\s*במשא)"), ("מה אומר סעיף 30 לחוק החוזים?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת\s*הציבור)"), ("מה זה גמירת דעת?", r"גמירת\s*דעת"), ("מה זה מסויימות?", r"מסוימות|מסויימות"), ("הסבר על קיום בתום לב", r"(KL.?39|סעיף\s*39|תום\s*לב|קיום)"), ("מה התרופות במכר פגום?", r"(תרופות|אכיפה|מכר|פגם)"), ("מתי קונה רשאי לבטל מכר?", r"(ביטול|מכר|הפרה\s*יסודית|ארכה)"), ("מה ההבדל בין שכירות למכירה?", r"(שכירות|מכר|חוק\s*המכר|חוק\s*השכירות)"), ("שנשר25512551", r"(NOMATCH_GIBBERISH)"), ("הסבר את חובת תום הלב בחוזים", r"(KL.?39|סעיף\s*39|תום\s*לב)"), ("איך מתבטל חוזה פסול?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת)"), ] def get_retrieved(pipe, query, top_k=30): """Try multiple paths to get raw retrieval results from the pipeline.""" q = Query(text=query, k=top_k) # Path 1: direct call to retriever if hasattr(pipe, "retrievers"): try: results = pipe.retrievers.retrieve(q) return results[:top_k] except Exception as e: print(f" [retrievers.retrieve failed: {e}]") # Path 2: pipe.run() and use response.retrieved try: resp = pipe.run(q) return getattr(resp, "retrieved", [])[:top_k] except Exception as e: print(f" [pipe.run failed: {e}]") return [] def get_text(retrieved_item): """Extract text field from a Retrieved/Chunk object.""" chunk = getattr(retrieved_item, "chunk", None) or \ getattr(retrieved_item, "document", None) or retrieved_item text = getattr(chunk, "text", None) or \ getattr(chunk, "content", "") or "" return str(text) def main(): print("=" * 78) print(" BM25 QUALITY PROBE — DIRECT (loads same pipeline as API)") print("=" * 78) print(f"Loading pipeline (TAU_RAG_PRESET={os.environ.get('TAU_RAG_PRESET')})...") pipe = get_pipeline() print(f"Pipeline ready. Generator: {type(getattr(pipe, 'generator', None)).__name__}") print(f"Retrievers: {type(getattr(pipe, 'retrievers', None)).__name__}") n_docs = len(getattr(pipe, "_indexed_docs", []) or []) print(f"Indexed docs: {n_docs}") print() if n_docs == 0: print("❌ Index is empty — no docs to retrieve from.") print(" The pipeline started but the corpus wasn't loaded.") print(" Check TAU_RAG_CORPUS_JSONL or pipeline hydration.") return rows = [] for q, anchor in TESTS: retrieved = get_retrieved(pipe, q, top_k=30) rank = 0 for i, r in enumerate(retrieved, 1): text = get_text(r) if re.search(anchor, text): rank = i break rows.append({"q": q, "rank": rank, "n": len(retrieved)}) marker = "✓" if rank > 0 else "✗" print(f"{marker} rank={rank or '—':>3} /{len(retrieved):>2} {q}") if rank == 0 and retrieved: # Show first retrieved doc's start so we can debug t = get_text(retrieved[0])[:100] print(f" [top-1: {t}...]") n = len(rows) n_top1 = sum(1 for r in rows if 1 <= r["rank"] <= 1) n_top3 = sum(1 for r in rows if 1 <= r["rank"] <= 3) n_top10 = sum(1 for r in rows if 1 <= r["rank"] <= 10) n_top30 = sum(1 for r in rows if r["rank"] >= 1) n_miss = sum(1 for r in rows if r["rank"] == 0) print() print("=" * 78) print(" AGGREGATE") print("=" * 78) print(f" total queries: {n}") print(f" hit @ top-1: {n_top1}/{n} ({100*n_top1/n:.0f}%)") print(f" hit @ top-3: {n_top3}/{n} ({100*n_top3/n:.0f}%)") print(f" hit @ top-10: {n_top10}/{n} ({100*n_top10/n:.0f}%)") print(f" hit @ top-30: {n_top30}/{n} ({100*n_top30/n:.0f}%)") print(f" missed entirely: {n_miss}/{n}") print() print(" VERDICT for reranker viability:") if n_top10 >= n - 2: print(" ✅ STRONG — reranker on top-30 will likely succeed.") elif n_top30 >= n * 0.7: print(" ⚠️ MIXED — top-30 mostly covers but BM25 needs tuning.") else: print(" ❌ WEAK — reranker can't fix retrieval gap.") if __name__ == "__main__": main()