244 lines
8.1 KiB
Python
244 lines
8.1 KiB
Python
|
|
"""Base classes for Agent tools."""
|
||
|
|
|
||
|
|
import json
|
||
|
|
from loguru import logger
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from typing import Any, Dict, List, Optional, Type, Callable
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from enum import Enum
|
||
|
|
|
||
|
|
class ToolParameterType(str, Enum):
|
||
|
|
"""Tool parameter types."""
|
||
|
|
STRING = "string"
|
||
|
|
INTEGER = "integer"
|
||
|
|
FLOAT = "float"
|
||
|
|
BOOLEAN = "boolean"
|
||
|
|
ARRAY = "array"
|
||
|
|
OBJECT = "object"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ToolParameter:
|
||
|
|
"""Tool parameter definition."""
|
||
|
|
name: str
|
||
|
|
type: ToolParameterType
|
||
|
|
description: str
|
||
|
|
required: bool = True
|
||
|
|
default: Any = None
|
||
|
|
enum: Optional[List[Any]] = None
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
"""Convert to dictionary for JSON schema."""
|
||
|
|
param_dict = {
|
||
|
|
"type": self.type.value,
|
||
|
|
"description": self.description
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.enum:
|
||
|
|
param_dict["enum"] = self.enum
|
||
|
|
|
||
|
|
if self.default is not None:
|
||
|
|
param_dict["default"] = self.default
|
||
|
|
|
||
|
|
return param_dict
|
||
|
|
|
||
|
|
|
||
|
|
class ToolResult(BaseModel):
|
||
|
|
"""Tool execution result."""
|
||
|
|
success: bool = Field(description="Whether the tool execution was successful")
|
||
|
|
result: Any = Field(description="The result data")
|
||
|
|
error: Optional[str] = Field(default=None, description="Error message if failed")
|
||
|
|
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||
|
|
|
||
|
|
|
||
|
|
class BaseTool(ABC):
|
||
|
|
"""Base class for all Agent tools."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.name = self.get_name()
|
||
|
|
self.description = self.get_description()
|
||
|
|
self.parameters = self.get_parameters()
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_name(self) -> str:
|
||
|
|
"""Get tool name."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_description(self) -> str:
|
||
|
|
"""Get tool description."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_parameters(self) -> List[ToolParameter]:
|
||
|
|
"""Get tool parameters."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def execute(self, **kwargs) -> ToolResult:
|
||
|
|
"""Execute the tool with given parameters."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def get_schema(self) -> Dict[str, Any]:
|
||
|
|
"""Get tool schema for LangChain."""
|
||
|
|
properties = {}
|
||
|
|
required = []
|
||
|
|
|
||
|
|
for param in self.parameters:
|
||
|
|
properties[param.name] = param.to_dict()
|
||
|
|
if param.required:
|
||
|
|
required.append(param.name)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"type": "function",
|
||
|
|
"function": {
|
||
|
|
"name": self.name,
|
||
|
|
"description": self.description,
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": properties,
|
||
|
|
"required": required
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
|
||
|
|
"""Validate and process input parameters."""
|
||
|
|
validated = {}
|
||
|
|
|
||
|
|
for param in self.parameters:
|
||
|
|
value = kwargs.get(param.name)
|
||
|
|
|
||
|
|
# Check required parameters
|
||
|
|
if param.required and value is None:
|
||
|
|
raise ValueError(f"Required parameter '{param.name}' is missing")
|
||
|
|
|
||
|
|
# Use default if not provided
|
||
|
|
if value is None and param.default is not None:
|
||
|
|
value = param.default
|
||
|
|
|
||
|
|
# Type validation (basic)
|
||
|
|
if value is not None:
|
||
|
|
if param.type == ToolParameterType.INTEGER and not isinstance(value, int):
|
||
|
|
try:
|
||
|
|
value = int(value)
|
||
|
|
except (ValueError, TypeError):
|
||
|
|
raise ValueError(f"Parameter '{param.name}' must be an integer")
|
||
|
|
|
||
|
|
elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)):
|
||
|
|
try:
|
||
|
|
value = float(value)
|
||
|
|
except (ValueError, TypeError):
|
||
|
|
raise ValueError(f"Parameter '{param.name}' must be a number")
|
||
|
|
|
||
|
|
elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool):
|
||
|
|
if isinstance(value, str):
|
||
|
|
value = value.lower() in ('true', '1', 'yes', 'on')
|
||
|
|
else:
|
||
|
|
value = bool(value)
|
||
|
|
|
||
|
|
# Enum validation
|
||
|
|
if param.enum and value not in param.enum:
|
||
|
|
raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}")
|
||
|
|
|
||
|
|
validated[param.name] = value
|
||
|
|
|
||
|
|
return validated
|
||
|
|
|
||
|
|
|
||
|
|
class ToolRegistry:
|
||
|
|
"""Registry for managing Agent tools."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._tools: Dict[str, BaseTool] = {}
|
||
|
|
self._enabled_tools: Dict[str, bool] = {}
|
||
|
|
|
||
|
|
def register(self, tool: BaseTool, enabled: bool = True) -> None:
|
||
|
|
"""Register a tool."""
|
||
|
|
tool_name = tool.get_name()
|
||
|
|
self._tools[tool_name] = tool
|
||
|
|
self._enabled_tools[tool_name] = enabled
|
||
|
|
logger.info(f"Registered tool: {tool_name} (enabled: {enabled})")
|
||
|
|
|
||
|
|
def unregister(self, tool_name: str) -> None:
|
||
|
|
"""Unregister a tool."""
|
||
|
|
if tool_name in self._tools:
|
||
|
|
del self._tools[tool_name]
|
||
|
|
del self._enabled_tools[tool_name]
|
||
|
|
logger.info(f"Unregistered tool: {tool_name}")
|
||
|
|
|
||
|
|
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
|
||
|
|
"""Get a tool by name."""
|
||
|
|
return self._tools.get(tool_name)
|
||
|
|
|
||
|
|
def get_enabled_tools(self) -> Dict[str, BaseTool]:
|
||
|
|
"""Get all enabled tools."""
|
||
|
|
return {
|
||
|
|
name: tool for name, tool in self._tools.items()
|
||
|
|
if self._enabled_tools.get(name, False)
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_all_tools(self) -> Dict[str, BaseTool]:
|
||
|
|
"""Get all registered tools."""
|
||
|
|
return self._tools.copy()
|
||
|
|
|
||
|
|
def enable_tool(self, tool_name: str) -> None:
|
||
|
|
"""Enable a tool."""
|
||
|
|
if tool_name in self._tools:
|
||
|
|
self._enabled_tools[tool_name] = True
|
||
|
|
logger.info(f"Enabled tool: {tool_name}")
|
||
|
|
|
||
|
|
def disable_tool(self, tool_name: str) -> None:
|
||
|
|
"""Disable a tool."""
|
||
|
|
if tool_name in self._tools:
|
||
|
|
self._enabled_tools[tool_name] = False
|
||
|
|
logger.info(f"Disabled tool: {tool_name}")
|
||
|
|
|
||
|
|
def is_enabled(self, tool_name: str) -> bool:
|
||
|
|
"""Check if a tool is enabled."""
|
||
|
|
return self._enabled_tools.get(tool_name, False)
|
||
|
|
|
||
|
|
def get_tools_schema(self) -> List[Dict[str, Any]]:
|
||
|
|
"""Get schema for all enabled tools."""
|
||
|
|
enabled_tools = self.get_enabled_tools()
|
||
|
|
return [tool.get_schema() for tool in enabled_tools.values()]
|
||
|
|
|
||
|
|
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult:
|
||
|
|
"""Execute a tool with given parameters."""
|
||
|
|
tool = self.get_tool(tool_name)
|
||
|
|
|
||
|
|
if not tool:
|
||
|
|
return ToolResult(
|
||
|
|
success=False,
|
||
|
|
result=None,
|
||
|
|
error=f"Tool '{tool_name}' not found"
|
||
|
|
)
|
||
|
|
|
||
|
|
if not self.is_enabled(tool_name):
|
||
|
|
return ToolResult(
|
||
|
|
success=False,
|
||
|
|
result=None,
|
||
|
|
error=f"Tool '{tool_name}' is disabled"
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Validate parameters
|
||
|
|
validated_params = tool.validate_parameters(**kwargs)
|
||
|
|
|
||
|
|
# Execute tool
|
||
|
|
result = await tool.execute(**validated_params)
|
||
|
|
logger.info(f"Tool '{tool_name}' executed successfully")
|
||
|
|
return result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True)
|
||
|
|
return ToolResult(
|
||
|
|
success=False,
|
||
|
|
result=None,
|
||
|
|
error=f"Tool execution failed: {str(e)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# Global tool registry instance
|
||
|
|
tool_registry = ToolRegistry()
|