397 lines
16 KiB
Python
397 lines
16 KiB
Python
"""LangChain-based chat service."""
|
||
|
||
import json
|
||
import asyncio
|
||
import os
|
||
from typing import AsyncGenerator, Optional, List, Dict, Any
|
||
from sqlalchemy.orm import Session
|
||
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||
from langchain_core.callbacks import BaseCallbackHandler
|
||
from langchain_core.outputs import LLMResult
|
||
|
||
from ..core.config import settings
|
||
from ..models.message import MessageRole
|
||
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
||
from utils.util_exceptions import ChatServiceError, OpenAIError, AuthenticationError, RateLimitError
|
||
from loguru import logger
|
||
from .conversation import ConversationService
|
||
|
||
|
||
class StreamingCallbackHandler(BaseCallbackHandler):
|
||
"""Custom callback handler for streaming responses."""
|
||
|
||
def __init__(self):
|
||
self.tokens = []
|
||
|
||
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
||
"""Handle new token from LLM."""
|
||
self.tokens.append(token)
|
||
|
||
def get_response(self) -> str:
|
||
"""Get the complete response."""
|
||
return "".join(self.tokens)
|
||
|
||
def clear(self):
|
||
"""Clear the tokens."""
|
||
self.tokens = []
|
||
|
||
|
||
class LangChainChatService:
|
||
"""LangChain-based chat service for AI model integration."""
|
||
|
||
def __init__(self, session: Session):
|
||
self.session = session
|
||
self.conversation_service = ConversationService(session)
|
||
|
||
async def initialize(self):
|
||
from ..core.new_agent import new_agent
|
||
|
||
# Initialize LangChain ChatOpenAI
|
||
self.llm = await new_agent(self.session, streaming=False)
|
||
self.session.desc = "LangChainChatService初始化 - llm 实例化完毕"
|
||
|
||
# Streaming LLM for stream responses
|
||
self.streaming_llm = await new_agent(self.session, streaming=True)
|
||
self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕"
|
||
|
||
self.streaming_handler = StreamingCallbackHandler()
|
||
self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕"
|
||
|
||
def _prepare_langchain_messages(self, conversation, history: List) -> List:
|
||
"""Prepare messages for LangChain format."""
|
||
messages = []
|
||
|
||
# Add system message if conversation has system prompt
|
||
if hasattr(conversation, 'system_prompt') and conversation.system_prompt:
|
||
messages.append(SystemMessage(content=conversation.system_prompt))
|
||
else:
|
||
# Default system message
|
||
messages.append(SystemMessage(
|
||
content="You are a helpful AI assistant. Please provide accurate and helpful responses."
|
||
))
|
||
|
||
# Add conversation history
|
||
for msg in history[:-1]: # Exclude the last message (current user message)
|
||
if msg.role == MessageRole.USER:
|
||
messages.append(HumanMessage(content=msg.content))
|
||
elif msg.role == MessageRole.ASSISTANT:
|
||
messages.append(AIMessage(content=msg.content))
|
||
|
||
# Add current user message
|
||
if history:
|
||
last_msg = history[-1]
|
||
if last_msg.role == MessageRole.USER:
|
||
messages.append(HumanMessage(content=last_msg.content))
|
||
|
||
return messages
|
||
|
||
async def chat(
|
||
self,
|
||
conversation_id: int,
|
||
message: str,
|
||
stream: bool = False,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> ChatResponse:
|
||
"""Send a message and get AI response using LangChain."""
|
||
logger.info(f"Processing LangChain chat request for conversation {conversation_id}")
|
||
|
||
try:
|
||
# Get conversation details
|
||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||
if not conversation:
|
||
raise ChatServiceError("Conversation not found")
|
||
|
||
# Add user message to database
|
||
user_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=message,
|
||
role=MessageRole.USER
|
||
)
|
||
|
||
# Get conversation history for context
|
||
history = await self.conversation_service.get_conversation_history(
|
||
conversation_id, limit=20
|
||
)
|
||
|
||
# Prepare messages for LangChain
|
||
langchain_messages = self._prepare_langchain_messages(conversation, history)
|
||
|
||
# Update LLM parameters if provided
|
||
llm_to_use = self.llm
|
||
if temperature is not None or max_tokens is not None:
|
||
llm_config = await settings.llm.get_current_config()
|
||
llm_to_use = ChatOpenAI(
|
||
model=llm_config["model"],
|
||
openai_api_key=llm_config["api_key"],
|
||
openai_api_base=llm_config["base_url"],
|
||
temperature=temperature if temperature is not None else float(conversation.temperature),
|
||
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
|
||
streaming=False
|
||
)
|
||
|
||
# Call LangChain LLM
|
||
response = await llm_to_use.ainvoke(langchain_messages)
|
||
|
||
# Extract response content
|
||
assistant_content = response.content
|
||
|
||
# Add assistant message to database
|
||
assistant_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=assistant_content,
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={
|
||
"model": llm_to_use.model_name,
|
||
"langchain_version": "0.1.0",
|
||
"provider": "langchain_openai"
|
||
}
|
||
)
|
||
|
||
# Update conversation timestamp
|
||
await self.conversation_service.update_conversation_timestamp(conversation_id)
|
||
|
||
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
|
||
|
||
return ChatResponse(
|
||
user_message=MessageResponse.from_orm(user_message),
|
||
assistant_message=MessageResponse.from_orm(assistant_message),
|
||
total_tokens=None, # LangChain doesn't provide token count by default
|
||
model_used=llm_to_use.model_name
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to process LangChain chat request for conversation {conversation_id}: {str(e)}", exc_info=True)
|
||
|
||
# Classify error types for better handling
|
||
error_type = type(e).__name__
|
||
error_message = self._format_error_message(e)
|
||
|
||
# Add error message to database
|
||
assistant_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=error_message,
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={
|
||
"error": True,
|
||
"error_type": error_type,
|
||
"original_error": str(e),
|
||
"langchain_error": True
|
||
}
|
||
)
|
||
|
||
# Re-raise specific exceptions for proper error handling
|
||
if "rate limit" in str(e).lower():
|
||
raise RateLimitError(str(e))
|
||
elif "api key" in str(e).lower() or "authentication" in str(e).lower():
|
||
raise AuthenticationError(str(e))
|
||
elif "openai" in str(e).lower():
|
||
raise OpenAIError(str(e))
|
||
|
||
return ChatResponse(
|
||
user_message=MessageResponse.from_orm(user_message),
|
||
assistant_message=MessageResponse.from_orm(assistant_message),
|
||
total_tokens=0,
|
||
model_used=self.llm.model_name
|
||
)
|
||
|
||
async def chat_stream(
|
||
self,
|
||
conversation_id: int,
|
||
message: str,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> AsyncGenerator[str, None]:
|
||
"""Send a message and get streaming AI response using LangChain."""
|
||
logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}")
|
||
|
||
try:
|
||
# Get conversation details
|
||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||
conv = conversation.to_dict()
|
||
if not conversation:
|
||
raise ChatServiceError("Conversation not found")
|
||
|
||
# Add user message to database
|
||
user_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=message,
|
||
role=MessageRole.USER
|
||
)
|
||
|
||
# Get conversation history for context
|
||
history = await self.conversation_service.get_conversation_history(
|
||
conversation_id, limit=20
|
||
)
|
||
|
||
# Prepare messages for LangChain
|
||
langchain_messages = self._prepare_langchain_messages(conv, history)
|
||
# Update streaming LLM parameters if provided
|
||
streaming_llm_to_use = self.streaming_llm
|
||
if temperature is not None or max_tokens is not None:
|
||
llm_config = await settings.llm.get_current_config()
|
||
streaming_llm_to_use = ChatOpenAI(
|
||
model=llm_config["model"],
|
||
openai_api_key=llm_config["api_key"],
|
||
openai_api_base=llm_config["base_url"],
|
||
temperature=temperature if temperature is not None else float(conversation.temperature),
|
||
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
|
||
streaming=True
|
||
)
|
||
# Clear previous streaming handler state
|
||
self.streaming_handler.clear()
|
||
|
||
# Stream response
|
||
full_response = ""
|
||
try:
|
||
async for chunk in streaming_llm_to_use._astream(langchain_messages):
|
||
# Handle different chunk types to avoid KeyError
|
||
chunk_content = None
|
||
if hasattr(chunk, 'content'):
|
||
# For object-like chunks with content attribute
|
||
chunk_content = chunk.content
|
||
elif isinstance(chunk, dict) and 'content' in chunk:
|
||
# For dict-like chunks with content key
|
||
chunk_content = chunk['content']
|
||
elif isinstance(chunk, dict) and 'error' in chunk:
|
||
# Handle error chunks explicitly
|
||
logger.error(f"Error in LLM response: {chunk['error']}")
|
||
yield self._format_error_message(Exception(chunk['error']))
|
||
continue
|
||
|
||
if chunk_content:
|
||
full_response += chunk_content
|
||
yield chunk_content
|
||
except Exception as e:
|
||
logger.error(f"Error in LLM streaming: {e}")
|
||
yield f"{self._format_error_message(e)} >>> {e}"
|
||
# Add complete assistant message to database
|
||
assistant_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=full_response,
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={
|
||
"model": streaming_llm_to_use.model_name,
|
||
"langchain_version": "0.1.0",
|
||
"provider": "langchain_openai",
|
||
"streaming": True
|
||
}
|
||
)
|
||
|
||
# Update conversation timestamp
|
||
await self.conversation_service.update_conversation_timestamp(conversation_id)
|
||
logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}")
|
||
|
||
except Exception as e:
|
||
# 安全地格式化异常信息,避免再次引发KeyError
|
||
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}"
|
||
logger.error(error_info, exc_info=True)
|
||
|
||
# Format error message for user
|
||
error_message = self._format_error_message(e)
|
||
yield error_message
|
||
|
||
# Add error message to database
|
||
await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=error_message,
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={
|
||
"error": True,
|
||
"error_type": type(e).__name__,
|
||
"original_error": str(e),
|
||
"langchain_error": True,
|
||
"streaming": True
|
||
}
|
||
)
|
||
|
||
async def get_available_models(self) -> List[str]:
|
||
"""Get list of available models from LangChain."""
|
||
try:
|
||
# LangChain doesn't have a direct method to list models
|
||
# Return commonly available OpenAI models
|
||
return [
|
||
"gpt-3.5-turbo",
|
||
"gpt-3.5-turbo-16k",
|
||
"gpt-4",
|
||
"gpt-4-turbo-preview",
|
||
"gpt-4o",
|
||
"gpt-4o-mini"
|
||
]
|
||
except Exception as e:
|
||
logger.error(f"Failed to get available models: {str(e)}")
|
||
return ["gpt-3.5-turbo"]
|
||
|
||
async def update_model_config(
|
||
self,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
):
|
||
"""Update LLM configuration."""
|
||
from ..core.new_agent import new_agent
|
||
|
||
# 重新创建LLM实例
|
||
self.llm = await new_agent(
|
||
model=model,
|
||
temperature=temperature,
|
||
streaming=False
|
||
)
|
||
|
||
self.streaming_llm = await new_agent(
|
||
model=model,
|
||
temperature=temperature,
|
||
streaming=True
|
||
)
|
||
|
||
logger.info(f"Updated LLM configuration: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
||
|
||
def _format_error_message(self, error: Exception) -> str:
|
||
"""Format error message for user display."""
|
||
error_type = type(error).__name__
|
||
error_str = str(error)
|
||
|
||
# Provide user-friendly error messages
|
||
if "rate limit" in error_str.lower():
|
||
return "服务器繁忙,请稍后再试。"
|
||
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
|
||
return f"API认证失败,请检查配置文件。"
|
||
elif "timeout" in error_str.lower():
|
||
return "请求超时,请重试。"
|
||
elif "connection" in error_str.lower():
|
||
return "网络连接错误,请检查网络连接。"
|
||
elif "model" in error_str.lower() and "not found" in error_str.lower():
|
||
return "指定的模型不可用,请选择其他模型。"
|
||
else:
|
||
return f"处理请求时发生错误:{error_str}"
|
||
|
||
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
|
||
"""Retry function with exponential backoff."""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return await func()
|
||
except Exception as e:
|
||
if attempt == max_retries - 1:
|
||
raise e
|
||
|
||
# Check if error is retryable
|
||
if not self._is_retryable_error(e):
|
||
raise e
|
||
|
||
delay = base_delay * (2 ** attempt)
|
||
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}")
|
||
await asyncio.sleep(delay)
|
||
|
||
def _is_retryable_error(self, error: Exception) -> bool:
|
||
"""Check if an error is retryable."""
|
||
error_str = str(error).lower()
|
||
retryable_errors = [
|
||
"timeout",
|
||
"connection",
|
||
"server error",
|
||
"internal error",
|
||
"rate limit"
|
||
]
|
||
return any(err in error_str for err in retryable_errors) |