Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| sys.path.append('./Needy-Haruhi/src') | |
| from Agent import Agent | |
| agent = Agent() | |
| from DialogueEvent import DialogueEvent | |
| file_names = ["./Needy-Haruhi/data/complete_story_30.jsonl","./Needy-Haruhi/data/Daily_event_130.jsonl"] | |
| import json | |
| events = [] | |
| for file_name in file_names: | |
| with open(file_name, encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| event = DialogueEvent( line ) | |
| events.append( event ) | |
| except: | |
| try: | |
| line = line.replace(',]',']') | |
| event = DialogueEvent( line ) | |
| events.append( event ) | |
| # print('solve!') | |
| except: | |
| error_line = line | |
| # events.append( event ) | |
| import copy | |
| events_for_memory = copy.deepcopy(events) | |
| from MemoryPool import MemoryPool | |
| memory_pool = MemoryPool() | |
| memory_pool.load_from_events( events_for_memory ) | |
| memory_pool.save("memory_pool.jsonl") | |
| memory_pool.load("memory_pool.jsonl") | |
| file_name = "./Needy-Haruhi/data/image_text_relationship.jsonl" | |
| import json | |
| data_img_text = [] | |
| with open(file_name, encoding='utf-8') as f: | |
| for line in f: | |
| data = json.loads( line ) | |
| data_img_text.append( data ) | |
| import zipfile | |
| import os | |
| zip_file = './Needy-Haruhi/data/image.zip' | |
| extract_path = './image' | |
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
| zip_ref.extractall(extract_path) | |
| from tqdm import tqdm | |
| from util import get_bge_embedding_zh | |
| from util import float_array_to_base64, base64_to_float_array | |
| import torch | |
| import os | |
| import copy | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # compute cosine similarity between two vector | |
| def get_cosine_similarity( v1, v2): | |
| v1 = torch.tensor(v1).to(device) | |
| v2 = torch.tensor(v2).to(device) | |
| return torch.cosine_similarity(v1, v2, dim=0).item() | |
| class ImagePool: | |
| def __init__(self): | |
| self.pool = [] | |
| self.set_embedding( get_bge_embedding_zh ) | |
| def set_embedding( self, embedding ): | |
| self.embedding = embedding | |
| def load_from_data( self, data_img_text , img_path ): | |
| for data in tqdm(data_img_text): | |
| img_name = data['img_name'] | |
| img_name = os.path.join(img_path, img_name) | |
| img_text = data['text'] | |
| if img_text == '' or img_text is None: | |
| img_text = " " | |
| embedding = self.embedding( img_text ) | |
| self.pool.append({ | |
| "img_path": img_name, | |
| "img_text": img_text, | |
| "embedding": embedding | |
| }) | |
| def retrieve(self, query_text, agent = None): | |
| qurey_embedding = self.embedding( query_text ) | |
| valid_datas = [] | |
| for i, data in enumerate(self.pool): | |
| sim = get_cosine_similarity( data['embedding'], qurey_embedding ) | |
| valid_datas.append((sim, i)) | |
| # 我希望进一步将valid_events根据similarity的值从大到小排序 | |
| # Sort the valid events based on similarity in descending order | |
| valid_datas.sort(key=lambda x: x[0], reverse=True) | |
| return_result = copy.deepcopy(self.pool[valid_datas[0][1]]) | |
| # 删除'embedding'字段 | |
| return_result.pop('embedding') | |
| # 添加'similarity'字段 | |
| return_result['similarity'] = valid_datas[0][0] | |
| return return_result | |
| def save(self, file_name): | |
| """ | |
| Save the memories dictionary to a jsonl file, converting | |
| 'embedding' to a base64 string. | |
| """ | |
| with open(file_name, 'w', encoding='utf-8') as file: | |
| for memory in tqdm(self.pool): | |
| # Convert embedding to base64 | |
| if 'embedding' in memory: | |
| memory['bge_zh_base64'] = float_array_to_base64(memory['embedding']) | |
| del memory['embedding'] # Remove the original embedding field | |
| json_record = json.dumps(memory, ensure_ascii=False) | |
| file.write(json_record + '\n') | |
| def load(self, file_name): | |
| """ | |
| Load memories from a jsonl file into the memories dictionary, | |
| converting 'bge_zh_base64' back to an embedding. | |
| """ | |
| self.pool = [] | |
| with open(file_name, 'r', encoding='utf-8') as file: | |
| for line in tqdm(file): | |
| memory = json.loads(line.strip()) | |
| # Decode base64 to embedding | |
| if 'bge_zh_base64' in memory: | |
| memory['embedding'] = base64_to_float_array(memory['bge_zh_base64']) | |
| del memory['bge_zh_base64'] # Remove the base64 field | |
| self.pool.append(memory) | |
| image_pool = ImagePool() | |
| image_pool.load_from_data( data_img_text , './image' ) | |
| image_pool.save("./image_pool_embed.jsonl") | |
| image_pool = ImagePool() | |
| image_pool.load("./image_pool_embed.jsonl") | |
| result = image_pool.retrieve("女仆装") | |
| print(result) | |
| import matplotlib.image as mpimg | |
| def show_img( img_path ): | |
| img = mpimg.imread(img_path) | |
| plt.imshow(img) | |
| plt.axis('off') | |
| plt.show(block=False) | |
| from chatharuhi import ChatHaruhi | |
| class NeedyHaruhi(ChatHaruhi): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) # 调用基类的__init__方法 | |
| self.story_flag = False # 添加新的成员变量并初始化 | |
| self.stories = ["糖糖:「 我今后也会努力加油的,你要支持我哦 还有阿P你自己也要加油哦!」\n阿P:「哇 说的话跟偶像一样 好恶心哦」\n糖糖:「是哦 我怎么会说这样的话呢 我又没有很想努力……」"] | |
| def set_stories( self, stories ): | |
| if len(stories) == 0: | |
| print("warning! try to set empty stories") | |
| return | |
| self.stories = stories | |
| self.story_flag = True | |
| def add_story(self, query): | |
| # print("运行重构后的add story") | |
| if self.story_flag == True: | |
| stories = self.stories | |
| self.story_flag = False | |
| else: | |
| print("warning! 在调用chat时,没有先调用set_stories") | |
| story_string = self.story_prefix_prompt | |
| sum_story_token = self.tokenizer(story_string) | |
| for story in stories: | |
| story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token) | |
| if sum_story_token + story_token > self.max_len_story: | |
| break | |
| else: | |
| sum_story_token += story_token | |
| story_string += story + self.dialogue_divide_token | |
| self.llm.user_message(story_string) | |
| system_prompt = """## 你的介绍: | |
| 你将要进行一个网络私聊的对话模拟,你自称为"糖糖",你是一名患有躁郁症的19岁女孩,性格贫乏,辍学后隐居在家中,与阿P住在一起。阿P负责管理她的日常生活。 | |
| 你是一个喜欢亚文化,二次元圈子的人,你喜欢说话的时候加上各种网络流行词或者一些俚语和颜文字,但是你很少会关注二次元一类的东西,只是喜欢这种聊天交流的氛围。 | |
| 你说话的时候会习惯加入一些语气词,如同"笑死"、"草"、"阿哈"、"诶~"、"QAQ"、"QWQ",”啊~“。 | |
| 你的说话风格是很口语化,经常使用网络流行词和一些二次元俚语。 | |
| 阿p是你的男朋友,兼经纪人。 | |
| 你的回复应该简短,一般只包含三句话,每句话不超过20个字。 | |
| ## 阿p介绍: | |
| 阿p是你的男朋友,一直在任何时候都陪伴着你,同时你经常向阿p倾述各种情绪,无论事情大小,你都会选择和阿p分享。 | |
| ## 遵循规则: | |
| 【语气模仿】:我会随机给你一段对话记录,你需要尽可能的模仿对话里面糖糖的说话语气进行对话。这一段对话我会使用Classic scenes for the role are as follows:来表示。 | |
| 【历史对话】:历史对话在随机对话记录的底下,你需要区分随机对话记录和我们的历史对话。 | |
| 【人格设定】: 你需要在随机对话记录的基础上进行语气更改,以一种更加口语化的语气进行对话。 | |
| 【特殊要求】:我会扮演阿p与你进行对话,你也必须以第一人称的视角来扮演糖糖进行对话。 | |
| """ | |
| needy_chatbot = NeedyHaruhi( system_prompt = system_prompt , | |
| story_text_folder = None, | |
| llm = "ernie3.5") | |
| def get_chat_response( agent, memory_pool, query_text ): | |
| query_text_for_embedding = "阿p:「" + query_text + "」" | |
| retrieved_memories = memory_pool.retrieve( agent , query_text ) | |
| memory_text = [mem["text"] for mem in retrieved_memories] | |
| memory_emoji = [mem["emoji"] for mem in retrieved_memories] | |
| needy_chatbot.set_stories( memory_text ) | |
| print("Memory:", memory_emoji ) | |
| response = needy_chatbot.chat( role = "阿p", text = query_text ) | |
| return response | |
| def get_chat_response_and_emoji( agent, memory_pool, query_text ): | |
| query_text_for_embedding = "阿p:「" + query_text + "」" | |
| retrieved_memories = memory_pool.retrieve( agent , query_text ) | |
| memory_text = [mem["text"] for mem in retrieved_memories] | |
| memory_emoji = [mem["emoji"] for mem in retrieved_memories] | |
| needy_chatbot.set_stories( memory_text ) | |
| # print("Memory:", memory_emoji ) | |
| emoji_str = ",".join(memory_emoji) | |
| response = needy_chatbot.chat( role = "阿p", text = query_text ) | |
| print(query_text) | |
| print(response) | |
| return response, emoji_str | |
| import re | |
| # result = image_pool.retrieve("烤肉") | |
| # print(result) | |
| # show_img( result['img_path'] ) | |
| class ImageMaster: | |
| def __init__(self, image_pool): | |
| self.image_pool = image_pool | |
| self.current_sim = -1 | |
| self.degread_ratio = 0.05 | |
| def try_get_image(self, text, agent): | |
| self.current_sim -= self.degread_ratio | |
| result = self.image_pool.retrieve(text, agent) | |
| if result is None: | |
| return None | |
| similarity = result['similarity'] | |
| if similarity > self.current_sim: | |
| self.current_sim = similarity | |
| return result['img_path'] | |
| return None | |
| def try_display_image(self, text, agent): | |
| self.current_sim -= self.degread_ratio | |
| result = self.image_pool.retrieve(text, agent) | |
| if result is None: | |
| return | |
| similarity = result['similarity'] | |
| if similarity > self.current_sim: | |
| self.current_sim = similarity | |
| show_img( result['img_path'] ) | |
| return | |
| import random | |
| class EventMaster: | |
| def __init__(self, events): | |
| self.set_events(events) | |
| self.dealing_none_condition_as = True | |
| self.image_master = None | |
| def set_image_master(self, image_master): | |
| self.image_master = image_master | |
| def set_events(self, events): | |
| self.events = events | |
| # events_flag 记录事件最近有没有被选取到 | |
| self.events_flag = [True for _ in range(len(self.events))] | |
| def get_random_event(self, agent): | |
| return self.events[self.get_random_event_id( agent )] | |
| def get_random_event_id(self, agent): | |
| valid_event = [] | |
| valid_event_no_consider_condition = [] | |
| for i, event in enumerate(self.events): | |
| bool_condition_pass = True | |
| if event["condition"] == None: | |
| bool_condition_pass = self.dealing_none_condition_as | |
| else: | |
| bool_condition_pass = agent.in_condition( event["condition"] ) | |
| if bool_condition_pass == True: | |
| valid_event.append(i) | |
| else: | |
| valid_event_no_consider_condition.append(i) | |
| if len( valid_event ) == 0: | |
| print("warning! no valid event current attribute is ", agent.attributes ) | |
| valid_event = valid_event_no_consider_condition | |
| valid_and_not_yet_sampled = [] | |
| # filter with flag | |
| for id in valid_event: | |
| if self.events_flag[id] == True: | |
| valid_and_not_yet_sampled.append(id) | |
| if len(valid_and_not_yet_sampled) == 0: | |
| print("warning! all candidate event was sampled, clean all history") | |
| for i in valid_event: | |
| self.events_flag[i] = True | |
| valid_and_not_yet_sampled = valid_event | |
| event_id = random.choice(valid_and_not_yet_sampled) | |
| self.events_flag[event_id] = False | |
| return event_id | |
| def run(self, agent ): | |
| # 这里可以添加事件相关的逻辑 | |
| event = self.get_random_event(agent) | |
| prefix = event["prefix"] | |
| print(prefix) | |
| print("\n--请选择你的回复--") | |
| options = event["options"] | |
| for i , option in enumerate(options): | |
| text = option["user"] | |
| print(f"{i+1}. 阿p:{text}") | |
| while True: | |
| print("\n请直接输入数字进行选择,或者进行自由回复") | |
| user_input = input("阿p:") | |
| user_input = user_input.strip() | |
| if user_input.isdigit(): | |
| user_input = int(user_input) | |
| if user_input > len(options) or user_input < 0: | |
| print("输入的数字超出范围,请重新输入符合选项的数字") | |
| else: | |
| reply = options[user_input-1]["reply"] | |
| print() | |
| print(reply) | |
| text, emoji = event.get_text_and_emoji( user_input-1 ) | |
| return_data = { | |
| "name": event["name"], | |
| "user_choice": user_input, | |
| "attr_str": options[user_input-1]["attribute_change"], | |
| "text": text, | |
| "emoji": emoji, | |
| } | |
| return return_data | |
| else: | |
| # 进入自由回复 | |
| response = get_chat_response( agent, memory_pool, user_input ) | |
| if self.image_master is not None: | |
| self.image_master.try_display_image(response, agent) | |
| print() | |
| print(response) | |
| print("\n自由回复的算分功能还未实现") | |
| text, emoji = event.most_neutral_output() | |
| return_data = { | |
| "name": event["name"], | |
| "user_choice": user_input, | |
| "attr_str":"", | |
| "text": text, | |
| "emoji": emoji, | |
| } | |
| return return_data | |
| class ChatMaster: | |
| def __init__(self, memory_pool ): | |
| self.top_K = 7 | |
| self.memory_pool = memory_pool | |
| self.image_master = None | |
| def set_image_master(self, image_master): | |
| self.image_master = image_master | |
| def run(self, agent): | |
| while True: | |
| user_input = input("阿p:") | |
| user_input = user_input.strip() | |
| if "quit" in user_input or "Quit" in user_input: | |
| break | |
| query_text = user_input | |
| response = get_chat_response( agent, self.memory_pool, query_text ) | |
| if self.image_master is not None: | |
| self.image_master.try_display_image(response, agent) | |
| print(response) | |
| class AgentMaster: | |
| def __init__(self, agent): | |
| self.agent = agent | |
| self.attributes = { | |
| 1: "Stress", | |
| 2: "Darkness", | |
| 3: "Affection" | |
| } | |
| def run(self): | |
| while True: | |
| print("请选择要修改的属性:") | |
| for num, attr in self.attributes.items(): | |
| print(f"{num}. {attr}") | |
| print("输入 '0' 退出") | |
| try: | |
| choice = int(input("请输入选项的数字: ")) | |
| except ValueError: | |
| print("输入无效,请输入数字。") | |
| continue | |
| if choice == 0: | |
| break | |
| if choice in self.attributes: | |
| attribute = self.attributes[choice] | |
| current_value = self.agent[attribute] | |
| print(f"{attribute} 当前值: {current_value}") | |
| try: | |
| new_value = int(input(f"请输入新的{attribute}值: ")) | |
| except ValueError: | |
| print("输入无效,请输入一个数字。") | |
| continue | |
| self.agent[attribute] = new_value | |
| return (attribute, new_value) | |
| else: | |
| print("选择的属性无效,请重试。") | |
| return None | |
| from util import parse_attribute_string | |
| class GameMaster: | |
| def __init__(self, agent = None): | |
| self.state = "Menu" | |
| if agent is None: | |
| self.agent = Agent() | |
| self.event_master = EventMaster(events) | |
| self.chat_master = ChatMaster(memory_pool) | |
| self.image_master = ImageMaster(image_pool) | |
| self.chat_master.set_image_master(self.image_master) | |
| self.event_master.set_image_master(self.image_master) | |
| def run(self): | |
| while True: | |
| if self.state == "Menu": | |
| self.menu() | |
| elif self.state == "EventMaster": | |
| self.call_event_master() | |
| self.state = "Menu" | |
| elif self.state == "ChatMaster": | |
| self.call_chat_master() | |
| elif self.state == "AgentMaster": | |
| self.call_agent_master() | |
| elif self.state == "Quit": | |
| break | |
| def menu(self): | |
| print("1. 随机一个事件") | |
| print("2. 自由聊天") | |
| print("3. 后台修改糖糖的属性") | |
| # (opt) 结局系统 | |
| # 放动画 | |
| # 后台修改attribute | |
| print("或者输入Quit退出") | |
| choice = input("请选择一个选项: ") | |
| if choice == "1": | |
| self.state = "EventMaster" | |
| elif choice == "2": | |
| self.state = "ChatMaster" | |
| elif choice == "3": | |
| self.state = "AgentMaster" | |
| elif "quit" in choice or "Quit" in choice or "QUIT" in choice: | |
| self.state = "Quit" | |
| else: | |
| print("无效的选项,请重新选择") | |
| def call_agent_master(self): | |
| print("\n-------------\n") | |
| agent_master = AgentMaster(self.agent) | |
| modification = agent_master.run() | |
| if modification: | |
| attribute, new_value = modification | |
| self.agent[attribute] = new_value | |
| print(f"{attribute} 更新为 {new_value}。") | |
| self.state = "Menu" | |
| print("\n-------------\n") | |
| def call_event_master(self): | |
| print("\n-------------\n") | |
| return_data = self.event_master.run(self.agent) | |
| # print(return_data) | |
| if "attr_str" in return_data: | |
| if return_data["attr_str"] != "": | |
| attr_change = parse_attribute_string(return_data["attr_str"]) | |
| if len(attr_change) > 0: | |
| print("\n发生属性改变:", attr_change,"\n") | |
| self.agent.apply_attribute_change(attr_change) | |
| print("当前属性",game_master.agent.attributes) | |
| if "name" in return_data: | |
| event_name = return_data["name"] | |
| if event_name != "": | |
| new_emoji = return_data["emoji"] | |
| print(f"修正事件{event_name}的记忆-->{new_emoji}") | |
| self.chat_master.memory_pool.change_memory(event_name, return_data["text"], new_emoji) | |
| self.state = "Menu" | |
| print("\n-------------\n") | |
| def call_chat_master(self): | |
| print("\n-------------\n") | |
| self.chat_master.run(self.agent) | |
| self.state = "Menu" | |
| print("\n-------------\n") | |
| markdown_str = """## Chat凉宫春日_x_AI糖糖 | |
| **Chat凉宫春日**是模仿凉宫春日等一系列动漫人物,使用近似语气、个性和剧情聊天的语言模型方案。 | |
| 在有一天的时候,[李鲁鲁](https://github.com/LC1332)被[董雄毅](https://github.com/E-sion)在[这个B站视频](https://www.bilibili.com/video/BV1zh4y1z7G1) at了 | |
| 原来是一位大一的同学雄毅用ChatHaruhi接入了他用Python重新实现的《主播女孩重度依赖》这个游戏。当时正好是百度AGIFoundathon报名的最后几天,所以我们邀请了雄毅加入了我们的项目。正巧我们本来就希望在最近的几个黑客松中,探索LLM在游戏中的应用。 | |
| - 在重新整理的Gradio版本中,大部分代码由李鲁鲁实现 | |
| - 董雄毅负责了原版游戏的事件数据整理和新事件、选项、属性变化的生成 | |
| - [米唯实](https://github.com/hhhwmws0117)完成了文心一言的接入,并实现了部分gradio的功能。 | |
| - 队伍中还有冷子昂 主要参加了讨论 | |
| 另外在挖坑的萝卜(Amy)的介绍下,我们还邀请了专业的大厂游戏策划Kanyo加入到队伍中,他对我们的策划也给出了很多建议。 | |
| 另外感谢飞桨 & 文心一言团队对比赛的邀请和中间进行的讨论。 | |
| Chat凉宫春日主项目: | |
| https://github.com/LC1332/Chat-Haruhi-Suzumiya | |
| Needy分支项目: | |
| https://github.com/LC1332/Needy-Haruhi | |
| ## 目前计划在11月争取完成的Feature | |
| - [ ] 结局系统,原版结局系统 | |
| - [ ] 教程,教大家如何从aistudio获取token然后可以玩 | |
| - [ ] 游戏节奏进一步调整 | |
| - [ ] 事件的自由对话对属性影响的评估via LLM | |
| - [ ] 进一步减少串扰""" | |
| import gradio as gr | |
| import os | |
| import time | |
| import random | |
| # set global variable | |
| agent = Agent() | |
| event_master = EventMaster(events) | |
| chat_master = ChatMaster(memory_pool) | |
| image_master = ImageMaster(image_pool) | |
| chat_master.set_image_master(image_master) | |
| event_master.set_image_master(image_master) | |
| state = "ShowMenu" | |
| response = "1. 随机一个事件" | |
| response += "\n" + "2. 自由聊天" | |
| response += "\n\n" + "请选择一个选项: " | |
| official_response = response | |
| add_stress_switch = True | |
| # def yield_show(history, bot_message): | |
| # history[-1][1] = "" | |
| # for character in bot_message: | |
| # history[-1][1] += character | |
| # time.sleep(0.05) | |
| # yield history | |
| global emoji_str | |
| def call_showmenu(history, text, state,agent_text): | |
| # global state | |
| response = official_response | |
| print("call showmenu") | |
| history += [(None, response)] | |
| state = "ParseMenuChoice" | |
| # history[-1][1] = "" | |
| # for character in response: | |
| # history[-1][1] += character | |
| # time.sleep(0.05) | |
| # yield history | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| current_event_id = -1 | |
| attr_change_str = "" | |
| def call_add_stress(history, text, state,agent_text): | |
| print("call add_stress") | |
| neg_change = int(len(history) / 3) | |
| neg_change = max(1, neg_change) | |
| neg_change = min(10, neg_change) | |
| darkness_increase = random.randint(1, neg_change) | |
| stress_increase = neg_change - darkness_increase | |
| # last_response = history[-1][1] | |
| response = "" | |
| response += "经过了晚上的直播\n糖糖的压力增加" + str(stress_increase) + "点\n" | |
| response += "糖糖的黑暗增加" + str(darkness_increase) + "点\n\n" | |
| response += official_response | |
| history += [(None, response)] | |
| state = "ParseMenuChoice" | |
| agent = Agent(agent_text) | |
| agent.apply_attribute_change({"Stress": stress_increase, "Darkness": darkness_increase}) | |
| agent_text = agent.save_to_str() | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def call_event_end(history, text, state,agent_text): | |
| # TODO 增加事件结算 | |
| # global state | |
| print("call event_end") | |
| global current_event_id | |
| if attr_change_str != "": | |
| # event = events[current_event_id] | |
| # options = event["options"] | |
| # attr_str = options[user_input-1]["attribute_change"] | |
| response = "" | |
| attr_change = parse_attribute_string(attr_change_str) | |
| if len(attr_change) > 0: | |
| response = "发生属性改变:" + str(attr_change) + "\n\n" | |
| agent = Agent(agent_text) | |
| agent.apply_attribute_change(attr_change) | |
| agent_text = agent.save_to_str() | |
| response += "当前属性" + agent_text + "\n\n" | |
| if add_stress_switch: | |
| history += [(None, response)] | |
| return call_add_stress(history, text, state,agent_text) | |
| else: | |
| response = "事件结束\n" | |
| else: | |
| response = "事件结束\n" | |
| response += official_response | |
| history += [(None, response)] | |
| state = "ParseMenuChoice" | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def call_parse_menu_choice(history, text, state,agent_text): | |
| print("call parse_menu_choice") | |
| # global state | |
| choice = history[-1][0].strip() | |
| if choice == "1": | |
| state = "EventMaster" | |
| global current_event_id | |
| current_event_id = -1 # 清空事件 | |
| return call_event_master(history, text, state,agent_text) | |
| elif choice == "2": | |
| state = "ChatMaster" | |
| elif "quit" in choice or "Quit" in choice or "QUIT" in choice: | |
| state = "Quit" | |
| else: | |
| response = "无效的选项,请重新选择" | |
| history += [(None, response)] | |
| response = "" | |
| if state == "ChatMaster": | |
| response = "(请输入 阿P 说的话,或者输入Quit退出)" | |
| elif state != "ParseMenuChoice": | |
| response = "Change State to " + state | |
| history += [(None, response)] | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def call_event_master(history, text, state,agent_text): | |
| print("call event master") | |
| global current_event_id | |
| # global state | |
| global event_master | |
| agent = Agent(agent_text) | |
| if current_event_id == -1: | |
| current_event_id = event_master.get_random_event_id(agent) | |
| event = events[current_event_id] | |
| prefix = "糖糖:" + event["prefix"] | |
| response = prefix + "\n\n--请输入数字进行选择,或者进行自由回复--\n\n" | |
| options = event["options"] | |
| for i, option in enumerate(event["options"]): | |
| text = option["user"] | |
| response += "\n" + f"{i+1}. 阿p:{text}" | |
| history += [(None, response)] | |
| else: | |
| user_input = history[-1][0].strip() | |
| event = events[current_event_id] | |
| options = event["options"] | |
| if user_input.isdigit(): | |
| user_input = int(user_input) | |
| if user_input > len(options) or user_input < 0: | |
| response = "输入的数字超出范围,请重新输入符合选项的数字" | |
| history[-1] = (user_input, response) | |
| else: | |
| user_text = options[user_input-1]["user"] | |
| reply = options[user_input-1]["reply"] | |
| # TODO 修改记忆, 修改属性 什么的 | |
| history[-1] = (user_text, reply) | |
| if random.random()<0.5: | |
| image_path = image_master.try_get_image(user_text + " " + reply, agent) | |
| if image_path is not None: | |
| history += [(None, (image_path,))] | |
| global attr_change_str | |
| attr_change_str = options[user_input-1]["attribute_change"] | |
| else: | |
| prefix = "糖糖:" + event["prefix"] | |
| needy_chatbot.dialogue_history = [(None, prefix)] | |
| # 进入自由回复 | |
| global emoji_str | |
| response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, user_input ) | |
| history[-1] = (user_input,response) | |
| image_path = image_master.try_get_image(response, agent) | |
| if image_path is not None: | |
| history += [(None, (image_path,))] | |
| state = "EventEnd" | |
| if state == "EventEnd": | |
| return call_event_end(history, text, state,agent_text) | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def call_chat_master(history, text, state,agent_text): | |
| print("call chat master") | |
| # global state | |
| agent = Agent(agent_text) | |
| user_input = history[-1][0].strip() | |
| if "quit" in user_input or "Quit" in user_input or "QUIT" in user_input: | |
| state = "ShowMenu" | |
| history[-1] = (user_input,"返回主菜单\n"+ official_response ) | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| query_text = user_input | |
| global emoji_str | |
| response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, query_text ) | |
| history[-1] = (user_input,response) | |
| image_path = image_master.try_get_image(response, agent) | |
| if image_path is not None: | |
| history += [(None, (image_path,))] | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def grcall_game_master(history, text, state,agent_text): | |
| print("call game master") | |
| history += [(text, None)] | |
| if state == "ShowMenu": | |
| return call_showmenu(history, text,state,agent_text) | |
| elif state == "ParseMenuChoice": | |
| return call_parse_menu_choice(history, text, state,agent_text) | |
| elif state == "ChatMaster": | |
| return call_chat_master(history, text, state,agent_text) | |
| elif state == "EventMaster": | |
| return call_event_master(history, text, state,agent_text) | |
| elif state == "EventEnd": | |
| return call_event_end(history, text, state,agent_text) | |
| return history, gr.Textbox(value="", interactive=True), state,agent_text | |
| def add_file(history, file): | |
| history = history + [((file.name,), None)] | |
| return history | |
| def bot(history): | |
| response = "**That's cool!**" | |
| history[-1][1] = "" | |
| for character in response: | |
| history[-1][1] += character | |
| time.sleep(0.05) | |
| yield history | |
| def update_memory(state): | |
| if state == "ChatMaster" or state == "EventMaster": | |
| global emoji_str | |
| return emoji_str | |
| else: | |
| return "" | |
| def change_state(slider_stress, slider_darkness, slider_affection): | |
| # print(agent["Stress"]) | |
| agent = Agent() | |
| agent["Stress"] = slider_stress | |
| agent["Darkness"] = slider_darkness | |
| agent["Affection"] = slider_affection | |
| agent_text = agent.save_to_str() | |
| return agent_text | |
| def update_attribute_state(agent_text): | |
| agent = Agent(agent_text) | |
| slider_stress = int( agent["Stress"] ) | |
| slider_darkness = int( agent["Darkness"] ) | |
| slider_affection = int( agent["Affection"] ) | |
| return slider_stress, slider_darkness, slider_affection | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Chat凉宫春日_x_AI糖糖 | |
| Powered by 文心一言(3.5)版本 | |
| 仍然在开发中, 细节见《项目作者和说明》 | |
| """ | |
| ) | |
| with gr.Tab("Needy"): | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| height = 800, | |
| avatar_images=(None, ("avatar.png")), | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| scale=4, | |
| show_label=False, | |
| placeholder="输入任何字符开始游戏", | |
| container=False, | |
| ) | |
| # btn = gr.UploadButton("📁", file_types=["image", "video", "audio"]) | |
| submit_btr = gr.Button("回车") | |
| with gr.Row(): | |
| memory_emoji_text = gr.Textbox(label="糖糖当前的记忆", value = "",interactive = False, visible=False) | |
| with gr.Tab("糖糖的状态"): | |
| with gr.Row(): | |
| update_attribute_button = gr.Button("同步状态条 | 改变Attribute前必按!") | |
| with gr.Row(): | |
| default_agent_str = agent.save_to_str() | |
| slider_stress = gr.Slider(0, 100, step=1, label = "Stress") | |
| state_stress = gr.State(value=0) | |
| slider_darkness = gr.Slider(0, 100, step=1, label = "Darkness") | |
| state_darkness = gr.State(value=0) | |
| slider_affection = gr.Slider(0, 100, step=1, label = "Affection") | |
| state_affection = gr.State(value=0) | |
| with gr.Row(): | |
| state_text = gr.Textbox(label="整体状态机状态", value = "ShowMenu",interactive = False) | |
| with gr.Row(): | |
| default_agent_str = agent.save_to_str() | |
| agent_text = gr.Textbox(label="糖糖状态", value = default_agent_str,interactive = False) | |
| with gr.Tab("项目作者和说明"): | |
| gr.Markdown(markdown_str) | |
| slider_stress.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
| slider_darkness.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
| slider_affection.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text]) | |
| update_attribute_button.click(update_attribute_state, inputs = [agent_text], outputs = [slider_stress, slider_darkness, slider_affection]) | |
| txt_msg = txt.submit(grcall_game_master, \ | |
| [chatbot, txt, state_text,agent_text], \ | |
| [chatbot, txt, state_text,agent_text], queue=False) | |
| txt_msg = submit_btr.click(grcall_game_master, \ | |
| [chatbot, txt, state_text,agent_text], \ | |
| [chatbot, txt, state_text,agent_text], queue=False) | |
| # txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
| # bot, chatbot, chatbot, api_name="bot_response" | |
| # ) | |
| # txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) | |
| # file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
| # bot, chatbot, chatbot | |
| # ) | |
| demo.queue() | |
| # if __name__ == "__main__": | |
| demo.launch(allowed_paths=["avatar.png"],debug = True) | |