from __future__ import annotations import os import re from dataclasses import dataclass, field from pathlib import Path from typing import Optional from datasets import load_dataset from torch.utils.data import Dataset as TorchDataset from transformers import HfArgumentParser, set_seed from trainer.tivd.online_trainer import ( TIVDConfig, TIVDTrainer, assert_qwen_tokenizer_compatibility, build_student_model, build_teacher_model, build_tokenizer, copy_training_sources, render_math_prompt, ) @dataclass class DataArguments: dataset_name: str = field(default="openai/gsm8k") dataset_config_name: Optional[str] = field(default="main") dataset_split: str = field(default="train") question_column: str = field(default="question") answer_column: str = field(default="answer") final_answer_column: str = field(default="") difficulty_column: str = field(default="") topic_column: str = field(default="") solution_columns: str = field(default="") limit: Optional[int] = field(default=None) class PromptListDataset(TorchDataset): """Simple Python dataset wrapper to avoid Arrow batched-indexing quirks in custom Trainer flows.""" def __init__(self, rows: list[dict]): self.rows = rows def __len__(self) -> int: return len(self.rows) def __getitem__(self, idx: int) -> dict: return self.rows[idx] def _parse_gsm8k_final_answer(answer_text: Optional[str]) -> Optional[str]: if not answer_text: return None match = re.search(r"####\s*(.+)$", answer_text.strip(), flags=re.MULTILINE) if match: return match.group(1).strip() return answer_text.strip().splitlines()[-1].strip() def build_filtered_dataset(data_args: DataArguments, train_args: TIVDConfig) -> PromptListDataset: load_kwargs = {"path": data_args.dataset_name, "split": data_args.dataset_split} if data_args.dataset_config_name: load_kwargs["name"] = data_args.dataset_config_name dataset = load_dataset(**load_kwargs) if data_args.difficulty_column and data_args.difficulty_column in dataset.column_names: dataset = dataset.filter( lambda ex: ex.get(data_args.difficulty_column) is not None and float(ex[data_args.difficulty_column]) >= float(train_args.difficulty_threshold), desc=f"Filtering difficulty >= {train_args.difficulty_threshold}", ) if data_args.limit is not None: dataset = dataset.select(range(min(len(dataset), data_args.limit))) solution_columns = [col.strip() for col in data_args.solution_columns.split(",") if col.strip()] rows: list[dict] = [] for example in dataset: raw_answer = example.get(data_args.answer_column) if data_args.answer_column else None if data_args.final_answer_column: final_answer = example.get(data_args.final_answer_column) else: final_answer = _parse_gsm8k_final_answer(raw_answer) row = { "prompt": render_math_prompt(example[data_args.question_column]), "question": example[data_args.question_column], "final_answer": final_answer, "answer": raw_answer, "difficulty": float(example.get(data_args.difficulty_column, 0.0) or 0.0) if data_args.difficulty_column and data_args.difficulty_column in example else 0.0, "topic": example.get(data_args.topic_column) if data_args.topic_column else None, } for col in solution_columns: if col in example: row[col] = example[col] rows.append(row) return PromptListDataset(rows) def main() -> None: parser = HfArgumentParser((TIVDConfig, DataArguments)) train_args, data_args = parser.parse_args_into_dataclasses() train_args.remove_unused_columns = False train_args.label_names = [] if train_args.wandb_project: os.environ.setdefault("WANDB_PROJECT", train_args.wandb_project) if train_args.wandb_run_name: os.environ.setdefault("WANDB_NAME", train_args.wandb_run_name) Path(train_args.output_dir).mkdir(parents=True, exist_ok=True) set_seed(train_args.seed) world_size = int(os.environ.get("WORLD_SIZE", "1")) if train_args.use_vllm and train_args.vllm_mode == "server" and world_size > 1: raise ValueError( "For this trainer, server-mode vLLM should be run with a single training process. " "Use accelerate launch --num_processes 1 so training stays on one GPU and the vLLM server on another, " "or use --vllm_mode colocate for same-GPU execution." ) student_tokenizer = build_tokenizer(train_args.student_model_name_or_path, train_args.trust_remote_code) teacher_tokenizer = build_tokenizer(train_args.teacher_model_name_or_path, train_args.trust_remote_code) assert_qwen_tokenizer_compatibility(student_tokenizer, teacher_tokenizer) train_dataset = build_filtered_dataset(data_args, train_args) student_model = build_student_model(train_args) teacher_model = build_teacher_model(train_args) copy_training_sources(train_args.output_dir, __file__, Path(__file__).parent / "online_trainer.py") trainer = TIVDTrainer( model=student_model, args=train_args, tokenizer=student_tokenizer, teacher_model=teacher_model, target_model=None, train_dataset=train_dataset, eval_dataset=None, ref_model=None, source_file_paths=[__file__, str(Path(__file__).parent / "online_trainer.py")], ) train_result = trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint) trainer.save_model(train_args.output_dir) student_tokenizer.save_pretrained(train_args.output_dir) metrics = train_result.metrics metrics["train_examples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() if train_args.push_to_hub: kwargs = {} if train_args.hub_model_id: kwargs["repo_id"] = train_args.hub_model_id trainer.push_to_hub(**kwargs) if __name__ == "__main__": main()