legal-eye / tau_rag /scripts /diagnose_v10.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
#!/usr/bin/env python3
"""
Diagnose v10 — runs 12 representative queries through the live API and
produces a single comprehensive report comparing each output to the
teacher answer it was supposed to learn.
Categories tested:
• Section quotes seen in training (3 queries)
• Conceptual questions seen in training (3 queries)
• Applied scenarios seen in training (3 queries)
• "I don't know" probes (1 query)
• Generalization — paraphrased versions of trained queries (2 queries)
"""
import json
import re
import sys
import urllib.request
from collections import Counter
from pathlib import Path
API = "http://127.0.0.1:8000/v1/query"
TRACES = Path(__file__).resolve().parent.parent / "runtime" / "training_data" / "traces.jsonl"
# Test set — (query, expected_topic, type)
TESTS = [
# === Section quotes (trained verbatim) ===
("מה אומר סעיף 39 לחוק החוזים?", "תום לב", "section"),
("מה אומר סעיף 12 לחוק החוזים?", "תום לב במשא ומתן", "section"),
("מה אומר סעיף 30 לחוק החוזים?", "חוזה פסול", "section"),
# === Conceptual (trained) ===
("מה זה גמירת דעת?", "כריתה", "concept"),
("מה זה מסויימות?", "כריתה", "concept"),
("הסבר על קיום בתום לב", "תום לב", "concept"),
# === Applied (trained) ===
("מה התרופות במכר פגום?", "תרופות", "applied"),
("מתי קונה רשאי לבטל מכר?", "ביטול מכר", "applied"),
("מה ההבדל בין שכירות למכירה?", "מכר vs שכירות", "applied"),
# === "I don't know" probe ===
("שנשר25512551", "gibberish", "idk"),
# === Generalization (paraphrase of trained) ===
("הסבר את חובת תום הלב בחוזים", "תום לב", "generalize"),
("איך מתבטל חוזה פסול?", "סעיף 30/31", "generalize"),
]
# Key terms expected per topic — used for lexical recall scoring
TOPIC_TERMS = {
"תום לב": ["תום", "לב", "סעיף", "חוזה", "חיוב"],
"תום לב במשא ומתן": ["תום", "לב", "משא", "ומתן", "פיצויים"],
"חוזה פסול": ["פסול", "תקנת", "ציבור", "בלתי", "חוקיים", "בטל"],
"כריתה": ["סעיף", "חוזה", "הצעה", "קיבול", "גמירת"],
"תרופות": ["אכיפה", "ביטול", "פיצויים", "השבה", "הפרה"],
"ביטול מכר": ["ביטול", "מכר", "הפרה", "ארכה", "פגם"],
"מכר vs שכירות": ["מכר", "שכירות", "בעלות", "שוכר", "משכיר"],
"gibberish": [], # for idk: should be SHORT and apologetic
"סעיף 30/31": ["פסול", "תקנת", "ציבור", "בטל", "השבה"],
}
def call_api(query):
body = json.dumps({"query": query, "top_k": 3}, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
API, data=body,
headers={"Content-Type": "application/json; charset=utf-8"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode("utf-8"))
def find_teacher(query, rows):
"""Find the teacher_answer for a query (exact match)."""
for r in rows:
if r.get("query") == query and r.get("teacher_answer"):
return r["teacher_answer"]
return None
def metrics(text):
words = re.findall(r"\S+", text)
if not words:
return {"words": 0, "unique": 0, "ratio": 0.0, "max_repeat": 0}
counts = Counter(words)
return {
"words": len(words),
"unique": len(counts),
"ratio": len(counts) / len(words),
"max_repeat": max(counts.values()),
}
def overlap(a, b):
"""Jaccard on word sets, ignoring punctuation."""
wa = set(re.findall(r"[א-ת]+", a))
wb = set(re.findall(r"[א-ת]+", b))
if not wa or not wb:
return 0.0
return len(wa & wb) / len(wa | wb)
def main():
rows = [json.loads(l) for l in open(TRACES, encoding="utf-8")]
print("=" * 78)
print(" v10 DIAGNOSTIC — comparing model output to teacher answers")
print("=" * 78)
summary = []
for query, topic, kind in TESTS:
try:
r = call_api(query)
except Exception as e:
print(f"\n❌ {query}\n API error: {e}")
continue
ans = r.get("answer", "")
gen = r.get("generator", {})
conf = r.get("confidence", 0)
used = gen.get("used", "?")
teacher = find_teacher(query, rows)
m = metrics(ans)
ov = overlap(ans, teacher) if teacher else 0.0
# lexical recall — fraction of expected terms present
expected = TOPIC_TERMS.get(topic, [])
hits = [t for t in expected if t in ans]
recall = len(hits) / len(expected) if expected else None
# verdict — accept both tau_native (LM) and extractive (verbatim)
# as legitimate. Extractive = correct verbatim citation = WIN.
legitimate = used in ("tau_native", "extractive")
if not legitimate:
verdict = f"⚠️ fallback to {used}"
elif kind == "idk":
verdict = "✅ short" if m["words"] < 30 else "❌ long for idk"
elif m["ratio"] < 0.4:
verdict = "❌ degenerate"
elif recall is not None and recall < 0.4:
verdict = "❌ off-topic"
elif used == "extractive":
# Verbatim citation — Jaccard with paraphrased teacher will
# be lower than the LM target but the answer IS the law text.
verdict = "✅ verbatim cite"
elif ov < 0.05:
verdict = "❌ no teacher overlap"
else:
verdict = "✓ on-topic, fluent?"
summary.append({
"query": query,
"kind": kind,
"topic": topic,
"verdict": verdict,
"used": used,
"conf": conf,
"words": m["words"],
"ratio": m["ratio"],
"max_repeat": m["max_repeat"],
"recall": recall,
"overlap": ov,
"answer": ans,
})
print(f"\n{'─' * 78}")
print(f"Q [{kind}]: {query}")
print(f" verdict: {verdict} used={used} conf={conf:.2f}")
print(f" words={m['words']} uniq_ratio={m['ratio']:.2f} "
f"max_repeat={m['max_repeat']} "
f"recall={recall if recall is None else f'{recall:.2f}'} "
f"teacher_overlap={ov:.2f}")
if expected:
missing = [t for t in expected if t not in ans]
if missing:
print(f" missing terms: {missing}")
print(f" answer: {ans[:200]}")
# ===== AGGREGATE =====
print("\n" + "=" * 78)
print(" AGGREGATE")
print("=" * 78)
n = len(summary)
n_native = sum(1 for s in summary if s["used"] == "tau_native")
n_pass = sum(1 for s in summary
if s["verdict"].startswith(("✓", "✅")))
n_degen = sum(1 for s in summary if "degenerate" in s["verdict"])
n_off = sum(1 for s in summary if "off-topic" in s["verdict"])
n_fallbk = sum(1 for s in summary if "fallback" in s["verdict"])
avg_conf = sum(s["conf"] for s in summary) / n if n else 0
avg_ratio= sum(s["ratio"] for s in summary) / n if n else 0
avg_ov = sum(s["overlap"] for s in summary) / n if n else 0
avg_recall = ([s["recall"] for s in summary if s["recall"] is not None])
avg_recall = sum(avg_recall)/len(avg_recall) if avg_recall else 0
print(f" total queries: {n}")
print(f" used tau_native: {n_native}/{n} ({100*n_native/n:.0f}%)")
print(f" passed verdict (✓): {n_pass}/{n}")
print(f" degenerate: {n_degen}/{n}")
print(f" off-topic: {n_off}/{n}")
print(f" fell back: {n_fallbk}/{n}")
print(f" avg confidence: {avg_conf:.2f}")
print(f" avg unique_ratio: {avg_ratio:.2f} (gate ≥ 0.35)")
print(f" avg topic-term recall: {avg_recall:.2f}")
print(f" avg teacher overlap (J): {avg_ov:.2f}")
# By category
print("\n BY CATEGORY:")
for kind in ["section", "concept", "applied", "idk", "generalize"]:
rows_k = [s for s in summary if s["kind"] == kind]
if not rows_k:
continue
rec = [s["recall"] for s in rows_k if s["recall"] is not None]
rec = sum(rec)/len(rec) if rec else 0
ov = sum(s["overlap"] for s in rows_k) / len(rows_k)
print(f" {kind:11s} n={len(rows_k)} recall={rec:.2f} "
f"overlap={ov:.2f}")
# save full report
out = Path(__file__).resolve().parent.parent / "runtime" / "v10_diagnostic.json"
out.write_text(json.dumps(summary, ensure_ascii=False, indent=2))
print(f"\n full report → {out}")
if __name__ == "__main__":
main()