#!/usr/bin/env python3 """continue_training_arguments.py — Fine-tune TAU LLM on argument pairs. Continues from your existing v11 checkpoint and trains for additional epochs on the focused (facts → argument) dataset produced by extract_argument_training_data.py. Why this approach (vs. training from scratch on whole judgments): 1. The current v11 model has the LANGUAGE (Hebrew) and the VOCABULARY (legal terms) — what it lacks is the FLOW (coherent sentence structure when adapting an argument). 2. By continuing from v11 on facts→argument PAIRS, we teach it the exact transformation we use it for at inference: "given these facts, produce this kind of legal reasoning." 3. ~200K paired examples × 3 epochs ≈ 600K weight updates targeted at the right task — should be enough to fix the word-soup problem without rebuilding the model. Usage: python3 -m tau_rag.scripts.continue_training_arguments \\ --data tau_rag/runtime/training_data/legal_arguments.jsonl \\ --base-checkpoint tau_rag/runtime/models/tau_hebrew_legal_llm_v11.pt \\ --output tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt \\ --epochs 3 \\ --batch-size 16 """ from __future__ import annotations import argparse import json import os import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[2])) def main(): ap = argparse.ArgumentParser() ap.add_argument("--data", required=True, help="JSONL produced by extract_argument_training_data") ap.add_argument("--base-checkpoint", required=True) ap.add_argument("--output", default="tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt") ap.add_argument("--epochs", type=int, default=3) ap.add_argument("--batch-size", type=int, default=16) ap.add_argument("--lr", type=float, default=1e-4, help="Lower LR for fine-tuning (default 1e-4 vs initial 3e-4)") ap.add_argument("--context-length", type=int, default=256, help="Longer context to fit facts+argument pairs") ap.add_argument("--max-samples", type=int, default=0, help="Cap samples for quick iteration. 0 = all.") ap.add_argument("--prefer-paired", action="store_true", help="Use only paired records (drops solo arguments)") ap.add_argument("--max-windows", type=int, default=120_000, help="Cap on training windows after dataset built. " "TextDataset uses stride=1 → ~10× more windows " "than tokens/context_len; 120K ≈ 90min/epoch on mps. " "Set 0 to disable (use full overlapping windows).") ap.add_argument("--checkpoint-every", type=int, default=2000, help="Save a rolling checkpoint every N batches so " "interrupted runs aren't a total loss. 0 = off.") args = ap.parse_args() print("=" * 75, flush=True) print("CONTINUE TRAINING ON LEGAL ARGUMENTS", flush=True) print("=" * 75, flush=True) print(f" data: {args.data}", flush=True) print(f" base ckpt: {args.base_checkpoint}", flush=True) print(f" output ckpt: {args.output}", flush=True) print(f" epochs: {args.epochs}", flush=True) print(f" batch_size: {args.batch_size}", flush=True) print(f" lr: {args.lr}", flush=True) print(f" context_len: {args.context_length}", flush=True) print(f" prefer_paired: {args.prefer_paired}", flush=True) # Load training data print("\n[1] Loading JSONL training data...", flush=True) t0 = time.time() texts = [] n_pairs = n_solo = 0 with open(args.data, "r", encoding="utf-8") as f: for line in f: try: rec = json.loads(line) except Exception: continue tt = rec.get("training_text") or rec.get("text") or "" if not tt: continue paired = rec.get("paired", False) if args.prefer_paired and not paired: continue texts.append(tt) if paired: n_pairs += 1 else: n_solo += 1 if args.max_samples and len(texts) >= args.max_samples: break print(f" ✅ {len(texts):,} records loaded " f"(pairs={n_pairs:,} solo={n_solo:,}) " f"in {time.time()-t0:.1f}s", flush=True) # Load v11 checkpoint print("\n[2] Loading base checkpoint...", flush=True) t0 = time.time() import torch ckpt = torch.load(args.base_checkpoint, map_location="cpu", weights_only=False) encoder = ckpt["encoder"] config_dict = ckpt["config"] print(f" ✅ checkpoint loaded ({time.time()-t0:.1f}s)", flush=True) print(f" vocab size: {len(getattr(encoder, 'vocab_layer', encoder).word2idx)}", flush=True) # Build trainer using the existing infrastructure print("\n[3] Building model + trainer...", flush=True) # next_token_trainer lives at the top-level `training/` package # (tau_platform_v4/training/), not under tau_rag/. Make sure the # platform root is on sys.path so the import resolves whether the # script is run from inside tau_rag/ or from the project root. _PLATFORM_ROOT = Path(__file__).resolve().parents[2] if str(_PLATFORM_ROOT) not in sys.path: sys.path.insert(0, str(_PLATFORM_ROOT)) from training.next_token_trainer import ( TAULanguageModel, TAUTrainer, TextDataset, TrainingConfig, ) # Honor the checkpoint's context_length so pos_encoding shape matches # the saved weights (v11 was trained with 512). If the saved config # doesn't specify, fall back to detecting from the state dict, then # to the CLI arg. sd = ckpt["model_state_dict"] pos_enc_shape = sd.get("embedding.pos_encoding") detected_ctx = (pos_enc_shape.shape[0] if pos_enc_shape is not None else None) ctx_len = (config_dict.get("context_length") or detected_ctx or args.context_length) if ctx_len != args.context_length: print(f" ℹ context_length: using {ctx_len} (from checkpoint) " f"instead of CLI default {args.context_length}", flush=True) cfg = TrainingConfig( embedding_dim=config_dict.get("embedding_dim", 256), hidden_dim=config_dict.get("hidden_dim", 512), num_layers=config_dict.get("num_layers", 4), num_heads=config_dict.get("num_heads", 8), context_length=ctx_len, dropout=config_dict.get("dropout", 0.1), batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.lr, device=("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"), ) vocab_layer = getattr(encoder, "vocab_layer", encoder) vocab_size = len(vocab_layer.word2idx) model = TAULanguageModel(vocab_size=vocab_size, config=cfg) model.load_state_dict(ckpt["model_state_dict"]) print(f" ✅ model built, weights loaded from v11", flush=True) print(f" device: {cfg.device}", flush=True) # Build dataset print("\n[4] Building dataset...", flush=True) dataset = TextDataset( texts=texts, encoder=encoder, context_length=cfg.context_length, ) print(f" ✅ {len(dataset):,} training windows", flush=True) from torch.utils.data import DataLoader, Subset # Cap windows to keep epoch wall-time tractable. TextDataset uses # stride=1 (each token starts a new window) so the raw count is # tokens-context_len. 120K windows ≈ 7500 batches @ batch_size=16 # ≈ 90 min/epoch on mps for a 16M-param model. if args.max_windows and len(dataset) > args.max_windows: import torch as _torch g = _torch.Generator().manual_seed(42) idx = _torch.randperm(len(dataset), generator=g)[:args.max_windows] dataset = Subset(dataset, idx.tolist()) print(f" ✂ capped to {len(dataset):,} windows " f"(--max-windows {args.max_windows})", flush=True) loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0) # Train print("\n[5] Training...", flush=True) trainer = TAUTrainer(model=model, config=cfg, vocab=vocab_layer) # Save helper — used both for periodic mid-epoch checkpoints and # for the final write at the end. Keeps the same save shape so the # rolling files are drop-in usable from the runtime loader. def _save(path, epoch_idx, last_loss): Path(path).parent.mkdir(parents=True, exist_ok=True) torch.save({ "model_state_dict": model.state_dict(), "encoder": encoder, "config": config_dict, "epoch": ckpt.get("epoch", 0) + epoch_idx, "loss": last_loss, "trained_on": "legal_arguments_v1", "base_checkpoint": args.base_checkpoint, }, path) # Patch trainer.train_epoch to call back periodically. Cleaner than # rewriting train_epoch — we monkey-patch the optimizer.step so we # can hook on every optimizer update. if args.checkpoint_every and args.checkpoint_every > 0: _orig_step = trainer.optimizer.step _state = {"step": 0, "last_loss": float("nan")} def _hooked_step(*a, **kw): r = _orig_step(*a, **kw) _state["step"] += 1 if _state["step"] % args.checkpoint_every == 0: if trainer.loss_history: _state["last_loss"] = trainer.loss_history[-1] roll = (Path(args.output).with_suffix("") .as_posix() + f"_step{_state['step']}.pt") _save(roll, epoch_idx=0, last_loss=_state["last_loss"]) print(f" 💾 rolling checkpoint @ step {_state['step']} " f"(loss={_state['last_loss']:.4f}) → {roll}", flush=True) return r trainer.optimizer.step = _hooked_step initial_loss = None avg_loss = None for epoch in range(args.epochs): print(f"\n ━━━ Epoch {epoch+1}/{args.epochs} ━━━", flush=True) try: avg_loss = trainer.train_epoch(loader) except KeyboardInterrupt: print(f"\n ⚠ interrupted mid-epoch — saving partial...", flush=True) partial = (Path(args.output).with_suffix("").as_posix() + f"_partial_epoch{epoch+1}.pt") last = (trainer.loss_history[-1] if trainer.loss_history else float("nan")) _save(partial, epoch_idx=epoch, last_loss=last) print(f" 💾 partial saved → {partial}", flush=True) return if initial_loss is None: initial_loss = avg_loss improvement = ((initial_loss - avg_loss) / initial_loss * 100 if initial_loss > 0 else 0) print(f" epoch {epoch+1}: avg_loss = {avg_loss:.4f} " f"(Δ {improvement:+.1f}% from start)", flush=True) # Per-epoch save so partial training is never lost per_epoch = (Path(args.output).with_suffix("").as_posix() + f"_epoch{epoch+1}.pt") _save(per_epoch, epoch_idx=epoch+1, last_loss=avg_loss) print(f" 💾 epoch checkpoint → {per_epoch}", flush=True) # Save print(f"\n[6] Saving to {args.output}...", flush=True) Path(args.output).parent.mkdir(parents=True, exist_ok=True) torch.save({ "model_state_dict": model.state_dict(), "encoder": encoder, "config": config_dict, "epoch": ckpt.get("epoch", 0) + args.epochs, "loss": avg_loss, "trained_on": "legal_arguments_v1", "base_checkpoint": args.base_checkpoint, }, args.output) print(f" ✅ saved", flush=True) print(f"\n[7] Test the new checkpoint:", flush=True) print(f" TAU_RAG_TAU_CKPT={args.output} \\\\", flush=True) print(f" TAU_RAG_TAU_DEBUG=1 python3 -B -m \\\\", flush=True) print(f" tau_rag.scripts.test_tau_polish \\\\", flush=True) print(f" --parquet ... --n 50", flush=True) if __name__ == "__main__": main()