import json from pathlib import Path from fastapi import FastAPI, UploadFile, File, BackgroundTasks, Form from fastapi.responses import StreamingResponse, FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import uvicorn # Page RAG imports from src.page_rag.retriever import retrieve as page_retrieve from src.page_rag.llm_engine import chat_with_rag as page_chat_with_rag, chat_no_context as page_chat_no_context from src.page_rag.indexer import index_all_documents as page_index_all # Vector RAG imports from src.vector_rag.retriever import retrieve as vector_retrieve from src.vector_rag.llm_engine import chat_with_rag as vector_chat_with_rag, chat_no_context as vector_chat_no_context from src.vector_rag.indexer import index_all_documents as vector_index_all # TOC RAG imports from src.toc_rag.retriever import retrieve as toc_retrieve from src.toc_rag.llm_engine import chat_with_rag as toc_chat_with_rag, chat_no_context as toc_chat_no_context from src.toc_rag.indexer import index_all_documents as toc_index_all from src.cost_tracker import add_chat_tokens, get_token_summary, clear_stats app = FastAPI(title="Local RAG Comparator") # Ensure directories exist Path("./documents").mkdir(exist_ok=True) Path("./db").mkdir(exist_ok=True) Path("./db/trees").mkdir(exist_ok=True) Path("./static").mkdir(exist_ok=True) Path("./thumbnails").mkdir(exist_ok=True) # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/thumbnails", StaticFiles(directory="thumbnails"), name="thumbnails") # State is_indexing = False indexing_progress = "" @app.get("/") def read_index(): return FileResponse("static/index.html") @app.get("/api/documents") def list_documents(): docs_path = Path("./documents") thumbnails_path = Path("./thumbnails") files = [] if docs_path.exists(): for f in docs_path.iterdir(): if f.is_file(): ext = f.suffix.lower().replace(".", "") thumb_url = None # Generate thumbnail for PDFs if ext == "pdf": try: import fitz thumb_file = thumbnails_path / f"{f.stem}.png" if not thumb_file.exists(): doc = fitz.open(str(f)) if len(doc) > 0: page = doc[0] pix = page.get_pixmap(matrix=fitz.Matrix(0.2, 0.2)) pix.save(str(thumb_file)) if thumb_file.exists(): thumb_url = f"/thumbnails/{f.stem}.png" except Exception as e: print("Thumbnail error:", e) files.append({ "name": f.name, "size": f.stat().st_size, "type": ext, "thumbnail": thumb_url }) return {"documents": files} from src.shared_db import get_collection @app.get("/api/documents/{filename}/chunks") def get_document_chunks(filename: str): """Fetch structured chunk data from both RAG indices for a specific document.""" page_chunks = [] vector_chunks = [] try: # Get from Page Index page_collection = get_collection("page_index") page_results = page_collection.get(where={"source": filename}) if page_results and page_results["ids"]: for i in range(len(page_results["ids"])): page_chunks.append({ "id": page_results["ids"][i], "text": page_results["documents"][i] if page_results["documents"] else "", "page_num": page_results["metadatas"][i].get("page_num", 0) if page_results["metadatas"] else 0, "tokens": page_results["metadatas"][i].get("tokens", 0) if page_results["metadatas"] else 0 }) page_chunks.sort(key=lambda x: x["page_num"]) except Exception as e: print(f"Error fetching page chunks: {e}") try: # Get from Vector Index vector_collection = get_collection("vector_index") vector_results = vector_collection.get(where={"source": filename}) if vector_results and vector_results["ids"]: for i in range(len(vector_results["ids"])): vector_chunks.append({ "id": vector_results["ids"][i], "text": vector_results["documents"][i] if vector_results["documents"] else "", "chunk_index": vector_results["metadatas"][i].get("chunk_index", 0) if vector_results["metadatas"] else 0, "tokens": vector_results["metadatas"][i].get("tokens", 0) if vector_results["metadatas"] else 0 }) vector_chunks.sort(key=lambda x: x["chunk_index"]) except Exception as e: print(f"Error fetching vector chunks: {e}") # ── TOC Index chunks ── toc_chunks = [] try: toc_collection = get_collection("toc_index") toc_results = toc_collection.get(where={"source": filename}) if toc_results and toc_results["ids"]: for i in range(len(toc_results["ids"])): toc_chunks.append({ "id": toc_results["ids"][i], "text": toc_results["documents"][i] if toc_results["documents"] else "", "node_id": toc_results["metadatas"][i].get("node_id", "") if toc_results["metadatas"] else "", "title": toc_results["metadatas"][i].get("title", "") if toc_results["metadatas"] else "", "tokens": toc_results["metadatas"][i].get("tokens", 0) if toc_results["metadatas"] else 0 }) toc_chunks.sort(key=lambda x: x["node_id"]) except Exception as e: print(f"Error fetching toc chunks: {e}") return { "filename": filename, "page_chunks": page_chunks, "vector_chunks": vector_chunks, "toc_chunks": toc_chunks } @app.get("/api/documents/{filename}/tree") def get_document_tree(filename: str): """Return the TOC tree structure for a document (if available).""" tree_path = Path("./db/trees") / f"{filename}.json" if tree_path.exists(): with open(tree_path, "r", encoding="utf-8") as f: return json.load(f) return {"filename": filename, "tree": None} @app.get("/api/index/status") def get_index_status(): global is_indexing, indexing_progress return { "is_indexing": is_indexing, "progress": indexing_progress } @app.get("/api/stats") def get_stats(): return get_token_summary() class ChatRequest(BaseModel): query: str top_k: int = 5 chat_history: list[dict] | None = None @app.post("/api/chat/page") def chat_page(req: ChatRequest): def generate(): try: pages = page_retrieve(req.query, top_k=req.top_k) source_data = [ {"source": p.source, "page_num": p.page_num, "score": p.score, "text": p.text, "tokens": p.tokens} for p in pages ] yield f"data: {json.dumps({'type': 'sources', 'sources': source_data})}\n\n" generator = page_chat_with_rag(req.query, pages, chat_history=req.chat_history) if pages else page_chat_no_context(req.query, chat_history=req.chat_history) for chunk in generator: if chunk.get("type") == "stats": add_chat_tokens(chunk.get("prompt_eval_count", 0), chunk.get("eval_count", 0)) yield f"data: {json.dumps(chunk)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream") @app.post("/api/chat/vector") def chat_vector(req: ChatRequest): def generate(): try: chunks = vector_retrieve(req.query, top_k=req.top_k) source_data = [ { "source": c.source, "chunk_index": c.chunk_index, "score": c.score, "mmr_score": c.mmr_score, "text": c.text, "tokens": c.tokens, } for c in chunks ] yield f"data: {json.dumps({'type': 'sources', 'sources': source_data})}\n\n" generator = vector_chat_with_rag(req.query, chunks, chat_history=req.chat_history) if chunks else vector_chat_no_context(req.query, chat_history=req.chat_history) for chunk in generator: if chunk.get("type") == "stats": add_chat_tokens(chunk.get("prompt_eval_count", 0), chunk.get("eval_count", 0)) yield f"data: {json.dumps(chunk)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream") @app.post("/api/chat/toc") def chat_toc(req: ChatRequest): def generate(): try: nodes = toc_retrieve(req.query, top_k=req.top_k) source_data = [ { "source": n.source, "node_id": n.node_id, "title": n.title, "score": n.score, "text": n.text, "tokens": n.tokens, } for n in nodes ] yield f"data: {json.dumps({'type': 'sources', 'sources': source_data})}\n\n" generator = toc_chat_with_rag(req.query, nodes, chat_history=req.chat_history) if nodes else toc_chat_no_context(req.query, chat_history=req.chat_history) for chunk in generator: if chunk.get("type") == "stats": add_chat_tokens(chunk.get("prompt_eval_count", 0), chunk.get("eval_count", 0)) yield f"data: {json.dumps(chunk)}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream") class IndexRequest(BaseModel): rag_mode: str = "all" # 'all', 'page', 'vector', 'toc' @app.post("/api/upload") async def upload_files( background_tasks: BackgroundTasks, files: list[UploadFile] = File(...), rag_mode: str = Form("all") ): global is_indexing docs_path = Path("./documents") count = 0 for file in files: file_path = docs_path / file.filename with open(file_path, "wb") as buffer: buffer.write(await file.read()) count += 1 def run_indexing(): global is_indexing, indexing_progress def progress_cb(msg): global indexing_progress indexing_progress = msg print(f"[Backend] {msg}") try: if "all" in rag_mode or "page" in rag_mode: progress_cb("Auto-indexing: Starting Page RAG...") page_index_all(progress_callback=progress_cb) if "all" in rag_mode or "vector" in rag_mode: progress_cb("Auto-indexing: Starting Vector RAG...") vector_index_all(progress_callback=progress_cb) if "all" in rag_mode or "toc" in rag_mode: progress_cb("Auto-indexing: Starting TOC RAG...") toc_index_all(progress_callback=progress_cb) progress_cb("Successfully indexed all documents.") except Exception as e: progress_cb(f"Error during indexing: {str(e)}") finally: is_indexing = False is_indexing = True background_tasks.add_task(run_indexing) return {"message": f"Saved {count} files successfully. Indexing started automatically."} @app.post("/api/index") def index_documents(req: IndexRequest, background_tasks: BackgroundTasks): global is_indexing, indexing_progress if is_indexing: return {"message": "Indexing is already in progress."} def run_indexing(): global is_indexing, indexing_progress is_indexing = True def progress_cb(msg): global indexing_progress indexing_progress = msg print(f"[Backend] {msg}") try: if "all" in req.rag_mode or "page" in req.rag_mode: progress_cb("Starting Page Indexing...") page_index_all(progress_callback=progress_cb) if "all" in req.rag_mode or "vector" in req.rag_mode: progress_cb("Starting Vector Indexing...") vector_index_all(progress_callback=progress_cb) if "all" in req.rag_mode or "toc" in req.rag_mode: progress_cb("Starting TOC Indexing...") toc_index_all(progress_callback=progress_cb) progress_cb("Indexing completed.") except Exception as e: progress_cb(f"Error: {str(e)}") finally: is_indexing = False is_indexing = True background_tasks.add_task(run_indexing) return {"message": "Indexing started in the background."} @app.post("/api/index/clear") def clear_index(): from src.shared_db import clear_chroma_storage # Purge on-disk database files and re-initialize a fresh client clear_chroma_storage() # Clear any app-specific stats clear_stats() # Delete all files in documents/ docs_path = Path("./documents") if docs_path.exists(): for f in docs_path.iterdir(): if f.is_file(): f.unlink() # Delete all files in thumbnails/ thumbnails_path = Path("./thumbnails") if thumbnails_path.exists(): for f in thumbnails_path.iterdir(): if f.is_file(): f.unlink() # Delete saved tree JSON files trees_path = Path("./db/trees") if trees_path.exists(): for f in trees_path.iterdir(): if f.is_file(): f.unlink() return {"message": "All data cleared: collections, storage directory, documents, thumbnails, and trees."} if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)