hxf/backend/th_agenter/services/redis_memory.py

239 lines
11 KiB
Python

"""Redis-based conversation memory service for LangChain with local fallback."""
import json
import redis
import uuid
import os
from typing import List, Dict, Any, Optional
from datetime import timedelta
from ..core.config import settings
from ..utils.logger import get_logger
logger = get_logger("redis_memory_service")
class RedisMemoryService:
"""Redis-based conversation memory service with local file fallback."""
def __init__(self):
"""Initialize Redis connection and local storage."""
# Redis connection setup
try:
# Get Redis configuration from settings
redis_config = getattr(settings, 'redis', None)
if redis_config:
logger.debug(f"Attempting to connect to Redis at {redis_config.host}:{redis_config.port} db={redis_config.db}")
# 使用与成功案例相同的连接方式
self.redis_client = redis.Redis(
host=redis_config.host,
port=redis_config.port,
db=redis_config.db,
password=redis_config.password,
decode_responses=True
)
# Test connection like the successful example
self.redis_client.ping()
logger.info(f"Successfully connected to Redis at {redis_config.host}:{redis_config.port} db={redis_config.db}")
# Set default TTL for conversation memory (30 days)
self.default_ttl = timedelta(days=30)
else:
logger.warning("Redis settings not found")
self.redis_client = None
except Exception as e:
logger.warning(f"Redis not available: {str(e)}. Conversation memory will be stored in local files only.")
logger.debug(f"Redis connection error details: {repr(e)}")
self.redis_client = None
# Local storage setup
# 使用绝对路径确保目录位置正确
self.history_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "conversation_history")
# 确保历史记录目录存在
os.makedirs(self.history_dir, exist_ok=True)
logger.info(f"Conversation history directory: {self.history_dir} (exists: {os.path.exists(self.history_dir)})")
logger.info(f"Current working directory: {os.getcwd()}")
def _get_conversation_key(self, conversation_id: str) -> str:
"""Generate Redis key for conversation memory."""
return f"conversation:{conversation_id}:messages"
def generate_conversation_id(self) -> str:
"""Generate a unique conversation ID."""
return str(uuid.uuid4())
def save_message(self, conversation_id: str, role: str, content: str, message_metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Save a message to conversation memory (Redis with local fallback)."""
# Prepare message data
message = {
"role": role,
"content": content,
"metadata": message_metadata or {}
}
saved_to_redis = False
# Try to save to Redis first
if self.redis_client:
try:
key = self._get_conversation_key(conversation_id)
# Add message to list
self.redis_client.lpush(key, json.dumps(message))
# Set TTL if not already set
if not self.redis_client.ttl(key):
self.redis_client.expire(key, self.default_ttl)
logger.debug(f"Saved message to Redis for conversation {conversation_id}")
saved_to_redis = True
except Exception as e:
logger.warning(f"Redis save failed: {str(e)}")
logger.debug(f"Redis save error details: {repr(e)}")
# Always save to local file as backup
try:
logger.debug(f"Attempting to save message to local file for conversation {conversation_id}")
# Get full history first
all_messages = self._get_local_history(conversation_id)
logger.debug(f"Current messages in local history: {len(all_messages)}")
# Add new message
all_messages.append(message)
logger.debug(f"Added new message, total messages: {len(all_messages)}")
# Limit history length (keep last 100 messages)
if len(all_messages) > 100:
all_messages = all_messages[-100:]
logger.debug(f"History truncated to 100 messages")
# Save to file
history_file = self._get_local_history_file(conversation_id)
logger.debug(f"Saving to file: {history_file}")
# Ensure directory exists
os.makedirs(os.path.dirname(history_file), exist_ok=True)
with open(history_file, 'w', encoding='utf-8') as f:
json.dump(all_messages, f, ensure_ascii=False, indent=2)
# Verify file was written
if os.path.exists(history_file):
file_size = os.path.getsize(history_file)
logger.debug(f"✓ Saved message to local file for conversation {conversation_id}, file size: {file_size} bytes")
return True
else:
logger.error(f"✗ File was not created after write attempt: {history_file}")
return saved_to_redis
except Exception as e:
logger.error(f"✗ Local file save failed: {str(e)}")
logger.error(f"✗ Error type: {type(e).__name__}")
import traceback
logger.error(f"✗ Traceback: {traceback.format_exc()}")
return saved_to_redis # Return True only if saved to Redis
def get_conversation_history(self, conversation_id: str, limit: int = 20) -> List[Dict[str, Any]]:
"""Get conversation history (Redis with local fallback)."""
# Try to get from Redis first
if self.redis_client:
try:
key = self._get_conversation_key(conversation_id)
# Get messages from Redis list (latest messages first)
messages_json = self.redis_client.lrange(key, 0, limit - 1)
# Parse and reverse to get chronological order
messages = [json.loads(msg) for msg in messages_json]
messages.reverse()
if messages:
logger.debug(f"Retrieved {len(messages)} messages from Redis for conversation {conversation_id}")
return messages
except Exception as e:
logger.warning(f"Redis history retrieval failed: {str(e)}")
logger.debug(f"Redis history error details: {repr(e)}")
# Fallback to local file
try:
all_messages = self._get_local_history(conversation_id)
# Return last N messages
result = all_messages[-limit:] if len(all_messages) > limit else all_messages
logger.debug(f"Retrieved {len(result)} messages from local file for conversation {conversation_id}")
return result
except Exception as e:
logger.error(f"Local file history retrieval failed: {str(e)}")
logger.debug(f"Local file history error details: {repr(e)}")
return []
def _get_local_history_file(self, conversation_id: str) -> str:
"""Get local history file path for conversation."""
file_path = os.path.join(self.history_dir, f"{conversation_id}.json")
logger.debug(f"Local history file path: {file_path} (directory exists: {os.path.exists(self.history_dir)})")
return file_path
def _get_local_history(self, conversation_id: str) -> List[Dict[str, Any]]:
"""Get full conversation history from local file."""
history_file = self._get_local_history_file(conversation_id)
if os.path.exists(history_file):
try:
with open(history_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to read local history file {history_file}: {str(e)}")
return []
return []
def clear_conversation_history(self, conversation_id: str) -> bool:
"""Clear conversation history from Redis and local file."""
cleared = False
# Clear from Redis first
if self.redis_client:
try:
key = self._get_conversation_key(conversation_id)
self.redis_client.delete(key)
logger.debug(f"Cleared conversation history from Redis for conversation {conversation_id}")
cleared = True
except Exception as e:
logger.error(f"Failed to clear conversation history from Redis: {str(e)}")
logger.debug(f"Redis clear error details: {repr(e)}")
# Clear from local file
try:
history_file = self._get_local_history_file(conversation_id)
if os.path.exists(history_file):
os.remove(history_file)
logger.debug(f"Cleared conversation history from local file for conversation {conversation_id}")
cleared = True
except Exception as e:
logger.error(f"Failed to clear conversation history from local file: {str(e)}")
logger.debug(f"Local file clear error details: {repr(e)}")
return cleared
def get_message_count(self, conversation_id: str) -> int:
"""Get number of messages in conversation history."""
# Try to get from Redis first
if self.redis_client:
try:
key = self._get_conversation_key(conversation_id)
count = self.redis_client.llen(key)
if count > 0:
return count
except Exception as e:
logger.warning(f"Failed to get message count from Redis: {str(e)}")
logger.debug(f"Redis message count error details: {repr(e)}")
# Fallback to local file
try:
all_messages = self._get_local_history(conversation_id)
return len(all_messages)
except Exception as e:
logger.error(f"Failed to get message count from local file: {str(e)}")
logger.debug(f"Local file message count error details: {repr(e)}")
return 0
# Global instance
redis_memory_service = RedisMemoryService()