hxf/backend/th_agenter/services/conversation_context.py

310 lines
9.7 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
from typing import Dict, Any, List, Optional
import json
from datetime import datetime
from sqlalchemy.orm import Session
from th_agenter.models.conversation import Conversation
from th_agenter.models.message import Message
2025-12-16 13:55:16 +08:00
from th_agenter.db.database import get_session
2025-12-04 14:48:38 +08:00
class ConversationContextService:
"""
对话上下文管理服务
用于管理智能问数的对话历史和上下文信息
"""
def __init__(self):
self.context_cache = {} # 内存缓存对话上下文
async def create_conversation(self, user_id: int, title: str = "智能问数对话") -> int:
"""
创建新的对话
Args:
user_id: 用户ID
title: 对话标题
Returns:
新创建的对话ID
"""
try:
2026-01-07 11:30:54 +08:00
session = await anext(get_session())
2025-12-04 14:48:38 +08:00
conversation = Conversation(
user_id=user_id,
title=title,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
2025-12-16 13:55:16 +08:00
session.add(conversation)
2026-01-07 11:30:54 +08:00
await session.commit()
await session.refresh(conversation)
2025-12-04 14:48:38 +08:00
# 初始化对话上下文
self.context_cache[conversation.id] = {
'conversation_id': conversation.id,
'user_id': user_id,
'file_list': [],
'selected_files': [],
'query_history': [],
'created_at': datetime.utcnow().isoformat()
}
return conversation.id
except Exception as e:
print(f"创建对话失败: {e}")
raise
finally:
2025-12-16 13:55:16 +08:00
session.close()
2025-12-04 14:48:38 +08:00
async def get_conversation_context(self, conversation_id: int) -> Optional[Dict[str, Any]]:
"""
获取对话上下文
Args:
conversation_id: 对话ID
Returns:
对话上下文信息
"""
# 先从缓存中查找
if conversation_id in self.context_cache:
return self.context_cache[conversation_id]
# 从数据库加载
try:
2026-01-07 11:30:54 +08:00
session = await anext(get_session())
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
conversation = session.query(Conversation).filter(
2025-12-04 14:48:38 +08:00
Conversation.id == conversation_id
).first()
if not conversation:
return None
# 加载消息历史
2025-12-16 13:55:16 +08:00
messages = session.query(Message).filter(
2025-12-04 14:48:38 +08:00
Message.conversation_id == conversation_id
).order_by(Message.created_at).all()
# 重建上下文
context = {
'conversation_id': conversation_id,
'user_id': conversation.user_id,
'file_list': [],
'selected_files': [],
'query_history': [],
'created_at': conversation.created_at.isoformat()
}
# 从消息中提取查询历史
for message in messages:
if message.role == 'user':
context['query_history'].append({
'query': message.content,
'timestamp': message.created_at.isoformat()
})
elif message.role == 'assistant' and message.metadata:
# 从助手消息的元数据中提取文件信息
try:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
if 'selected_files' in metadata:
context['selected_files'] = metadata['selected_files']
if 'file_list' in metadata:
context['file_list'] = metadata['file_list']
except (json.JSONDecodeError, TypeError):
pass
# 缓存上下文
self.context_cache[conversation_id] = context
return context
except Exception as e:
print(f"获取对话上下文失败: {e}")
return None
finally:
2025-12-16 13:55:16 +08:00
session.close()
2025-12-04 14:48:38 +08:00
async def update_conversation_context(
self,
conversation_id: int,
file_list: List[Dict[str, Any]] = None,
selected_files: List[Dict[str, Any]] = None,
query: str = None
) -> bool:
"""
更新对话上下文
Args:
conversation_id: 对话ID
file_list: 文件列表
selected_files: 选中的文件
query: 用户查询
Returns:
更新是否成功
"""
try:
# 获取或创建上下文
context = await self.get_conversation_context(conversation_id)
if not context:
return False
# 更新上下文信息
if file_list is not None:
context['file_list'] = file_list
if selected_files is not None:
context['selected_files'] = selected_files
if query is not None:
context['query_history'].append({
'query': query,
'timestamp': datetime.utcnow().isoformat()
})
# 更新缓存
self.context_cache[conversation_id] = context
return True
except Exception as e:
print(f"更新对话上下文失败: {e}")
return False
async def save_message(
self,
conversation_id: int,
role: str,
content: str,
metadata: Dict[str, Any] = None
) -> bool:
"""
保存消息到数据库
Args:
conversation_id: 对话ID
role: 消息角色 (user/assistant)
content: 消息内容
metadata: 元数据
Returns:
保存是否成功
"""
try:
2026-01-07 11:30:54 +08:00
session = await anext(get_session())
2025-12-04 14:48:38 +08:00
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
metadata=json.dumps(metadata) if metadata else None,
created_at=datetime.utcnow()
)
2025-12-16 13:55:16 +08:00
session.add(message)
2026-01-07 11:30:54 +08:00
await session.commit()
2025-12-04 14:48:38 +08:00
# 更新对话的最后更新时间
2025-12-16 13:55:16 +08:00
conversation = session.query(Conversation).filter(
2025-12-04 14:48:38 +08:00
Conversation.id == conversation_id
).first()
if conversation:
conversation.updated_at = datetime.utcnow()
2026-01-07 11:30:54 +08:00
await session.commit()
2025-12-04 14:48:38 +08:00
return True
except Exception as e:
print(f"保存消息失败: {e}")
return False
finally:
2025-12-16 13:55:16 +08:00
session.close()
2025-12-04 14:48:38 +08:00
async def reset_conversation_context(self, conversation_id: int) -> bool:
"""
重置对话上下文
Args:
conversation_id: 对话ID
Returns:
重置是否成功
"""
try:
# 清除缓存
if conversation_id in self.context_cache:
context = self.context_cache[conversation_id]
# 保留基本信息,清除文件和查询历史
context.update({
'file_list': [],
'selected_files': [],
'query_history': []
})
return True
except Exception as e:
print(f"重置对话上下文失败: {e}")
return False
async def get_conversation_history(self, conversation_id: int) -> List[Dict[str, Any]]:
"""
获取对话历史消息
Args:
conversation_id: 对话ID
Returns:
消息历史列表
"""
try:
2026-01-07 11:30:54 +08:00
session = await anext(get_session())
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
messages = session.query(Message).filter(
2025-12-04 14:48:38 +08:00
Message.conversation_id == conversation_id
).order_by(Message.created_at).all()
history = []
for message in messages:
msg_data = {
'id': message.id,
'role': message.role,
'content': message.content,
'timestamp': message.created_at.isoformat()
}
if message.metadata:
try:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
msg_data['metadata'] = metadata
except (json.JSONDecodeError, TypeError):
pass
history.append(msg_data)
return history
except Exception as e:
print(f"获取对话历史失败: {e}")
return []
finally:
2025-12-16 13:55:16 +08:00
session.close()
2025-12-04 14:48:38 +08:00
def clear_cache(self, conversation_id: int = None):
"""
清除缓存
Args:
conversation_id: 特定对话ID如果为None则清除所有缓存
"""
if conversation_id:
self.context_cache.pop(conversation_id, None)
else:
self.context_cache.clear()
# 全局实例
conversation_context_service = ConversationContextService()