2025-12-16 13:55:16 +08:00
|
|
|
"""Chat endpoints for TH Agenter."""
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
import json
|
2025-12-04 14:48:38 +08:00
|
|
|
from typing import List
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from sqlalchemy.orm import Session
|
2025-12-16 13:55:16 +08:00
|
|
|
from loguru import logger
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
from ...db.database import get_session
|
2025-12-04 14:48:38 +08:00
|
|
|
from ...models.user import User
|
|
|
|
|
from ...services.auth import AuthService
|
|
|
|
|
from ...services.chat import ChatService
|
|
|
|
|
from ...services.conversation import ConversationService
|
2026-01-07 11:30:54 +08:00
|
|
|
from utils.util_exceptions import HxfResponse
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
from utils.util_schemas import (
|
2025-12-04 14:48:38 +08:00
|
|
|
ConversationCreate,
|
|
|
|
|
ConversationResponse,
|
|
|
|
|
ConversationUpdate,
|
|
|
|
|
MessageCreate,
|
|
|
|
|
MessageResponse,
|
|
|
|
|
ChatRequest,
|
|
|
|
|
ChatResponse
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def update_conversation(
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
conversation_update: ConversationUpdate,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""更新指定对话"""
|
|
|
|
|
session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
updated_conversation = await conversation_service.update_conversation(
|
2025-12-04 14:48:38 +08:00
|
|
|
conversation_id, conversation_update
|
|
|
|
|
)
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
|
2026-01-07 11:30:54 +08:00
|
|
|
response = ConversationResponse.model_validate(updated_conversation)
|
|
|
|
|
return HxfResponse(response)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def delete_conversation(
|
|
|
|
|
conversation_id: int,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""删除指定对话"""
|
|
|
|
|
session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
await conversation_service.delete_conversation(conversation_id)
|
|
|
|
|
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
|
2026-01-07 11:30:54 +08:00
|
|
|
response = {"message": "Conversation deleted successfully"}
|
|
|
|
|
return HxfResponse(response)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def archive_conversation(
|
|
|
|
|
conversation_id: int,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""归档指定对话."""
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
success = await conversation_service.archive_conversation(conversation_id)
|
2025-12-04 14:48:38 +08:00
|
|
|
if not success:
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}"
|
2025-12-04 14:48:38 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail="Failed to archive conversation"
|
|
|
|
|
)
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
2026-01-07 11:30:54 +08:00
|
|
|
response = {"message": "Conversation archived successfully"}
|
|
|
|
|
return HxfResponse(response)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def unarchive_conversation(
|
|
|
|
|
conversation_id: int,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""取消归档指定对话."""
|
|
|
|
|
session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
success = await conversation_service.unarchive_conversation(conversation_id)
|
2025-12-04 14:48:38 +08:00
|
|
|
if not success:
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}"
|
2025-12-04 14:48:38 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail="Failed to unarchive conversation"
|
|
|
|
|
)
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
2026-01-07 11:30:54 +08:00
|
|
|
response = {"message": "Conversation unarchived successfully"}
|
|
|
|
|
return HxfResponse(response)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Message management
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def get_conversation_messages(
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 100,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""获取指定对话的消息"""
|
|
|
|
|
session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
messages = await conversation_service.get_conversation_messages(
|
2025-12-04 14:48:38 +08:00
|
|
|
conversation_id, skip=skip, limit=limit
|
|
|
|
|
)
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
2026-01-07 11:30:54 +08:00
|
|
|
response = [MessageResponse.model_validate(msg) for msg in messages]
|
|
|
|
|
return HxfResponse(response)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
# Chat functionality
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def chat(
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
chat_request: ChatRequest,
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""发送消息并获取AI响应"""
|
|
|
|
|
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
|
|
|
|
|
chat_service = ChatService(session)
|
2026-01-07 11:30:54 +08:00
|
|
|
await chat_service.initialize(conversation_id)
|
|
|
|
|
|
|
|
|
|
# response = await chat_service.chat(
|
|
|
|
|
# conversation_id=conversation_id,
|
|
|
|
|
# message=chat_request.message,
|
|
|
|
|
# stream=False,
|
|
|
|
|
# temperature=chat_request.temperature,
|
|
|
|
|
# max_tokens=chat_request.max_tokens,
|
|
|
|
|
# use_agent=chat_request.use_agent, # 可以简化掉
|
|
|
|
|
# use_langgraph=chat_request.use_langgraph, # 可以简化掉
|
|
|
|
|
# use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉
|
|
|
|
|
# knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉
|
|
|
|
|
# )
|
|
|
|
|
response = "oooooooooooooooooooK"
|
2025-12-16 13:55:16 +08:00
|
|
|
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
return HxfResponse(response)
|
|
|
|
|
# ------------------------------------------------------------------------
|
2025-12-16 13:55:16 +08:00
|
|
|
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
|
2025-12-04 14:48:38 +08:00
|
|
|
async def chat_stream(
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
chat_request: ChatRequest,
|
2026-01-07 11:30:54 +08:00
|
|
|
current_user = Depends(AuthService.get_current_user),
|
2025-12-16 13:55:16 +08:00
|
|
|
session: Session = Depends(get_session)
|
2025-12-04 14:48:38 +08:00
|
|
|
):
|
2025-12-16 13:55:16 +08:00
|
|
|
"""发送消息并获取流式AI响应."""
|
2026-01-07 11:30:54 +08:00
|
|
|
session.title = f"对话{conversation_id} 发送消息并获取流式AI响应"
|
|
|
|
|
session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> "
|
2025-12-16 13:55:16 +08:00
|
|
|
chat_service = ChatService(session)
|
2026-01-07 11:30:54 +08:00
|
|
|
await chat_service.initialize(conversation_id, streaming=True)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
async def generate_response(chat_service):
|
|
|
|
|
try:
|
|
|
|
|
async for chunk in chat_service.chat_stream(
|
|
|
|
|
message=chat_request.message
|
|
|
|
|
):
|
|
|
|
|
yield chunk + "\n"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}")
|
|
|
|
|
yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"}
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
response = StreamingResponse(
|
|
|
|
|
generate_response(chat_service),
|
|
|
|
|
media_type="text/stream",
|
2025-12-04 14:48:38 +08:00
|
|
|
headers={
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
}
|
|
|
|
|
)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
# Conversation management
|
|
|
|
|
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
|
|
|
|
async def create_conversation(
|
|
|
|
|
conversation_data: ConversationCreate,
|
|
|
|
|
current_user: User = Depends(AuthService.get_current_user),
|
|
|
|
|
session: Session = Depends(get_session)
|
|
|
|
|
):
|
|
|
|
|
"""创建新对话"""
|
|
|
|
|
id = current_user.id
|
|
|
|
|
session.title = f"用户{current_user.username} - 创建新对话"
|
|
|
|
|
session.desc = "START: 创建新对话"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
conversation = await conversation_service.create_conversation(
|
|
|
|
|
user_id=id,
|
|
|
|
|
conversation_data=conversation_data
|
|
|
|
|
)
|
|
|
|
|
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}"
|
|
|
|
|
response = ConversationResponse.model_validate(conversation)
|
|
|
|
|
return HxfResponse(response)
|
|
|
|
|
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
|
|
|
|
async def list_conversations(
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 50,
|
|
|
|
|
search: str = None,
|
|
|
|
|
include_archived: bool = False,
|
|
|
|
|
order_by: str = "updated_at",
|
|
|
|
|
order_desc: bool = True,
|
|
|
|
|
session: Session = Depends(get_session)
|
|
|
|
|
):
|
|
|
|
|
"""获取用户对话列表"""
|
|
|
|
|
session.title = "获取用户对话列表"
|
|
|
|
|
session.desc = "START: 获取用户对话列表"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
conversations = await conversation_service.get_user_conversations(
|
|
|
|
|
skip=skip,
|
|
|
|
|
limit=limit,
|
|
|
|
|
search_query=search,
|
|
|
|
|
include_archived=include_archived,
|
|
|
|
|
order_by=order_by,
|
|
|
|
|
order_desc=order_desc
|
|
|
|
|
)
|
|
|
|
|
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..."
|
|
|
|
|
response = [ConversationResponse.model_validate(conv) for conv in conversations]
|
|
|
|
|
return HxfResponse(response)
|
|
|
|
|
|
|
|
|
|
@router.get("/conversations/count", summary="获取用户对话总数")
|
|
|
|
|
async def get_conversations_count(
|
|
|
|
|
search: str = None,
|
|
|
|
|
include_archived: bool = False,
|
|
|
|
|
session: Session = Depends(get_session)
|
|
|
|
|
):
|
|
|
|
|
"""获取用户对话总数"""
|
|
|
|
|
from th_agenter.core.context import UserContext
|
|
|
|
|
user_id = UserContext.get_current_user_id()
|
|
|
|
|
session.title = f"获取用户对话总数[用户id = {user_id}]"
|
|
|
|
|
session.desc = "START: 获取用户对话总数"
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
count = await conversation_service.get_user_conversations_count(
|
|
|
|
|
search_query=search,
|
|
|
|
|
include_archived=include_archived
|
|
|
|
|
)
|
|
|
|
|
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
|
|
|
|
response = {"count": count}
|
|
|
|
|
return HxfResponse(response)
|
|
|
|
|
|
|
|
|
|
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
|
|
|
|
async def get_conversation(
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
session: Session = Depends(get_session)
|
|
|
|
|
):
|
|
|
|
|
"""获取指定对话"""
|
|
|
|
|
session.title = f"获取指定对话[对话id = {conversation_id}]"
|
|
|
|
|
session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}"
|
|
|
|
|
|
|
|
|
|
conversation_service = ConversationService(session)
|
|
|
|
|
conversation = await conversation_service.get_conversation(
|
|
|
|
|
conversation_id=conversation_id
|
|
|
|
|
)
|
|
|
|
|
if not conversation:
|
|
|
|
|
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
detail="Conversation not found"
|
|
|
|
|
)
|
|
|
|
|
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}"
|
|
|
|
|
|
|
|
|
|
response = ConversationResponse.model_validate(conversation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# chat_service = ChatService(session)
|
|
|
|
|
# await chat_service.initialize(conversation_id, streaming=False)
|
|
|
|
|
# messages = await chat_service.get_conversation_history_messages(
|
|
|
|
|
# conversation_id
|
|
|
|
|
# )
|
|
|
|
|
# response.messages = messages
|
|
|
|
|
|
|
|
|
|
messages = await conversation_service.get_conversation_messages(
|
|
|
|
|
conversation_id, skip=0, limit=100
|
|
|
|
|
)
|
|
|
|
|
response.messages = [MessageResponse.model_validate(msg) for msg in messages]
|
|
|
|
|
|
|
|
|
|
response.message_count = len(response.messages)
|
|
|
|
|
return HxfResponse(response)
|