295 lines
11 KiB
Python
295 lines
11 KiB
Python
|
|
"""Conversation service."""
|
|||
|
|
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
from sqlalchemy import select, desc, func, or_
|
|||
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
|||
|
|
|
|||
|
|
from th_agenter.db.database import AsyncSessionFactory
|
|||
|
|
|
|||
|
|
from ..models.conversation import Conversation
|
|||
|
|
from ..models.message import Message, MessageRole
|
|||
|
|
from utils.util_schemas import ConversationCreate, ConversationUpdate
|
|||
|
|
from utils.util_exceptions import ConversationNotFoundError, DatabaseError
|
|||
|
|
from ..core.context import UserContext
|
|||
|
|
from datetime import datetime, timezone
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
class ConversationService:
|
|||
|
|
"""Service for managing conversations and messages."""
|
|||
|
|
|
|||
|
|
def __init__(self, session: Session):
|
|||
|
|
self.session = session
|
|||
|
|
|
|||
|
|
async def create_conversation(
|
|||
|
|
self,
|
|||
|
|
user_id: int,
|
|||
|
|
conversation_data: ConversationCreate
|
|||
|
|
) -> Conversation:
|
|||
|
|
"""Create a new conversation."""
|
|||
|
|
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
conversation = Conversation(
|
|||
|
|
**conversation_data.model_dump(),
|
|||
|
|
user_id=user_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Set audit fields
|
|||
|
|
conversation.set_audit_fields(user_id=user_id, is_update=False)
|
|||
|
|
|
|||
|
|
self.session.add(conversation)
|
|||
|
|
await self.session.commit()
|
|||
|
|
await self.session.refresh(conversation)
|
|||
|
|
|
|||
|
|
self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}"
|
|||
|
|
return conversation
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
|
|||
|
|
await self.session.rollback()
|
|||
|
|
raise DatabaseError(f"创建会话失败: {str(e)}")
|
|||
|
|
|
|||
|
|
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
|
|||
|
|
"""Get a conversation by ID."""
|
|||
|
|
try:
|
|||
|
|
user_id = UserContext.get_current_user_id()
|
|||
|
|
self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}"
|
|||
|
|
if user_id is None:
|
|||
|
|
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
conversation = await self.session.scalar(
|
|||
|
|
select(Conversation).where(
|
|||
|
|
Conversation.id == conversation_id,
|
|||
|
|
Conversation.user_id == user_id
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not conversation:
|
|||
|
|
self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}"
|
|||
|
|
return conversation
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}"
|
|||
|
|
raise DatabaseError(f"Failed to get conversation: {str(e)}")
|
|||
|
|
|
|||
|
|
async def get_user_conversations(
|
|||
|
|
self,
|
|||
|
|
skip: int = 0,
|
|||
|
|
limit: int = 50,
|
|||
|
|
search_query: Optional[str] = None,
|
|||
|
|
include_archived: bool = False,
|
|||
|
|
order_by: str = "updated_at",
|
|||
|
|
order_desc: bool = True
|
|||
|
|
) -> List[Conversation]:
|
|||
|
|
"""Get user's conversations with search and filtering."""
|
|||
|
|
user_id = UserContext.get_current_user_id()
|
|||
|
|
if user_id is None:
|
|||
|
|
logger.error("Failed to get user conversations: No user context available")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
query = select(Conversation).where(
|
|||
|
|
Conversation.user_id == user_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Filter archived conversations
|
|||
|
|
if not include_archived:
|
|||
|
|
query = query.where(Conversation.is_archived == False)
|
|||
|
|
|
|||
|
|
# Search functionality
|
|||
|
|
if search_query and search_query.strip():
|
|||
|
|
search_term = f"%{search_query.strip()}%"
|
|||
|
|
query = query.where(
|
|||
|
|
or_(
|
|||
|
|
Conversation.title.ilike(search_term),
|
|||
|
|
Conversation.system_prompt.ilike(search_term)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Ordering
|
|||
|
|
order_column = getattr(Conversation, order_by, Conversation.updated_at)
|
|||
|
|
if order_desc:
|
|||
|
|
query = query.order_by(desc(order_column))
|
|||
|
|
else:
|
|||
|
|
query = query.order_by(order_column)
|
|||
|
|
|
|||
|
|
return (await self.session.scalars(query.offset(skip).limit(limit))).all()
|
|||
|
|
|
|||
|
|
async def update_conversation(
|
|||
|
|
self,
|
|||
|
|
conversation_id: int,
|
|||
|
|
conversation_update: ConversationUpdate
|
|||
|
|
) -> Optional[Conversation]:
|
|||
|
|
"""Update a conversation."""
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if not conversation:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
update_data = conversation_update.dict(exclude_unset=True)
|
|||
|
|
for field, value in update_data.items():
|
|||
|
|
setattr(conversation, field, value)
|
|||
|
|
|
|||
|
|
# Update audit fields
|
|||
|
|
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
await self.session.commit()
|
|||
|
|
await self.session.refresh(conversation)
|
|||
|
|
return conversation
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True)
|
|||
|
|
await self.session.rollback()
|
|||
|
|
raise DatabaseError(f"Failed to update conversation: {str(e)}")
|
|||
|
|
|
|||
|
|
async def delete_conversation(self, conversation_id: int) -> bool:
|
|||
|
|
"""Delete a conversation."""
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if not conversation:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
await self.session.delete(conversation)
|
|||
|
|
await self.session.commit()
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
async def get_conversation_messages(
|
|||
|
|
self,
|
|||
|
|
conversation_id: int,
|
|||
|
|
skip: int = 0,
|
|||
|
|
limit: int = 100
|
|||
|
|
) -> List[Message]:
|
|||
|
|
"""Get messages from a conversation."""
|
|||
|
|
return (await self.session.scalars(
|
|||
|
|
select(Message).where(
|
|||
|
|
Message.conversation_id == conversation_id
|
|||
|
|
).order_by(Message.created_at).offset(skip).limit(limit)
|
|||
|
|
)).all()
|
|||
|
|
|
|||
|
|
async def add_message(
|
|||
|
|
self,
|
|||
|
|
conversation_id: int,
|
|||
|
|
content: str,
|
|||
|
|
role: MessageRole,
|
|||
|
|
message_metadata: Optional[dict] = None,
|
|||
|
|
context_documents: Optional[list] = None,
|
|||
|
|
prompt_tokens: Optional[int] = None,
|
|||
|
|
completion_tokens: Optional[int] = None,
|
|||
|
|
total_tokens: Optional[int] = None
|
|||
|
|
) -> Message:
|
|||
|
|
"""Add a message to a conversation."""
|
|||
|
|
message = Message(
|
|||
|
|
conversation_id=conversation_id,
|
|||
|
|
content=content,
|
|||
|
|
role=role,
|
|||
|
|
message_metadata=message_metadata,
|
|||
|
|
context_documents=context_documents,
|
|||
|
|
prompt_tokens=prompt_tokens,
|
|||
|
|
completion_tokens=completion_tokens,
|
|||
|
|
total_tokens=total_tokens
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Set audit fields
|
|||
|
|
message.set_audit_fields()
|
|||
|
|
|
|||
|
|
session = AsyncSessionFactory()
|
|||
|
|
session.begin()
|
|||
|
|
try:
|
|||
|
|
session.add(message)
|
|||
|
|
await session.commit()
|
|||
|
|
await session.refresh(message)
|
|||
|
|
|
|||
|
|
# Update conversation's updated_at timestamp
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if conversation:
|
|||
|
|
conversation.updated_at = datetime.now(timezone.utc)
|
|||
|
|
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
|||
|
|
await session.commit()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
|
|||
|
|
await session.rollback()
|
|||
|
|
finally:
|
|||
|
|
await session.close()
|
|||
|
|
|
|||
|
|
return message
|
|||
|
|
|
|||
|
|
async def get_conversation_history_messages(
|
|||
|
|
self,
|
|||
|
|
conversation_id: int,
|
|||
|
|
limit: int = 20
|
|||
|
|
) -> List[Message]:
|
|||
|
|
"""Get recent conversation history messages."""
|
|||
|
|
history = await self.get_conversation_history(conversation_id, limit)
|
|||
|
|
history_messages = []
|
|||
|
|
for message in history:
|
|||
|
|
if message.role == MessageRole.USER:
|
|||
|
|
history_messages.append(HumanMessage(content=message.content))
|
|||
|
|
elif message.role == MessageRole.ASSISTANT:
|
|||
|
|
history_messages.append(AIMessage(content=message.content))
|
|||
|
|
return history_messages
|
|||
|
|
|
|||
|
|
async def get_conversation_history(
|
|||
|
|
self,
|
|||
|
|
conversation_id: int,
|
|||
|
|
limit: int = 20
|
|||
|
|
) -> List[Message]:
|
|||
|
|
"""Get recent conversation history for context."""
|
|||
|
|
return (await self.session.scalars(
|
|||
|
|
select(Message).where(
|
|||
|
|
Message.conversation_id == conversation_id
|
|||
|
|
).order_by(desc(Message.created_at)).limit(limit)
|
|||
|
|
)).all()[::-1] # Reverse to get chronological order
|
|||
|
|
|
|||
|
|
async def update_conversation_timestamp(self, conversation_id: int) -> None:
|
|||
|
|
"""Update conversation's updated_at timestamp."""
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if conversation:
|
|||
|
|
conversation.updated_at = datetime.now(timezone.utc)
|
|||
|
|
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
|||
|
|
await self.session.commit()
|
|||
|
|
|
|||
|
|
async def get_user_conversations_count(
|
|||
|
|
self,
|
|||
|
|
search_query: Optional[str] = None,
|
|||
|
|
include_archived: bool = False
|
|||
|
|
) -> int:
|
|||
|
|
"""Get total count of user's conversations."""
|
|||
|
|
user_id = UserContext.get_current_user_id()
|
|||
|
|
query = select(func.count(Conversation.id)).where(
|
|||
|
|
Conversation.user_id == user_id
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not include_archived:
|
|||
|
|
query = query.where(Conversation.is_archived == False)
|
|||
|
|
|
|||
|
|
if search_query and search_query.strip():
|
|||
|
|
search_term = f"%{search_query.strip()}%"
|
|||
|
|
query = query.where(
|
|||
|
|
or_(
|
|||
|
|
Conversation.title.ilike(search_term),
|
|||
|
|
Conversation.system_prompt.ilike(search_term)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return (await self.session.scalar(query)) or 0
|
|||
|
|
|
|||
|
|
async def archive_conversation(self, conversation_id: int) -> bool:
|
|||
|
|
"""Archive a conversation."""
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if not conversation:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
conversation.is_archived = True
|
|||
|
|
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
|||
|
|
await self.session.commit()
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
async def unarchive_conversation(self, conversation_id: int) -> bool:
|
|||
|
|
"""Unarchive a conversation."""
|
|||
|
|
conversation = await self.get_conversation(conversation_id)
|
|||
|
|
if not conversation:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
conversation.is_archived = False
|
|||
|
|
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
|||
|
|
await self.session.commit()
|
|||
|
|
return True
|