mistral_2407_14b_rm_m / data /scripts_hsichen /pk4preference_llama.py
bingqin111's picture
Upload folder using huggingface_hub
677cc91 verified
# 导入所需库
import pandas as pd # 用于读取 Parquet 文件
import json # 用于处理 JSON 数据
import os # 用于处理文件路径
import re # 用于正则表达式解析 LLaMA 模板
# --- LLaMA 模板解析函数 ---
def parse_llama_template(prompt_string):
"""
解析 LLaMA 风格的模板字符串,还原成对话历史列表。
Args:
prompt_string (str): 包含 LLaMA 模板标记的字符串。
Returns:
list: 一个包含对话轮次的列表,每个轮次是 {"from": role, "value": content} 格式的字典。
如果解析失败或格式不符,可能返回空列表或部分解析结果。
"""
conversation_history = []
# 使用正则表达式匹配 LLaMA 模板的标记
# 匹配格式: <|start_header_id|>role<|end_header_id|>\n\ncontent<|eot_id|>
pattern = r"<\|start_header_id\|>(.*?)<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>"
matches = re.finditer(pattern, prompt_string, re.DOTALL)
for match in matches:
role_raw = match.group(1).strip().lower() # 获取角色
value = match.group(2).strip() # 获取内容
# 映射角色到 ShareGPT 格式
role_mapped = ""
if role_raw == "system":
role_mapped = "system"
elif role_raw == "user":
role_mapped = "human"
elif role_raw == "assistant":
role_mapped = "gpt"
else:
print(f"警告: 无法识别的角色 '{role_raw}',跳过此轮次。")
continue # 跳过无法识别角色的轮次
# 添加到对话历史
conversation_history.append({"from": role_mapped, "value": value})
if not conversation_history:
print("警告: 无法解析 LLaMA 模板,可能格式不正确或内容为空。")
return conversation_history
# --- 对话历史规范化函数 (与 Mistral 版本相同) ---
def normalize_conversation_history(raw_history, entry_index):
"""
规范化原始解析出的对话历史,确保满足以下规则:
1. 可选的单个 system 开头。
2. 第一个非 system 回合必须是 human (如果不是,前面插入空 human)。
3. 后续回合严格在 human 和 gpt 之间交替。
4. 连续相同角色的回合将被合并 (value 用 \n 连接)。
Args:
raw_history (list): 从 parse_llama_template 输出的原始对话历史列表。
entry_index (int): 当前处理的数据行索引(用于日志记录)。
Returns:
list: 规范化后的对话历史列表。如果出现无法处理的结构错误,返回空列表 `[]`。
"""
if not raw_history:
return [] # 输入为空,直接返回
normalized_history = []
processed_raw_index = 0 # 跟踪处理到 raw_history 的哪个索引
# 1. 处理可选的 system 开头
if raw_history[0]["from"] == "system":
normalized_history.append(raw_history[0])
processed_raw_index = 1 # system 处理完了,从下一个开始
if len(raw_history) == 1: # 只有 system prompt
print(f"警告 (行 {entry_index+1}): 原始历史只包含 system prompt,无法规范化为交替对话。")
return [] # 无法形成有效对话
# 如果处理完 system 后没有更多内容了
if processed_raw_index >= len(raw_history):
print(f"警告 (行 {entry_index+1}): 原始历史在 system prompt 后为空,无法规范化。")
return []
# 2. 处理第一个对话回合 (必须是 human)
first_conv_turn = raw_history[processed_raw_index]
if first_conv_turn["from"] == "gpt":
# 如果第一个是 gpt,插入一个空的 human
normalized_history.append({"from": "human", "value": ""})
# 注意:原始的 gpt 回合将在后续循环中处理
elif first_conv_turn["from"] == "human":
# 如果第一个是 human,直接添加
normalized_history.append(first_conv_turn)
processed_raw_index += 1 # 这个 human 处理完了
else: # 如果是 system (不应该在这里出现) 或其他错误
print(f"错误 (行 {entry_index+1}): 处理第一个对话回合时遇到非预期的角色 '{first_conv_turn['from']}'。")
return [] # 结构错误
# 3. 迭代处理剩余的回合,进行合并和交替检查
while processed_raw_index < len(raw_history):
current_raw_turn = raw_history[processed_raw_index]
last_normalized_turn = normalized_history[-1] # 获取已添加到规范列表的最后一个
# 检查角色是否与上一个规范化的回合相同
if current_raw_turn["from"] == last_normalized_turn["from"]:
# 角色相同,合并 value
# 只有在两个 value 都非空时才加换行符
if last_normalized_turn["value"] and current_raw_turn["value"]:
last_normalized_turn["value"] += "\n" + current_raw_turn["value"]
elif current_raw_turn["value"]: # 如果上一个为空,当前不为空
last_normalized_turn["value"] = current_raw_turn["value"]
# 如果当前为空,则无需操作
processed_raw_index += 1 # 当前原始回合已被合并
else:
# 角色不同,检查是否符合交替规则
expected_next_role = ""
if last_normalized_turn["from"] == "human":
expected_next_role = "gpt"
elif last_normalized_turn["from"] == "gpt":
expected_next_role = "human"
# (System 应该只在开头,理论上不会在这里的 last_normalized_turn)
if current_raw_turn["from"] == expected_next_role:
# 符合交替规则,直接添加
normalized_history.append(current_raw_turn)
processed_raw_index += 1 # 当前原始回合处理完毕
else:
# 不符合交替规则 (例如 system 再次出现,或 human->human 未合并等逻辑错误)
print(f"错误 (行 {entry_index+1}): 对话角色顺序错误。期望在 '{last_normalized_turn['from']}' 之后是 '{expected_next_role}',但得到 '{current_raw_turn['from']}'。")
return [] # 结构错误
if normalized_history[-1]["from"] == "gpt":
del normalized_history[-1]
# 仅保留 system_prompt 与 10 个回合(20 个 turn)
if len(normalized_history) > 21:
final_history = [normalized_history[0]] + normalized_history[-20:]
else:
final_history = normalized_history
return final_history
# --- 处理 Prompt 的函数 ---
def process_prompt(prompt_to_process, entry_index, chosen_response, rejected_response):
"""
处理单个 Prompt,解析并规范化对话历史,生成 ShareGPT DPO 格式的条目。
Args:
prompt_to_process (str): 要解析的 Prompt 字符串(LLaMA 模板)。
entry_index (int): 当前处理的数据行索引(用于日志记录)。
chosen_response (str): 选择的响应。
rejected_response (str): 拒绝的响应。
Returns:
dict: ShareGPT DPO 格式的条目,包含 conversations、chosen 和 rejected。
"""
# 1. 尝试基础解析
raw_parsed_history = parse_llama_template(prompt_to_process)
conversation_history = normalize_conversation_history(raw_parsed_history, entry_index)
if len(conversation_history) == 0:
raise ValueError(f"错误 (行 {entry_index+1}): 解析出的对话历史为空,跳过此条记录。")
# --- 构建 ShareGPT DPO 条目 ---
return {
"conversations": conversation_history,
"chosen": {"from": "gpt", "value": chosen_response},
"rejected": {"from": "gpt", "value": rejected_response}
}
# --- 主处理函数 ---
def convert_parquet_to_sharegpt_dpo_llama(parquet_paths, json_path):
"""
将 Parquet 格式的 DPO 数据转换为 ShareGPT DPO JSON 格式,针对 LLaMA 模板。
Args:
parquet_paths (str or list): 输入 Parquet 文件的路径(支持单个文件或文件列表)。
json_path (str): 输出 JSON 文件的路径。
"""
if not isinstance(parquet_paths, list):
parquet_paths = [parquet_paths]
merged_data = []
for parquet_path in parquet_paths:
print(f"开始转换文件: {parquet_path}")
try:
# 读取 Parquet 文件
df = pd.read_parquet(parquet_path)
print(f"成功读取 Parquet 文件,包含 {len(df)} 条记录。")
required_columns = ['chosen_prompt', 'reject_prompt', 'chosen', 'reject', 'chosen_model', 'reject_model']
if not all(col in df.columns for col in required_columns):
missing = [col for col in required_columns if col not in df.columns]
print(f"错误: Parquet 文件缺少必需的列: {missing}")
return
except Exception as e:
print(f"错误: 读取 Parquet 文件时出错: {e}")
return
sharegpt_data = [] # 用于存储转换后的数据
skipped_basic_validation = 0 # 因基础字段缺失跳过
print("开始处理数据行...")
for index, row in df.iterrows():
chosen_prompt = row.get('chosen_prompt', None)
rejected_prompt = row.get('reject_prompt', None)
chosen_response = row.get('chosen', None)
rejected_response = row.get('reject', None)
# 选择包含 LLaMA 标记的 Prompt
if '<|start_header_id|>' in chosen_prompt:
prompt = chosen_prompt
elif '<|start_header_id|>' in rejected_prompt:
prompt = rejected_prompt
else:
print(f"警告 (行 {index+1}): 没有 <|start_header_id|> 符号出现在 Prompt 中,跳过此条记录。")
skipped_basic_validation += 1
continue
# --- 基础有效性检查 ---
try:
if not rejected_prompt or not chosen_response or not rejected_response:
skipped_basic_validation += 1
continue
except Exception as e:
print(f"错误 (行 {index+1}): 基础有效性检查失败: {e}")
skipped_basic_validation += 1
continue
try:
conv_chosen = process_prompt(prompt, index, chosen_response, rejected_response)
sharegpt_data.append(conv_chosen)
except ValueError as e:
print(e)
skipped_basic_validation += 1
continue
print(f"数据处理完成。成功转换 {len(sharegpt_data)} 条记录,跳过 {skipped_basic_validation} 条记录。\n")
merged_data = merged_data + sharegpt_data
# --- 将结果写入 JSON 文件 ---
print(f"正在将结果写入 JSON 文件: {json_path}, 一共保存了 {len(merged_data)} 条记录。")
with open(json_path, 'w') as f:
json.dump(merged_data, f, indent=2)
if __name__ == "__main__":
parquet_file_list = [
'/home/hsichen/LLaMA-Factory/data/chaiting/pk925/data/train-00000-of-00001.parquet'
]
output_json_file = "/home/hsichen/LLaMA-Factory/data/chaiting/xty_collected/pk_433_test_llama.json"
convert_parquet_to_sharegpt_dpo_llama(parquet_file_list, output_json_file)