| |
| """continue_training_arguments.py — Fine-tune TAU LLM on argument pairs. |
| |
| Continues from your existing v11 checkpoint and trains for additional |
| epochs on the focused (facts → argument) dataset produced by |
| extract_argument_training_data.py. |
| |
| Why this approach (vs. training from scratch on whole judgments): |
| |
| 1. The current v11 model has the LANGUAGE (Hebrew) and the |
| VOCABULARY (legal terms) — what it lacks is the FLOW |
| (coherent sentence structure when adapting an argument). |
| |
| 2. By continuing from v11 on facts→argument PAIRS, we teach |
| it the exact transformation we use it for at inference: |
| "given these facts, produce this kind of legal reasoning." |
| |
| 3. ~200K paired examples × 3 epochs ≈ 600K weight updates |
| targeted at the right task — should be enough to fix the |
| word-soup problem without rebuilding the model. |
| |
| Usage: |
| python3 -m tau_rag.scripts.continue_training_arguments \\ |
| --data tau_rag/runtime/training_data/legal_arguments.jsonl \\ |
| --base-checkpoint tau_rag/runtime/models/tau_hebrew_legal_llm_v11.pt \\ |
| --output tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt \\ |
| --epochs 3 \\ |
| --batch-size 16 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data", required=True, |
| help="JSONL produced by extract_argument_training_data") |
| ap.add_argument("--base-checkpoint", required=True) |
| ap.add_argument("--output", |
| default="tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt") |
| ap.add_argument("--epochs", type=int, default=3) |
| ap.add_argument("--batch-size", type=int, default=16) |
| ap.add_argument("--lr", type=float, default=1e-4, |
| help="Lower LR for fine-tuning (default 1e-4 vs initial 3e-4)") |
| ap.add_argument("--context-length", type=int, default=256, |
| help="Longer context to fit facts+argument pairs") |
| ap.add_argument("--max-samples", type=int, default=0, |
| help="Cap samples for quick iteration. 0 = all.") |
| ap.add_argument("--prefer-paired", action="store_true", |
| help="Use only paired records (drops solo arguments)") |
| ap.add_argument("--max-windows", type=int, default=120_000, |
| help="Cap on training windows after dataset built. " |
| "TextDataset uses stride=1 → ~10× more windows " |
| "than tokens/context_len; 120K ≈ 90min/epoch on mps. " |
| "Set 0 to disable (use full overlapping windows).") |
| ap.add_argument("--checkpoint-every", type=int, default=2000, |
| help="Save a rolling checkpoint every N batches so " |
| "interrupted runs aren't a total loss. 0 = off.") |
| args = ap.parse_args() |
|
|
| print("=" * 75, flush=True) |
| print("CONTINUE TRAINING ON LEGAL ARGUMENTS", flush=True) |
| print("=" * 75, flush=True) |
| print(f" data: {args.data}", flush=True) |
| print(f" base ckpt: {args.base_checkpoint}", flush=True) |
| print(f" output ckpt: {args.output}", flush=True) |
| print(f" epochs: {args.epochs}", flush=True) |
| print(f" batch_size: {args.batch_size}", flush=True) |
| print(f" lr: {args.lr}", flush=True) |
| print(f" context_len: {args.context_length}", flush=True) |
| print(f" prefer_paired: {args.prefer_paired}", flush=True) |
|
|
| |
| print("\n[1] Loading JSONL training data...", flush=True) |
| t0 = time.time() |
| texts = [] |
| n_pairs = n_solo = 0 |
| with open(args.data, "r", encoding="utf-8") as f: |
| for line in f: |
| try: |
| rec = json.loads(line) |
| except Exception: |
| continue |
| tt = rec.get("training_text") or rec.get("text") or "" |
| if not tt: |
| continue |
| paired = rec.get("paired", False) |
| if args.prefer_paired and not paired: |
| continue |
| texts.append(tt) |
| if paired: n_pairs += 1 |
| else: n_solo += 1 |
| if args.max_samples and len(texts) >= args.max_samples: |
| break |
| print(f" ✅ {len(texts):,} records loaded " |
| f"(pairs={n_pairs:,} solo={n_solo:,}) " |
| f"in {time.time()-t0:.1f}s", flush=True) |
|
|
| |
| print("\n[2] Loading base checkpoint...", flush=True) |
| t0 = time.time() |
| import torch |
| ckpt = torch.load(args.base_checkpoint, map_location="cpu", |
| weights_only=False) |
| encoder = ckpt["encoder"] |
| config_dict = ckpt["config"] |
| print(f" ✅ checkpoint loaded ({time.time()-t0:.1f}s)", flush=True) |
| print(f" vocab size: {len(getattr(encoder, 'vocab_layer', encoder).word2idx)}", |
| flush=True) |
|
|
| |
| print("\n[3] Building model + trainer...", flush=True) |
| |
| |
| |
| |
| _PLATFORM_ROOT = Path(__file__).resolve().parents[2] |
| if str(_PLATFORM_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PLATFORM_ROOT)) |
| from training.next_token_trainer import ( |
| TAULanguageModel, TAUTrainer, TextDataset, TrainingConfig, |
| ) |
| |
| |
| |
| |
| sd = ckpt["model_state_dict"] |
| pos_enc_shape = sd.get("embedding.pos_encoding") |
| detected_ctx = (pos_enc_shape.shape[0] |
| if pos_enc_shape is not None else None) |
| ctx_len = (config_dict.get("context_length") |
| or detected_ctx |
| or args.context_length) |
| if ctx_len != args.context_length: |
| print(f" ℹ context_length: using {ctx_len} (from checkpoint) " |
| f"instead of CLI default {args.context_length}", flush=True) |
| cfg = TrainingConfig( |
| embedding_dim=config_dict.get("embedding_dim", 256), |
| hidden_dim=config_dict.get("hidden_dim", 512), |
| num_layers=config_dict.get("num_layers", 4), |
| num_heads=config_dict.get("num_heads", 8), |
| context_length=ctx_len, |
| dropout=config_dict.get("dropout", 0.1), |
| batch_size=args.batch_size, |
| epochs=args.epochs, |
| learning_rate=args.lr, |
| device=("mps" if torch.backends.mps.is_available() |
| else "cuda" if torch.cuda.is_available() else "cpu"), |
| ) |
| vocab_layer = getattr(encoder, "vocab_layer", encoder) |
| vocab_size = len(vocab_layer.word2idx) |
| model = TAULanguageModel(vocab_size=vocab_size, config=cfg) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| print(f" ✅ model built, weights loaded from v11", flush=True) |
| print(f" device: {cfg.device}", flush=True) |
|
|
| |
| print("\n[4] Building dataset...", flush=True) |
| dataset = TextDataset( |
| texts=texts, |
| encoder=encoder, |
| context_length=cfg.context_length, |
| ) |
| print(f" ✅ {len(dataset):,} training windows", flush=True) |
|
|
| from torch.utils.data import DataLoader, Subset |
| |
| |
| |
| |
| if args.max_windows and len(dataset) > args.max_windows: |
| import torch as _torch |
| g = _torch.Generator().manual_seed(42) |
| idx = _torch.randperm(len(dataset), generator=g)[:args.max_windows] |
| dataset = Subset(dataset, idx.tolist()) |
| print(f" ✂ capped to {len(dataset):,} windows " |
| f"(--max-windows {args.max_windows})", flush=True) |
|
|
| loader = DataLoader(dataset, batch_size=cfg.batch_size, |
| shuffle=True, num_workers=0) |
|
|
| |
| print("\n[5] Training...", flush=True) |
| trainer = TAUTrainer(model=model, config=cfg, vocab=vocab_layer) |
|
|
| |
| |
| |
| def _save(path, epoch_idx, last_loss): |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "encoder": encoder, |
| "config": config_dict, |
| "epoch": ckpt.get("epoch", 0) + epoch_idx, |
| "loss": last_loss, |
| "trained_on": "legal_arguments_v1", |
| "base_checkpoint": args.base_checkpoint, |
| }, path) |
|
|
| |
| |
| |
| if args.checkpoint_every and args.checkpoint_every > 0: |
| _orig_step = trainer.optimizer.step |
| _state = {"step": 0, "last_loss": float("nan")} |
|
|
| def _hooked_step(*a, **kw): |
| r = _orig_step(*a, **kw) |
| _state["step"] += 1 |
| if _state["step"] % args.checkpoint_every == 0: |
| if trainer.loss_history: |
| _state["last_loss"] = trainer.loss_history[-1] |
| roll = (Path(args.output).with_suffix("") |
| .as_posix() + f"_step{_state['step']}.pt") |
| _save(roll, epoch_idx=0, last_loss=_state["last_loss"]) |
| print(f" 💾 rolling checkpoint @ step {_state['step']} " |
| f"(loss={_state['last_loss']:.4f}) → {roll}", |
| flush=True) |
| return r |
|
|
| trainer.optimizer.step = _hooked_step |
|
|
| initial_loss = None |
| avg_loss = None |
| for epoch in range(args.epochs): |
| print(f"\n ━━━ Epoch {epoch+1}/{args.epochs} ━━━", flush=True) |
| try: |
| avg_loss = trainer.train_epoch(loader) |
| except KeyboardInterrupt: |
| print(f"\n ⚠ interrupted mid-epoch — saving partial...", |
| flush=True) |
| partial = (Path(args.output).with_suffix("").as_posix() |
| + f"_partial_epoch{epoch+1}.pt") |
| last = (trainer.loss_history[-1] |
| if trainer.loss_history else float("nan")) |
| _save(partial, epoch_idx=epoch, last_loss=last) |
| print(f" 💾 partial saved → {partial}", flush=True) |
| return |
| if initial_loss is None: |
| initial_loss = avg_loss |
| improvement = ((initial_loss - avg_loss) / initial_loss * 100 |
| if initial_loss > 0 else 0) |
| print(f" epoch {epoch+1}: avg_loss = {avg_loss:.4f} " |
| f"(Δ {improvement:+.1f}% from start)", flush=True) |
| |
| per_epoch = (Path(args.output).with_suffix("").as_posix() |
| + f"_epoch{epoch+1}.pt") |
| _save(per_epoch, epoch_idx=epoch+1, last_loss=avg_loss) |
| print(f" 💾 epoch checkpoint → {per_epoch}", flush=True) |
|
|
| |
| print(f"\n[6] Saving to {args.output}...", flush=True) |
| Path(args.output).parent.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "encoder": encoder, |
| "config": config_dict, |
| "epoch": ckpt.get("epoch", 0) + args.epochs, |
| "loss": avg_loss, |
| "trained_on": "legal_arguments_v1", |
| "base_checkpoint": args.base_checkpoint, |
| }, args.output) |
| print(f" ✅ saved", flush=True) |
| print(f"\n[7] Test the new checkpoint:", flush=True) |
| print(f" TAU_RAG_TAU_CKPT={args.output} \\\\", flush=True) |
| print(f" TAU_RAG_TAU_DEBUG=1 python3 -B -m \\\\", flush=True) |
| print(f" tau_rag.scripts.test_tau_polish \\\\", flush=True) |
| print(f" --parquet ... --n 50", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|