"""Conversation service.""" from typing import List, Optional from sqlalchemy.orm import Session from sqlalchemy import select, desc, func, or_ from ..models.conversation import Conversation from ..models.message import Message, MessageRole from utils.util_schemas import ConversationCreate, ConversationUpdate from utils.util_exceptions import ConversationNotFoundError, DatabaseError from ..core.context import UserContext from datetime import datetime, timezone from loguru import logger class ConversationService: """Service for managing conversations and messages.""" def __init__(self, session: Session): self.session = session async def create_conversation( self, user_id: int, conversation_data: ConversationCreate ) -> Conversation: """Create a new conversation.""" logger.info(f"Creating new conversation for user {user_id}: {conversation_data}") try: conversation = Conversation( **conversation_data.model_dump(), user_id=user_id ) # Set audit fields conversation.set_audit_fields(user_id=user_id, is_update=False) self.session.add(conversation) await self.session.commit() await self.session.refresh(conversation) logger.info(f"Successfully created conversation {conversation.id} for user {user_id}") return conversation except Exception as e: logger.error(f"Failed to create conversation: {str(e)}", exc_info=True) await self.session.rollback() raise DatabaseError(f"Failed to create conversation: {str(e)}") async def get_conversation(self, conversation_id: int) -> Optional[Conversation]: """Get a conversation by ID.""" try: user_id = UserContext.get_current_user_id() conversation = await self.session.scalar( select(Conversation).where( Conversation.id == conversation_id, Conversation.user_id == user_id ) ) if not conversation: logger.warning(f"Conversation {conversation_id} not found") return conversation except Exception as e: logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True) raise DatabaseError(f"Failed to get conversation: {str(e)}") async def get_user_conversations( self, skip: int = 0, limit: int = 50, search_query: Optional[str] = None, include_archived: bool = False, order_by: str = "updated_at", order_desc: bool = True ) -> List[Conversation]: """Get user's conversations with search and filtering.""" user_id = UserContext.get_current_user_id() query = select(Conversation).where( Conversation.user_id == user_id ) # Filter archived conversations if not include_archived: query = query.where(Conversation.is_archived == False) # Search functionality if search_query and search_query.strip(): search_term = f"%{search_query.strip()}%" query = query.where( or_( Conversation.title.ilike(search_term), Conversation.system_prompt.ilike(search_term) ) ) # Ordering order_column = getattr(Conversation, order_by, Conversation.updated_at) if order_desc: query = query.order_by(desc(order_column)) else: query = query.order_by(order_column) return (await self.session.scalars(query.offset(skip).limit(limit))).all() async def update_conversation( self, conversation_id: int, conversation_update: ConversationUpdate ) -> Optional[Conversation]: """Update a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return None update_data = conversation_update.dict(exclude_unset=True) for field, value in update_data.items(): setattr(conversation, field, value) # Update audit fields conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) try: await self.session.commit() await self.session.refresh(conversation) return conversation except Exception as e: logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True) await self.session.rollback() raise DatabaseError(f"Failed to update conversation: {str(e)}") async def delete_conversation(self, conversation_id: int) -> bool: """Delete a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False self.session.delete(conversation) await self.session.commit() return True async def get_conversation_messages( self, conversation_id: int, skip: int = 0, limit: int = 100 ) -> List[Message]: """Get messages from a conversation.""" return (await self.session.scalars( select(Message).where( Message.conversation_id == conversation_id ).order_by(Message.created_at).offset(skip).limit(limit) )).all() async def add_message( self, conversation_id: int, content: str, role: MessageRole, message_metadata: Optional[dict] = None, context_documents: Optional[list] = None, prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, total_tokens: Optional[int] = None ) -> Message: """Add a message to a conversation.""" message = Message( conversation_id=conversation_id, content=content, role=role, message_metadata=message_metadata, context_documents=context_documents, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens ) # Set audit fields message.set_audit_fields() self.session.add(message) await self.session.commit() await self.session.refresh(message) # Update conversation's updated_at timestamp conversation = await self.get_conversation(conversation_id) if conversation: conversation.updated_at = datetime.now(timezone.utc) conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() return message async def get_conversation_history( self, conversation_id: int, limit: int = 20 ) -> List[Message]: """Get recent conversation history for context.""" return (await self.session.scalars( select(Message).where( Message.conversation_id == conversation_id ).order_by(desc(Message.created_at)).limit(limit) )).all()[::-1] # Reverse to get chronological order async def update_conversation_timestamp(self, conversation_id: int) -> None: """Update conversation's updated_at timestamp.""" conversation = await self.get_conversation(conversation_id) if conversation: conversation.updated_at = datetime.now(timezone.utc) conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() async def get_user_conversations_count( self, search_query: Optional[str] = None, include_archived: bool = False ) -> int: """Get total count of user's conversations.""" user_id = UserContext.get_current_user_id() query = select(func.count(Conversation.id)).where( Conversation.user_id == user_id ) if not include_archived: query = query.where(Conversation.is_archived == False) if search_query and search_query.strip(): search_term = f"%{search_query.strip()}%" query = query.where( or_( Conversation.title.ilike(search_term), Conversation.system_prompt.ilike(search_term) ) ) return (await self.session.scalar(query)) or 0 async def archive_conversation(self, conversation_id: int) -> bool: """Archive a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False conversation.is_archived = True conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() return True async def unarchive_conversation(self, conversation_id: int) -> bool: """Unarchive a conversation.""" conversation = await self.get_conversation(conversation_id) if not conversation: return False conversation.is_archived = False conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) await self.session.commit() return True