legal-eye / tau_rag /scripts /continue_training_arguments.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
Raw
History Blame Contribute Delete
12.5 kB
#!/usr/bin/env python3
"""continue_training_arguments.py — Fine-tune TAU LLM on argument pairs.
Continues from your existing v11 checkpoint and trains for additional
epochs on the focused (facts → argument) dataset produced by
extract_argument_training_data.py.
Why this approach (vs. training from scratch on whole judgments):
1. The current v11 model has the LANGUAGE (Hebrew) and the
VOCABULARY (legal terms) — what it lacks is the FLOW
(coherent sentence structure when adapting an argument).
2. By continuing from v11 on facts→argument PAIRS, we teach
it the exact transformation we use it for at inference:
"given these facts, produce this kind of legal reasoning."
3. ~200K paired examples × 3 epochs ≈ 600K weight updates
targeted at the right task — should be enough to fix the
word-soup problem without rebuilding the model.
Usage:
python3 -m tau_rag.scripts.continue_training_arguments \\
--data tau_rag/runtime/training_data/legal_arguments.jsonl \\
--base-checkpoint tau_rag/runtime/models/tau_hebrew_legal_llm_v11.pt \\
--output tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt \\
--epochs 3 \\
--batch-size 16
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data", required=True,
help="JSONL produced by extract_argument_training_data")
ap.add_argument("--base-checkpoint", required=True)
ap.add_argument("--output",
default="tau_rag/runtime/models/tau_hebrew_legal_llm_v12.pt")
ap.add_argument("--epochs", type=int, default=3)
ap.add_argument("--batch-size", type=int, default=16)
ap.add_argument("--lr", type=float, default=1e-4,
help="Lower LR for fine-tuning (default 1e-4 vs initial 3e-4)")
ap.add_argument("--context-length", type=int, default=256,
help="Longer context to fit facts+argument pairs")
ap.add_argument("--max-samples", type=int, default=0,
help="Cap samples for quick iteration. 0 = all.")
ap.add_argument("--prefer-paired", action="store_true",
help="Use only paired records (drops solo arguments)")
ap.add_argument("--max-windows", type=int, default=120_000,
help="Cap on training windows after dataset built. "
"TextDataset uses stride=1 → ~10× more windows "
"than tokens/context_len; 120K ≈ 90min/epoch on mps. "
"Set 0 to disable (use full overlapping windows).")
ap.add_argument("--checkpoint-every", type=int, default=2000,
help="Save a rolling checkpoint every N batches so "
"interrupted runs aren't a total loss. 0 = off.")
args = ap.parse_args()
print("=" * 75, flush=True)
print("CONTINUE TRAINING ON LEGAL ARGUMENTS", flush=True)
print("=" * 75, flush=True)
print(f" data: {args.data}", flush=True)
print(f" base ckpt: {args.base_checkpoint}", flush=True)
print(f" output ckpt: {args.output}", flush=True)
print(f" epochs: {args.epochs}", flush=True)
print(f" batch_size: {args.batch_size}", flush=True)
print(f" lr: {args.lr}", flush=True)
print(f" context_len: {args.context_length}", flush=True)
print(f" prefer_paired: {args.prefer_paired}", flush=True)
# Load training data
print("\n[1] Loading JSONL training data...", flush=True)
t0 = time.time()
texts = []
n_pairs = n_solo = 0
with open(args.data, "r", encoding="utf-8") as f:
for line in f:
try:
rec = json.loads(line)
except Exception:
continue
tt = rec.get("training_text") or rec.get("text") or ""
if not tt:
continue
paired = rec.get("paired", False)
if args.prefer_paired and not paired:
continue
texts.append(tt)
if paired: n_pairs += 1
else: n_solo += 1
if args.max_samples and len(texts) >= args.max_samples:
break
print(f" ✅ {len(texts):,} records loaded "
f"(pairs={n_pairs:,} solo={n_solo:,}) "
f"in {time.time()-t0:.1f}s", flush=True)
# Load v11 checkpoint
print("\n[2] Loading base checkpoint...", flush=True)
t0 = time.time()
import torch
ckpt = torch.load(args.base_checkpoint, map_location="cpu",
weights_only=False)
encoder = ckpt["encoder"]
config_dict = ckpt["config"]
print(f" ✅ checkpoint loaded ({time.time()-t0:.1f}s)", flush=True)
print(f" vocab size: {len(getattr(encoder, 'vocab_layer', encoder).word2idx)}",
flush=True)
# Build trainer using the existing infrastructure
print("\n[3] Building model + trainer...", flush=True)
# next_token_trainer lives at the top-level `training/` package
# (tau_platform_v4/training/), not under tau_rag/. Make sure the
# platform root is on sys.path so the import resolves whether the
# script is run from inside tau_rag/ or from the project root.
_PLATFORM_ROOT = Path(__file__).resolve().parents[2]
if str(_PLATFORM_ROOT) not in sys.path:
sys.path.insert(0, str(_PLATFORM_ROOT))
from training.next_token_trainer import (
TAULanguageModel, TAUTrainer, TextDataset, TrainingConfig,
)
# Honor the checkpoint's context_length so pos_encoding shape matches
# the saved weights (v11 was trained with 512). If the saved config
# doesn't specify, fall back to detecting from the state dict, then
# to the CLI arg.
sd = ckpt["model_state_dict"]
pos_enc_shape = sd.get("embedding.pos_encoding")
detected_ctx = (pos_enc_shape.shape[0]
if pos_enc_shape is not None else None)
ctx_len = (config_dict.get("context_length")
or detected_ctx
or args.context_length)
if ctx_len != args.context_length:
print(f" ℹ context_length: using {ctx_len} (from checkpoint) "
f"instead of CLI default {args.context_length}", flush=True)
cfg = TrainingConfig(
embedding_dim=config_dict.get("embedding_dim", 256),
hidden_dim=config_dict.get("hidden_dim", 512),
num_layers=config_dict.get("num_layers", 4),
num_heads=config_dict.get("num_heads", 8),
context_length=ctx_len,
dropout=config_dict.get("dropout", 0.1),
batch_size=args.batch_size,
epochs=args.epochs,
learning_rate=args.lr,
device=("mps" if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available() else "cpu"),
)
vocab_layer = getattr(encoder, "vocab_layer", encoder)
vocab_size = len(vocab_layer.word2idx)
model = TAULanguageModel(vocab_size=vocab_size, config=cfg)
model.load_state_dict(ckpt["model_state_dict"])
print(f" ✅ model built, weights loaded from v11", flush=True)
print(f" device: {cfg.device}", flush=True)
# Build dataset
print("\n[4] Building dataset...", flush=True)
dataset = TextDataset(
texts=texts,
encoder=encoder,
context_length=cfg.context_length,
)
print(f" ✅ {len(dataset):,} training windows", flush=True)
from torch.utils.data import DataLoader, Subset
# Cap windows to keep epoch wall-time tractable. TextDataset uses
# stride=1 (each token starts a new window) so the raw count is
# tokens-context_len. 120K windows ≈ 7500 batches @ batch_size=16
# ≈ 90 min/epoch on mps for a 16M-param model.
if args.max_windows and len(dataset) > args.max_windows:
import torch as _torch
g = _torch.Generator().manual_seed(42)
idx = _torch.randperm(len(dataset), generator=g)[:args.max_windows]
dataset = Subset(dataset, idx.tolist())
print(f" ✂ capped to {len(dataset):,} windows "
f"(--max-windows {args.max_windows})", flush=True)
loader = DataLoader(dataset, batch_size=cfg.batch_size,
shuffle=True, num_workers=0)
# Train
print("\n[5] Training...", flush=True)
trainer = TAUTrainer(model=model, config=cfg, vocab=vocab_layer)
# Save helper — used both for periodic mid-epoch checkpoints and
# for the final write at the end. Keeps the same save shape so the
# rolling files are drop-in usable from the runtime loader.
def _save(path, epoch_idx, last_loss):
Path(path).parent.mkdir(parents=True, exist_ok=True)
torch.save({
"model_state_dict": model.state_dict(),
"encoder": encoder,
"config": config_dict,
"epoch": ckpt.get("epoch", 0) + epoch_idx,
"loss": last_loss,
"trained_on": "legal_arguments_v1",
"base_checkpoint": args.base_checkpoint,
}, path)
# Patch trainer.train_epoch to call back periodically. Cleaner than
# rewriting train_epoch — we monkey-patch the optimizer.step so we
# can hook on every optimizer update.
if args.checkpoint_every and args.checkpoint_every > 0:
_orig_step = trainer.optimizer.step
_state = {"step": 0, "last_loss": float("nan")}
def _hooked_step(*a, **kw):
r = _orig_step(*a, **kw)
_state["step"] += 1
if _state["step"] % args.checkpoint_every == 0:
if trainer.loss_history:
_state["last_loss"] = trainer.loss_history[-1]
roll = (Path(args.output).with_suffix("")
.as_posix() + f"_step{_state['step']}.pt")
_save(roll, epoch_idx=0, last_loss=_state["last_loss"])
print(f" 💾 rolling checkpoint @ step {_state['step']} "
f"(loss={_state['last_loss']:.4f}) → {roll}",
flush=True)
return r
trainer.optimizer.step = _hooked_step
initial_loss = None
avg_loss = None
for epoch in range(args.epochs):
print(f"\n ━━━ Epoch {epoch+1}/{args.epochs} ━━━", flush=True)
try:
avg_loss = trainer.train_epoch(loader)
except KeyboardInterrupt:
print(f"\n ⚠ interrupted mid-epoch — saving partial...",
flush=True)
partial = (Path(args.output).with_suffix("").as_posix()
+ f"_partial_epoch{epoch+1}.pt")
last = (trainer.loss_history[-1]
if trainer.loss_history else float("nan"))
_save(partial, epoch_idx=epoch, last_loss=last)
print(f" 💾 partial saved → {partial}", flush=True)
return
if initial_loss is None:
initial_loss = avg_loss
improvement = ((initial_loss - avg_loss) / initial_loss * 100
if initial_loss > 0 else 0)
print(f" epoch {epoch+1}: avg_loss = {avg_loss:.4f} "
f"(Δ {improvement:+.1f}% from start)", flush=True)
# Per-epoch save so partial training is never lost
per_epoch = (Path(args.output).with_suffix("").as_posix()
+ f"_epoch{epoch+1}.pt")
_save(per_epoch, epoch_idx=epoch+1, last_loss=avg_loss)
print(f" 💾 epoch checkpoint → {per_epoch}", flush=True)
# Save
print(f"\n[6] Saving to {args.output}...", flush=True)
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
torch.save({
"model_state_dict": model.state_dict(),
"encoder": encoder,
"config": config_dict,
"epoch": ckpt.get("epoch", 0) + args.epochs,
"loss": avg_loss,
"trained_on": "legal_arguments_v1",
"base_checkpoint": args.base_checkpoint,
}, args.output)
print(f" ✅ saved", flush=True)
print(f"\n[7] Test the new checkpoint:", flush=True)
print(f" TAU_RAG_TAU_CKPT={args.output} \\\\", flush=True)
print(f" TAU_RAG_TAU_DEBUG=1 python3 -B -m \\\\", flush=True)
print(f" tau_rag.scripts.test_tau_polish \\\\", flush=True)
print(f" --parquet ... --n 50", flush=True)
if __name__ == "__main__":
main()