#!/usr/bin/env python3 """benchmark_pipeline.py — End-to-end empirical validation on real corpus. We've built each stage of the legal-argument pipeline and tested each in isolation. This script runs the WHOLE thing on real Hebrew judgments from the parquet corpus and reports per-stage metrics: Stage 1: Outcome-detector accuracy (regex catches actual rulings) Stage 2: Polarity-classifier accuracy on discussion paragraphs Stage 3: Retrieval health on real queries (Ω distribution) Stage 4: CBR template extraction (n templates, source diversity) Stage 5: Outcome-Ω prediction accuracy (does Ω match actual outcome?) Method: Pull N random judgments from parquet → for each, treat its FACTS section as a user-input simulation → run the full strategy pipeline → compare predicted Ω against the judgment's actual outcome. Self-leakage prevention: the test judgment is excluded from retrieval results during its own evaluation. Usage: python3 -m tau_rag.scripts.benchmark_pipeline \\ --parquet storage/raw/datasets/698f9b2b-c27f-4857-aa0a-985bace9f2e2.parquet \\ --n 30 """ from __future__ import annotations import argparse import json import random import re import sys import time from collections import Counter, defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple sys.path.insert(0, str(Path(__file__).resolve().parents[2])) def _repair_hebrew_encoding(text: str) -> str: """Detect and repair cp1255 mojibake (Windows-Hebrew shown as Latin-1). Heuristic: if the text contains characters in the range \\u00C0-\\u00FF (Latin-1 supplement) but very few real Hebrew chars, it's likely cp1255 misencoded. Try to round-trip and check if Hebrew appears. """ if not text: return text hebrew_count = sum(1 for c in text if "א" <= c <= "ת") latin_supp_count = sum(1 for c in text if "À" <= c <= "ÿ") if hebrew_count >= 50: return text # already proper Hebrew, no repair needed if latin_supp_count >= 50: # Likely cp1255 misencoded — try repair try: repaired = text.encode("latin-1", errors="ignore").decode( "cp1255", errors="ignore") new_hebrew = sum(1 for c in repaired if "א" <= c <= "ת") if new_hebrew > hebrew_count + 50: return repaired except Exception: pass return text def _is_substantive(text: str, min_chars: int = 3000) -> bool: """A judgment is 'substantive' (worth benchmarking on) if: • length ≥ min_chars (filters out scheduling orders, brief motions) • has at least 200 Hebrew letters (filters mojibake-only docs) """ if not text or len(text) < min_chars: return False hebrew = sum(1 for c in text if "א" <= c <= "ת") return hebrew >= 200 def _load_random_docs( parquet_path: str, n: int, seed: int = 42, min_chars: int = 3000, ): """Load `n` random *substantive* judgments from the parquet. Filters: - docs with broken cp1255 encoding (auto-repair attempted) - docs <3000 chars (scheduling orders, brief motions — no argument structure to mine) - docs without enough Hebrew (mojibake fallout that didn't repair) """ import pyarrow.parquet as pq schema = pq.read_schema(parquet_path) cols = ["text"] id_col = None for cand in ("doc_id", "id", "case_id", "__filename", "filename"): if cand in schema.names: id_col = cand cols.append(cand) break table = pq.read_table(parquet_path, columns=cols) rows = [] n_total = n_repaired = n_too_short = n_bad_encoding = 0 for batch in table.to_batches(): texts = batch.column("text").to_pylist() ids = (batch.column(id_col).to_pylist() if id_col else [None] * len(texts)) for i, txt in enumerate(texts): n_total += 1 if not txt: continue txt_repaired = _repair_hebrew_encoding(txt) if txt_repaired is not txt and len(txt_repaired) >= min_chars: n_repaired += 1 txt = txt_repaired if not _is_substantive(txt, min_chars): if len(txt) < min_chars: n_too_short += 1 else: n_bad_encoding += 1 continue rows.append((ids[i] or f"doc_{len(rows)}", txt)) print(f"[loader] scanned {n_total} docs | substantive: {len(rows)} | " f"too short: {n_too_short} | bad encoding: {n_bad_encoding} | " f"repaired: {n_repaired}", flush=True) random.seed(seed) random.shuffle(rows) return rows[:n] def _extract_facts_section(text: str) -> str: """Get the 'facts' / opening narrative of a judgment. Heuristic: take everything between header and the first appearance of 'דיון' or 'טענות'. If neither is found, return chars[300:1500]. """ body = text for marker in ["דיון", "ההכרעה", "טענות התובע", "טענות"]: idx = body.find(marker) if idx > 200: return body[200:idx][:1500] return body[300:1500] _DIAG_FIRED = {"v": False} def _evaluate_one( doc_id: str, text: str, pipeline, outcome_detector_func, ): """Run the full pipeline on one judgment, treating its facts as input. Returns a dict of per-stage results for aggregation. """ from tau_rag.intelligence import ( compute_retrieval_signals, compute_outcome_signals, CaseBasedArgumentExtractor, ) from tau_rag.core.types import Query from tau_rag.intelligence.case_based_arguments import ( _ACCEPT_MARKERS, _REJECT_MARKERS, ) # Stage 1: ground-truth outcome from the operative section actual_outcome = outcome_detector_func(text) facts_input = _extract_facts_section(text) # Stage 2: polarity-classification spot check on the discussion section discussion = text[int(len(text) * 0.3):int(len(text) * 0.7)] accept_hits = sum(1 for w in _ACCEPT_MARKERS if w in discussion) reject_hits = sum(1 for w in _REJECT_MARKERS if w in discussion) polarity_predicted = ( "ACCEPT" if accept_hits > reject_hits + 0.5 else "REJECT" if reject_hits > accept_hits + 0.5 else "UNCLEAR" ) polarity_correct = (polarity_predicted == actual_outcome) # Stage 3: retrieval health on this judgment's facts # Prefer the STRATIFIED retriever (built once at indexing) — it breaks # base-rate dominance by returning balanced accepted/rejected results. retriever = getattr(pipeline, "_stratified_retriever", None) if retriever is None: named = getattr(pipeline.retrievers, "_retrievers", {}) or {} retriever = named.get("hebrew_encoder") or pipeline.retrievers try: hits = retriever.search(Query(text=facts_input), k=10) except Exception: hits = [] # Filter self-hits (no leakage) hits = [h for h in hits if h.chunk.doc_id != doc_id][:8] retrieval_sig = compute_retrieval_signals(hits) if hits else None # Stage 4: CBR template extraction cbr_extractor = CaseBasedArgumentExtractor( retriever=retriever, tau_llm_polish=False, ) # Manually filter the test-doc out of retrieval at the extractor level # by passing a wrapped retriever class _FilteredRetriever: def __init__(self, inner, exclude_id): self._inner = inner self._exclude = exclude_id def search(self, q, k): res = self._inner.search(q, k=k + 5) return [r for r in res if r.chunk.doc_id != self._exclude][:k] @property def name(self): return getattr(self._inner, "name", "filtered") cbr_extractor.retriever = _FilteredRetriever(retriever, doc_id) try: cbr_result = cbr_extractor.extract_and_draft( user_facts=facts_input, side="claimant", top_k_cases=10, full_text_loader=pipeline.get_text, ) except Exception as e: cbr_result = {"error": str(e), "argument_templates": []} # One-shot diagnostic: walk the path manually for the first eval and # show exactly where templates are lost, if anywhere. if not _DIAG_FIRED["v"]: _DIAG_FIRED["v"] = True print(f"\n[deep-diag] first eval: doc={doc_id}", flush=True) print(f" cbr_result keys: {list(cbr_result.keys())}", flush=True) if "error" in cbr_result: print(f" ERROR: {cbr_result['error']}", flush=True) n_tmpl = len(cbr_result.get("argument_templates") or []) n_drft = len(cbr_result.get("drafted_arguments_for_user") or []) print(f" argument_templates returned: {n_tmpl}", flush=True) print(f" drafted_arguments returned: {n_drft}", flush=True) # Now manually walk extract path for direct comparison manual_hits = cbr_extractor.retriever.search( __import__("tau_rag.core.types", fromlist=["Query"]).Query(text=facts_input), k=10, ) print(f" manual retrieval: {len(manual_hits)} hits", flush=True) for h in manual_hits[:3]: ft = pipeline.get_text(h.chunk.doc_id) or "" t_count = len(cbr_extractor._extract_arguments_from_one_case( h.chunk.doc_id, ft)) print(f" hit {h.chunk.doc_id} | text={len(ft)} chars | " f"templates={t_count}", flush=True) # Stage 5: outcome-Ω prediction with delta-based τ # We pass the RAW retrieval hits + outcome_map so compute_outcome_signals # can derive τ from the empirically-validated similarity-delta signal # (probe_signal showed +33pt gap between actual ACCEPT and REJECT cases). omega_signals = None omega_predicted = None if cbr_result.get("argument_templates"): try: # Re-retrieve RAW hits for the τ calculation (CBR uses templates; # we need hits with outcome labels). Use the same stratified # retriever to ensure both classes are represented. outcome_map_pipe = getattr(pipeline, "_outcome_map", None) tau_hits = [] if outcome_map_pipe is not None: strat = getattr(pipeline, "_stratified_retriever", None) if strat is not None: try: tau_hits = strat.search( Query(text=facts_input), k=10) tau_hits = [h for h in tau_hits if h.chunk.doc_id != doc_id] except Exception: tau_hits = [] omega_signals = compute_outcome_signals( argument_templates=cbr_result.get("argument_templates", []), drafted_arguments=cbr_result.get( "drafted_arguments_for_user", []), retrieved_hits=tau_hits or None, outcome_map=outcome_map_pipe, ) omega_predicted = ("ACCEPT" if omega_signals.omega >= 0.5 else "REJECT") except Exception: pass omega_correct = ( omega_predicted == actual_outcome if (omega_predicted and actual_outcome in ("ACCEPT", "REJECT")) else None ) # Capture the raw delta (s_acc_mean - s_rej_mean) so threshold sweeps # can re-classify without re-running the slow CBR pipeline. delta_value = None s_acc_mean = None s_rej_mean = None if outcome_map_pipe is not None: try: acc_scores = [h.score for h in (tau_hits or []) if outcome_map_pipe.get(h.chunk.doc_id) == "accepted"] rej_scores = [h.score for h in (tau_hits or []) if outcome_map_pipe.get(h.chunk.doc_id) == "rejected"] if len(acc_scores) >= 2 and len(rej_scores) >= 2: s_acc_mean = sum(acc_scores) / len(acc_scores) s_rej_mean = sum(rej_scores) / len(rej_scores) delta_value = s_acc_mean - s_rej_mean except Exception: pass return { "doc_id": doc_id, "actual_outcome": actual_outcome, "polarity_predicted": polarity_predicted, "polarity_correct": polarity_correct, "n_templates": len(cbr_result.get("argument_templates") or []), "n_unique_source_cases": cbr_result.get("n_similar_cases"), "retrieval_omega": ( retrieval_sig.omega if retrieval_sig else None), "outcome_omega": ( omega_signals.omega if omega_signals else None), "omega_predicted": omega_predicted, "omega_correct": omega_correct, "delta_value": delta_value, # raw similarity delta "s_acc_mean": s_acc_mean, "s_rej_mean": s_rej_mean, } def _aggregate(rows: List[Dict]) -> Dict[str, Any]: """Compute per-stage aggregate metrics, including per-class breakdown for outcome prediction (so we can tell if the model is genuinely detecting ACCEPT cases or just matching the base rate).""" n = len(rows) n_with_outcome = sum(1 for r in rows if r["actual_outcome"] in ("ACCEPT", "REJECT")) polarity_correct = sum(1 for r in rows if r["polarity_correct"] is True) polarity_evaluable = sum(1 for r in rows if r["actual_outcome"] in ("ACCEPT", "REJECT")) omega_correct = sum(1 for r in rows if r["omega_correct"] is True) omega_evaluable = sum(1 for r in rows if r["omega_correct"] is not None) retr_omegas = [r["retrieval_omega"] for r in rows if r["retrieval_omega"] is not None] out_omegas = [r["outcome_omega"] for r in rows if r["outcome_omega"] is not None] template_counts = [r["n_templates"] for r in rows] # Confusion matrix for outcome prediction (the critical question: # does Ω actually identify ACCEPT cases, or is it always-REJECT?) tp = fp = tn = fn = 0 for r in rows: actual = r["actual_outcome"] pred = r["omega_predicted"] if actual not in ("ACCEPT", "REJECT") or pred is None: continue if actual == "ACCEPT" and pred == "ACCEPT": tp += 1 elif actual == "REJECT" and pred == "ACCEPT": fp += 1 elif actual == "REJECT" and pred == "REJECT": tn += 1 elif actual == "ACCEPT" and pred == "REJECT": fn += 1 actual_accept = tp + fn actual_reject = tn + fp pred_accept = tp + fp base_rate_baseline = round(actual_reject / max(actual_accept + actual_reject, 1), 3) precision = round(tp / max(pred_accept, 1), 3) if pred_accept > 0 else None recall = round(tp / max(actual_accept, 1), 3) if actual_accept > 0 else None f1 = (round(2 * tp / (2 * tp + fp + fn), 3) if (2 * tp + fp + fn) > 0 else None) # Lift: how much better than always-REJECT baseline? accuracy = (tp + tn) / max(tp + tn + fp + fn, 1) lift = round(accuracy - base_rate_baseline, 3) return { "n_total": n, "n_with_outcome_detected": n_with_outcome, "outcome_detection_rate": round(n_with_outcome / max(n, 1), 3), "polarity_classifier": { "correct": polarity_correct, "evaluable": polarity_evaluable, "accuracy": (round(polarity_correct / polarity_evaluable, 3) if polarity_evaluable > 0 else None), }, "retrieval_health": { "n": len(retr_omegas), "omega_mean": round(sum(retr_omegas) / max(len(retr_omegas), 1), 3), "omega_min": round(min(retr_omegas), 3) if retr_omegas else None, "omega_max": round(max(retr_omegas), 3) if retr_omegas else None, }, "cbr_extraction": { "templates_mean": round( sum(template_counts) / max(len(template_counts), 1), 1), "templates_min": min(template_counts) if template_counts else 0, "templates_max": max(template_counts) if template_counts else 0, }, "outcome_omega_prediction": { "correct": omega_correct, "evaluable": omega_evaluable, "accuracy": (round(omega_correct / omega_evaluable, 3) if omega_evaluable > 0 else None), "omega_mean": (round(sum(out_omegas) / max(len(out_omegas), 1), 3) if out_omegas else None), # Confusion matrix — the critical detail "actual_accept_count": actual_accept, "actual_reject_count": actual_reject, "pred_accept_count": pred_accept, "true_positives": tp, # ACCEPT correctly identified "false_positives": fp, # REJECT wrongly called ACCEPT "true_negatives": tn, # REJECT correctly identified "false_negatives": fn, # ACCEPT missed (predicted REJECT) "precision_on_accept": precision, "recall_on_accept": recall, "f1_on_accept": f1, "always_reject_baseline_accuracy": base_rate_baseline, "lift_over_baseline": lift, }, } def main(): ap = argparse.ArgumentParser() ap.add_argument("--parquet", required=True) ap.add_argument("--n", type=int, default=30, help="Number of test judgments") ap.add_argument("--corpus-size", type=int, default=200, help="Total docs to index (test + reference). The " "additional docs serve as a retrieval corpus so " "leave-one-out queries return meaningful hits.") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--output", default="benchmark_report.json") ap.add_argument("--sweep-thresholds", type=str, default=None, help="Comma-separated delta thresholds to evaluate, " "e.g. '2.3,0,-5,-10'. When given, reports a " "precision/recall curve at each threshold. " "Without this, uses default threshold 2.3.") args = ap.parse_args() # CRITICAL: disable autoload so retrieval only sees the corpus we index # below. Otherwise hits come back referring to autoloaded doc_ids whose # text we can't resolve via pipeline.get_text → 0 templates. import os os.environ["TAU_RAG_AUTOLOAD_CORPUS"] = "0" os.environ["TAU_RAG_LAZY_TEXT"] = "0" # Pull more docs than tested — extra ones populate the retrieval corpus. total = max(args.n, args.corpus_size) print(f"[benchmark] loading {total} random judgments " f"(of which {args.n} will be tested)...", flush=True) docs = _load_random_docs(args.parquet, n=total, seed=args.seed) print(f"[benchmark] got {len(docs)} docs", flush=True) print(f"[benchmark] initializing pipeline + indexing corpus...", flush=True) from tau_rag.pipeline import get_pipeline from tau_rag.core.types import Document, Chunk pipeline = get_pipeline() # Force-index ALL loaded docs so retrieval has something to find. We # store one Chunk per Document — sufficient for similarity queries # at this corpus size. indexed_docs = [] chunks = [] for doc_id, txt in docs: d = Document(id=doc_id, text=txt, metadata={"title": doc_id}) c = Chunk(doc_id=doc_id, chunk_id=doc_id, text=txt, metadata={"title": doc_id}) indexed_docs.append(d) chunks.append(c) pipeline._indexed_docs = indexed_docs pipeline.add_chunks(chunks) print(f"[benchmark] indexed {len(chunks)} docs into pipeline", flush=True) # ───────────────────────────────────────────────────────────────── # NEW: build outcome-stratified retriever to break base-rate dominance # ───────────────────────────────────────────────────────────────── from tau_rag.scripts.build_polarity_lexicon import detect_outcome from tau_rag.retrieve.stratified import ( StratifiedRetriever, build_outcome_map, ) print(f"[benchmark] computing outcome map for {len(docs)} docs...", flush=True) outcome_map = build_outcome_map(docs, detect_outcome_fn=detect_outcome) n_acc = sum(1 for v in outcome_map.values() if v == "accepted") n_rej = sum(1 for v in outcome_map.values() if v == "rejected") n_unk = sum(1 for v in outcome_map.values() if v is None) print(f"[benchmark] outcome distribution: accepted={n_acc} " f"rejected={n_rej} unknown={n_unk}", flush=True) # Wrap the hebrew_encoder (or multi) retriever in stratified balancing named = getattr(pipeline.retrievers, "_retrievers", {}) or {} inner_retriever = named.get("hebrew_encoder") or pipeline.retrievers # pool_factor=12: corpus has ~5% acceptance rate, so to reliably get # 5 accepted candidates from the pool we need to query for ~120 docs # (60 × 2 = ~12k). Empirically tuned for this corpus's outcome # distribution. stratified = StratifiedRetriever( inner=inner_retriever, outcome_map=outcome_map, pool_factor=12, balance="balanced", ) # Stash for _evaluate_one to pick up pipeline._stratified_retriever = stratified pipeline._outcome_map = outcome_map from tau_rag.scripts.build_polarity_lexicon import detect_outcome # Test only the first args.n docs (rest serve as reference corpus) test_docs = docs[:args.n] # One-shot diagnostic on the first test doc — proves whether retrieval # is finding our indexed docs and whether get_text resolves them. diag_doc_id, diag_text = test_docs[0] print(f"\n[diagnostic] testing retrieval on first doc ({diag_doc_id})") facts_in = _extract_facts_section(diag_text) named = getattr(pipeline.retrievers, "_retrievers", {}) or {} retr = named.get("hebrew_encoder") or pipeline.retrievers from tau_rag.core.types import Query as _Q diag_hits = retr.search(_Q(text=facts_in), k=5) print(f"[diagnostic] retrieved {len(diag_hits)} hits") for h in diag_hits[:3]: in_index = any(d.id == h.chunk.doc_id for d in indexed_docs) get_text_result = pipeline.get_text(h.chunk.doc_id) or "" print(f" hit: doc_id={h.chunk.doc_id!r:25s} score={h.score:.3f} " f"in_indexed_docs={in_index} " f"get_text returns {len(get_text_result)} chars") rows: List[Dict] = [] t0 = time.time() for i, (doc_id, text) in enumerate(test_docs): try: r = _evaluate_one(doc_id, text, pipeline, detect_outcome) rows.append(r) if (i + 1) % 5 == 0: elapsed = time.time() - t0 print(f"[benchmark] {i+1}/{len(test_docs)} ({elapsed:.1f}s)", flush=True) except Exception as e: print(f"[benchmark] failed on {doc_id}: {e}", flush=True) continue agg = _aggregate(rows) print("\n" + "=" * 70) print("BENCHMARK RESULTS") print("=" * 70) print(json.dumps(agg, ensure_ascii=False, indent=2)) # Optional threshold sweep — uses the cached delta_value so it doesn't # need to re-run CBR. ~instant. if args.sweep_thresholds: thresholds = [float(t) for t in args.sweep_thresholds.split(",")] print("\n" + "=" * 70) print("THRESHOLD SWEEP (precision-recall trade-off)") print("=" * 70) print(f"{'thresh':>8s} {'acc':>5s} {'prec':>5s} {'recall':>6s} " f"{'F1':>5s} {'TP':>3s} {'FP':>3s} {'TN':>3s} {'FN':>3s} " f"{'lift':>5s}") baseline = agg["outcome_omega_prediction"][ "always_reject_baseline_accuracy"] for thr in thresholds: tp = fp = tn = fn = 0 for r in rows: actual = r["actual_outcome"] delta = r.get("delta_value") if actual not in ("ACCEPT", "REJECT") or delta is None: continue pred = "ACCEPT" if delta > thr else "REJECT" if actual == "ACCEPT" and pred == "ACCEPT": tp += 1 elif actual == "REJECT" and pred == "ACCEPT": fp += 1 elif actual == "REJECT" and pred == "REJECT": tn += 1 elif actual == "ACCEPT" and pred == "REJECT": fn += 1 n_eval = tp + fp + tn + fn if n_eval == 0: continue acc = (tp + tn) / n_eval prec = tp / max(tp + fp, 1) if (tp + fp) > 0 else 0.0 rec = tp / max(tp + fn, 1) if (tp + fn) > 0 else 0.0 f1 = 2 * tp / max(2 * tp + fp + fn, 1) if (2*tp+fp+fn) > 0 else 0.0 lift = acc - baseline print(f"{thr:+8.1f} {acc:>5.3f} {prec:>5.3f} {rec:>6.3f} " f"{f1:>5.3f} {tp:>3d} {fp:>3d} {tn:>3d} {fn:>3d} " f"{lift:>+5.3f}") Path(args.output).write_text( json.dumps({ "summary": agg, "per_doc_rows": rows, "config": vars(args), }, ensure_ascii=False, indent=2) ) print(f"\n[benchmark] full results saved to {args.output}") if __name__ == "__main__": main()