feat: 添加知识库节点支持和相关功能

- 在工作流引擎中添加知识库节点的执行逻辑
- 更新数据库查询方式以使用 SQLAlchemy 的异步查询
- 增强知识库节点的查询处理,支持多种字段名和缺省查询处理
- 更新相关文档以反映新功能
This commit is contained in:
eason 2026-01-23 12:45:05 +08:00
parent 42d432acb1
commit e5e9eebcf2
4 changed files with 91 additions and 2 deletions

Binary file not shown.

View File

@ -23,6 +23,7 @@ class NodeType(enum.Enum):
CODE = "code" # 代码执行节点
HTTP = "http" # HTTP请求节点
TOOL = "tool" # 工具节点
KNOWLEDGE_BASE = "knowledge-base" # 知识库节点
class ExecutionStatus(enum.Enum):
"""执行状态枚举"""

View File

@ -458,6 +458,8 @@ class WorkflowEngine:
output_data = await self._execute_code_node(node, input_data)
elif node_type == 'http':
output_data = await self._execute_http_node(node, input_data)
elif node_type == 'knowledge-base':
output_data = await self._execute_knowledge_base_node(node, input_data)
else:
raise ValueError(f"不支持的节点类型: {node_type}")
@ -663,14 +665,22 @@ class WorkflowEngine:
model_id = model_value
else:
# 如果是字符串,按名称查询
llm_config = self.session.query(LLMConfig).filter(LLMConfig.model_name == model_value).first()
from sqlalchemy import select
result = await self.session.execute(
select(LLMConfig).where(LLMConfig.model_name == model_value)
)
llm_config = result.scalar_one_or_none()
if llm_config:
model_id = llm_config.id
if not model_id:
raise ValueError("未指定有效的大模型配置")
llm_config = self.session.query(LLMConfig).filter(LLMConfig.id == model_id).first()
from sqlalchemy import select
result = await self.session.execute(
select(LLMConfig).where(LLMConfig.id == model_id)
)
llm_config = result.scalar_one_or_none()
if not llm_config:
raise ValueError(f"大模型配置 {model_id} 不存在")
@ -936,6 +946,84 @@ class WorkflowEngine:
logger.error(f"HTTP请求失败: {str(e)}")
raise ValueError(f"HTTP请求失败: {str(e)}")
async def _execute_knowledge_base_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库节点"""
from ..services.document import DocumentService
config = input_data['node_config']
# 支持多种字段名knowledge_base_id, knowledgeBase, kb_id
knowledge_base_id = config.get('knowledge_base_id') or config.get('knowledgeBase') or config.get('kb_id')
query = config.get('query', '')
top_k = config.get('top_k', config.get('topK', 5))
similarity_threshold = config.get('similarity_threshold', config.get('similarityThreshold', 0.7))
if not knowledge_base_id:
raise ValueError("知识库节点配置缺少 knowledge_base_id (或 knowledgeBase/kb_id)")
if not query:
# 尝试从 resolved_inputs 中获取查询文本
resolved_inputs = input_data.get('resolved_inputs', {})
query = resolved_inputs.get('query', '')
# 如果还是没有,尝试从 workflow_input 中获取
if not query:
workflow_input = input_data.get('workflow_input', {})
query = workflow_input.get('query', '')
# 如果还是没有,尝试从 previous_outputs 中获取(可能是上一个节点的输出)
if not query:
previous_outputs = input_data.get('previous_outputs', {})
# 尝试从上一个节点的输出中获取查询文本
for node_id, output in previous_outputs.items():
if isinstance(output, dict):
query = output.get('query') or output.get('data', {}).get('query', '')
if query:
break
# 如果还是没有,使用默认查询或抛出错误
if not query:
# 如果没有查询文本,可以返回空结果而不是抛出错误
logger.warning(f"知识库节点缺少查询文本,将返回空结果")
return {
'success': True,
'query': '',
'knowledge_base_id': knowledge_base_id,
'results': [],
'total_results': 0,
'top_k': top_k,
'similarity_threshold': similarity_threshold,
'warning': '缺少查询文本,返回空结果'
}
try:
# 直接使用 document_processor 进行搜索
# 注意get_document_processor 期望同步 Session但这里传入 None 以避免类型不匹配
from ..services.document_processor import get_document_processor
document_processor = await get_document_processor(None)
results = document_processor.search_similar_documents(
knowledge_base_id=knowledge_base_id,
query=query,
k=top_k
)
# 过滤相似度阈值
filtered_results = []
for result in results:
score = result.get('normalized_score', result.get('similarity_score', 0))
if score >= similarity_threshold:
filtered_results.append(result)
return {
'success': True,
'query': query,
'knowledge_base_id': knowledge_base_id,
'results': filtered_results,
'total_results': len(filtered_results),
'top_k': top_k,
'similarity_threshold': similarity_threshold
}
except Exception as e:
logger.error(f"知识库搜索失败: {str(e)}")
raise ValueError(f"知识库搜索失败: {str(e)}")
# 工作流引擎实例
async def get_workflow_engine(session: AsyncSession = None) -> WorkflowEngine: