""" Setup: $env:TOGETHER_API_KEY="your-key" (from api.together.ai) pip install openai datasets coverage radon pytest Usage: python build_dataset.py --out training_data.jsonl python build_dataset.py --mbpp-limit 9999 --he-limit 9999 --out training_data.jsonl python build_dataset.py --skip 104 --out training_data.jsonl # resume """ import json, os, sys, ast, time, re, argparse from pathlib import Path from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeout from datasets import load_dataset from openai import OpenAI from evaluate import evaluate_one from analyze import analyze_one, get_client, MODEL, API_TIMEOUT MIN_IMPROVEMENT = 10.0 EVAL_TIMEOUT = 60 MAX_STARTING_SCORE = 85 def run_with_timeout(fn, *args, timeout=60, label="call", **kwargs): with ThreadPoolExecutor(max_workers=1) as ex: future = ex.submit(fn, *args, **kwargs) try: return future.result(timeout=timeout) except FuturesTimeout: raise TimeoutError( f"{label} exceeded {timeout}s — thread killed, moving to next function" ) def parse_wait_seconds(msg): msg = str(msg) m = re.search(r'(\d+)m(\d+\.?\d*)s', msg) if m: return int(m.group(1)) * 60 + float(m.group(2)) + 5 m = re.search(r'(\d+\.?\d*)s', msg) if m: return float(m.group(1)) + 5 m = re.search(r'(\d+)m', msg) if m: return int(m.group(1)) * 60 + 5 return 60 def call_api(fn, *args, label="API call", **kwargs): last_error = None for attempt in range(3): try: return run_with_timeout(fn, *args, timeout=API_TIMEOUT, label=label, **kwargs) except TimeoutError: raise # timeout = always skip, don't retry except Exception as e: last_error = e msg = str(e) if "429" in msg or "quota" in msg.lower() or "rate" in msg.lower(): wait = parse_wait_seconds(msg) print(f" [{label}] rate limit — waiting {wait:.0f}s...", end=" ", flush=True) time.sleep(wait) print("retrying") elif attempt < 2: print(f" [{label}] error: {msg[:80]}, retrying in 3s...") time.sleep(3) else: raise RuntimeError( f"[{label}] failed after 3 attempts. Last error: {msg[:120]}" ) raise last_error def make_weak_test(fn_name, assert_line): return ( f"import unittest\n" f"from target import {fn_name}\n\n" f"class Test{fn_name.title().replace('_', '')}(unittest.TestCase):\n" f" def test_basic(self):\n" f" {assert_line.strip()}\n" ) def load_mbpp(limit): print("Loading MBPP...") ds = load_dataset("google-research-datasets/mbpp", "full", split="train") ds = ds.shuffle(seed=42).select(range(min(limit, len(ds)))) entries = [] for row in ds: fn_code = row.get("code", "").strip() test_list = row.get("test_list", []) task_text = row.get("text", "").strip() if not fn_code or not test_list: continue try: fn_name = next( n.name for n in ast.walk(ast.parse(fn_code)) if isinstance(n, ast.FunctionDef) ) except Exception: continue entries.append({ "source": "mbpp", "fn_name": fn_name, "fn_code": fn_code, "test_code": make_weak_test(fn_name, test_list[0]), "task_text": task_text, }) print(f" MBPP: {len(entries)} entries") return entries def load_humaneval(limit): print("Loading HumanEval...") try: ds = load_dataset("openai/openai_humaneval", split="test") except Exception: try: ds = load_dataset("evalplus/humanevalplus", split="test") except Exception as e: print(f" HumanEval load failed: {e}") return [] ds = ds.shuffle(seed=42).select(range(min(limit, len(ds)))) entries = [] for row in ds: fn_code = (row.get("prompt", "") + row.get("canonical_solution", "")).strip() test_code = row.get("test", "").strip() fn_name = row.get("entry_point", "").strip() task_text = row.get("prompt", "").strip() if not fn_code or not fn_name: continue try: ast.parse(fn_code) except SyntaxError: continue first_assert = next( (l.strip() for l in test_code.splitlines() if l.strip().startswith("assert ")), None ) if not first_assert: continue entries.append({ "source": "humaneval", "fn_name": fn_name, "fn_code": fn_code, "test_code": make_weak_test(fn_name, first_assert), "task_text": task_text, }) print(f" HumanEval: {len(entries)} entries") return entries def generate_tests(client, fn_code, test_code, suggestions, problems, missing): suggest_text = "\n".join(f"- {s}" for s in suggestions) problems_text = "\n".join(f"- {p}" for p in problems) missing_text = "\n".join(f"- {m}" for m in missing) prompt = f"""Improve this Python test suite by adding more test methods. Output ONLY the complete improved Python test file. No explanations. No markdown. What to add: {suggest_text} Problems to fix: {problems_text} Missing cases: {missing_text} Function being tested: {fn_code} Current tests (improve these): {test_code}""" messages = [ {"role": "system", "content": "You are a Python test engineer. " "Output ONLY valid Python code. " "No markdown fences, no explanations."}, {"role": "user", "content": prompt}, ] def _call(): return client.chat.completions.create( model=MODEL, messages=messages, max_tokens=2048, temperature=0.2, ) try: response = call_api(_call, label="generate") raw = response.choices[0].message.content.strip() if "```" in raw: parts = raw.split("```") if len(parts) >= 2: code = parts[1] if code.startswith("python"): code = code[6:] raw = code.strip() ast.parse(raw) # validate Python syntax return raw except TimeoutError as e: print(f" generate TIMEOUT: {e}") return None except SyntaxError as e: print(f" generate syntax error: {e}") return None except Exception as e: print(f" generate error: {e}") return None def build_prompt(fn_code, test_code, suggestions): suggest_text = "\n".join(f"# - {s}" for s in suggestions[:3]) return ( f"### Instruction:\n" f"Improve the test suite below. Add more test methods to cover missing cases.\n" f"Only output Python code, no explanation.\n\n" f"### Suggestions:\n{suggest_text}\n\n" f"### Function:\n{fn_code}\n\n" f"### Current tests:\n{test_code}\n\n" f"### Improved tests:\n" ) def save_record(out_path, record): with open(out_path, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mbpp-limit", default=9999, type=int) parser.add_argument("--he-limit", default=9999, type=int) parser.add_argument("--out", default="training_data.jsonl") parser.add_argument("--skip", default=0, type=int, help="Skip first N entries to resume after a stop") parser.add_argument("--min-improvement", default=MIN_IMPROVEMENT, type=float) parser.add_argument("--max-start", default=MAX_STARTING_SCORE, type=float) parser.add_argument("--eval-timeout", default=EVAL_TIMEOUT, type=int, help="Max seconds for coverage+mutation evaluation (default 60)") args = parser.parse_args() client = get_client() print(f"Model : {MODEL}") print(f"API timeout : {API_TIMEOUT}s per call") print(f"Eval timeout : {args.eval_timeout}s per evaluation\n") entries = load_mbpp(args.mbpp_limit) + load_humaneval(args.he_limit) print(f"\nTotal: {len(entries)} entries combined") if args.skip > 0: entries = entries[args.skip:] print(f"Resuming from entry {args.skip + 1}\n") total_saved = 0 total_too_good = 0 total_no_improve = 0 total_timeout = 0 total_failed = 0 print(f"Processing : {len(entries)} functions → {args.out}") print(f"Skip if score >= {args.max_start} (nothing to improve)") print(f"Save if improvement >= {args.min_improvement} pts") print("=" * 65) for i, entry in enumerate(entries): fn_name = entry["fn_name"] fn_code = entry["fn_code"] test_code = entry["test_code"] source = entry["source"] print(f"[{i+1}/{len(entries)}] {fn_name} [{source}]") try: orig_cov, orig_mut, orig_sta, orig_score = run_with_timeout( evaluate_one, fn_code, test_code, timeout=args.eval_timeout, label="evaluate" ) except TimeoutError as e: print(f" STOP reason: TIMEOUT — {e}\n") total_timeout += 1 continue except Exception as e: print(f" STOP reason: EVAL FAIL — {e}\n") total_failed += 1 continue line_pct = (orig_cov.get("line_coverage_pct", "?") if "error" not in orig_cov else "ERR") print(f" original : score={orig_score}/100 " f"line={line_pct}% mut={orig_mut['mutation_score']}%") if orig_score >= args.max_start: print(f" STOP reason: TOO GOOD — score {orig_score} >= {args.max_start}\n") total_too_good += 1 continue try: analysis = call_api( analyze_one, client, fn_name, fn_code, test_code, orig_cov, orig_mut, orig_sta, label="analyze" ) except TimeoutError as e: print(f" STOP reason: TIMEOUT — {e}\n") total_timeout += 1 continue except Exception as e: print(f" STOP reason: ANALYZE FAIL — {e}\n") total_failed += 1 continue if "error" in analysis: print(f" STOP reason: ANALYZE FAIL — {analysis['error']}\n") total_failed += 1 continue suggestions = analysis.get("suggestions", []) problems = analysis.get("problems", []) missing = analysis.get("missing_cases", []) print(f" analyze : score={analysis.get('score','?')}/100 " f"suggestions={len(suggestions)}") new_tests = generate_tests( client, fn_code, test_code, suggestions, problems, missing ) if new_tests is None: print(f" STOP reason: GEN FAIL — LLM returned invalid Python or timed out\n") total_failed += 1 continue try: new_cov, new_mut, new_sta, new_score = run_with_timeout( evaluate_one, fn_code, new_tests, timeout=args.eval_timeout, label="re-evaluate" ) except TimeoutError as e: print(f" STOP reason: TIMEOUT — {e}\n") total_timeout += 1 continue except Exception as e: print(f" STOP reason: EVAL FAIL (re-eval) — {e}\n") total_failed += 1 continue new_line = (new_cov.get("line_coverage_pct", "?") if "error" not in new_cov else "ERR") improvement = round(new_score - orig_score, 1) arrow = "▲" if improvement > 0 else ("▼" if improvement < 0 else "─") print(f" improved : score={new_score}/100 " f"line={new_line}% mut={new_mut.get('mutation_score','?')}%") print(f" delta : {arrow} {abs(improvement)} pts") if improvement < args.min_improvement: print(f" STOP reason: NO IMPROVE — " f"{improvement} pts < {args.min_improvement} pts required\n") total_no_improve += 1 continue save_record(args.out, { "source": source, "fn_name": fn_name, "task_text": entry.get("task_text", ""), "fn_code": fn_code, "old_tests": test_code, "good_tests": new_tests, "prompt": build_prompt(fn_code, test_code, suggestions), "suggestions": suggestions, "old_score": orig_score, "new_score": new_score, "improvement": improvement, }) total_saved += 1 print(f" → SAVED\n") time.sleep(0.5) print("=" * 65) print(f" Processed : {len(entries)}") print(f" Saved : {total_saved}") print(f" Too good : {total_too_good} (score >= {args.max_start})") print(f" No improve : {total_no_improve} (improvement < {args.min_improvement} pts)") print(f" Timeout : {total_timeout} (eval or API hung > timeout)") print(f" Failed : {total_failed} (eval crash or LLM error)") print(f" Output : {args.out}") print("=" * 65)