239 lines
11 KiB
Python
239 lines
11 KiB
Python
"""Redis-based conversation memory service for LangChain with local fallback."""
|
|
|
|
import json
|
|
import redis
|
|
import uuid
|
|
import os
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import timedelta
|
|
|
|
from ..core.config import settings
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger("redis_memory_service")
|
|
|
|
|
|
class RedisMemoryService:
|
|
"""Redis-based conversation memory service with local file fallback."""
|
|
|
|
def __init__(self):
|
|
"""Initialize Redis connection and local storage."""
|
|
# Redis connection setup
|
|
try:
|
|
# Get Redis configuration from settings
|
|
redis_config = getattr(settings, 'redis', None)
|
|
if redis_config:
|
|
logger.debug(f"Attempting to connect to Redis at {redis_config.host}:{redis_config.port} db={redis_config.db}")
|
|
# 使用与成功案例相同的连接方式
|
|
self.redis_client = redis.Redis(
|
|
host=redis_config.host,
|
|
port=redis_config.port,
|
|
db=redis_config.db,
|
|
password=redis_config.password,
|
|
decode_responses=True
|
|
)
|
|
# Test connection like the successful example
|
|
self.redis_client.ping()
|
|
logger.info(f"Successfully connected to Redis at {redis_config.host}:{redis_config.port} db={redis_config.db}")
|
|
|
|
# Set default TTL for conversation memory (30 days)
|
|
self.default_ttl = timedelta(days=30)
|
|
else:
|
|
logger.warning("Redis settings not found")
|
|
self.redis_client = None
|
|
except Exception as e:
|
|
logger.warning(f"Redis not available: {str(e)}. Conversation memory will be stored in local files only.")
|
|
logger.debug(f"Redis connection error details: {repr(e)}")
|
|
self.redis_client = None
|
|
|
|
# Local storage setup
|
|
# 使用绝对路径确保目录位置正确
|
|
self.history_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "conversation_history")
|
|
# 确保历史记录目录存在
|
|
os.makedirs(self.history_dir, exist_ok=True)
|
|
logger.info(f"Conversation history directory: {self.history_dir} (exists: {os.path.exists(self.history_dir)})")
|
|
logger.info(f"Current working directory: {os.getcwd()}")
|
|
|
|
def _get_conversation_key(self, conversation_id: str) -> str:
|
|
"""Generate Redis key for conversation memory."""
|
|
return f"conversation:{conversation_id}:messages"
|
|
|
|
def generate_conversation_id(self) -> str:
|
|
"""Generate a unique conversation ID."""
|
|
return str(uuid.uuid4())
|
|
|
|
def save_message(self, conversation_id: str, role: str, content: str, message_metadata: Optional[Dict[str, Any]] = None) -> bool:
|
|
"""Save a message to conversation memory (Redis with local fallback)."""
|
|
# Prepare message data
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
"metadata": message_metadata or {}
|
|
}
|
|
|
|
saved_to_redis = False
|
|
|
|
# Try to save to Redis first
|
|
if self.redis_client:
|
|
try:
|
|
key = self._get_conversation_key(conversation_id)
|
|
|
|
# Add message to list
|
|
self.redis_client.lpush(key, json.dumps(message))
|
|
|
|
# Set TTL if not already set
|
|
if not self.redis_client.ttl(key):
|
|
self.redis_client.expire(key, self.default_ttl)
|
|
|
|
logger.debug(f"Saved message to Redis for conversation {conversation_id}")
|
|
saved_to_redis = True
|
|
except Exception as e:
|
|
logger.warning(f"Redis save failed: {str(e)}")
|
|
logger.debug(f"Redis save error details: {repr(e)}")
|
|
|
|
# Always save to local file as backup
|
|
try:
|
|
logger.debug(f"Attempting to save message to local file for conversation {conversation_id}")
|
|
|
|
# Get full history first
|
|
all_messages = self._get_local_history(conversation_id)
|
|
logger.debug(f"Current messages in local history: {len(all_messages)}")
|
|
|
|
# Add new message
|
|
all_messages.append(message)
|
|
logger.debug(f"Added new message, total messages: {len(all_messages)}")
|
|
|
|
# Limit history length (keep last 100 messages)
|
|
if len(all_messages) > 100:
|
|
all_messages = all_messages[-100:]
|
|
logger.debug(f"History truncated to 100 messages")
|
|
|
|
# Save to file
|
|
history_file = self._get_local_history_file(conversation_id)
|
|
logger.debug(f"Saving to file: {history_file}")
|
|
|
|
# Ensure directory exists
|
|
os.makedirs(os.path.dirname(history_file), exist_ok=True)
|
|
|
|
with open(history_file, 'w', encoding='utf-8') as f:
|
|
json.dump(all_messages, f, ensure_ascii=False, indent=2)
|
|
|
|
# Verify file was written
|
|
if os.path.exists(history_file):
|
|
file_size = os.path.getsize(history_file)
|
|
logger.debug(f"✓ Saved message to local file for conversation {conversation_id}, file size: {file_size} bytes")
|
|
return True
|
|
else:
|
|
logger.error(f"✗ File was not created after write attempt: {history_file}")
|
|
return saved_to_redis
|
|
except Exception as e:
|
|
logger.error(f"✗ Local file save failed: {str(e)}")
|
|
logger.error(f"✗ Error type: {type(e).__name__}")
|
|
import traceback
|
|
logger.error(f"✗ Traceback: {traceback.format_exc()}")
|
|
return saved_to_redis # Return True only if saved to Redis
|
|
|
|
def get_conversation_history(self, conversation_id: str, limit: int = 20) -> List[Dict[str, Any]]:
|
|
"""Get conversation history (Redis with local fallback)."""
|
|
# Try to get from Redis first
|
|
if self.redis_client:
|
|
try:
|
|
key = self._get_conversation_key(conversation_id)
|
|
|
|
# Get messages from Redis list (latest messages first)
|
|
messages_json = self.redis_client.lrange(key, 0, limit - 1)
|
|
|
|
# Parse and reverse to get chronological order
|
|
messages = [json.loads(msg) for msg in messages_json]
|
|
messages.reverse()
|
|
|
|
if messages:
|
|
logger.debug(f"Retrieved {len(messages)} messages from Redis for conversation {conversation_id}")
|
|
return messages
|
|
except Exception as e:
|
|
logger.warning(f"Redis history retrieval failed: {str(e)}")
|
|
logger.debug(f"Redis history error details: {repr(e)}")
|
|
|
|
# Fallback to local file
|
|
try:
|
|
all_messages = self._get_local_history(conversation_id)
|
|
# Return last N messages
|
|
result = all_messages[-limit:] if len(all_messages) > limit else all_messages
|
|
logger.debug(f"Retrieved {len(result)} messages from local file for conversation {conversation_id}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Local file history retrieval failed: {str(e)}")
|
|
logger.debug(f"Local file history error details: {repr(e)}")
|
|
return []
|
|
|
|
def _get_local_history_file(self, conversation_id: str) -> str:
|
|
"""Get local history file path for conversation."""
|
|
file_path = os.path.join(self.history_dir, f"{conversation_id}.json")
|
|
logger.debug(f"Local history file path: {file_path} (directory exists: {os.path.exists(self.history_dir)})")
|
|
return file_path
|
|
|
|
def _get_local_history(self, conversation_id: str) -> List[Dict[str, Any]]:
|
|
"""Get full conversation history from local file."""
|
|
history_file = self._get_local_history_file(conversation_id)
|
|
if os.path.exists(history_file):
|
|
try:
|
|
with open(history_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Failed to read local history file {history_file}: {str(e)}")
|
|
return []
|
|
return []
|
|
|
|
def clear_conversation_history(self, conversation_id: str) -> bool:
|
|
"""Clear conversation history from Redis and local file."""
|
|
cleared = False
|
|
|
|
# Clear from Redis first
|
|
if self.redis_client:
|
|
try:
|
|
key = self._get_conversation_key(conversation_id)
|
|
self.redis_client.delete(key)
|
|
logger.debug(f"Cleared conversation history from Redis for conversation {conversation_id}")
|
|
cleared = True
|
|
except Exception as e:
|
|
logger.error(f"Failed to clear conversation history from Redis: {str(e)}")
|
|
logger.debug(f"Redis clear error details: {repr(e)}")
|
|
|
|
# Clear from local file
|
|
try:
|
|
history_file = self._get_local_history_file(conversation_id)
|
|
if os.path.exists(history_file):
|
|
os.remove(history_file)
|
|
logger.debug(f"Cleared conversation history from local file for conversation {conversation_id}")
|
|
cleared = True
|
|
except Exception as e:
|
|
logger.error(f"Failed to clear conversation history from local file: {str(e)}")
|
|
logger.debug(f"Local file clear error details: {repr(e)}")
|
|
|
|
return cleared
|
|
|
|
def get_message_count(self, conversation_id: str) -> int:
|
|
"""Get number of messages in conversation history."""
|
|
# Try to get from Redis first
|
|
if self.redis_client:
|
|
try:
|
|
key = self._get_conversation_key(conversation_id)
|
|
count = self.redis_client.llen(key)
|
|
if count > 0:
|
|
return count
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get message count from Redis: {str(e)}")
|
|
logger.debug(f"Redis message count error details: {repr(e)}")
|
|
|
|
# Fallback to local file
|
|
try:
|
|
all_messages = self._get_local_history(conversation_id)
|
|
return len(all_messages)
|
|
except Exception as e:
|
|
logger.error(f"Failed to get message count from local file: {str(e)}")
|
|
logger.debug(f"Local file message count error details: {repr(e)}")
|
|
return 0
|
|
|
|
|
|
# Global instance
|
|
redis_memory_service = RedisMemoryService() |