SennPiee's picture
Rename appreal.py to app.py
f474b8f verified
Raw
History Blame Contribute Delete
2.24 kB
# app.py
import gradio as gr
import ast
import torch
from pipeline import load_model, generate_tests, evaluate, compute_score
# Load model once when app starts
model, tokenizer = load_model(use_finetuned=True)
def process_function(fn_code, old_test_code):
"""
Given a function and optional existing tests, generate improved tests
and compute evaluation metrics.
"""
if not fn_code.strip():
return "Error: Function code is empty.", "", ""
# Try to extract function name
try:
fn_name = next(n.name for n in ast.walk(ast.parse(fn_code))
if isinstance(n, ast.FunctionDef))
except:
fn_name = "function"
# Step 1: Evaluate old tests
old_eval = evaluate(fn_code, old_test_code)
old_score = compute_score(old_eval)
# Step 2: Generate new tests
new_tests = generate_tests(model, tokenizer, fn_code, old_test_code or "")
try:
new_eval = evaluate(fn_code, new_tests)
new_score = compute_score(new_eval)
except Exception as e:
new_eval = {"error": str(e)}
new_score = 0.0
improvement = new_score - old_score
summary = f"Function: {fn_name}\nOld score: {old_score}\nNew score: {new_score}\nImprovement: {round(improvement*100, 2)}%"
return summary, new_tests, str(new_eval)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🧪 PyTest Generator Demo")
gr.Markdown("Paste a Python function and optional existing test code. The model will generate improved test cases and provide evaluation metrics.")
with gr.Row():
fn_input = gr.Textbox(label="Function Code (exact code needed)", placeholder="def add(a, b): ...", lines=10)
test_input = gr.Textbox(label="Existing Tests (optional)", placeholder="def test_add(): ...", lines=10)
output_summary = gr.Textbox(label="Evaluation Summary", lines=5)
output_tests = gr.Textbox(label="Generated Test Code", lines=15)
output_eval = gr.Textbox(label="Evaluation Details (JSON)", lines=15)
run_btn = gr.Button("Generate Tests")
run_btn.click(fn=process_function,
inputs=[fn_input, test_input],
outputs=[output_summary, output_tests, output_eval])
demo.launch()