FinAI / app.py
junaid17's picture
Upload 13 files
ca67025 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
import shutil
from fastapi.responses import StreamingResponse
import json
import os
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from scripts.rag import RagPipeline
from scripts.main import set_rag_instance
from scripts.main import stream_chat_response
app = FastAPI(version='1.0', title='FinAI', description="A finetuned qwen model for financial QA with rag.")
rag = RagPipeline()
set_rag_instance(rag)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatRequest(BaseModel):
message: str
thread_id: str = "default"
@app.get('/')
def health_check():
return {'status' : 'The api is live.'}
@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
try:
if not file.filename.endswith(".pdf"):
raise HTTPException(
status_code=400,
detail="Only PDF files allowed."
)
os.makedirs("temp_docs", exist_ok=True)
file_path = f"temp_docs/{file.filename}"
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
rag.delete_all_docs()
docs = rag.load_docs(file_path)
split_docs = rag.split_docs(docs)
rag.add_docs(split_docs)
rag.create_bm25(split_docs)
return {
"status": "success",
"message": "Document uploaded successfully",
"chunks": len(split_docs),
"document": file.filename
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@app.delete("/reset")
async def reset_docs():
try:
rag.delete_all_docs()
return {
"status": "success",
"message": "All docs deleted"
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
def event_generator():
try:
metadata_sent = False
for chunk in stream_chat_response(
user_message=request.message,
thread_id=request.thread_id
):
if not metadata_sent:
metadata_event = {
"type": "metadata",
"used_rag": chunk["metadata"]["used_rag"],
"sources": chunk["metadata"]["sources"],
"thread_id": chunk["metadata"]["thread_id"]
}
yield f"data: {json.dumps(metadata_event)}\n\n"
metadata_sent = True
token_event = {
"type": "token",
"content": chunk["token"]
}
yield f"data: {json.dumps(token_event)}\n\n"
done_event = {
"type": "done"
}
yield f"data: {json.dumps(done_event)}\n\n"
except Exception as e:
error_event = {
"type": "error",
"message": str(e)
}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)