hyf-backend/th_agenter/services/agent/base.py

244 lines
8.1 KiB
Python
Raw Permalink Normal View History

2026-01-21 13:45:39 +08:00
"""Base classes for Agent tools."""
import json
from loguru import logger
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type, Callable
from pydantic import BaseModel, Field
from dataclasses import dataclass
from enum import Enum
class ToolParameterType(str, Enum):
"""Tool parameter types."""
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
@dataclass
class ToolParameter:
"""Tool parameter definition."""
name: str
type: ToolParameterType
description: str
required: bool = True
default: Any = None
enum: Optional[List[Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON schema."""
param_dict = {
"type": self.type.value,
"description": self.description
}
if self.enum:
param_dict["enum"] = self.enum
if self.default is not None:
param_dict["default"] = self.default
return param_dict
class ToolResult(BaseModel):
"""Tool execution result."""
success: bool = Field(description="Whether the tool execution was successful")
result: Any = Field(description="The result data")
error: Optional[str] = Field(default=None, description="Error message if failed")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
class BaseTool(ABC):
"""Base class for all Agent tools."""
def __init__(self):
self.name = self.get_name()
self.description = self.get_description()
self.parameters = self.get_parameters()
@abstractmethod
def get_name(self) -> str:
"""Get tool name."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get tool description."""
pass
@abstractmethod
def get_parameters(self) -> List[ToolParameter]:
"""Get tool parameters."""
pass
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""Execute the tool with given parameters."""
pass
def get_schema(self) -> Dict[str, Any]:
"""Get tool schema for LangChain."""
properties = {}
required = []
for param in self.parameters:
properties[param.name] = param.to_dict()
if param.required:
required.append(param.name)
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
"""Validate and process input parameters."""
validated = {}
for param in self.parameters:
value = kwargs.get(param.name)
# Check required parameters
if param.required and value is None:
raise ValueError(f"Required parameter '{param.name}' is missing")
# Use default if not provided
if value is None and param.default is not None:
value = param.default
# Type validation (basic)
if value is not None:
if param.type == ToolParameterType.INTEGER and not isinstance(value, int):
try:
value = int(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param.name}' must be an integer")
elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)):
try:
value = float(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param.name}' must be a number")
elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool):
if isinstance(value, str):
value = value.lower() in ('true', '1', 'yes', 'on')
else:
value = bool(value)
# Enum validation
if param.enum and value not in param.enum:
raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}")
validated[param.name] = value
return validated
class ToolRegistry:
"""Registry for managing Agent tools."""
def __init__(self):
self._tools: Dict[str, BaseTool] = {}
self._enabled_tools: Dict[str, bool] = {}
def register(self, tool: BaseTool, enabled: bool = True) -> None:
"""Register a tool."""
tool_name = tool.get_name()
self._tools[tool_name] = tool
self._enabled_tools[tool_name] = enabled
logger.info(f"Registered tool: {tool_name} (enabled: {enabled})")
def unregister(self, tool_name: str) -> None:
"""Unregister a tool."""
if tool_name in self._tools:
del self._tools[tool_name]
del self._enabled_tools[tool_name]
logger.info(f"Unregistered tool: {tool_name}")
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
"""Get a tool by name."""
return self._tools.get(tool_name)
def get_enabled_tools(self) -> Dict[str, BaseTool]:
"""Get all enabled tools."""
return {
name: tool for name, tool in self._tools.items()
if self._enabled_tools.get(name, False)
}
def get_all_tools(self) -> Dict[str, BaseTool]:
"""Get all registered tools."""
return self._tools.copy()
def enable_tool(self, tool_name: str) -> None:
"""Enable a tool."""
if tool_name in self._tools:
self._enabled_tools[tool_name] = True
logger.info(f"Enabled tool: {tool_name}")
def disable_tool(self, tool_name: str) -> None:
"""Disable a tool."""
if tool_name in self._tools:
self._enabled_tools[tool_name] = False
logger.info(f"Disabled tool: {tool_name}")
def is_enabled(self, tool_name: str) -> bool:
"""Check if a tool is enabled."""
return self._enabled_tools.get(tool_name, False)
def get_tools_schema(self) -> List[Dict[str, Any]]:
"""Get schema for all enabled tools."""
enabled_tools = self.get_enabled_tools()
return [tool.get_schema() for tool in enabled_tools.values()]
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult:
"""Execute a tool with given parameters."""
tool = self.get_tool(tool_name)
if not tool:
return ToolResult(
success=False,
result=None,
error=f"Tool '{tool_name}' not found"
)
if not self.is_enabled(tool_name):
return ToolResult(
success=False,
result=None,
error=f"Tool '{tool_name}' is disabled"
)
try:
# Validate parameters
validated_params = tool.validate_parameters(**kwargs)
# Execute tool
result = await tool.execute(**validated_params)
logger.info(f"Tool '{tool_name}' executed successfully")
return result
except Exception as e:
logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True)
return ToolResult(
success=False,
result=None,
error=f"Tool execution failed: {str(e)}"
)
# Global tool registry instance
tool_registry = ToolRegistry()