#!/usr/bin/env python3 """Fine-tune a classifier for argument detection + acceptance from labeled Hebrew legal paragraphs. Uses ONLY local models (no external API). Two strategies are supported: 1. HEBERT_HEAD (recommended for small datasets <5k labels): Loads `avichr/heBERT` (or another Hebrew encoder), adds a multi-task classification head, fine-tunes on Mac M1/M2 in 2-6 hours. 2. TAU_LLM_HEAD (advanced, uses your custom 16.5M model): Loads runtime/models/tau_hebrew_legal_llm_v11.pt, adds a small classification head over its hidden states. Faster inference at runtime but requires the encoder modules to be importable. Inputs: --labels-jsonl /path/to/labeled_paragraphs.jsonl Each line: {"text": "...", "is_argument": bool, "outcome": "accepted|rejected|partial|unknown", "side": "plaintiff|defendant|court|unknown", "arg_type": "factual|legal|procedural|policy|equitable"} Output: runtime/models/argument_classifier_v1.pt runtime/models/argument_classifier_v1.metrics.json Usage: # Bare minimum — point at the labeled JSONL python3 -m tau_rag.scripts.finetune_argument_classifier \\ --labels-jsonl data/labeled_args.jsonl # Custom config python3 -m tau_rag.scripts.finetune_argument_classifier \\ --labels-jsonl data/labeled_args.jsonl \\ --strategy HEBERT_HEAD \\ --epochs 6 \\ --batch-size 16 \\ --lr 2e-5 \\ --out runtime/models/argument_classifier_v2.pt Hardware: - Mac M1/M2: trains in 2-6 hours on 5k labels (HEBERT_HEAD) - GPU: 30-60 minutes - CPU x86: 8-12 hours (still feasible) """ from __future__ import annotations # CRITICAL: these env vars must be set BEFORE importing transformers, to # prevent it from auto-loading tensorflow. tf imports often fail on Python # 3.11 due to the `wrapt`/`formatargspec` deprecation chain. We use PyTorch # exclusively for fine-tuning, so TF is unnecessary. import os os.environ.setdefault("USE_TF", "0") os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") # Disable telemetry that sometimes triggers extra imports os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") import argparse import json import random import sys import time from pathlib import Path from typing import Dict, List, Tuple # Defer torch import — let user run --help without having torch installed try: import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader _HAS_TORCH = True except ImportError: _HAS_TORCH = False # Label spaces — keep these stable across runs IS_ARGUMENT_LABELS = ["no", "yes"] OUTCOME_LABELS = ["unknown", "accepted", "rejected", "partial"] SIDE_LABELS = ["unknown", "plaintiff", "defendant", "court"] ARG_TYPE_LABELS = ["unknown", "factual", "legal", "procedural", "policy", "equitable", "constitutional", "substantive"] # ============================================================================= # Dataset # ============================================================================= class ArgumentLabelDataset: """Lazy-loads a JSONL of labeled paragraphs into tokenized tensors. Only realized when torch is available — otherwise this is a stub for --help to work. """ def __init__(self, jsonl_path: Path, tokenizer, max_len: int = 256): self.jsonl_path = jsonl_path self.tokenizer = tokenizer self.max_len = max_len self.records: List[Dict] = [] with jsonl_path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue rec = json.loads(line) if "text" not in rec: continue self.records.append(rec) print(f" loaded {len(self.records):,} labels from {jsonl_path.name}") def __len__(self): return len(self.records) def __getitem__(self, i): rec = self.records[i] encoded = self.tokenizer( rec["text"], max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt", ) return { "input_ids": encoded["input_ids"].squeeze(0), "attention_mask": encoded["attention_mask"].squeeze(0), "is_argument": int(rec.get("is_argument", False)), "outcome": OUTCOME_LABELS.index(rec.get("outcome", "unknown")), "side": SIDE_LABELS.index(rec.get("side", "unknown")), "arg_type": ARG_TYPE_LABELS.index(rec.get("arg_type", "unknown")), } # ============================================================================= # Multi-task classifier head # ============================================================================= class MultiTaskArgumentClassifier(nn.Module if _HAS_TORCH else object): """4 small heads sharing one HeBERT encoder. Head outputs: is_argument: binary outcome: 4-class side: 4-class arg_type: 8-class Shared encoder = 95% of compute. 4 heads add < 1% params. """ def __init__(self, encoder, hidden_size: int = 768, dropout: float = 0.1): super().__init__() self.encoder = encoder self.dropout = nn.Dropout(dropout) self.head_is_arg = nn.Linear(hidden_size, len(IS_ARGUMENT_LABELS)) self.head_outcome = nn.Linear(hidden_size, len(OUTCOME_LABELS)) self.head_side = nn.Linear(hidden_size, len(SIDE_LABELS)) self.head_argtype = nn.Linear(hidden_size, len(ARG_TYPE_LABELS)) def forward(self, input_ids, attention_mask): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Take [CLS] token representation (or mean-pool for non-BERT) h = out.last_hidden_state[:, 0, :] h = self.dropout(h) return { "is_argument": self.head_is_arg(h), "outcome": self.head_outcome(h), "side": self.head_side(h), "arg_type": self.head_argtype(h), } # ============================================================================= # Training loop # ============================================================================= def train_hebert_strategy(args): """HEBERT_HEAD strategy — recommended.""" if not _HAS_TORCH: sys.exit("ERROR: torch + transformers required. " "pip install torch transformers") try: from transformers import AutoModel, AutoTokenizer except ImportError: sys.exit("ERROR: transformers required. pip install transformers") print(f"\n[1/5] Loading HeBERT encoder ({args.hebert_model})...") tokenizer = AutoTokenizer.from_pretrained(args.hebert_model) encoder = AutoModel.from_pretrained(args.hebert_model) hidden = encoder.config.hidden_size print(f" encoder hidden_size={hidden}, vocab={tokenizer.vocab_size:,}") print(f"\n[2/5] Loading dataset...") full = ArgumentLabelDataset(Path(args.labels_jsonl), tokenizer, max_len=args.max_len) n = len(full) if n < 50: sys.exit(f"ERROR: only {n} labels — need at least 50 for sensible training.") # Train/val split random.seed(42) indices = list(range(n)) random.shuffle(indices) val_size = max(20, int(n * 0.15)) val_idx = set(indices[:val_size]) train_set = [full[i] for i in indices[val_size:]] val_set = [full[i] for i in indices[:val_size]] print(f" train={len(train_set):,} val={len(val_set):,}") print(f"\n[3/5] Building model...") device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) print(f" device={device}") model = MultiTaskArgumentClassifier(encoder, hidden_size=hidden, dropout=args.dropout).to(device) optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) loss_fn = nn.CrossEntropyLoss() def to_device(batch): return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} def collate(batch_items): return { "input_ids": torch.stack([b["input_ids"] for b in batch_items]), "attention_mask": torch.stack([b["attention_mask"] for b in batch_items]), "is_argument": torch.tensor([b["is_argument"] for b in batch_items]), "outcome": torch.tensor([b["outcome"] for b in batch_items]), "side": torch.tensor([b["side"] for b in batch_items]), "arg_type": torch.tensor([b["arg_type"] for b in batch_items]), } print(f"\n[4/5] Training {args.epochs} epochs...") out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) best_val = 0.0 for epoch in range(1, args.epochs + 1): # ---- train ---- model.train() random.shuffle(train_set) losses = [] t0 = time.time() for i in range(0, len(train_set), args.batch_size): batch_items = train_set[i:i + args.batch_size] batch = collate(batch_items) batch = to_device(batch) optim.zero_grad() preds = model(batch["input_ids"], batch["attention_mask"]) loss = ( loss_fn(preds["is_argument"], batch["is_argument"]) + 0.7 * loss_fn(preds["outcome"], batch["outcome"]) + 0.5 * loss_fn(preds["side"], batch["side"]) + 0.5 * loss_fn(preds["arg_type"], batch["arg_type"]) ) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optim.step() losses.append(loss.item()) if (i // args.batch_size) % 10 == 0: print(f" epoch {epoch} step {i//args.batch_size}/" f"{len(train_set)//args.batch_size}: loss={loss.item():.4f}", end="\r") train_loss = sum(losses) / len(losses) train_dt = time.time() - t0 # ---- val ---- model.eval() val_correct = {"is_argument": 0, "outcome": 0, "side": 0, "arg_type": 0} with torch.no_grad(): for i in range(0, len(val_set), args.batch_size): batch_items = val_set[i:i + args.batch_size] batch = collate(batch_items) batch = to_device(batch) preds = model(batch["input_ids"], batch["attention_mask"]) for task in val_correct: val_correct[task] += ( preds[task].argmax(-1) == batch[task] ).sum().item() val_acc = {k: v / len(val_set) for k, v in val_correct.items()} avg_val = sum(val_acc.values()) / 4 print(f" epoch {epoch}: train_loss={train_loss:.4f} " f"val_acc[is_arg={val_acc['is_argument']:.2f} " f"outcome={val_acc['outcome']:.2f} " f"side={val_acc['side']:.2f} " f"argtype={val_acc['arg_type']:.2f}] " f"avg={avg_val:.2f} ({train_dt:.0f}s)") if avg_val > best_val: best_val = avg_val torch.save({ "model_state": model.state_dict(), "tokenizer_name": args.hebert_model, "label_spaces": { "is_argument": IS_ARGUMENT_LABELS, "outcome": OUTCOME_LABELS, "side": SIDE_LABELS, "arg_type": ARG_TYPE_LABELS, }, "config": {"hidden_size": hidden, "dropout": args.dropout}, "epoch": epoch, "val_acc": val_acc, }, out_path) print(f" ✓ saved best model to {out_path} (avg_val={avg_val:.3f})") print(f"\n[5/5] Best val accuracy = {best_val:.3f}") metrics_path = out_path.with_suffix(".metrics.json") metrics_path.write_text(json.dumps({ "best_val_avg_accuracy": best_val, "n_train": len(train_set), "n_val": len(val_set), "config": vars(args), }, indent=2, ensure_ascii=False), encoding="utf-8") print(f" metrics saved to {metrics_path}") # ============================================================================= # Inference helper (used by strategy_synthesizer at runtime) # ============================================================================= def load_classifier(checkpoint_path: str): """Load a trained classifier as a callable. Returns: fn(text: str) -> {is_argument, outcome, side, arg_type, confidence} """ if not _HAS_TORCH: raise RuntimeError("torch not installed") from transformers import AutoModel, AutoTokenizer ckpt = torch.load(checkpoint_path, map_location="cpu") tokenizer = AutoTokenizer.from_pretrained(ckpt["tokenizer_name"]) encoder = AutoModel.from_pretrained(ckpt["tokenizer_name"]) model = MultiTaskArgumentClassifier( encoder, hidden_size=ckpt["config"]["hidden_size"], dropout=ckpt["config"]["dropout"], ) model.load_state_dict(ckpt["model_state"]) model.eval() label_spaces = ckpt["label_spaces"] def predict(text: str) -> dict: with torch.no_grad(): enc = tokenizer(text, max_length=256, padding="max_length", truncation=True, return_tensors="pt") preds = model(enc["input_ids"], enc["attention_mask"]) out = {} for task, labels in [ ("is_argument", label_spaces["is_argument"]), ("outcome", label_spaces["outcome"]), ("side", label_spaces["side"]), ("arg_type", label_spaces["arg_type"]), ]: logits = preds[task] probs = torch.softmax(logits, dim=-1) idx = int(probs.argmax(-1).item()) out[task] = labels[idx] out[f"{task}_confidence"] = float(probs[0, idx].item()) # Convert is_argument string to bool for downstream out["is_argument"] = (out["is_argument"] == "yes") out["confidence"] = ( out["is_argument_confidence"] * out["outcome_confidence"] * out["side_confidence"] ) ** (1/3) # Strength heuristic from outcome confidence out["strength"] = ( 0.85 if out["outcome"] == "accepted" else 0.50 if out["outcome"] == "partial" else 0.25 ) * out["outcome_confidence"] return out return predict # ============================================================================= # CLI # ============================================================================= def main(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--labels-jsonl", required=True, help="JSONL with labeled paragraphs") p.add_argument("--out", default="tau_rag/runtime/models/argument_classifier_v1.pt", help="output checkpoint path") p.add_argument("--strategy", choices=["HEBERT_HEAD", "TAU_LLM_HEAD"], default="HEBERT_HEAD", help="which encoder to fine-tune (default: HEBERT_HEAD)") p.add_argument("--hebert-model", default="avichr/heBERT", help="HuggingFace model name (HeBERT default)") p.add_argument("--epochs", type=int, default=4) p.add_argument("--batch-size", type=int, default=8) p.add_argument("--lr", type=float, default=2e-5) p.add_argument("--weight-decay", type=float, default=0.01) p.add_argument("--dropout", type=float, default=0.1) p.add_argument("--max-len", type=int, default=256) args = p.parse_args() if args.strategy == "HEBERT_HEAD": train_hebert_strategy(args) else: sys.exit("TAU_LLM_HEAD strategy not implemented yet — use HEBERT_HEAD.") if __name__ == "__main__": main()