import ast import os, sys, json, ast, csv, torch, tempfile, subprocess, copy from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from transformers import BitsAndBytesConfig # config BASE_MODEL = "stabilityai/stable-code-3b" FINETUNED_DIR = r"finetuned_model" INPUT_CSV = "functions.csv" OUTPUT_FILE = "pipeline2_results.json" # score weights (same as pipeline 1) WEIGHTS = { "line_coverage": 0.35, "branch_coverage": 0.35, "mutation_score": 0.20, "assertion_density": 0.10, } # ── evaluation helpers (same logic as evaluate.py) def write_files(fn_code, test_code, tmpdir): Path(tmpdir, "target.py").write_text(fn_code, encoding="utf-8") Path(tmpdir, "test_target.py").write_text(test_code, encoding="utf-8") def run_mutant(mutant_src, test_code, timeout=5): script = f""" import sys, unittest, io, types fn = {repr(mutant_src)} tc = {repr(test_code)} mod = types.ModuleType("target") exec(compile(fn, "", "exec"), mod.__dict__) sys.modules["target"] = mod tm = types.ModuleType("_tm") exec(compile(tc, "", "exec"), tm.__dict__) sys.modules["_tm"] = tm suite = unittest.TestLoader().loadTestsFromModule(tm) buf = io.StringIO() r = unittest.TextTestRunner(stream=buf, verbosity=0).run(suite) print("KILLED" if (r.failures or r.errors) else "SURVIVED") """ try: r = subprocess.run([sys.executable, "-c", script], capture_output=True, text=True, timeout=timeout) if "KILLED" in r.stdout: return "killed" if "SURVIVED" in r.stdout: return "survived" return "killed" except subprocess.TimeoutExpired: return "timeout" AOR = {ast.Add: ast.Sub, ast.Sub: ast.Add, ast.Mult: ast.Div, ast.Div: ast.Mult, ast.Mod: ast.Add} ROR = {ast.Eq: ast.NotEq, ast.NotEq: ast.Eq, ast.Lt: ast.Gt, ast.Gt: ast.Lt, ast.LtE: ast.GtE, ast.GtE: ast.LtE} LCR = {ast.And: ast.Or, ast.Or: ast.And} def make_mutants(fn_code): tree = ast.parse(fn_code) mutants = [] def mutate(pred, do_replace): for node in ast.walk(copy.deepcopy(tree)): if not pred(node): continue m = copy.deepcopy(tree) for n in ast.walk(m): if pred(n) and getattr(n, "lineno", -1) == getattr(node, "lineno", -2): do_replace(n); break ast.fix_missing_locations(m) try: mutants.append(ast.unparse(m)) except: pass mutate(lambda n: isinstance(n, ast.BinOp) and type(n.op) in AOR, lambda n: setattr(n, "op", AOR[type(n.op)]())) mutate(lambda n: isinstance(n, ast.Compare) and n.ops and type(n.ops[0]) in ROR, lambda n: n.ops.__setitem__(0, ROR[type(n.ops[0])]())) mutate(lambda n: isinstance(n, ast.BoolOp) and type(n.op) in LCR, lambda n: setattr(n, "op", LCR[type(n.op)]())) mutate(lambda n: isinstance(n, ast.Return) and n.value is not None, lambda n: setattr(n, "value", ast.Constant(value=None))) mutate(lambda n: isinstance(n, ast.Constant) and n.value in (0, 1), lambda n: setattr(n, "value", 1 - n.value)) return mutants ASSERT_METHODS = {"assertEqual","assertNotEqual","assertTrue","assertFalse", "assertIn","assertAlmostEqual","assertRaises","assertGreater", "assertLess","assertIsNone","assertIsNotNone","assertNotIn"} def count_assertions(fn_node): count = 0 for n in ast.walk(fn_node): if isinstance(n, ast.Assert): count += 1 elif (isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) and n.func.attr in ASSERT_METHODS): count += 1 return count def eval_coverage(fn_code, test_code, tmpdir): write_files(fn_code, test_code, tmpdir) r1 = subprocess.run([sys.executable, "-m", "coverage", "run", "--branch", "--include=target.py", "-m", "pytest", "test_target.py", "-q", "--tb=no"], capture_output=True, text=True, cwd=tmpdir) subprocess.run([sys.executable, "-m", "coverage", "json", "-o", "cov.json"], capture_output=True, text=True, cwd=tmpdir) cov_json = Path(tmpdir, "cov.json") if not cov_json.exists(): return {} raw = json.loads(cov_json.read_text()) fd = next((v for k, v in raw["files"].items() if k.endswith("target.py")), None) if not fd: return {} s = fd["summary"] missing_br = fd.get("missing_branches", []) total_br = s.get("num_branches", 0) branch_pct = round(100 * (total_br - len(missing_br)) / total_br, 1) if total_br else 100.0 return { "line_coverage_pct": round(s.get("percent_covered", 0), 1), "branch_coverage_pct": branch_pct, "missing_lines": fd.get("missing_lines", []), "missing_branches": missing_br, } def eval_mutation(fn_code, test_code): mutants = make_mutants(fn_code) if not mutants: return {"total_mutants": 0, "killed": 0, "mutation_score": 100.0} killed = sum(1 for m in mutants if run_mutant(m, test_code) in ("killed", "timeout")) return { "total_mutants": len(mutants), "killed": killed, "mutation_score": round(100 * killed / len(mutants), 1), } def eval_static(fn_code, test_code): test_tree = ast.parse(test_code) test_fns = [n for n in ast.walk(test_tree) if isinstance(n, ast.FunctionDef) and n.name.startswith("test")] assertions = [count_assertions(fn) for fn in test_fns] total = sum(assertions) density = round(total / len(test_fns), 2) if test_fns else 0.0 return { "total_test_functions": len(test_fns), "total_assertions": total, "assertion_density": density, } def evaluate(fn_code, test_code): with tempfile.TemporaryDirectory() as tmp: cov = eval_coverage(fn_code, test_code, tmp) mut = eval_mutation(fn_code, test_code) sta = eval_static(fn_code, test_code) return {"coverage": cov, "mutation": mut, "static": sta} def compute_score(eval_result): cov = eval_result.get("coverage", {}) mut = eval_result.get("mutation", {}) sta = eval_result.get("static", {}) line = cov.get("line_coverage_pct", 0) / 100 branch = cov.get("branch_coverage_pct", 0) / 100 mut_s = mut.get("mutation_score", 0) / 100 dens = min(sta.get("assertion_density", 0) / 2, 1.0) # cap at 1 score = (WEIGHTS["line_coverage"] * line + WEIGHTS["branch_coverage"] * branch + WEIGHTS["mutation_score"] * mut_s + WEIGHTS["assertion_density"] * dens) return round(score, 4) # ── model ───────────────────────────────────────────────────────────────────── def load_model(use_finetuned=True): print("loading model...") if use_finetuned: tokenizer = AutoTokenizer.from_pretrained(FINETUNED_DIR) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", dtype=torch.float16) model = PeftModel.from_pretrained(base_model, FINETUNED_DIR) print("using fine-tuned model") else: tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", dtype=torch.float16) print("using base model (before fine-tuning)") model.eval() # warn if running on CPU — will be very slow device = next(model.parameters()).device if str(device) == "cpu": print(" WARNING: model is on CPU, generation will be very slow") print(" consider using a smaller model or a machine with GPU") else: print(f" running on {device}") print("model loaded\n") return model, tokenizer def generate_tests(model, tokenizer, fn_code, old_test_code): prompt = f"""### Instruction: Given the following Python function and its current test cases, generate improved test cases. The new tests should have better coverage, catch more bugs, and include edge cases. ### Function: {fn_code} ### Current tests: {old_test_code} ### Response: """ device = next(model.parameters()).device inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=300, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) code = generated.split("### Response:")[-1].strip() # model sometimes cuts off mid-line — trim to the last complete line # that ends with a valid Python ending character lines = code.splitlines() clean = [] last_valid = [] for line in lines: clean.append(line) try: ast.parse("\n".join(clean)) last_valid = list(clean) except SyntaxError: pass return "\n".join(last_valid) if last_valid else code # ── main ────────────────────────────────────────────────────────────────────── def load_csv(path): rows = [] with open(path, newline="", encoding="utf-8") as f: for row in csv.DictReader(f): rows.append({"function_code": row["function_code"].strip(), "test_code": row["test_code"].strip()}) return rows def main(): rows = load_csv(INPUT_CSV) use_finetuned = True # set False to test base model model, tokenizer = load_model(use_finetuned=use_finetuned) # set True to use fine-tuned results = [] for i, row in enumerate(rows): fn_code = row["function_code"] old_tests = row["test_code"] try: fn_name = next(n.name for n in ast.walk(ast.parse(fn_code)) if isinstance(n, ast.FunctionDef)) except: fn_name = f"function_{i+1}" print(f"[{i+1}/{len(rows)}] {fn_name}") # step 1 - evaluate old tests print(" evaluating old tests...") old_eval = evaluate(fn_code, old_tests) old_score = compute_score(old_eval) print(f" old score: {old_score}") # step 2 - generate new tests (retry up to 3 times if finetuned) max_attempts = 3 if use_finetuned else 1 new_tests = new_eval = None new_score = 0.0 for attempt in range(1, max_attempts + 1): if max_attempts > 1: print(f" generating new tests (attempt {attempt}/{max_attempts})...") else: print(" generating new tests...") candidate = generate_tests(model, tokenizer, fn_code, old_tests) print(" ── generated tests ──────────────────────────────") print(candidate) print(" ─────────────────────────────────────────────────") try: candidate_eval = evaluate(fn_code, candidate) candidate_score = compute_score(candidate_eval) except Exception as e: candidate_eval = {"error": str(e)} candidate_score = 0.0 print(f" score: {candidate_score}") if candidate_score > old_score: new_tests = candidate new_eval = candidate_eval new_score = candidate_score print(f" improvement found: {round((new_score - old_score) * 100, 2)}%\n") break else: print(f" no improvement (attempt {attempt}/{max_attempts})") if attempt == max_attempts: print(" failed to generate better tests\n") new_tests = candidate new_eval = candidate_eval new_score = candidate_score results.append({ "function_name": fn_name, "function_code": fn_code, "old_test_code": old_tests, "new_test_code": new_tests, "old_score": old_score, "new_score": new_score, "improvement": round(new_score - old_score, 4), "old_evaluation": old_eval, "new_evaluation": new_eval, }) Path(OUTPUT_FILE).write_text(json.dumps(results, indent=2)) print(f"saved -> {OUTPUT_FILE}") # print summary print("\n" + "="*60) print(f" {'function':<25} {'old':>6} {'new':>6} {'diff':>7}") print(" " + "-"*58) for r in results: diff = f"+{r['improvement']*100:.1f}%" if r['improvement'] >= 0 else f"{r['improvement']*100:.1f}%" print(f" {r['function_name']:<25} {r['old_score']:>6.3f} {r['new_score']:>6.3f} {diff:>7}") avg = sum(r['improvement'] for r in results) / len(results) if results else 0 print(f"\n average improvement: {avg*100:.2f}%") print("="*60) if __name__ == "__main__": main()