csci4130projectdemo / pipeline.py
SennPiee's picture
fix pipeline.py
1db7285 verified
Raw
History Blame Contribute Delete
13.5 kB
import ast
import os, sys, json, ast, csv, torch, tempfile, subprocess, copy
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from transformers import BitsAndBytesConfig
# config
BASE_MODEL = "stabilityai/stable-code-3b"
FINETUNED_DIR = r"finetuned_model"
INPUT_CSV = "functions.csv"
OUTPUT_FILE = "pipeline2_results.json"
# score weights (same as pipeline 1)
WEIGHTS = {
"line_coverage": 0.35,
"branch_coverage": 0.35,
"mutation_score": 0.20,
"assertion_density": 0.10,
}
# ── evaluation helpers (same logic as evaluate.py)
def write_files(fn_code, test_code, tmpdir):
Path(tmpdir, "target.py").write_text(fn_code, encoding="utf-8")
Path(tmpdir, "test_target.py").write_text(test_code, encoding="utf-8")
def run_mutant(mutant_src, test_code, timeout=5):
script = f"""
import sys, unittest, io, types
fn = {repr(mutant_src)}
tc = {repr(test_code)}
mod = types.ModuleType("target")
exec(compile(fn, "<t>", "exec"), mod.__dict__)
sys.modules["target"] = mod
tm = types.ModuleType("_tm")
exec(compile(tc, "<tc>", "exec"), tm.__dict__)
sys.modules["_tm"] = tm
suite = unittest.TestLoader().loadTestsFromModule(tm)
buf = io.StringIO()
r = unittest.TextTestRunner(stream=buf, verbosity=0).run(suite)
print("KILLED" if (r.failures or r.errors) else "SURVIVED")
"""
try:
r = subprocess.run([sys.executable, "-c", script],
capture_output=True, text=True, timeout=timeout)
if "KILLED" in r.stdout: return "killed"
if "SURVIVED" in r.stdout: return "survived"
return "killed"
except subprocess.TimeoutExpired:
return "timeout"
AOR = {ast.Add: ast.Sub, ast.Sub: ast.Add, ast.Mult: ast.Div, ast.Div: ast.Mult, ast.Mod: ast.Add}
ROR = {ast.Eq: ast.NotEq, ast.NotEq: ast.Eq, ast.Lt: ast.Gt, ast.Gt: ast.Lt,
ast.LtE: ast.GtE, ast.GtE: ast.LtE}
LCR = {ast.And: ast.Or, ast.Or: ast.And}
def make_mutants(fn_code):
tree = ast.parse(fn_code)
mutants = []
def mutate(pred, do_replace):
for node in ast.walk(copy.deepcopy(tree)):
if not pred(node): continue
m = copy.deepcopy(tree)
for n in ast.walk(m):
if pred(n) and getattr(n, "lineno", -1) == getattr(node, "lineno", -2):
do_replace(n); break
ast.fix_missing_locations(m)
try: mutants.append(ast.unparse(m))
except: pass
mutate(lambda n: isinstance(n, ast.BinOp) and type(n.op) in AOR,
lambda n: setattr(n, "op", AOR[type(n.op)]()))
mutate(lambda n: isinstance(n, ast.Compare) and n.ops and type(n.ops[0]) in ROR,
lambda n: n.ops.__setitem__(0, ROR[type(n.ops[0])]()))
mutate(lambda n: isinstance(n, ast.BoolOp) and type(n.op) in LCR,
lambda n: setattr(n, "op", LCR[type(n.op)]()))
mutate(lambda n: isinstance(n, ast.Return) and n.value is not None,
lambda n: setattr(n, "value", ast.Constant(value=None)))
mutate(lambda n: isinstance(n, ast.Constant) and n.value in (0, 1),
lambda n: setattr(n, "value", 1 - n.value))
return mutants
ASSERT_METHODS = {"assertEqual","assertNotEqual","assertTrue","assertFalse",
"assertIn","assertAlmostEqual","assertRaises","assertGreater",
"assertLess","assertIsNone","assertIsNotNone","assertNotIn"}
def count_assertions(fn_node):
count = 0
for n in ast.walk(fn_node):
if isinstance(n, ast.Assert): count += 1
elif (isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute)
and n.func.attr in ASSERT_METHODS): count += 1
return count
def eval_coverage(fn_code, test_code, tmpdir):
write_files(fn_code, test_code, tmpdir)
r1 = subprocess.run([sys.executable, "-m", "coverage", "run", "--branch",
"--include=target.py", "-m", "pytest", "test_target.py", "-q", "--tb=no"],
capture_output=True, text=True, cwd=tmpdir)
subprocess.run([sys.executable, "-m", "coverage", "json", "-o", "cov.json"],
capture_output=True, text=True, cwd=tmpdir)
cov_json = Path(tmpdir, "cov.json")
if not cov_json.exists():
return {}
raw = json.loads(cov_json.read_text())
fd = next((v for k, v in raw["files"].items() if k.endswith("target.py")), None)
if not fd: return {}
s = fd["summary"]
missing_br = fd.get("missing_branches", [])
total_br = s.get("num_branches", 0)
branch_pct = round(100 * (total_br - len(missing_br)) / total_br, 1) if total_br else 100.0
return {
"line_coverage_pct": round(s.get("percent_covered", 0), 1),
"branch_coverage_pct": branch_pct,
"missing_lines": fd.get("missing_lines", []),
"missing_branches": missing_br,
}
def eval_mutation(fn_code, test_code):
mutants = make_mutants(fn_code)
if not mutants:
return {"total_mutants": 0, "killed": 0, "mutation_score": 100.0}
killed = sum(1 for m in mutants if run_mutant(m, test_code) in ("killed", "timeout"))
return {
"total_mutants": len(mutants),
"killed": killed,
"mutation_score": round(100 * killed / len(mutants), 1),
}
def eval_static(fn_code, test_code):
test_tree = ast.parse(test_code)
test_fns = [n for n in ast.walk(test_tree)
if isinstance(n, ast.FunctionDef) and n.name.startswith("test")]
assertions = [count_assertions(fn) for fn in test_fns]
total = sum(assertions)
density = round(total / len(test_fns), 2) if test_fns else 0.0
return {
"total_test_functions": len(test_fns),
"total_assertions": total,
"assertion_density": density,
}
def evaluate(fn_code, test_code):
with tempfile.TemporaryDirectory() as tmp:
cov = eval_coverage(fn_code, test_code, tmp)
mut = eval_mutation(fn_code, test_code)
sta = eval_static(fn_code, test_code)
return {"coverage": cov, "mutation": mut, "static": sta}
def compute_score(eval_result):
cov = eval_result.get("coverage", {})
mut = eval_result.get("mutation", {})
sta = eval_result.get("static", {})
line = cov.get("line_coverage_pct", 0) / 100
branch = cov.get("branch_coverage_pct", 0) / 100
mut_s = mut.get("mutation_score", 0) / 100
dens = min(sta.get("assertion_density", 0) / 2, 1.0) # cap at 1
score = (WEIGHTS["line_coverage"] * line +
WEIGHTS["branch_coverage"] * branch +
WEIGHTS["mutation_score"] * mut_s +
WEIGHTS["assertion_density"] * dens)
return round(score, 4)
# ── model ─────────────────────────────────────────────────────────────────────
def load_model(use_finetuned=True):
print("loading model...")
if use_finetuned:
tokenizer = AutoTokenizer.from_pretrained(FINETUNED_DIR)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, device_map="auto", dtype=torch.float16)
model = PeftModel.from_pretrained(base_model, FINETUNED_DIR)
print("using fine-tuned model")
else:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, device_map="auto", dtype=torch.float16)
print("using base model (before fine-tuning)")
model.eval()
# warn if running on CPU β€” will be very slow
device = next(model.parameters()).device
if str(device) == "cpu":
print(" WARNING: model is on CPU, generation will be very slow")
print(" consider using a smaller model or a machine with GPU")
else:
print(f" running on {device}")
print("model loaded\n")
return model, tokenizer
def generate_tests(model, tokenizer, fn_code, old_test_code):
prompt = f"""### Instruction:
Given the following Python function and its current test cases, generate improved test cases.
The new tests should have better coverage, catch more bugs, and include edge cases.
### Function:
{fn_code}
### Current tests:
{old_test_code}
### Response:
"""
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
code = generated.split("### Response:")[-1].strip()
# model sometimes cuts off mid-line β€” trim to the last complete line
# that ends with a valid Python ending character
lines = code.splitlines()
clean = []
last_valid = []
for line in lines:
clean.append(line)
try:
ast.parse("\n".join(clean))
last_valid = list(clean)
except SyntaxError:
pass
return "\n".join(last_valid) if last_valid else code
# ── main ──────────────────────────────────────────────────────────────────────
def load_csv(path):
rows = []
with open(path, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
rows.append({"function_code": row["function_code"].strip(),
"test_code": row["test_code"].strip()})
return rows
def main():
rows = load_csv(INPUT_CSV)
use_finetuned = True # set False to test base model
model, tokenizer = load_model(use_finetuned=use_finetuned) # set True to use fine-tuned
results = []
for i, row in enumerate(rows):
fn_code = row["function_code"]
old_tests = row["test_code"]
try:
fn_name = next(n.name for n in ast.walk(ast.parse(fn_code))
if isinstance(n, ast.FunctionDef))
except:
fn_name = f"function_{i+1}"
print(f"[{i+1}/{len(rows)}] {fn_name}")
# step 1 - evaluate old tests
print(" evaluating old tests...")
old_eval = evaluate(fn_code, old_tests)
old_score = compute_score(old_eval)
print(f" old score: {old_score}")
# step 2 - generate new tests (retry up to 3 times if finetuned)
max_attempts = 3 if use_finetuned else 1
new_tests = new_eval = None
new_score = 0.0
for attempt in range(1, max_attempts + 1):
if max_attempts > 1:
print(f" generating new tests (attempt {attempt}/{max_attempts})...")
else:
print(" generating new tests...")
candidate = generate_tests(model, tokenizer, fn_code, old_tests)
print(" ── generated tests ──────────────────────────────")
print(candidate)
print(" ─────────────────────────────────────────────────")
try:
candidate_eval = evaluate(fn_code, candidate)
candidate_score = compute_score(candidate_eval)
except Exception as e:
candidate_eval = {"error": str(e)}
candidate_score = 0.0
print(f" score: {candidate_score}")
if candidate_score > old_score:
new_tests = candidate
new_eval = candidate_eval
new_score = candidate_score
print(f" improvement found: {round((new_score - old_score) * 100, 2)}%\n")
break
else:
print(f" no improvement (attempt {attempt}/{max_attempts})")
if attempt == max_attempts:
print(" failed to generate better tests\n")
new_tests = candidate
new_eval = candidate_eval
new_score = candidate_score
results.append({
"function_name": fn_name,
"function_code": fn_code,
"old_test_code": old_tests,
"new_test_code": new_tests,
"old_score": old_score,
"new_score": new_score,
"improvement": round(new_score - old_score, 4),
"old_evaluation": old_eval,
"new_evaluation": new_eval,
})
Path(OUTPUT_FILE).write_text(json.dumps(results, indent=2))
print(f"saved -> {OUTPUT_FILE}")
# print summary
print("\n" + "="*60)
print(f" {'function':<25} {'old':>6} {'new':>6} {'diff':>7}")
print(" " + "-"*58)
for r in results:
diff = f"+{r['improvement']*100:.1f}%" if r['improvement'] >= 0 else f"{r['improvement']*100:.1f}%"
print(f" {r['function_name']:<25} {r['old_score']:>6.3f} {r['new_score']:>6.3f} {diff:>7}")
avg = sum(r['improvement'] for r in results) / len(results) if results else 0
print(f"\n average improvement: {avg*100:.2f}%")
print("="*60)
if __name__ == "__main__":
main()