2026-01-21 13:45:39 +08:00
|
|
|
|
"""Chat service for AI model integration using LangChain."""
|
|
|
|
|
|
|
|
|
|
|
|
from th_agenter import db
|
|
|
|
|
|
import json
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import os
|
|
|
|
|
|
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
from th_agenter.core.new_agent import new_agent, new_llm
|
|
|
|
|
|
from ..core.config import settings
|
|
|
|
|
|
from ..models.message import MessageRole
|
|
|
|
|
|
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
|
|
|
|
|
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
|
|
|
|
|
|
from .conversation import ConversationService
|
|
|
|
|
|
from .langchain_chat import LangChainChatService
|
|
|
|
|
|
try:
|
|
|
|
|
|
from .knowledge_chat import KnowledgeChatService
|
|
|
|
|
|
except ModuleNotFoundError as e:
|
|
|
|
|
|
KnowledgeChatService = None # 需 pip install langchain-chroma
|
|
|
|
|
|
from .agent.agent_service import get_agent_service
|
|
|
|
|
|
from .agent.langgraph_agent_service import get_langgraph_agent_service
|
|
|
|
|
|
|
|
|
|
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
|
|
|
|
from langgraph.checkpoint.postgres import PostgresSaver
|
|
|
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
|
|
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
|
|
class AgentState(TypedDict):
|
|
|
|
|
|
messages: List[dict] # 存储对话消息(核心记忆)
|
|
|
|
|
|
|
|
|
|
|
|
class ChatService:
|
|
|
|
|
|
"""Service for handling AI chat functionality using LangChain."""
|
|
|
|
|
|
_checkpointer_initialized = False
|
|
|
|
|
|
_conn_string = None
|
|
|
|
|
|
|
|
|
|
|
|
async def chat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
conversation_id: int,
|
|
|
|
|
|
message: str,
|
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
|
temperature: Optional[float] = None,
|
|
|
|
|
|
max_tokens: Optional[int] = None,
|
|
|
|
|
|
use_agent: bool = False,
|
|
|
|
|
|
use_langgraph: bool = False,
|
|
|
|
|
|
use_knowledge_base: bool = False,
|
|
|
|
|
|
knowledge_base_id: Optional[int] = None
|
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
|
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
|
|
|
|
|
|
if use_knowledge_base and knowledge_base_id:
|
|
|
|
|
|
if not self.knowledge_chat_service:
|
|
|
|
|
|
raise ChatServiceError("知识库功能需要安装: pip install langchain-chroma")
|
|
|
|
|
|
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
|
|
|
|
|
return await self.knowledge_chat_service.chat_with_knowledge_base(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
message=message,
|
|
|
|
|
|
knowledge_base_id=knowledge_base_id,
|
|
|
|
|
|
stream=stream,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
elif use_langgraph:
|
|
|
|
|
|
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
|
|
|
|
|
|
|
|
|
|
|
|
# Get conversation history for LangGraph agent
|
|
|
|
|
|
conversation = await self.conversation_service.get_conversation(conversation_id)
|
|
|
|
|
|
if not conversation:
|
|
|
|
|
|
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
|
|
|
|
|
|
|
|
|
|
|
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
|
|
|
|
|
chat_history = [{
|
|
|
|
|
|
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
|
|
|
|
|
"content": msg.content
|
|
|
|
|
|
} for msg in messages]
|
|
|
|
|
|
|
|
|
|
|
|
# Use LangGraph agent service
|
|
|
|
|
|
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
|
|
|
|
|
|
|
|
|
|
|
|
if agent_result["success"]:
|
|
|
|
|
|
# Save user message
|
|
|
|
|
|
user_message = await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
content=message,
|
|
|
|
|
|
role=MessageRole.USER
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Save assistant response
|
|
|
|
|
|
assistant_message = await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
content=agent_result["response"],
|
|
|
|
|
|
role=MessageRole.ASSISTANT,
|
|
|
|
|
|
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
|
message=MessageResponse(
|
|
|
|
|
|
id=assistant_message.id,
|
|
|
|
|
|
content=agent_result["response"],
|
|
|
|
|
|
role=MessageRole.ASSISTANT,
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
created_at=assistant_message.created_at,
|
|
|
|
|
|
metadata=assistant_message.metadata
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
|
|
|
|
|
|
elif use_agent:
|
|
|
|
|
|
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
|
|
|
|
|
|
|
|
|
|
|
|
# Get conversation history for agent
|
|
|
|
|
|
conversation = await self.conversation_service.get_conversation(conversation_id)
|
|
|
|
|
|
if not conversation:
|
|
|
|
|
|
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
|
|
|
|
|
|
|
|
|
|
|
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
|
|
|
|
|
chat_history = [{
|
|
|
|
|
|
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
|
|
|
|
|
"content": msg.content
|
|
|
|
|
|
} for msg in messages]
|
|
|
|
|
|
|
|
|
|
|
|
# Use agent service
|
|
|
|
|
|
agent_result = await self.agent_service.chat(message, chat_history)
|
|
|
|
|
|
|
|
|
|
|
|
if agent_result["success"]:
|
|
|
|
|
|
# Save user message
|
|
|
|
|
|
user_message = await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
content=message,
|
|
|
|
|
|
role=MessageRole.USER
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Save assistant response
|
|
|
|
|
|
assistant_message = await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
content=agent_result["response"],
|
|
|
|
|
|
role=MessageRole.ASSISTANT,
|
|
|
|
|
|
message_metadata={"tool_calls": agent_result["tool_calls"]}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return ChatResponse(
|
|
|
|
|
|
message=MessageResponse(
|
|
|
|
|
|
id=assistant_message.id,
|
|
|
|
|
|
content=agent_result["response"],
|
|
|
|
|
|
role=MessageRole.ASSISTANT,
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
created_at=assistant_message.created_at,
|
|
|
|
|
|
metadata=assistant_message.metadata
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
|
|
|
|
|
|
|
|
|
|
|
|
# Delegate to LangChain service
|
|
|
|
|
|
return await self.langchain_chat_service.chat(
|
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
message=message,
|
|
|
|
|
|
stream=stream,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def get_available_models(self) -> List[str]:
|
|
|
|
|
|
"""Get list of available models from LangChain."""
|
|
|
|
|
|
logger.info("Getting available models via LangChain")
|
|
|
|
|
|
|
|
|
|
|
|
# Delegate to LangChain service
|
|
|
|
|
|
return await self.langchain_chat_service.get_available_models()
|
|
|
|
|
|
|
|
|
|
|
|
def update_model_config(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model: Optional[str] = None,
|
|
|
|
|
|
temperature: Optional[float] = None,
|
|
|
|
|
|
max_tokens: Optional[int] = None
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Update LLM configuration via LangChain."""
|
|
|
|
|
|
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
|
|
|
|
|
|
|
|
|
|
|
# Delegate to LangChain service
|
|
|
|
|
|
self.langchain_chat_service.update_model_config(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
# -------------------------------------------------------------------------
|
|
|
|
|
|
def __init__(self, session: Session):
|
|
|
|
|
|
self.session = session
|
|
|
|
|
|
self.knowledge_chat_service = KnowledgeChatService(session) if KnowledgeChatService else None
|
|
|
|
|
|
|
|
|
|
|
|
async def initialize(self, conversation_id: int, streaming: bool = False):
|
|
|
|
|
|
self.conversation_service = ConversationService(self.session)
|
|
|
|
|
|
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
|
|
|
|
|
|
self.conversation = await self.conversation_service.get_conversation(
|
|
|
|
|
|
conversation_id=conversation_id
|
|
|
|
|
|
)
|
|
|
|
|
|
if not self.conversation:
|
|
|
|
|
|
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
|
|
|
|
|
|
|
|
|
|
|
if not ChatService._checkpointer_initialized:
|
|
|
|
|
|
from langgraph.checkpoint.postgres import PostgresSaver
|
|
|
|
|
|
import psycopg2
|
2026-03-03 13:41:29 +08:00
|
|
|
|
import re
|
|
|
|
|
|
# LangGraph 使用的 Postgres 连接串(psycopg 格式:postgresql://)
|
|
|
|
|
|
# 优先级:LANGGRAPH_PG_URL > 从 DATABASE_URL 派生 > 默认 localhost
|
|
|
|
|
|
conn_string = os.getenv("LANGGRAPH_PG_URL")
|
|
|
|
|
|
if not conn_string:
|
|
|
|
|
|
db_url = os.getenv("DATABASE_URL", "")
|
|
|
|
|
|
if db_url and "postgresql" in db_url.split("://")[0].lower():
|
|
|
|
|
|
# 将 postgresql+asyncpg:// 转为 postgresql://,供 LangGraph/psycopg 使用
|
|
|
|
|
|
conn_string = re.sub(
|
|
|
|
|
|
r"^postgresql\+[a-zA-Z0-9]+://",
|
|
|
|
|
|
"postgresql://",
|
|
|
|
|
|
db_url,
|
|
|
|
|
|
count=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
conn_string = "postgresql://drgraph:yingping@localhost:5433/th_agenter"
|
2026-03-03 13:31:37 +08:00
|
|
|
|
ChatService._conn_string = conn_string
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查必要的表是否已存在
|
|
|
|
|
|
tables_need_setup = True
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 连接到数据库并检查表是否存在
|
2026-03-03 13:31:37 +08:00
|
|
|
|
conn = psycopg2.connect(conn_string)
|
2026-01-21 13:45:39 +08:00
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
# 检查langgraph需要的表是否存在
|
|
|
|
|
|
cursor.execute("""
|
|
|
|
|
|
SELECT table_name
|
|
|
|
|
|
FROM information_schema.tables
|
|
|
|
|
|
WHERE table_schema = 'public'
|
|
|
|
|
|
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
|
|
|
|
|
|
""")
|
|
|
|
|
|
existing_tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否所有必要的表都存在
|
|
|
|
|
|
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
|
|
|
|
|
|
if all(table in existing_tables for table in required_tables):
|
|
|
|
|
|
tables_need_setup = False
|
|
|
|
|
|
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup"
|
|
|
|
|
|
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.session.desc = f"ChatService初始化 - checkpoint失败: {str(e)},将进行setup"
|
|
|
|
|
|
tables_need_setup = True
|
|
|
|
|
|
|
|
|
|
|
|
# 只有在需要时才进行setup
|
|
|
|
|
|
if tables_need_setup:
|
|
|
|
|
|
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
|
|
|
|
|
|
try:
|
2026-03-03 13:31:37 +08:00
|
|
|
|
async with AsyncPostgresSaver.from_conn_string(conn_string) as checkpointer:
|
2026-01-21 13:45:39 +08:00
|
|
|
|
await checkpointer.setup()
|
|
|
|
|
|
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
|
|
|
|
|
|
logger.info("PostgresSaver setup完成")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
|
|
|
|
|
|
logger.error(f"PostgresSaver setup失败: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
|
|
|
|
|
|
|
|
|
|
|
|
# 存储连接字符串供后续使用
|
|
|
|
|
|
ChatService._checkpointer_initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
self.llm = await new_llm(session=self.session, streaming=streaming)
|
|
|
|
|
|
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
|
|
|
|
|
|
|
|
|
|
|
|
def get_config(self):
|
|
|
|
|
|
config = {
|
|
|
|
|
|
"configurable": {
|
|
|
|
|
|
"thread_id": str(self.conversation.id),
|
|
|
|
|
|
"checkpoint_ns": "drgraph"
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
async def chat_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
message: str
|
|
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
|
|
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
|
|
|
|
|
|
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
|
|
|
|
|
|
await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=self.conversation.id,
|
|
|
|
|
|
role=MessageRole.USER,
|
|
|
|
|
|
content=message
|
|
|
|
|
|
)
|
|
|
|
|
|
full_assistant_content = ""
|
|
|
|
|
|
|
|
|
|
|
|
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
|
|
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
|
|
agent = create_react_agent(self.llm, [], checkpointer=checkpointer)
|
|
|
|
|
|
async for chunk in agent.astream(
|
|
|
|
|
|
{"messages": [HumanMessage(content=message)]},
|
|
|
|
|
|
config=self.get_config(),
|
|
|
|
|
|
stream_mode="messages"
|
|
|
|
|
|
):
|
|
|
|
|
|
part = chunk[0].content if hasattr(chunk[0], "content") else str(chunk[0])
|
|
|
|
|
|
full_assistant_content += part
|
|
|
|
|
|
json_result = {"data": {"v": part}}
|
|
|
|
|
|
yield json.dumps(
|
|
|
|
|
|
json_result,
|
|
|
|
|
|
ensure_ascii=True
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(full_assistant_content) > 0:
|
|
|
|
|
|
await self.conversation_service.add_message(
|
|
|
|
|
|
conversation_id=self.conversation.id,
|
|
|
|
|
|
role=MessageRole.ASSISTANT,
|
|
|
|
|
|
content=full_assistant_content
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_conversation_history_messages(
|
|
|
|
|
|
self, conversation_id: int, skip: int = 0, limit: int = 100
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Get conversation history messages with pagination."""
|
|
|
|
|
|
result = []
|
|
|
|
|
|
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
|
|
|
|
|
checkpoints = checkpointer.list(self.get_config())
|
|
|
|
|
|
for checkpoint in checkpoints:
|
|
|
|
|
|
print(checkpoint)
|
|
|
|
|
|
result.append(checkpoint.messages)
|
|
|
|
|
|
return result
|