#!/usr/bin/env python3 """Evaluate the argument classifier against a human-verified gold set. Computes per-task precision / recall / F1 for: is_argument (binary) outcome (4-class: accepted / rejected / partial / unknown) side (4-class: plaintiff / defendant / court / unknown) arg_type (8-class) Three eval modes: --rule-based eval the AutoLabeler (rule-based baseline) --classifier PATH eval a fine-tuned HeBERT classifier --compare CKPT_PATH run BOTH and produce a side-by-side report Output: a console report PLUS a JSON sidecar at .eval.json. Usage: # Just the rule-based baseline python3 -m tau_rag.scripts.eval_classifier \\ --gold data/gold_pool.labels.jsonl \\ --rule-based # Just the trained model python3 -m tau_rag.scripts.eval_classifier \\ --gold data/gold_pool.labels.jsonl \\ --classifier tau_rag/runtime/models/argument_classifier_v1.pt # Compare both — the most useful — tells you if training was worth it python3 -m tau_rag.scripts.eval_classifier \\ --gold data/gold_pool.labels.jsonl \\ --compare tau_rag/runtime/models/argument_classifier_v1.pt """ from __future__ import annotations import argparse import json import os import sys from collections import Counter, defaultdict from pathlib import Path from typing import Any, Callable, Dict, List, Tuple # CRITICAL: same env-var dance as in finetune script — prevent TF auto-load os.environ.setdefault("USE_TF", "0") os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") # package import dance for stand-alone execution _THIS_DIR = Path(__file__).resolve().parent _PKG_PARENT = _THIS_DIR.parent.parent if str(_PKG_PARENT) not in sys.path: sys.path.insert(0, str(_PKG_PARENT)) # Tasks evaluated and their possible label values TASK_LABELS: Dict[str, List[str]] = { "is_argument": ["true", "false"], # bool flattened to strings "outcome": ["accepted", "rejected", "partial", "unknown"], "side": ["plaintiff", "defendant", "court", "unknown"], "arg_type": ["legal", "factual", "procedural", "policy", "equitable", "constitutional", "substantive", "unknown"], } def _normalize_label(task: str, value: Any) -> str: """Coerce raw label to canonical string.""" if task == "is_argument": if isinstance(value, bool): return "true" if value else "false" s = str(value).lower() return "true" if s in ("true", "yes", "1") else "false" return str(value).lower() if value is not None else "unknown" def confusion_matrix( gold: List[str], pred: List[str], labels: List[str] ) -> Dict[str, Dict[str, int]]: """Build dict-of-dict confusion matrix: cm[gold_label][pred_label] = count.""" cm: Dict[str, Dict[str, int]] = {l: {l2: 0 for l2 in labels} for l in labels} for g, p in zip(gold, pred): if g not in cm: cm[g] = {l2: 0 for l2 in labels} if p not in cm[g]: cm[g][p] = 0 cm[g][p] += 1 return cm def per_class_metrics( gold: List[str], pred: List[str], labels: List[str] ) -> Dict[str, Dict[str, float]]: """Compute precision/recall/F1 per class (and accuracy overall).""" out: Dict[str, Dict[str, float]] = {} for label in labels: tp = sum(1 for g, p in zip(gold, pred) if g == label and p == label) fp = sum(1 for g, p in zip(gold, pred) if g != label and p == label) fn = sum(1 for g, p in zip(gold, pred) if g == label and p != label) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 support = sum(1 for g in gold if g == label) out[label] = { "precision": round(precision, 3), "recall": round(recall, 3), "f1": round(f1, 3), "support": support, } accuracy = sum(1 for g, p in zip(gold, pred) if g == p) / max(len(gold), 1) macro_f1 = sum(m["f1"] for m in out.values()) / max(len(out), 1) out["__overall__"] = { "accuracy": round(accuracy, 3), "macro_f1": round(macro_f1, 3), "n_samples": len(gold), } return out def evaluate_predictor( gold_records: List[dict], predictor: Callable[[str], dict], name: str, ) -> Dict[str, Any]: """Run a predictor on each gold paragraph, compute per-task metrics.""" print(f"\n{'='*60}") print(f" Evaluating: {name}") print(f"{'='*60}") pred_per_task: Dict[str, List[str]] = {t: [] for t in TASK_LABELS} gold_per_task: Dict[str, List[str]] = {t: [] for t in TASK_LABELS} for i, rec in enumerate(gold_records): text = rec.get("text", "") try: pred = predictor(text) or {} except Exception as e: print(f" predictor error on record {i}: {e}") pred = {} for task in TASK_LABELS: g = _normalize_label(task, rec.get(task)) p = _normalize_label(task, pred.get(task, "unknown")) gold_per_task[task].append(g) pred_per_task[task].append(p) if (i + 1) % 50 == 0: print(f" [{i+1}/{len(gold_records)}]", flush=True) metrics = {} for task, labels in TASK_LABELS.items(): m = per_class_metrics(gold_per_task[task], pred_per_task[task], labels) metrics[task] = m print(f"\n Task: {task}") print(f" {'class':<14} {'P':>7} {'R':>7} {'F1':>7} {'n':>5}") for lbl, vals in m.items(): if lbl == "__overall__": continue print(f" {lbl:<12} {vals['precision']:>7.3f} " f"{vals['recall']:>7.3f} {vals['f1']:>7.3f} {vals['support']:>5}") ov = m["__overall__"] print(f" {'OVERALL':<12} acc={ov['accuracy']:.3f} " f"macro_f1={ov['macro_f1']:.3f} n={ov['n_samples']}") return metrics def make_rule_based_predictor(): """Wrap AutoLabeler.label() as a (text → label dict) callable.""" from tau_rag.intelligence import AutoLabeler labeler = AutoLabeler() def predict(text: str) -> dict: out = labeler.label(text) if out is None: return {"is_argument": False, "outcome": "unknown", "side": "unknown", "arg_type": "unknown"} return { "is_argument": out["is_argument"], "outcome": out["outcome"], "side": out["side"], "arg_type": out["arg_type"], } return predict def make_classifier_predictor(checkpoint_path: str): """Wrap the trained classifier as a (text → label dict) callable.""" from tau_rag.scripts.finetune_argument_classifier import load_classifier clf = load_classifier(checkpoint_path) def predict(text: str) -> dict: out = clf(text) or {} return { "is_argument": out.get("is_argument", False), "outcome": out.get("outcome", "unknown"), "side": out.get("side", "unknown"), "arg_type": out.get("arg_type", "unknown"), } return predict def print_comparison(rule_metrics: Dict, clf_metrics: Dict) -> None: """Side-by-side overall summary.""" print(f"\n{'='*60}") print(f" COMPARISON") print(f"{'='*60}") print(f" {'task':<14} {'rule_acc':>10} {'clf_acc':>10} " f"{'rule_f1':>9} {'clf_f1':>9} {'Δf1':>8}") for task in TASK_LABELS: r_ov = rule_metrics[task]["__overall__"] c_ov = clf_metrics[task]["__overall__"] delta = c_ov["macro_f1"] - r_ov["macro_f1"] delta_s = f"{'+' if delta >= 0 else ''}{delta:.3f}" print(f" {task:<14} {r_ov['accuracy']:>10.3f} " f"{c_ov['accuracy']:>10.3f} " f"{r_ov['macro_f1']:>9.3f} {c_ov['macro_f1']:>9.3f} " f"{delta_s:>8}") def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--gold", required=True, help="JSONL of human-verified labels") mode = ap.add_mutually_exclusive_group(required=True) mode.add_argument("--rule-based", action="store_true", help="evaluate AutoLabeler only") mode.add_argument("--classifier", help="path to trained .pt — eval that only") mode.add_argument("--compare", help="path to trained .pt — run BOTH + compare") ap.add_argument("--out", default=None, help="output JSON report (default: .eval.json)") args = ap.parse_args() gold_path = Path(args.gold) if not gold_path.exists(): raise SystemExit(f"gold set not found: {gold_path}") out_path = Path(args.out) if args.out else gold_path.with_suffix(".eval.json") # Load gold records gold_records = [] with gold_path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: rec = json.loads(line) if "text" in rec and "is_argument" in rec: gold_records.append(rec) except Exception: pass print(f"loaded {len(gold_records)} gold records from {gold_path.name}") if len(gold_records) < 50: print(f"WARNING: fewer than 50 records — metrics will be unreliable") report: Dict[str, Any] = { "gold_file": str(gold_path), "n_records": len(gold_records), } if args.rule_based: report["rule_based"] = evaluate_predictor( gold_records, make_rule_based_predictor(), "Rule-based AutoLabeler") elif args.classifier: report["trained_classifier"] = evaluate_predictor( gold_records, make_classifier_predictor(args.classifier), f"Trained classifier ({Path(args.classifier).name})") elif args.compare: report["rule_based"] = evaluate_predictor( gold_records, make_rule_based_predictor(), "Rule-based AutoLabeler") report["trained_classifier"] = evaluate_predictor( gold_records, make_classifier_predictor(args.compare), f"Trained classifier ({Path(args.compare).name})") print_comparison(report["rule_based"], report["trained_classifier"]) out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") print(f"\n ✓ report saved to {out_path}") if __name__ == "__main__": main()