File size: 9,302 Bytes
3be54c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python3
"""
Diagnose v10 — runs 12 representative queries through the live API and
produces a single comprehensive report comparing each output to the
teacher answer it was supposed to learn.

Categories tested:
  • Section quotes seen in training (3 queries)
  • Conceptual questions seen in training (3 queries)
  • Applied scenarios seen in training (3 queries)
  • "I don't know" probes (1 query)
  • Generalization — paraphrased versions of trained queries (2 queries)
"""
import json
import re
import sys
import urllib.request
from collections import Counter
from pathlib import Path

API = "http://127.0.0.1:8000/v1/query"
TRACES = Path(__file__).resolve().parent.parent / "runtime" / "training_data" / "traces.jsonl"

# Test set — (query, expected_topic, type)
TESTS = [
    # === Section quotes (trained verbatim) ===
    ("מה אומר סעיף 39 לחוק החוזים?",       "תום לב",           "section"),
    ("מה אומר סעיף 12 לחוק החוזים?",       "תום לב במשא ומתן",  "section"),
    ("מה אומר סעיף 30 לחוק החוזים?",       "חוזה פסול",         "section"),
    # === Conceptual (trained) ===
    ("מה זה גמירת דעת?",                    "כריתה",             "concept"),
    ("מה זה מסויימות?",                     "כריתה",             "concept"),
    ("הסבר על קיום בתום לב",                "תום לב",            "concept"),
    # === Applied (trained) ===
    ("מה התרופות במכר פגום?",                "תרופות",            "applied"),
    ("מתי קונה רשאי לבטל מכר?",              "ביטול מכר",         "applied"),
    ("מה ההבדל בין שכירות למכירה?",          "מכר vs שכירות",     "applied"),
    # === "I don't know" probe ===
    ("שנשר25512551",                        "gibberish",         "idk"),
    # === Generalization (paraphrase of trained) ===
    ("הסבר את חובת תום הלב בחוזים",          "תום לב",            "generalize"),
    ("איך מתבטל חוזה פסול?",                "סעיף 30/31",        "generalize"),
]

# Key terms expected per topic — used for lexical recall scoring
TOPIC_TERMS = {
    "תום לב":            ["תום", "לב", "סעיף", "חוזה", "חיוב"],
    "תום לב במשא ומתן":  ["תום", "לב", "משא", "ומתן", "פיצויים"],
    "חוזה פסול":         ["פסול", "תקנת", "ציבור", "בלתי", "חוקיים", "בטל"],
    "כריתה":             ["סעיף", "חוזה", "הצעה", "קיבול", "גמירת"],
    "תרופות":            ["אכיפה", "ביטול", "פיצויים", "השבה", "הפרה"],
    "ביטול מכר":         ["ביטול", "מכר", "הפרה", "ארכה", "פגם"],
    "מכר vs שכירות":     ["מכר", "שכירות", "בעלות", "שוכר", "משכיר"],
    "gibberish":         [],  # for idk: should be SHORT and apologetic
    "סעיף 30/31":        ["פסול", "תקנת", "ציבור", "בטל", "השבה"],
}


def call_api(query):
    body = json.dumps({"query": query, "top_k": 3}, ensure_ascii=False).encode("utf-8")
    req = urllib.request.Request(
        API, data=body,
        headers={"Content-Type": "application/json; charset=utf-8"},
    )
    with urllib.request.urlopen(req, timeout=30) as resp:
        return json.loads(resp.read().decode("utf-8"))


def find_teacher(query, rows):
    """Find the teacher_answer for a query (exact match)."""
    for r in rows:
        if r.get("query") == query and r.get("teacher_answer"):
            return r["teacher_answer"]
    return None


def metrics(text):
    words = re.findall(r"\S+", text)
    if not words:
        return {"words": 0, "unique": 0, "ratio": 0.0, "max_repeat": 0}
    counts = Counter(words)
    return {
        "words": len(words),
        "unique": len(counts),
        "ratio": len(counts) / len(words),
        "max_repeat": max(counts.values()),
    }


def overlap(a, b):
    """Jaccard on word sets, ignoring punctuation."""
    wa = set(re.findall(r"[א-ת]+", a))
    wb = set(re.findall(r"[א-ת]+", b))
    if not wa or not wb:
        return 0.0
    return len(wa & wb) / len(wa | wb)


def main():
    rows = [json.loads(l) for l in open(TRACES, encoding="utf-8")]
    print("=" * 78)
    print(" v10 DIAGNOSTIC — comparing model output to teacher answers")
    print("=" * 78)

    summary = []
    for query, topic, kind in TESTS:
        try:
            r = call_api(query)
        except Exception as e:
            print(f"\n❌ {query}\n   API error: {e}")
            continue

        ans = r.get("answer", "")
        gen = r.get("generator", {})
        conf = r.get("confidence", 0)
        used = gen.get("used", "?")
        teacher = find_teacher(query, rows)

        m = metrics(ans)
        ov = overlap(ans, teacher) if teacher else 0.0

        # lexical recall — fraction of expected terms present
        expected = TOPIC_TERMS.get(topic, [])
        hits = [t for t in expected if t in ans]
        recall = len(hits) / len(expected) if expected else None

        # verdict — accept both tau_native (LM) and extractive (verbatim)
        # as legitimate. Extractive = correct verbatim citation = WIN.
        legitimate = used in ("tau_native", "extractive")
        if not legitimate:
            verdict = f"⚠️  fallback to {used}"
        elif kind == "idk":
            verdict = "✅ short" if m["words"] < 30 else "❌ long for idk"
        elif m["ratio"] < 0.4:
            verdict = "❌ degenerate"
        elif recall is not None and recall < 0.4:
            verdict = "❌ off-topic"
        elif used == "extractive":
            # Verbatim citation — Jaccard with paraphrased teacher will
            # be lower than the LM target but the answer IS the law text.
            verdict = "✅ verbatim cite"
        elif ov < 0.05:
            verdict = "❌ no teacher overlap"
        else:
            verdict = "✓ on-topic, fluent?"

        summary.append({
            "query": query,
            "kind": kind,
            "topic": topic,
            "verdict": verdict,
            "used": used,
            "conf": conf,
            "words": m["words"],
            "ratio": m["ratio"],
            "max_repeat": m["max_repeat"],
            "recall": recall,
            "overlap": ov,
            "answer": ans,
        })

        print(f"\n{'─' * 78}")
        print(f"Q [{kind}]: {query}")
        print(f"   verdict: {verdict}  used={used}  conf={conf:.2f}")
        print(f"   words={m['words']}  uniq_ratio={m['ratio']:.2f}  "
              f"max_repeat={m['max_repeat']}  "
              f"recall={recall if recall is None else f'{recall:.2f}'}  "
              f"teacher_overlap={ov:.2f}")
        if expected:
            missing = [t for t in expected if t not in ans]
            if missing:
                print(f"   missing terms: {missing}")
        print(f"   answer: {ans[:200]}")

    # ===== AGGREGATE =====
    print("\n" + "=" * 78)
    print(" AGGREGATE")
    print("=" * 78)
    n = len(summary)
    n_native = sum(1 for s in summary if s["used"] == "tau_native")
    n_pass   = sum(1 for s in summary
                   if s["verdict"].startswith(("✓", "✅")))
    n_degen  = sum(1 for s in summary if "degenerate" in s["verdict"])
    n_off    = sum(1 for s in summary if "off-topic" in s["verdict"])
    n_fallbk = sum(1 for s in summary if "fallback" in s["verdict"])
    avg_conf = sum(s["conf"] for s in summary) / n if n else 0
    avg_ratio= sum(s["ratio"] for s in summary) / n if n else 0
    avg_ov   = sum(s["overlap"] for s in summary) / n if n else 0
    avg_recall = ([s["recall"] for s in summary if s["recall"] is not None])
    avg_recall = sum(avg_recall)/len(avg_recall) if avg_recall else 0

    print(f"  total queries:            {n}")
    print(f"  used tau_native:          {n_native}/{n} ({100*n_native/n:.0f}%)")
    print(f"  passed verdict (✓):       {n_pass}/{n}")
    print(f"  degenerate:               {n_degen}/{n}")
    print(f"  off-topic:                {n_off}/{n}")
    print(f"  fell back:                {n_fallbk}/{n}")
    print(f"  avg confidence:           {avg_conf:.2f}")
    print(f"  avg unique_ratio:         {avg_ratio:.2f}  (gate ≥ 0.35)")
    print(f"  avg topic-term recall:    {avg_recall:.2f}")
    print(f"  avg teacher overlap (J):  {avg_ov:.2f}")

    # By category
    print("\n  BY CATEGORY:")
    for kind in ["section", "concept", "applied", "idk", "generalize"]:
        rows_k = [s for s in summary if s["kind"] == kind]
        if not rows_k:
            continue
        rec = [s["recall"] for s in rows_k if s["recall"] is not None]
        rec = sum(rec)/len(rec) if rec else 0
        ov  = sum(s["overlap"] for s in rows_k) / len(rows_k)
        print(f"    {kind:11s}  n={len(rows_k)}  recall={rec:.2f}  "
              f"overlap={ov:.2f}")

    # save full report
    out = Path(__file__).resolve().parent.parent / "runtime" / "v10_diagnostic.json"
    out.write_text(json.dumps(summary, ensure_ascii=False, indent=2))
    print(f"\n  full report → {out}")


if __name__ == "__main__":
    main()