Spaces:
Sleeping
Sleeping
| import ast | |
| import os, sys, json, ast, csv, torch, tempfile, subprocess, copy | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from transformers import BitsAndBytesConfig | |
| # config | |
| BASE_MODEL = "stabilityai/stable-code-3b" | |
| FINETUNED_DIR = r"finetuned_model" | |
| INPUT_CSV = "functions.csv" | |
| OUTPUT_FILE = "pipeline2_results.json" | |
| # score weights (same as pipeline 1) | |
| WEIGHTS = { | |
| "line_coverage": 0.35, | |
| "branch_coverage": 0.35, | |
| "mutation_score": 0.20, | |
| "assertion_density": 0.10, | |
| } | |
| # ββ evaluation helpers (same logic as evaluate.py) | |
| def write_files(fn_code, test_code, tmpdir): | |
| Path(tmpdir, "target.py").write_text(fn_code, encoding="utf-8") | |
| Path(tmpdir, "test_target.py").write_text(test_code, encoding="utf-8") | |
| def run_mutant(mutant_src, test_code, timeout=5): | |
| script = f""" | |
| import sys, unittest, io, types | |
| fn = {repr(mutant_src)} | |
| tc = {repr(test_code)} | |
| mod = types.ModuleType("target") | |
| exec(compile(fn, "<t>", "exec"), mod.__dict__) | |
| sys.modules["target"] = mod | |
| tm = types.ModuleType("_tm") | |
| exec(compile(tc, "<tc>", "exec"), tm.__dict__) | |
| sys.modules["_tm"] = tm | |
| suite = unittest.TestLoader().loadTestsFromModule(tm) | |
| buf = io.StringIO() | |
| r = unittest.TextTestRunner(stream=buf, verbosity=0).run(suite) | |
| print("KILLED" if (r.failures or r.errors) else "SURVIVED") | |
| """ | |
| try: | |
| r = subprocess.run([sys.executable, "-c", script], | |
| capture_output=True, text=True, timeout=timeout) | |
| if "KILLED" in r.stdout: return "killed" | |
| if "SURVIVED" in r.stdout: return "survived" | |
| return "killed" | |
| except subprocess.TimeoutExpired: | |
| return "timeout" | |
| AOR = {ast.Add: ast.Sub, ast.Sub: ast.Add, ast.Mult: ast.Div, ast.Div: ast.Mult, ast.Mod: ast.Add} | |
| ROR = {ast.Eq: ast.NotEq, ast.NotEq: ast.Eq, ast.Lt: ast.Gt, ast.Gt: ast.Lt, | |
| ast.LtE: ast.GtE, ast.GtE: ast.LtE} | |
| LCR = {ast.And: ast.Or, ast.Or: ast.And} | |
| def make_mutants(fn_code): | |
| tree = ast.parse(fn_code) | |
| mutants = [] | |
| def mutate(pred, do_replace): | |
| for node in ast.walk(copy.deepcopy(tree)): | |
| if not pred(node): continue | |
| m = copy.deepcopy(tree) | |
| for n in ast.walk(m): | |
| if pred(n) and getattr(n, "lineno", -1) == getattr(node, "lineno", -2): | |
| do_replace(n); break | |
| ast.fix_missing_locations(m) | |
| try: mutants.append(ast.unparse(m)) | |
| except: pass | |
| mutate(lambda n: isinstance(n, ast.BinOp) and type(n.op) in AOR, | |
| lambda n: setattr(n, "op", AOR[type(n.op)]())) | |
| mutate(lambda n: isinstance(n, ast.Compare) and n.ops and type(n.ops[0]) in ROR, | |
| lambda n: n.ops.__setitem__(0, ROR[type(n.ops[0])]())) | |
| mutate(lambda n: isinstance(n, ast.BoolOp) and type(n.op) in LCR, | |
| lambda n: setattr(n, "op", LCR[type(n.op)]())) | |
| mutate(lambda n: isinstance(n, ast.Return) and n.value is not None, | |
| lambda n: setattr(n, "value", ast.Constant(value=None))) | |
| mutate(lambda n: isinstance(n, ast.Constant) and n.value in (0, 1), | |
| lambda n: setattr(n, "value", 1 - n.value)) | |
| return mutants | |
| ASSERT_METHODS = {"assertEqual","assertNotEqual","assertTrue","assertFalse", | |
| "assertIn","assertAlmostEqual","assertRaises","assertGreater", | |
| "assertLess","assertIsNone","assertIsNotNone","assertNotIn"} | |
| def count_assertions(fn_node): | |
| count = 0 | |
| for n in ast.walk(fn_node): | |
| if isinstance(n, ast.Assert): count += 1 | |
| elif (isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) | |
| and n.func.attr in ASSERT_METHODS): count += 1 | |
| return count | |
| def eval_coverage(fn_code, test_code, tmpdir): | |
| write_files(fn_code, test_code, tmpdir) | |
| r1 = subprocess.run([sys.executable, "-m", "coverage", "run", "--branch", | |
| "--include=target.py", "-m", "pytest", "test_target.py", "-q", "--tb=no"], | |
| capture_output=True, text=True, cwd=tmpdir) | |
| subprocess.run([sys.executable, "-m", "coverage", "json", "-o", "cov.json"], | |
| capture_output=True, text=True, cwd=tmpdir) | |
| cov_json = Path(tmpdir, "cov.json") | |
| if not cov_json.exists(): | |
| return {} | |
| raw = json.loads(cov_json.read_text()) | |
| fd = next((v for k, v in raw["files"].items() if k.endswith("target.py")), None) | |
| if not fd: return {} | |
| s = fd["summary"] | |
| missing_br = fd.get("missing_branches", []) | |
| total_br = s.get("num_branches", 0) | |
| branch_pct = round(100 * (total_br - len(missing_br)) / total_br, 1) if total_br else 100.0 | |
| return { | |
| "line_coverage_pct": round(s.get("percent_covered", 0), 1), | |
| "branch_coverage_pct": branch_pct, | |
| "missing_lines": fd.get("missing_lines", []), | |
| "missing_branches": missing_br, | |
| } | |
| def eval_mutation(fn_code, test_code): | |
| mutants = make_mutants(fn_code) | |
| if not mutants: | |
| return {"total_mutants": 0, "killed": 0, "mutation_score": 100.0} | |
| killed = sum(1 for m in mutants if run_mutant(m, test_code) in ("killed", "timeout")) | |
| return { | |
| "total_mutants": len(mutants), | |
| "killed": killed, | |
| "mutation_score": round(100 * killed / len(mutants), 1), | |
| } | |
| def eval_static(fn_code, test_code): | |
| test_tree = ast.parse(test_code) | |
| test_fns = [n for n in ast.walk(test_tree) | |
| if isinstance(n, ast.FunctionDef) and n.name.startswith("test")] | |
| assertions = [count_assertions(fn) for fn in test_fns] | |
| total = sum(assertions) | |
| density = round(total / len(test_fns), 2) if test_fns else 0.0 | |
| return { | |
| "total_test_functions": len(test_fns), | |
| "total_assertions": total, | |
| "assertion_density": density, | |
| } | |
| def evaluate(fn_code, test_code): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| cov = eval_coverage(fn_code, test_code, tmp) | |
| mut = eval_mutation(fn_code, test_code) | |
| sta = eval_static(fn_code, test_code) | |
| return {"coverage": cov, "mutation": mut, "static": sta} | |
| def compute_score(eval_result): | |
| cov = eval_result.get("coverage", {}) | |
| mut = eval_result.get("mutation", {}) | |
| sta = eval_result.get("static", {}) | |
| line = cov.get("line_coverage_pct", 0) / 100 | |
| branch = cov.get("branch_coverage_pct", 0) / 100 | |
| mut_s = mut.get("mutation_score", 0) / 100 | |
| dens = min(sta.get("assertion_density", 0) / 2, 1.0) # cap at 1 | |
| score = (WEIGHTS["line_coverage"] * line + | |
| WEIGHTS["branch_coverage"] * branch + | |
| WEIGHTS["mutation_score"] * mut_s + | |
| WEIGHTS["assertion_density"] * dens) | |
| return round(score, 4) | |
| # ββ model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(use_finetuned=True): | |
| print("loading model...") | |
| if use_finetuned: | |
| tokenizer = AutoTokenizer.from_pretrained(FINETUNED_DIR) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, device_map="auto", dtype=torch.float16) | |
| model = PeftModel.from_pretrained(base_model, FINETUNED_DIR) | |
| print("using fine-tuned model") | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, device_map="auto", dtype=torch.float16) | |
| print("using base model (before fine-tuning)") | |
| model.eval() | |
| # warn if running on CPU β will be very slow | |
| device = next(model.parameters()).device | |
| if str(device) == "cpu": | |
| print(" WARNING: model is on CPU, generation will be very slow") | |
| print(" consider using a smaller model or a machine with GPU") | |
| else: | |
| print(f" running on {device}") | |
| print("model loaded\n") | |
| return model, tokenizer | |
| def generate_tests(model, tokenizer, fn_code, old_test_code): | |
| prompt = f"""### Instruction: | |
| Given the following Python function and its current test cases, generate improved test cases. | |
| The new tests should have better coverage, catch more bugs, and include edge cases. | |
| ### Function: | |
| {fn_code} | |
| ### Current tests: | |
| {old_test_code} | |
| ### Response: | |
| """ | |
| device = next(model.parameters()).device | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| ) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| code = generated.split("### Response:")[-1].strip() | |
| # model sometimes cuts off mid-line β trim to the last complete line | |
| # that ends with a valid Python ending character | |
| lines = code.splitlines() | |
| clean = [] | |
| last_valid = [] | |
| for line in lines: | |
| clean.append(line) | |
| try: | |
| ast.parse("\n".join(clean)) | |
| last_valid = list(clean) | |
| except SyntaxError: | |
| pass | |
| return "\n".join(last_valid) if last_valid else code | |
| # ββ main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_csv(path): | |
| rows = [] | |
| with open(path, newline="", encoding="utf-8") as f: | |
| for row in csv.DictReader(f): | |
| rows.append({"function_code": row["function_code"].strip(), | |
| "test_code": row["test_code"].strip()}) | |
| return rows | |
| def main(): | |
| rows = load_csv(INPUT_CSV) | |
| use_finetuned = True # set False to test base model | |
| model, tokenizer = load_model(use_finetuned=use_finetuned) # set True to use fine-tuned | |
| results = [] | |
| for i, row in enumerate(rows): | |
| fn_code = row["function_code"] | |
| old_tests = row["test_code"] | |
| try: | |
| fn_name = next(n.name for n in ast.walk(ast.parse(fn_code)) | |
| if isinstance(n, ast.FunctionDef)) | |
| except: | |
| fn_name = f"function_{i+1}" | |
| print(f"[{i+1}/{len(rows)}] {fn_name}") | |
| # step 1 - evaluate old tests | |
| print(" evaluating old tests...") | |
| old_eval = evaluate(fn_code, old_tests) | |
| old_score = compute_score(old_eval) | |
| print(f" old score: {old_score}") | |
| # step 2 - generate new tests (retry up to 3 times if finetuned) | |
| max_attempts = 3 if use_finetuned else 1 | |
| new_tests = new_eval = None | |
| new_score = 0.0 | |
| for attempt in range(1, max_attempts + 1): | |
| if max_attempts > 1: | |
| print(f" generating new tests (attempt {attempt}/{max_attempts})...") | |
| else: | |
| print(" generating new tests...") | |
| candidate = generate_tests(model, tokenizer, fn_code, old_tests) | |
| print(" ββ generated tests ββββββββββββββββββββββββββββββ") | |
| print(candidate) | |
| print(" βββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| try: | |
| candidate_eval = evaluate(fn_code, candidate) | |
| candidate_score = compute_score(candidate_eval) | |
| except Exception as e: | |
| candidate_eval = {"error": str(e)} | |
| candidate_score = 0.0 | |
| print(f" score: {candidate_score}") | |
| if candidate_score > old_score: | |
| new_tests = candidate | |
| new_eval = candidate_eval | |
| new_score = candidate_score | |
| print(f" improvement found: {round((new_score - old_score) * 100, 2)}%\n") | |
| break | |
| else: | |
| print(f" no improvement (attempt {attempt}/{max_attempts})") | |
| if attempt == max_attempts: | |
| print(" failed to generate better tests\n") | |
| new_tests = candidate | |
| new_eval = candidate_eval | |
| new_score = candidate_score | |
| results.append({ | |
| "function_name": fn_name, | |
| "function_code": fn_code, | |
| "old_test_code": old_tests, | |
| "new_test_code": new_tests, | |
| "old_score": old_score, | |
| "new_score": new_score, | |
| "improvement": round(new_score - old_score, 4), | |
| "old_evaluation": old_eval, | |
| "new_evaluation": new_eval, | |
| }) | |
| Path(OUTPUT_FILE).write_text(json.dumps(results, indent=2)) | |
| print(f"saved -> {OUTPUT_FILE}") | |
| # print summary | |
| print("\n" + "="*60) | |
| print(f" {'function':<25} {'old':>6} {'new':>6} {'diff':>7}") | |
| print(" " + "-"*58) | |
| for r in results: | |
| diff = f"+{r['improvement']*100:.1f}%" if r['improvement'] >= 0 else f"{r['improvement']*100:.1f}%" | |
| print(f" {r['function_name']:<25} {r['old_score']:>6.3f} {r['new_score']:>6.3f} {diff:>7}") | |
| avg = sum(r['improvement'] for r in results) / len(results) if results else 0 | |
| print(f"\n average improvement: {avg*100:.2f}%") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() |