311 lines
14 KiB
Python
311 lines
14 KiB
Python
"""Chat service for AI model integration using LangChain."""
|
||
|
||
from th_agenter import db
|
||
import json
|
||
import asyncio
|
||
import os
|
||
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
|
||
from sqlalchemy.orm import Session
|
||
|
||
from loguru import logger
|
||
|
||
from th_agenter.core.new_agent import new_agent, new_llm
|
||
from ..core.config import settings
|
||
from ..models.message import MessageRole
|
||
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
||
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
|
||
from .conversation import ConversationService
|
||
from .langchain_chat import LangChainChatService
|
||
from .knowledge_chat import KnowledgeChatService
|
||
from .agent.agent_service import get_agent_service
|
||
from .agent.langgraph_agent_service import get_langgraph_agent_service
|
||
|
||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||
from langgraph.checkpoint.postgres import PostgresSaver
|
||
from langgraph.graph import StateGraph, START, END
|
||
from langgraph.graph.state import CompiledStateGraph
|
||
class AgentState(TypedDict):
|
||
messages: List[dict] # 存储对话消息(核心记忆)
|
||
|
||
class ChatService:
|
||
"""Service for handling AI chat functionality using LangChain."""
|
||
_checkpointer_initialized = False
|
||
_conn_string = None
|
||
|
||
async def chat(
|
||
self,
|
||
conversation_id: int,
|
||
message: str,
|
||
stream: bool = False,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
use_agent: bool = False,
|
||
use_langgraph: bool = False,
|
||
use_knowledge_base: bool = False,
|
||
knowledge_base_id: Optional[int] = None
|
||
) -> ChatResponse:
|
||
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
|
||
if use_knowledge_base and knowledge_base_id:
|
||
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
||
|
||
# Use knowledge base chat service
|
||
return await self.knowledge_chat_service.chat_with_knowledge_base(
|
||
conversation_id=conversation_id,
|
||
message=message,
|
||
knowledge_base_id=knowledge_base_id,
|
||
stream=stream,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
elif use_langgraph:
|
||
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
|
||
|
||
# Get conversation history for LangGraph agent
|
||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||
if not conversation:
|
||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||
|
||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||
chat_history = [{
|
||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||
"content": msg.content
|
||
} for msg in messages]
|
||
|
||
# Use LangGraph agent service
|
||
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
|
||
|
||
if agent_result["success"]:
|
||
# Save user message
|
||
user_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=message,
|
||
role=MessageRole.USER
|
||
)
|
||
|
||
# Save assistant response
|
||
assistant_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=agent_result["response"],
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
|
||
)
|
||
|
||
return ChatResponse(
|
||
message=MessageResponse(
|
||
id=assistant_message.id,
|
||
content=agent_result["response"],
|
||
role=MessageRole.ASSISTANT,
|
||
conversation_id=conversation_id,
|
||
created_at=assistant_message.created_at,
|
||
metadata=assistant_message.metadata
|
||
)
|
||
)
|
||
else:
|
||
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
|
||
elif use_agent:
|
||
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
|
||
|
||
# Get conversation history for agent
|
||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||
if not conversation:
|
||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||
|
||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||
chat_history = [{
|
||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||
"content": msg.content
|
||
} for msg in messages]
|
||
|
||
# Use agent service
|
||
agent_result = await self.agent_service.chat(message, chat_history)
|
||
|
||
if agent_result["success"]:
|
||
# Save user message
|
||
user_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=message,
|
||
role=MessageRole.USER
|
||
)
|
||
|
||
# Save assistant response
|
||
assistant_message = await self.conversation_service.add_message(
|
||
conversation_id=conversation_id,
|
||
content=agent_result["response"],
|
||
role=MessageRole.ASSISTANT,
|
||
message_metadata={"tool_calls": agent_result["tool_calls"]}
|
||
)
|
||
|
||
return ChatResponse(
|
||
message=MessageResponse(
|
||
id=assistant_message.id,
|
||
content=agent_result["response"],
|
||
role=MessageRole.ASSISTANT,
|
||
conversation_id=conversation_id,
|
||
created_at=assistant_message.created_at,
|
||
metadata=assistant_message.metadata
|
||
)
|
||
)
|
||
else:
|
||
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
|
||
else:
|
||
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
|
||
|
||
# Delegate to LangChain service
|
||
return await self.langchain_chat_service.chat(
|
||
conversation_id=conversation_id,
|
||
message=message,
|
||
stream=stream,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
|
||
async def get_available_models(self) -> List[str]:
|
||
"""Get list of available models from LangChain."""
|
||
logger.info("Getting available models via LangChain")
|
||
|
||
# Delegate to LangChain service
|
||
return await self.langchain_chat_service.get_available_models()
|
||
|
||
def update_model_config(
|
||
self,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
):
|
||
"""Update LLM configuration via LangChain."""
|
||
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
||
|
||
# Delegate to LangChain service
|
||
self.langchain_chat_service.update_model_config(
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
# -------------------------------------------------------------------------
|
||
def __init__(self, session: Session):
|
||
self.session = session
|
||
|
||
async def initialize(self, conversation_id: int, streaming: bool = False):
|
||
self.conversation_service = ConversationService(self.session)
|
||
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
|
||
self.conversation = await self.conversation_service.get_conversation(
|
||
conversation_id=conversation_id
|
||
)
|
||
if not self.conversation:
|
||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||
|
||
if not ChatService._checkpointer_initialized:
|
||
from langgraph.checkpoint.postgres import PostgresSaver
|
||
import psycopg2
|
||
CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres"
|
||
ChatService._conn_string = CONN_STRING
|
||
|
||
# 检查必要的表是否已存在
|
||
tables_need_setup = True
|
||
try:
|
||
# 连接到数据库并检查表是否存在
|
||
conn = psycopg2.connect(CONN_STRING)
|
||
cursor = conn.cursor()
|
||
|
||
# 检查langgraph需要的表是否存在
|
||
cursor.execute("""
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
|
||
""")
|
||
existing_tables = [row[0] for row in cursor.fetchall()]
|
||
|
||
# 检查是否所有必要的表都存在
|
||
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
|
||
if all(table in existing_tables for table in required_tables):
|
||
tables_need_setup = False
|
||
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup"
|
||
|
||
cursor.close()
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)},将进行setup"
|
||
tables_need_setup = True
|
||
|
||
# 只有在需要时才进行setup
|
||
if tables_need_setup:
|
||
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
|
||
try:
|
||
async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer:
|
||
await checkpointer.setup()
|
||
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
|
||
logger.info("PostgresSaver setup完成")
|
||
except Exception as e:
|
||
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
|
||
logger.error(f"PostgresSaver setup失败: {e}")
|
||
raise
|
||
else:
|
||
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
|
||
|
||
# 存储连接字符串供后续使用
|
||
ChatService._checkpointer_initialized = True
|
||
|
||
self.llm = await new_llm(session=self.session, streaming=streaming)
|
||
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
|
||
|
||
def get_config(self):
|
||
config = {
|
||
"configurable": {
|
||
"thread_id": str(self.conversation.id),
|
||
"checkpoint_ns": "drgraph"
|
||
}
|
||
}
|
||
return config
|
||
|
||
async def chat_stream(
|
||
self,
|
||
message: str
|
||
) -> AsyncGenerator[str, None]:
|
||
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
|
||
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
|
||
await self.conversation_service.add_message(
|
||
conversation_id=self.conversation.id,
|
||
role=MessageRole.USER,
|
||
content=message
|
||
)
|
||
full_assistant_content = ""
|
||
|
||
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
||
from langchain.agents import create_agent
|
||
agent = create_agent(
|
||
model=self.llm, # await new_llm(session=self.session, streaming=self.streaming),
|
||
checkpointer=checkpointer
|
||
)
|
||
async for chunk in agent.astream(
|
||
{"messages": [{"role": "user", "content": message}]},
|
||
config=self.get_config(),
|
||
stream_mode="messages"
|
||
):
|
||
full_assistant_content += chunk[0].content
|
||
json_result = {"data": {"v": chunk[0].content }}
|
||
yield json.dumps(
|
||
json_result,
|
||
ensure_ascii=True
|
||
)
|
||
|
||
if len(full_assistant_content) > 0:
|
||
await self.conversation_service.add_message(
|
||
conversation_id=self.conversation.id,
|
||
role=MessageRole.ASSISTANT,
|
||
content=full_assistant_content
|
||
)
|
||
|
||
def get_conversation_history_messages(
|
||
self, conversation_id: int, skip: int = 0, limit: int = 100
|
||
):
|
||
"""Get conversation history messages with pagination."""
|
||
result = []
|
||
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
||
checkpoints = checkpointer.list(self.get_config())
|
||
for checkpoint in checkpoints:
|
||
print(checkpoint)
|
||
result.append(checkpoint.messages)
|
||
return result
|