hyf-backend/th_agenter/api/endpoints/workflow.py

537 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""工作流管理API"""
from typing import List, Optional, AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import select, and_, func
import json
from datetime import datetime
from ...db.database import get_session
from ...schemas.workflow import (
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
)
from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
from ...services.workflow_engine import get_workflow_engine
from ...services.auth import AuthService
from ...models.user import User
from loguru import logger
from utils.util_exceptions import HxfResponse
router = APIRouter()
def convert_workflow_for_response(workflow_dict):
"""转换工作流数据以适配响应模型"""
if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'):
for conn in workflow_dict['definition']['connections']:
if 'from_node' in conn:
conn['from'] = conn.pop('from_node')
if 'to_node' in conn:
conn['to'] = conn.pop('to_node')
return workflow_dict
@router.get("/", response_model=WorkflowListResponse, summary="获取工作流列表")
async def list_workflows(
skip: Optional[int] = Query(None, ge=0),
limit: Optional[int] = Query(None, ge=1, le=100),
workflow_status: Optional[WorkflowStatus] = None,
search: Optional[str] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流列表"""
from ...models.workflow import Workflow
session.title = f"获取用户 {current_user.username} 的所有工作流"
session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})"
# 构建查询
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
if workflow_status:
stmt = stmt.where(Workflow.status == workflow_status)
# 添加搜索功能
if search:
stmt = stmt.where(Workflow.name.ilike(f"%{search}%"))
# 获取总数
count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id)
if workflow_status:
count_query = count_query.where(Workflow.status == workflow_status)
if search:
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}"
total = await session.scalar(count_query)
session.desc = f"查询结果: 共 {total}"
# 如果没有传分页参数,返回所有数据
if skip is None and limit is None:
workflows = (await session.scalars(stmt)).all()
session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)}"
response = WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=1,
size=total
)
return HxfResponse(response)
# 使用默认分页参数
if skip is None:
skip = 0
if limit is None:
limit = 10
# 分页查询
workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all()
session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)}"
response = WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=skip // limit + 1, # 计算页码
size=limit
)
return HxfResponse(response)
@router.get("/{workflow_id}", response_model=WorkflowResponse, summary="获取工作流详情")
async def get_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流详情"""
from ...models.workflow import Workflow
session.title = f"获取工作流 {workflow_id}"
session.desc = f"START: 获取工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).where(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
if not workflow:
session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}"
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
return HxfResponse(response)
@router.put("/{workflow_id}", response_model=WorkflowResponse, summary="更新工作流")
async def update_workflow(
workflow_id: int,
workflow_data: WorkflowUpdate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""更新工作流"""
from ...models.workflow import Workflow
session.title = f"更新工作流 {workflow_id}"
session.desc = f"START: 更新工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
workflow_data.status = WorkflowStatus.PUBLISHED
# 更新字段
session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}"
update_data = workflow_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "definition" and value:
# 如果value是Pydantic模型转换为字典如果已经是字典直接使用
if hasattr(value, 'dict'):
setattr(workflow, field, value.dict())
else:
setattr(workflow, field, value)
else:
setattr(workflow, field, value)
workflow.set_audit_fields(current_user.id, is_update=True)
await session.commit()
await session.refresh(workflow)
session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}"
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
return HxfResponse(response)
@router.delete("/{workflow_id}", summary="删除工作流")
async def delete_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除工作流"""
from ...models.workflow import Workflow
session.title = f"删除工作流 {workflow_id}"
session.desc = f"START: 删除工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
session.desc = f"删除工作流: {workflow.name}"
await session.delete(workflow)
await session.commit()
session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}"
response = {"message": "工作流删除成功"}
return HxfResponse(response)
@router.post("/{workflow_id}/activate", summary="激活工作流")
async def activate_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""激活工作流"""
from ...models.workflow import Workflow
session.title = f"激活工作流 {workflow_id}"
session.desc = f"START: 激活工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
workflow.status = ModelWorkflowStatus.PUBLISHED
workflow.set_audit_fields(current_user.id, is_update=True)
await session.commit()
session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}"
response = {"message": "工作流激活成功"}
return HxfResponse(response)
@router.post("/{workflow_id}/deactivate", summary="停用工作流")
async def deactivate_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""停用工作流"""
from ...models.workflow import Workflow
session.title = f"停用工作流 {workflow_id}"
session.desc = f"START: 停用工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
workflow.status = ModelWorkflowStatus.ARCHIVED
workflow.set_audit_fields(current_user.id, is_update=True)
await session.commit()
session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}"
response = {"message": "工作流停用成功"}
return HxfResponse(response)
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse], summary="获取工作流执行历史")
async def list_workflow_executions(
workflow_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=1, le=100),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行历史"""
session.title = f"获取工作流执行历史 {workflow_id}"
session.desc = f"START: 获取工作流执行历史 {workflow_id}"
try:
from ...models.workflow import Workflow, WorkflowExecution
# 验证工作流所有权
workflow = await session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
# 获取执行历史
executions = (await session.scalars(
select(WorkflowExecution).where(
WorkflowExecution.workflow_id == workflow_id
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
)).all()
session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}"
response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
return HxfResponse(response)
except Exception as e:
session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行历史失败"
)
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse, summary="获取工作流执行详情")
async def get_workflow_execution(
execution_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行详情"""
session.title = f"获取工作流执行详情 {execution_id}"
session.desc = f"START: 获取工作流执行详情 {execution_id}"
try:
from ...models.workflow import WorkflowExecution, Workflow
execution = await session.scalar(
select(WorkflowExecution).join(
Workflow, WorkflowExecution.workflow_id == Workflow.id
).where(
WorkflowExecution.id == execution_id,
Workflow.owner_id == current_user.id
)
)
if not execution:
session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在"
)
response = WorkflowExecutionResponse.model_validate(execution)
session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}"
return HxfResponse(response)
except Exception as e:
session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行详情失败"
)
@router.post("/{workflow_id}/execute-stream", summary="流式执行工作流")
async def execute_workflow_stream(
workflow_id: int,
request: WorkflowExecuteRequest,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""流式执行工作流,实时推送节点执行状态"""
session.title = f"流式执行工作流 {workflow_id}"
session.desc = f"START: 流式执行工作流 {workflow_id}"
async def generate_stream() -> AsyncGenerator[str, None]:
workflow_engine = None
try:
from ...models.workflow import Workflow
# 验证工作流
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
return
if workflow.status != ModelWorkflowStatus.PUBLISHED:
yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n"
return
# 发送开始信号
yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
# 获取工作流引擎
workflow_engine = await get_workflow_engine(session)
# 执行工作流(流式版本)
async for step_data in workflow_engine.execute_workflow_stream(
workflow=workflow,
input_data=request.input_data,
user_id=current_user.id,
session=session
):
# 推送工作流步骤
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
# 发送完成信号
yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
except Exception as e:
# 这里捕获到的通常是内部节点或引擎抛出的异常,比如 KeyError("'pk_1'")
# 使用 exception 打印完整堆栈,并记录异常类型与 repr方便排查
logger.exception(
f"流式工作流执行异常type={type(e).__name__}, repr={repr(e)}"
)
# 将错误信息推送给前端
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
response = StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "*"
}
)
session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕"
return HxfResponse(response)
# -----------------------------------------------------------------------
@router.post("/", response_model=WorkflowResponse, summary="创建工作流")
async def create_workflow(
workflow_data: WorkflowCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建工作流"""
from ...models.workflow import Workflow
session.title = f"创建工作流 {workflow_data.name}"
session.desc = f"START: 创建工作流 {workflow_data.name}"
# 创建工作流
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
definition=workflow_data.definition.model_dump(),
version="1.0.0",
status=ModelWorkflowStatus.PUBLISHED, # workflow_data.status,
owner_id=current_user.id
)
session.desc = f"创建工作流实例 - Workflow() {workflow_data.name}"
workflow.set_audit_fields(current_user.id)
session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}"
session.add(workflow)
await session.commit()
await session.refresh(workflow)
session.desc = f"保存工作流 - commit & refresh {workflow_data.name}"
# 转换definition中的字段映射
workflow_dict = convert_workflow_for_response(workflow.to_dict())
session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}"
response = WorkflowResponse(**workflow_dict)
session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}"
return HxfResponse(response)
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse, summary="执行工作流")
async def execute_workflow(
workflow_id: int,
request: WorkflowExecuteRequest,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""执行工作流"""
from ...models.workflow import Workflow
session.title = f"执行工作流 {workflow_id}"
session.desc = f"START: 执行工作流 {workflow_id}"
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
session.desc = f"获取工作流数据 - Workflow() {workflow_id}"
if workflow.status != ModelWorkflowStatus.PUBLISHED:
session.desc = f"ERROR: 执行工作流数据 - 工作流未激活 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="工作流未激活,无法执行"
)
# 获取工作流引擎并执行
engine = await get_workflow_engine(session)
session.desc = f"获取工作流引擎 - get_workflow_engine {workflow_id}"
execution_result = await engine.execute_workflow(
workflow=workflow,
input_data=request.input_data,
user_id=current_user.id,
session=session
)
session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}"
return HxfResponse(execution_result)