hxf/backend/th_agenter/services/agent/base.py

244 lines
8.1 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""Base classes for Agent tools."""
import json
2025-12-16 13:55:16 +08:00
from loguru import logger
2025-12-04 14:48:38 +08:00
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Callable
from pydantic import BaseModel, Field
from dataclasses import dataclass
from enum import Enum
class ToolParameterType(str, Enum):
"""Tool parameter types."""
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
@dataclass
class ToolParameter:
"""Tool parameter definition."""
name: str
type: ToolParameterType
description: str
required: bool = True
default: Any = None
enum: Optional[List[Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON schema."""
param_dict = {
"type": self.type.value,
"description": self.description
}
if self.enum:
param_dict["enum"] = self.enum
if self.default is not None:
param_dict["default"] = self.default
return param_dict
class ToolResult(BaseModel):
"""Tool execution result."""
success: bool = Field(description="Whether the tool execution was successful")
result: Any = Field(description="The result data")
error: Optional[str] = Field(default=None, description="Error message if failed")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
class BaseTool(ABC):
"""Base class for all Agent tools."""
def __init__(self):
self.name = self.get_name()
self.description = self.get_description()
self.parameters = self.get_parameters()
@abstractmethod
def get_name(self) -> str:
"""Get tool name."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get tool description."""
pass
@abstractmethod
def get_parameters(self) -> List[ToolParameter]:
"""Get tool parameters."""
pass
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""Execute the tool with given parameters."""
pass
def get_schema(self) -> Dict[str, Any]:
"""Get tool schema for LangChain."""
properties = {}
required = []
for param in self.parameters:
properties[param.name] = param.to_dict()
if param.required:
required.append(param.name)
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
"""Validate and process input parameters."""
validated = {}
for param in self.parameters:
value = kwargs.get(param.name)
# Check required parameters
if param.required and value is None:
raise ValueError(f"Required parameter '{param.name}' is missing")
# Use default if not provided
if value is None and param.default is not None:
value = param.default
# Type validation (basic)
if value is not None:
if param.type == ToolParameterType.INTEGER and not isinstance(value, int):
try:
value = int(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param.name}' must be an integer")
elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)):
try:
value = float(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param.name}' must be a number")
elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool):
if isinstance(value, str):
value = value.lower() in ('true', '1', 'yes', 'on')
else:
value = bool(value)
# Enum validation
if param.enum and value not in param.enum:
raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}")
validated[param.name] = value
return validated
class ToolRegistry:
"""Registry for managing Agent tools."""
def __init__(self):
self._tools: Dict[str, BaseTool] = {}
self._enabled_tools: Dict[str, bool] = {}
def register(self, tool: BaseTool, enabled: bool = True) -> None:
"""Register a tool."""
tool_name = tool.get_name()
self._tools[tool_name] = tool
self._enabled_tools[tool_name] = enabled
logger.info(f"Registered tool: {tool_name} (enabled: {enabled})")
def unregister(self, tool_name: str) -> None:
"""Unregister a tool."""
if tool_name in self._tools:
del self._tools[tool_name]
del self._enabled_tools[tool_name]
logger.info(f"Unregistered tool: {tool_name}")
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
"""Get a tool by name."""
return self._tools.get(tool_name)
def get_enabled_tools(self) -> Dict[str, BaseTool]:
"""Get all enabled tools."""
return {
name: tool for name, tool in self._tools.items()
if self._enabled_tools.get(name, False)
}
def get_all_tools(self) -> Dict[str, BaseTool]:
"""Get all registered tools."""
return self._tools.copy()
def enable_tool(self, tool_name: str) -> None:
"""Enable a tool."""
if tool_name in self._tools:
self._enabled_tools[tool_name] = True
logger.info(f"Enabled tool: {tool_name}")
def disable_tool(self, tool_name: str) -> None:
"""Disable a tool."""
if tool_name in self._tools:
self._enabled_tools[tool_name] = False
logger.info(f"Disabled tool: {tool_name}")
def is_enabled(self, tool_name: str) -> bool:
"""Check if a tool is enabled."""
return self._enabled_tools.get(tool_name, False)
def get_tools_schema(self) -> List[Dict[str, Any]]:
"""Get schema for all enabled tools."""
enabled_tools = self.get_enabled_tools()
return [tool.get_schema() for tool in enabled_tools.values()]
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult:
"""Execute a tool with given parameters."""
tool = self.get_tool(tool_name)
if not tool:
return ToolResult(
success=False,
result=None,
error=f"Tool '{tool_name}' not found"
)
if not self.is_enabled(tool_name):
return ToolResult(
success=False,
result=None,
error=f"Tool '{tool_name}' is disabled"
)
try:
# Validate parameters
validated_params = tool.validate_parameters(**kwargs)
# Execute tool
result = await tool.execute(**validated_params)
logger.info(f"Tool '{tool_name}' executed successfully")
return result
except Exception as e:
logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True)
return ToolResult(
success=False,
result=None,
error=f"Tool execution failed: {str(e)}"
)
# Global tool registry instance
tool_registry = ToolRegistry()