hxf/backend/th_agenter/api/endpoints/smart_chat.py

338 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPBearer
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any
from datetime import datetime
from th_agenter.db.database import get_session
from th_agenter.services.auth import AuthService
from th_agenter.services.smart_workflow import SmartWorkflowManager
from th_agenter.services.conversation import ConversationService
from th_agenter.services.conversation_context import conversation_context_service
from utils.util_schemas import BaseResponse
from pydantic import BaseModel
from loguru import logger
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
security = HTTPBearer()
# Request/Response Models
class SmartQueryRequest(BaseModel):
query: str
conversation_id: Optional[int] = None
is_new_conversation: bool = False
class SmartQueryResponse(BaseModel):
success: bool
message: str
data: Optional[Dict[str, Any]] = None
workflow_steps: Optional[list] = None
conversation_id: Optional[int] = None
class ConversationContextResponse(BaseModel):
success: bool
message: str
data: Optional[Dict[str, Any]] = None
@router.post("/query", response_model=SmartQueryResponse, summary="智能问数查询")
async def smart_query(
request: SmartQueryRequest,
current_user = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""
智能问数查询接口
支持新对话时自动加载文件列表智能选择相关Excel文件生成和执行pandas代码
"""
session.desc = f"START: 用户 {current_user.username} 智能问数查询"
conversation_id = None
try:
# 验证请求参数
if not request.query or not request.query.strip():
session.desc = "ERROR: 用户输入为空, 查询内容不能为空"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="查询内容不能为空"
)
if len(request.query) > 1000:
session.desc = "ERROR: 用户输入过长, 查询内容不能超过1000字符"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="查询内容过长请控制在1000字符以内"
)
# 初始化工作流管理器
workflow_manager = SmartWorkflowManager(session)
await workflow_manager.initialize()
conversation_service = ConversationService(session)
# 处理对话上下文
conversation_id = request.conversation_id
# 如果是新对话或没有指定对话ID创建新对话
if request.is_new_conversation or not conversation_id:
try:
conversation_id = await conversation_context_service.create_conversation(
user_id=current_user.id,
title=f"智能问数: {request.query[:20]}..."
)
request.is_new_conversation = True
session.desc = f"创建新对话: {conversation_id}"
except Exception as e:
session.desc = f"WARNING: 创建对话失败,使用临时会话: {e}"
conversation_id = None
else:
# 验证对话是否存在且属于当前用户
try:
context = await conversation_context_service.get_conversation_context(conversation_id)
if not context or context.get('user_id') != current_user.id:
session.desc = f"ERROR: 对话 {conversation_id} 不存在或无权访问"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在或无权访问"
)
session.desc = f"使用现有对话: {conversation_id}"
except HTTPException:
session.desc = f"EXCEPTION: 对话 {conversation_id} 不存在或无权访问"
raise
except Exception as e:
session.desc = f"ERROR: 验证对话失败: {e}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="对话验证失败"
)
# 保存用户消息
if conversation_id:
try:
await conversation_context_service.save_message(
conversation_id=conversation_id,
role="user",
content=request.query
)
except Exception as e:
session.desc = f"WARNING: 保存用户消息失败: {e}"
# 不阻断流程,继续执行查询
# 执行智能查询工作流
try:
result = await workflow_manager.process_smart_query(
user_query=request.query,
user_id=current_user.id,
conversation_id=conversation_id,
is_new_conversation=request.is_new_conversation
)
except Exception as e:
session.desc = f"ERROR: 智能查询执行失败: {e}"
# 返回结构化的错误响应
response = SmartQueryResponse(
success=False,
message=f"查询执行失败: {str(e)}",
data={'error_type': 'query_execution_error'},
workflow_steps=[{
'step': 'error',
'status': 'failed',
'message': str(e)
}],
conversation_id=conversation_id
)
return HxfResponse(response)
# 如果查询成功,保存助手回复和更新上下文
if result['success'] and conversation_id:
try:
# 保存助手回复
await conversation_context_service.save_message(
conversation_id=conversation_id,
role="assistant",
content=result.get('data', {}).get('summary', '查询完成'),
metadata={
'query_result': result.get('data'),
'workflow_steps': result.get('workflow_steps', []),
'selected_files': result.get('data', {}).get('used_files', [])
}
)
# 更新对话上下文
await conversation_context_service.update_conversation_context(
conversation_id=conversation_id,
query=request.query,
selected_files=result.get('data', {}).get('used_files', [])
)
except Exception as e:
session.desc = f"EXCEPTION: 保存消息到对话历史失败: {e}"
# 不影响返回结果,只记录警告
# 返回结果包含对话ID
response_data = result.get('data', {})
if conversation_id:
response_data['conversation_id'] = conversation_id
session.desc = f"SUCCESS: 保存助手回复和更新上下文对话ID: {conversation_id}"
response = SmartQueryResponse(
success=result['success'],
message=result.get('message', '查询完成'),
data=response_data,
workflow_steps=result.get('workflow_steps', []),
conversation_id=conversation_id
)
return HxfResponse(response)
except HTTPException as e:
session.desc = f"EXCEPTION: HTTP异常: {e}"
raise e
except Exception as e:
session.desc = f"ERROR: 智能查询接口异常: {e}"
# 返回通用错误响应
response = SmartQueryResponse(
success=False,
message="服务器内部错误,请稍后重试",
data={'error_type': 'internal_server_error'},
workflow_steps=[{
'step': 'error',
'status': 'failed',
'message': '系统异常'
}],
conversation_id=conversation_id
)
return HxfResponse(response)
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
async def get_conversation_context(
conversation_id: int,
current_user = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""
获取对话上下文信息,包括已使用的文件和历史查询
"""
# 获取对话上下文
session.desc = f"START: 获取对话上下文对话ID: {conversation_id}"
context = await conversation_context_service.get_conversation_context(conversation_id)
if not context:
session.desc = f"ERROR: 对话上下文不存在对话ID: {conversation_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话上下文不存在"
)
# 验证用户权限
if context['user_id'] != current_user.id:
session.desc = f"ERROR: 无权访问对话上下文对话ID: {conversation_id}"
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此对话"
)
# 获取对话历史
history = await conversation_context_service.get_conversation_history(conversation_id)
context['message_history'] = history
session.desc = f"SUCCESS: 获取对话上下文成功对话ID: {conversation_id}"
response = ConversationContextResponse(
success=True,
message="获取对话上下文成功",
data=context
)
return HxfResponse(response)
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
async def get_files_status(
current_user = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""
获取用户当前的文件状态和统计信息
"""
session.desc = f"START: 获取用户文件状态和统计信息用户ID: {current_user.id}"
workflow_manager = SmartWorkflowManager()
await workflow_manager.initialize()
# 获取用户文件列表
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
# 统计信息
total_files = len(file_list)
total_rows = sum(f.get('row_count', 0) for f in file_list)
total_columns = sum(f.get('column_count', 0) for f in file_list)
# 文件类型统计
file_types = {}
for file_info in file_list:
filename = file_info['filename']
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
file_types[ext] = file_types.get(ext, 0) + 1
status_data = {
'total_files': total_files,
'total_rows': total_rows,
'total_columns': total_columns,
'file_types': file_types,
'files': [{
'id': f['id'],
'filename': f['filename'],
'row_count': f.get('row_count', 0),
'column_count': f.get('column_count', 0),
'columns': f.get('columns', []),
'upload_time': f.get('upload_time')
} for f in file_list],
'ready_for_query': total_files > 0
}
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功用户ID: {current_user.id}"
response = ConversationContextResponse(
success=True,
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件请先上传Excel文件",
data=status_data
)
return HxfResponse(response)
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
async def reset_conversation_context(
conversation_id: int,
current_user = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""
重置对话上下文,清除历史查询记录但保留文件
"""
session.desc = f"START: 重置对话上下文对话ID: {conversation_id}"
# 验证对话存在和用户权限
context = await conversation_context_service.get_conversation_context(conversation_id)
if not context:
session.desc = f"ERROR: 对话上下文不存在对话ID: {conversation_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话上下文不存在"
)
if context['user_id'] != current_user.id:
session.desc = f"ERROR: 无权访问对话上下文对话ID: {conversation_id}"
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此对话"
)
# 重置对话上下文
success = await conversation_context_service.reset_conversation_context(conversation_id)
if success:
session.desc = f"SUCCESS: 重置对话上下文成功对话ID: {conversation_id}"
response = ConversationContextResponse(
success=True,
message="对话上下文已重置,可以开始新的数据分析会话"
)
return HxfResponse(response)
else:
session.desc = f"EXCEPTION: 重置对话上下文失败对话ID: {conversation_id}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="重置对话上下文失败"
)