| |
| import pandas as pd |
| import json |
| import os |
| import re |
|
|
| |
| def parse_llama_template(prompt_string): |
| """ |
| 解析 LLaMA 风格的模板字符串,还原成对话历史列表。 |
| |
| Args: |
| prompt_string (str): 包含 LLaMA 模板标记的字符串。 |
| |
| Returns: |
| list: 一个包含对话轮次的列表,每个轮次是 {"from": role, "value": content} 格式的字典。 |
| 如果解析失败或格式不符,可能返回空列表或部分解析结果。 |
| """ |
| conversation_history = [] |
| |
| |
| |
| pattern = r"<\|start_header_id\|>(.*?)<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>" |
| matches = re.finditer(pattern, prompt_string, re.DOTALL) |
| |
| for match in matches: |
| role_raw = match.group(1).strip().lower() |
| value = match.group(2).strip() |
| |
| |
| role_mapped = "" |
| if role_raw == "system": |
| role_mapped = "system" |
| elif role_raw == "user": |
| role_mapped = "human" |
| elif role_raw == "assistant": |
| role_mapped = "gpt" |
| else: |
| print(f"警告: 无法识别的角色 '{role_raw}',跳过此轮次。") |
| continue |
|
|
| |
| conversation_history.append({"from": role_mapped, "value": value}) |
| |
| if not conversation_history: |
| print("警告: 无法解析 LLaMA 模板,可能格式不正确或内容为空。") |
| |
| return conversation_history |
|
|
| |
| def normalize_conversation_history(raw_history, entry_index): |
| """ |
| 规范化原始解析出的对话历史,确保满足以下规则: |
| 1. 可选的单个 system 开头。 |
| 2. 第一个非 system 回合必须是 human (如果不是,前面插入空 human)。 |
| 3. 后续回合严格在 human 和 gpt 之间交替。 |
| 4. 连续相同角色的回合将被合并 (value 用 \n 连接)。 |
| |
| Args: |
| raw_history (list): 从 parse_llama_template 输出的原始对话历史列表。 |
| entry_index (int): 当前处理的数据行索引(用于日志记录)。 |
| |
| Returns: |
| list: 规范化后的对话历史列表。如果出现无法处理的结构错误,返回空列表 `[]`。 |
| """ |
| if not raw_history: |
| return [] |
|
|
| normalized_history = [] |
| processed_raw_index = 0 |
|
|
| |
| if raw_history[0]["from"] == "system": |
| normalized_history.append(raw_history[0]) |
| processed_raw_index = 1 |
| if len(raw_history) == 1: |
| print(f"警告 (行 {entry_index+1}): 原始历史只包含 system prompt,无法规范化为交替对话。") |
| return [] |
|
|
| |
| if processed_raw_index >= len(raw_history): |
| print(f"警告 (行 {entry_index+1}): 原始历史在 system prompt 后为空,无法规范化。") |
| return [] |
|
|
| |
| first_conv_turn = raw_history[processed_raw_index] |
| if first_conv_turn["from"] == "gpt": |
| |
| normalized_history.append({"from": "human", "value": ""}) |
| |
| elif first_conv_turn["from"] == "human": |
| |
| normalized_history.append(first_conv_turn) |
| processed_raw_index += 1 |
| else: |
| print(f"错误 (行 {entry_index+1}): 处理第一个对话回合时遇到非预期的角色 '{first_conv_turn['from']}'。") |
| return [] |
|
|
| |
| while processed_raw_index < len(raw_history): |
| current_raw_turn = raw_history[processed_raw_index] |
| last_normalized_turn = normalized_history[-1] |
|
|
| |
| if current_raw_turn["from"] == last_normalized_turn["from"]: |
| |
| |
| if last_normalized_turn["value"] and current_raw_turn["value"]: |
| last_normalized_turn["value"] += "\n" + current_raw_turn["value"] |
| elif current_raw_turn["value"]: |
| last_normalized_turn["value"] = current_raw_turn["value"] |
| |
| processed_raw_index += 1 |
| else: |
| |
| expected_next_role = "" |
| if last_normalized_turn["from"] == "human": |
| expected_next_role = "gpt" |
| elif last_normalized_turn["from"] == "gpt": |
| expected_next_role = "human" |
| |
|
|
| if current_raw_turn["from"] == expected_next_role: |
| |
| normalized_history.append(current_raw_turn) |
| processed_raw_index += 1 |
| else: |
| |
| print(f"错误 (行 {entry_index+1}): 对话角色顺序错误。期望在 '{last_normalized_turn['from']}' 之后是 '{expected_next_role}',但得到 '{current_raw_turn['from']}'。") |
| return [] |
|
|
| if normalized_history[-1]["from"] == "gpt": |
| del normalized_history[-1] |
| |
| |
| if len(normalized_history) > 21: |
| final_history = [normalized_history[0]] + normalized_history[-20:] |
| else: |
| final_history = normalized_history |
| return final_history |
|
|
| |
| def process_prompt(prompt_to_process, entry_index, chosen_response, rejected_response): |
| """ |
| 处理单个 Prompt,解析并规范化对话历史,生成 ShareGPT DPO 格式的条目。 |
| |
| Args: |
| prompt_to_process (str): 要解析的 Prompt 字符串(LLaMA 模板)。 |
| entry_index (int): 当前处理的数据行索引(用于日志记录)。 |
| chosen_response (str): 选择的响应。 |
| rejected_response (str): 拒绝的响应。 |
| |
| Returns: |
| dict: ShareGPT DPO 格式的条目,包含 conversations、chosen 和 rejected。 |
| """ |
| |
| raw_parsed_history = parse_llama_template(prompt_to_process) |
| conversation_history = normalize_conversation_history(raw_parsed_history, entry_index) |
| if len(conversation_history) == 0: |
| raise ValueError(f"错误 (行 {entry_index+1}): 解析出的对话历史为空,跳过此条记录。") |
| |
| |
| return { |
| "conversations": conversation_history, |
| "chosen": {"from": "gpt", "value": chosen_response}, |
| "rejected": {"from": "gpt", "value": rejected_response} |
| } |
|
|
| |
| def convert_parquet_to_sharegpt_dpo_llama(parquet_paths, json_path): |
| """ |
| 将 Parquet 格式的 DPO 数据转换为 ShareGPT DPO JSON 格式,针对 LLaMA 模板。 |
| |
| Args: |
| parquet_paths (str or list): 输入 Parquet 文件的路径(支持单个文件或文件列表)。 |
| json_path (str): 输出 JSON 文件的路径。 |
| """ |
| if not isinstance(parquet_paths, list): |
| parquet_paths = [parquet_paths] |
| |
| merged_data = [] |
|
|
| for parquet_path in parquet_paths: |
| print(f"开始转换文件: {parquet_path}") |
|
|
| try: |
| |
| df = pd.read_parquet(parquet_path) |
| print(f"成功读取 Parquet 文件,包含 {len(df)} 条记录。") |
|
|
| required_columns = ['chosen_prompt', 'reject_prompt', 'chosen', 'reject', 'chosen_model', 'reject_model'] |
| if not all(col in df.columns for col in required_columns): |
| missing = [col for col in required_columns if col not in df.columns] |
| print(f"错误: Parquet 文件缺少必需的列: {missing}") |
| return |
|
|
| except Exception as e: |
| print(f"错误: 读取 Parquet 文件时出错: {e}") |
| return |
|
|
| sharegpt_data = [] |
| skipped_basic_validation = 0 |
|
|
| print("开始处理数据行...") |
| for index, row in df.iterrows(): |
| chosen_prompt = row.get('chosen_prompt', None) |
| rejected_prompt = row.get('reject_prompt', None) |
| chosen_response = row.get('chosen', None) |
| rejected_response = row.get('reject', None) |
|
|
| |
| if '<|start_header_id|>' in chosen_prompt: |
| prompt = chosen_prompt |
| elif '<|start_header_id|>' in rejected_prompt: |
| prompt = rejected_prompt |
| else: |
| print(f"警告 (行 {index+1}): 没有 <|start_header_id|> 符号出现在 Prompt 中,跳过此条记录。") |
| skipped_basic_validation += 1 |
| continue |
|
|
| |
| try: |
| if not rejected_prompt or not chosen_response or not rejected_response: |
| skipped_basic_validation += 1 |
| continue |
| except Exception as e: |
| print(f"错误 (行 {index+1}): 基础有效性检查失败: {e}") |
| skipped_basic_validation += 1 |
| continue |
| |
| try: |
| conv_chosen = process_prompt(prompt, index, chosen_response, rejected_response) |
| sharegpt_data.append(conv_chosen) |
| except ValueError as e: |
| print(e) |
| skipped_basic_validation += 1 |
| continue |
|
|
| print(f"数据处理完成。成功转换 {len(sharegpt_data)} 条记录,跳过 {skipped_basic_validation} 条记录。\n") |
| merged_data = merged_data + sharegpt_data |
|
|
| |
| print(f"正在将结果写入 JSON 文件: {json_path}, 一共保存了 {len(merged_data)} 条记录。") |
| with open(json_path, 'w') as f: |
| json.dump(merged_data, f, indent=2) |
|
|
| if __name__ == "__main__": |
| parquet_file_list = [ |
| '/home/hsichen/LLaMA-Factory/data/chaiting/pk925/data/train-00000-of-00001.parquet' |
| ] |
| output_json_file = "/home/hsichen/LLaMA-Factory/data/chaiting/xty_collected/pk_433_test_llama.json" |
| convert_parquet_to_sharegpt_dpo_llama(parquet_file_list, output_json_file) |