import os import sys import argparse import json import re import subprocess import tempfile import time import numpy as np from threading import Thread import tqdm from collections import defaultdict def pass_at_k(n, c, k): """ Calculate pass@k :param n: total number of samples :param c: number of correct samples :param k: k in pass@k """ if n - c < k: return 1.0 return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) def read_test_cases(test_file): """ Read test cases from the test file in JSONL format. Returns a dictionary mapping task_id to test data. """ tests = {} try: with open(test_file, 'r') as f: for line in f: item = json.loads(line.strip()) if "task_id" in item: tests[item["task_id"]] = item except Exception as e: print(f"Error reading test file: {e}") sys.exit(1) return tests def read_description_file(description_file): """ Read problem descriptions from a file in JSONL format. Returns a dictionary mapping task_id to description data. """ descriptions = {} try: with open(description_file, 'r') as f: for line in f: item = json.loads(line.strip()) if "task_id" in item: descriptions[item["task_id"]] = item except Exception as e: print(f"Error reading description file: {e}") sys.exit(1) return descriptions def read_completions_file(completions_file): """ Read completions from file, handling different possible formats. """ try: with open(completions_file, 'r') as f: content = f.read().strip() # Check if it's already in the target format if content.startswith('[') and content.endswith(']'): return json.loads(content) # Try as JSONL tasks = [] for line in content.split('\n'): if line.strip(): try: task = json.loads(line) tasks.append(task) except: pass if tasks: return tasks # If we get here, something is wrong print(f"Could not parse the completions file: {completions_file}") print(f"First 200 characters: {content[:200]}") sys.exit(1) except Exception as e: print(f"Error reading completions file: {e}") sys.exit(1) def exec_command(cmd, timeout=5): """Execute a command with timeout and return result, stdout, and stderr""" try: result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=timeout ) return result.returncode, result.stdout, result.stderr except subprocess.TimeoutExpired: return -1, "", "Command timed out" except Exception as e: return -1, "", str(e) def check_syntax_and_functionality(task_id, completion, test_code, timeout=5): """ Check the syntax and functionality of a completion. Returns a tuple (syntax_success, func_success). """ # Create temporary directory for testing with tempfile.TemporaryDirectory() as temp_dir: # Write completion to a file design_file = os.path.join(temp_dir, f"{task_id}.v") with open(design_file, 'w') as f: f.write(completion) # Write test code to a file test_file = os.path.join(temp_dir, f"{task_id}_tb.v") with open(test_file, 'w') as f: f.write(test_code) # Check syntax syntax_cmd = ["iverilog", "-Wall", "-g2012", design_file] syntax_code, syntax_out, syntax_err = exec_command(syntax_cmd, timeout) syntax_success = syntax_code == 0 # Check functionality if syntax_success: # Compile the design and testbench vvp_file = os.path.join(temp_dir, "sim.vvp") compile_cmd = ["iverilog", "-Wall", "-g2012", "-s", "tb", "-o", vvp_file, design_file, test_file] compile_code, compile_out, compile_err = exec_command(compile_cmd, timeout) if compile_code == 0: # Run the simulation run_cmd = ["vvp", vvp_file] run_code, run_out, run_err = exec_command(run_cmd, timeout) # Check if simulation passed func_success = False if run_code == 0: # Look for success patterns in output if "PASS" in run_out or "Your Design Passed" in run_out: func_success = True # Mismatches pattern check - if 0 mismatches, it passed match = re.search(r'Mismatches: ([0-9]*) in ([0-9]*) samples', run_out) if match and int(match.group(1)) == 0: func_success = True else: func_success = False else: func_success = False return syntax_success, func_success def prepare_formatted_output(tasks, test_cases, descriptions): """ Ensure tasks are in the correct format with all required fields. """ # Convert to the format we need formatted_tasks = [] # Group completions by task_id if needed if all(isinstance(t, dict) and ("task_id" in t or "completion" in t) for t in tasks): # Need to group by task_id by_task = defaultdict(list) for t in tasks: task_id = t.get("task_id", "unknown") completion = t.get("completion", "") by_task[task_id].append(completion) # Create formatted tasks for task_id, responses in by_task.items(): spec = "" if task_id in descriptions: spec = descriptions[task_id].get("detail_description", "") # Try to get from test case prompt if available if task_id in test_cases and "prompt" in test_cases[task_id]: spec = test_cases[task_id]["prompt"] formatted_tasks.append({ "name": task_id, "spec": spec, "generated_responses": responses }) else: # Assume tasks are already in the right structure or close to it for task in tasks: task_id = task.get("name", "unknown") # Get spec spec = task.get("spec", "") if not spec and task_id in descriptions: spec = descriptions[task_id].get("detail_description", "") if task_id in test_cases and "prompt" in test_cases[task_id]: spec = test_cases[task_id]["prompt"] # Get responses responses = task.get("generated_responses", []) formatted_tasks.append({ "name": task_id, "spec": spec, "generated_responses": responses }) return formatted_tasks def main(): parser = argparse.ArgumentParser(description="Convert completions to required format and evaluate") parser.add_argument("--completions-file", type=str, required=True, help="Path to the completions file (JSON or JSONL)") parser.add_argument("--test-file", type=str, required=True, help="Path to the test cases file") parser.add_argument("--description-file", type=str, required=False, help="Path to the problem descriptions file") parser.add_argument("--output-dir", type=str, default="./outputs", help="Directory to store outputs") parser.add_argument("--gen-name", type=str, default="qwen_rtlcoder", help="Name of the generation model") parser.add_argument("--num-samples", type=int, default=20, help="Number of samples per task to include") parser.add_argument("--evaluate", action="store_true", help="Run evaluation on the responses") parser.add_argument("--timeout", type=int, default=5, help="Timeout in seconds for command execution") args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Read test cases print(f"Reading test cases from {args.test_file}") test_cases = read_test_cases(args.test_file) print(f"Loaded {len(test_cases)} test cases") # Read descriptions if provided descriptions = {} if args.description_file: print(f"Reading descriptions from {args.description_file}") descriptions = read_description_file(args.description_file) print(f"Loaded {len(descriptions)} descriptions") # Read completions print(f"Reading completions from {args.completions_file}") tasks = read_completions_file(args.completions_file) print(f"Loaded data with {len(tasks)} tasks") # Format tasks formatted_tasks = prepare_formatted_output(tasks, test_cases, descriptions) # Save formatted tasks output_format_file = os.path.join(args.output_dir, f"{args.gen_name}.json") with open(output_format_file, 'w') as f: json.dump(formatted_tasks, f, indent=4) print(f"Formatted results saved to {output_format_file}") # Evaluate if requested if args.evaluate: # Results dictionary evaluation_results = {} # Process each task for task in tqdm.tqdm(formatted_tasks, desc="Evaluating tasks"): task_id = task["name"] if task_id not in test_cases: continue test_code = test_cases[task_id]["test"] responses = task["generated_responses"][:args.num_samples] syntax_pass = 0 func_pass = 0 for resp in responses: result = check_syntax_and_functionality(task_id, resp, test_code, args.timeout) if result[0]: syntax_pass += 1 if result[1]: func_pass += 1 # Store results evaluation_results[task_id] = { "syntax_success": syntax_pass, "func_success": func_pass, "num_samples": len(responses), "syntax_pass_rate": syntax_pass / len(responses) if responses else 0, "func_pass_rate": func_pass / len(responses) if responses else 0 } # Calculate pass@1 and pass@5 if len(responses) >= 1: evaluation_results[task_id]["syntax_pass@1"] = pass_at_k(len(responses), syntax_pass, 1) evaluation_results[task_id]["func_pass@1"] = pass_at_k(len(responses), func_pass, 1) if len(responses) >= 5: evaluation_results[task_id]["syntax_pass@5"] = pass_at_k(len(responses), syntax_pass, 5) evaluation_results[task_id]["func_pass@5"] = pass_at_k(len(responses), func_pass, 5) # Calculate overall results total_tasks = len(evaluation_results) total_samples = sum(result["num_samples"] for result in evaluation_results.values()) total_syntax = sum(result["syntax_success"] for result in evaluation_results.values()) total_func = sum(result["func_success"] for result in evaluation_results.values()) overall_syntax_rate = total_syntax / total_samples if total_samples else 0 overall_func_rate = total_func / total_samples if total_samples else 0 # Calculate overall pass@k metrics syntax_pass1_values = [r["syntax_pass@1"] for r in evaluation_results.values() if "syntax_pass@1" in r] func_pass1_values = [r["func_pass@1"] for r in evaluation_results.values() if "func_pass@1" in r] syntax_pass5_values = [r["syntax_pass@5"] for r in evaluation_results.values() if "syntax_pass@5" in r] func_pass5_values = [r["func_pass@5"] for r in evaluation_results.values() if "func_pass@5" in r] # Prepare overall metrics overall_metrics = { "num_tasks": total_tasks, "num_samples": total_samples, "overall_syntax_pass_rate": overall_syntax_rate, "overall_func_pass_rate": overall_func_rate } if syntax_pass1_values: overall_metrics["overall_syntax_pass@1"] = sum(syntax_pass1_values) / len(syntax_pass1_values) if func_pass1_values: overall_metrics["overall_func_pass@1"] = sum(func_pass1_values) / len(func_pass1_values) if syntax_pass5_values: overall_metrics["overall_syntax_pass@5"] = sum(syntax_pass5_values) / len(syntax_pass5_values) if func_pass5_values: overall_metrics["overall_func_pass@5"] = sum(func_pass5_values) / len(func_pass5_values) # Final results final_results = { "overall": overall_metrics, "per_task": evaluation_results } # Save evaluation results eval_output_file = os.path.join(args.output_dir, f"{args.gen_name}_eval.json") with open(eval_output_file, 'w') as f: json.dump(final_results, f, indent=4) # Print summary print("\nEvaluation Results:") print(f"Total tasks: {total_tasks}") print(f"Total samples: {total_samples}") print(f"Syntax pass rate: {overall_syntax_rate:.4f} ({total_syntax}/{total_samples})") print(f"Functionality pass rate: {overall_func_rate:.4f} ({total_func}/{total_samples})") if "overall_syntax_pass@1" in overall_metrics: print(f"Syntax pass@1: {overall_metrics['overall_syntax_pass@1']:.4f}") if "overall_func_pass@1" in overall_metrics: print(f"Func pass@1: {overall_metrics['overall_func_pass@1']:.4f}") if "overall_syntax_pass@5" in overall_metrics: print(f"Syntax pass@5: {overall_metrics['overall_syntax_pass@5']:.4f}") if "overall_func_pass@5" in overall_metrics: print(f"Func pass@5: {overall_metrics['overall_func_pass@5']:.4f}") print(f"Results saved to {eval_output_file}") if __name__ == "__main__": main()