"""Vector store service wrapping SupabaseVectorStore and embeddings. Centralizes initialization to keep routes/services clean and consistent. """ from __future__ import annotations import warnings from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple import os from langchain_core.documents import Document from langchain_mistralai import MistralAIEmbeddings from langchain_community.vectorstores import SupabaseVectorStore from supabase import create_client, Client from config.settings import settings class VectorStoreServiceError(Exception): pass class PatchedSupabaseVectorStore(SupabaseVectorStore): """Compatibility patch across postgrest builder API variants.""" @staticmethod def _set_query_param(query_builder: Any, key: str, value: Any) -> None: """Set query params on both legacy and newer builder shapes.""" if hasattr(query_builder, "params"): query_builder.params = query_builder.params.set(key, value) return request_obj = getattr(query_builder, "request", None) if request_obj is not None and hasattr(request_obj, "params"): request_obj.params = request_obj.params.set(key, value) return raise AttributeError( f"Unsupported RPC query builder shape: {type(query_builder).__name__}" ) def similarity_search_by_vector_with_relevance_scores( self, query: List[float], k: int, filter: Optional[Dict[str, Any]] = None, postgrest_filter: Optional[str] = None, score_threshold: Optional[float] = None, ) -> List[Tuple[Document, float]]: if filter: for key, value in filter.items(): if isinstance(value, dict) and "$in" in value: in_values = value["$in"] values_str = ",".join(f"'{str(v)}'" for v in in_values) new_filter = f"metadata->>{key} IN ({values_str})" if postgrest_filter: postgrest_filter = f"({postgrest_filter}) and ({new_filter})" else: postgrest_filter = new_filter match_documents_params = self.match_args(query, filter) query_builder = self._client.rpc(self.query_name, match_documents_params) if postgrest_filter: self._set_query_param(query_builder, "and", f"({postgrest_filter})") self._set_query_param(query_builder, "limit", k) res = query_builder.execute() match_result = [ ( Document( metadata=search.get("metadata", {}), page_content=search.get("content", ""), ), search.get("similarity", 0.0), ) for search in res.data if search.get("content") ] if score_threshold is not None: match_result = [ (doc, similarity) for doc, similarity in match_result if similarity >= score_threshold ] if len(match_result) == 0: warnings.warn( "No relevant docs were retrieved using the relevance score" f" threshold {score_threshold}" ) return match_result @lru_cache(maxsize=1) def _get_supabase_client() -> Client: url = settings.supabase_url or os.getenv("SUPABASE_URL") key = settings.supabase_key or ( os.getenv("SUPABASE_KEY") or os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY") ) if not url or not key: raise VectorStoreServiceError("SUPABASE_URL and a SUPABASE_*KEY env var are required.") return create_client(url, key) @lru_cache(maxsize=1) def _get_embeddings() -> MistralAIEmbeddings: return MistralAIEmbeddings(model="mistral-embed", api_key=settings.mistralai_api_key) def _resolve_table_and_query( index_name: Optional[str], table_name: Optional[str], query_name: Optional[str], ) -> Tuple[str, str]: # Explicit beats implicit if table_name and query_name: return table_name, query_name # Named logical index if index_name: idx = settings.vector_indexes.get(index_name) if not idx: raise VectorStoreServiceError(f"Unknown vector index '{index_name}'. Configure settings.vector_indexes.") return idx["table"], idx["query_name"] # Backward-compatible default return settings.supabase_table, settings.supabase_match_fn @lru_cache(maxsize=16) def get_vector_store( index_name: Optional[str] = None, *, table_name: Optional[str] = None, query_name: Optional[str] = None, ) -> PatchedSupabaseVectorStore: client = _get_supabase_client() emb = _get_embeddings() table, query = _resolve_table_and_query(index_name, table_name, query_name) return PatchedSupabaseVectorStore( embedding=emb, client=client, table_name=table, query_name=query, ) def add_documents( documents: List[Document], index_name: Optional[str] = None, *, table_name: Optional[str] = None, query_name: Optional[str] = None, ) -> List[str]: if not documents: return [] vs = get_vector_store(index_name, table_name=table_name, query_name=query_name) return vs.add_documents(documents) def document_exists( project_id: str, document_id: str, index_name: Optional[str] = None, *, table_name: Optional[str] = None, query_name: Optional[str] = None, ) -> bool: """ Check if at least one row exists for the given project_id and document_id. """ client = _get_supabase_client() table, _ = _resolve_table_and_query(index_name, table_name, query_name) resp = ( client.table(table) .select("id") .eq("project_id", project_id) .eq("document_id", document_id) .limit(1) .execute() ) data = getattr(resp, "data", None) return bool(data) def list_project_chunks_paginated( project_id: str, *, document_ids: Optional[List[str]] = None, offset: int = 0, limit: int = 20, index_name: str = "projects", ) -> Dict[str, Any]: """ List all vector rows (chunks) for a project from the projects index table. Pagination is stable on ``id`` ascending. Optional filter: ``document_id`` must be one of the provided UUIDs. Returns a dict with ``documents`` (LangChain ``Document`` list), ``offset``, ``limit``, ``has_more``, ``next_offset`` (if more pages exist). """ if index_name != "projects": raise VectorStoreServiceError("list_project_chunks_paginated is only supported for index 'projects'.") if not (project_id or "").strip(): raise VectorStoreServiceError("project_id is required.") safe_limit = max(1, min(20, int(limit))) safe_offset = max(0, int(offset)) client = _get_supabase_client() table, _ = _resolve_table_and_query(index_name, None, None) query_builder = ( client.table(table) .select("id,content,metadata,document_id,project_id") .eq("project_id", project_id.strip()) .order("id", desc=False) ) if document_ids: cleaned = [str(doc_id).strip() for doc_id in document_ids if doc_id and str(doc_id).strip()] if cleaned: query_builder = query_builder.in_("document_id", cleaned) response = query_builder.range(safe_offset, safe_offset + safe_limit - 1).execute() rows = getattr(response, "data", None) or [] documents: List[Document] = [] for row in rows: meta_raw = row.get("metadata") meta: Dict[str, Any] = dict(meta_raw) if isinstance(meta_raw, dict) else {} if row.get("id") is not None: meta.setdefault("id", row.get("id")) if row.get("document_id") is not None: meta["document_id"] = row.get("document_id") if row.get("project_id") is not None: meta["project_id"] = row.get("project_id") documents.append( Document( page_content=row.get("content") or "", metadata=meta, ) ) has_more = len(rows) == safe_limit next_offset = safe_offset + len(rows) if has_more else None return { "documents": documents, "offset": safe_offset, "limit": safe_limit, "has_more": has_more, "next_offset": next_offset, } def delete_documents_by_document_id( project_id: str, document_id: str, index_name: Optional[str] = None, *, table_name: Optional[str] = None, query_name: Optional[str] = None, ) -> int: """ Delete all rows matching project_id and document_id. Returns the number of deleted rows. """ client = _get_supabase_client() table, _ = _resolve_table_and_query(index_name, table_name, query_name) resp = ( client.table(table) .delete() .eq("project_id", project_id) .eq("document_id", document_id) .execute() ) data = getattr(resp, "data", None) return len(data) if data else 0