"""Agent service for executing LangGraph agents.""" from typing import Optional, AsyncIterator, List, Dict, Any import time from langchain_core.messages import AIMessageChunk, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from domain.enums import ModelName from .llm_service import llm_service from .agent_registry import agent_registry from .usage_utils import normalize_usage from services.postprocessing.registry import build_orchestrator from services.postprocessing.context import RunContext from services.stream_payloads import normalize_custom_writer_payload, parse_stream_graph_item # Shared checkpointer for text completion so server-side memory (thread_id) is consistent across requests. _text_checkpointer = MemorySaver() class AgentService: """ Service for executing agent graphs with different LLMs. This service is the bridge between the API layer and the LangGraph agents. It handles: - Creating the right LLM based on model selection - Getting the right agent graph from the registry - Executing the graph with or without streaming """ def __init__(self): """Initialize the agent service.""" pass async def invoke( self, message: str, model_name: ModelName, agent: Optional[str] = None, temperature: float = 0.7, max_tokens: Optional[int] = None, conversation_history: Optional[List[Dict[str, str]]] = None, project_id: Optional[str] = None, sources: Optional[List[str]] = None, conversation_id: Optional[str] = None, ) -> dict: """ Invoke agent for a single response (non-streaming). Args: message: User message model_name: LLM model to use agent: Agent identifier. Defaults to "AGENT" when omitted. temperature: Sampling temperature max_tokens: Max tokens to generate conversation_history: Optional conversation history (ignored if conversation_id is set) project_id: Optional project id for retrieval conversation_id: Optional conversation id for server-side memory (thread_id) Returns: Response dictionary with content and metadata """ # Create LLM instance llm = llm_service.get_llm( model_name=model_name, temperature=temperature, streaming=False, max_tokens=max_tokens ) resolved_agent = agent_registry.resolve_agent_id(agent) builder = agent_registry.get_builder_for_request(agent=resolved_agent) use_memory = bool(conversation_id) graph = builder(llm, checkpointer=_text_checkpointer) if use_memory else builder(llm) if use_memory: messages = [HumanMessage(content=message)] config: Optional[Dict[str, Any]] = {"configurable": {"thread_id": conversation_id}} else: messages = self._prepare_messages(message, conversation_history) config = None # Execute graph with latency start_time = time.time() result = await graph.ainvoke( { "messages": messages, "query": message, "project_id": project_id, "sources": sources, }, config=config, ) latency_s = time.time() - start_time # Extract response response_message = result["messages"][-1] response_content = response_message.content # Prepare metadata and run post-processing pipeline usage = getattr(response_message, "usage_metadata", None) or {} usage_totals = normalize_usage(usage) usage_by_model = {model_name.value: usage_totals} ctx = RunContext( provider=model_name.provider.value, model=model_name.value, usage_totals=usage_totals, usage_by_model=usage_by_model, latency_s=latency_s, ) build_orchestrator().run(ctx) base_metadata: Dict[str, Any] = { "message_count": len(result["messages"]), } base_metadata.update(ctx.metadata_out) return { "response": response_content, "model": model_name.value, "agent": resolved_agent, "usage": usage, "metadata": base_metadata, } async def stream( self, message: str, model_name: ModelName, agent: Optional[str] = None, temperature: float = 0.7, max_tokens: Optional[int] = None, conversation_history: Optional[List[Dict[str, str]]] = None, project_id: Optional[str] = None, sources: Optional[List[str]] = None, conversation_id: Optional[str] = None, ) -> AsyncIterator[dict]: """ Stream agent response token by token. Args: message: User message model_name: LLM model to use agent: Agent identifier. Defaults to "AGENT" when omitted. temperature: Sampling temperature max_tokens: Max tokens to generate conversation_history: Optional conversation history (ignored if conversation_id is set) project_id: Optional project id for retrieval conversation_id: Optional conversation id for server-side memory (thread_id) Yields: Dictionary chunks with content and metadata """ # Create LLM instance with streaming enabled llm = llm_service.get_llm( model_name=model_name, temperature=temperature, streaming=True, max_tokens=max_tokens ) resolved_agent = agent_registry.resolve_agent_id(agent) builder = agent_registry.get_builder_for_request(agent=resolved_agent) use_memory = bool(conversation_id) graph = builder(llm, checkpointer=_text_checkpointer) if use_memory else builder(llm) if use_memory: messages = [HumanMessage(content=message)] config = {"configurable": {"thread_id": conversation_id}} else: messages = self._prepare_messages(message, conversation_history) config = None # Track usage and latency for final emissions calculation usage_totals: Dict[str, int] = {} usage_by_model: Dict[str, Dict[str, int]] = {} start_time = time.time() documents = [] # Stream graph execution for msg in graph.stream( { "messages": messages, "query": message, "project_id": project_id, "sources": sources, }, config=config, stream_mode=["messages", "updates", "custom"], ): mode, stream_payload = parse_stream_graph_item(msg) # LangGraph may yield (node_name, message) tuples in messages mode event = None params = None # Only emit assistant outputs; ignore user/history echoes text: Optional[str] = None if mode == "custom": custom_payload = normalize_custom_writer_payload(stream_payload) yield { "chunk_kind": "custom", "custom": custom_payload, "content": "", "done": False, "metadata": { "model": model_name.value, "agent": resolved_agent, "usage": usage_totals, }, "documents": list(documents), } continue if mode == "messages": chunk = stream_payload if isinstance(chunk, tuple) and len(chunk) == 2: # Prefer the BaseMessage element if present from langchain_core.messages import BaseMessage as _LCBaseMessage if isinstance(chunk[1], _LCBaseMessage): event = chunk[1] params = chunk[0] elif isinstance(chunk[0], _LCBaseMessage): event = chunk[0] params = chunk[1] else: # Fallback to second element by convention event = chunk[1] else: event = chunk if mode == "updates": node = stream_payload # Parse node updates generically to stay compatible across workflows. node_messages: List[Any] = [] if isinstance(node, dict): for node_payload in node.values(): if not isinstance(node_payload, dict): continue payload_messages = node_payload.get("messages", []) if isinstance(payload_messages, list): node_messages.extend(payload_messages) tool_documents = node_payload.get("documents", []) if isinstance(tool_documents, list): for doc in tool_documents: if doc is not None and doc not in documents: documents.append(doc) # Get the latest message, if available, from the messages list last_message = node_messages[-1] if node_messages else None # Keep assistant text for streaming when available. # if isinstance(last_message, AIMessage): # text = self._extract_text_content(last_message.content) if isinstance(last_message, (AIMessage, ToolMessage)): doc_meta = self._extract_document_metadata(last_message) if doc_meta is not None and doc_meta not in documents: documents.append(doc_meta) if isinstance(event, AIMessageChunk): tool_calls = getattr(event, "tool_calls", None) or [] tool_call_chunks = getattr(event, "tool_call_chunks", None) or [] # AIMessageChunk is a Pydantic model: `"tool_calls" in event` is always False; # use the attributes directly. Skip deltas that belong to a tool call stream. if not tool_calls and not tool_call_chunks: text = self._extract_text_content(event.content) # Capture usage if present on chunks try: chunk_usage = getattr(event, "usage_metadata", None) if isinstance(chunk_usage, dict): norm = normalize_usage(chunk_usage) model_id = self._extract_model_from_params(params) or model_name.value bucket = usage_by_model.setdefault(model_id, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) bucket["input_tokens"] += norm["input_tokens"] bucket["output_tokens"] += norm["output_tokens"] bucket["total_tokens"] += norm["total_tokens"] usage_totals["input_tokens"] = usage_totals.get("input_tokens", 0) + norm["input_tokens"] usage_totals["output_tokens"] = usage_totals.get("output_tokens", 0) + norm["output_tokens"] usage_totals["total_tokens"] = usage_totals.get("total_tokens", 0) + norm["total_tokens"] except Exception: pass else: # Not an assistant output we should stream (e.g., HumanMessage) continue if text: yield { "chunk_kind": "text", "content": text, "done": False, "metadata": { "model": model_name.value, "agent": resolved_agent, "usage": usage_totals }, "documents": list(documents), } # Compute latency and run post-processing pipeline for the final chunk latency_s = time.time() - start_time ctx = RunContext( provider=model_name.provider.value, model=model_name.value, usage_totals=usage_totals, usage_by_model=usage_by_model, latency_s=latency_s, ) build_orchestrator().run(ctx) # Send final chunk yield { "chunk_kind": "final", "content": "", "done": True, "metadata": { "model": model_name.value, "agent": resolved_agent, "usage": usage_totals, "usage_by_model": usage_by_model, "latency_s": latency_s, **ctx.metadata_out }, "documents": list(documents), } def _prepare_messages( self, message: str, conversation_history: Optional[List[Dict[str, str]]] = None ) -> List[BaseMessage]: """ Prepare messages list from user input and optional history. Args: message: Current user message conversation_history: Optional list of previous messages Returns: List of LangChain messages """ messages = [] # Add conversation history if provided if conversation_history: for msg in conversation_history: role = msg.get("role", "user") content = msg.get("content", "") if role == "user": messages.append(HumanMessage(content=content)) elif role == "assistant": messages.append(AIMessage(content=content)) elif role == "system": messages.append(SystemMessage(content=content)) # Add current message messages.append(HumanMessage(content=message)) return messages def _extract_text_content(self, content: object) -> Optional[str]: """ Normalize LangChain message content into a plain text string. Handles both string content and list-structured content with text parts. """ if content is None: return None if isinstance(content, str): return content if isinstance(content, list): # LangChain can represent content as a list of parts like # [{"type": "text", "text": "..."}, ...] text_parts: List[str] = [] for part in content: try: # Dict-like parts with type/text if isinstance(part, dict): if part.get("type") == "text" and isinstance(part.get("text"), str): text_parts.append(part["text"]) # Object-like parts with attributes elif hasattr(part, "type") and getattr(part, "type") == "text" and hasattr(part, "text"): value = getattr(part, "text") if isinstance(value, str): text_parts.append(value) except Exception: # Skip any malformed parts continue return "".join(text_parts) if text_parts else None # Fallback: unknown content structure return None def _extract_model_from_params(self, params: Optional[Dict[str, Any]]) -> Optional[str]: """Best-effort extraction of model identifier from LangGraph params.""" if not isinstance(params, dict): return None keys = ("ls_model_name", "model", "model_name", "model_id", "name", "llm_model", "openai_model") for key in keys: val = params.get(key) if isinstance(val, str) and val: return val # Search likely nested containers for container_key in ("configuration", "config", "kwargs", "meta", "metadata"): sub = params.get(container_key) if isinstance(sub, dict): for key in keys: val = sub.get(key) if isinstance(val, str) and val: return val return None def _extract_document_metadata(self, message: object) -> Optional[Dict[str, Any]]: """Extract a normalized document payload from AI/Tool messages.""" if not isinstance(message, (AIMessage, ToolMessage)): return None metadata = getattr(message, "metadata", None) if isinstance(metadata, dict): doc = metadata.get("document") if isinstance(doc, dict): return doc artifact = getattr(message, "artifact", None) if isinstance(artifact, dict): doc = artifact.get("document") if isinstance(doc, dict): return doc additional_kwargs = getattr(message, "additional_kwargs", None) if isinstance(additional_kwargs, dict): meta = additional_kwargs.get("metadata") if isinstance(meta, dict): doc = meta.get("document") if isinstance(doc, dict): return doc return None # Singleton instance agent_service = AgentService()