#!/usr/bin/env python3 """Auto-label Hebrew legal paragraphs based on citations + linguistic markers. Replaces 50+ hours of manual labeling with a few minutes of CPU work. Output is noisy training data — typical precision 80-90% on records above the confidence threshold. The classifier trained on this data learns to denoise + generalize. Workflow: # 1. Extract paragraph pool from the full corpus python3 -m tau_rag.scripts.extract_paragraphs_for_labeling \\ --n 50000 --out data/pool_50k.jsonl --high-signal-only # 2. Auto-label everything in the pool python3 -m tau_rag.scripts.auto_label_paragraphs \\ --in data/pool_50k.jsonl \\ --out data/auto_labels_50k.jsonl \\ --min-confidence 0.55 # 3. Optional: validate a sample manually with the labeling tool python3 -m tau_rag.scripts.labeling_server \\ --pool data/auto_labels_50k.jsonl # 4. Train the classifier on the auto-labels python3 -m tau_rag.scripts.finetune_argument_classifier \\ --labels-jsonl data/auto_labels_50k.jsonl You can also auto-label DIRECTLY on the corpus without the extraction step by passing --corpus instead of --in. The script will scan paragraphs and emit only those above the confidence threshold. """ from __future__ import annotations import argparse import json import sys import time from collections import Counter from pathlib import Path from typing import List, Optional # Allow running this file directly without installing the package _THIS_DIR = Path(__file__).resolve().parent _PKG_PARENT = _THIS_DIR.parent.parent if str(_PKG_PARENT) not in sys.path: sys.path.insert(0, str(_PKG_PARENT)) from tau_rag.intelligence.auto_labeler import AutoLabeler def split_paragraphs(text: str, min_len: int = 60, max_len: int = 800) -> List[str]: if not text: return [] paras = [] for p in text.split("\n\n"): p = p.strip() if min_len <= len(p) <= max_len: paras.append(p) return paras def label_record(rec: dict, labeler: AutoLabeler, min_conf: float) -> Optional[dict]: """Label a single paragraph record. Returns None if below threshold.""" text = rec.get("text", "") label = labeler.label(text) if label is None: return None if label["overall_confidence"] < min_conf: return None out = { "id": rec.get("id"), "case_id": rec.get("case_id"), "domain": rec.get("domain"), "text": text, } out.update({ "is_argument": label["is_argument"], "outcome": label["outcome"], "side": label["side"], "arg_type": label["arg_type"], "confidence": label["overall_confidence"], "confidence_breakdown": { "is_argument": label["is_argument_confidence"], "outcome": label["outcome_confidence"], "side": label["side_confidence"], "arg_type": label["arg_type_confidence"], }, "auto_labeled": True, }) return out def run_on_pool(pool_path: Path, out_path: Path, min_conf: float) -> dict: """Mode A: input is a paragraph pool JSONL (one paragraph per line).""" labeler = AutoLabeler() n_in = n_out = 0 label_dist: Counter = Counter() out_path.parent.mkdir(parents=True, exist_ok=True) t0 = time.time() with pool_path.open("r", encoding="utf-8") as fin, \ out_path.open("w", encoding="utf-8") as fout: for line in fin: line = line.strip() if not line: continue try: rec = json.loads(line) except Exception: continue n_in += 1 labeled = label_record(rec, labeler, min_conf) if labeled is None: continue fout.write(json.dumps(labeled, ensure_ascii=False) + "\n") n_out += 1 label_dist[(labeled["is_argument"], labeled["outcome"], labeled["side"])] += 1 if n_in % 5000 == 0: elapsed = time.time() - t0 rate = n_in / max(elapsed, 0.001) print(f" scanned {n_in:,} kept {n_out:,} " f"({rate:.0f}/s)", flush=True) elapsed = time.time() - t0 return { "n_in": n_in, "n_out": n_out, "n_dropped_low_confidence": n_in - n_out, "elapsed_s": round(elapsed, 1), "rate_per_s": round(n_in / max(elapsed, 0.001), 0), "top_label_combinations": [ {"is_arg": k[0], "outcome": k[1], "side": k[2], "count": v} for k, v in label_dist.most_common(10) ], } def run_on_corpus(corpus_path: Path, out_path: Path, min_conf: float, max_cases: Optional[int]) -> dict: """Mode B: input is the full corpus JSONL — extract paragraphs on the fly.""" labeler = AutoLabeler() n_cases = n_paras = n_kept = 0 label_dist: Counter = Counter() out_path.parent.mkdir(parents=True, exist_ok=True) t0 = time.time() with corpus_path.open("r", encoding="utf-8") as fin, \ out_path.open("w", encoding="utf-8") as fout: for line in fin: line = line.strip() if not line: continue try: rec = json.loads(line) except Exception: continue n_cases += 1 if max_cases and n_cases > max_cases: break case_id = rec.get("id", "") domain = (rec.get("metadata") or {}).get("domain") text = rec.get("text", "") or "" for i, para in enumerate(split_paragraphs(text)): n_paras += 1 pseudo = { "id": f"{case_id}::{i}", "case_id": case_id, "domain": domain, "text": para, } labeled = label_record(pseudo, labeler, min_conf) if labeled is None: continue fout.write(json.dumps(labeled, ensure_ascii=False) + "\n") n_kept += 1 label_dist[(labeled["is_argument"], labeled["outcome"], labeled["side"])] += 1 if n_cases % 1000 == 0: elapsed = time.time() - t0 rate_p = n_paras / max(elapsed, 0.001) print(f" cases={n_cases:,} paragraphs={n_paras:,} " f"kept={n_kept:,} ({rate_p:.0f} para/s)", flush=True) elapsed = time.time() - t0 return { "n_cases": n_cases, "n_paragraphs": n_paras, "n_kept": n_kept, "kept_ratio": round(n_kept / max(n_paras, 1), 3), "elapsed_s": round(elapsed, 1), "rate_paragraphs_per_s": round(n_paras / max(elapsed, 0.001), 0), "top_label_combinations": [ {"is_arg": k[0], "outcome": k[1], "side": k[2], "count": v} for k, v in label_dist.most_common(10) ], } def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) src = ap.add_mutually_exclusive_group(required=True) src.add_argument("--in", dest="input_pool", help="paragraph pool JSONL (from extract_paragraphs_for_labeling)") src.add_argument("--corpus", help="full corpus JSONL (e.g. parquet_cases.jsonl) — " "paragraphs extracted on the fly") ap.add_argument("--out", required=True, help="output labels JSONL") ap.add_argument("--min-confidence", type=float, default=0.55, help="drop labels below this overall confidence (default 0.55)") ap.add_argument("--max-cases", type=int, default=None, help="(corpus mode) stop after this many cases") args = ap.parse_args() out_path = Path(args.out) print(f"\n auto-labeler — min_confidence={args.min_confidence}\n") if args.input_pool: pool_path = Path(args.input_pool) if not pool_path.exists(): sys.exit(f"input pool not found: {pool_path}") stats = run_on_pool(pool_path, out_path, args.min_confidence) print(f"\n ✓ done") print(f" scanned: {stats['n_in']:,}") print(f" kept: {stats['n_out']:,} " f"(dropped {stats['n_dropped_low_confidence']:,} below threshold)") else: corpus_path = Path(args.corpus) if not corpus_path.exists(): sys.exit(f"corpus not found: {corpus_path}") stats = run_on_corpus(corpus_path, out_path, args.min_confidence, args.max_cases) print(f"\n ✓ done") print(f" cases: {stats['n_cases']:,}") print(f" paragraphs: {stats['n_paragraphs']:,}") print(f" labels kept: {stats['n_kept']:,} " f"({100*stats['kept_ratio']:.1f}%)") print(f" rate: {stats['rate_paragraphs_per_s']:.0f} para/s") print(f" output: {out_path}\n") print(f" top label combinations:") for c in stats["top_label_combinations"]: print(f" is_arg={c['is_arg']} outcome={c['outcome']:>9} " f"side={c['side']:>10} → {c['count']:,}") # Also dump stats sidecar stats_path = out_path.with_suffix(".stats.json") stats_path.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding="utf-8") print(f"\n stats: {stats_path}") print(f"\n next:") print(f" python3 -m tau_rag.scripts.finetune_argument_classifier \\") print(f" --labels-jsonl {out_path}\n") if __name__ == "__main__": main()