"""Agent service for executing LangGraph agents.""" from typing import Optional, AsyncIterator, List, Dict, Any import time from langchain_core.messages import AIMessageChunk, HumanMessage, AIMessage, BaseMessage, SystemMessage from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from domain.enums import ModelName, AgentType from .llm_service import llm_service from .agent_registry import agent_registry from services.postprocessing.registry import build_orchestrator from services.postprocessing.context import RunContext # Shared checkpointer for text completion so server-side memory (thread_id) is consistent across requests. _text_checkpointer = MemorySaver() class AgentService: """ Service for executing agent graphs with different LLMs. This service is the bridge between the API layer and the LangGraph agents. It handles: - Creating the right LLM based on model selection - Getting the right agent graph from the registry - Executing the graph with or without streaming """ def __init__(self): """Initialize the agent service.""" pass async def invoke( self, message: str, model_name: ModelName, agent_type: AgentType = AgentType.SIMPLE, temperature: float = 0.7, max_tokens: Optional[int] = None, conversation_history: Optional[List[Dict[str, str]]] = None, project_id: Optional[str] = None, conversation_id: Optional[str] = None, ) -> dict: """ Invoke agent for a single response (non-streaming). Args: message: User message model_name: LLM model to use agent_type: Type of agent graph temperature: Sampling temperature max_tokens: Max tokens to generate conversation_history: Optional conversation history (ignored if conversation_id is set) project_id: Optional project id for retrieval conversation_id: Optional conversation id for server-side memory (thread_id) Returns: Response dictionary with content and metadata """ # Create LLM instance llm = llm_service.get_llm( model_name=model_name, temperature=temperature, streaming=False, max_tokens=max_tokens ) builder = agent_registry.get_builder(agent_type) use_memory = bool(conversation_id) graph = builder(llm, checkpointer=_text_checkpointer) if use_memory else builder(llm) if use_memory: messages = [HumanMessage(content=message)] config: Optional[Dict[str, Any]] = {"configurable": {"thread_id": conversation_id}} else: messages = self._prepare_messages(message, conversation_history) config = None # Execute graph with latency start_time = time.time() result = await graph.ainvoke( { "messages": messages, "query": message, "project_id": project_id }, config=config, ) latency_s = time.time() - start_time # Extract response response_message = result["messages"][-1] response_content = response_message.content # Prepare metadata and run post-processing pipeline usage = getattr(response_message, "usage_metadata", None) or {} usage_totals = self._normalize_usage(usage) usage_by_model = {model_name.value: usage_totals} ctx = RunContext( provider=model_name.provider.value, model=model_name.value, usage_totals=usage_totals, usage_by_model=usage_by_model, latency_s=latency_s, ) build_orchestrator().run(ctx) base_metadata: Dict[str, Any] = { "message_count": len(result["messages"]), } base_metadata.update(ctx.metadata_out) return { "response": response_content, "model": model_name.value, "agent_type": agent_type.value, "usage": usage, "metadata": base_metadata, } async def stream( self, message: str, model_name: ModelName, agent_type: AgentType = AgentType.SIMPLE, temperature: float = 0.7, max_tokens: Optional[int] = None, conversation_history: Optional[List[Dict[str, str]]] = None, project_id: Optional[str] = None, conversation_id: Optional[str] = None, ) -> AsyncIterator[dict]: """ Stream agent response token by token. Args: message: User message model_name: LLM model to use agent_type: Type of agent graph temperature: Sampling temperature max_tokens: Max tokens to generate conversation_history: Optional conversation history (ignored if conversation_id is set) project_id: Optional project id for retrieval conversation_id: Optional conversation id for server-side memory (thread_id) Yields: Dictionary chunks with content and metadata """ # Create LLM instance with streaming enabled llm = llm_service.get_llm( model_name=model_name, temperature=temperature, streaming=True, max_tokens=max_tokens ) builder = agent_registry.get_builder(agent_type) use_memory = bool(conversation_id) graph = builder(llm, checkpointer=_text_checkpointer) if use_memory else builder(llm) if use_memory: messages = [HumanMessage(content=message)] config = {"configurable": {"thread_id": conversation_id}} else: messages = self._prepare_messages(message, conversation_history) config = None # Track usage and latency for final emissions calculation usage_totals: Dict[str, int] = {} usage_by_model: Dict[str, Dict[str, int]] = {} start_time = time.time() documents = [] # Stream graph execution async for msg in graph.astream( { "messages": messages, "query": message, "project_id": project_id }, config=config, stream_mode=["messages", "updates"], ): # LangGraph may yield (node_name, message) tuples in messages mode event = None params = None # Only emit assistant outputs; ignore user/history echoes text: Optional[str] = None if msg[0] == "messages": chunk = msg[1] if isinstance(chunk, tuple) and len(chunk) == 2: # Prefer the BaseMessage element if present from langchain_core.messages import BaseMessage as _LCBaseMessage if isinstance(chunk[1], _LCBaseMessage): event = chunk[1] params = chunk[0] elif isinstance(chunk[0], _LCBaseMessage): event = chunk[0] params = chunk[1] else: # Fallback to second element by convention event = chunk[1] else: event = chunk if msg[0] == "updates": node = msg[1] # Ajout de contrôles pour éviter KeyError/TypeError et assurer l'existence de la clé/messages attendus messages = [] if isinstance(node, dict): summarizer_export = node.get("summarizer_export") if summarizer_export and isinstance(summarizer_export, dict): messages = summarizer_export.get("messages", []) # Get the latest message, if available, from the messages list last_message = messages[-1] if messages else None # On chaque message assistant (AIMessage), si metadata.document présent, l'ajouter au tableau documents if isinstance(last_message, AIMessage): text = self._extract_text_content(last_message.content) doc_meta = last_message.metadata.get("document") if last_message.metadata else None if doc_meta is not None: documents.append(doc_meta) if isinstance(event, AIMessageChunk): text = self._extract_text_content(event.content) # Capture usage if present on chunks try: chunk_usage = getattr(event, "usage_metadata", None) if isinstance(chunk_usage, dict): norm = self._normalize_usage(chunk_usage) model_id = self._extract_model_from_params(params) or model_name.value bucket = usage_by_model.setdefault(model_id, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) bucket["input_tokens"] += norm["input_tokens"] bucket["output_tokens"] += norm["output_tokens"] bucket["total_tokens"] += norm["total_tokens"] usage_totals["input_tokens"] = usage_totals.get("input_tokens", 0) + norm["input_tokens"] usage_totals["output_tokens"] = usage_totals.get("output_tokens", 0) + norm["output_tokens"] usage_totals["total_tokens"] = usage_totals.get("total_tokens", 0) + norm["total_tokens"] except Exception: pass else: # Not an assistant output we should stream (e.g., HumanMessage) continue if text: yield { "content": text, "done": False, "metadata": { "model": model_name.value, "agent_type": agent_type.value, "usage": usage_totals }, "documents": documents } # Compute latency and run post-processing pipeline for the final chunk latency_s = time.time() - start_time ctx = RunContext( provider=model_name.provider.value, model=model_name.value, usage_totals=usage_totals, usage_by_model=usage_by_model, latency_s=latency_s, ) build_orchestrator().run(ctx) # Send final chunk yield { "content": "", "done": True, "metadata": { "model": model_name.value, "agent_type": agent_type.value, "usage": usage_totals, "usage_by_model": usage_by_model, "latency_s": latency_s, **ctx.metadata_out }, "documents": documents } def _prepare_messages( self, message: str, conversation_history: Optional[List[Dict[str, str]]] = None ) -> List[BaseMessage]: """ Prepare messages list from user input and optional history. Args: message: Current user message conversation_history: Optional list of previous messages Returns: List of LangChain messages """ messages = [] # Add conversation history if provided if conversation_history: for msg in conversation_history: role = msg.get("role", "user") content = msg.get("content", "") if role == "user": messages.append(HumanMessage(content=content)) elif role == "assistant": messages.append(AIMessage(content=content)) elif role == "system": messages.append(SystemMessage(content=content)) # Add current message messages.append(HumanMessage(content=message)) return messages def _extract_text_content(self, content: object) -> Optional[str]: """ Normalize LangChain message content into a plain text string. Handles both string content and list-structured content with text parts. """ if content is None: return None if isinstance(content, str): return content if isinstance(content, list): # LangChain can represent content as a list of parts like # [{"type": "text", "text": "..."}, ...] text_parts: List[str] = [] for part in content: try: # Dict-like parts with type/text if isinstance(part, dict): if part.get("type") == "text" and isinstance(part.get("text"), str): text_parts.append(part["text"]) # Object-like parts with attributes elif hasattr(part, "type") and getattr(part, "type") == "text" and hasattr(part, "text"): value = getattr(part, "text") if isinstance(value, str): text_parts.append(value) except Exception: # Skip any malformed parts continue return "".join(text_parts) if text_parts else None # Fallback: unknown content structure return None def _normalize_usage(self, usage: Dict[str, Any]) -> Dict[str, int]: """Normalize usage keys to input/output/total integers. Supports variants like prompt_tokens/completion_tokens. """ try: input_val = usage.get("input_tokens") if not isinstance(input_val, (int, float)): input_val = usage.get("prompt_tokens", 0) output_val = usage.get("output_tokens") if not isinstance(output_val, (int, float)): output_val = usage.get("completion_tokens", 0) total_val = usage.get("total_tokens") if not isinstance(total_val, (int, float)): total_val = (int(input_val or 0)) + (int(output_val or 0)) return { "input_tokens": int(input_val or 0), "output_tokens": int(output_val or 0), "total_tokens": int(total_val or 0), } except Exception: return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} def _extract_model_from_params(self, params: Optional[Dict[str, Any]]) -> Optional[str]: """Best-effort extraction of model identifier from LangGraph params.""" if not isinstance(params, dict): return None keys = ("ls_model_name", "model", "model_name", "model_id", "name", "llm_model", "openai_model") for key in keys: val = params.get(key) if isinstance(val, str) and val: return val # Search likely nested containers for container_key in ("configuration", "config", "kwargs", "meta", "metadata"): sub = params.get(container_key) if isinstance(sub, dict): for key in keys: val = sub.get(key) if isinstance(val, str) and val: return val return None # Singleton instance agent_service = AgentService()