"""Classifier agent: decide routing between CLASSIC and SUMMARIZE.""" from typing import Callable from langchain_core.prompts import ChatPromptTemplate from langchain_core.language_models.chat_models import BaseChatModel from graphs.state import AgentState from graphs.models import QueryClassification from graphs.prompts import CLASSIFIER_SYSTEM_PROMPT def classifier_node(llm: BaseChatModel) -> Callable[[AgentState], AgentState]: print("Classifier node") prompt = ChatPromptTemplate.from_messages( [ ("system", CLASSIFIER_SYSTEM_PROMPT), ( "human", "Historique: {messages}\nQuestion: {query}", ), ] ) structured_llm = llm.with_structured_output(QueryClassification) chain = prompt | structured_llm def _run(graph_state: AgentState) -> AgentState: messages = graph_state.get("messages", []) # Prefer explicit state.query when provided, else last human message last_query = (graph_state.get("query") or "").strip() if not last_query: for msg in reversed(list(messages)): if getattr(msg, "type", "") == "human": last_query = (getattr(msg, "content", "") or "").strip() break try: result = chain.invoke({"messages": messages, "query": last_query}) except Exception as e: result = QueryClassification(classification="UNKNOWN", reasoning=str(e)) graph_state["classification"] = result return graph_state return _run