hyf-backend/th_agenter/services/langchain_chat.py

397 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LangChain-based chat service."""
import json
import asyncio
import os
from typing import AsyncGenerator, Optional, List, Dict, Any
from sqlalchemy.orm import Session
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
from utils.util_exceptions import ChatServiceError, OpenAIError, AuthenticationError, RateLimitError
from loguru import logger
from .conversation import ConversationService
class StreamingCallbackHandler(BaseCallbackHandler):
"""Custom callback handler for streaming responses."""
def __init__(self):
self.tokens = []
def on_llm_new_token(self, token: str, **kwargs) -> None:
"""Handle new token from LLM."""
self.tokens.append(token)
def get_response(self) -> str:
"""Get the complete response."""
return "".join(self.tokens)
def clear(self):
"""Clear the tokens."""
self.tokens = []
class LangChainChatService:
"""LangChain-based chat service for AI model integration."""
def __init__(self, session: Session):
self.session = session
self.conversation_service = ConversationService(session)
async def initialize(self):
from ..core.new_agent import new_agent
# Initialize LangChain ChatOpenAI
self.llm = await new_agent(self.session, streaming=False)
self.session.desc = "LangChainChatService初始化 - llm 实例化完毕"
# Streaming LLM for stream responses
self.streaming_llm = await new_agent(self.session, streaming=True)
self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕"
self.streaming_handler = StreamingCallbackHandler()
self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕"
def _prepare_langchain_messages(self, conversation, history: List) -> List:
"""Prepare messages for LangChain format."""
messages = []
# Add system message if conversation has system prompt
if hasattr(conversation, 'system_prompt') and conversation.system_prompt:
messages.append(SystemMessage(content=conversation.system_prompt))
else:
# Default system message
messages.append(SystemMessage(
content="You are a helpful AI assistant. Please provide accurate and helpful responses."
))
# Add conversation history
for msg in history[:-1]: # Exclude the last message (current user message)
if msg.role == MessageRole.USER:
messages.append(HumanMessage(content=msg.content))
elif msg.role == MessageRole.ASSISTANT:
messages.append(AIMessage(content=msg.content))
# Add current user message
if history:
last_msg = history[-1]
if last_msg.role == MessageRole.USER:
messages.append(HumanMessage(content=last_msg.content))
return messages
async def chat(
self,
conversation_id: int,
message: str,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> ChatResponse:
"""Send a message and get AI response using LangChain."""
logger.info(f"Processing LangChain chat request for conversation {conversation_id}")
try:
# Get conversation details
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
# Prepare messages for LangChain
langchain_messages = self._prepare_langchain_messages(conversation, history)
# Update LLM parameters if provided
llm_to_use = self.llm
if temperature is not None or max_tokens is not None:
llm_config = await settings.llm.get_current_config()
llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
openai_api_base=llm_config["base_url"],
temperature=temperature if temperature is not None else float(conversation.temperature),
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
streaming=False
)
# Call LangChain LLM
response = await llm_to_use.ainvoke(langchain_messages)
# Extract response content
assistant_content = response.content
# Add assistant message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=assistant_content,
role=MessageRole.ASSISTANT,
message_metadata={
"model": llm_to_use.model_name,
"langchain_version": "0.1.0",
"provider": "langchain_openai"
}
)
# Update conversation timestamp
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
return ChatResponse(
user_message=MessageResponse.from_orm(user_message),
assistant_message=MessageResponse.from_orm(assistant_message),
total_tokens=None, # LangChain doesn't provide token count by default
model_used=llm_to_use.model_name
)
except Exception as e:
logger.error(f"Failed to process LangChain chat request for conversation {conversation_id}: {str(e)}", exc_info=True)
# Classify error types for better handling
error_type = type(e).__name__
error_message = self._format_error_message(e)
# Add error message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
message_metadata={
"error": True,
"error_type": error_type,
"original_error": str(e),
"langchain_error": True
}
)
# Re-raise specific exceptions for proper error handling
if "rate limit" in str(e).lower():
raise RateLimitError(str(e))
elif "api key" in str(e).lower() or "authentication" in str(e).lower():
raise AuthenticationError(str(e))
elif "openai" in str(e).lower():
raise OpenAIError(str(e))
return ChatResponse(
user_message=MessageResponse.from_orm(user_message),
assistant_message=MessageResponse.from_orm(assistant_message),
total_tokens=0,
model_used=self.llm.model_name
)
async def chat_stream(
self,
conversation_id: int,
message: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain."""
logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}")
try:
# Get conversation details
conversation = await self.conversation_service.get_conversation(conversation_id)
conv = conversation.to_dict()
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
# Prepare messages for LangChain
langchain_messages = self._prepare_langchain_messages(conv, history)
# Update streaming LLM parameters if provided
streaming_llm_to_use = self.streaming_llm
if temperature is not None or max_tokens is not None:
llm_config = await settings.llm.get_current_config()
streaming_llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
openai_api_base=llm_config["base_url"],
temperature=temperature if temperature is not None else float(conversation.temperature),
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
streaming=True
)
# Clear previous streaming handler state
self.streaming_handler.clear()
# Stream response
full_response = ""
try:
async for chunk in streaming_llm_to_use._astream(langchain_messages):
# Handle different chunk types to avoid KeyError
chunk_content = None
if hasattr(chunk, 'content'):
# For object-like chunks with content attribute
chunk_content = chunk.content
elif isinstance(chunk, dict) and 'content' in chunk:
# For dict-like chunks with content key
chunk_content = chunk['content']
elif isinstance(chunk, dict) and 'error' in chunk:
# Handle error chunks explicitly
logger.error(f"Error in LLM response: {chunk['error']}")
yield self._format_error_message(Exception(chunk['error']))
continue
if chunk_content:
full_response += chunk_content
yield chunk_content
except Exception as e:
logger.error(f"Error in LLM streaming: {e}")
yield f"{self._format_error_message(e)} >>> {e}"
# Add complete assistant message to database
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
message_metadata={
"model": streaming_llm_to_use.model_name,
"langchain_version": "0.1.0",
"provider": "langchain_openai",
"streaming": True
}
)
# Update conversation timestamp
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}")
except Exception as e:
# 安全地格式化异常信息避免再次引发KeyError
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}"
logger.error(error_info, exc_info=True)
# Format error message for user
error_message = self._format_error_message(e)
yield error_message
# Add error message to database
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
message_metadata={
"error": True,
"error_type": type(e).__name__,
"original_error": str(e),
"langchain_error": True,
"streaming": True
}
)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
try:
# LangChain doesn't have a direct method to list models
# Return commonly available OpenAI models
return [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4o",
"gpt-4o-mini"
]
except Exception as e:
logger.error(f"Failed to get available models: {str(e)}")
return ["gpt-3.5-turbo"]
async def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration."""
from ..core.new_agent import new_agent
# 重新创建LLM实例
self.llm = await new_agent(
model=model,
temperature=temperature,
streaming=False
)
self.streaming_llm = await new_agent(
model=model,
temperature=temperature,
streaming=True
)
logger.info(f"Updated LLM configuration: model={model}, temperature={temperature}, max_tokens={max_tokens}")
def _format_error_message(self, error: Exception) -> str:
"""Format error message for user display."""
error_type = type(error).__name__
error_str = str(error)
# Provide user-friendly error messages
if "rate limit" in error_str.lower():
return "服务器繁忙,请稍后再试。"
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
return f"API认证失败请检查配置文件。"
elif "timeout" in error_str.lower():
return "请求超时,请重试。"
elif "connection" in error_str.lower():
return "网络连接错误,请检查网络连接。"
elif "model" in error_str.lower() and "not found" in error_str.lower():
return "指定的模型不可用,请选择其他模型。"
else:
return f"处理请求时发生错误:{error_str}"
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
"""Retry function with exponential backoff."""
for attempt in range(max_retries):
try:
return await func()
except Exception as e:
if attempt == max_retries - 1:
raise e
# Check if error is retryable
if not self._is_retryable_error(e):
raise e
delay = base_delay * (2 ** attempt)
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}")
await asyncio.sleep(delay)
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if an error is retryable."""
error_str = str(error).lower()
retryable_errors = [
"timeout",
"connection",
"server error",
"internal error",
"rate limit"
]
return any(err in error_str for err in retryable_errors)