324 lines
11 KiB
Python
324 lines
11 KiB
Python
from typing import Dict, Any, List, Optional
|
||
import json
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import select
|
||
from th_agenter.models.conversation import Conversation
|
||
from th_agenter.models.message import Message
|
||
from th_agenter.db.database import get_session
|
||
|
||
class ConversationContextService:
|
||
"""
|
||
对话上下文管理服务
|
||
用于管理智能问数的对话历史和上下文信息
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.context_cache = {} # 内存缓存对话上下文
|
||
|
||
async def create_conversation(self, user_id: int, title: str = "智能问数对话") -> int:
|
||
"""
|
||
创建新的对话
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
title: 对话标题
|
||
|
||
Returns:
|
||
新创建的对话ID
|
||
"""
|
||
try:
|
||
# Python 3.9 兼容:使用 async for 替代 anext
|
||
async for session in get_session():
|
||
try:
|
||
conversation = Conversation(
|
||
user_id=user_id,
|
||
title=title,
|
||
created_at=datetime.utcnow(),
|
||
updated_at=datetime.utcnow()
|
||
)
|
||
|
||
session.add(conversation)
|
||
await session.commit()
|
||
await session.refresh(conversation)
|
||
|
||
# 初始化对话上下文
|
||
self.context_cache[conversation.id] = {
|
||
'conversation_id': conversation.id,
|
||
'user_id': user_id,
|
||
'file_list': [],
|
||
'selected_files': [],
|
||
'query_history': [],
|
||
'created_at': datetime.utcnow().isoformat()
|
||
}
|
||
|
||
await session.close()
|
||
return conversation.id
|
||
finally:
|
||
break # 只取第一个 session
|
||
|
||
except Exception as e:
|
||
print(f"创建对话失败: {e}")
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
async def get_conversation_context(self, conversation_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取对话上下文
|
||
|
||
Args:
|
||
conversation_id: 对话ID
|
||
|
||
Returns:
|
||
对话上下文信息
|
||
"""
|
||
# 先从缓存中查找
|
||
if conversation_id in self.context_cache:
|
||
return self.context_cache[conversation_id]
|
||
|
||
# 从数据库加载
|
||
try:
|
||
# Python 3.9 兼容:使用 async for 替代 anext
|
||
async for session in get_session():
|
||
try:
|
||
conversation = await session.scalar(
|
||
select(Conversation).where(Conversation.id == conversation_id)
|
||
)
|
||
|
||
if not conversation:
|
||
await session.close()
|
||
return None
|
||
|
||
# 加载消息历史
|
||
messages = await session.scalars(
|
||
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||
)
|
||
messages_list = list(messages)
|
||
|
||
# 重建上下文
|
||
context = {
|
||
'conversation_id': conversation_id,
|
||
'user_id': conversation.user_id,
|
||
'file_list': [],
|
||
'selected_files': [],
|
||
'query_history': [],
|
||
'created_at': conversation.created_at.isoformat()
|
||
}
|
||
|
||
# 从消息中提取查询历史
|
||
for message in messages_list:
|
||
if message.role == 'user':
|
||
context['query_history'].append({
|
||
'query': message.content,
|
||
'timestamp': message.created_at.isoformat()
|
||
})
|
||
elif message.role == 'assistant' and message.metadata:
|
||
# 从助手消息的元数据中提取文件信息
|
||
try:
|
||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||
if 'selected_files' in metadata:
|
||
context['selected_files'] = metadata['selected_files']
|
||
if 'file_list' in metadata:
|
||
context['file_list'] = metadata['file_list']
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
# 缓存上下文
|
||
self.context_cache[conversation_id] = context
|
||
|
||
await session.close()
|
||
return context
|
||
finally:
|
||
break # 只取第一个 session
|
||
|
||
except Exception as e:
|
||
print(f"获取对话上下文失败: {e}")
|
||
return None
|
||
finally:
|
||
session.close()
|
||
|
||
async def update_conversation_context(
|
||
self,
|
||
conversation_id: int,
|
||
file_list: List[Dict[str, Any]] = None,
|
||
selected_files: List[Dict[str, Any]] = None,
|
||
query: str = None
|
||
) -> bool:
|
||
"""
|
||
更新对话上下文
|
||
|
||
Args:
|
||
conversation_id: 对话ID
|
||
file_list: 文件列表
|
||
selected_files: 选中的文件
|
||
query: 用户查询
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
try:
|
||
# 获取或创建上下文
|
||
context = await self.get_conversation_context(conversation_id)
|
||
if not context:
|
||
return False
|
||
|
||
# 更新上下文信息
|
||
if file_list is not None:
|
||
context['file_list'] = file_list
|
||
|
||
if selected_files is not None:
|
||
context['selected_files'] = selected_files
|
||
|
||
if query is not None:
|
||
context['query_history'].append({
|
||
'query': query,
|
||
'timestamp': datetime.utcnow().isoformat()
|
||
})
|
||
|
||
# 更新缓存
|
||
self.context_cache[conversation_id] = context
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"更新对话上下文失败: {e}")
|
||
return False
|
||
|
||
async def save_message(
|
||
self,
|
||
conversation_id: int,
|
||
role: str,
|
||
content: str,
|
||
metadata: Dict[str, Any] = None
|
||
) -> bool:
|
||
"""
|
||
保存消息到数据库
|
||
|
||
Args:
|
||
conversation_id: 对话ID
|
||
role: 消息角色 (user/assistant)
|
||
content: 消息内容
|
||
metadata: 元数据
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
try:
|
||
# Python 3.9 兼容:使用 async for 替代 anext
|
||
async for session in get_session():
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
metadata=json.dumps(metadata) if metadata else None,
|
||
created_at=datetime.utcnow()
|
||
)
|
||
|
||
session.add(message)
|
||
await session.commit()
|
||
|
||
# 更新对话的最后更新时间
|
||
conversation = await session.scalar(
|
||
select(Conversation).where(Conversation.id == conversation_id)
|
||
)
|
||
|
||
if conversation:
|
||
conversation.updated_at = datetime.utcnow()
|
||
await session.commit()
|
||
|
||
await session.close()
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"保存消息失败: {e}")
|
||
return False
|
||
finally:
|
||
session.close()
|
||
|
||
async def reset_conversation_context(self, conversation_id: int) -> bool:
|
||
"""
|
||
重置对话上下文
|
||
|
||
Args:
|
||
conversation_id: 对话ID
|
||
|
||
Returns:
|
||
重置是否成功
|
||
"""
|
||
try:
|
||
# 清除缓存
|
||
if conversation_id in self.context_cache:
|
||
context = self.context_cache[conversation_id]
|
||
# 保留基本信息,清除文件和查询历史
|
||
context.update({
|
||
'file_list': [],
|
||
'selected_files': [],
|
||
'query_history': []
|
||
})
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"重置对话上下文失败: {e}")
|
||
return False
|
||
|
||
async def get_conversation_history(self, conversation_id: int) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取对话历史消息
|
||
|
||
Args:
|
||
conversation_id: 对话ID
|
||
|
||
Returns:
|
||
消息历史列表
|
||
"""
|
||
try:
|
||
# Python 3.9 兼容:使用 async for 替代 anext
|
||
async for session in get_session():
|
||
messages = await session.scalars(
|
||
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||
)
|
||
messages_list = list(messages)
|
||
|
||
history = []
|
||
for message in messages_list:
|
||
msg_data = {
|
||
'id': message.id,
|
||
'role': message.role,
|
||
'content': message.content,
|
||
'timestamp': message.created_at.isoformat()
|
||
}
|
||
|
||
if message.metadata:
|
||
try:
|
||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||
msg_data['metadata'] = metadata
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
history.append(msg_data)
|
||
|
||
await session.close()
|
||
return history
|
||
|
||
except Exception as e:
|
||
print(f"获取对话历史失败: {e}")
|
||
return []
|
||
finally:
|
||
session.close()
|
||
|
||
def clear_cache(self, conversation_id: int = None):
|
||
"""
|
||
清除缓存
|
||
|
||
Args:
|
||
conversation_id: 特定对话ID,如果为None则清除所有缓存
|
||
"""
|
||
if conversation_id:
|
||
self.context_cache.pop(conversation_id, None)
|
||
else:
|
||
self.context_cache.clear()
|
||
|
||
# 全局实例
|
||
conversation_context_service = ConversationContextService() |