Cyril Dupland commited on
Commit
53e5530
·
1 Parent(s): 595f77d

Include project knowledge in Workflow

Browse files
api/routes/completion.py CHANGED
@@ -116,7 +116,8 @@ async def _complete(request: CompletionRequest) -> CompletionResponse:
116
  agent_type=request.agent_type,
117
  temperature=request.temperature,
118
  max_tokens=request.max_tokens,
119
- conversation_history=request.conversation_history
 
120
  )
121
 
122
  return CompletionResponse(**result)
@@ -141,7 +142,8 @@ async def _stream_completion(request: CompletionRequest) -> StreamingResponse:
141
  agent_type=request.agent_type,
142
  temperature=request.temperature,
143
  max_tokens=request.max_tokens,
144
- conversation_history=request.conversation_history
 
145
  ):
146
  # Format as SSE: "data: {json}\n\n"
147
  chunk_json = json.dumps(chunk, ensure_ascii=False)
 
116
  agent_type=request.agent_type,
117
  temperature=request.temperature,
118
  max_tokens=request.max_tokens,
119
+ conversation_history=request.conversation_history,
120
+ project_id=request.project_id
121
  )
122
 
123
  return CompletionResponse(**result)
 
142
  agent_type=request.agent_type,
143
  temperature=request.temperature,
144
  max_tokens=request.max_tokens,
145
+ conversation_history=request.conversation_history,
146
+ project_id=request.project_id,
147
  ):
148
  # Format as SSE: "data: {json}\n\n"
149
  chunk_json = json.dumps(chunk, ensure_ascii=False)
domain/models.py CHANGED
@@ -35,6 +35,8 @@ class CompletionRequest(BaseModel):
35
  default=None,
36
  description="Optional conversation history"
37
  )
 
 
38
 
39
 
40
  class CompletionResponse(BaseModel):
 
35
  default=None,
36
  description="Optional conversation history"
37
  )
38
+ # Project-scoped retrieval
39
+ project_id: Optional[str] = Field(default=None, description="Optional project id to scope retrieval")
40
 
41
 
42
  class CompletionResponse(BaseModel):
graphs/agents/chat_agent.py CHANGED
@@ -17,6 +17,19 @@ def chat_node(llm: BaseChatModel) -> Callable[[AgentState], AgentState]:
17
 
18
  formation_context = state.get("formation_context", "")
19
  prestation_context = state.get("prestation_context", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  if formation_context:
22
  sys_msgs.append(
 
17
 
18
  formation_context = state.get("formation_context", "")
19
  prestation_context = state.get("prestation_context", "")
20
+ project_context = state.get("project_context", "")
21
+
22
+
23
+ if project_context:
24
+ sys_msgs.append(
25
+ SystemMessage(
26
+ content=(
27
+ "CONTEXTE PROJET (extraits des documents du projet; n'utilise rien d'autre):\n\n"
28
+ f"{project_context}\n\n"
29
+ "Consignes projet: Ce contenu indique des informations complémentaires à prendre en compte pour répondre à la question. "
30
+ )
31
+ )
32
+ )
33
 
34
  if formation_context:
35
  sys_msgs.append(
graphs/base_graph.py CHANGED
@@ -23,6 +23,8 @@ class AgentState(TypedDict, total=False):
23
  prestation_docs: List[Document]
24
  formation_context: str
25
  prestation_context: str
 
 
26
 
27
 
28
  def create_simple_graph(llm: BaseChatModel):
 
23
  prestation_docs: List[Document]
24
  formation_context: str
25
  prestation_context: str
26
+ project_docs: List[Document]
27
+ project_context: str
28
 
29
 
30
  def create_simple_graph(llm: BaseChatModel):
graphs/nodes/retrieval.py CHANGED
@@ -1,11 +1,12 @@
1
  """Retrieval nodes for LangGraph workflows."""
2
- from typing import Dict
3
 
4
  from graphs.state import AgentState
5
- from retrievers.supabase import get_retriever, format_documents
 
6
 
7
 
8
- def retrieve_both_types(state: AgentState) -> AgentState:
9
  """Builds a query from the state and retrieves formation and prestation docs.
10
 
11
  Returns the augmented state with docs and formatted contexts.
@@ -18,14 +19,29 @@ def retrieve_both_types(state: AgentState) -> AgentState:
18
  query_text = (msg.content or "").strip()
19
  break
20
 
 
21
  formation_retriever = get_retriever("formation", k=8)
22
  prestation_retriever = get_retriever("prestation", k=8)
23
 
 
24
  formation_docs = formation_retriever.invoke(query_text)
25
  prestation_docs = prestation_retriever.invoke(query_text)
26
 
27
- formation_context = format_documents(formation_docs, "formation")
28
- prestation_context = format_documents(prestation_docs, "prestation")
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  return {
31
  "formation_docs": formation_docs,
@@ -35,3 +51,38 @@ def retrieve_both_types(state: AgentState) -> AgentState:
35
  }
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Retrieval nodes for LangGraph workflows."""
2
+ from typing import Dict, Optional
3
 
4
  from graphs.state import AgentState
5
+ from retrievers.supabase import get_retriever, get_retriever_for, format_documents
6
+ from langchain_core.messages import SystemMessage
7
 
8
 
9
+ def retrieve_catalogue(state: AgentState) -> AgentState:
10
  """Builds a query from the state and retrieves formation and prestation docs.
11
 
12
  Returns the augmented state with docs and formatted contexts.
 
19
  query_text = (msg.content or "").strip()
20
  break
21
 
22
+ # retrievers
23
  formation_retriever = get_retriever("formation", k=8)
24
  prestation_retriever = get_retriever("prestation", k=8)
25
 
26
+ # Invoke
27
  formation_docs = formation_retriever.invoke(query_text)
28
  prestation_docs = prestation_retriever.invoke(query_text)
29
 
30
+ new_formation_context = format_documents(formation_docs, "formation")
31
+ new_prestation_context = format_documents(prestation_docs, "prestation")
32
+
33
+ # Merge with any existing contexts (e.g., from project retrieval) to keep both
34
+ old_formation_context = state.get("formation_context", "")
35
+ old_prestation_context = state.get("prestation_context", "")
36
+
37
+ formation_context = (
38
+ (old_formation_context + "\n\n---\n\n" + new_formation_context).strip()
39
+ if old_formation_context else new_formation_context
40
+ )
41
+ prestation_context = (
42
+ (old_prestation_context + "\n\n---\n\n" + new_prestation_context).strip()
43
+ if old_prestation_context else new_prestation_context
44
+ )
45
 
46
  return {
47
  "formation_docs": formation_docs,
 
51
  }
52
 
53
 
54
+ def retrieve_projects(state: AgentState) -> AgentState:
55
+ """Retrieve only project-scoped documents (formation and prestation) and add a system hint.
56
+
57
+ Used when `project_id` is present to focus retrieval on the 'projects' vector index.
58
+ """
59
+ # Extract user query
60
+ query_text = state.get("query") or ""
61
+ if not query_text:
62
+ for msg in reversed(list(state.get("messages", []))):
63
+ if getattr(msg, "type", "") == "human":
64
+ query_text = (msg.content or "").strip()
65
+ break
66
+
67
+ project_id: Optional[str] = state.get("project_id") # type: ignore[assignment]
68
+ index_name: Optional[str] = "projects" # type: ignore[assignment]
69
+
70
+ # Safety: if no project_id, return state unchanged (router should avoid calling us)
71
+ if not project_id:
72
+ return {}
73
+
74
+ extra = {"project_id": project_id}
75
+
76
+ project_retriever = get_retriever_for(index_name, k=8, filter=extra)
77
+
78
+ project_docs = project_retriever.invoke(query_text)
79
+
80
+ projet_context = format_documents(project_docs, "project")
81
+
82
+
83
+ return {
84
+ "project_docs": project_docs,
85
+ "project_context": projet_context,
86
+ }
87
+
88
+
graphs/state.py CHANGED
@@ -14,13 +14,17 @@ class AgentState(TypedDict, total=False):
14
  # Conversation
15
  messages: Annotated[Sequence[BaseMessage], add_messages]
16
  query: Optional[str]
 
 
 
17
 
18
  # RAG retrieval results
19
  formation_docs: List[Document]
20
  prestation_docs: List[Document]
21
  formation_context: str
22
  prestation_context: str
23
-
 
24
  # Summarization artifacts
25
  summary_markdown: str
26
  summary_pdf_path: str # local path or URL if uploaded
 
14
  # Conversation
15
  messages: Annotated[Sequence[BaseMessage], add_messages]
16
  query: Optional[str]
17
+ # Project scoping
18
+ project_id: Optional[str]
19
+ index_name: Optional[str]
20
 
21
  # RAG retrieval results
22
  formation_docs: List[Document]
23
  prestation_docs: List[Document]
24
  formation_context: str
25
  prestation_context: str
26
+ project_docs: List[Document]
27
+ project_context: str
28
  # Summarization artifacts
29
  summary_markdown: str
30
  summary_pdf_path: str # local path or URL if uploaded
graphs/workflows/orchestrated.py CHANGED
@@ -4,7 +4,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
4
 
5
  from graphs.state import AgentState
6
  from graphs.agents.classifier_agent import classifier_node
7
- from graphs.nodes.retrieval import retrieve_both_types
8
  from graphs.agents.chat_agent import chat_node
9
  from graphs.agents.summarizer_agent import summarizer_llm_node, summarizer_export_node
10
  # from tools.markdown import markdown_to_html
@@ -17,7 +17,14 @@ def create_orchestrated_graph(llm: BaseChatModel):
17
 
18
  # Nodes
19
  workflow.add_node("classify", classifier_node(llm))
20
- workflow.add_node("retrieve", retrieve_both_types)
 
 
 
 
 
 
 
21
  workflow.add_node("agent", chat_node(llm))
22
  workflow.add_node("summarizer_llm", summarizer_llm_node(llm))
23
  workflow.add_node(
@@ -37,15 +44,29 @@ def create_orchestrated_graph(llm: BaseChatModel):
37
  "classify",
38
  lambda s: getattr(s.get("classification"), "classification", "CLASSIC"),
39
  {
40
- "CLASSIC": "retrieve",
 
41
  "SUMMARIZE": "summarizer_llm",
42
- "UNKNOWN": "retrieve",
 
 
 
 
 
 
 
 
 
 
43
  },
44
  )
45
 
46
  # Linear branches
 
 
47
  workflow.add_edge("retrieve", "agent")
48
  workflow.add_edge("agent", END)
 
49
  workflow.add_edge("summarizer_llm", "summarizer_export")
50
  workflow.add_edge("summarizer_export", END)
51
 
 
4
 
5
  from graphs.state import AgentState
6
  from graphs.agents.classifier_agent import classifier_node
7
+ from graphs.nodes.retrieval import retrieve_catalogue, retrieve_projects
8
  from graphs.agents.chat_agent import chat_node
9
  from graphs.agents.summarizer_agent import summarizer_llm_node, summarizer_export_node
10
  # from tools.markdown import markdown_to_html
 
17
 
18
  # Nodes
19
  workflow.add_node("classify", classifier_node(llm))
20
+ workflow.add_node("retrieve", retrieve_catalogue)
21
+ # Route to classic vs project retrieval
22
+ def _router_passthrough(state: AgentState) -> AgentState:
23
+ # Must write at least one allowed key; pass through the current query
24
+ q = state.get("query") or ""
25
+ return {"query": q}
26
+ workflow.add_node("retrieve_router", _router_passthrough)
27
+ workflow.add_node("retrieve_project", retrieve_projects)
28
  workflow.add_node("agent", chat_node(llm))
29
  workflow.add_node("summarizer_llm", summarizer_llm_node(llm))
30
  workflow.add_node(
 
44
  "classify",
45
  lambda s: getattr(s.get("classification"), "classification", "CLASSIC"),
46
  {
47
+ # Route through a retrieval router to optionally branch to project retrieval
48
+ "CLASSIC": "retrieve_router",
49
  "SUMMARIZE": "summarizer_llm",
50
+ "UNKNOWN": "retrieve_router",
51
+ },
52
+ )
53
+
54
+ # Conditional choice between project vs classic retrieval
55
+ workflow.add_conditional_edges(
56
+ "retrieve_router",
57
+ lambda s: "PROJECT" if s.get("project_id") else "CLASSIC",
58
+ {
59
+ "PROJECT": "retrieve_project",
60
+ "CLASSIC": "retrieve",
61
  },
62
  )
63
 
64
  # Linear branches
65
+ # If project path is taken, run project retrieval then classic retrieval
66
+ workflow.add_edge("retrieve_project", "retrieve")
67
  workflow.add_edge("retrieve", "agent")
68
  workflow.add_edge("agent", END)
69
+
70
  workflow.add_edge("summarizer_llm", "summarizer_export")
71
  workflow.add_edge("summarizer_export", END)
72
 
retrievers/supabase.py CHANGED
@@ -5,6 +5,7 @@ import os
5
 
6
  from langchain_core.documents import Document
7
  from langchain_mistralai import MistralAIEmbeddings
 
8
  from langchain_community.vectorstores import SupabaseVectorStore
9
  from supabase import create_client, Client
10
 
@@ -31,7 +32,7 @@ def get_retriever(doc_type: str, k: Optional[int] = None):
31
 
32
  client: Client = create_client(url, key)
33
  vector_store = SupabaseVectorStore(
34
- embedding=MistralAIEmbeddings(model="mistral-embed", api_key=settings.mistralai_api_key),
35
  client=client,
36
  table_name=settings.supabase_table,
37
  query_name=settings.supabase_match_fn,
@@ -42,12 +43,17 @@ def get_retriever(doc_type: str, k: Optional[int] = None):
42
  )
43
 
44
 
45
- def get_retriever_for(index_name: str, doc_type: str, k: Optional[int] = None):
 
 
 
 
46
  """Return a retriever for a specific logical index (table/query pair)."""
47
  vector_store = get_vector_store(index_name=index_name)
48
  top_k = int(k or settings.rag_top_k)
 
49
  return vector_store.as_retriever(
50
- search_kwargs={"k": top_k, "filter": {"type": doc_type}}
51
  )
52
 
53
 
 
5
 
6
  from langchain_core.documents import Document
7
  from langchain_mistralai import MistralAIEmbeddings
8
+ from langchain_openai import OpenAIEmbeddings
9
  from langchain_community.vectorstores import SupabaseVectorStore
10
  from supabase import create_client, Client
11
 
 
32
 
33
  client: Client = create_client(url, key)
34
  vector_store = SupabaseVectorStore(
35
+ embedding=OpenAIEmbeddings(),
36
  client=client,
37
  table_name=settings.supabase_table,
38
  query_name=settings.supabase_match_fn,
 
43
  )
44
 
45
 
46
+ def get_retriever_for(
47
+ index_name: str,
48
+ k: Optional[int] = None,
49
+ filter: Optional[dict] = None,
50
+ ):
51
  """Return a retriever for a specific logical index (table/query pair)."""
52
  vector_store = get_vector_store(index_name=index_name)
53
  top_k = int(k or settings.rag_top_k)
54
+
55
  return vector_store.as_retriever(
56
+ search_kwargs={"k": top_k, "filter": filter}
57
  )
58
 
59
 
services/agent_service.py CHANGED
@@ -32,7 +32,8 @@ class AgentService:
32
  agent_type: AgentType = AgentType.SIMPLE,
33
  temperature: float = 0.7,
34
  max_tokens: Optional[int] = None,
35
- conversation_history: Optional[List[Dict[str, str]]] = None
 
36
  ) -> dict:
37
  """
38
  Invoke agent for a single response (non-streaming).
@@ -65,7 +66,11 @@ class AgentService:
65
 
66
  # Execute graph with latency
67
  start_time = time.time()
68
- result = await graph.ainvoke({"messages": messages})
 
 
 
 
69
  latency_s = time.time() - start_time
70
 
71
  # Extract response
@@ -106,7 +111,8 @@ class AgentService:
106
  agent_type: AgentType = AgentType.SIMPLE,
107
  temperature: float = 0.7,
108
  max_tokens: Optional[int] = None,
109
- conversation_history: Optional[List[Dict[str, str]]] = None
 
110
  ) -> AsyncIterator[dict]:
111
  """
112
  Stream agent response token by token.
@@ -144,7 +150,11 @@ class AgentService:
144
  documents = []
145
 
146
  # Stream graph execution
147
- async for msg in graph.astream({"messages": messages}, stream_mode=["messages","updates"]):
 
 
 
 
148
  # LangGraph may yield (node_name, message) tuples in messages mode
149
  event = None
150
  params = None
 
32
  agent_type: AgentType = AgentType.SIMPLE,
33
  temperature: float = 0.7,
34
  max_tokens: Optional[int] = None,
35
+ conversation_history: Optional[List[Dict[str, str]]] = None,
36
+ project_id: Optional[str] = None
37
  ) -> dict:
38
  """
39
  Invoke agent for a single response (non-streaming).
 
66
 
67
  # Execute graph with latency
68
  start_time = time.time()
69
+ result = await graph.ainvoke({
70
+ "messages": messages,
71
+ "query": message,
72
+ "project_id": project_id
73
+ })
74
  latency_s = time.time() - start_time
75
 
76
  # Extract response
 
111
  agent_type: AgentType = AgentType.SIMPLE,
112
  temperature: float = 0.7,
113
  max_tokens: Optional[int] = None,
114
+ conversation_history: Optional[List[Dict[str, str]]] = None,
115
+ project_id: Optional[str] = None
116
  ) -> AsyncIterator[dict]:
117
  """
118
  Stream agent response token by token.
 
150
  documents = []
151
 
152
  # Stream graph execution
153
+ async for msg in graph.astream({
154
+ "messages": messages,
155
+ "query": message,
156
+ "project_id": project_id
157
+ }, stream_mode=["messages","updates"]):
158
  # LangGraph may yield (node_name, message) tuples in messages mode
159
  event = None
160
  params = None