75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
"""基于TavilySearch的搜索工具"""
|
||
|
||
from th_agenter.core.config import get_settings
|
||
from loguru import logger
|
||
|
||
from langchain.tools import BaseTool
|
||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||
from pydantic import BaseModel, Field, PrivateAttr
|
||
from typing import Optional, Type, ClassVar
|
||
from langchain_tavily import TavilySearch
|
||
|
||
# 定义输入参数模型(替代原get_parameters())
|
||
class SearchInput(BaseModel):
|
||
query: str = Field(description="搜索查询内容")
|
||
max_results: Optional[int] = Field(
|
||
default=5,
|
||
description="返回结果的最大数量(默认:5)"
|
||
)
|
||
topic: Optional[str] = Field(
|
||
default="general",
|
||
description="搜索主题,可选值:general, academic, news, places"
|
||
)
|
||
|
||
|
||
class TavilySearchTool(BaseTool):
|
||
name:ClassVar[str] = "tavily_search_tool"
|
||
description:ClassVar[str] = """使用Tavily搜索引擎进行网络搜索,可以获取最新信息。
|
||
输入应该包含搜索查询(query),可选参数包括max_results和topic。""" # 替代get_description()
|
||
args_schema: Type[BaseModel] = SearchInput # 用Pydantic模型定义参数
|
||
_tavily_api_key: str = PrivateAttr()
|
||
_search_client: TavilySearchResults = PrivateAttr()
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self._tavily_api_key = get_settings().tool.tavily_api_key
|
||
if not self._tavily_api_key:
|
||
raise ValueError("Tavily API key not found in settings")
|
||
|
||
# 初始化Tavily客户端
|
||
self._search_client = TavilySearch(
|
||
tavily_api_key=self._tavily_api_key
|
||
)
|
||
|
||
def _run(self, query: str, max_results: int = 5, topic: str = "general"):
|
||
try:
|
||
logger.info(f"执行搜索:{query}")
|
||
# 调用Tavily(LangChain已内置Tavily工具,这里直接使用)
|
||
results = self._search_client.run({
|
||
"query": query,
|
||
"max_results": max_results,
|
||
"topic": topic
|
||
})
|
||
|
||
# 格式化结果(根据Tavily的实际返回结构调整)
|
||
if isinstance(results, list):
|
||
return {
|
||
"status": "success",
|
||
"results": [
|
||
{
|
||
"title": r.get("title", ""),
|
||
"url": r.get("url", ""),
|
||
"content": r.get("content", "")[:200] + "..."
|
||
} for r in results
|
||
]
|
||
}
|
||
else:
|
||
return {"status": "error", "message": "Unexpected result format"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索失败: {str(e)}")
|
||
return {"status": "error", "message": str(e)}
|
||
|
||
async def _arun(self, **kwargs):
|
||
"""异步版本"""
|
||
"""直接调用同步版本"""
|
||
return self._run(**kwargs) # 直接委托给同步方法 |