hxf/backend/th_agenter/api/endpoints/workflow.py

480 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""工作流管理API"""
from typing import List, Optional, AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import select, and_, func
import json
from datetime import datetime
from ...db.database import get_session
from ...schemas.workflow import (
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
)
from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
from ...services.workflow_engine import get_workflow_engine
from ...services.auth import AuthService
from ...models.user import User
from loguru import logger
router = APIRouter()
def convert_workflow_for_response(workflow_dict):
"""转换工作流数据以适配响应模型"""
if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'):
for conn in workflow_dict['definition']['connections']:
if 'from_node' in conn:
conn['from'] = conn.pop('from_node')
if 'to_node' in conn:
conn['to'] = conn.pop('to_node')
return workflow_dict
@router.post("/", response_model=WorkflowResponse)
async def create_workflow(
workflow_data: WorkflowCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建工作流"""
from ...models.workflow import Workflow
# 创建工作流
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
definition=workflow_data.definition.model_dump(),
version="1.0.0",
status=workflow_data.status,
owner_id=current_user.id
)
workflow.set_audit_fields(current_user.id)
await session.add(workflow)
await session.commit()
await session.refresh(workflow)
# 转换definition中的字段映射
workflow_dict = convert_workflow_for_response(workflow.to_dict())
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
return WorkflowResponse(**workflow_dict)
@router.get("/", response_model=WorkflowListResponse)
async def list_workflows(
skip: Optional[int] = Query(None, ge=0),
limit: Optional[int] = Query(None, ge=1, le=100),
workflow_status: Optional[WorkflowStatus] = None,
search: Optional[str] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流列表"""
from ...models.workflow import Workflow
# 构建查询
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
if workflow_status:
stmt = stmt.where(Workflow.status == workflow_status)
# 添加搜索功能
if search:
stmt = stmt.where(Workflow.name.ilike(f"%{search}%"))
# 获取总数
count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id)
if workflow_status:
count_query = count_query.where(Workflow.status == workflow_status)
if search:
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
total = session.scalar(count_query)
# 如果没有传分页参数,返回所有数据
if skip is None and limit is None:
workflows = session.scalars(stmt).all()
return WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=1,
size=total
)
# 使用默认分页参数
if skip is None:
skip = 0
if limit is None:
limit = 10
# 分页查询
workflows = session.scalars(stmt.offset(skip).limit(limit)).all()
return WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=skip // limit + 1, # 计算页码
size=limit
)
@router.get("/{workflow_id}", response_model=WorkflowResponse)
async def get_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流详情"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).where(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
@router.put("/{workflow_id}", response_model=WorkflowResponse)
async def update_workflow(
workflow_id: int,
workflow_data: WorkflowUpdate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""更新工作流"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
# 更新字段
update_data = workflow_data.dict(exclude_unset=True)
for field, value in update_data.items():
if field == "definition" and value:
# 如果value是Pydantic模型转换为字典如果已经是字典直接使用
if hasattr(value, 'dict'):
setattr(workflow, field, value.dict())
else:
setattr(workflow, field, value)
else:
setattr(workflow, field, value)
workflow.set_audit_fields(current_user.id, is_update=True)
session.commit()
session.refresh(workflow)
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
@router.delete("/{workflow_id}")
async def delete_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除工作流"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
session.delete(workflow)
session.commit()
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流删除成功"}
@router.post("/{workflow_id}/activate")
async def activate_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""激活工作流"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
workflow.status = ModelWorkflowStatus.PUBLISHED
workflow.set_audit_fields(current_user.id, is_update=True)
session.commit()
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流激活成功"}
@router.post("/{workflow_id}/deactivate")
async def deactivate_workflow(
workflow_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""停用工作流"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
workflow.status = ModelWorkflowStatus.ARCHIVED
workflow.set_audit_fields(current_user.id, is_update=True)
session.commit()
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流停用成功"}
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
async def execute_workflow(
workflow_id: int,
request: WorkflowExecuteRequest,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""执行工作流"""
from ...models.workflow import Workflow
workflow = session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
if workflow.status != ModelWorkflowStatus.PUBLISHED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="工作流未激活,无法执行"
)
# 获取工作流引擎并执行
engine = get_workflow_engine()
execution_result = await engine.execute_workflow(
workflow=workflow,
input_data=request.input_data,
user_id=current_user.id,
session=session
)
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
return execution_result
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
async def list_workflow_executions(
workflow_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=1, le=100),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行历史"""
try:
from ...models.workflow import Workflow, WorkflowExecution
# 验证工作流所有权
workflow = session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
# 获取执行历史
executions = session.scalars(
select(WorkflowExecution).where(
WorkflowExecution.workflow_id == workflow_id
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
).all()
return [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
except HTTPException:
raise
except Exception as e:
logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行历史失败"
)
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
async def get_workflow_execution(
execution_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行详情"""
try:
from ...models.workflow import WorkflowExecution, Workflow
execution = session.scalar(
select(WorkflowExecution).join(
Workflow, WorkflowExecution.workflow_id == Workflow.id
).where(
WorkflowExecution.id == execution_id,
Workflow.owner_id == current_user.id
)
)
if not execution:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在"
)
return WorkflowExecutionResponse.model_validate(execution)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow execution {execution_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行详情失败"
)
@router.post("/{workflow_id}/execute-stream")
async def execute_workflow_stream(
workflow_id: int,
request: WorkflowExecuteRequest,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""流式执行工作流,实时推送节点执行状态"""
async def generate_stream() -> AsyncGenerator[str, None]:
workflow_engine = None
try:
from ...models.workflow import Workflow
# 验证工作流
workflow = session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
)
)
)
if not workflow:
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
return
if workflow.status != ModelWorkflowStatus.PUBLISHED:
yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n"
return
# 发送开始信号
yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
# 获取工作流引擎
workflow_engine = get_workflow_engine()
# 执行工作流(流式版本)
async for step_data in workflow_engine.execute_workflow_stream(
workflow=workflow,
input_data=request.input_data,
user_id=current_user.id,
session=session
):
# 推送工作流步骤
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
# 发送完成信号
yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "*"
}
)