hyf-backend/th_agenter/services/conversation.py

295 lines
11 KiB
Python
Raw Permalink Normal View History

2026-01-21 13:45:39 +08:00
"""Conversation service."""
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, desc, func, or_
from langchain_core.messages import HumanMessage, AIMessage
from th_agenter.db.database import AsyncSessionFactory
from ..models.conversation import Conversation
from ..models.message import Message, MessageRole
from utils.util_schemas import ConversationCreate, ConversationUpdate
from utils.util_exceptions import ConversationNotFoundError, DatabaseError
from ..core.context import UserContext
from datetime import datetime, timezone
from loguru import logger
class ConversationService:
"""Service for managing conversations and messages."""
def __init__(self, session: Session):
self.session = session
async def create_conversation(
self,
user_id: int,
conversation_data: ConversationCreate
) -> Conversation:
"""Create a new conversation."""
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
try:
conversation = Conversation(
**conversation_data.model_dump(),
user_id=user_id
)
# Set audit fields
conversation.set_audit_fields(user_id=user_id, is_update=False)
self.session.add(conversation)
await self.session.commit()
await self.session.refresh(conversation)
self.session.desc = f"创建新会话 Conversation ID: {conversation.id}用户ID: {user_id}"
return conversation
except Exception as e:
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
await self.session.rollback()
raise DatabaseError(f"创建会话失败: {str(e)}")
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
"""Get a conversation by ID."""
try:
user_id = UserContext.get_current_user_id()
self.session.desc = f"获取会话 - 会话ID: {conversation_id}用户ID: {user_id}"
if user_id is None:
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
return None
conversation = await self.session.scalar(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == user_id
)
)
if not conversation:
self.session.desc = f"警告: 会话 {conversation_id} 不存在用户ID: {user_id}"
return conversation
except Exception as e:
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id}用户ID: {user_id},错误: {str(e)}"
raise DatabaseError(f"Failed to get conversation: {str(e)}")
async def get_user_conversations(
self,
skip: int = 0,
limit: int = 50,
search_query: Optional[str] = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True
) -> List[Conversation]:
"""Get user's conversations with search and filtering."""
user_id = UserContext.get_current_user_id()
if user_id is None:
logger.error("Failed to get user conversations: No user context available")
return []
query = select(Conversation).where(
Conversation.user_id == user_id
)
# Filter archived conversations
if not include_archived:
query = query.where(Conversation.is_archived == False)
# Search functionality
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.where(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
# Ordering
order_column = getattr(Conversation, order_by, Conversation.updated_at)
if order_desc:
query = query.order_by(desc(order_column))
else:
query = query.order_by(order_column)
return (await self.session.scalars(query.offset(skip).limit(limit))).all()
async def update_conversation(
self,
conversation_id: int,
conversation_update: ConversationUpdate
) -> Optional[Conversation]:
"""Update a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return None
update_data = conversation_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(conversation, field, value)
# Update audit fields
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
try:
await self.session.commit()
await self.session.refresh(conversation)
return conversation
except Exception as e:
logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True)
await self.session.rollback()
raise DatabaseError(f"Failed to update conversation: {str(e)}")
async def delete_conversation(self, conversation_id: int) -> bool:
"""Delete a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
await self.session.delete(conversation)
await self.session.commit()
return True
async def get_conversation_messages(
self,
conversation_id: int,
skip: int = 0,
limit: int = 100
) -> List[Message]:
"""Get messages from a conversation."""
return (await self.session.scalars(
select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at).offset(skip).limit(limit)
)).all()
async def add_message(
self,
conversation_id: int,
content: str,
role: MessageRole,
message_metadata: Optional[dict] = None,
context_documents: Optional[list] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> Message:
"""Add a message to a conversation."""
message = Message(
conversation_id=conversation_id,
content=content,
role=role,
message_metadata=message_metadata,
context_documents=context_documents,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
# Set audit fields
message.set_audit_fields()
session = AsyncSessionFactory()
session.begin()
try:
session.add(message)
await session.commit()
await session.refresh(message)
# Update conversation's updated_at timestamp
conversation = await self.get_conversation(conversation_id)
if conversation:
conversation.updated_at = datetime.now(timezone.utc)
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await session.commit()
except Exception as e:
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
await session.rollback()
finally:
await session.close()
return message
async def get_conversation_history_messages(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history messages."""
history = await self.get_conversation_history(conversation_id, limit)
history_messages = []
for message in history:
if message.role == MessageRole.USER:
history_messages.append(HumanMessage(content=message.content))
elif message.role == MessageRole.ASSISTANT:
history_messages.append(AIMessage(content=message.content))
return history_messages
async def get_conversation_history(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history for context."""
return (await self.session.scalars(
select(Message).where(
Message.conversation_id == conversation_id
).order_by(desc(Message.created_at)).limit(limit)
)).all()[::-1] # Reverse to get chronological order
async def update_conversation_timestamp(self, conversation_id: int) -> None:
"""Update conversation's updated_at timestamp."""
conversation = await self.get_conversation(conversation_id)
if conversation:
conversation.updated_at = datetime.now(timezone.utc)
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
async def get_user_conversations_count(
self,
search_query: Optional[str] = None,
include_archived: bool = False
) -> int:
"""Get total count of user's conversations."""
user_id = UserContext.get_current_user_id()
query = select(func.count(Conversation.id)).where(
Conversation.user_id == user_id
)
if not include_archived:
query = query.where(Conversation.is_archived == False)
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.where(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
return (await self.session.scalar(query)) or 0
async def archive_conversation(self, conversation_id: int) -> bool:
"""Archive a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = True
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
return True
async def unarchive_conversation(self, conversation_id: int) -> bool:
"""Unarchive a conversation."""
conversation = await self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = False
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
return True