#!/usr/bin/env python import os from dotenv import load_dotenv from datasets import load_dataset, concatenate_datasets from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer ) from peft import get_peft_model import config # -------------------- # Prompt functions # -------------------- def build_prompt(code): return f"""### Instruction: Given the following Python function, generate a comprehensive pytest test suite. The test suite MUST: - Use pytest format - Include at least one test function - Wrap all assertions inside test functions - Be valid Python code ### Code: {code} ### Response: """ def format_example(example, tokenizer): code = example["code"] tests = "\n".join(example["test_list"]) prompt = build_prompt(code) full_text = prompt + tests + tokenizer.eos_token tokenized = tokenizer( full_text, truncation=True, padding="max_length", max_length=config.MAX_LENGTH ) labels = tokenized["input_ids"].copy() prompt_ids = tokenizer(prompt, truncation=True, max_length=config.MAX_LENGTH)["input_ids"] labels[:len(prompt_ids)] = [-100] * len(prompt_ids) tokenized["labels"] = labels return tokenized def convert_local(example): code = example["fn_code"] tests_raw = example["good_tests"] test_list = [line for line in tests_raw.split("\n") if line.strip()] return {"code": code, "test_list": test_list} def main(): load_dotenv() os.makedirs(config.OUTPUT_DIR, exist_ok=True) # -------------------- # Tokenizer / Model # -------------------- tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( config.MODEL_NAME, quantization_config=config.BNB_CONFIG, device_map="auto" ) # -------------------- # Dataset # -------------------- mbpp_dataset = load_dataset(*config.MBPP_DATASET, split="train") mbpp_dataset = mbpp_dataset.map(lambda x: {"code": x["code"], "test_list": x["test_list"]}) mbpp_dataset = mbpp_dataset.remove_columns( [col for col in mbpp_dataset.column_names if col not in ["code", "test_list"]] ) local_dataset = load_dataset("json", data_files=config.LOCAL_DATA_PATH, split="train") local_dataset = local_dataset.map(convert_local) local_dataset = local_dataset.remove_columns( [col for col in local_dataset.column_names if col not in ["code", "test_list"]] ) dataset = concatenate_datasets([mbpp_dataset, local_dataset]) dataset = dataset.shuffle(seed=config.SEED).train_test_split(test_size=config.TEST_SIZE) train_data = dataset["train"].map(lambda x: format_example(x, tokenizer)) eval_data = dataset["test"].map(lambda x: format_example(x, tokenizer)) # -------------------- # LoRA # -------------------- model = get_peft_model(model, config.LORA_CONFIG) # -------------------- # Training # -------------------- training_args = TrainingArguments( output_dir=config.OUTPUT_DIR, **config.TRAINING_ARGS ) trainer = Trainer( model=model, args=training_args, train_dataset=train_data, eval_dataset=eval_data ) trainer.train() # -------------------- # Save # -------------------- model.save_pretrained(config.OUTPUT_DIR) tokenizer.save_pretrained(config.OUTPUT_DIR) print(f"Model saved to {config.OUTPUT_DIR}") if __name__ == "__main__": main()