260 lines
8.9 KiB
Python
260 lines
8.9 KiB
Python
"""Conversation service."""
|
|
|
|
from typing import List, Optional
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import desc, func, or_
|
|
|
|
from ..models.conversation import Conversation
|
|
from ..models.message import Message, MessageRole
|
|
from ..utils.schemas import ConversationCreate, ConversationUpdate
|
|
from ..utils.exceptions import ConversationNotFoundError, DatabaseError
|
|
from ..utils.logger import get_logger
|
|
from ..core.context import UserContext
|
|
|
|
logger = get_logger("conversation_service")
|
|
|
|
|
|
class ConversationService:
|
|
"""Service for managing conversations and messages."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def create_conversation(
|
|
self,
|
|
user_id: int,
|
|
conversation_data: ConversationCreate
|
|
) -> Conversation:
|
|
"""Create a new conversation."""
|
|
logger.info(f"Creating new conversation for user {user_id}")
|
|
|
|
try:
|
|
conversation = Conversation(
|
|
**conversation_data.dict(),
|
|
user_id=user_id
|
|
)
|
|
|
|
# Set audit fields
|
|
conversation.set_audit_fields(user_id=user_id, is_update=False)
|
|
|
|
self.db.add(conversation)
|
|
self.db.commit()
|
|
self.db.refresh(conversation)
|
|
|
|
logger.info(f"Successfully created conversation {conversation.id} for user {user_id}")
|
|
return conversation
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create conversation: {str(e)}", exc_info=True)
|
|
self.db.rollback()
|
|
raise DatabaseError(f"Failed to create conversation: {str(e)}")
|
|
|
|
def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
|
|
"""Get a conversation by ID."""
|
|
try:
|
|
user_id = UserContext.get_current_user_id()
|
|
conversation = self.db.query(Conversation).filter(
|
|
Conversation.id == conversation_id,
|
|
Conversation.user_id == user_id
|
|
).first()
|
|
|
|
if not conversation:
|
|
logger.warning(f"Conversation {conversation_id} not found")
|
|
|
|
return conversation
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True)
|
|
raise DatabaseError(f"Failed to get conversation: {str(e)}")
|
|
|
|
def get_user_conversations(
|
|
self,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
search_query: Optional[str] = None,
|
|
include_archived: bool = False,
|
|
order_by: str = "updated_at",
|
|
order_desc: bool = True
|
|
) -> List[Conversation]:
|
|
"""Get user's conversations with search and filtering."""
|
|
user_id = UserContext.get_current_user_id()
|
|
query = self.db.query(Conversation).filter(
|
|
Conversation.user_id == user_id
|
|
)
|
|
|
|
# Filter archived conversations
|
|
if not include_archived:
|
|
query = query.filter(Conversation.is_archived == False)
|
|
|
|
# Search functionality
|
|
if search_query and search_query.strip():
|
|
search_term = f"%{search_query.strip()}%"
|
|
query = query.filter(
|
|
or_(
|
|
Conversation.title.ilike(search_term),
|
|
Conversation.system_prompt.ilike(search_term)
|
|
)
|
|
)
|
|
|
|
# Ordering
|
|
order_column = getattr(Conversation, order_by, Conversation.updated_at)
|
|
if order_desc:
|
|
query = query.order_by(desc(order_column))
|
|
else:
|
|
query = query.order_by(order_column)
|
|
|
|
return query.offset(skip).limit(limit).all()
|
|
|
|
def update_conversation(
|
|
self,
|
|
conversation_id: int,
|
|
conversation_update: ConversationUpdate
|
|
) -> Optional[Conversation]:
|
|
"""Update a conversation."""
|
|
conversation = self.get_conversation(conversation_id)
|
|
if not conversation:
|
|
return None
|
|
|
|
update_data = conversation_update.dict(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(conversation, field, value)
|
|
|
|
self.db.commit()
|
|
self.db.refresh(conversation)
|
|
return conversation
|
|
|
|
def delete_conversation(self, conversation_id: int) -> bool:
|
|
"""Delete a conversation."""
|
|
conversation = self.get_conversation(conversation_id)
|
|
if not conversation:
|
|
return False
|
|
|
|
self.db.delete(conversation)
|
|
self.db.commit()
|
|
return True
|
|
|
|
def get_conversation_messages(
|
|
self,
|
|
conversation_id: int,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[Message]:
|
|
"""Get messages from a conversation."""
|
|
return self.db.query(Message).filter(
|
|
Message.conversation_id == conversation_id
|
|
).order_by(Message.created_at).offset(skip).limit(limit).all()
|
|
|
|
def add_message(
|
|
self,
|
|
conversation_id: int,
|
|
content: str,
|
|
role: MessageRole,
|
|
message_metadata: Optional[dict] = None,
|
|
context_documents: Optional[list] = None,
|
|
prompt_tokens: Optional[int] = None,
|
|
completion_tokens: Optional[int] = None,
|
|
total_tokens: Optional[int] = None
|
|
) -> Message:
|
|
"""Add a message to a conversation."""
|
|
message = Message(
|
|
conversation_id=conversation_id,
|
|
content=content,
|
|
role=role,
|
|
message_metadata=message_metadata,
|
|
context_documents=context_documents,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens
|
|
)
|
|
|
|
# Set audit fields
|
|
message.set_audit_fields()
|
|
|
|
self.db.add(message)
|
|
self.db.commit()
|
|
self.db.refresh(message)
|
|
return message
|
|
|
|
def get_conversation_history(
|
|
self,
|
|
conversation_id: int,
|
|
limit: int = 20
|
|
) -> List[Message]:
|
|
"""Get recent conversation history for context."""
|
|
return self.db.query(Message).filter(
|
|
Message.conversation_id == conversation_id
|
|
).order_by(desc(Message.created_at)).limit(limit).all()[::-1] # Reverse to get chronological order
|
|
|
|
def update_conversation_timestamp(self, conversation_id: int) -> None:
|
|
"""Update conversation's updated_at timestamp."""
|
|
conversation = self.get_conversation(conversation_id)
|
|
if conversation:
|
|
# SQLAlchemy will automatically update the updated_at field
|
|
self.db.commit()
|
|
|
|
def get_user_conversations_count(
|
|
self,
|
|
search_query: Optional[str] = None,
|
|
include_archived: bool = False
|
|
) -> int:
|
|
"""Get total count of user's conversations."""
|
|
user_id = UserContext.get_current_user_id()
|
|
query = self.db.query(func.count(Conversation.id)).filter(
|
|
Conversation.user_id == user_id
|
|
)
|
|
|
|
if not include_archived:
|
|
query = query.filter(Conversation.is_archived == False)
|
|
|
|
if search_query and search_query.strip():
|
|
search_term = f"%{search_query.strip()}%"
|
|
query = query.filter(
|
|
or_(
|
|
Conversation.title.ilike(search_term),
|
|
Conversation.system_prompt.ilike(search_term)
|
|
)
|
|
)
|
|
|
|
return query.scalar() or 0
|
|
|
|
def archive_conversation(self, conversation_id: int) -> bool:
|
|
"""Archive a conversation."""
|
|
conversation = self.get_conversation(conversation_id)
|
|
if not conversation:
|
|
return False
|
|
|
|
conversation.is_archived = True
|
|
self.db.commit()
|
|
return True
|
|
|
|
def unarchive_conversation(self, conversation_id: int) -> bool:
|
|
"""Unarchive a conversation."""
|
|
conversation = self.get_conversation(conversation_id)
|
|
if not conversation:
|
|
return False
|
|
|
|
conversation.is_archived = False
|
|
self.db.commit()
|
|
return True
|
|
|
|
def delete_all_conversations(self) -> bool:
|
|
"""Delete all conversations for the current user."""
|
|
try:
|
|
user_id = UserContext.get_current_user_id()
|
|
# Get all conversations for the user
|
|
conversations = self.db.query(Conversation).filter(
|
|
Conversation.user_id == user_id
|
|
).all()
|
|
|
|
# Delete each conversation
|
|
for conversation in conversations:
|
|
self.db.delete(conversation)
|
|
|
|
self.db.commit()
|
|
logger.info(f"Successfully deleted all conversations for user {user_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete all conversations: {str(e)}", exc_info=True)
|
|
self.db.rollback()
|
|
raise DatabaseError(f"Failed to delete all conversations: {str(e)}") |