| |
| """ |
| BM25 quality probe — DIRECT (bypasses API response stripping). |
| Loads the same pipeline the API uses, runs each query through the |
| retriever ONLY (no generation), and reports rank of the right doc. |
| """ |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| |
| HERE = Path(__file__).resolve().parent.parent.parent |
| sys.path.insert(0, str(HERE)) |
|
|
| |
| os.environ.setdefault("TAU_RAG_PRESET", "hebrew_legal_local") |
|
|
| from tau_rag.pipeline import get_pipeline |
| from tau_rag.core.types import Query |
|
|
|
|
| TESTS = [ |
| ("מה אומר סעיף 39 לחוק החוזים?", r"(KL.?39|סעיף\s*39|תום\s*לב)"), |
| ("מה אומר סעיף 12 לחוק החוזים?", r"(KL.?12|סעיף\s*12|תום\s*לב\s*במשא)"), |
| ("מה אומר סעיף 30 לחוק החוזים?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת\s*הציבור)"), |
| ("מה זה גמירת דעת?", r"גמירת\s*דעת"), |
| ("מה זה מסויימות?", r"מסוימות|מסויימות"), |
| ("הסבר על קיום בתום לב", r"(KL.?39|סעיף\s*39|תום\s*לב|קיום)"), |
| ("מה התרופות במכר פגום?", r"(תרופות|אכיפה|מכר|פגם)"), |
| ("מתי קונה רשאי לבטל מכר?", r"(ביטול|מכר|הפרה\s*יסודית|ארכה)"), |
| ("מה ההבדל בין שכירות למכירה?", r"(שכירות|מכר|חוק\s*המכר|חוק\s*השכירות)"), |
| ("שנשר25512551", r"(NOMATCH_GIBBERISH)"), |
| ("הסבר את חובת תום הלב בחוזים", r"(KL.?39|סעיף\s*39|תום\s*לב)"), |
| ("איך מתבטל חוזה פסול?", r"(KL.?30|סעיף\s*30|חוזה\s*פסול|תקנת)"), |
| ] |
|
|
|
|
| def get_retrieved(pipe, query, top_k=30): |
| """Try multiple paths to get raw retrieval results from the pipeline.""" |
| q = Query(text=query, k=top_k) |
| |
| if hasattr(pipe, "retrievers"): |
| try: |
| results = pipe.retrievers.retrieve(q) |
| return results[:top_k] |
| except Exception as e: |
| print(f" [retrievers.retrieve failed: {e}]") |
| |
| try: |
| resp = pipe.run(q) |
| return getattr(resp, "retrieved", [])[:top_k] |
| except Exception as e: |
| print(f" [pipe.run failed: {e}]") |
| return [] |
|
|
|
|
| def get_text(retrieved_item): |
| """Extract text field from a Retrieved/Chunk object.""" |
| chunk = getattr(retrieved_item, "chunk", None) or \ |
| getattr(retrieved_item, "document", None) or retrieved_item |
| text = getattr(chunk, "text", None) or \ |
| getattr(chunk, "content", "") or "" |
| return str(text) |
|
|
|
|
| def main(): |
| print("=" * 78) |
| print(" BM25 QUALITY PROBE — DIRECT (loads same pipeline as API)") |
| print("=" * 78) |
| print(f"Loading pipeline (TAU_RAG_PRESET={os.environ.get('TAU_RAG_PRESET')})...") |
| pipe = get_pipeline() |
| print(f"Pipeline ready. Generator: {type(getattr(pipe, 'generator', None)).__name__}") |
| print(f"Retrievers: {type(getattr(pipe, 'retrievers', None)).__name__}") |
| n_docs = len(getattr(pipe, "_indexed_docs", []) or []) |
| print(f"Indexed docs: {n_docs}") |
| print() |
|
|
| if n_docs == 0: |
| print("❌ Index is empty — no docs to retrieve from.") |
| print(" The pipeline started but the corpus wasn't loaded.") |
| print(" Check TAU_RAG_CORPUS_JSONL or pipeline hydration.") |
| return |
|
|
| rows = [] |
| for q, anchor in TESTS: |
| retrieved = get_retrieved(pipe, q, top_k=30) |
| rank = 0 |
| for i, r in enumerate(retrieved, 1): |
| text = get_text(r) |
| if re.search(anchor, text): |
| rank = i |
| break |
| rows.append({"q": q, "rank": rank, "n": len(retrieved)}) |
| marker = "✓" if rank > 0 else "✗" |
| print(f"{marker} rank={rank or '—':>3} /{len(retrieved):>2} {q}") |
| if rank == 0 and retrieved: |
| |
| t = get_text(retrieved[0])[:100] |
| print(f" [top-1: {t}...]") |
|
|
| n = len(rows) |
| n_top1 = sum(1 for r in rows if 1 <= r["rank"] <= 1) |
| n_top3 = sum(1 for r in rows if 1 <= r["rank"] <= 3) |
| n_top10 = sum(1 for r in rows if 1 <= r["rank"] <= 10) |
| n_top30 = sum(1 for r in rows if r["rank"] >= 1) |
| n_miss = sum(1 for r in rows if r["rank"] == 0) |
|
|
| print() |
| print("=" * 78) |
| print(" AGGREGATE") |
| print("=" * 78) |
| print(f" total queries: {n}") |
| print(f" hit @ top-1: {n_top1}/{n} ({100*n_top1/n:.0f}%)") |
| print(f" hit @ top-3: {n_top3}/{n} ({100*n_top3/n:.0f}%)") |
| print(f" hit @ top-10: {n_top10}/{n} ({100*n_top10/n:.0f}%)") |
| print(f" hit @ top-30: {n_top30}/{n} ({100*n_top30/n:.0f}%)") |
| print(f" missed entirely: {n_miss}/{n}") |
| print() |
| print(" VERDICT for reranker viability:") |
| if n_top10 >= n - 2: |
| print(" ✅ STRONG — reranker on top-30 will likely succeed.") |
| elif n_top30 >= n * 0.7: |
| print(" ⚠️ MIXED — top-30 mostly covers but BM25 needs tuning.") |
| else: |
| print(" ❌ WEAK — reranker can't fix retrieval gap.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|