# 导入所需库 import pandas as pd # 用于读取 Parquet 文件 import json # 用于处理 JSON 数据 import os # 用于处理文件路径 import re # 用于正则表达式解析 LLaMA 模板 # --- LLaMA 模板解析函数 --- def parse_llama_template(prompt_string): """ 解析 LLaMA 风格的模板字符串,还原成对话历史列表。 Args: prompt_string (str): 包含 LLaMA 模板标记的字符串。 Returns: list: 一个包含对话轮次的列表,每个轮次是 {"from": role, "value": content} 格式的字典。 如果解析失败或格式不符,可能返回空列表或部分解析结果。 """ conversation_history = [] # 使用正则表达式匹配 LLaMA 模板的标记 # 匹配格式: <|start_header_id|>role<|end_header_id|>\n\ncontent<|eot_id|> pattern = r"<\|start_header_id\|>(.*?)<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>" matches = re.finditer(pattern, prompt_string, re.DOTALL) for match in matches: role_raw = match.group(1).strip().lower() # 获取角色 value = match.group(2).strip() # 获取内容 # 映射角色到 ShareGPT 格式 role_mapped = "" if role_raw == "system": role_mapped = "system" elif role_raw == "user": role_mapped = "human" elif role_raw == "assistant": role_mapped = "gpt" else: print(f"警告: 无法识别的角色 '{role_raw}',跳过此轮次。") continue # 跳过无法识别角色的轮次 # 添加到对话历史 conversation_history.append({"from": role_mapped, "value": value}) if not conversation_history: print("警告: 无法解析 LLaMA 模板,可能格式不正确或内容为空。") return conversation_history # --- 对话历史规范化函数 (与 Mistral 版本相同) --- def normalize_conversation_history(raw_history, entry_index): """ 规范化原始解析出的对话历史,确保满足以下规则: 1. 可选的单个 system 开头。 2. 第一个非 system 回合必须是 human (如果不是,前面插入空 human)。 3. 后续回合严格在 human 和 gpt 之间交替。 4. 连续相同角色的回合将被合并 (value 用 \n 连接)。 Args: raw_history (list): 从 parse_llama_template 输出的原始对话历史列表。 entry_index (int): 当前处理的数据行索引(用于日志记录)。 Returns: list: 规范化后的对话历史列表。如果出现无法处理的结构错误,返回空列表 `[]`。 """ if not raw_history: return [] # 输入为空,直接返回 normalized_history = [] processed_raw_index = 0 # 跟踪处理到 raw_history 的哪个索引 # 1. 处理可选的 system 开头 if raw_history[0]["from"] == "system": normalized_history.append(raw_history[0]) processed_raw_index = 1 # system 处理完了,从下一个开始 if len(raw_history) == 1: # 只有 system prompt print(f"警告 (行 {entry_index+1}): 原始历史只包含 system prompt,无法规范化为交替对话。") return [] # 无法形成有效对话 # 如果处理完 system 后没有更多内容了 if processed_raw_index >= len(raw_history): print(f"警告 (行 {entry_index+1}): 原始历史在 system prompt 后为空,无法规范化。") return [] # 2. 处理第一个对话回合 (必须是 human) first_conv_turn = raw_history[processed_raw_index] if first_conv_turn["from"] == "gpt": # 如果第一个是 gpt,插入一个空的 human normalized_history.append({"from": "human", "value": ""}) # 注意:原始的 gpt 回合将在后续循环中处理 elif first_conv_turn["from"] == "human": # 如果第一个是 human,直接添加 normalized_history.append(first_conv_turn) processed_raw_index += 1 # 这个 human 处理完了 else: # 如果是 system (不应该在这里出现) 或其他错误 print(f"错误 (行 {entry_index+1}): 处理第一个对话回合时遇到非预期的角色 '{first_conv_turn['from']}'。") return [] # 结构错误 # 3. 迭代处理剩余的回合,进行合并和交替检查 while processed_raw_index < len(raw_history): current_raw_turn = raw_history[processed_raw_index] last_normalized_turn = normalized_history[-1] # 获取已添加到规范列表的最后一个 # 检查角色是否与上一个规范化的回合相同 if current_raw_turn["from"] == last_normalized_turn["from"]: # 角色相同,合并 value # 只有在两个 value 都非空时才加换行符 if last_normalized_turn["value"] and current_raw_turn["value"]: last_normalized_turn["value"] += "\n" + current_raw_turn["value"] elif current_raw_turn["value"]: # 如果上一个为空,当前不为空 last_normalized_turn["value"] = current_raw_turn["value"] # 如果当前为空,则无需操作 processed_raw_index += 1 # 当前原始回合已被合并 else: # 角色不同,检查是否符合交替规则 expected_next_role = "" if last_normalized_turn["from"] == "human": expected_next_role = "gpt" elif last_normalized_turn["from"] == "gpt": expected_next_role = "human" # (System 应该只在开头,理论上不会在这里的 last_normalized_turn) if current_raw_turn["from"] == expected_next_role: # 符合交替规则,直接添加 normalized_history.append(current_raw_turn) processed_raw_index += 1 # 当前原始回合处理完毕 else: # 不符合交替规则 (例如 system 再次出现,或 human->human 未合并等逻辑错误) print(f"错误 (行 {entry_index+1}): 对话角色顺序错误。期望在 '{last_normalized_turn['from']}' 之后是 '{expected_next_role}',但得到 '{current_raw_turn['from']}'。") return [] # 结构错误 if normalized_history[-1]["from"] == "gpt": del normalized_history[-1] # 仅保留 system_prompt 与 10 个回合(20 个 turn) if len(normalized_history) > 21: final_history = [normalized_history[0]] + normalized_history[-20:] else: final_history = normalized_history return final_history # --- 处理 Prompt 的函数 --- def process_prompt(prompt_to_process, entry_index, chosen_response, rejected_response): """ 处理单个 Prompt,解析并规范化对话历史,生成 ShareGPT DPO 格式的条目。 Args: prompt_to_process (str): 要解析的 Prompt 字符串(LLaMA 模板)。 entry_index (int): 当前处理的数据行索引(用于日志记录)。 chosen_response (str): 选择的响应。 rejected_response (str): 拒绝的响应。 Returns: dict: ShareGPT DPO 格式的条目,包含 conversations、chosen 和 rejected。 """ # 1. 尝试基础解析 raw_parsed_history = parse_llama_template(prompt_to_process) conversation_history = normalize_conversation_history(raw_parsed_history, entry_index) if len(conversation_history) == 0: raise ValueError(f"错误 (行 {entry_index+1}): 解析出的对话历史为空,跳过此条记录。") # --- 构建 ShareGPT DPO 条目 --- return { "conversations": conversation_history, "chosen": {"from": "gpt", "value": chosen_response}, "rejected": {"from": "gpt", "value": rejected_response} } # --- 主处理函数 --- def convert_parquet_to_sharegpt_dpo_llama(parquet_paths, json_path): """ 将 Parquet 格式的 DPO 数据转换为 ShareGPT DPO JSON 格式,针对 LLaMA 模板。 Args: parquet_paths (str or list): 输入 Parquet 文件的路径(支持单个文件或文件列表)。 json_path (str): 输出 JSON 文件的路径。 """ if not isinstance(parquet_paths, list): parquet_paths = [parquet_paths] merged_data = [] for parquet_path in parquet_paths: print(f"开始转换文件: {parquet_path}") try: # 读取 Parquet 文件 df = pd.read_parquet(parquet_path) print(f"成功读取 Parquet 文件,包含 {len(df)} 条记录。") required_columns = ['chosen_prompt', 'reject_prompt', 'chosen', 'reject', 'chosen_model', 'reject_model'] if not all(col in df.columns for col in required_columns): missing = [col for col in required_columns if col not in df.columns] print(f"错误: Parquet 文件缺少必需的列: {missing}") return except Exception as e: print(f"错误: 读取 Parquet 文件时出错: {e}") return sharegpt_data = [] # 用于存储转换后的数据 skipped_basic_validation = 0 # 因基础字段缺失跳过 print("开始处理数据行...") for index, row in df.iterrows(): chosen_prompt = row.get('chosen_prompt', None) rejected_prompt = row.get('reject_prompt', None) chosen_response = row.get('chosen', None) rejected_response = row.get('reject', None) # 选择包含 LLaMA 标记的 Prompt if '<|start_header_id|>' in chosen_prompt: prompt = chosen_prompt elif '<|start_header_id|>' in rejected_prompt: prompt = rejected_prompt else: print(f"警告 (行 {index+1}): 没有 <|start_header_id|> 符号出现在 Prompt 中,跳过此条记录。") skipped_basic_validation += 1 continue # --- 基础有效性检查 --- try: if not rejected_prompt or not chosen_response or not rejected_response: skipped_basic_validation += 1 continue except Exception as e: print(f"错误 (行 {index+1}): 基础有效性检查失败: {e}") skipped_basic_validation += 1 continue try: conv_chosen = process_prompt(prompt, index, chosen_response, rejected_response) sharegpt_data.append(conv_chosen) except ValueError as e: print(e) skipped_basic_validation += 1 continue print(f"数据处理完成。成功转换 {len(sharegpt_data)} 条记录,跳过 {skipped_basic_validation} 条记录。\n") merged_data = merged_data + sharegpt_data # --- 将结果写入 JSON 文件 --- print(f"正在将结果写入 JSON 文件: {json_path}, 一共保存了 {len(merged_data)} 条记录。") with open(json_path, 'w') as f: json.dump(merged_data, f, indent=2) if __name__ == "__main__": parquet_file_list = [ '/home/hsichen/LLaMA-Factory/data/chaiting/pk925/data/train-00000-of-00001.parquet' ] output_json_file = "/home/hsichen/LLaMA-Factory/data/chaiting/xty_collected/pk_433_test_llama.json" convert_parquet_to_sharegpt_dpo_llama(parquet_file_list, output_json_file)