#!/usr/bin/env python3 """extract_argument_training_data.py — Build a focused training dataset. The current TAU LLM (v11, 16M params) produces "word soup" when asked to adapt legal arguments — it has the vocabulary but lacks coherent legal sentence structure. The hypothesis: training on FOCUSED legal-argument paragraphs (instead of whole judgments which include headers, dates, procedural noise) will teach the model the legal flow it's missing. This script: 1. Walks all substantive judgments in the corpus 2. Runs judgment_structurer + CBR's argument detector on each 3. Keeps only HIGH-QUALITY argument paragraphs: • 80–1500 chars (real arguments, not fragments or full sections) • Score ≥ 0.30 (has legal-reasoning signal) • From discussion / arguments_plaintiff / arguments_defendant • Has at least one sentence terminator (real prose, not a list) 4. Writes the paragraphs as JSONL — one focused legal argument per line The output dataset feeds directly into your existing training pipeline (train_hebrew_llm.py / next_token_trainer.py). The model continues from the v11 checkpoint and learns to PRODUCE coherent legal Hebrew, not just recognize legal vocabulary. Output format (one line per argument paragraph): {"text": "", "case_id": "...", "section": "discussion", "score": 0.65, "has_statute": true, "has_case_ref": true} Usage: python3 -m tau_rag.scripts.extract_argument_training_data \\ --parquet storage/raw/datasets/698f9b2b-...parquet \\ --output tau_rag/runtime/training_data/legal_arguments.jsonl \\ --max-docs 50000 Defaults to 50K docs (~10-15 min) which is enough to produce ~200K paragraphs — a strong training set. Use --max-docs 0 for the full 134K. """ from __future__ import annotations import argparse import json import os import re import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2])) # A paragraph is a "legal argument" if its text exhibits the same # patterns a real lawyer would write: a claim, a citation, reasoning, # possibly a conclusion. We reuse CBR's _classify_paragraph for this. def _is_argument_paragraph(text: str, ext) -> dict: """Return classifier tags + booleans on whether this looks usable.""" if not text or len(text) < 80: return {"keep": False} tags = ext._classify_paragraph(text) # Count sentence terminators — real legal prose has them; word-soup # often doesn't. n_punct = sum(1 for c in text if c in ".,;:!?") has_punct = n_punct >= 2 keep = ( tags["score"] >= 0.30 and has_punct and len(text) <= 1500 ) return { "keep": keep, "score": tags["score"], "side": tags["side"], "outcome_marker": tags["outcome_marker"], "has_statute": bool(tags["statute_refs"]), "has_case_ref": bool(tags["case_refs"]), } def main(): ap = argparse.ArgumentParser() ap.add_argument("--parquet", required=True) ap.add_argument("--output", default= "tau_rag/runtime/training_data/legal_arguments.jsonl") ap.add_argument("--max-docs", type=int, default=50000, help="Cap on judgments scanned. 0 = all.") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--min-paragraphs-per-doc", type=int, default=1) args = ap.parse_args() os.environ["TAU_RAG_AUTOLOAD_CORPUS"] = "0" os.environ["TAU_RAG_LAZY_TEXT"] = "0" print("=" * 75, flush=True) print("EXTRACT LEGAL-ARGUMENT TRAINING DATA", flush=True) print("=" * 75, flush=True) print(f"\n[1] Loading substantive judgments from parquet...", flush=True) from tau_rag.scripts.benchmark_pipeline import _load_random_docs cap = args.max_docs if args.max_docs > 0 else 1_000_000 t0 = time.time() docs = _load_random_docs(args.parquet, n=cap, seed=args.seed) print(f" ✅ {len(docs)} substantive docs ({time.time()-t0:.1f}s)\n", flush=True) print("[2] Initializing CBR extractor + structurer...", flush=True) from tau_rag.intelligence import CaseBasedArgumentExtractor from tau_rag.judgment_structurer import structure_judgment ext = CaseBasedArgumentExtractor() print(" ✅ ready\n", flush=True) # Sections we mine for argument paragraphs ARG_SECTIONS = { "arguments_plaintiff", "arguments_claimant", "arguments_defendant", "arguments_respondent", "arguments_general", "discussion", } Path(args.output).parent.mkdir(parents=True, exist_ok=True) print(f"[3] Extracting (facts → argument) pairs → {args.output}...", flush=True) n_pairs = 0 # facts → argument PAIRS (the strong signal) n_solo = 0 # standalone argument paragraphs (weaker signal) n_dropped_short = 0 n_dropped_score = 0 paragraph_lengths = [] t0 = time.time() out_f = open(args.output, "w", encoding="utf-8") for i, (doc_id, text) in enumerate(docs): try: struct = structure_judgment(text) sections = struct.get("sections", []) except Exception: sections = [] # Pull the facts section (if present) — this is what we'll use # as the prefix/condition for argument paragraphs. facts_text = "" for sec in sections: if sec.get("id") in ("facts", "background", "header"): t = re.sub(r"\s+", " ", (sec.get("text") or "")).strip() if 60 <= len(t) <= 1500: facts_text = t break for sec in sections: sec_id = sec.get("id", "") if sec_id not in ARG_SECTIONS: continue sec_text = sec.get("text", "") or "" for para in ext._split_paragraphs(sec_text): if len(para) < 80: n_dropped_short += 1 continue tags = _is_argument_paragraph(para, ext) if not tags["keep"]: if tags.get("score", 0) < 0.30: n_dropped_score += 1 continue rec = { "text": para, "case_id": doc_id, "section": sec_id, "score": round(tags["score"], 3), "side": tags["side"], "has_statute": tags["has_statute"], "has_case_ref": tags["has_case_ref"], } # If we found a facts section, ALSO emit a paired record: # the model learns "given these facts, produce this # argument". This is the strongest training signal — it # teaches the fact→reasoning mapping directly. if facts_text and sec_id == "discussion": rec["facts"] = facts_text rec["paired"] = True # Causal-LM-ready training string: rec["training_text"] = ( f"עובדות:\n{facts_text}\n\n" f"ניתוח משפטי:\n{para}" ) n_pairs += 1 else: rec["paired"] = False rec["training_text"] = para n_solo += 1 out_f.write(json.dumps(rec, ensure_ascii=False) + "\n") paragraph_lengths.append(len(para)) if (i + 1) % 1000 == 0: elapsed = time.time() - t0 rate = (i + 1) / elapsed eta = (len(docs) - i - 1) / rate print(f" {i+1}/{len(docs)} | pairs={n_pairs} solo={n_solo} | " f"{rate:.0f} docs/s | ETA {eta:.0f}s", flush=True) out_f.close() n_kept = n_pairs + n_solo n_dropped_no_args = 0 # no longer tracked (kept for back-compat below) elapsed = time.time() - t0 avg_len = (sum(paragraph_lengths) / max(len(paragraph_lengths), 1)) print(f"\n ✅ EXTRACTION COMPLETE ({elapsed:.0f}s)", flush=True) print(f" Total records: {n_kept:,}", flush=True) print(f" ├─ facts → argument: {n_pairs:,} (★ strong signal)", flush=True) print(f" └─ standalone arguments: {n_solo:,}", flush=True) print(f" Avg argument length: {avg_len:.0f} chars", flush=True) print(f" Total training characters: {sum(paragraph_lengths):,}", flush=True) print(f" Dropped (too short): {n_dropped_short:,}", flush=True) print(f" Dropped (low score): {n_dropped_score:,}", flush=True) print(f"\n Output: {args.output}", flush=True) print(f"\n[4] Next steps:", flush=True) print(f" • Inspect a few lines:", flush=True) print(f" head -5 {args.output} | python3 -m json.tool", flush=True) print(f" • Continue training from v11 on this dataset:", flush=True) print(f" python3 -m tau_rag.scripts.continue_training_arguments \\\\", flush=True) print(f" --data {args.output} \\\\", flush=True) print(f" --base-checkpoint tau_rag/runtime/models/tau_hebrew_legal_llm_v11.pt \\\\", flush=True) print(f" --epochs 3", flush=True) if __name__ == "__main__": main()