#!/usr/bin/env python3 """Build a stratified gold validation set from auto-labels. Picks N paragraphs that are *representative* of the label distribution — mixing accepted / rejected / unknown / argument / non-argument so the eval isn't dominated by one class. The result is the input to the labeling tool where you'll manually verify each one (1-2 hours of work). Use case: 1. Run this script → produces `gold_pool.jsonl` (200 candidates). 2. Open in labeling_server.py → manually verify each paragraph. 3. The verified labels (`gold_pool.labels.jsonl`) become the test set. 4. Run eval_classifier.py against it → see real precision/recall. Without this gold set, "val_accuracy=0.85" during training is meaningless because val labels themselves are noisy auto-labels. Usage: python3 -m tau_rag.scripts.build_gold_set \\ --auto-labels data/auto_labels.jsonl \\ --out data/gold_pool.jsonl \\ --n 200 """ from __future__ import annotations import argparse import json import random from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List def stratified_sample( labels: List[dict], n_total: int, rng: random.Random, ) -> List[dict]: """Pick a balanced sample across (is_argument, outcome, side) buckets. The auto-label distribution is heavily skewed toward "is_argument=True, outcome=unknown, side=unknown" because most paragraphs don't have clear acceptance markers. For evaluation we want each meaningful class to be represented so we can actually measure precision per class. """ # Group by (outcome, side) — these are the dimensions we care most # about getting right. is_argument is binary and easier; arg_type is # secondary. buckets: Dict[tuple, List[dict]] = defaultdict(list) for rec in labels: if not rec.get("is_argument"): buckets[("not-argument",)].append(rec) else: key = (rec.get("outcome", "unknown"), rec.get("side", "unknown")) buckets[key].append(rec) # Decide quota per bucket — favor diversity over proportionality. # We want at least 8-15 examples per meaningful bucket. n_buckets = len(buckets) base_quota = max(8, n_total // n_buckets) selected = [] for key, items in buckets.items(): rng.shuffle(items) quota = min(len(items), base_quota) selected.extend(items[:quota]) rng.shuffle(selected) return selected[:n_total] def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--auto-labels", default="data/auto_labels.jsonl", help="JSONL of auto-labeled paragraphs") ap.add_argument("--out", default="data/gold_pool.jsonl", help="output stratified pool for manual verification") ap.add_argument("--n", type=int, default=200, help="how many paragraphs to include (default 200)") ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() src = Path(args.auto_labels) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) if not src.exists(): raise SystemExit(f"input not found: {src}") print(f"reading {src.name}...") labels: List[dict] = [] with src.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: labels.append(json.loads(line)) except Exception: pass print(f" loaded {len(labels):,} auto-labels") rng = random.Random(args.seed) selected = stratified_sample(labels, args.n, rng) print(f"\n selected {len(selected):,} stratified candidates:") # Show distribution dist = Counter() for rec in selected: if not rec.get("is_argument"): dist["not-argument"] += 1 else: dist[(rec.get("outcome"), rec.get("side"))] += 1 for key, count in sorted(dist.items(), key=lambda x: -x[1]): print(f" {str(key):>40s} {count}") # Strip the auto labels — leave them as hints but the labeling tool # will overwrite. The gold set is what HUMAN says. pool_records = [] for rec in selected: pool_records.append({ "id": rec.get("id"), "case_id": rec.get("case_id"), "domain": rec.get("domain"), "text": rec.get("text"), # Keep auto-labels as hint for the labeler — they can confirm/correct "auto_hint": { "is_argument": rec.get("is_argument"), "outcome": rec.get("outcome"), "side": rec.get("side"), "arg_type": rec.get("arg_type"), "confidence": rec.get("confidence"), }, }) with out.open("w", encoding="utf-8") as f: for rec in pool_records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") print(f"\n ✓ wrote {out}") print(f"\n next: manually verify each label") print(f" python3 -m tau_rag.scripts.labeling_server --pool {out}") print(f" open http://localhost:8765") print(f"\n the verified labels will be saved to:") print(f" {out.with_suffix('.labels.jsonl')}") print(f"\n then run evaluation:") print(f" python3 -m tau_rag.scripts.eval_classifier \\") print(f" --gold {out.with_suffix('.labels.jsonl')} \\") print(f" --classifier tau_rag/runtime/models/argument_classifier_v1.pt") if __name__ == "__main__": main()