from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel import uvicorn import json from openai import OpenAI from rag import retrieve app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Groq client setup client = OpenAI( api_key="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", base_url="https://api.groq.com/openai/v1" ) # --- Emotional Expression Logic --- _EXPRESSION_KEYWORDS = { "SAFETY": ["concerned", "helpline", "reach out", "trusted person", "crisis", "emergency"], "COMFORTING": ["comfort", "here for you", "not alone", "support", "care"], "EMPATHETIC": ["understand", "hear you", "feel", "must be", "sounds"], "REFLECTIVE": ["think", "ponder", "reflect", "maybe", "perhaps"], "WARM": ["proud", "happy", "glad", "wonderful", "great", "good"], "STRESSED": ["overwhelmed", "anxious", "panic", "stress", "worry"], "TIRED": ["exhausted", "tired", "sleep", "fatigue", "drain"], "STEADY": ["breathe", "calm", "ground", "focus", "present"], "TALKING": ["tell me more", "go on", "listening", "share"], "NEUTRAL": ["yes", "no", "noted", "okay", "alright"] } def get_bot_expression(response_text: str) -> str: text_lower = response_text.lower() for expression, keywords in _EXPRESSION_KEYWORDS.items(): if any(kw in text_lower for kw in keywords): return expression return "DEFAULT" class ChatRequest(BaseModel): session_id: str message: str sessions = {} def get_messages(req_session_id, req_message): if req_session_id not in sessions: sessions[req_session_id] = [] history = sessions[req_session_id] # Retrieve RAG context rag_context = retrieve(req_message) system_prompt = "You are a supportive mental health companion.\n\nSpeak directly to the user.\n\nDo not explain anything.\nDo not give examples.\nDo not describe your response.\nDo not mention being an assistant.\n\nJust reply naturally like a caring human." if rag_context: system_prompt += f"\n\nUse the following knowledge base context to inform your response if relevant:\n{rag_context}" messages = [ {"role": "system", "content": system_prompt} ] # Keep only the last 5 conversation turns to prevent context window overflow for u, a in history[-5:]: messages.append({"role": "user", "content": u}) messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": req_message}) return history, messages def clean_response(text): bad_patterns = [ "User:", "<|user|>", "<|assistant|>", "example", "response", "assistant:", "Sure," ] for p in bad_patterns: text = text.split(p)[0] return text.strip() @app.post("/chat") def chat_endpoint(req: ChatRequest): history, messages = get_messages(req.session_id, req.message) try: response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=messages, max_tokens=150, # Keep responses relatively short temperature=0.6, top_p=0.9 ) text = response.choices[0].message.content cleaned_response = clean_response(text) history.append((req.message, cleaned_response)) sessions[req.session_id] = history[-5:] return { "reply": cleaned_response, "expression": get_bot_expression(cleaned_response) } except Exception as e: return {"error": str(e)} @app.post("/chat_stream") def chat_stream_endpoint(req: ChatRequest): history, messages = get_messages(req.session_id, req.message) def event_stream(): try: response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=messages, max_tokens=150, temperature=0.6, top_p=0.9, stream=True ) generated_text = "" for chunk in response: if len(chunk.choices) > 0 and chunk.choices[0].delta.content: token = chunk.choices[0].delta.content generated_text += token yield f"data: {json.dumps({'token': token})}\n\n" cleaned_response = clean_response(generated_text) expression = get_bot_expression(cleaned_response) history.append((req.message, cleaned_response)) sessions[req.session_id] = history[-5:] yield f"data: {json.dumps({'reply': cleaned_response, 'expression': expression})}\n\n" yield "data: [DONE]\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") class SummarizeRequest(BaseModel): session_id: str @app.post("/summarize") def summarize_endpoint(req: SummarizeRequest): history = sessions.get(req.session_id, []) if not history: return {"summary": "We haven't talked much yet, but I'm here if you need me!"} convo_text = "" for u, a in history: convo_text += f"User: {u}\nAssistant: {a}\n\n" prompt = f"Please provide a brief, compassionate, and supportive 2-3 sentence summary of the following mental health conversation. Focus on the user's feelings and any positive steps or realizations. Do not be overly clinical, and do not use robotic terms.\n\nConversation:\n{convo_text}" try: response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=[{"role": "system", "content": prompt}], max_tokens=150, temperature=0.5, top_p=0.9 ) summary_text = response.choices[0].message.content.strip() # Optionally clear session history: sessions[req.session_id] = [] return {"summary": summary_text} except Exception as e: return {"error": str(e)} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)