"""Retrieval nodes for LangGraph workflows.""" from typing import Dict, Optional from graphs.state import AgentState from retrievers.supabase import get_retriever, get_retriever_for, format_documents def retrieve_catalogue(state: AgentState) -> AgentState: """Builds a query from the state and retrieves formation and prestation docs. Returns the augmented state with docs and formatted contexts. """ query_text = state.get("query") or "" if not query_text: for msg in reversed(list(state.get("messages", []))): if getattr(msg, "type", "") == "human": query_text = (msg.content or "").strip() break formation_retriever = get_retriever("formation", k=8) prestation_retriever = get_retriever("prestation", k=8) formation_docs = formation_retriever.invoke(query_text) prestation_docs = prestation_retriever.invoke(query_text) new_formation_context = format_documents(formation_docs, "formation") new_prestation_context = format_documents(prestation_docs, "prestation") # Merge with any existing contexts (e.g., from project retrieval) to keep both old_formation_context = state.get("formation_context", "") old_prestation_context = state.get("prestation_context", "") formation_context = ( (old_formation_context + "\n\n---\n\n" + new_formation_context).strip() if old_formation_context else new_formation_context ) prestation_context = ( (old_prestation_context + "\n\n---\n\n" + new_prestation_context).strip() if old_prestation_context else new_prestation_context ) return { "formation_docs": formation_docs, "prestation_docs": prestation_docs, "formation_context": formation_context, "prestation_context": prestation_context, } def retrieve_projects(state: AgentState) -> AgentState: """Retrieve only project-scoped documents (formation and prestation) and add a system hint. Used when `project_id` is present to focus retrieval on the 'projects' vector index. """ # Extract user query query_text = state.get("query") or "" if not query_text: for msg in reversed(list(state.get("messages", []))): if getattr(msg, "type", "") == "human": query_text = (msg.content or "").strip() break project_id: Optional[str] = state.get("project_id") # type: ignore[assignment] index_name: Optional[str] = "projects" # type: ignore[assignment] # Safety: if no project_id, return state unchanged (router should avoid calling us) if not project_id: return {} extra = {"project_id": project_id} project_retriever = get_retriever_for(index_name, k=8, filter=extra) project_docs = project_retriever.invoke(query_text) projet_context = format_documents(project_docs, "project") return { "project_docs": project_docs, "project_context": projet_context, }