#!/usr/bin/env python3 """probe_signal.py — Test whether retrieval-similarity discriminates outcomes. Before redesigning Ω, we need to know empirically: for cases where we KNOW the actual outcome, is the user's fact-similarity to accepted precedents systematically higher or lower than similarity to rejected precedents? Method: For each test judgment with known outcome: 1. Stratified-retrieve top-5 accepted + top-5 rejected 2. Compute s_acc_mean = mean retrieval score of accepted hits 3. Compute s_rej_mean = mean retrieval score of rejected hits 4. Delta = s_acc - s_rej 5. Group by actual outcome and compare distributions If actual ACCEPT cases have systematically larger delta than actual REJECT cases, retrieval-similarity HAS discriminative power and we should base τ on it. If distributions overlap, fact-similarity alone is insufficient. Usage: python3 -B -m tau_rag.scripts.probe_signal \\ --parquet ... --n 200 --corpus-size 1000 """ from __future__ import annotations import argparse import statistics import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2])) def main(): ap = argparse.ArgumentParser() ap.add_argument("--parquet", required=True) ap.add_argument("--n", type=int, default=200) ap.add_argument("--corpus-size", type=int, default=1000) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() import os os.environ["TAU_RAG_AUTOLOAD_CORPUS"] = "0" os.environ["TAU_RAG_LAZY_TEXT"] = "0" from tau_rag.scripts.benchmark_pipeline import ( _load_random_docs, _extract_facts_section, ) from tau_rag.scripts.build_polarity_lexicon import detect_outcome from tau_rag.retrieve.stratified import build_outcome_map from tau_rag.pipeline import get_pipeline from tau_rag.core.types import Document, Chunk, Query docs = _load_random_docs(args.parquet, n=args.corpus_size, seed=args.seed) print(f"loaded {len(docs)} substantive judgments") pipeline = get_pipeline() indexed_docs = [] chunks = [] for doc_id, txt in docs: indexed_docs.append(Document(id=doc_id, text=txt, metadata={})) chunks.append(Chunk(doc_id=doc_id, chunk_id=doc_id, text=txt)) pipeline._indexed_docs = indexed_docs pipeline.add_chunks(chunks) outcome_map = build_outcome_map(docs, detect_outcome) print(f"outcome distribution: " f"accepted={sum(1 for v in outcome_map.values() if v=='accepted')} " f"rejected={sum(1 for v in outcome_map.values() if v=='rejected')} " f"unknown={sum(1 for v in outcome_map.values() if v is None)}") named = getattr(pipeline.retrievers, "_retrievers", {}) or {} retr = named.get("hebrew_encoder") or pipeline.retrievers test_docs = docs[:args.n] rows = [] # (actual_outcome, s_acc_mean, s_rej_mean, delta, n_acc, n_rej) t0 = time.time() for i, (doc_id, text) in enumerate(test_docs): actual = detect_outcome(text) if actual not in ("ACCEPT", "REJECT"): continue facts = _extract_facts_section(text) try: # Wide pool: 80 candidates, then partition pool = retr.search(Query(text=facts), k=80) except Exception: continue # Filter self pool = [h for h in pool if h.chunk.doc_id != doc_id] acc_hits = [h for h in pool if outcome_map.get(h.chunk.doc_id) == "accepted"][:5] rej_hits = [h for h in pool if outcome_map.get(h.chunk.doc_id) == "rejected"][:5] if len(acc_hits) < 2 or len(rej_hits) < 2: continue s_acc = sum(h.score for h in acc_hits) / len(acc_hits) s_rej = sum(h.score for h in rej_hits) / len(rej_hits) delta = s_acc - s_rej rows.append({ "doc_id": doc_id, "actual": actual, "s_acc_mean": s_acc, "s_rej_mean": s_rej, "delta": delta, "n_acc_pool": len(acc_hits), "n_rej_pool": len(rej_hits), }) if (i + 1) % 25 == 0: print(f" {i+1}/{len(test_docs)} ({time.time()-t0:.1f}s)", flush=True) print(f"\nProbed {len(rows)} cases with detectable outcome\n") # Group by actual outcome accepts = [r for r in rows if r["actual"] == "ACCEPT"] rejects = [r for r in rows if r["actual"] == "REJECT"] print(f"Actual ACCEPT cases: n={len(accepts)}") print(f"Actual REJECT cases: n={len(rejects)}") print() if accepts: deltas_a = [r["delta"] for r in accepts] s_a_a = [r["s_acc_mean"] for r in accepts] s_a_r = [r["s_rej_mean"] for r in accepts] print("For actual-ACCEPT cases:") print(f" s_acc_mean: mean={statistics.mean(s_a_a):.2f} " f"median={statistics.median(s_a_a):.2f}") print(f" s_rej_mean: mean={statistics.mean(s_a_r):.2f} " f"median={statistics.median(s_a_r):.2f}") print(f" delta: mean={statistics.mean(deltas_a):+.2f} " f"median={statistics.median(deltas_a):+.2f} " f"(positive = better signal)") print(f" delta>0 fraction: " f"{sum(1 for d in deltas_a if d > 0)/len(deltas_a):.2f}") if rejects: deltas_r = [r["delta"] for r in rejects] s_r_a = [r["s_acc_mean"] for r in rejects] s_r_r = [r["s_rej_mean"] for r in rejects] print("\nFor actual-REJECT cases:") print(f" s_acc_mean: mean={statistics.mean(s_r_a):.2f} " f"median={statistics.median(s_r_a):.2f}") print(f" s_rej_mean: mean={statistics.mean(s_r_r):.2f} " f"median={statistics.median(s_r_r):.2f}") print(f" delta: mean={statistics.mean(deltas_r):+.2f} " f"median={statistics.median(deltas_r):+.2f} " f"(positive = WRONG signal)") print(f" delta>0 fraction: " f"{sum(1 for d in deltas_r if d > 0)/len(deltas_r):.2f}") if accepts and rejects: mean_a = statistics.mean(deltas_a) mean_r = statistics.mean(deltas_r) gap = mean_a - mean_r print(f"\n=== KEY METRIC ===") print(f" mean(delta | ACCEPT) - mean(delta | REJECT) = {gap:+.3f}") print(f" Positive value means similarity DOES discriminate.") print(f" Larger gap = stronger signal.") # Best linear classifier on delta — at what threshold do we get # the best discrimination? all_deltas = [(r["delta"], r["actual"]) for r in rows] all_deltas.sort() best_acc = 0 best_t = None for thr_d, _ in all_deltas: tp = sum(1 for d, a in all_deltas if d > thr_d and a == "ACCEPT") tn = sum(1 for d, a in all_deltas if d <= thr_d and a == "REJECT") acc = (tp + tn) / len(all_deltas) if acc > best_acc: best_acc = acc best_t = thr_d baseline = len(rejects) / (len(accepts) + len(rejects)) print(f" Best threshold accuracy: {best_acc:.3f} " f"(at delta_threshold={best_t:+.3f})") print(f" Always-REJECT baseline: {baseline:.3f}") print(f" Lift over baseline: {best_acc - baseline:+.3f}") if __name__ == "__main__": main()