238 lines
7.4 KiB
Python
238 lines
7.4 KiB
Python
"""Chat endpoints."""
|
|
|
|
from typing import List
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ...db.database import get_db
|
|
from ...models.user import User
|
|
from ...services.auth import AuthService
|
|
from ...services.chat import ChatService
|
|
from ...services.conversation import ConversationService
|
|
from ...utils.schemas import (
|
|
ConversationCreate,
|
|
ConversationResponse,
|
|
ConversationUpdate,
|
|
MessageCreate,
|
|
MessageResponse,
|
|
ChatRequest,
|
|
ChatResponse
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# Conversation management
|
|
@router.post("/conversations", response_model=ConversationResponse)
|
|
async def create_conversation(
|
|
conversation_data: ConversationCreate,
|
|
current_user: User = Depends(AuthService.get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Create a new conversation."""
|
|
conversation_service = ConversationService(db)
|
|
conversation = conversation_service.create_conversation(
|
|
user_id=current_user.id,
|
|
conversation_data=conversation_data
|
|
)
|
|
return ConversationResponse.from_orm(conversation)
|
|
|
|
|
|
@router.get("/conversations", response_model=List[ConversationResponse])
|
|
async def list_conversations(
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
search: str = None,
|
|
include_archived: bool = False,
|
|
order_by: str = "updated_at",
|
|
order_desc: bool = True,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""List user's conversations with search and filtering."""
|
|
conversation_service = ConversationService(db)
|
|
conversations = conversation_service.get_user_conversations(
|
|
skip=skip,
|
|
limit=limit,
|
|
search_query=search,
|
|
include_archived=include_archived,
|
|
order_by=order_by,
|
|
order_desc=order_desc
|
|
)
|
|
return [ConversationResponse.from_orm(conv) for conv in conversations]
|
|
|
|
|
|
@router.get("/conversations/count")
|
|
async def get_conversations_count(
|
|
search: str = None,
|
|
include_archived: bool = False,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get total count of conversations."""
|
|
conversation_service = ConversationService(db)
|
|
count = conversation_service.get_user_conversations_count(
|
|
search_query=search,
|
|
include_archived=include_archived
|
|
)
|
|
return {"count": count}
|
|
|
|
|
|
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse)
|
|
async def get_conversation(
|
|
conversation_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get a specific conversation."""
|
|
conversation_service = ConversationService(db)
|
|
conversation = conversation_service.get_conversation(
|
|
conversation_id=conversation_id
|
|
)
|
|
if not conversation:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Conversation not found"
|
|
)
|
|
return ConversationResponse.from_orm(conversation)
|
|
|
|
|
|
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse)
|
|
async def update_conversation(
|
|
conversation_id: int,
|
|
conversation_update: ConversationUpdate,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Update a conversation."""
|
|
conversation_service = ConversationService(db)
|
|
updated_conversation = conversation_service.update_conversation(
|
|
conversation_id, conversation_update
|
|
)
|
|
return ConversationResponse.from_orm(updated_conversation)
|
|
|
|
|
|
@router.delete("/conversations/{conversation_id}")
|
|
async def delete_conversation(
|
|
conversation_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Delete a conversation."""
|
|
conversation_service = ConversationService(db)
|
|
conversation_service.delete_conversation(conversation_id)
|
|
return {"message": "Conversation deleted successfully"}
|
|
|
|
|
|
@router.delete("/conversations")
|
|
async def delete_all_conversations(
|
|
current_user: User = Depends(AuthService.get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Delete all conversations."""
|
|
conversation_service = ConversationService(db)
|
|
conversation_service.delete_all_conversations()
|
|
return {"message": "All conversations deleted successfully"}
|
|
|
|
|
|
@router.put("/conversations/{conversation_id}/archive")
|
|
async def archive_conversation(
|
|
conversation_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Archive a conversation."""
|
|
conversation_service = ConversationService(db)
|
|
success = conversation_service.archive_conversation(conversation_id)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to archive conversation"
|
|
)
|
|
|
|
return {"message": "Conversation archived successfully"}
|
|
|
|
|
|
@router.put("/conversations/{conversation_id}/unarchive")
|
|
async def unarchive_conversation(
|
|
conversation_id: int,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Unarchive a conversation."""
|
|
conversation_service = ConversationService(db)
|
|
success = conversation_service.unarchive_conversation(conversation_id)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to unarchive conversation"
|
|
)
|
|
|
|
return {"message": "Conversation unarchived successfully"}
|
|
|
|
|
|
# Message management
|
|
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse])
|
|
async def get_conversation_messages(
|
|
conversation_id: int,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Get messages from a conversation."""
|
|
conversation_service = ConversationService(db)
|
|
messages = conversation_service.get_conversation_messages(
|
|
conversation_id, skip=skip, limit=limit
|
|
)
|
|
return [MessageResponse.from_orm(msg) for msg in messages]
|
|
|
|
|
|
# Chat functionality
|
|
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse)
|
|
async def chat(
|
|
conversation_id: int,
|
|
chat_request: ChatRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Send a message and get AI response."""
|
|
chat_service = ChatService(db)
|
|
response = await chat_service.chat(
|
|
conversation_id=conversation_id,
|
|
message=chat_request.message,
|
|
stream=False,
|
|
temperature=chat_request.temperature,
|
|
max_tokens=chat_request.max_tokens,
|
|
use_agent=chat_request.use_agent,
|
|
use_langgraph=chat_request.use_langgraph,
|
|
use_knowledge_base=chat_request.use_knowledge_base,
|
|
knowledge_base_id=chat_request.knowledge_base_id
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@router.post("/conversations/{conversation_id}/chat/stream")
|
|
async def chat_stream(
|
|
conversation_id: int,
|
|
chat_request: ChatRequest,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""Send a message and get streaming AI response."""
|
|
chat_service = ChatService(db)
|
|
|
|
async def generate_response():
|
|
async for chunk in chat_service.chat_stream(
|
|
conversation_id=conversation_id,
|
|
message=chat_request.message,
|
|
temperature=chat_request.temperature,
|
|
max_tokens=chat_request.max_tokens,
|
|
use_agent=chat_request.use_agent,
|
|
use_langgraph=chat_request.use_langgraph,
|
|
use_knowledge_base=chat_request.use_knowledge_base,
|
|
knowledge_base_id=chat_request.knowledge_base_id
|
|
):
|
|
yield f"data: {chunk}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_response(),
|
|
media_type="text/plain",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
}
|
|
)
|