hyf-backend/th_agenter/services/conversation_context.py

324 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, Any, List, Optional
import json
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import select
from th_agenter.models.conversation import Conversation
from th_agenter.models.message import Message
from th_agenter.db.database import get_session
class ConversationContextService:
"""
对话上下文管理服务
用于管理智能问数的对话历史和上下文信息
"""
def __init__(self):
self.context_cache = {} # 内存缓存对话上下文
async def create_conversation(self, user_id: int, title: str = "智能问数对话") -> int:
"""
创建新的对话
Args:
user_id: 用户ID
title: 对话标题
Returns:
新创建的对话ID
"""
try:
# Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
try:
conversation = Conversation(
user_id=user_id,
title=title,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(conversation)
await session.commit()
await session.refresh(conversation)
# 初始化对话上下文
self.context_cache[conversation.id] = {
'conversation_id': conversation.id,
'user_id': user_id,
'file_list': [],
'selected_files': [],
'query_history': [],
'created_at': datetime.utcnow().isoformat()
}
await session.close()
return conversation.id
finally:
break # 只取第一个 session
except Exception as e:
print(f"创建对话失败: {e}")
raise
finally:
session.close()
async def get_conversation_context(self, conversation_id: int) -> Optional[Dict[str, Any]]:
"""
获取对话上下文
Args:
conversation_id: 对话ID
Returns:
对话上下文信息
"""
# 先从缓存中查找
if conversation_id in self.context_cache:
return self.context_cache[conversation_id]
# 从数据库加载
try:
# Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
try:
conversation = await session.scalar(
select(Conversation).where(Conversation.id == conversation_id)
)
if not conversation:
await session.close()
return None
# 加载消息历史
messages = await session.scalars(
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
)
messages_list = list(messages)
# 重建上下文
context = {
'conversation_id': conversation_id,
'user_id': conversation.user_id,
'file_list': [],
'selected_files': [],
'query_history': [],
'created_at': conversation.created_at.isoformat()
}
# 从消息中提取查询历史
for message in messages_list:
if message.role == 'user':
context['query_history'].append({
'query': message.content,
'timestamp': message.created_at.isoformat()
})
elif message.role == 'assistant' and message.metadata:
# 从助手消息的元数据中提取文件信息
try:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
if 'selected_files' in metadata:
context['selected_files'] = metadata['selected_files']
if 'file_list' in metadata:
context['file_list'] = metadata['file_list']
except (json.JSONDecodeError, TypeError):
pass
# 缓存上下文
self.context_cache[conversation_id] = context
await session.close()
return context
finally:
break # 只取第一个 session
except Exception as e:
print(f"获取对话上下文失败: {e}")
return None
finally:
session.close()
async def update_conversation_context(
self,
conversation_id: int,
file_list: List[Dict[str, Any]] = None,
selected_files: List[Dict[str, Any]] = None,
query: str = None
) -> bool:
"""
更新对话上下文
Args:
conversation_id: 对话ID
file_list: 文件列表
selected_files: 选中的文件
query: 用户查询
Returns:
更新是否成功
"""
try:
# 获取或创建上下文
context = await self.get_conversation_context(conversation_id)
if not context:
return False
# 更新上下文信息
if file_list is not None:
context['file_list'] = file_list
if selected_files is not None:
context['selected_files'] = selected_files
if query is not None:
context['query_history'].append({
'query': query,
'timestamp': datetime.utcnow().isoformat()
})
# 更新缓存
self.context_cache[conversation_id] = context
return True
except Exception as e:
print(f"更新对话上下文失败: {e}")
return False
async def save_message(
self,
conversation_id: int,
role: str,
content: str,
metadata: Dict[str, Any] = None
) -> bool:
"""
保存消息到数据库
Args:
conversation_id: 对话ID
role: 消息角色 (user/assistant)
content: 消息内容
metadata: 元数据
Returns:
保存是否成功
"""
try:
# Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
metadata=json.dumps(metadata) if metadata else None,
created_at=datetime.utcnow()
)
session.add(message)
await session.commit()
# 更新对话的最后更新时间
conversation = await session.scalar(
select(Conversation).where(Conversation.id == conversation_id)
)
if conversation:
conversation.updated_at = datetime.utcnow()
await session.commit()
await session.close()
return True
except Exception as e:
print(f"保存消息失败: {e}")
return False
finally:
session.close()
async def reset_conversation_context(self, conversation_id: int) -> bool:
"""
重置对话上下文
Args:
conversation_id: 对话ID
Returns:
重置是否成功
"""
try:
# 清除缓存
if conversation_id in self.context_cache:
context = self.context_cache[conversation_id]
# 保留基本信息,清除文件和查询历史
context.update({
'file_list': [],
'selected_files': [],
'query_history': []
})
return True
except Exception as e:
print(f"重置对话上下文失败: {e}")
return False
async def get_conversation_history(self, conversation_id: int) -> List[Dict[str, Any]]:
"""
获取对话历史消息
Args:
conversation_id: 对话ID
Returns:
消息历史列表
"""
try:
# Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
messages = await session.scalars(
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
)
messages_list = list(messages)
history = []
for message in messages_list:
msg_data = {
'id': message.id,
'role': message.role,
'content': message.content,
'timestamp': message.created_at.isoformat()
}
if message.metadata:
try:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
msg_data['metadata'] = metadata
except (json.JSONDecodeError, TypeError):
pass
history.append(msg_data)
await session.close()
return history
except Exception as e:
print(f"获取对话历史失败: {e}")
return []
finally:
session.close()
def clear_cache(self, conversation_id: int = None):
"""
清除缓存
Args:
conversation_id: 特定对话ID如果为None则清除所有缓存
"""
if conversation_id:
self.context_cache.pop(conversation_id, None)
else:
self.context_cache.clear()
# 全局实例
conversation_context_service = ConversationContextService()