hxf/backend/th_agenter/services/conversation.py

260 lines
8.9 KiB
Python

"""Conversation service."""
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import desc, func, or_
from ..models.conversation import Conversation
from ..models.message import Message, MessageRole
from ..utils.schemas import ConversationCreate, ConversationUpdate
from ..utils.exceptions import ConversationNotFoundError, DatabaseError
from ..utils.logger import get_logger
from ..core.context import UserContext
logger = get_logger("conversation_service")
class ConversationService:
"""Service for managing conversations and messages."""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
user_id: int,
conversation_data: ConversationCreate
) -> Conversation:
"""Create a new conversation."""
logger.info(f"Creating new conversation for user {user_id}")
try:
conversation = Conversation(
**conversation_data.dict(),
user_id=user_id
)
# Set audit fields
conversation.set_audit_fields(user_id=user_id, is_update=False)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info(f"Successfully created conversation {conversation.id} for user {user_id}")
return conversation
except Exception as e:
logger.error(f"Failed to create conversation: {str(e)}", exc_info=True)
self.db.rollback()
raise DatabaseError(f"Failed to create conversation: {str(e)}")
def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
"""Get a conversation by ID."""
try:
user_id = UserContext.get_current_user_id()
conversation = self.db.query(Conversation).filter(
Conversation.id == conversation_id,
Conversation.user_id == user_id
).first()
if not conversation:
logger.warning(f"Conversation {conversation_id} not found")
return conversation
except Exception as e:
logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True)
raise DatabaseError(f"Failed to get conversation: {str(e)}")
def get_user_conversations(
self,
skip: int = 0,
limit: int = 50,
search_query: Optional[str] = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True
) -> List[Conversation]:
"""Get user's conversations with search and filtering."""
user_id = UserContext.get_current_user_id()
query = self.db.query(Conversation).filter(
Conversation.user_id == user_id
)
# Filter archived conversations
if not include_archived:
query = query.filter(Conversation.is_archived == False)
# Search functionality
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.filter(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
# Ordering
order_column = getattr(Conversation, order_by, Conversation.updated_at)
if order_desc:
query = query.order_by(desc(order_column))
else:
query = query.order_by(order_column)
return query.offset(skip).limit(limit).all()
def update_conversation(
self,
conversation_id: int,
conversation_update: ConversationUpdate
) -> Optional[Conversation]:
"""Update a conversation."""
conversation = self.get_conversation(conversation_id)
if not conversation:
return None
update_data = conversation_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(conversation, field, value)
self.db.commit()
self.db.refresh(conversation)
return conversation
def delete_conversation(self, conversation_id: int) -> bool:
"""Delete a conversation."""
conversation = self.get_conversation(conversation_id)
if not conversation:
return False
self.db.delete(conversation)
self.db.commit()
return True
def get_conversation_messages(
self,
conversation_id: int,
skip: int = 0,
limit: int = 100
) -> List[Message]:
"""Get messages from a conversation."""
return self.db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at).offset(skip).limit(limit).all()
def add_message(
self,
conversation_id: int,
content: str,
role: MessageRole,
message_metadata: Optional[dict] = None,
context_documents: Optional[list] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None
) -> Message:
"""Add a message to a conversation."""
message = Message(
conversation_id=conversation_id,
content=content,
role=role,
message_metadata=message_metadata,
context_documents=context_documents,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
# Set audit fields
message.set_audit_fields()
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def get_conversation_history(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history for context."""
return self.db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(desc(Message.created_at)).limit(limit).all()[::-1] # Reverse to get chronological order
def update_conversation_timestamp(self, conversation_id: int) -> None:
"""Update conversation's updated_at timestamp."""
conversation = self.get_conversation(conversation_id)
if conversation:
# SQLAlchemy will automatically update the updated_at field
self.db.commit()
def get_user_conversations_count(
self,
search_query: Optional[str] = None,
include_archived: bool = False
) -> int:
"""Get total count of user's conversations."""
user_id = UserContext.get_current_user_id()
query = self.db.query(func.count(Conversation.id)).filter(
Conversation.user_id == user_id
)
if not include_archived:
query = query.filter(Conversation.is_archived == False)
if search_query and search_query.strip():
search_term = f"%{search_query.strip()}%"
query = query.filter(
or_(
Conversation.title.ilike(search_term),
Conversation.system_prompt.ilike(search_term)
)
)
return query.scalar() or 0
def archive_conversation(self, conversation_id: int) -> bool:
"""Archive a conversation."""
conversation = self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = True
self.db.commit()
return True
def unarchive_conversation(self, conversation_id: int) -> bool:
"""Unarchive a conversation."""
conversation = self.get_conversation(conversation_id)
if not conversation:
return False
conversation.is_archived = False
self.db.commit()
return True
def delete_all_conversations(self) -> bool:
"""Delete all conversations for the current user."""
try:
user_id = UserContext.get_current_user_id()
# Get all conversations for the user
conversations = self.db.query(Conversation).filter(
Conversation.user_id == user_id
).all()
# Delete each conversation
for conversation in conversations:
self.db.delete(conversation)
self.db.commit()
logger.info(f"Successfully deleted all conversations for user {user_id}")
return True
except Exception as e:
logger.error(f"Failed to delete all conversations: {str(e)}", exc_info=True)
self.db.rollback()
raise DatabaseError(f"Failed to delete all conversations: {str(e)}")