#!/usr/bin/env python3 """Distill teacher answers from Claude into the training-traces dataset. Reads runtime/training_data/traces.jsonl, finds rows missing `teacher_answer`, sends each (query, context) to Anthropic Claude, and writes the response back to the row. The fine-tune script then prefers teacher_answer over the original extractive answer as the training target — bootstrapping a high-quality dataset in minutes instead of weeks of manual labeling. Cost: at default rate (~$3/1M input + $15/1M output, claude-sonnet-4) each row is ~2K input + 500 output ≈ $0.013. 100 rows ≈ $1.30. 500 rows ≈ $6.50. Usage: export ANTHROPIC_API_KEY=sk-ant-... ./scripts/distill.sh # process all unfilled rows ./scripts/distill.sh --max 50 # cap at 50 rows ./scripts/distill.sh --redo # overwrite existing teachers ./scripts/distill.sh --model claude-haiku-4-5 # cheaper model ./scripts/distill.sh --dry-run # show what would run, no calls """ from __future__ import annotations import argparse import json import os import sys import time from pathlib import Path from typing import Any, Dict, List, Optional SYSTEM_PROMPT = """אתה עוזר משפטי מומחה במשפט ישראלי, מתמחה בדיני חוזים ופסיקה. תפקידך: לקבל שאלה משפטית + מקורות (סעיפי חוק / פסקי דין שאוחזרו אוטומטית), \ ולכתוב תשובה משפטית מקצועית, מבוססת אך-ורק על המקורות שניתנו. כללים: 1. צטט את הסעיפים/הפסקים בצורתם המלאה — לא רק קטעים. 2. השתמש בציטוטי ייחוס בסגנון `[1]`, `[2]` המתייחסים לסדר המקורות. 3. אם המקורות לא מספיקים לתשובה — אמור זאת בכנות. 4. אסור להמציא ציטוטים, מספרי סעיפים, או הלכות שלא מופיעים במקורות. 5. כתוב בעברית משפטית רהוטה, מובנית — מתאימה למאמר משפטי. 6. הבא קשרים בין מקורות (לדוג' "סעיף 12 משלים את סעיף 39 בכך ש..."). 7. סגור בנקודות תיוג של פעולה מעשית אם רלוונטי. אורך מומלץ: 200-500 מילים.""" def build_user_prompt(query: str, context: List[Dict[str, Any]]) -> str: parts = [f"שאלה: {query.strip()}", "", "מקורות:"] for i, c in enumerate(context[:6], 1): meta = c.get("metadata") or {} title = meta.get("citation") or ( f"{meta.get('law', '')} · סעיף {meta.get('section', '')}" if meta.get("section") else c.get("doc_id", f"מקור {i}")) text = (c.get("text") or "").strip() # Strip boost markers import re as _re text = _re.sub(r"^(\[[^\]]+\]\s*)+", "", text) text = _re.sub(r"\[[^\]]+\]\s*", "", text) text = _re.sub(r"\s+", " ", text) parts.append(f"\n[{i}] {title}\n{text}") parts.append("\n\nכתוב תשובה משפטית מבוססת מקורות:") return "\n".join(parts) def call_claude(query: str, context: List[Dict[str, Any]], model: str, api_key: str, max_tokens: int = 800) -> str: """Single sync call to Claude. Uses urllib so we don't add a runtime dep on the anthropic SDK — the script can run on a fresh box.""" import urllib.request import urllib.error body = { "model": model, "max_tokens": max_tokens, "system": SYSTEM_PROMPT, "messages": [{ "role": "user", "content": build_user_prompt(query, context), }], } req = urllib.request.Request( "https://api.anthropic.com/v1/messages", data=json.dumps(body).encode("utf-8"), headers={ "Content-Type": "application/json", "x-api-key": api_key, "anthropic-version": "2023-06-01", }, ) with urllib.request.urlopen(req, timeout=60) as r: resp = json.loads(r.read().decode("utf-8")) # Anthropic returns content as a list of blocks blocks = resp.get("content") or [] return "".join(b.get("text", "") for b in blocks if b.get("type") == "text").strip() def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--traces", default=None, help="Path to traces.jsonl (default: runtime/training_data/traces.jsonl)") parser.add_argument("--model", default="claude-sonnet-4-6", help="Anthropic model id (claude-haiku-4-5 = cheaper)") parser.add_argument("--max", type=int, default=None, help="Cap number of rows to process this run") parser.add_argument("--redo", action="store_true", help="Re-run even on rows that already have teacher_answer") parser.add_argument("--dry-run", action="store_true", help="Show planned work without calling Claude") parser.add_argument("--sleep", type=float, default=0.5, help="Seconds between API calls (rate-limit friendly)") args = parser.parse_args() api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key and not args.dry_run: print("❌ ANTHROPIC_API_KEY not set. Run:") print(" export ANTHROPIC_API_KEY=sk-ant-...") print(" (use --dry-run to plan without an API key)") return 1 here = Path(__file__).resolve().parent.parent traces_path = Path(args.traces) if args.traces else ( here / "runtime" / "training_data" / "traces.jsonl") if not traces_path.exists(): print(f"❌ No traces at {traces_path}") print(" Run queries with TAU_RAG_COLLECT_TRAINING=1 first.") return 1 # Load all rows rows = [] with traces_path.open(encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: rows.append(json.loads(line)) except json.JSONDecodeError: continue print(f"📂 Loaded {len(rows)} rows from {traces_path}") # Filter: rows that need processing needs_work = [] for i, row in enumerate(rows): if not row.get("query") or not row.get("context"): continue if not args.redo and row.get("teacher_answer"): continue # Skip negative-feedback rows — we don't want to ground a teacher # call on a query that we already know was bad. Saves money. if row.get("feedback") == "down": continue needs_work.append(i) if args.max: needs_work = needs_work[:args.max] print(f"🎯 {len(needs_work)} rows need a teacher answer") if args.dry_run: print("\n--- DRY RUN — sample row plan ---") for i in needs_work[:3]: r = rows[i] ctx_n = len(r.get("context") or []) print(f" row {i}: query={r['query'][:50]!r} " f"context_chunks={ctx_n} " f"existing_teacher={'yes' if r.get('teacher_answer') else 'no'}") cost_low = len(needs_work) * 0.005 # haiku-ish cost_high = len(needs_work) * 0.020 # sonnet-ish print(f"\nEstimated cost: ${cost_low:.2f} – ${cost_high:.2f} " f"(model={args.model})") return 0 # Process n_done = 0 n_failed = 0 t0 = time.time() for idx in needs_work: row = rows[idx] try: teacher = call_claude( row["query"], row.get("context") or [], model=args.model, api_key=api_key, ) if teacher and len(teacher) > 30: row["teacher_answer"] = teacher row["teacher_model"] = args.model row["teacher_timestamp"] = time.time() n_done += 1 # Re-write the file periodically (so we don't lose progress # if the script is killed mid-run on a long batch) if n_done % 5 == 0: _save_all(traces_path, rows) preview = teacher[:80].replace("\n", " ") print(f" ✅ row {idx} ({n_done}/{len(needs_work)}): " f"\"{preview}...\"") else: n_failed += 1 print(f" ⚠️ row {idx}: empty response, skipping") except Exception as e: n_failed += 1 print(f" ❌ row {idx}: {type(e).__name__}: {e}") time.sleep(args.sleep) _save_all(traces_path, rows) elapsed = time.time() - t0 print(f"\n📊 Distillation done: {n_done} succeeded, {n_failed} failed " f"in {elapsed/60:.1f} min") print(f" Updated traces saved → {traces_path}") return 0 if n_done > 0 else 1 def _save_all(path: Path, rows: List[Dict[str, Any]]) -> None: tmp = path.with_suffix(".jsonl.tmp") with tmp.open("w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r, ensure_ascii=False) + "\n") tmp.replace(path) if __name__ == "__main__": sys.exit(main())