295 lines
11 KiB
Python
295 lines
11 KiB
Python
"""Conversation service."""
|
||
|
||
from typing import List, Optional
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import select, desc, func, or_
|
||
from langchain_core.messages import HumanMessage, AIMessage
|
||
|
||
from th_agenter.db.database import AsyncSessionFactory
|
||
|
||
from ..models.conversation import Conversation
|
||
from ..models.message import Message, MessageRole
|
||
from utils.util_schemas import ConversationCreate, ConversationUpdate
|
||
from utils.util_exceptions import ConversationNotFoundError, DatabaseError
|
||
from ..core.context import UserContext
|
||
from datetime import datetime, timezone
|
||
from loguru import logger
|
||
|
||
class ConversationService:
|
||
"""Service for managing conversations and messages."""
|
||
|
||
def __init__(self, session: Session):
|
||
self.session = session
|
||
|
||
async def create_conversation(
|
||
self,
|
||
user_id: int,
|
||
conversation_data: ConversationCreate
|
||
) -> Conversation:
|
||
"""Create a new conversation."""
|
||
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
|
||
|
||
try:
|
||
conversation = Conversation(
|
||
**conversation_data.model_dump(),
|
||
user_id=user_id
|
||
)
|
||
|
||
# Set audit fields
|
||
conversation.set_audit_fields(user_id=user_id, is_update=False)
|
||
|
||
self.session.add(conversation)
|
||
await self.session.commit()
|
||
await self.session.refresh(conversation)
|
||
|
||
self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}"
|
||
return conversation
|
||
|
||
except Exception as e:
|
||
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
|
||
await self.session.rollback()
|
||
raise DatabaseError(f"创建会话失败: {str(e)}")
|
||
|
||
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
|
||
"""Get a conversation by ID."""
|
||
try:
|
||
user_id = UserContext.get_current_user_id()
|
||
self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}"
|
||
if user_id is None:
|
||
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
|
||
return None
|
||
|
||
conversation = await self.session.scalar(
|
||
select(Conversation).where(
|
||
Conversation.id == conversation_id,
|
||
Conversation.user_id == user_id
|
||
)
|
||
)
|
||
|
||
if not conversation:
|
||
self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}"
|
||
return conversation
|
||
|
||
except Exception as e:
|
||
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}"
|
||
raise DatabaseError(f"Failed to get conversation: {str(e)}")
|
||
|
||
async def get_user_conversations(
|
||
self,
|
||
skip: int = 0,
|
||
limit: int = 50,
|
||
search_query: Optional[str] = None,
|
||
include_archived: bool = False,
|
||
order_by: str = "updated_at",
|
||
order_desc: bool = True
|
||
) -> List[Conversation]:
|
||
"""Get user's conversations with search and filtering."""
|
||
user_id = UserContext.get_current_user_id()
|
||
if user_id is None:
|
||
logger.error("Failed to get user conversations: No user context available")
|
||
return []
|
||
|
||
query = select(Conversation).where(
|
||
Conversation.user_id == user_id
|
||
)
|
||
|
||
# Filter archived conversations
|
||
if not include_archived:
|
||
query = query.where(Conversation.is_archived == False)
|
||
|
||
# Search functionality
|
||
if search_query and search_query.strip():
|
||
search_term = f"%{search_query.strip()}%"
|
||
query = query.where(
|
||
or_(
|
||
Conversation.title.ilike(search_term),
|
||
Conversation.system_prompt.ilike(search_term)
|
||
)
|
||
)
|
||
|
||
# Ordering
|
||
order_column = getattr(Conversation, order_by, Conversation.updated_at)
|
||
if order_desc:
|
||
query = query.order_by(desc(order_column))
|
||
else:
|
||
query = query.order_by(order_column)
|
||
|
||
return (await self.session.scalars(query.offset(skip).limit(limit))).all()
|
||
|
||
async def update_conversation(
|
||
self,
|
||
conversation_id: int,
|
||
conversation_update: ConversationUpdate
|
||
) -> Optional[Conversation]:
|
||
"""Update a conversation."""
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if not conversation:
|
||
return None
|
||
|
||
update_data = conversation_update.dict(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(conversation, field, value)
|
||
|
||
# Update audit fields
|
||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||
|
||
try:
|
||
await self.session.commit()
|
||
await self.session.refresh(conversation)
|
||
return conversation
|
||
except Exception as e:
|
||
logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True)
|
||
await self.session.rollback()
|
||
raise DatabaseError(f"Failed to update conversation: {str(e)}")
|
||
|
||
async def delete_conversation(self, conversation_id: int) -> bool:
|
||
"""Delete a conversation."""
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if not conversation:
|
||
return False
|
||
|
||
await self.session.delete(conversation)
|
||
await self.session.commit()
|
||
return True
|
||
|
||
async def get_conversation_messages(
|
||
self,
|
||
conversation_id: int,
|
||
skip: int = 0,
|
||
limit: int = 100
|
||
) -> List[Message]:
|
||
"""Get messages from a conversation."""
|
||
return (await self.session.scalars(
|
||
select(Message).where(
|
||
Message.conversation_id == conversation_id
|
||
).order_by(Message.created_at).offset(skip).limit(limit)
|
||
)).all()
|
||
|
||
async def add_message(
|
||
self,
|
||
conversation_id: int,
|
||
content: str,
|
||
role: MessageRole,
|
||
message_metadata: Optional[dict] = None,
|
||
context_documents: Optional[list] = None,
|
||
prompt_tokens: Optional[int] = None,
|
||
completion_tokens: Optional[int] = None,
|
||
total_tokens: Optional[int] = None
|
||
) -> Message:
|
||
"""Add a message to a conversation."""
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
content=content,
|
||
role=role,
|
||
message_metadata=message_metadata,
|
||
context_documents=context_documents,
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens
|
||
)
|
||
|
||
# Set audit fields
|
||
message.set_audit_fields()
|
||
|
||
session = AsyncSessionFactory()
|
||
session.begin()
|
||
try:
|
||
session.add(message)
|
||
await session.commit()
|
||
await session.refresh(message)
|
||
|
||
# Update conversation's updated_at timestamp
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if conversation:
|
||
conversation.updated_at = datetime.now(timezone.utc)
|
||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||
await session.commit()
|
||
except Exception as e:
|
||
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
|
||
await session.rollback()
|
||
finally:
|
||
await session.close()
|
||
|
||
return message
|
||
|
||
async def get_conversation_history_messages(
|
||
self,
|
||
conversation_id: int,
|
||
limit: int = 20
|
||
) -> List[Message]:
|
||
"""Get recent conversation history messages."""
|
||
history = await self.get_conversation_history(conversation_id, limit)
|
||
history_messages = []
|
||
for message in history:
|
||
if message.role == MessageRole.USER:
|
||
history_messages.append(HumanMessage(content=message.content))
|
||
elif message.role == MessageRole.ASSISTANT:
|
||
history_messages.append(AIMessage(content=message.content))
|
||
return history_messages
|
||
|
||
async def get_conversation_history(
|
||
self,
|
||
conversation_id: int,
|
||
limit: int = 20
|
||
) -> List[Message]:
|
||
"""Get recent conversation history for context."""
|
||
return (await self.session.scalars(
|
||
select(Message).where(
|
||
Message.conversation_id == conversation_id
|
||
).order_by(desc(Message.created_at)).limit(limit)
|
||
)).all()[::-1] # Reverse to get chronological order
|
||
|
||
async def update_conversation_timestamp(self, conversation_id: int) -> None:
|
||
"""Update conversation's updated_at timestamp."""
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if conversation:
|
||
conversation.updated_at = datetime.now(timezone.utc)
|
||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||
await self.session.commit()
|
||
|
||
async def get_user_conversations_count(
|
||
self,
|
||
search_query: Optional[str] = None,
|
||
include_archived: bool = False
|
||
) -> int:
|
||
"""Get total count of user's conversations."""
|
||
user_id = UserContext.get_current_user_id()
|
||
query = select(func.count(Conversation.id)).where(
|
||
Conversation.user_id == user_id
|
||
)
|
||
|
||
if not include_archived:
|
||
query = query.where(Conversation.is_archived == False)
|
||
|
||
if search_query and search_query.strip():
|
||
search_term = f"%{search_query.strip()}%"
|
||
query = query.where(
|
||
or_(
|
||
Conversation.title.ilike(search_term),
|
||
Conversation.system_prompt.ilike(search_term)
|
||
)
|
||
)
|
||
|
||
return (await self.session.scalar(query)) or 0
|
||
|
||
async def archive_conversation(self, conversation_id: int) -> bool:
|
||
"""Archive a conversation."""
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if not conversation:
|
||
return False
|
||
|
||
conversation.is_archived = True
|
||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||
await self.session.commit()
|
||
return True
|
||
|
||
async def unarchive_conversation(self, conversation_id: int) -> bool:
|
||
"""Unarchive a conversation."""
|
||
conversation = await self.get_conversation(conversation_id)
|
||
if not conversation:
|
||
return False
|
||
|
||
conversation.is_archived = False
|
||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||
await self.session.commit()
|
||
return True |