| from fastapi import FastAPI, UploadFile, File, HTTPException
|
| import shutil
|
| from fastapi.responses import StreamingResponse
|
| import json
|
| import os
|
| from pydantic import BaseModel
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from scripts.rag import RagPipeline
|
| from scripts.main import set_rag_instance
|
| from scripts.main import stream_chat_response
|
|
|
|
|
| app = FastAPI(version='1.0', title='FinAI', description="A finetuned qwen model for financial QA with rag.")
|
| rag = RagPipeline()
|
| set_rag_instance(rag)
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| class ChatRequest(BaseModel):
|
| message: str
|
| thread_id: str = "default"
|
|
|
|
|
| @app.get('/')
|
| def health_check():
|
| return {'status' : 'The api is live.'}
|
|
|
|
|
| @app.post("/upload")
|
| async def upload_document(file: UploadFile = File(...)):
|
| try:
|
| if not file.filename.endswith(".pdf"):
|
| raise HTTPException(
|
| status_code=400,
|
| detail="Only PDF files allowed."
|
| )
|
|
|
| os.makedirs("temp_docs", exist_ok=True)
|
|
|
| file_path = f"temp_docs/{file.filename}"
|
|
|
| with open(file_path, "wb") as buffer:
|
| shutil.copyfileobj(file.file, buffer)
|
|
|
| rag.delete_all_docs()
|
|
|
| docs = rag.load_docs(file_path)
|
| split_docs = rag.split_docs(docs)
|
|
|
| rag.add_docs(split_docs)
|
| rag.create_bm25(split_docs)
|
|
|
| return {
|
| "status": "success",
|
| "message": "Document uploaded successfully",
|
| "chunks": len(split_docs),
|
| "document": file.filename
|
| }
|
|
|
| except Exception as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=str(e)
|
| )
|
|
|
|
|
| @app.delete("/reset")
|
| async def reset_docs():
|
| try:
|
| rag.delete_all_docs()
|
|
|
| return {
|
| "status": "success",
|
| "message": "All docs deleted"
|
| }
|
|
|
| except Exception as e:
|
| raise HTTPException(
|
| status_code=500,
|
| detail=str(e)
|
| )
|
|
|
|
|
|
|
| @app.post("/chat/stream")
|
| async def chat_stream(request: ChatRequest):
|
|
|
| def event_generator():
|
| try:
|
| metadata_sent = False
|
|
|
| for chunk in stream_chat_response(
|
| user_message=request.message,
|
| thread_id=request.thread_id
|
| ):
|
| if not metadata_sent:
|
| metadata_event = {
|
| "type": "metadata",
|
| "used_rag": chunk["metadata"]["used_rag"],
|
| "sources": chunk["metadata"]["sources"],
|
| "thread_id": chunk["metadata"]["thread_id"]
|
| }
|
|
|
| yield f"data: {json.dumps(metadata_event)}\n\n"
|
|
|
| metadata_sent = True
|
|
|
| token_event = {
|
| "type": "token",
|
| "content": chunk["token"]
|
| }
|
|
|
| yield f"data: {json.dumps(token_event)}\n\n"
|
|
|
| done_event = {
|
| "type": "done"
|
| }
|
|
|
| yield f"data: {json.dumps(done_event)}\n\n"
|
|
|
| except Exception as e:
|
| error_event = {
|
| "type": "error",
|
| "message": str(e)
|
| }
|
|
|
| yield f"data: {json.dumps(error_event)}\n\n"
|
|
|
| return StreamingResponse(
|
| event_generator(),
|
| media_type="text/event-stream"
|
| )
|
|
|
|
|