"""LangChain Agent service with tool calling capabilities.""" import asyncio from typing import List, Dict, Any, Optional, AsyncGenerator from langchain_core.tools import BaseTool as LangChainBaseTool from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage, AIMessage from ...utils.logger import get_logger logger = get_logger("agent_service") # Try to import langchain_classic with exception handling try: from langchain_classic.agents import AgentExecutor from langchain_classic.agents.tool_calling_agent.base import create_tool_calling_agent LANGCHAIN_CLASSIC_AVAILABLE = True except ImportError: logger.warning("langchain_classic not available. Agent functionality will be disabled.") AgentExecutor = None create_tool_calling_agent = None LANGCHAIN_CLASSIC_AVAILABLE = False from pydantic import BaseModel, Field from .base import BaseTool, ToolRegistry, ToolResult from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool from ..postgresql_tool_manager import get_postgresql_tool from ..mysql_tool_manager import get_mysql_tool from ...core.config import get_settings from ..agent_config import AgentConfigService class LangChainToolWrapper(LangChainBaseTool): """Wrapper to convert our BaseTool to LangChain tool.""" name: str = Field(...) description: str = Field(...) base_tool: BaseTool = Field(...) def __init__(self, base_tool: BaseTool, **kwargs): super().__init__( name=base_tool.get_name(), description=base_tool.get_description(), base_tool=base_tool, **kwargs ) def _run(self, *args, **kwargs) -> str: """Synchronous run method.""" # Handle both positional and keyword arguments if args: # If positional arguments are provided, convert them to kwargs # based on the tool's parameter names params = self.base_tool.get_parameters() for i, arg in enumerate(args): if i < len(params): kwargs[params[i].name] = arg # Run async method in sync context loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: result = loop.run_until_complete(self.base_tool.execute(**kwargs)) return self._format_result(result) finally: loop.close() async def _arun(self, *args, **kwargs) -> str: """Asynchronous run method.""" # Handle both positional and keyword arguments if args: # If positional arguments are provided, convert them to kwargs # based on the tool's parameter names params = self.base_tool.get_parameters() for i, arg in enumerate(args): if i < len(params): kwargs[params[i].name] = arg result = await self.base_tool.execute(**kwargs) return self._format_result(result) def _format_result(self, result: ToolResult) -> str: """Format tool result for LangChain.""" if result.success: if isinstance(result.result, dict) and "summary" in result.result: return result.result["summary"] return str(result.result) else: return f"Error: {result.error}" class AgentConfig(BaseModel): """Agent configuration.""" enabled_tools: List[str] = Field(default_factory=lambda: [ "calculator", "weather", "search", "datetime", "file", "generate_image", "postgresql_mcp", "mysql_mcp" ]) max_iterations: int = Field(default=10) temperature: float = Field(default=0.1) system_message: str = Field( default="You are a helpful AI assistant with access to various tools. " "Use the available tools to help answer user questions accurately. " "Always explain your reasoning and the tools you're using." ) verbose: bool = Field(default=True) class AgentService: """LangChain Agent service with tool calling capabilities.""" def __init__(self, db_session=None): self.settings = get_settings() self.tool_registry = ToolRegistry() self.config = AgentConfig() self.agent_executor: Optional[AgentExecutor] = None self.db_session = db_session self.config_service = AgentConfigService(db_session) if db_session else None self._initialize_tools() self._load_config() def _initialize_tools(self): """Initialize and register all available tools.""" tools = [ WeatherQueryTool(), TavilySearchTool(), DateTimeTool(), get_postgresql_tool(), # 使用单例PostgreSQL MCP工具 get_mysql_tool() # 使用单例MySQL MCP工具 ] for tool in tools: self.tool_registry.register(tool) logger.info(f"Registered tool: {tool.get_name()}") def _load_config(self): """Load configuration from database if available.""" if self.config_service: try: config_dict = self.config_service.get_config_dict() # Update config with database values for key, value in config_dict.items(): if hasattr(self.config, key): setattr(self.config, key, value) logger.info("Loaded agent configuration from database") except Exception as e: logger.warning(f"Failed to load config from database, using defaults: {str(e)}") def _get_enabled_tools(self) -> List[LangChainToolWrapper]: """Get list of enabled LangChain tools.""" enabled_tools = [] for tool_name in self.config.enabled_tools: tool = self.tool_registry.get_tool(tool_name) if tool: langchain_tool = LangChainToolWrapper(base_tool=tool) enabled_tools.append(langchain_tool) logger.debug(f"Enabled tool: {tool_name}") else: logger.warning(f"Tool not found: {tool_name}") return enabled_tools def _create_agent_executor(self) -> AgentExecutor: """Create LangChain agent executor.""" if not LANGCHAIN_CLASSIC_AVAILABLE: raise ValueError("Agent functionality is disabled because langchain_classic is not available.") # Get LLM configuration from ...core.llm import create_llm llm = create_llm() # Get enabled tools tools = self._get_enabled_tools() # Create prompt template prompt = ChatPromptTemplate.from_messages([ ("system", self.config.system_message), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad") ]) # Create agent agent = create_tool_calling_agent(llm, tools, prompt) # Create agent executor agent_executor = AgentExecutor( agent=agent, tools=tools, max_iterations=self.config.max_iterations, verbose=self.config.verbose, return_intermediate_steps=True ) return agent_executor async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]: """Process chat message with agent.""" try: logger.info(f"Processing agent chat message: {message[:100]}...") # Create agent executor if not exists if not self.agent_executor: self.agent_executor = self._create_agent_executor() # Convert chat history to LangChain format langchain_history = [] if chat_history: for msg in chat_history: if msg["role"] == "user": langchain_history.append(HumanMessage(content=msg["content"])) elif msg["role"] == "assistant": langchain_history.append(AIMessage(content=msg["content"])) # Execute agent result = await self.agent_executor.ainvoke({ "input": message, "chat_history": langchain_history }) # Extract response and intermediate steps response = result["output"] intermediate_steps = result.get("intermediate_steps", []) # Format tool calls for response tool_calls = [] for step in intermediate_steps: if len(step) >= 2: action, observation = step[0], step[1] tool_calls.append({ "tool": action.tool, "input": action.tool_input, "output": observation }) logger.info(f"Agent response generated successfully with {len(tool_calls)} tool calls") return { "response": response, "tool_calls": tool_calls, "success": True } except Exception as e: logger.error(f"Agent chat error: {str(e)}", exc_info=True) return { "response": f"Sorry, I encountered an error: {str(e)}", "tool_calls": [], "success": False, "error": str(e) } async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]: """Process chat message with agent (streaming).""" tool_calls = [] # Initialize tool_calls at the beginning try: logger.info(f"Processing agent chat stream: {message[:100]}...") # Create agent executor if not exists if not self.agent_executor: self.agent_executor = self._create_agent_executor() # Convert chat history to LangChain format langchain_history = [] if chat_history: for msg in chat_history: if msg["role"] == "user": langchain_history.append(HumanMessage(content=msg["content"])) elif msg["role"] == "assistant": langchain_history.append(AIMessage(content=msg["content"])) # Yield initial status yield { "type": "status", "content": "🤖 开始分析您的请求...", "done": False } await asyncio.sleep(0.2) # Use astream_events for real streaming (if available) or fallback to simulation try: # Try to use streaming events if available async for event in self.agent_executor.astream_events( {"input": message, "chat_history": langchain_history}, version="v1" ): if event["event"] == "on_tool_start": tool_name = event["name"] yield { "type": "tool_start", "content": f"🔧 正在使用工具: {tool_name}", "tool_name": tool_name, "done": False } await asyncio.sleep(0.1) elif event["event"] == "on_tool_end": tool_name = event["name"] yield { "type": "tool_end", "content": f"✅ 工具 {tool_name} 执行完成", "tool_name": tool_name, "done": False } await asyncio.sleep(0.1) elif event["event"] == "on_chat_model_stream": chunk = event["data"]["chunk"] if hasattr(chunk, 'content') and chunk.content: yield { "type": "content", "content": chunk.content, "done": False } await asyncio.sleep(0.05) except Exception as stream_error: logger.warning(f"Streaming events not available, falling back to simulation: {stream_error}") # Fallback: Execute agent and simulate streaming result = await self.agent_executor.ainvoke({ "input": message, "chat_history": langchain_history }) # Extract response and intermediate steps response = result["output"] intermediate_steps = result.get("intermediate_steps", []) # Yield tool execution steps tool_calls = [] for i, step in enumerate(intermediate_steps): if len(step) >= 2: action, observation = step[0], step[1] tool_calls.append({ "tool": action.tool, "input": action.tool_input, "output": observation }) # Yield tool execution status yield { "type": "tool", "content": f"🔧 使用工具 {action.tool}: {str(action.tool_input)[:100]}...", "tool_name": action.tool, "tool_input": action.tool_input, "done": False } await asyncio.sleep(0.3) yield { "type": "tool_result", "content": f"✅ 工具结果: {str(observation)[:200]}...", "tool_name": action.tool, "done": False } await asyncio.sleep(0.2) # Yield thinking status yield { "type": "thinking", "content": "🤔 正在整理回答...", "done": False } await asyncio.sleep(0.3) # Yield the final response in chunks to simulate streaming words = response.split() current_content = "" for i, word in enumerate(words): current_content += word + " " # Yield every 2-3 words or at the end if (i + 1) % 2 == 0 or i == len(words) - 1: yield { "type": "response", "content": current_content.strip(), "tool_calls": tool_calls if i == len(words) - 1 else [], "done": i == len(words) - 1 } # Small delay to simulate typing if i < len(words) - 1: await asyncio.sleep(0.05) logger.info(f"Agent stream response completed with {len(tool_calls)} tool calls") except Exception as e: logger.error(f"Agent chat stream error: {str(e)}", exc_info=True) yield { "type": "error", "content": f"Sorry, I encountered an error: {str(e)}", "done": True } def update_config(self, config: Dict[str, Any]): """Update agent configuration.""" try: # Update configuration for key, value in config.items(): if hasattr(self.config, key): setattr(self.config, key, value) logger.info(f"Updated agent config: {key} = {value}") # Reset agent executor to apply new config self.agent_executor = None except Exception as e: logger.error(f"Error updating agent config: {str(e)}", exc_info=True) raise def load_config_from_db(self, config_id: Optional[int] = None): """Load configuration from database.""" if not self.config_service: logger.warning("No database session available for loading config") return try: config_dict = self.config_service.get_config_dict(config_id) self.update_config(config_dict) logger.info(f"Loaded configuration from database (ID: {config_id})") except Exception as e: logger.error(f"Error loading config from database: {str(e)}") raise def get_available_tools(self) -> List[Dict[str, Any]]: """Get list of available tools.""" tools = [] for tool_name, tool in self.tool_registry._tools.items(): tools.append({ "name": tool.get_name(), "description": tool.get_description(), "parameters": [{ "name": param.name, "type": param.type.value, "description": param.description, "required": param.required, "default": param.default, "enum": param.enum } for param in tool.get_parameters()], "enabled": tool_name in self.config.enabled_tools }) return tools def get_config(self) -> Dict[str, Any]: """Get current agent configuration.""" return self.config.dict() # Global agent service instance _agent_service: Optional[AgentService] = None def get_agent_service(db_session=None) -> Optional[AgentService]: """Get global agent service instance.""" global _agent_service if _agent_service is None: try: _agent_service = AgentService(db_session) except Exception as e: logger.warning(f"Failed to initialize AgentService: {str(e)}. Agent functionality will be disabled.") _agent_service = None elif db_session and _agent_service and not _agent_service.db_session: # Update with database session if not already set _agent_service.db_session = db_session _agent_service.config_service = AgentConfigService(db_session) _agent_service._load_config() return _agent_service