hxf/backend/th_agenter/services/chat.py

311 lines
14 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""Chat service for AI model integration using LangChain."""
2026-01-07 11:30:54 +08:00
from th_agenter import db
2025-12-04 14:48:38 +08:00
import json
import asyncio
import os
2026-01-07 11:30:54 +08:00
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
2025-12-04 14:48:38 +08:00
from sqlalchemy.orm import Session
2025-12-16 13:55:16 +08:00
from loguru import logger
2026-01-07 11:30:54 +08:00
from th_agenter.core.new_agent import new_agent, new_llm
2025-12-04 14:48:38 +08:00
from ..core.config import settings
from ..models.message import MessageRole
2025-12-16 13:55:16 +08:00
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
2026-01-07 11:30:54 +08:00
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
2025-12-04 14:48:38 +08:00
from .conversation import ConversationService
from .langchain_chat import LangChainChatService
from .knowledge_chat import KnowledgeChatService
from .agent.agent_service import get_agent_service
from .agent.langgraph_agent_service import get_langgraph_agent_service
2026-01-07 11:30:54 +08:00
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
class AgentState(TypedDict):
messages: List[dict] # 存储对话消息(核心记忆)
2025-12-04 14:48:38 +08:00
class ChatService:
"""Service for handling AI chat functionality using LangChain."""
2026-01-07 11:30:54 +08:00
_checkpointer_initialized = False
_conn_string = None
2025-12-04 14:48:38 +08:00
async def chat(
self,
conversation_id: int,
message: str,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
use_agent: bool = False,
use_langgraph: bool = False,
use_knowledge_base: bool = False,
knowledge_base_id: Optional[int] = None
) -> ChatResponse:
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
if use_knowledge_base and knowledge_base_id:
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
# Use knowledge base chat service
2026-01-07 11:30:54 +08:00
return await self.knowledge_chat_service.chat_with_knowledge_base(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
message=message,
knowledge_base_id=knowledge_base_id,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
elif use_langgraph:
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
# Get conversation history for LangGraph agent
2026-01-07 11:30:54 +08:00
conversation = await self.conversation_service.get_conversation(conversation_id)
2025-12-04 14:48:38 +08:00
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
2026-01-07 11:30:54 +08:00
messages = await self.conversation_service.get_conversation_messages(conversation_id)
2025-12-04 14:48:38 +08:00
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use LangGraph agent service
2026-01-07 11:30:54 +08:00
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
2025-12-04 14:48:38 +08:00
if agent_result["success"]:
# Save user message
2026-01-07 11:30:54 +08:00
user_message = await self.conversation_service.add_message(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
2026-01-07 11:30:54 +08:00
assistant_message = await self.conversation_service.add_message(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
elif use_agent:
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
# Get conversation history for agent
2026-01-07 11:30:54 +08:00
conversation = await self.conversation_service.get_conversation(conversation_id)
2025-12-04 14:48:38 +08:00
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
2026-01-07 11:30:54 +08:00
messages = await self.conversation_service.get_conversation_messages(conversation_id)
2025-12-04 14:48:38 +08:00
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use agent service
agent_result = await self.agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
2026-01-07 11:30:54 +08:00
user_message = await self.conversation_service.add_message(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
2026-01-07 11:30:54 +08:00
assistant_message = await self.conversation_service.add_message(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
message_metadata={"tool_calls": agent_result["tool_calls"]}
)
return ChatResponse(
message=MessageResponse(
id=assistant_message.id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
conversation_id=conversation_id,
created_at=assistant_message.created_at,
metadata=assistant_message.metadata
)
)
else:
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
else:
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
# Delegate to LangChain service
2026-01-07 11:30:54 +08:00
return await self.langchain_chat_service.chat(
2025-12-04 14:48:38 +08:00
conversation_id=conversation_id,
message=message,
stream=stream,
temperature=temperature,
max_tokens=max_tokens
)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
logger.info("Getting available models via LangChain")
# Delegate to LangChain service
2026-01-07 11:30:54 +08:00
return await self.langchain_chat_service.get_available_models()
2025-12-04 14:48:38 +08:00
def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration via LangChain."""
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
# Delegate to LangChain service
2026-01-07 11:30:54 +08:00
self.langchain_chat_service.update_model_config(
2025-12-04 14:48:38 +08:00
model=model,
temperature=temperature,
max_tokens=max_tokens
2026-01-07 11:30:54 +08:00
)
# -------------------------------------------------------------------------
def __init__(self, session: Session):
self.session = session
async def initialize(self, conversation_id: int, streaming: bool = False):
self.conversation_service = ConversationService(self.session)
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
self.conversation = await self.conversation_service.get_conversation(
conversation_id=conversation_id
)
if not self.conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
if not ChatService._checkpointer_initialized:
from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2
CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres"
ChatService._conn_string = CONN_STRING
# 检查必要的表是否已存在
tables_need_setup = True
try:
# 连接到数据库并检查表是否存在
conn = psycopg2.connect(CONN_STRING)
cursor = conn.cursor()
# 检查langgraph需要的表是否存在
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
""")
existing_tables = [row[0] for row in cursor.fetchall()]
# 检查是否所有必要的表都存在
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
if all(table in existing_tables for table in required_tables):
tables_need_setup = False
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在跳过setup"
cursor.close()
conn.close()
except Exception as e:
self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)}将进行setup"
tables_need_setup = True
# 只有在需要时才进行setup
if tables_need_setup:
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
try:
async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer:
await checkpointer.setup()
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
logger.info("PostgresSaver setup完成")
except Exception as e:
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
logger.error(f"PostgresSaver setup失败: {e}")
raise
else:
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
# 存储连接字符串供后续使用
ChatService._checkpointer_initialized = True
self.llm = await new_llm(session=self.session, streaming=streaming)
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
def get_config(self):
config = {
"configurable": {
"thread_id": str(self.conversation.id),
"checkpoint_ns": "drgraph"
}
}
return config
async def chat_stream(
self,
message: str
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.USER,
content=message
)
full_assistant_content = ""
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
from langchain.agents import create_agent
agent = create_agent(
model=self.llm, # await new_llm(session=self.session, streaming=self.streaming),
checkpointer=checkpointer
)
async for chunk in agent.astream(
{"messages": [{"role": "user", "content": message}]},
config=self.get_config(),
stream_mode="messages"
):
full_assistant_content += chunk[0].content
json_result = {"data": {"v": chunk[0].content }}
yield json.dumps(
json_result,
ensure_ascii=True
)
if len(full_assistant_content) > 0:
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.ASSISTANT,
content=full_assistant_content
)
def get_conversation_history_messages(
self, conversation_id: int, skip: int = 0, limit: int = 100
):
"""Get conversation history messages with pagination."""
result = []
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
checkpoints = checkpointer.list(self.get_config())
for checkpoint in checkpoints:
print(checkpoint)
result.append(checkpoint.messages)
return result