#!/usr/bin/env python3 """Fine-tune the TAU Native LM on collected (query, context, answer) traces. Reads from runtime/training_data/traces.jsonl (populated whenever the user runs with TAU_RAG_COLLECT_TRAINING=1), and continues training the checkpoint at TAU_RAG_TAU_CKPT. Training objective: next-token prediction on the full trace formatted as: שאלה: מקורות: [1] ... תשובה: Only rows where feedback != "down" are used (filter out known-bad). Rows with feedback == "up" are oversampled 3× (emphasize known-good). Usage: python -m tau_rag.scripts.finetune_from_traces \ --epochs 3 \ --lr 5e-5 \ --out training/tau_hebrew_legal_llm_v2.pt Outputs a new .pt alongside the base, preserving the original so you can A/B compare. """ from __future__ import annotations import argparse import json import os import sys import time from pathlib import Path from typing import Any, Dict, Iterable, List, Tuple def _format_row(row: Dict[str, Any]) -> tuple: """Build the training (prompt, answer) pair for one trace row. v4.x — returns (prompt_text, answer_text) so the trainer can apply LOSS MASKING. Only the answer portion contributes to the loss. Without masking the model wastes capacity learning to autoregress over question/context tokens, drowning out the Q+ctx → A signal. Falls back to extractive `answer` when no `teacher_answer` exists. """ q = (row.get("query") or "").strip() ans = (row.get("teacher_answer") or row.get("answer") or "").strip() ctx_list = row.get("context") or [] import re as _re # Use top-1 retrieved doc as grounding context, truncated. # (Multi-doc inflates the prompt past context_length, leaving no # room for the answer — defeats the purpose. Single doc is enough # signal for the model to learn "ground here".) ctx_text = "" if ctx_list: raw = (ctx_list[0].get("text") or "").strip() raw = _re.sub(r"^(\[[^\]]+\]\s*)+", "", raw) raw = _re.sub(r"\s+", " ", raw) ctx_text = raw[:600] if ctx_text: prompt = f"שאלה: {q}\nמקור: {ctx_text}\nתשובה:" else: prompt = f"שאלה: {q}\nתשובה:" return prompt, ans def iter_traces(path: Path, skip_negative: bool = True, oversample_positive: int = 3) -> Iterable[tuple]: """Stream (prompt, answer) pairs from traces.jsonl.""" with path.open(encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: row = json.loads(line) except json.JSONDecodeError: continue fb = row.get("feedback") if skip_negative and fb == "down": continue pair = _format_row(row) yield pair if fb == "up": for _ in range(oversample_positive - 1): yield pair def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--traces", default=None, help="Path to traces.jsonl (default: runtime/training_data/traces.jsonl)") parser.add_argument("--ckpt-in", default=os.environ.get("TAU_RAG_TAU_CKPT"), help="Base checkpoint to continue from") parser.add_argument("--ckpt-out", required=True, help="Where to save the fine-tuned checkpoint") parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--lr", type=float, default=5e-5) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=None, help="Override checkpoint's context length (default: use checkpoint)") parser.add_argument("--device", default=None, help="cpu/mps/cuda (default: auto)") parser.add_argument("--dry-run", action="store_true", help="Parse traces and report dataset stats, don't train") args = parser.parse_args() # Resolve paths here = Path(__file__).resolve().parent.parent traces_path = Path(args.traces) if args.traces else ( here / "runtime" / "training_data" / "traces.jsonl") if not traces_path.exists(): print(f"❌ No traces found at {traces_path}") print(" Run queries with TAU_RAG_COLLECT_TRAINING=1 to generate some.") return 1 # Build the dataset (materialize for shuffling) print(f"📖 Loading traces from {traces_path}...") pairs = list(iter_traces(traces_path)) # list of (prompt, answer) tuples print(f"✅ {len(pairs)} training examples (including oversampled 👍)") if not pairs: print("❌ Dataset empty — nothing to train on.") return 1 if args.dry_run: # Print a sample print("\n--- SAMPLE PROMPT ---") print(pairs[0][0][:500]) print("\n--- SAMPLE ANSWER ---") print(pairs[0][1][:500]) print("\n--- STATS ---") lens = [len(p[0]) + len(p[1]) for p in pairs] texts = [p[0] + " " + p[1] for p in pairs] print(f" chars: min={min(lens)}, avg={sum(lens)//len(lens)}, max={max(lens)}") return 0 # Now the heavy part — imports are deferred so --dry-run doesn't # need torch installed. if not args.ckpt_in or not Path(args.ckpt_in).exists(): print(f"❌ Base checkpoint missing: {args.ckpt_in}") return 1 try: import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset except ImportError as e: print(f"❌ torch required: {e}") return 1 # Import our model + vendored encoder sys.path.insert(0, str(here.parent)) # make tau_rag.* importable from tau_rag.generate.tau_native_model import ( TAULanguageModel, TrainingConfig, ) from tau_rag import encoding as _enc_pkg sys.modules.setdefault("encoding", _enc_pkg) from tau_rag.encoding import hebrew_encoder as _heb sys.modules.setdefault("encoding.hebrew_encoder", _heb) print(f"🔋 Loading base checkpoint: {args.ckpt_in}") ckpt = torch.load(args.ckpt_in, map_location="cpu", weights_only=False) cfg_dict = ckpt.get("config") or {} # v4.x — default context_length bumped 128 → 512 so a full # (question + context + answer) example fits in one window. # The position encoding is sinusoidal (parameter-free), so # increasing this is safe — we just rebuild the buffer. ctx_len = args.context_len or max(cfg_dict.get("context_length", 128), 512) config = TrainingConfig( embedding_dim=cfg_dict.get("embedding_dim", 256), hidden_dim=cfg_dict.get("hidden_dim", 512), num_layers=cfg_dict.get("num_layers", 4), num_heads=cfg_dict.get("num_heads", 8), context_length=ctx_len, dropout=cfg_dict.get("dropout", 0.1), learning_rate=args.lr, batch_size=args.batch_size, epochs=args.epochs, ) encoder = ckpt.get("encoder") vocab = getattr(encoder, "vocab_layer", None) or encoder vocab_size = cfg_dict.get("vocab_size") or vocab.size() model = TAULanguageModel(config=config, vocab_size=vocab_size) # v4.x — drop the saved sinusoidal pos_encoding buffer when its # length differs from the new context_length; the freshly-built # 512-length buffer in `model` is already correct (deterministic). saved_sd = ckpt["model_state_dict"] saved_sd = {k: v for k, v in saved_sd.items() if not k.endswith(".pos_encoding")} missing, unexpected = model.load_state_dict(saved_sd, strict=False) # `pos_encoding` will be in `missing` — that's expected and fine. real_missing = [m for m in missing if not m.endswith(".pos_encoding")] if real_missing: print(f" ⚠️ unexpected missing keys: {real_missing}") if unexpected: print(f" ⚠️ unexpected extra keys: {unexpected}") device = args.device or ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") # v3.x — MPS workaround: only PyTorch < 2.4 needed weight-untying. # PyTorch 2.4+ handles weight-tied backward correctly on MPS, and # untying doubles trainable params which halves effective lr. torch_version = tuple(int(x) for x in torch.__version__.split(".")[:2]) needs_mps_untie = (device == "mps" and torch_version < (2, 4)) if needs_mps_untie: import torch.nn as nn emb_weight = model.embedding.token_embedding.weight.data.clone() model.lm_head.weight = nn.Parameter(emb_weight) print(f" [MPS workaround] untied lm_head ↔ embedding " f"(torch {torch.__version__} < 2.4)") elif device == "mps": print(f" [MPS] torch {torch.__version__} — keeping weight-tying " f"(no workaround needed)") model.to(device) model.train() print(f"✅ Loaded model ({sum(p.numel() for p in model.parameters()):,} params) " f"on {device}") # v4.x — Tokenize PROMPT and ANSWER separately so we can build # `labels` with -100 (ignore_index) on prompt tokens. Only the # answer tokens contribute to the loss → real supervised # fine-tuning, not language modeling on the whole soup. def _to_ids(text): result = vocab.encode(text) ids = result[0] if isinstance(result, tuple) else result if hasattr(ids, "tolist"): ids = ids.tolist() return list(ids) print("🔤 Tokenizing (prompt + answer separately for loss masking)...") examples = [] # list of (input_ids, labels) — both length ctx_len skipped = 0 for prompt, ans in pairs: if not ans.strip(): skipped += 1 continue prompt_ids = _to_ids(prompt) answer_ids = _to_ids(ans) if len(answer_ids) < 5: skipped += 1 continue # Budget: total length ctx_len + 1 (we shift by 1 for x/y). # Reserve at least min(len(answer_ids), ctx_len // 2) for the # answer; truncate prompt from the LEFT (keep question + the # "תשובה:" marker right before the answer) if it overflows. budget = ctx_len + 1 ans_keep = min(len(answer_ids), max(budget - 20, budget * 3 // 4)) prompt_keep = budget - ans_keep if len(prompt_ids) > prompt_keep: # Keep the TAIL of the prompt (the part closest to the # answer) so the "תשובה:" marker stays adjacent. The head # of the prompt = beginning of context — least critical. prompt_ids = prompt_ids[-prompt_keep:] if len(answer_ids) > ans_keep: answer_ids = answer_ids[:ans_keep] full = prompt_ids + answer_ids # Pad to ctx_len + 1 pad_id = 0 if len(full) < budget: full = full + [pad_id] * (budget - len(full)) else: full = full[:budget] # Build labels: -100 everywhere except answer positions labels = [-100] * budget ans_start = len(prompt_ids) for j in range(ans_start, ans_start + len(answer_ids)): if j < budget: labels[j] = full[j] # Shift for next-token prediction: x = full[:-1], y = labels[1:] x = full[:-1] y = labels[1:] examples.append((x, y)) print(f"✅ {len(examples)} training examples after build " f"(skipped {skipped} empty/short)") if not examples: print("❌ No usable examples after masking. Aborting.") return 1 class TracesDataset(Dataset): def __init__(self, exs): self.exs = exs def __len__(self): return len(self.exs) def __getitem__(self, i): x, y = self.exs[i] return (torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)) ds = TracesDataset(examples) dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True) # Sanity: how many ANSWER tokens per example on average? ans_tokens = sum(sum(1 for t in y if t != -100) for _, y in examples) avg_ans = ans_tokens / max(1, len(examples)) print(f"🧱 {len(ds)} examples, batch size {args.batch_size}, " f"avg {avg_ans:.0f} answer-tokens/example " f"(prompt tokens are masked from loss)") opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) start = time.time() # v3.x — MPS workaround: F.cross_entropy(ignore_index=...) bus-errors # on PyTorch 2.0.x + MPS. Implement masked-NLL manually so the only # MPS ops are matmul + softmax + log (all stable). Falls through to # the standard path on CPU/CUDA. # v4.x — ignore index is now -100 (prompt-token mask), not pad=0. # We still mask pad=0 implicitly because labels[i] is set to -100 # everywhere outside the answer span (and pad falls outside). def _safe_loss(logits_2d, targets_1d, vocab_size, ignore_idx=-100): if device == "mps": log_probs = F.log_softmax(logits_2d, dim=-1) mask = (targets_1d != ignore_idx).float() # Gather the log-prob of the target token at each position; # clamp to avoid OOB indexing on the masked (-100) positions. tgt_safe = targets_1d.clamp(min=0) picked = log_probs.gather(1, tgt_safe.unsqueeze(1)).squeeze(1) nll = -(picked * mask).sum() / mask.sum().clamp(min=1.0) return nll return F.cross_entropy(logits_2d, targets_1d, ignore_index=ignore_idx) # v3.x — track best checkpoint across epochs and save the BEST, # not the LAST. With small datasets the model overfits and the # final epoch is often worse than some earlier one. We snapshot # state_dict to RAM whenever loss improves and write it at the end. best_loss = float("inf") best_state = None best_epoch = 0 patience_counter = 0 PATIENCE = max(10, args.epochs // 4) # stop if no improvement for N epochs for epoch in range(1, args.epochs + 1): total = 0.0 n = 0 for bx, by in dl: bx, by = bx.to(device), by.to(device) opt.zero_grad() logits, _ = model(bx) loss = _safe_loss( logits.view(-1, vocab_size), by.view(-1), vocab_size) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() total += loss.item() n += 1 if n % 20 == 0: print(f" epoch={epoch} step={n} loss={loss.item():.4f}") avg_loss = total / max(1, n) marker = "" if avg_loss < best_loss - 0.01: # need real improvement, not noise best_loss = avg_loss best_epoch = epoch # Snapshot weights to CPU (cheaper RAM than GPU) best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} patience_counter = 0 marker = " ⭐ NEW BEST" else: patience_counter += 1 print(f"📈 Epoch {epoch}: avg loss = {avg_loss:.4f}{marker}") if patience_counter >= PATIENCE: print(f"⏹ Early stopping — no improvement for {PATIENCE} epochs.") break # Restore the best snapshot before saving if best_state is not None and best_epoch < epoch: print(f"🔄 Restoring best checkpoint from epoch {best_epoch} " f"(loss={best_loss:.4f}, vs final loss={avg_loss:.4f})") model.load_state_dict(best_state) # Update the loss/n to reflect what we'll save n = 1 total = best_loss elapsed = time.time() - start print(f"⏱ Training time: {elapsed/60:.1f} min") # v3.x — re-tie only if we untied earlier (torch < 2.4 path). if needs_mps_untie: with torch.no_grad(): avg = (model.embedding.token_embedding.weight.data + model.lm_head.weight.data) / 2 model.embedding.token_embedding.weight.data = avg model.lm_head.weight = model.embedding.token_embedding.weight print(f" [MPS workaround] re-tied lm_head ↔ embedding (averaged)") # Save out_path = Path(args.ckpt_out) out_path.parent.mkdir(parents=True, exist_ok=True) torch.save({ "model_state_dict": model.state_dict(), "encoder": encoder, "config": { "embedding_dim": config.embedding_dim, "hidden_dim": config.hidden_dim, "num_layers": config.num_layers, "num_heads": config.num_heads, "context_length": config.context_length, "dropout": config.dropout, "vocab_size": vocab_size, }, "epoch": (cfg_dict.get("epoch") or 0) + args.epochs, "loss": total / max(1, n), "fine_tuned_from": str(args.ckpt_in), "n_traces": len(pairs), }, out_path) print(f"💾 Saved fine-tuned checkpoint → {out_path}") return 0 if __name__ == "__main__": sys.exit(main())