diff --git a/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin b/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin index 69e7080..d352f28 100644 Binary files a/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin and b/data/chroma/kb_2/714cba61-f48b-47c5-bcb8-eed8ffd400b0/data_level0.bin differ diff --git a/data/chroma/kb_2/chroma.sqlite3 b/data/chroma/kb_2/chroma.sqlite3 index c4673cd..504c313 100644 Binary files a/data/chroma/kb_2/chroma.sqlite3 and b/data/chroma/kb_2/chroma.sqlite3 differ diff --git a/th_agenter/models/workflow.py b/th_agenter/models/workflow.py index 1df374d..bec2add 100644 --- a/th_agenter/models/workflow.py +++ b/th_agenter/models/workflow.py @@ -23,6 +23,7 @@ class NodeType(enum.Enum): CODE = "code" # 代码执行节点 HTTP = "http" # HTTP请求节点 TOOL = "tool" # 工具节点 + KNOWLEDGE_BASE = "knowledge-base" # 知识库节点 class ExecutionStatus(enum.Enum): """执行状态枚举""" diff --git a/th_agenter/services/workflow_engine.py b/th_agenter/services/workflow_engine.py index d958a4f..5fddb48 100644 --- a/th_agenter/services/workflow_engine.py +++ b/th_agenter/services/workflow_engine.py @@ -458,6 +458,8 @@ class WorkflowEngine: output_data = await self._execute_code_node(node, input_data) elif node_type == 'http': output_data = await self._execute_http_node(node, input_data) + elif node_type == 'knowledge-base': + output_data = await self._execute_knowledge_base_node(node, input_data) else: raise ValueError(f"不支持的节点类型: {node_type}") @@ -663,14 +665,22 @@ class WorkflowEngine: model_id = model_value else: # 如果是字符串,按名称查询 - llm_config = self.session.query(LLMConfig).filter(LLMConfig.model_name == model_value).first() + from sqlalchemy import select + result = await self.session.execute( + select(LLMConfig).where(LLMConfig.model_name == model_value) + ) + llm_config = result.scalar_one_or_none() if llm_config: model_id = llm_config.id if not model_id: raise ValueError("未指定有效的大模型配置") - llm_config = self.session.query(LLMConfig).filter(LLMConfig.id == model_id).first() + from sqlalchemy import select + result = await self.session.execute( + select(LLMConfig).where(LLMConfig.id == model_id) + ) + llm_config = result.scalar_one_or_none() if not llm_config: raise ValueError(f"大模型配置 {model_id} 不存在") @@ -935,6 +945,84 @@ class WorkflowEngine: except Exception as e: logger.error(f"HTTP请求失败: {str(e)}") raise ValueError(f"HTTP请求失败: {str(e)}") + + async def _execute_knowledge_base_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行知识库节点""" + from ..services.document import DocumentService + + config = input_data['node_config'] + # 支持多种字段名:knowledge_base_id, knowledgeBase, kb_id + knowledge_base_id = config.get('knowledge_base_id') or config.get('knowledgeBase') or config.get('kb_id') + query = config.get('query', '') + top_k = config.get('top_k', config.get('topK', 5)) + similarity_threshold = config.get('similarity_threshold', config.get('similarityThreshold', 0.7)) + + if not knowledge_base_id: + raise ValueError("知识库节点配置缺少 knowledge_base_id (或 knowledgeBase/kb_id)") + + if not query: + # 尝试从 resolved_inputs 中获取查询文本 + resolved_inputs = input_data.get('resolved_inputs', {}) + query = resolved_inputs.get('query', '') + # 如果还是没有,尝试从 workflow_input 中获取 + if not query: + workflow_input = input_data.get('workflow_input', {}) + query = workflow_input.get('query', '') + # 如果还是没有,尝试从 previous_outputs 中获取(可能是上一个节点的输出) + if not query: + previous_outputs = input_data.get('previous_outputs', {}) + # 尝试从上一个节点的输出中获取查询文本 + for node_id, output in previous_outputs.items(): + if isinstance(output, dict): + query = output.get('query') or output.get('data', {}).get('query', '') + if query: + break + # 如果还是没有,使用默认查询或抛出错误 + if not query: + # 如果没有查询文本,可以返回空结果而不是抛出错误 + logger.warning(f"知识库节点缺少查询文本,将返回空结果") + return { + 'success': True, + 'query': '', + 'knowledge_base_id': knowledge_base_id, + 'results': [], + 'total_results': 0, + 'top_k': top_k, + 'similarity_threshold': similarity_threshold, + 'warning': '缺少查询文本,返回空结果' + } + + try: + # 直接使用 document_processor 进行搜索 + # 注意:get_document_processor 期望同步 Session,但这里传入 None 以避免类型不匹配 + from ..services.document_processor import get_document_processor + document_processor = await get_document_processor(None) + results = document_processor.search_similar_documents( + knowledge_base_id=knowledge_base_id, + query=query, + k=top_k + ) + + # 过滤相似度阈值 + filtered_results = [] + for result in results: + score = result.get('normalized_score', result.get('similarity_score', 0)) + if score >= similarity_threshold: + filtered_results.append(result) + + return { + 'success': True, + 'query': query, + 'knowledge_base_id': knowledge_base_id, + 'results': filtered_results, + 'total_results': len(filtered_results), + 'top_k': top_k, + 'similarity_threshold': similarity_threshold + } + + except Exception as e: + logger.error(f"知识库搜索失败: {str(e)}") + raise ValueError(f"知识库搜索失败: {str(e)}") # 工作流引擎实例