"""Chat endpoints for TH Agenter.""" import json from typing import List from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from loguru import logger from ...db.database import get_session from ...models.user import User from ...services.auth import AuthService from ...services.chat import ChatService from ...services.conversation import ConversationService from utils.util_exceptions import HxfResponse from utils.util_schemas import ( ConversationCreate, ConversationResponse, ConversationUpdate, MessageCreate, MessageResponse, ChatRequest, ChatResponse ) router = APIRouter() @router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话") async def update_conversation( conversation_id: int, conversation_update: ConversationUpdate, session: Session = Depends(get_session) ): """更新指定对话""" session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}" conversation_service = ConversationService(session) updated_conversation = await conversation_service.update_conversation( conversation_id, conversation_update ) session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}" response = ConversationResponse.model_validate(updated_conversation) return HxfResponse(response) @router.delete("/conversations/{conversation_id}", summary="删除指定对话") async def delete_conversation( conversation_id: int, session: Session = Depends(get_session) ): """删除指定对话""" session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}" conversation_service = ConversationService(session) await conversation_service.delete_conversation(conversation_id) session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}" response = {"message": "Conversation deleted successfully"} return HxfResponse(response) @router.put("/conversations/{conversation_id}/archive", summary="归档指定对话") async def archive_conversation( conversation_id: int, session: Session = Depends(get_session) ): """归档指定对话.""" conversation_service = ConversationService(session) success = await conversation_service.archive_conversation(conversation_id) if not success: session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}" raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to archive conversation" ) session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}" response = {"message": "Conversation archived successfully"} return HxfResponse(response) @router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话") async def unarchive_conversation( conversation_id: int, session: Session = Depends(get_session) ): """取消归档指定对话.""" session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}" conversation_service = ConversationService(session) success = await conversation_service.unarchive_conversation(conversation_id) if not success: session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}" raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to unarchive conversation" ) session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}" response = {"message": "Conversation unarchived successfully"} return HxfResponse(response) # Message management @router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息") async def get_conversation_messages( conversation_id: int, skip: int = 0, limit: int = 100, session: Session = Depends(get_session) ): """获取指定对话的消息""" session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" conversation_service = ConversationService(session) messages = await conversation_service.get_conversation_messages( conversation_id, skip=skip, limit=limit ) session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" response = [MessageResponse.model_validate(msg) for msg in messages] return HxfResponse(response) # Chat functionality @router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应") async def chat( conversation_id: int, chat_request: ChatRequest, session: Session = Depends(get_session) ): """发送消息并获取AI响应""" session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}" chat_service = ChatService(session) await chat_service.initialize(conversation_id) # response = await chat_service.chat( # conversation_id=conversation_id, # message=chat_request.message, # stream=False, # temperature=chat_request.temperature, # max_tokens=chat_request.max_tokens, # use_agent=chat_request.use_agent, # 可以简化掉 # use_langgraph=chat_request.use_langgraph, # 可以简化掉 # use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉 # knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉 # ) response = "oooooooooooooooooooK" session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}" return HxfResponse(response) # ------------------------------------------------------------------------ @router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应") async def chat_stream( conversation_id: int, chat_request: ChatRequest, current_user = Depends(AuthService.get_current_user), session: Session = Depends(get_session) ): """发送消息并获取流式AI响应.""" session.title = f"对话{conversation_id} 发送消息并获取流式AI响应" session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> " chat_service = ChatService(session) await chat_service.initialize(conversation_id, streaming=True) async def generate_response(chat_service): try: async for chunk in chat_service.chat_stream( message=chat_request.message ): yield chunk + "\n" except Exception as e: logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}") yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"} response = StreamingResponse( generate_response(chat_service), media_type="text/stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) return response # Conversation management @router.post("/conversations", response_model=ConversationResponse, summary="创建新对话") async def create_conversation( conversation_data: ConversationCreate, current_user: User = Depends(AuthService.get_current_user), session: Session = Depends(get_session) ): """创建新对话""" id = current_user.id session.title = f"用户{current_user.username} - 创建新对话" session.desc = "START: 创建新对话" conversation_service = ConversationService(session) conversation = await conversation_service.create_conversation( user_id=id, conversation_data=conversation_data ) session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}" response = ConversationResponse.model_validate(conversation) return HxfResponse(response) @router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表") async def list_conversations( skip: int = 0, limit: int = 50, search: str = None, include_archived: bool = False, order_by: str = "updated_at", order_desc: bool = True, session: Session = Depends(get_session) ): """获取用户对话列表""" session.title = "获取用户对话列表" session.desc = "START: 获取用户对话列表" conversation_service = ConversationService(session) conversations = await conversation_service.get_user_conversations( skip=skip, limit=limit, search_query=search, include_archived=include_archived, order_by=order_by, order_desc=order_desc ) session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..." response = [ConversationResponse.model_validate(conv) for conv in conversations] return HxfResponse(response) @router.get("/conversations/count", summary="获取用户对话总数") async def get_conversations_count( search: str = None, include_archived: bool = False, session: Session = Depends(get_session) ): """获取用户对话总数""" from th_agenter.core.context import UserContext user_id = UserContext.get_current_user_id() session.title = f"获取用户对话总数[用户id = {user_id}]" session.desc = "START: 获取用户对话总数" conversation_service = ConversationService(session) count = await conversation_service.get_user_conversations_count( search_query=search, include_archived=include_archived ) session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话" response = {"count": count} return HxfResponse(response) @router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话") async def get_conversation( conversation_id: int, session: Session = Depends(get_session) ): """获取指定对话""" session.title = f"获取指定对话[对话id = {conversation_id}]" session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}" conversation_service = ConversationService(session) conversation = await conversation_service.get_conversation( conversation_id=conversation_id ) if not conversation: session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found" ) session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}" response = ConversationResponse.model_validate(conversation) # chat_service = ChatService(session) # await chat_service.initialize(conversation_id, streaming=False) # messages = await chat_service.get_conversation_history_messages( # conversation_id # ) # response.messages = messages messages = await conversation_service.get_conversation_messages( conversation_id, skip=0, limit=100 ) response.messages = [MessageResponse.model_validate(msg) for msg in messages] response.message_count = len(response.messages) return HxfResponse(response)