hxf/backend/th_agenter/services/chat.py

311 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Chat service for AI model integration using LangChain."""
from th_agenter import db
import json
import asyncio
import os
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
from sqlalchemy.orm import Session
from loguru import logger
from th_agenter.core.new_agent import new_agent, new_llm
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
from .conversation import ConversationService
from .langchain_chat import LangChainChatService
from .knowledge_chat import KnowledgeChatService
from .agent.agent_service import get_agent_service
from .agent.langgraph_agent_service import get_langgraph_agent_service
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
class AgentState(TypedDict):
messages: List[dict] # 存储对话消息(核心记忆)
class ChatService:
"""Service for handling AI chat functionality using LangChain."""
_checkpointer_initialized = False
_conn_string = None
async def chat(
self,
conversation_id: int,
message: str,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
use_agent: bool = False,
use_langgraph: bool = False,
use_knowledge_base: bool = False,
knowledge_base_id: Optional[int] = None
) -> ChatResponse:
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
if use_knowledge_base and knowledge_base_id:
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
# Use knowledge base chat service
return await self.knowledge_chat_service.chat_with_knowledge_base(
conversation_id=conversation_id,
message=message,
knowledge_base_id=knowledge_base_id,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
elif use_langgraph:
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
# Get conversation history for LangGraph agent
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use LangGraph agent service
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
elif use_agent:
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
# Get conversation history for agent
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use agent service
agent_result = await self.agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"tool_calls": agent_result["tool_calls"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
else:
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
# Delegate to LangChain service
return await self.langchain_chat_service.chat(
conversation_id=conversation_id,
message=message,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
logger.info("Getting available models via LangChain")
# Delegate to LangChain service
return await self.langchain_chat_service.get_available_models()
def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration via LangChain."""
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
# Delegate to LangChain service
self.langchain_chat_service.update_model_config(
model=model,
temperature=temperature,
max_tokens=max_tokens
)
# -------------------------------------------------------------------------
def __init__(self, session: Session):
self.session = session
async def initialize(self, conversation_id: int, streaming: bool = False):
self.conversation_service = ConversationService(self.session)
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
self.conversation = await self.conversation_service.get_conversation(
conversation_id=conversation_id
)
if not self.conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
if not ChatService._checkpointer_initialized:
from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2
CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres"
ChatService._conn_string = CONN_STRING
# 检查必要的表是否已存在
tables_need_setup = True
try:
# 连接到数据库并检查表是否存在
conn = psycopg2.connect(CONN_STRING)
cursor = conn.cursor()
# 检查langgraph需要的表是否存在
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
""")
existing_tables = [row[0] for row in cursor.fetchall()]
# 检查是否所有必要的表都存在
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
if all(table in existing_tables for table in required_tables):
tables_need_setup = False
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在跳过setup"
cursor.close()
conn.close()
except Exception as e:
self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)}将进行setup"
tables_need_setup = True
# 只有在需要时才进行setup
if tables_need_setup:
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
try:
async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer:
await checkpointer.setup()
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
logger.info("PostgresSaver setup完成")
except Exception as e:
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
logger.error(f"PostgresSaver setup失败: {e}")
raise
else:
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
# 存储连接字符串供后续使用
ChatService._checkpointer_initialized = True
self.llm = await new_llm(session=self.session, streaming=streaming)
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
def get_config(self):
config = {
"configurable": {
"thread_id": str(self.conversation.id),
"checkpoint_ns": "drgraph"
}
}
return config
async def chat_stream(
self,
message: str
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.USER,
content=message
)
full_assistant_content = ""
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
from langchain.agents import create_agent
agent = create_agent(
model=self.llm, # await new_llm(session=self.session, streaming=self.streaming),
checkpointer=checkpointer
)
async for chunk in agent.astream(
{"messages": [{"role": "user", "content": message}]},
config=self.get_config(),
stream_mode="messages"
):
full_assistant_content += chunk[0].content
json_result = {"data": {"v": chunk[0].content }}
yield json.dumps(
json_result,
ensure_ascii=True
)
if len(full_assistant_content) > 0:
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.ASSISTANT,
content=full_assistant_content
)
def get_conversation_history_messages(
self, conversation_id: int, skip: int = 0, limit: int = 100
):
"""Get conversation history messages with pagination."""
result = []
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
checkpoints = checkpointer.list(self.get_config())
for checkpoint in checkpoints:
print(checkpoint)
result.append(checkpoint.messages)
return result