"""Retriever tools for V2 tool-calling agent.""" from typing import Any, Dict, List, Optional from langchain_core.documents import Document from langchain_core.tools import tool from langchain.tools import ToolRuntime from langgraph.config import get_stream_writer from retrievers.supabase import get_retriever from services.retrieval.prestations_v2 import get_prestations_v2_retriever from services.vectorstore_service import VectorStoreServiceError, list_project_chunks_paginated # Bornes communes pour le top-k des tools de retrieval (formations / prestations). _TOOL_K_DEFAULT = 5 _TOOL_K_MIN = 3 _TOOL_K_MAX = 10 # Pagination liste projet (scan complet des chunks, pas de recherche vectorielle). # Max 20 chunks par appel : enchaƮner les appels avec next_offset tant que has_more. _PROJECT_LIST_DEFAULT = 10 _PROJECT_LIST_MAX = 10 _META_ALLOWLIST = ( "title", "similarity", "source", "page_number", "type", "contact", "link", "url", ) def _clamp_k( k: int, default: int = _TOOL_K_DEFAULT, min_k: int = _TOOL_K_MIN, max_k: int = _TOOL_K_MAX, ) -> int: try: value = int(k) except Exception: value = default return max(min_k, min(max_k, value)) def _clamp_project_list_limit(limit: int) -> int: try: value = int(limit) except Exception: value = _PROJECT_LIST_DEFAULT return max(1, min(_PROJECT_LIST_MAX, value)) def _clamp_offset(offset: int) -> int: try: return max(0, int(offset)) except Exception: return 0 def _clamp_threshold( value: float, default: float = 0.5, lo: float = 0.0, hi: float = 1.0, ) -> float: try: v = float(value) except Exception: v = default return max(lo, min(hi, v)) def _trim_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: if not meta: return {} return { key: meta[key] for key in _META_ALLOWLIST if key in meta and meta[key] not in (None, "") } def _serialize_docs(docs: List[Document], max_items: int = 8) -> List[Dict[str, Any]]: items: List[Dict[str, Any]] = [] for doc in docs[:max_items]: items.append( { "page_content": doc.page_content or "", "metadata": _trim_metadata(doc.metadata), } ) return items @tool("search_formations") def search_formations(query: str, k: int = _TOOL_K_DEFAULT) -> Dict[str, Any]: """Search formation catalogue documents by semantic similarity. Use when user needs training recommendations or details. """ writer = get_stream_writer() writer({"kind": "tool", "tool": "search_formations", "message": f"Recherche de formations: {query}"}) top_k = _clamp_k(k) retriever = get_retriever("formation", k=top_k) docs = retriever.invoke(query or "") return { "tool": "search_formations", "count": len(docs), "items": _serialize_docs(docs, max_items=top_k), "applied": {"k": top_k, "query": query}, } @tool("search_prestations") def search_prestations( query: str, k: int = _TOOL_K_DEFAULT, score_threshold: float = 0.5, offset: int = 0, ) -> Dict[str, Any]: """Search service/prestation catalogue (V2: documents_v2, mistral-embed, match_documents_v2_full). Use when user needs service recommendations or details. """ writer = get_stream_writer() writer({"kind": "tool", "tool": "search_prestations", "message": f"Recherche de prestations: {query}"}) top_k = _clamp_k(k) safe_threshold = _clamp_threshold(score_threshold, default=0.5) safe_offset = _clamp_offset(offset) retriever = get_prestations_v2_retriever() docs = retriever.search( query or "", k=top_k, score_threshold=safe_threshold, offset=safe_offset, ) return { "tool": "search_prestations", "count": len(docs), "items": _serialize_docs(docs, max_items=top_k), "applied": { "k": top_k, "score_threshold": safe_threshold, "offset": safe_offset, }, } @tool("search_project_docs") def search_project_docs( runtime: ToolRuntime, project_id: Optional[str], offset: int = 0, limit: int = _PROJECT_LIST_DEFAULT, ) -> Dict[str, Any]: """List project knowledge-base chunks from Supabase (projects index only). Returns a page of chunks ordered by stable id. This is not semantic search: call again with ``next_offset`` from the response until ``has_more`` is false. ``project_id`` may be omitted: the server injects it from request context when present. """ if not project_id: return { "tool": "search_project_docs", "count": 0, "items": [], "error": "project_id is required", } writer = get_stream_writer() writer({"kind": "tool", "tool": "search_project_docs", "message": f"Recherche dans le projet"}) safe_limit = _clamp_project_list_limit(limit) safe_offset = _clamp_offset(offset) state = getattr(runtime, "state", None) raw_filters = ( state.get("sources") if isinstance(state, dict) else getattr(state, "sources", None) ) document_ids: Optional[List[str]] = None if isinstance(raw_filters, list): cleaned = [str(item).strip() for item in raw_filters if item and str(item).strip()] document_ids = cleaned or None try: payload = list_project_chunks_paginated( project_id, document_ids=document_ids, offset=safe_offset, limit=safe_limit, index_name="projects", ) except VectorStoreServiceError as exc: return { "tool": "search_project_docs", "count": 0, "items": [], "error": str(exc), } docs: List[Document] = payload["documents"] n = len(docs) serialized_items = _serialize_docs(docs, max_items=n) return { "tool": "search_project_docs", "count": n, "items": serialized_items, "offset": payload["offset"], "limit": payload["limit"], "has_more": payload["has_more"], "next_offset": payload["next_offset"], "applied_sources": document_ids, }