Spaces:
Sleeping
Sleeping
| # app.py ── 2025-06-08 适配 HuggingFace CPU Space | |
| import os, logging, gradio as gr | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.llms import HuggingFacePipeline | |
| logging.basicConfig(level=logging.INFO) | |
| # ========= 1. 载入本地向量库 ========= | |
| embedder = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
| ) | |
| VEC_DIR = "vector_store" | |
| if not (os.path.isdir(VEC_DIR) and os.path.isfile(f"{VEC_DIR}/chroma.sqlite3")): | |
| raise RuntimeError(f"❌ 未找到完整向量库 {VEC_DIR},请先执行 build_vector_store.py") | |
| vectordb = Chroma(persist_directory=VEC_DIR, embedding_function=embedder) | |
| # ========= 2. 载入轻量 LLM ========= | |
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B CPU 可跑 | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", # 需要 requirements 里有 accelerate | |
| torch_dtype="auto", | |
| ) | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| llm = HuggingFacePipeline(pipeline=generator) | |
| # ========= 3. 构建 RAG 问答链 ========= | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectordb.as_retriever(search_kwargs={"k": 3}), | |
| ) | |
| # ========= 4. 业务函数 ========= | |
| def simple_qa(user_query: str): | |
| if not user_query.strip(): | |
| return "⚠️ 请输入学习问题,例如:什么是定积分?" | |
| try: | |
| return qa_chain.run(user_query) | |
| except Exception as e: | |
| logging.error(f"[QA ERROR] {e}") | |
| return f"⚠️ 问答失败:{e}" | |
| def generate_outline(topic: str): | |
| if not topic.strip(): | |
| return "⚠️ 请输入章节或主题", "" | |
| try: | |
| docs = vectordb.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(topic) | |
| snippet = "\n---\n".join(d.page_content for d in docs) | |
| prompt = ( | |
| f"请基于以下资料,为「{topic}」生成结构化学习大纲,格式:\n" | |
| f"一、章节标题\n 1. 节标题\n (1)要点…\n\n" | |
| f"资料:\n{snippet}\n\n大纲:" | |
| ) | |
| outline = llm.invoke(prompt).strip() | |
| return outline, snippet | |
| except Exception as e: | |
| logging.error(f"[OUTLINE ERROR] {e}") | |
| return "⚠️ 生成失败", "" | |
| def placeholder(*_): | |
| return "功能待开发…" | |
| # ========= 5. Gradio UI ========= | |
| with gr.Blocks(title="智能学习助手") as demo: | |
| gr.Markdown("# 智能学习助手 v2.0\n💡 大学生专业课 RAG Demo") | |
| with gr.Tabs(): | |
| with gr.TabItem("智能问答"): | |
| chatbot = gr.Chatbot(height=350) | |
| msg = gr.Textbox(placeholder="在此提问…") | |
| def chat(m, hist): | |
| ans = simple_qa(m) | |
| hist.append((m, ans)) | |
| return "", hist | |
| msg.submit(chat, [msg, chatbot], [msg, chatbot]) | |
| with gr.TabItem("生成学习大纲"): | |
| topic = gr.Textbox(label="章节主题", placeholder="高等数学 第六章 定积分") | |
| outline = gr.Textbox(label="学习大纲", lines=12) | |
| debug = gr.Textbox(label="调试:检索片段", lines=6) | |
| gen = gr.Button("生成") | |
| gen.click(generate_outline, [topic], [outline, debug]) | |
| with gr.TabItem("自动出题"): | |
| placeholder(label="待开发") | |
| with gr.TabItem("答案批改"): | |
| placeholder(label="待开发") | |
| gr.Markdown("---\nPowered by LangChain • TinyLlama • Chroma") | |
| if __name__ == "__main__": | |
| demo.launch() | |