2026-01-21 13:45:39 +08:00
|
|
|
|
"""Workflow execution engine."""
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import json
|
|
|
|
|
|
import time
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
|
|
from ..models.workflow import Workflow, WorkflowExecution, NodeExecution, ExecutionStatus, NodeType
|
|
|
|
|
|
from ..models.llm_config import LLMConfig
|
|
|
|
|
|
from ..services.llm_service import LLMService
|
|
|
|
|
|
|
|
|
|
|
|
from ..db.database import get_session, AsyncSessionFactory
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkflowEngine:
|
|
|
|
|
|
"""工作流执行引擎"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
|
|
|
|
self.session = session
|
|
|
|
|
|
self.llm_service = LLMService()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def execute_workflow(self, workflow: Workflow, input_data: Optional[Dict[str, Any]] = None,
|
|
|
|
|
|
user_id: int = None, session: AsyncSession = None):
|
|
|
|
|
|
"""执行工作流"""
|
|
|
|
|
|
from ..schemas.workflow import WorkflowExecutionResponse
|
|
|
|
|
|
|
|
|
|
|
|
id = workflow.id
|
|
|
|
|
|
if session:
|
|
|
|
|
|
self.session = session
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > Enter"
|
|
|
|
|
|
# 创建执行记录
|
|
|
|
|
|
execution = WorkflowExecution(
|
|
|
|
|
|
workflow_id=id,
|
|
|
|
|
|
status=ExecutionStatus.RUNNING,
|
|
|
|
|
|
input_data=input_data or {},
|
|
|
|
|
|
executor_id=user_id,
|
|
|
|
|
|
started_at=datetime.now().isoformat()
|
|
|
|
|
|
)
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > 创建执行记录"
|
|
|
|
|
|
execution.set_audit_fields(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
self.session.add(execution)
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > 添加执行记录"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 重新加载 workflow 对象,确保数据是最新的
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from ..models.workflow import Workflow
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
|
select(Workflow).where(Workflow.id == id)
|
|
|
|
|
|
)
|
|
|
|
|
|
workflow = result.scalar_one_or_none()
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > reload workflow"
|
|
|
|
|
|
|
|
|
|
|
|
# 解析工作流定义
|
|
|
|
|
|
definition = workflow.definition
|
|
|
|
|
|
nodes = {node['id']: node for node in definition['nodes']}
|
|
|
|
|
|
connections = definition['connections']
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > definition {id}"
|
|
|
|
|
|
|
|
|
|
|
|
# 构建节点依赖图
|
|
|
|
|
|
node_graph = self._build_node_graph(nodes, connections)
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > _build_node_graph {id}"
|
|
|
|
|
|
|
|
|
|
|
|
# 执行工作流
|
|
|
|
|
|
result = await self._execute_nodes(execution, nodes, node_graph, input_data or {})
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > _execute_nodes {id}"
|
|
|
|
|
|
|
|
|
|
|
|
# 更新执行状态
|
|
|
|
|
|
execution.status = ExecutionStatus.COMPLETED
|
|
|
|
|
|
execution.output_data = result
|
|
|
|
|
|
execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > execution {id}"
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"工作流执行失败 - {id}: {str(e)}")
|
|
|
|
|
|
execution.status = ExecutionStatus.FAILED
|
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
|
execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
|
|
|
|
|
|
execution.set_audit_fields(user_id, is_update=True)
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > set_audit_fields {id}"
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > refresh {id}"
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from ..models.workflow import NodeExecution
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
|
select(NodeExecution).where(NodeExecution.workflow_execution_id == execution.id)
|
|
|
|
|
|
)
|
|
|
|
|
|
node_executions = result.scalars().all()
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > load node_executions {id}"
|
|
|
|
|
|
node_executions = [node.to_dict() for node in node_executions]
|
|
|
|
|
|
execution_dict = execution.to_dict()
|
|
|
|
|
|
execution_dict['node_executions'] = node_executions
|
|
|
|
|
|
session.desc = f"执行工作流数据 - {id} > build response {id}"
|
|
|
|
|
|
|
|
|
|
|
|
return WorkflowExecutionResponse(**execution_dict)
|
|
|
|
|
|
|
|
|
|
|
|
async def execute_workflow_stream(self, workflow: 'Workflow', input_data: Optional[Dict[str, Any]] = None,
|
|
|
|
|
|
user_id: int = None, session: AsyncSession = None):
|
|
|
|
|
|
"""流式执行工作流,实时推送节点状态"""
|
|
|
|
|
|
from ..schemas.workflow import WorkflowExecutionResponse
|
|
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
|
|
|
|
|
|
|
if session:
|
|
|
|
|
|
self.session = session
|
|
|
|
|
|
|
|
|
|
|
|
# 创建执行记录
|
|
|
|
|
|
execution = WorkflowExecution(
|
|
|
|
|
|
workflow_id=workflow.id,
|
|
|
|
|
|
status=ExecutionStatus.RUNNING,
|
|
|
|
|
|
input_data=input_data or {},
|
|
|
|
|
|
executor_id=user_id,
|
|
|
|
|
|
started_at=datetime.now().isoformat()
|
|
|
|
|
|
)
|
|
|
|
|
|
execution.set_audit_fields(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
self.session.add(execution)
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
# 发送工作流开始执行的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'workflow_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'status': 'started',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
"workflow_id": workflow.id,
|
|
|
|
|
|
"workflow_name": workflow.name,
|
|
|
|
|
|
"input_data": input_data or {},
|
|
|
|
|
|
"started_at": execution.started_at
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解析工作流定义
|
|
|
|
|
|
definition = workflow.definition
|
|
|
|
|
|
nodes = {node['id']: node for node in definition['nodes']}
|
|
|
|
|
|
connections = definition['connections']
|
|
|
|
|
|
|
|
|
|
|
|
# 构建节点依赖图
|
|
|
|
|
|
node_graph = self._build_node_graph(nodes, connections)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行工作流(流式版本)
|
|
|
|
|
|
result = None
|
|
|
|
|
|
async for step_data in self._execute_nodes_stream(execution, nodes, node_graph, input_data or {}):
|
|
|
|
|
|
yield step_data
|
|
|
|
|
|
# 如果是最终结果,保存它
|
|
|
|
|
|
if step_data.get('type') == 'workflow_result':
|
|
|
|
|
|
result = step_data.get('data', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 更新执行状态
|
|
|
|
|
|
execution.status = ExecutionStatus.COMPLETED
|
|
|
|
|
|
execution.output_data = result
|
|
|
|
|
|
execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
|
|
|
|
|
|
# 发送工作流完成的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'workflow_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'status': 'completed',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
"output_data": result,
|
|
|
|
|
|
"completed_at": execution.completed_at
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
# 打印完整堆栈,方便排查如 KeyError("'pk_1'") 之类的问题
|
|
|
|
|
|
logger.exception(f"工作流执行失败: {str(e)}")
|
2026-01-21 13:45:39 +08:00
|
|
|
|
execution.status = ExecutionStatus.FAILED
|
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
|
execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
|
|
|
|
|
|
# 发送工作流失败的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'workflow_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'status': 'failed',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
|
"completed_at": execution.completed_at
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
execution.set_audit_fields(user_id, is_update=True)
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
|
|
|
|
|
"""构建节点依赖图"""
|
|
|
|
|
|
graph = {}
|
|
|
|
|
|
|
|
|
|
|
|
for node_id, node in nodes.items():
|
|
|
|
|
|
graph[node_id] = {
|
|
|
|
|
|
'node': node,
|
|
|
|
|
|
'inputs': [], # 输入节点
|
|
|
|
|
|
'outputs': [] # 输出节点
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for connection in connections:
|
|
|
|
|
|
# 支持两种字段名格式:from/to 和 from_node/to_node
|
|
|
|
|
|
from_node = connection.get('from') or connection.get('from_node')
|
|
|
|
|
|
to_node = connection.get('to') or connection.get('to_node')
|
|
|
|
|
|
|
|
|
|
|
|
if from_node in graph and to_node in graph:
|
|
|
|
|
|
graph[from_node]['outputs'].append(to_node)
|
|
|
|
|
|
graph[to_node]['inputs'].append(from_node)
|
|
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_nodes(self, execution: WorkflowExecution, nodes: Dict[str, Any],
|
|
|
|
|
|
node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行节点"""
|
|
|
|
|
|
# 找到开始节点
|
|
|
|
|
|
start_nodes = [node_id for node_id, info in node_graph.items()
|
|
|
|
|
|
if info['node']['type'] == 'start']
|
|
|
|
|
|
|
|
|
|
|
|
if not start_nodes:
|
|
|
|
|
|
raise ValueError("未找到开始节点")
|
|
|
|
|
|
|
|
|
|
|
|
if len(start_nodes) > 1:
|
|
|
|
|
|
raise ValueError("存在多个开始节点")
|
|
|
|
|
|
|
|
|
|
|
|
start_node_id = start_nodes[0]
|
|
|
|
|
|
|
|
|
|
|
|
# 执行上下文
|
|
|
|
|
|
context = {
|
|
|
|
|
|
'workflow_input': workflow_input,
|
|
|
|
|
|
'node_outputs': {}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 从开始节点开始执行
|
|
|
|
|
|
await self._execute_node_recursive(execution, start_node_id, node_graph, context)
|
|
|
|
|
|
|
|
|
|
|
|
# 找到结束节点的输出作为工作流结果
|
|
|
|
|
|
end_nodes = [node_id for node_id, info in node_graph.items()
|
|
|
|
|
|
if info['node']['type'] == 'end']
|
|
|
|
|
|
|
|
|
|
|
|
if end_nodes:
|
|
|
|
|
|
end_node_id = end_nodes[0]
|
|
|
|
|
|
return context['node_outputs'].get(end_node_id, {})
|
|
|
|
|
|
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_nodes_stream(self, execution: WorkflowExecution, nodes: Dict[str, Any],
|
|
|
|
|
|
node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]):
|
|
|
|
|
|
"""流式执行节点,实时推送节点状态"""
|
|
|
|
|
|
# 找到开始节点
|
|
|
|
|
|
start_nodes = [node_id for node_id, info in node_graph.items()
|
|
|
|
|
|
if info['node']['type'] == 'start']
|
|
|
|
|
|
|
|
|
|
|
|
if not start_nodes:
|
|
|
|
|
|
raise ValueError("未找到开始节点")
|
|
|
|
|
|
|
|
|
|
|
|
if len(start_nodes) > 1:
|
|
|
|
|
|
raise ValueError("存在多个开始节点")
|
|
|
|
|
|
|
|
|
|
|
|
start_node_id = start_nodes[0]
|
|
|
|
|
|
|
|
|
|
|
|
# 执行上下文
|
|
|
|
|
|
context = {
|
|
|
|
|
|
'workflow_input': workflow_input,
|
|
|
|
|
|
'node_outputs': {}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 从开始节点开始执行
|
|
|
|
|
|
async for step_data in self._execute_node_recursive_stream(execution, start_node_id, node_graph, context):
|
|
|
|
|
|
yield step_data
|
|
|
|
|
|
|
|
|
|
|
|
# 找到结束节点的输出作为工作流结果
|
|
|
|
|
|
end_nodes = [node_id for node_id, info in node_graph.items()
|
|
|
|
|
|
if info['node']['type'] == 'end']
|
|
|
|
|
|
|
|
|
|
|
|
if end_nodes:
|
|
|
|
|
|
end_node_id = end_nodes[0]
|
|
|
|
|
|
result = context['node_outputs'].get(end_node_id, {})
|
|
|
|
|
|
else:
|
|
|
|
|
|
result = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 发送最终结果
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'workflow_result',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'data': result,
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_node_recursive_stream(self, execution: WorkflowExecution, node_id: str,
|
|
|
|
|
|
node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]):
|
|
|
|
|
|
"""递归执行节点(流式版本)"""
|
|
|
|
|
|
if node_id in context['node_outputs']:
|
|
|
|
|
|
# 节点已执行过
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
node_info = node_graph[node_id]
|
|
|
|
|
|
node = node_info['node']
|
2026-01-28 20:14:29 +08:00
|
|
|
|
node_type = node.get('type', '')
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 等待所有输入节点完成
|
|
|
|
|
|
for input_node_id in node_info['inputs']:
|
|
|
|
|
|
async for step_data in self._execute_node_recursive_stream(execution, input_node_id, node_graph, context):
|
|
|
|
|
|
yield step_data
|
|
|
|
|
|
|
|
|
|
|
|
# 发送节点开始执行的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'node_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'node_id': node_id,
|
|
|
|
|
|
'status': 'started',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'node_name': node.get('name', ''),
|
|
|
|
|
|
'node_type': node.get('type', ''),
|
|
|
|
|
|
'started_at': datetime.now().isoformat()
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 执行当前节点
|
2026-01-28 20:14:29 +08:00
|
|
|
|
if node_type == 'llm':
|
|
|
|
|
|
# 对 LLM 节点使用真正的流式执行
|
|
|
|
|
|
output = None
|
|
|
|
|
|
async for event in self._execute_llm_node_stream(execution, node, context):
|
|
|
|
|
|
# event 统一为内部事件,包含 event_type 字段
|
|
|
|
|
|
if event.get('event_type') == 'delta':
|
|
|
|
|
|
# 向前端推送流式增量输出
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'node_stream',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'node_id': node_id,
|
|
|
|
|
|
'status': 'streaming',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'node_name': node.get('name', ''),
|
|
|
|
|
|
'node_type': node_type,
|
|
|
|
|
|
'delta': event.get('delta', ''),
|
|
|
|
|
|
'full_response': event.get('full_response', '')
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
elif event.get('event_type') == 'final':
|
|
|
|
|
|
# 最终完整输出,供后续节点使用
|
|
|
|
|
|
output = event.get('output', {})
|
|
|
|
|
|
|
|
|
|
|
|
if output is None:
|
|
|
|
|
|
output = {}
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 非 LLM 节点仍然走原来的单次执行逻辑
|
|
|
|
|
|
output = await self._execute_single_node(execution, node, context)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
context['node_outputs'][node_id] = output
|
|
|
|
|
|
|
|
|
|
|
|
# 发送节点完成的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'node_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'node_id': node_id,
|
|
|
|
|
|
'status': 'completed',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'node_name': node.get('name', ''),
|
|
|
|
|
|
'node_type': node.get('type', ''),
|
|
|
|
|
|
'output': output,
|
|
|
|
|
|
'completed_at': datetime.now().isoformat()
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 发送节点失败的消息
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'type': 'node_status',
|
|
|
|
|
|
'execution_id': execution.id,
|
|
|
|
|
|
'node_id': node_id,
|
|
|
|
|
|
'status': 'failed',
|
|
|
|
|
|
'data': {
|
|
|
|
|
|
'node_name': node.get('name', ''),
|
|
|
|
|
|
'node_type': node.get('type', ''),
|
|
|
|
|
|
'error_message': str(e),
|
|
|
|
|
|
'failed_at': datetime.now().isoformat()
|
|
|
|
|
|
},
|
|
|
|
|
|
'timestamp': datetime.now().isoformat()
|
|
|
|
|
|
}
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
# 执行所有输出节点
|
|
|
|
|
|
for output_node_id in node_info['outputs']:
|
|
|
|
|
|
async for step_data in self._execute_node_recursive_stream(execution, output_node_id, node_graph, context):
|
|
|
|
|
|
yield step_data
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_node_recursive(self, execution: WorkflowExecution, node_id: str,
|
|
|
|
|
|
node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]):
|
|
|
|
|
|
"""递归执行节点"""
|
|
|
|
|
|
if node_id in context['node_outputs']:
|
|
|
|
|
|
# 节点已执行过
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
node_info = node_graph[node_id]
|
|
|
|
|
|
node = node_info['node']
|
|
|
|
|
|
|
|
|
|
|
|
# 等待所有输入节点完成
|
|
|
|
|
|
for input_node_id in node_info['inputs']:
|
|
|
|
|
|
await self._execute_node_recursive(execution, input_node_id, node_graph, context)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行当前节点
|
|
|
|
|
|
output = await self._execute_single_node(execution, node, context)
|
|
|
|
|
|
context['node_outputs'][node_id] = output
|
|
|
|
|
|
|
|
|
|
|
|
# 执行所有输出节点
|
|
|
|
|
|
for output_node_id in node_info['outputs']:
|
|
|
|
|
|
await self._execute_node_recursive(execution, output_node_id, node_graph, context)
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_single_node(self, execution: WorkflowExecution, node: Dict[str, Any],
|
|
|
|
|
|
context: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行单个节点"""
|
|
|
|
|
|
node_id = node['id']
|
|
|
|
|
|
node_type = node['type']
|
|
|
|
|
|
node_name = node['name']
|
|
|
|
|
|
|
|
|
|
|
|
# 创建节点执行记录
|
|
|
|
|
|
node_execution = NodeExecution(
|
|
|
|
|
|
workflow_execution_id=execution.id,
|
|
|
|
|
|
node_id=node_id,
|
|
|
|
|
|
node_type=NodeType(node_type),
|
|
|
|
|
|
node_name=node_name,
|
|
|
|
|
|
status=ExecutionStatus.RUNNING,
|
|
|
|
|
|
started_at=datetime.now().isoformat()
|
|
|
|
|
|
)
|
|
|
|
|
|
self.session.add(node_execution)
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(node_execution)
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 准备输入数据
|
|
|
|
|
|
input_data = self._prepare_node_input(node, context)
|
2026-01-28 20:14:29 +08:00
|
|
|
|
# 这里打印节点级别的输入数据,辅助定位 KeyError 等问题
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.info(f"执行节点 {node_id} ({node_type}) 输入数据: {json.dumps(input_data, ensure_ascii=False)[:2000]}")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
# 有些数据不可序列化,退化为直接打印 repr
|
|
|
|
|
|
logger.info(f"执行节点 {node_id} ({node_type}) 输入数据(非JSON): {repr(input_data)[:2000]}")
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 为前端显示准备输入数据
|
|
|
|
|
|
display_input_data = input_data.copy()
|
|
|
|
|
|
|
|
|
|
|
|
# 对于开始节点,显示的输入应该是workflow_input
|
|
|
|
|
|
if node_type == 'start':
|
|
|
|
|
|
display_input_data = input_data['workflow_input']
|
|
|
|
|
|
elif node_type == 'llm':
|
|
|
|
|
|
# 对于LLM节点,先执行变量替换以获取处理后的提示词
|
|
|
|
|
|
config = input_data['node_config']
|
|
|
|
|
|
prompt_template = config.get('prompt', '')
|
|
|
|
|
|
enable_variable_substitution = config.get('enable_variable_substitution', True)
|
|
|
|
|
|
|
|
|
|
|
|
if enable_variable_substitution:
|
|
|
|
|
|
processed_prompt = self._substitute_variables(prompt_template, input_data)
|
|
|
|
|
|
else:
|
|
|
|
|
|
processed_prompt = prompt_template
|
|
|
|
|
|
|
|
|
|
|
|
display_input_data = {
|
|
|
|
|
|
'original_prompt': prompt_template,
|
|
|
|
|
|
'processed_prompt': processed_prompt,
|
|
|
|
|
|
'model_config': config,
|
|
|
|
|
|
'resolved_inputs': input_data.get('resolved_inputs', {})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
node_execution.input_data = display_input_data
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据节点类型执行
|
|
|
|
|
|
if node_type == 'start':
|
|
|
|
|
|
output_data = await self._execute_start_node(node, input_data)
|
|
|
|
|
|
elif node_type == 'end':
|
|
|
|
|
|
output_data = await self._execute_end_node(node, input_data)
|
|
|
|
|
|
elif node_type == 'llm':
|
|
|
|
|
|
output_data = await self._execute_llm_node(node, input_data)
|
|
|
|
|
|
elif node_type == 'condition':
|
|
|
|
|
|
output_data = await self._execute_condition_node(node, input_data)
|
|
|
|
|
|
elif node_type == 'code':
|
|
|
|
|
|
output_data = await self._execute_code_node(node, input_data)
|
|
|
|
|
|
elif node_type == 'http':
|
|
|
|
|
|
output_data = await self._execute_http_node(node, input_data)
|
2026-01-23 12:45:05 +08:00
|
|
|
|
elif node_type == 'knowledge-base':
|
|
|
|
|
|
output_data = await self._execute_knowledge_base_node(node, input_data)
|
2026-01-21 13:45:39 +08:00
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的节点类型: {node_type}")
|
|
|
|
|
|
|
|
|
|
|
|
# 更新执行状态
|
|
|
|
|
|
end_time = time.time()
|
|
|
|
|
|
node_execution.status = ExecutionStatus.COMPLETED
|
|
|
|
|
|
node_execution.output_data = output_data
|
|
|
|
|
|
node_execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
node_execution.duration_ms = int((end_time - start_time) * 1000)
|
|
|
|
|
|
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
return output_data
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
# 记录更详细的节点异常信息(包含堆栈)
|
|
|
|
|
|
logger.exception(
|
|
|
|
|
|
f"节点执行失败 - id={node_id}, type={node_type}, name={node_name}, "
|
|
|
|
|
|
f"error_type={type(e).__name__}, error={str(e)}"
|
|
|
|
|
|
)
|
2026-01-21 13:45:39 +08:00
|
|
|
|
end_time = time.time()
|
|
|
|
|
|
node_execution.status = ExecutionStatus.FAILED
|
|
|
|
|
|
node_execution.error_message = str(e)
|
|
|
|
|
|
node_execution.completed_at = datetime.now().isoformat()
|
|
|
|
|
|
node_execution.duration_ms = int((end_time - start_time) * 1000)
|
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(execution)
|
|
|
|
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_node_input(self, node: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""准备节点输入数据"""
|
|
|
|
|
|
# 基础输入数据
|
|
|
|
|
|
input_data = {
|
|
|
|
|
|
'workflow_input': context['workflow_input'],
|
|
|
|
|
|
'node_config': node.get('config', {}),
|
|
|
|
|
|
'previous_outputs': context['node_outputs']
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 处理节点参数配置
|
|
|
|
|
|
node_parameters = node.get('parameters', {})
|
|
|
|
|
|
if node_parameters and 'inputs' in node_parameters:
|
|
|
|
|
|
resolved_inputs = {}
|
|
|
|
|
|
|
|
|
|
|
|
for param in node_parameters['inputs']:
|
|
|
|
|
|
param_name = param.get('name')
|
|
|
|
|
|
param_source = param.get('source', 'default')
|
|
|
|
|
|
param_default = param.get('default_value')
|
|
|
|
|
|
variable_name = param.get('variable_name', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 优先使用variable_name,如果存在的话
|
|
|
|
|
|
if variable_name:
|
|
|
|
|
|
resolved_value = self._resolve_variable_value(variable_name, context)
|
|
|
|
|
|
resolved_inputs[param_name] = resolved_value if resolved_value is not None else param_default
|
|
|
|
|
|
elif param_source == 'workflow':
|
|
|
|
|
|
# 从工作流输入获取
|
|
|
|
|
|
source_param_name = param.get('source_param_name', param_name)
|
|
|
|
|
|
resolved_inputs[param_name] = context['workflow_input'].get(source_param_name, param_default)
|
|
|
|
|
|
elif param_source == 'node':
|
|
|
|
|
|
# 从其他节点输出获取
|
|
|
|
|
|
source_node_id = param.get('source_node_id')
|
|
|
|
|
|
source_param_name = param.get('source_param_name', 'data')
|
|
|
|
|
|
|
|
|
|
|
|
if source_node_id and source_node_id in context['node_outputs']:
|
|
|
|
|
|
source_output = context['node_outputs'][source_node_id]
|
|
|
|
|
|
if isinstance(source_output, dict):
|
|
|
|
|
|
resolved_inputs[param_name] = source_output.get(source_param_name, param_default)
|
|
|
|
|
|
else:
|
|
|
|
|
|
resolved_inputs[param_name] = source_output
|
|
|
|
|
|
else:
|
|
|
|
|
|
resolved_inputs[param_name] = param_default
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 使用默认值
|
|
|
|
|
|
resolved_inputs[param_name] = param_default
|
|
|
|
|
|
|
|
|
|
|
|
# 将解析后的参数添加到输入数据
|
|
|
|
|
|
input_data['resolved_inputs'] = resolved_inputs
|
|
|
|
|
|
|
|
|
|
|
|
return input_data
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_variable_value(self, variable_name: str, context: Dict[str, Any]) -> Any:
|
|
|
|
|
|
"""解析变量值,支持格式如 "node_id.output.field_name" 或更深层路径"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解析变量名格式:node_id.output.field_name 或 node_id.field1.field2.field3
|
|
|
|
|
|
parts = variable_name.split('.')
|
|
|
|
|
|
if len(parts) >= 2:
|
|
|
|
|
|
source_node_id = parts[0]
|
|
|
|
|
|
|
|
|
|
|
|
# 从previous_outputs中获取源节点的输出
|
|
|
|
|
|
if source_node_id in context['node_outputs']:
|
|
|
|
|
|
source_output = context['node_outputs'][source_node_id]
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(source_output, dict):
|
|
|
|
|
|
# 从第二个部分开始遍历路径
|
|
|
|
|
|
current_value = source_output
|
|
|
|
|
|
for field_name in parts[1:]:
|
|
|
|
|
|
if isinstance(current_value, dict) and field_name in current_value:
|
|
|
|
|
|
current_value = current_value[field_name]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果路径不存在,返回None
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return current_value
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"解析变量值失败: {variable_name}, 错误: {str(e)}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_start_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行开始节点"""
|
|
|
|
|
|
# 开始节点的输入和输出应该一致,都是workflow_input
|
|
|
|
|
|
workflow_input = input_data['workflow_input']
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'message': '工作流开始',
|
|
|
|
|
|
'data': workflow_input,
|
|
|
|
|
|
'user_input': workflow_input # 添加用户输入显示
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_end_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行结束节点"""
|
|
|
|
|
|
previous_outputs = input_data.get('previous_outputs', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 处理结束节点的输出参数配置
|
2026-01-23 14:02:44 +08:00
|
|
|
|
node_parameters = node.get('parameters') or {}
|
|
|
|
|
|
output_params = node_parameters.get('outputs', []) if isinstance(node_parameters, dict) else []
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
result_data = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 根据输出参数配置获取对应的值
|
|
|
|
|
|
for param in output_params:
|
|
|
|
|
|
param_name = param.get('name')
|
|
|
|
|
|
variable_name = param.get('variable_name')
|
|
|
|
|
|
|
|
|
|
|
|
if variable_name:
|
|
|
|
|
|
# 解析variable_name,格式如: "node_1759022611056.output.response"
|
|
|
|
|
|
try:
|
|
|
|
|
|
parts = variable_name.split('.')
|
|
|
|
|
|
if len(parts) >= 3:
|
|
|
|
|
|
source_node_id = parts[0]
|
|
|
|
|
|
output_type = parts[1] # 通常是"output"
|
|
|
|
|
|
field_name = parts[2] # 具体的字段名,如"response"
|
|
|
|
|
|
|
|
|
|
|
|
# 从前一个节点的输出中获取值
|
|
|
|
|
|
if source_node_id in previous_outputs:
|
|
|
|
|
|
source_output = previous_outputs[source_node_id]
|
|
|
|
|
|
if isinstance(source_output, dict):
|
|
|
|
|
|
# 首先尝试从根级别获取字段(如LLM节点的response字段)
|
|
|
|
|
|
if field_name in source_output:
|
|
|
|
|
|
result_data[param_name] = source_output[field_name]
|
|
|
|
|
|
# 如果根级别没有,再尝试从data字段中获取
|
|
|
|
|
|
elif 'data' in source_output and isinstance(source_output['data'], dict):
|
|
|
|
|
|
if field_name in source_output['data']:
|
|
|
|
|
|
result_data[param_name] = source_output['data'][field_name]
|
|
|
|
|
|
else:
|
|
|
|
|
|
result_data[param_name] = None
|
|
|
|
|
|
else:
|
|
|
|
|
|
result_data[param_name] = None
|
|
|
|
|
|
else:
|
|
|
|
|
|
result_data[param_name] = source_output
|
|
|
|
|
|
else:
|
|
|
|
|
|
result_data[param_name] = None
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 格式不正确,使用默认值
|
|
|
|
|
|
result_data[param_name] = param.get('default_value')
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"解析variable_name失败: {variable_name}, 错误: {str(e)}")
|
|
|
|
|
|
result_data[param_name] = param.get('default_value')
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 没有variable_name,使用默认值
|
|
|
|
|
|
result_data[param_name] = param.get('default_value')
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有配置输出参数,返回简化的前一个节点输出(保持向后兼容)
|
|
|
|
|
|
if not output_params:
|
|
|
|
|
|
simplified_outputs = {}
|
|
|
|
|
|
for node_id, output in previous_outputs.items():
|
|
|
|
|
|
if isinstance(output, dict):
|
|
|
|
|
|
simplified_outputs[node_id] = {
|
|
|
|
|
|
'success': output.get('success', False),
|
|
|
|
|
|
'message': output.get('message', ''),
|
|
|
|
|
|
'data': output.get('data', {}) if not isinstance(output.get('data'), dict) or node_id not in str(output.get('data', {})) else {}
|
|
|
|
|
|
}
|
|
|
|
|
|
else:
|
|
|
|
|
|
simplified_outputs[node_id] = output
|
|
|
|
|
|
result_data = simplified_outputs
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'message': '工作流结束',
|
|
|
|
|
|
'data': result_data
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-28 20:14:29 +08:00
|
|
|
|
def _build_llm_prompt(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> (str, str):
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据节点配置、知识库结果和工作流输入构建提示词。
|
|
|
|
|
|
返回 (prompt, prompt_template):
|
|
|
|
|
|
- prompt: 变量替换后的最终提示词
|
|
|
|
|
|
- prompt_template: 原始模板(未替换变量)
|
|
|
|
|
|
"""
|
2026-01-23 14:02:44 +08:00
|
|
|
|
config = input_data.get('node_config', {})
|
2026-01-21 13:45:39 +08:00
|
|
|
|
prompt_template = config.get('prompt', '')
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
# 如果提示词为空,尝试自动构建提示词(RAG 或直接用用户输入)
|
2026-01-23 14:31:28 +08:00
|
|
|
|
if not prompt_template:
|
|
|
|
|
|
previous_outputs = input_data.get('previous_outputs', {})
|
|
|
|
|
|
knowledge_base_results = None
|
|
|
|
|
|
user_query = None
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 14:31:28 +08:00
|
|
|
|
# 查找知识库节点的输出
|
|
|
|
|
|
for node_id, output in previous_outputs.items():
|
|
|
|
|
|
if isinstance(output, dict) and output.get('knowledge_base_id'):
|
|
|
|
|
|
knowledge_base_results = output.get('results', [])
|
|
|
|
|
|
user_query = output.get('query', '')
|
|
|
|
|
|
break
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 14:31:28 +08:00
|
|
|
|
# 如果没有找到知识库结果,尝试从工作流输入中获取查询
|
|
|
|
|
|
if not user_query:
|
|
|
|
|
|
workflow_input = input_data.get('workflow_input', {})
|
|
|
|
|
|
for key, value in workflow_input.items():
|
|
|
|
|
|
if isinstance(value, str) and value.strip():
|
|
|
|
|
|
user_query = value.strip()
|
|
|
|
|
|
break
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 14:31:28 +08:00
|
|
|
|
# 构建提示词
|
|
|
|
|
|
if knowledge_base_results and len(knowledge_base_results) > 0:
|
2026-01-23 14:36:49 +08:00
|
|
|
|
max_score = 0
|
|
|
|
|
|
for result in knowledge_base_results:
|
|
|
|
|
|
score = result.get('normalized_score', result.get('similarity_score', 0))
|
|
|
|
|
|
if score > max_score:
|
|
|
|
|
|
max_score = score
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-30 16:56:10 +08:00
|
|
|
|
is_relevant = max_score >= 0.45
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 14:36:49 +08:00
|
|
|
|
if is_relevant:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
# 有相关的知识库结果,构建 RAG 风格的提示词
|
2026-01-23 14:36:49 +08:00
|
|
|
|
context_parts = []
|
2026-01-28 20:14:29 +08:00
|
|
|
|
for i, result in enumerate(knowledge_base_results[:5], 1):
|
2026-01-23 14:36:49 +08:00
|
|
|
|
content = result.get('content', '').strip()
|
|
|
|
|
|
if content:
|
|
|
|
|
|
max_length = 1000
|
|
|
|
|
|
if len(content) > max_length:
|
|
|
|
|
|
content = content[:max_length] + "..."
|
|
|
|
|
|
context_parts.append(f"【参考文档{i}】\n{content}\n")
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 14:36:49 +08:00
|
|
|
|
context = "\n\n".join(context_parts)
|
|
|
|
|
|
prompt_template = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
|
2026-01-23 14:31:28 +08:00
|
|
|
|
|
|
|
|
|
|
{context}
|
|
|
|
|
|
|
|
|
|
|
|
【用户问题】
|
|
|
|
|
|
{user_query or '请回答上述问题'}
|
|
|
|
|
|
|
|
|
|
|
|
【重要提示】
|
|
|
|
|
|
- 参考文档中包含了与用户问题相关的信息
|
|
|
|
|
|
- 请仔细阅读参考文档,提取相关信息来回答用户的问题
|
|
|
|
|
|
- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明
|
|
|
|
|
|
- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答
|
|
|
|
|
|
- 回答要准确、详细、有条理,尽量引用文档中的具体内容"""
|
2026-01-28 20:14:29 +08:00
|
|
|
|
logger.info(
|
|
|
|
|
|
f"自动构建RAG提示词,包含 {len(knowledge_base_results)} 个相关知识库结果(最高相似度: {max_score:.3f}),用户问题: {user_query}"
|
|
|
|
|
|
)
|
2026-01-23 14:36:49 +08:00
|
|
|
|
else:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
logger.warning(
|
|
|
|
|
|
f"知识库结果相似度较低(最高: {max_score:.3f}),认为不相关,将直接回答用户问题"
|
|
|
|
|
|
)
|
2026-01-23 14:36:49 +08:00
|
|
|
|
prompt_template = user_query or "请帮助我处理这个任务。"
|
2026-01-23 14:31:28 +08:00
|
|
|
|
elif user_query:
|
|
|
|
|
|
prompt_template = user_query
|
|
|
|
|
|
logger.info(f"自动使用工作流输入作为提示词: {user_query}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_template = "请帮助我处理这个任务。"
|
|
|
|
|
|
logger.warning("LLM节点提示词为空,且无法从上下文获取,使用默认提示词")
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
# 变量替换
|
2026-01-21 13:45:39 +08:00
|
|
|
|
enable_variable_substitution = config.get('enable_variable_substitution', True)
|
|
|
|
|
|
if enable_variable_substitution:
|
|
|
|
|
|
prompt = self._substitute_variables(prompt_template, input_data)
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt = prompt_template
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
return prompt, prompt_template
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_llm_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行LLM节点(非流式)"""
|
|
|
|
|
|
config = input_data.get('node_config', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 标记是否已经使用了默认模型(用于错误时决定是否回退)
|
|
|
|
|
|
used_default_model = False
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 LLM 配置
|
|
|
|
|
|
model_id = config.get('model_id')
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
model_value = config.get('model_name', config.get('model'))
|
|
|
|
|
|
if model_value:
|
|
|
|
|
|
if isinstance(model_value, int):
|
|
|
|
|
|
model_id = model_value
|
|
|
|
|
|
else:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.model_name == model_value)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_cfg = result.scalar_one_or_none()
|
|
|
|
|
|
if llm_cfg:
|
|
|
|
|
|
model_id = llm_cfg.id
|
|
|
|
|
|
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
node_config = node.get('config', {})
|
|
|
|
|
|
model_id = node_config.get('model_id')
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
model_value = node_config.get('model_name', node_config.get('model'))
|
|
|
|
|
|
if model_value:
|
|
|
|
|
|
if isinstance(model_value, int):
|
|
|
|
|
|
model_id = model_value
|
|
|
|
|
|
else:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.model_name == model_value)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_cfg = result.scalar_one_or_none()
|
|
|
|
|
|
if llm_cfg:
|
|
|
|
|
|
model_id = llm_cfg.id
|
|
|
|
|
|
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
from ..services.llm_config_service import LLMConfigService
|
|
|
|
|
|
llm_config_service = LLMConfigService()
|
|
|
|
|
|
default_config = await llm_config_service.get_default_chat_config(self.session)
|
|
|
|
|
|
if default_config:
|
|
|
|
|
|
model_id = default_config.id
|
|
|
|
|
|
used_default_model = True
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"LLM节点未指定模型配置,使用默认模型: {default_config.model_name} (ID: {model_id})"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"未指定有效的大模型配置,且未找到默认配置。\n"
|
|
|
|
|
|
"请在节点配置中添加模型ID或模型名称,例如:\n"
|
|
|
|
|
|
" - config.model_id: 1\n"
|
|
|
|
|
|
" - config.model_name: 'gpt-4'\n"
|
|
|
|
|
|
" - config.model: 'gpt-4'\n"
|
|
|
|
|
|
"或者设置一个默认的LLM配置。"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.id == model_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_config = result.scalar_one_or_none()
|
|
|
|
|
|
if not llm_config:
|
|
|
|
|
|
raise ValueError(f"大模型配置 {model_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用统一的构建逻辑生成提示词
|
|
|
|
|
|
prompt, prompt_template = self._build_llm_prompt(node, input_data)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"LLM 节点最终提示词(非流式): node_id={node.get('id')}, "
|
|
|
|
|
|
f"model_id={llm_config.id}, model_name={llm_config.model_name}, prompt={prompt}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
# 记录处理后的提示词到输入数据中,用于前端显示
|
|
|
|
|
|
input_data['processed_prompt'] = prompt
|
|
|
|
|
|
input_data['original_prompt'] = prompt_template
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
# 调用 LLM 服务(非流式路径:用于 /execute 接口)
|
2026-01-21 13:45:39 +08:00
|
|
|
|
try:
|
|
|
|
|
|
response = await self.llm_service.chat_completion(
|
|
|
|
|
|
model_config=llm_config,
|
|
|
|
|
|
messages=[{"role": "user", "content": prompt}],
|
|
|
|
|
|
temperature=config.get('temperature', 0.7),
|
2026-01-28 20:14:29 +08:00
|
|
|
|
max_tokens=config.get('max_tokens'),
|
2026-01-21 13:45:39 +08:00
|
|
|
|
)
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'response': response,
|
|
|
|
|
|
'prompt': prompt,
|
|
|
|
|
|
'model': llm_config.model_name,
|
2026-01-28 20:14:29 +08:00
|
|
|
|
'tokens_used': getattr(response, 'usage', {}).get('total_tokens', 0)
|
|
|
|
|
|
if hasattr(response, 'usage')
|
|
|
|
|
|
else 0,
|
2026-01-21 13:45:39 +08:00
|
|
|
|
}
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
except Exception as e:
|
2026-01-23 17:30:48 +08:00
|
|
|
|
error_msg = str(e)
|
2026-01-24 11:18:19 +08:00
|
|
|
|
detailed_error = error_msg
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-24 11:18:19 +08:00
|
|
|
|
if "使用的模型:" not in error_msg and "模型:" not in error_msg:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
model_info = (
|
|
|
|
|
|
f"使用的模型: {llm_config.model_name} (ID: {llm_config.id}), "
|
|
|
|
|
|
f"base_url: {llm_config.base_url}"
|
|
|
|
|
|
)
|
2026-01-24 11:18:19 +08:00
|
|
|
|
if "Not Found" in error_msg or "404" in error_msg:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
detailed_error = (
|
|
|
|
|
|
f"{detailed_error}。{model_info}。可能的原因:1) 模型名称格式不正确(SiliconFlow需要org/model格式);"
|
|
|
|
|
|
"2) base_url配置错误;3) API端点不存在"
|
|
|
|
|
|
)
|
|
|
|
|
|
elif (
|
|
|
|
|
|
"403" in error_msg
|
|
|
|
|
|
or "account balance" in error_msg.lower()
|
|
|
|
|
|
or "insufficient" in error_msg.lower()
|
|
|
|
|
|
):
|
|
|
|
|
|
detailed_error = (
|
|
|
|
|
|
f"{detailed_error}。{model_info}。可能的原因:账户余额不足或API密钥权限不足"
|
|
|
|
|
|
)
|
2026-01-24 11:18:19 +08:00
|
|
|
|
elif "401" in error_msg or "authentication" in error_msg.lower():
|
2026-01-28 20:14:29 +08:00
|
|
|
|
detailed_error = (
|
|
|
|
|
|
f"{detailed_error}。{model_info}。可能的原因:API密钥无效或已过期"
|
|
|
|
|
|
)
|
2026-01-24 11:18:19 +08:00
|
|
|
|
else:
|
|
|
|
|
|
detailed_error = f"{detailed_error}。{model_info}"
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-24 11:18:19 +08:00
|
|
|
|
logger.error(f"LLM调用失败: {detailed_error}")
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
if (not used_default_model) and (
|
|
|
|
|
|
"Not Found" in error_msg or "404" in error_msg
|
|
|
|
|
|
):
|
2026-01-23 17:30:48 +08:00
|
|
|
|
try:
|
|
|
|
|
|
from ..services.llm_config_service import LLMConfigService
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
2026-01-23 17:30:48 +08:00
|
|
|
|
llm_config_service = LLMConfigService()
|
2026-01-28 20:14:29 +08:00
|
|
|
|
default_config = await llm_config_service.get_default_chat_config(
|
|
|
|
|
|
self.session
|
|
|
|
|
|
)
|
2026-01-23 17:30:48 +08:00
|
|
|
|
if default_config:
|
|
|
|
|
|
logger.warning(
|
2026-01-28 20:14:29 +08:00
|
|
|
|
"LLM调用失败,模型可能不存在或端点错误,"
|
2026-01-23 17:30:48 +08:00
|
|
|
|
f"尝试使用默认模型重试: {default_config.model_name} (ID: {default_config.id})"
|
|
|
|
|
|
)
|
|
|
|
|
|
fallback_response = await self.llm_service.chat_completion(
|
|
|
|
|
|
model_config=default_config,
|
|
|
|
|
|
messages=[{"role": "user", "content": prompt}],
|
|
|
|
|
|
temperature=config.get('temperature', 0.7),
|
2026-01-28 20:14:29 +08:00
|
|
|
|
max_tokens=config.get('max_tokens'),
|
2026-01-23 17:30:48 +08:00
|
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'response': fallback_response,
|
|
|
|
|
|
'prompt': prompt,
|
|
|
|
|
|
'model': default_config.model_name,
|
2026-01-28 20:14:29 +08:00
|
|
|
|
'tokens_used': getattr(
|
|
|
|
|
|
fallback_response, 'usage', {}
|
|
|
|
|
|
).get('total_tokens', 0)
|
|
|
|
|
|
if hasattr(fallback_response, 'usage')
|
|
|
|
|
|
else 0,
|
|
|
|
|
|
'fallback_model_used': True,
|
2026-01-23 17:30:48 +08:00
|
|
|
|
}
|
|
|
|
|
|
except Exception as fallback_error:
|
2026-01-28 20:14:29 +08:00
|
|
|
|
logger.error(
|
|
|
|
|
|
f"使用默认模型重试LLM调用失败: {str(fallback_error)}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-24 11:18:19 +08:00
|
|
|
|
raise ValueError(f"LLM调用失败: {detailed_error}")
|
2026-01-28 20:14:29 +08:00
|
|
|
|
|
|
|
|
|
|
async def _execute_llm_node_stream(self, execution: WorkflowExecution, node: Dict[str, Any], context: Dict[str, Any]):
|
|
|
|
|
|
"""执行LLM节点(流式版本),用于 /execute-stream 接口"""
|
|
|
|
|
|
node_id = node['id']
|
|
|
|
|
|
config = self._prepare_node_input(node, context).get('node_config', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 下面的逻辑与 _execute_llm_node 中获取模型配置和提示词的过程保持一致,
|
|
|
|
|
|
# 以保证流式与非流式路径的行为一致。
|
|
|
|
|
|
used_default_model = False
|
|
|
|
|
|
|
|
|
|
|
|
# 获取LLM配置
|
|
|
|
|
|
model_id = config.get('model_id')
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
model_value = config.get('model_name', config.get('model'))
|
|
|
|
|
|
if model_value:
|
|
|
|
|
|
if isinstance(model_value, int):
|
|
|
|
|
|
model_id = model_value
|
|
|
|
|
|
else:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.model_name == model_value)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_cfg = result.scalar_one_or_none()
|
|
|
|
|
|
if llm_cfg:
|
|
|
|
|
|
model_id = llm_cfg.id
|
|
|
|
|
|
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
node_config = node.get('config', {})
|
|
|
|
|
|
model_id = node_config.get('model_id')
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
model_value = node_config.get('model_name', node_config.get('model'))
|
|
|
|
|
|
if model_value:
|
|
|
|
|
|
if isinstance(model_value, int):
|
|
|
|
|
|
model_id = model_value
|
|
|
|
|
|
else:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.model_name == model_value)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_cfg = result.scalar_one_or_none()
|
|
|
|
|
|
if llm_cfg:
|
|
|
|
|
|
model_id = llm_cfg.id
|
|
|
|
|
|
|
|
|
|
|
|
if not model_id:
|
|
|
|
|
|
from ..services.llm_config_service import LLMConfigService
|
|
|
|
|
|
llm_config_service = LLMConfigService()
|
|
|
|
|
|
default_config = await llm_config_service.get_default_chat_config(self.session)
|
|
|
|
|
|
if default_config:
|
|
|
|
|
|
model_id = default_config.id
|
|
|
|
|
|
used_default_model = True
|
|
|
|
|
|
logger.info(f"[STREAM] LLM节点未指定模型配置,使用默认模型: {default_config.model_name} (ID: {model_id})")
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"未指定有效的大模型配置,且未找到默认配置。"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
result = await self.session.execute(
|
|
|
|
|
|
select(LLMConfig).where(LLMConfig.id == model_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
llm_config = result.scalar_one_or_none()
|
|
|
|
|
|
if not llm_config:
|
|
|
|
|
|
raise ValueError(f"大模型配置 {model_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 构造 prompt,使用与非流式路径相同的逻辑
|
|
|
|
|
|
input_data = self._prepare_node_input(node, context)
|
|
|
|
|
|
config = input_data.get('node_config', {})
|
|
|
|
|
|
prompt, prompt_template = self._build_llm_prompt(node, input_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 打印流式路径下的提示词,确认实际发给大模型的内容
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"LLM 节点最终提示词(流式): node_id={node.get('id')}, "
|
|
|
|
|
|
f"model_id={llm_config.id}, model_name={llm_config.model_name}, prompt={prompt}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
full_response = ""
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 调用 LLMService 流式接口
|
|
|
|
|
|
async for chunk in self.llm_service.chat_completion_stream(
|
|
|
|
|
|
model_config=llm_config,
|
|
|
|
|
|
messages=[{"role": "user", "content": prompt}],
|
|
|
|
|
|
temperature=config.get('temperature', 0.7),
|
|
|
|
|
|
max_tokens=config.get('max_tokens')
|
|
|
|
|
|
):
|
|
|
|
|
|
if not chunk:
|
|
|
|
|
|
continue
|
|
|
|
|
|
full_response += chunk
|
|
|
|
|
|
# 将增量结果向外层生成器抛出
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'event_type': 'delta',
|
|
|
|
|
|
'delta': chunk,
|
|
|
|
|
|
'full_response': full_response,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 完成后抛出最终结果,供后续节点依赖
|
|
|
|
|
|
final_output = {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'response': full_response,
|
|
|
|
|
|
'prompt': prompt,
|
|
|
|
|
|
'model': llm_config.model_name,
|
|
|
|
|
|
'tokens_used': 0 # 流式接口暂不提供 usage 统计
|
|
|
|
|
|
}
|
|
|
|
|
|
yield {
|
|
|
|
|
|
'event_type': 'final',
|
|
|
|
|
|
'output': final_output,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
error_msg = str(e)
|
|
|
|
|
|
detailed_error = error_msg
|
|
|
|
|
|
if "使用的模型:" not in error_msg and "模型:" not in error_msg:
|
|
|
|
|
|
model_info = f"使用的模型: {llm_config.model_name} (ID: {llm_config.id}), base_url: {llm_config.base_url}"
|
|
|
|
|
|
detailed_error = f"{detailed_error}。{model_info}"
|
|
|
|
|
|
logger.error(f"[STREAM] LLM流式调用失败: {detailed_error}")
|
|
|
|
|
|
raise ValueError(f"LLM流式调用失败: {detailed_error}")
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
def _substitute_variables(self, template: str, input_data: Dict[str, Any]) -> str:
|
|
|
|
|
|
"""变量替换函数"""
|
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
# 获取解析后的输入参数
|
|
|
|
|
|
resolved_inputs = input_data.get('resolved_inputs', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 获取工作流输入数据
|
|
|
|
|
|
# input_data['workflow_input'] 包含了用户输入的参数
|
|
|
|
|
|
workflow_input = input_data.get('workflow_input', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 构建变量上下文
|
|
|
|
|
|
variable_context = {}
|
|
|
|
|
|
|
|
|
|
|
|
# 首先添加解析后的参数
|
|
|
|
|
|
variable_context.update(resolved_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加工作流输入的顶层字段
|
|
|
|
|
|
variable_context.update(workflow_input)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果 workflow_input 包含 user_input 字段,将其内容提升到顶层
|
|
|
|
|
|
if 'user_input' in workflow_input and isinstance(workflow_input['user_input'], dict):
|
|
|
|
|
|
variable_context.update(workflow_input['user_input'])
|
|
|
|
|
|
|
|
|
|
|
|
# 添加前一个节点的输出(简化访问)
|
|
|
|
|
|
for node_id, output in input_data.get('previous_outputs', {}).items():
|
|
|
|
|
|
if isinstance(output, dict):
|
|
|
|
|
|
# 添加节点输出的直接访问
|
|
|
|
|
|
variable_context[f'node_{node_id}'] = output.get('data', output)
|
|
|
|
|
|
# 如果输出有response字段,也添加直接访问
|
|
|
|
|
|
if 'response' in output:
|
|
|
|
|
|
variable_context[f'node_{node_id}_response'] = output['response']
|
|
|
|
|
|
|
|
|
|
|
|
# 调试日志:打印变量上下文
|
|
|
|
|
|
logger.info(f"变量替换上下文: {variable_context}")
|
|
|
|
|
|
logger.info(f"原始模板: {template}")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用正则表达式替换变量 {{variable_name}} 和 {variable_name}
|
|
|
|
|
|
def replace_variable(match):
|
|
|
|
|
|
var_name = match.group(1)
|
|
|
|
|
|
replacement = variable_context.get(var_name, match.group(0))
|
|
|
|
|
|
logger.info(f"替换变量 {match.group(0)} -> {replacement}")
|
|
|
|
|
|
return str(replacement)
|
|
|
|
|
|
|
|
|
|
|
|
# 首先替换 {{variable_name}} 格式的变量
|
|
|
|
|
|
result = re.sub(r'\{\{([^}]+)\}\}', replace_variable, template)
|
|
|
|
|
|
# 然后替换 {variable_name} 格式的变量
|
|
|
|
|
|
result = re.sub(r'\{([^}]+)\}', replace_variable, result)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"替换后结果: {result}")
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_condition_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行条件节点"""
|
|
|
|
|
|
config = input_data['node_config']
|
|
|
|
|
|
condition = config.get('condition', '')
|
|
|
|
|
|
|
|
|
|
|
|
# 简单的条件评估(生产环境需要更安全的实现)
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 构建评估上下文
|
|
|
|
|
|
eval_context = {
|
|
|
|
|
|
'input': input_data['workflow_input'],
|
|
|
|
|
|
'previous': input_data['previous_outputs']
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 评估条件
|
|
|
|
|
|
result = eval(condition, {"__builtins__": {}}, eval_context)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'condition': condition,
|
|
|
|
|
|
'result': bool(result)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"条件评估失败: {str(e)}")
|
|
|
|
|
|
raise ValueError(f"条件评估失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_code_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行代码节点"""
|
|
|
|
|
|
config = input_data['node_config']
|
|
|
|
|
|
language = config.get('language', 'python')
|
|
|
|
|
|
code = config.get('code', '')
|
|
|
|
|
|
|
|
|
|
|
|
if language == 'python':
|
|
|
|
|
|
# 执行Python代码
|
|
|
|
|
|
execution_result = await self._execute_python_code(code, input_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 处理输出参数配置
|
|
|
|
|
|
node_parameters = node.get('parameters', {})
|
|
|
|
|
|
if node_parameters and 'outputs' in node_parameters:
|
|
|
|
|
|
output_params = node_parameters['outputs']
|
|
|
|
|
|
code_result = execution_result.get('result', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 根据输出参数配置构建最终输出
|
|
|
|
|
|
final_output = {
|
|
|
|
|
|
'success': execution_result['success'],
|
|
|
|
|
|
'code': execution_result['code'],
|
|
|
|
|
|
'input_parameters': execution_result.get('input_parameters', {})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 如果代码返回的是字典,根据输出参数配置提取对应字段
|
|
|
|
|
|
if isinstance(code_result, dict):
|
|
|
|
|
|
for output_param in output_params:
|
|
|
|
|
|
param_name = output_param.get('name')
|
|
|
|
|
|
if param_name and param_name in code_result:
|
|
|
|
|
|
final_output[param_name] = code_result[param_name]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果代码返回的不是字典,且只有一个输出参数,直接使用返回值
|
|
|
|
|
|
if len(output_params) == 1:
|
|
|
|
|
|
param_name = output_params[0].get('name')
|
|
|
|
|
|
if param_name:
|
|
|
|
|
|
final_output[param_name] = code_result
|
|
|
|
|
|
|
|
|
|
|
|
return final_output
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没有输出参数配置,返回原始结果
|
|
|
|
|
|
return execution_result
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的代码语言: {language}")
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_python_code(self, code: str, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行Python代码"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 构建执行上下文
|
|
|
|
|
|
safe_builtins = {
|
|
|
|
|
|
'len': len,
|
|
|
|
|
|
'str': str,
|
|
|
|
|
|
'int': int,
|
|
|
|
|
|
'float': float,
|
|
|
|
|
|
'bool': bool,
|
|
|
|
|
|
'list': list,
|
|
|
|
|
|
'dict': dict,
|
|
|
|
|
|
'tuple': tuple,
|
|
|
|
|
|
'set': set,
|
|
|
|
|
|
'range': range,
|
|
|
|
|
|
'enumerate': enumerate,
|
|
|
|
|
|
'zip': zip,
|
|
|
|
|
|
'sum': sum,
|
|
|
|
|
|
'min': min,
|
|
|
|
|
|
'max': max,
|
|
|
|
|
|
'abs': abs,
|
|
|
|
|
|
'round': round,
|
|
|
|
|
|
'sorted': sorted,
|
|
|
|
|
|
'reversed': reversed,
|
|
|
|
|
|
'print': print,
|
|
|
|
|
|
'__import__': __import__,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 导入常用模块
|
|
|
|
|
|
import json
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
import math
|
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
exec_context = {
|
|
|
|
|
|
'__builtins__': safe_builtins,
|
|
|
|
|
|
'json': json, # 允许使用json模块
|
|
|
|
|
|
'datetime': datetime, # 允许使用datetime模块
|
|
|
|
|
|
'math': math, # 允许使用math模块
|
|
|
|
|
|
're': re, # 允许使用re模块
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 执行代码以定义函数
|
|
|
|
|
|
exec(code, exec_context)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否定义了main函数
|
|
|
|
|
|
if 'main' not in exec_context:
|
|
|
|
|
|
raise ValueError("代码中必须定义一个main函数")
|
|
|
|
|
|
|
|
|
|
|
|
main_function = exec_context['main']
|
|
|
|
|
|
|
|
|
|
|
|
# 获取已解析的输入参数
|
|
|
|
|
|
resolved_inputs = input_data.get('resolved_inputs', {})
|
|
|
|
|
|
|
|
|
|
|
|
# 调用main函数并传递参数
|
|
|
|
|
|
if resolved_inputs:
|
|
|
|
|
|
# 使用解析后的输入参数调用main函数
|
|
|
|
|
|
result = main_function(**resolved_inputs)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没有输入参数,直接调用main函数
|
|
|
|
|
|
result = main_function()
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'result': result,
|
|
|
|
|
|
'code': code,
|
|
|
|
|
|
'input_parameters': resolved_inputs
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"Python代码执行失败: {str(e)}")
|
|
|
|
|
|
raise ValueError(f"Python代码执行失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_http_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行HTTP请求节点"""
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
|
|
|
|
|
|
|
config = input_data['node_config']
|
|
|
|
|
|
method = config.get('method', 'GET').upper()
|
|
|
|
|
|
url = config.get('url', '')
|
|
|
|
|
|
headers = config.get('headers', {})
|
|
|
|
|
|
body = config.get('body')
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
|
async with session.request(
|
|
|
|
|
|
method=method,
|
|
|
|
|
|
url=url,
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
data=body
|
|
|
|
|
|
) as response:
|
|
|
|
|
|
response_text = await response.text()
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'status_code': response.status,
|
|
|
|
|
|
'response': response_text,
|
|
|
|
|
|
'headers': dict(response.headers)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"HTTP请求失败: {str(e)}")
|
|
|
|
|
|
raise ValueError(f"HTTP请求失败: {str(e)}")
|
2026-01-23 12:45:05 +08:00
|
|
|
|
|
|
|
|
|
|
async def _execute_knowledge_base_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
"""执行知识库节点"""
|
|
|
|
|
|
from ..services.document import DocumentService
|
|
|
|
|
|
|
2026-01-23 13:59:20 +08:00
|
|
|
|
config = input_data.get('node_config', {})
|
2026-01-23 12:45:05 +08:00
|
|
|
|
# 支持多种字段名:knowledge_base_id, knowledgeBase, kb_id
|
2026-01-23 13:59:20 +08:00
|
|
|
|
# 先从 config 中获取
|
2026-01-23 12:45:05 +08:00
|
|
|
|
knowledge_base_id = config.get('knowledge_base_id') or config.get('knowledgeBase') or config.get('kb_id')
|
2026-01-23 13:59:20 +08:00
|
|
|
|
# 如果 config 中没有,尝试从节点定义本身获取
|
|
|
|
|
|
if not knowledge_base_id:
|
|
|
|
|
|
knowledge_base_id = node.get('knowledge_base_id') or node.get('knowledgeBase') or node.get('kb_id')
|
|
|
|
|
|
# 如果还是没有,尝试从节点的 config 字段(节点定义中的 config)获取
|
|
|
|
|
|
if not knowledge_base_id and 'config' in node:
|
|
|
|
|
|
node_config = node.get('config', {})
|
|
|
|
|
|
knowledge_base_id = node_config.get('knowledge_base_id') or node_config.get('knowledgeBase') or node_config.get('kb_id')
|
|
|
|
|
|
|
2026-01-23 12:45:05 +08:00
|
|
|
|
query = config.get('query', '')
|
|
|
|
|
|
top_k = config.get('top_k', config.get('topK', 5))
|
2026-01-30 16:56:10 +08:00
|
|
|
|
similarity_threshold = config.get('similarity_threshold', config.get('similarityThreshold', 0.45))
|
2026-01-23 12:45:05 +08:00
|
|
|
|
|
2026-01-23 13:59:20 +08:00
|
|
|
|
# 如果还是没有,尝试从节点名称中提取(例如 "knowledge-base 2" -> 2)
|
|
|
|
|
|
if not knowledge_base_id:
|
|
|
|
|
|
node_name = node.get('name', '')
|
|
|
|
|
|
import re
|
|
|
|
|
|
# 尝试从名称中提取数字(可能是知识库ID)
|
|
|
|
|
|
match = re.search(r'(\d+)', node_name)
|
|
|
|
|
|
if match:
|
|
|
|
|
|
try:
|
|
|
|
|
|
potential_id = int(match.group(1))
|
|
|
|
|
|
logger.info(f"从节点名称 '{node_name}' 中提取到潜在的知识库ID: {potential_id}")
|
|
|
|
|
|
knowledge_base_id = potential_id
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-23 12:45:05 +08:00
|
|
|
|
if not knowledge_base_id:
|
2026-01-23 13:59:20 +08:00
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"知识库节点配置缺少 knowledge_base_id (或 knowledgeBase/kb_id)。\n"
|
|
|
|
|
|
f"请在节点配置中添加知识库ID,例如:\n"
|
|
|
|
|
|
f" - config.knowledge_base_id: 2\n"
|
|
|
|
|
|
f" - config.knowledgeBase: 2\n"
|
|
|
|
|
|
f" - config.kb_id: 2\n"
|
|
|
|
|
|
f"当前节点配置: {config}"
|
|
|
|
|
|
)
|
2026-01-23 12:45:05 +08:00
|
|
|
|
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
# 尝试从 resolved_inputs 中获取查询文本
|
|
|
|
|
|
resolved_inputs = input_data.get('resolved_inputs', {})
|
|
|
|
|
|
query = resolved_inputs.get('query', '')
|
|
|
|
|
|
# 如果还是没有,尝试从 workflow_input 中获取
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
workflow_input = input_data.get('workflow_input', {})
|
2026-01-23 14:26:14 +08:00
|
|
|
|
# 首先尝试获取 'query' 字段
|
2026-01-23 12:45:05 +08:00
|
|
|
|
query = workflow_input.get('query', '')
|
2026-01-23 14:26:14 +08:00
|
|
|
|
# 如果没有 'query' 字段,尝试获取第一个非空的字符串值作为查询文本
|
|
|
|
|
|
if not query and isinstance(workflow_input, dict):
|
|
|
|
|
|
for key, value in workflow_input.items():
|
|
|
|
|
|
if isinstance(value, str) and value.strip():
|
|
|
|
|
|
query = value.strip()
|
|
|
|
|
|
logger.info(f"从工作流输入的 '{key}' 字段获取查询文本: {query}")
|
|
|
|
|
|
break
|
2026-01-23 12:45:05 +08:00
|
|
|
|
# 如果还是没有,尝试从 previous_outputs 中获取(可能是上一个节点的输出)
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
previous_outputs = input_data.get('previous_outputs', {})
|
|
|
|
|
|
# 尝试从上一个节点的输出中获取查询文本
|
|
|
|
|
|
for node_id, output in previous_outputs.items():
|
|
|
|
|
|
if isinstance(output, dict):
|
2026-01-23 14:26:14 +08:00
|
|
|
|
# 首先尝试从根级别获取
|
|
|
|
|
|
query = output.get('query', '')
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
# 尝试从 data 字段中获取
|
|
|
|
|
|
data = output.get('data', {})
|
|
|
|
|
|
if isinstance(data, dict):
|
|
|
|
|
|
query = data.get('query', '')
|
|
|
|
|
|
# 如果 data 中没有 query,尝试获取第一个非空字符串值
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
for key, value in data.items():
|
|
|
|
|
|
if isinstance(value, str) and value.strip():
|
|
|
|
|
|
query = value.strip()
|
|
|
|
|
|
logger.info(f"从节点 {node_id} 输出的 data.{key} 字段获取查询文本: {query}")
|
|
|
|
|
|
break
|
2026-01-23 12:45:05 +08:00
|
|
|
|
if query:
|
|
|
|
|
|
break
|
|
|
|
|
|
# 如果还是没有,使用默认查询或抛出错误
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
# 如果没有查询文本,可以返回空结果而不是抛出错误
|
|
|
|
|
|
logger.warning(f"知识库节点缺少查询文本,将返回空结果")
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'query': '',
|
|
|
|
|
|
'knowledge_base_id': knowledge_base_id,
|
|
|
|
|
|
'results': [],
|
|
|
|
|
|
'total_results': 0,
|
|
|
|
|
|
'top_k': top_k,
|
|
|
|
|
|
'similarity_threshold': similarity_threshold,
|
|
|
|
|
|
'warning': '缺少查询文本,返回空结果'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 直接使用 document_processor 进行搜索
|
2026-01-23 14:26:14 +08:00
|
|
|
|
# 注意:get_document_processor 需要 session 来初始化嵌入模型
|
2026-01-23 12:45:05 +08:00
|
|
|
|
from ..services.document_processor import get_document_processor
|
2026-01-23 14:26:14 +08:00
|
|
|
|
# 传入 self.session 以便初始化嵌入模型(虽然类型不匹配,但 get_document_processor 会处理)
|
|
|
|
|
|
document_processor = await get_document_processor(self.session)
|
2026-01-23 12:45:05 +08:00
|
|
|
|
results = document_processor.search_similar_documents(
|
|
|
|
|
|
knowledge_base_id=knowledge_base_id,
|
|
|
|
|
|
query=query,
|
|
|
|
|
|
k=top_k
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-23 14:26:14 +08:00
|
|
|
|
logger.info(f"知识库 {knowledge_base_id} 搜索查询 '{query}' 返回 {len(results)} 个原始结果")
|
|
|
|
|
|
|
2026-01-23 12:45:05 +08:00
|
|
|
|
# 过滤相似度阈值
|
|
|
|
|
|
filtered_results = []
|
2026-01-23 14:26:14 +08:00
|
|
|
|
all_results = []
|
2026-01-23 12:45:05 +08:00
|
|
|
|
for result in results:
|
|
|
|
|
|
score = result.get('normalized_score', result.get('similarity_score', 0))
|
2026-01-23 14:26:14 +08:00
|
|
|
|
all_results.append({
|
|
|
|
|
|
**result,
|
|
|
|
|
|
'score': score
|
|
|
|
|
|
})
|
2026-01-23 12:45:05 +08:00
|
|
|
|
if score >= similarity_threshold:
|
|
|
|
|
|
filtered_results.append(result)
|
|
|
|
|
|
|
2026-01-23 14:26:14 +08:00
|
|
|
|
logger.info(f"应用相似度阈值 {similarity_threshold} 后,剩余 {len(filtered_results)} 个结果")
|
|
|
|
|
|
|
2026-01-23 14:36:49 +08:00
|
|
|
|
# 如果过滤后结果为空,但原始结果不为空
|
2026-01-23 14:26:14 +08:00
|
|
|
|
if not filtered_results and results:
|
2026-01-23 14:36:49 +08:00
|
|
|
|
# 检查最高相似度分数
|
|
|
|
|
|
max_score = max([r.get('score', 0) for r in all_results]) if all_results else 0
|
|
|
|
|
|
# 如果最高分数仍然很低(低于阈值的50%),说明结果完全不相关,返回空结果
|
|
|
|
|
|
if max_score < similarity_threshold * 0.5:
|
|
|
|
|
|
logger.warning(f"所有搜索结果相似度都很低(最高: {max_score:.3f}),返回空结果")
|
|
|
|
|
|
filtered_results = []
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 相似度还可以,返回前几个结果并添加警告
|
|
|
|
|
|
logger.warning(f"所有搜索结果都被相似度阈值 {similarity_threshold} 过滤,但最高分数 {max_score:.3f} 尚可,返回前 {min(len(results), top_k)} 个结果")
|
|
|
|
|
|
filtered_results = results[:top_k]
|
2026-01-23 14:26:14 +08:00
|
|
|
|
|
2026-01-23 12:45:05 +08:00
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'query': query,
|
|
|
|
|
|
'knowledge_base_id': knowledge_base_id,
|
|
|
|
|
|
'results': filtered_results,
|
|
|
|
|
|
'total_results': len(filtered_results),
|
2026-01-23 14:26:14 +08:00
|
|
|
|
'raw_results_count': len(results),
|
2026-01-23 12:45:05 +08:00
|
|
|
|
'top_k': top_k,
|
2026-01-23 14:26:14 +08:00
|
|
|
|
'similarity_threshold': similarity_threshold,
|
|
|
|
|
|
'all_results_scores': [r.get('score', 0) for r in all_results[:5]] if all_results else []
|
2026-01-23 12:45:05 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"知识库搜索失败: {str(e)}")
|
|
|
|
|
|
raise ValueError(f"知识库搜索失败: {str(e)}")
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 工作流引擎实例
|
|
|
|
|
|
async def get_workflow_engine(session: AsyncSession = None) -> WorkflowEngine:
|
|
|
|
|
|
"""获取工作流引擎实例"""
|
|
|
|
|
|
if session is None:
|
|
|
|
|
|
async for s in get_session():
|
|
|
|
|
|
session = s
|
|
|
|
|
|
break
|
|
|
|
|
|
return WorkflowEngine(session)
|