hxf/backend/th_agenter/services/tools/search.py

75 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""基于TavilySearch的搜索工具"""
from th_agenter.core.config import get_settings
from loguru import logger
from langchain.tools import BaseTool
from langchain_community.tools.tavily_search import TavilySearchResults
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, Type, ClassVar
from langchain_tavily import TavilySearch
# 定义输入参数模型替代原get_parameters()
class SearchInput(BaseModel):
query: str = Field(description="搜索查询内容")
max_results: Optional[int] = Field(
default=5,
description="返回结果的最大数量默认5"
)
topic: Optional[str] = Field(
default="general",
description="搜索主题可选值general, academic, news, places"
)
class TavilySearchTool(BaseTool):
name:ClassVar[str] = "tavily_search_tool"
description:ClassVar[str] = """使用Tavily搜索引擎进行网络搜索可以获取最新信息。
输入应该包含搜索查询(query)可选参数包括max_results和topic。""" # 替代get_description()
args_schema: Type[BaseModel] = SearchInput # 用Pydantic模型定义参数
_tavily_api_key: str = PrivateAttr()
_search_client: TavilySearchResults = PrivateAttr()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tavily_api_key = get_settings().tool.tavily_api_key
if not self._tavily_api_key:
raise ValueError("Tavily API key not found in settings")
# 初始化Tavily客户端
self._search_client = TavilySearch(
tavily_api_key=self._tavily_api_key
)
def _run(self, query: str, max_results: int = 5, topic: str = "general"):
try:
logger.info(f"执行搜索:{query}")
# 调用TavilyLangChain已内置Tavily工具这里直接使用
results = self._search_client.run({
"query": query,
"max_results": max_results,
"topic": topic
})
# 格式化结果根据Tavily的实际返回结构调整
if isinstance(results, list):
return {
"status": "success",
"results": [
{
"title": r.get("title", ""),
"url": r.get("url", ""),
"content": r.get("content", "")[:200] + "..."
} for r in results
]
}
else:
return {"status": "error", "message": "Unexpected result format"}
except Exception as e:
logger.error(f"搜索失败: {str(e)}")
return {"status": "error", "message": str(e)}
async def _arun(self, **kwargs):
"""异步版本"""
"""直接调用同步版本"""
return self._run(**kwargs) # 直接委托给同步方法