hxf/backend/th_agenter/api/endpoints/chat.py

243 lines
9.6 KiB
Python
Raw Normal View History

2025-12-16 13:55:16 +08:00
"""Chat endpoints for TH Agenter."""
2025-12-04 14:48:38 +08:00
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
2025-12-16 13:55:16 +08:00
from loguru import logger
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
from ...db.database import get_session
2025-12-04 14:48:38 +08:00
from ...models.user import User
from ...services.auth import AuthService
from ...services.chat import ChatService
from ...services.conversation import ConversationService
2025-12-16 13:55:16 +08:00
from utils.util_schemas import (
2025-12-04 14:48:38 +08:00
ConversationCreate,
ConversationResponse,
ConversationUpdate,
MessageCreate,
MessageResponse,
ChatRequest,
ChatResponse
)
router = APIRouter()
# Conversation management
2025-12-16 13:55:16 +08:00
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
2025-12-04 14:48:38 +08:00
async def create_conversation(
conversation_data: ConversationCreate,
current_user: User = Depends(AuthService.get_current_user),
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""创建新对话"""
session.desc = "START: 创建新对话"
conversation_service = ConversationService(session)
conversation = await conversation_service.create_conversation(
2025-12-04 14:48:38 +08:00
user_id=current_user.id,
conversation_data=conversation_data
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}"
return ConversationResponse.model_validate(conversation)
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
2025-12-04 14:48:38 +08:00
async def list_conversations(
skip: int = 0,
limit: int = 50,
search: str = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""获取用户对话列表"""
session.desc = "START: 获取用户对话列表"
conversation_service = ConversationService(session)
conversations = await conversation_service.get_user_conversations(
2025-12-04 14:48:38 +08:00
skip=skip,
limit=limit,
search_query=search,
include_archived=include_archived,
order_by=order_by,
order_desc=order_desc
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话"
return [ConversationResponse.model_validate(conv) for conv in conversations]
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@router.get("/conversations/count", summary="获取用户对话总数")
2025-12-04 14:48:38 +08:00
async def get_conversations_count(
search: str = None,
include_archived: bool = False,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""获取用户对话总数"""
session.desc = "START: 获取用户对话总数"
conversation_service = ConversationService(session)
count = await conversation_service.get_user_conversations_count(
2025-12-04 14:48:38 +08:00
search_query=search,
include_archived=include_archived
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
2025-12-04 14:48:38 +08:00
return {"count": count}
2025-12-16 13:55:16 +08:00
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
2025-12-04 14:48:38 +08:00
async def get_conversation(
conversation_id: int,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""获取指定对话"""
session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}"
conversation_service = ConversationService(session)
conversation = await conversation_service.get_conversation(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id
)
if not conversation:
2025-12-16 13:55:16 +08:00
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
2025-12-04 14:48:38 +08:00
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}"
return ConversationResponse.model_validate(conversation)
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
2025-12-04 14:48:38 +08:00
async def update_conversation(
conversation_id: int,
conversation_update: ConversationUpdate,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""更新指定对话"""
session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}"
conversation_service = ConversationService(session)
updated_conversation = await conversation_service.update_conversation(
2025-12-04 14:48:38 +08:00
conversation_id, conversation_update
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
return ConversationResponse.model_validate(updated_conversation)
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
2025-12-04 14:48:38 +08:00
async def delete_conversation(
conversation_id: int,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""删除指定对话"""
session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}"
conversation_service = ConversationService(session)
await conversation_service.delete_conversation(conversation_id)
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
return {"message": "Conversation deleted successfully"}
2025-12-16 13:55:16 +08:00
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
2025-12-04 14:48:38 +08:00
async def archive_conversation(
conversation_id: int,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""归档指定对话."""
conversation_service = ConversationService(session)
success = await conversation_service.archive_conversation(conversation_id)
2025-12-04 14:48:38 +08:00
if not success:
2025-12-16 13:55:16 +08:00
session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to archive conversation"
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
return {"message": "Conversation archived successfully"}
2025-12-16 13:55:16 +08:00
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
2025-12-04 14:48:38 +08:00
async def unarchive_conversation(
conversation_id: int,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""取消归档指定对话."""
session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}"
conversation_service = ConversationService(session)
success = await conversation_service.unarchive_conversation(conversation_id)
2025-12-04 14:48:38 +08:00
if not success:
2025-12-16 13:55:16 +08:00
session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to unarchive conversation"
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
return {"message": "Conversation unarchived successfully"}
# Message management
2025-12-16 13:55:16 +08:00
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息")
2025-12-04 14:48:38 +08:00
async def get_conversation_messages(
conversation_id: int,
skip: int = 0,
limit: int = 100,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""获取指定对话的消息"""
session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
conversation_service = ConversationService(session)
messages = await conversation_service.get_conversation_messages(
2025-12-04 14:48:38 +08:00
conversation_id, skip=skip, limit=limit
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
return [MessageResponse.model_validate(msg) for msg in messages]
2025-12-04 14:48:38 +08:00
# Chat functionality
2025-12-16 13:55:16 +08:00
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
2025-12-04 14:48:38 +08:00
async def chat(
conversation_id: int,
chat_request: ChatRequest,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""发送消息并获取AI响应"""
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
chat_service = ChatService(session)
2025-12-04 14:48:38 +08:00
response = await chat_service.chat(
conversation_id=conversation_id,
message=chat_request.message,
stream=False,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
use_agent=chat_request.use_agent,
use_langgraph=chat_request.use_langgraph,
use_knowledge_base=chat_request.use_knowledge_base,
knowledge_base_id=chat_request.knowledge_base_id
)
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
2025-12-04 14:48:38 +08:00
return response
2025-12-16 13:55:16 +08:00
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
2025-12-04 14:48:38 +08:00
async def chat_stream(
conversation_id: int,
chat_request: ChatRequest,
2025-12-16 13:55:16 +08:00
session: Session = Depends(get_session)
2025-12-04 14:48:38 +08:00
):
2025-12-16 13:55:16 +08:00
"""发送消息并获取流式AI响应."""
chat_service = ChatService(session)
2025-12-04 14:48:38 +08:00
async def generate_response():
async for chunk in chat_service.chat_stream(
conversation_id=conversation_id,
message=chat_request.message,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
use_agent=chat_request.use_agent,
use_langgraph=chat_request.use_langgraph,
use_knowledge_base=chat_request.use_knowledge_base,
knowledge_base_id=chat_request.knowledge_base_id
):
yield f"data: {chunk}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)