swadeshb's picture
Training in progress, step 100
11a6bae verified
from __future__ import annotations
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from datasets import load_dataset
from torch.utils.data import Dataset as TorchDataset
from transformers import HfArgumentParser, set_seed
from trainer.tivd.online_trainer import (
TIVDConfig,
TIVDTrainer,
assert_qwen_tokenizer_compatibility,
build_student_model,
build_teacher_model,
build_tokenizer,
copy_training_sources,
render_math_prompt,
)
@dataclass
class DataArguments:
dataset_name: str = field(default="openai/gsm8k")
dataset_config_name: Optional[str] = field(default="main")
dataset_split: str = field(default="train")
question_column: str = field(default="question")
answer_column: str = field(default="answer")
final_answer_column: str = field(default="")
difficulty_column: str = field(default="")
topic_column: str = field(default="")
solution_columns: str = field(default="")
limit: Optional[int] = field(default=None)
class PromptListDataset(TorchDataset):
"""Simple Python dataset wrapper to avoid Arrow batched-indexing quirks in custom Trainer flows."""
def __init__(self, rows: list[dict]):
self.rows = rows
def __len__(self) -> int:
return len(self.rows)
def __getitem__(self, idx: int) -> dict:
return self.rows[idx]
def _parse_gsm8k_final_answer(answer_text: Optional[str]) -> Optional[str]:
if not answer_text:
return None
match = re.search(r"####\s*(.+)$", answer_text.strip(), flags=re.MULTILINE)
if match:
return match.group(1).strip()
return answer_text.strip().splitlines()[-1].strip()
def build_filtered_dataset(data_args: DataArguments, train_args: TIVDConfig) -> PromptListDataset:
load_kwargs = {"path": data_args.dataset_name, "split": data_args.dataset_split}
if data_args.dataset_config_name:
load_kwargs["name"] = data_args.dataset_config_name
dataset = load_dataset(**load_kwargs)
if data_args.difficulty_column and data_args.difficulty_column in dataset.column_names:
dataset = dataset.filter(
lambda ex: ex.get(data_args.difficulty_column) is not None
and float(ex[data_args.difficulty_column]) >= float(train_args.difficulty_threshold),
desc=f"Filtering difficulty >= {train_args.difficulty_threshold}",
)
if data_args.limit is not None:
dataset = dataset.select(range(min(len(dataset), data_args.limit)))
solution_columns = [col.strip() for col in data_args.solution_columns.split(",") if col.strip()]
rows: list[dict] = []
for example in dataset:
raw_answer = example.get(data_args.answer_column) if data_args.answer_column else None
if data_args.final_answer_column:
final_answer = example.get(data_args.final_answer_column)
else:
final_answer = _parse_gsm8k_final_answer(raw_answer)
row = {
"prompt": render_math_prompt(example[data_args.question_column]),
"question": example[data_args.question_column],
"final_answer": final_answer,
"answer": raw_answer,
"difficulty": float(example.get(data_args.difficulty_column, 0.0) or 0.0)
if data_args.difficulty_column and data_args.difficulty_column in example
else 0.0,
"topic": example.get(data_args.topic_column) if data_args.topic_column else None,
}
for col in solution_columns:
if col in example:
row[col] = example[col]
rows.append(row)
return PromptListDataset(rows)
def main() -> None:
parser = HfArgumentParser((TIVDConfig, DataArguments))
train_args, data_args = parser.parse_args_into_dataclasses()
train_args.remove_unused_columns = False
train_args.label_names = []
if train_args.wandb_project:
os.environ.setdefault("WANDB_PROJECT", train_args.wandb_project)
if train_args.wandb_run_name:
os.environ.setdefault("WANDB_NAME", train_args.wandb_run_name)
Path(train_args.output_dir).mkdir(parents=True, exist_ok=True)
set_seed(train_args.seed)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
if train_args.use_vllm and train_args.vllm_mode == "server" and world_size > 1:
raise ValueError(
"For this trainer, server-mode vLLM should be run with a single training process. "
"Use accelerate launch --num_processes 1 so training stays on one GPU and the vLLM server on another, "
"or use --vllm_mode colocate for same-GPU execution."
)
student_tokenizer = build_tokenizer(train_args.student_model_name_or_path, train_args.trust_remote_code)
teacher_tokenizer = build_tokenizer(train_args.teacher_model_name_or_path, train_args.trust_remote_code)
assert_qwen_tokenizer_compatibility(student_tokenizer, teacher_tokenizer)
train_dataset = build_filtered_dataset(data_args, train_args)
student_model = build_student_model(train_args)
teacher_model = build_teacher_model(train_args)
copy_training_sources(train_args.output_dir, __file__, Path(__file__).parent / "online_trainer.py")
trainer = TIVDTrainer(
model=student_model,
args=train_args,
tokenizer=student_tokenizer,
teacher_model=teacher_model,
target_model=None,
train_dataset=train_dataset,
eval_dataset=None,
ref_model=None,
source_file_paths=[__file__, str(Path(__file__).parent / "online_trainer.py")],
)
train_result = trainer.train(resume_from_checkpoint=train_args.resume_from_checkpoint)
trainer.save_model(train_args.output_dir)
student_tokenizer.save_pretrained(train_args.output_dir)
metrics = train_result.metrics
metrics["train_examples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if train_args.push_to_hub:
kwargs = {}
if train_args.hub_model_id:
kwargs["repo_id"] = train_args.hub_model_id
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
main()