"""Chat endpoints for TH Agenter.""" from typing import List from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from loguru import logger from ...db.database import get_session from ...models.user import User from ...services.auth import AuthService from ...services.chat import ChatService from ...services.conversation import ConversationService from utils.util_schemas import ( ConversationCreate, ConversationResponse, ConversationUpdate, MessageCreate, MessageResponse, ChatRequest, ChatResponse ) router = APIRouter() # Conversation management @router.post("/conversations", response_model=ConversationResponse, summary="创建新对话") async def create_conversation( conversation_data: ConversationCreate, current_user: User = Depends(AuthService.get_current_user), session: Session = Depends(get_session) ): """创建新对话""" session.desc = "START: 创建新对话" conversation_service = ConversationService(session) conversation = await conversation_service.create_conversation( user_id=current_user.id, conversation_data=conversation_data ) session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}" return ConversationResponse.model_validate(conversation) @router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表") async def list_conversations( skip: int = 0, limit: int = 50, search: str = None, include_archived: bool = False, order_by: str = "updated_at", order_desc: bool = True, session: Session = Depends(get_session) ): """获取用户对话列表""" session.desc = "START: 获取用户对话列表" conversation_service = ConversationService(session) conversations = await conversation_service.get_user_conversations( skip=skip, limit=limit, search_query=search, include_archived=include_archived, order_by=order_by, order_desc=order_desc ) session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话" return [ConversationResponse.model_validate(conv) for conv in conversations] @router.get("/conversations/count", summary="获取用户对话总数") async def get_conversations_count( search: str = None, include_archived: bool = False, session: Session = Depends(get_session) ): """获取用户对话总数""" session.desc = "START: 获取用户对话总数" conversation_service = ConversationService(session) count = await conversation_service.get_user_conversations_count( search_query=search, include_archived=include_archived ) session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话" return {"count": count} @router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话") async def get_conversation( conversation_id: int, session: Session = Depends(get_session) ): """获取指定对话""" session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}" conversation_service = ConversationService(session) conversation = await conversation_service.get_conversation( conversation_id=conversation_id ) if not conversation: session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Conversation not found" ) session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}" return ConversationResponse.model_validate(conversation) @router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话") async def update_conversation( conversation_id: int, conversation_update: ConversationUpdate, session: Session = Depends(get_session) ): """更新指定对话""" session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}" conversation_service = ConversationService(session) updated_conversation = await conversation_service.update_conversation( conversation_id, conversation_update ) session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}" return ConversationResponse.model_validate(updated_conversation) @router.delete("/conversations/{conversation_id}", summary="删除指定对话") async def delete_conversation( conversation_id: int, session: Session = Depends(get_session) ): """删除指定对话""" session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}" conversation_service = ConversationService(session) await conversation_service.delete_conversation(conversation_id) session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}" return {"message": "Conversation deleted successfully"} @router.put("/conversations/{conversation_id}/archive", summary="归档指定对话") async def archive_conversation( conversation_id: int, session: Session = Depends(get_session) ): """归档指定对话.""" conversation_service = ConversationService(session) success = await conversation_service.archive_conversation(conversation_id) if not success: session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}" raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to archive conversation" ) session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}" return {"message": "Conversation archived successfully"} @router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话") async def unarchive_conversation( conversation_id: int, session: Session = Depends(get_session) ): """取消归档指定对话.""" session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}" conversation_service = ConversationService(session) success = await conversation_service.unarchive_conversation(conversation_id) if not success: session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}" raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to unarchive conversation" ) session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}" return {"message": "Conversation unarchived successfully"} # Message management @router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息") async def get_conversation_messages( conversation_id: int, skip: int = 0, limit: int = 100, session: Session = Depends(get_session) ): """获取指定对话的消息""" session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" conversation_service = ConversationService(session) messages = await conversation_service.get_conversation_messages( conversation_id, skip=skip, limit=limit ) session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" return [MessageResponse.model_validate(msg) for msg in messages] # Chat functionality @router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应") async def chat( conversation_id: int, chat_request: ChatRequest, session: Session = Depends(get_session) ): """发送消息并获取AI响应""" session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}" chat_service = ChatService(session) response = await chat_service.chat( conversation_id=conversation_id, message=chat_request.message, stream=False, temperature=chat_request.temperature, max_tokens=chat_request.max_tokens, use_agent=chat_request.use_agent, use_langgraph=chat_request.use_langgraph, use_knowledge_base=chat_request.use_knowledge_base, knowledge_base_id=chat_request.knowledge_base_id ) session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}" return response @router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应") async def chat_stream( conversation_id: int, chat_request: ChatRequest, session: Session = Depends(get_session) ): """发送消息并获取流式AI响应.""" chat_service = ChatService(session) async def generate_response(): async for chunk in chat_service.chat_stream( conversation_id=conversation_id, message=chat_request.message, temperature=chat_request.temperature, max_tokens=chat_request.max_tokens, use_agent=chat_request.use_agent, use_langgraph=chat_request.use_langgraph, use_knowledge_base=chat_request.use_knowledge_base, knowledge_base_id=chat_request.knowledge_base_id ): yield f"data: {chunk}\n\n" return StreamingResponse( generate_response(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } )