hxf/backend/th_agenter/services/agent/agent_service.py

282 lines
11 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""LangChain Agent service with tool calling capabilities."""
import asyncio
from typing import List, Dict, Any, Optional, AsyncGenerator
2025-12-16 13:55:16 +08:00
from langchain.agents import create_agent
2025-12-04 14:48:38 +08:00
from langchain_core.messages import HumanMessage, AIMessage
2025-12-16 13:55:16 +08:00
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2025-12-04 14:48:38 +08:00
from pydantic import BaseModel, Field
from .base import BaseTool, ToolRegistry, ToolResult
from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool
from ..postgresql_tool_manager import get_postgresql_tool
from ..mysql_tool_manager import get_mysql_tool
from ...core.config import get_settings
from ..agent_config import AgentConfigService
2025-12-16 13:55:16 +08:00
from loguru import logger
2025-12-04 14:48:38 +08:00
class AgentConfig(BaseModel):
"""Agent configuration."""
enabled_tools: List[str] = Field(default_factory=lambda: [
"calculator", "weather", "search", "datetime", "file", "generate_image", "postgresql_mcp", "mysql_mcp"
])
max_iterations: int = Field(default=10)
temperature: float = Field(default=0.1)
system_message: str = Field(
default="You are a helpful AI assistant with access to various tools. "
"Use the available tools to help answer user questions accurately. "
"Always explain your reasoning and the tools you're using."
)
verbose: bool = Field(default=True)
class AgentService:
"""LangChain Agent service with tool calling capabilities."""
2026-01-07 11:30:54 +08:00
def __init__(self):
2025-12-04 14:48:38 +08:00
self.settings = get_settings()
2026-01-07 11:30:54 +08:00
async def initialize(self, session=None):
2025-12-04 14:48:38 +08:00
self.tool_registry = ToolRegistry()
self.config = AgentConfig()
2026-01-07 11:30:54 +08:00
self.session = session
self.config_service = AgentConfigService(session) if session else None
2025-12-04 14:48:38 +08:00
self._initialize_tools()
2026-01-07 11:30:54 +08:00
await self._load_config()
2025-12-04 14:48:38 +08:00
def _initialize_tools(self):
"""Initialize and register all available tools."""
tools = [
WeatherQueryTool(),
TavilySearchTool(),
DateTimeTool(),
get_postgresql_tool(), # 使用单例PostgreSQL MCP工具
get_mysql_tool() # 使用单例MySQL MCP工具
]
for tool in tools:
self.tool_registry.register(tool)
logger.info(f"Registered tool: {tool.get_name()}")
2026-01-07 11:30:54 +08:00
async def _load_config(self):
2025-12-04 14:48:38 +08:00
"""Load configuration from database if available."""
if self.config_service:
try:
2026-01-07 11:30:54 +08:00
config_dict = await self.config_service.get_config_dict()
2025-12-04 14:48:38 +08:00
# Update config with database values
for key, value in config_dict.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
except Exception as e:
2026-01-07 11:30:54 +08:00
logger.error(f"Failed to load config from database, using defaults: {str(e)}")
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def _get_enabled_tools(self) -> List[Any]:
2025-12-04 14:48:38 +08:00
"""Get list of enabled LangChain tools."""
enabled_tools = []
for tool_name in self.config.enabled_tools:
tool = self.tool_registry.get_tool(tool_name)
if tool:
2025-12-16 13:55:16 +08:00
enabled_tools.append(tool)
2025-12-04 14:48:38 +08:00
logger.debug(f"Enabled tool: {tool_name}")
else:
logger.warning(f"Tool not found: {tool_name}")
return enabled_tools
2026-01-07 11:30:54 +08:00
async def _create_agent_executor(self) -> Any:
2025-12-04 14:48:38 +08:00
"""Create LangChain agent executor."""
# Get LLM configuration
2026-01-07 11:30:54 +08:00
from ...core.new_agent import new_agent
llm = await new_agent()
2025-12-04 14:48:38 +08:00
# Get enabled tools
tools = self._get_enabled_tools()
# Create prompt template
prompt = ChatPromptTemplate.from_messages([
("system", self.config.system_message),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
])
2025-12-16 13:55:16 +08:00
# Create agent using new LangChain 1.0+ API
agent = create_agent(
llm=llm,
2025-12-04 14:48:38 +08:00
tools=tools,
2025-12-16 13:55:16 +08:00
prompt=prompt
2025-12-04 14:48:38 +08:00
)
2025-12-16 13:55:16 +08:00
return agent
2025-12-04 14:48:38 +08:00
async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
"""Process chat message with agent."""
try:
logger.info(f"Processing agent chat message: {message[:100]}...")
2025-12-16 13:55:16 +08:00
# Create agent
2026-01-07 11:30:54 +08:00
agent = await self._create_agent_executor()
2025-12-04 14:48:38 +08:00
# Convert chat history to LangChain format
langchain_history = []
if chat_history:
for msg in chat_history:
if msg["role"] == "user":
langchain_history.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
langchain_history.append(AIMessage(content=msg["content"]))
# Execute agent
2025-12-16 13:55:16 +08:00
result = await agent.ainvoke({
2025-12-04 14:48:38 +08:00
"input": message,
"chat_history": langchain_history
})
2025-12-16 13:55:16 +08:00
logger.info(f"Agent response generated successfully")
2025-12-04 14:48:38 +08:00
return {
2025-12-16 13:55:16 +08:00
"response": result["output"] if isinstance(result, dict) and "output" in result else str(result),
"tool_calls": [],
2025-12-04 14:48:38 +08:00
"success": True
}
except Exception as e:
logger.error(f"Agent chat error: {str(e)}", exc_info=True)
return {
"response": f"Sorry, I encountered an error: {str(e)}",
"tool_calls": [],
"success": False,
"error": str(e)
}
async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]:
"""Process chat message with agent (streaming)."""
tool_calls = [] # Initialize tool_calls at the beginning
try:
logger.info(f"Processing agent chat stream: {message[:100]}...")
2025-12-16 13:55:16 +08:00
# Create agent
2026-01-07 11:30:54 +08:00
agent = await self._create_agent_executor()
2025-12-04 14:48:38 +08:00
# Convert chat history to LangChain format
langchain_history = []
if chat_history:
for msg in chat_history:
if msg["role"] == "user":
langchain_history.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
langchain_history.append(AIMessage(content=msg["content"]))
# Yield initial status
yield {
"type": "status",
"content": "🤖 开始分析您的请求...",
"done": False
}
await asyncio.sleep(0.2)
2025-12-16 13:55:16 +08:00
# Generate response
result = await agent.ainvoke({
"input": message,
"chat_history": langchain_history
})
response_content = result["output"] if isinstance(result, dict) and "output" in result else str(result)
# Yield the final response in chunks to simulate streaming
words = response_content.split()
current_content = ""
for i, word in enumerate(words):
current_content += word + " "
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
# Yield every 2-3 words or at the end
if (i + 1) % 2 == 0 or i == len(words) - 1:
yield {
"type": "response",
"content": current_content.strip(),
"tool_calls": tool_calls if i == len(words) - 1 else [],
"done": i == len(words) - 1
}
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
# Small delay to simulate typing
if i < len(words) - 1:
await asyncio.sleep(0.05)
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
logger.info(f"Agent stream response completed")
2025-12-04 14:48:38 +08:00
except Exception as e:
logger.error(f"Agent chat stream error: {str(e)}", exc_info=True)
yield {
"type": "error",
"content": f"Sorry, I encountered an error: {str(e)}",
"done": True
}
def update_config(self, config: Dict[str, Any]):
"""Update agent configuration."""
try:
# Update configuration
for key, value in config.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
logger.info(f"Updated agent config: {key} = {value}")
except Exception as e:
logger.error(f"Error updating agent config: {str(e)}", exc_info=True)
raise
def load_config_from_db(self, config_id: Optional[int] = None):
"""Load configuration from database."""
if not self.config_service:
logger.warning("No database session available for loading config")
return
try:
config_dict = self.config_service.get_config_dict(config_id)
self.update_config(config_dict)
logger.info(f"Loaded configuration from database (ID: {config_id})")
except Exception as e:
logger.error(f"Error loading config from database: {str(e)}")
raise
def get_available_tools(self) -> List[Dict[str, Any]]:
"""Get list of available tools."""
tools = []
for tool_name, tool in self.tool_registry._tools.items():
tools.append({
"name": tool.get_name(),
"description": tool.get_description(),
"parameters": [{
"name": param.name,
"type": param.type.value,
"description": param.description,
"required": param.required,
"default": param.default,
"enum": param.enum
} for param in tool.get_parameters()],
"enabled": tool_name in self.config.enabled_tools
})
return tools
def get_config(self) -> Dict[str, Any]:
"""Get current agent configuration."""
return self.config.dict()
# Global agent service instance
2026-01-07 11:30:54 +08:00
_global_agent_service: Optional[AgentService] = None
2025-12-04 14:48:38 +08:00
2026-01-07 11:30:54 +08:00
async def get_agent_service(session=None) -> AgentService:
2025-12-04 14:48:38 +08:00
"""Get global agent service instance."""
2026-01-07 11:30:54 +08:00
global _global_agent_service
if _global_agent_service is None:
_global_agent_service = AgentService()
await _global_agent_service.initialize(session)
elif session and session != _global_agent_service.session:
2025-12-04 14:48:38 +08:00
# Update with database session if not already set
2026-01-07 11:30:54 +08:00
_global_agent_service.session = session
_global_agent_service.config_service = AgentConfigService(session)
_global_agent_service._load_config()
return _global_agent_service