"""Chat service for AI model integration using LangChain.""" from th_agenter import db import json import asyncio import os from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict from sqlalchemy.orm import Session from loguru import logger from th_agenter.core.new_agent import new_agent, new_llm from ..core.config import settings from ..models.message import MessageRole from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError from .conversation import ConversationService from .langchain_chat import LangChainChatService from .knowledge_chat import KnowledgeChatService from .agent.agent_service import get_agent_service from .agent.langgraph_agent_service import get_langgraph_agent_service from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres import PostgresSaver from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph class AgentState(TypedDict): messages: List[dict] # 存储对话消息(核心记忆) class ChatService: """Service for handling AI chat functionality using LangChain.""" _checkpointer_initialized = False _conn_string = None async def chat( self, conversation_id: int, message: str, stream: bool = False, temperature: Optional[float] = None, max_tokens: Optional[int] = None, use_agent: bool = False, use_langgraph: bool = False, use_knowledge_base: bool = False, knowledge_base_id: Optional[int] = None ) -> ChatResponse: """Send a message and get AI response using LangChain, Agent, or Knowledge Base.""" if use_knowledge_base and knowledge_base_id: logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}") # Use knowledge base chat service return await self.knowledge_chat_service.chat_with_knowledge_base( conversation_id=conversation_id, message=message, knowledge_base_id=knowledge_base_id, stream=stream, temperature=temperature, max_tokens=max_tokens ) elif use_langgraph: logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent") # Get conversation history for LangGraph agent conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError(f"Conversation {conversation_id} not found") messages = await self.conversation_service.get_conversation_messages(conversation_id) chat_history = [{ "role": "user" if msg.role == MessageRole.USER else "assistant", "content": msg.content } for msg in messages] # Use LangGraph agent service agent_result = await self.langgraph_agent_service.chat(message, chat_history) if agent_result["success"]: # Save user message user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Save assistant response assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=agent_result["response"], role=MessageRole.ASSISTANT, message_metadata={"intermediate_steps": agent_result["intermediate_steps"]} ) return ChatResponse( message=MessageResponse( id=assistant_message.id, content=agent_result["response"], role=MessageRole.ASSISTANT, conversation_id=conversation_id, created_at=assistant_message.created_at, metadata=assistant_message.metadata ) ) else: raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}") elif use_agent: logger.info(f"Processing chat request for conversation {conversation_id} via Agent") # Get conversation history for agent conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError(f"Conversation {conversation_id} not found") messages = await self.conversation_service.get_conversation_messages(conversation_id) chat_history = [{ "role": "user" if msg.role == MessageRole.USER else "assistant", "content": msg.content } for msg in messages] # Use agent service agent_result = await self.agent_service.chat(message, chat_history) if agent_result["success"]: # Save user message user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Save assistant response assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=agent_result["response"], role=MessageRole.ASSISTANT, message_metadata={"tool_calls": agent_result["tool_calls"]} ) return ChatResponse( message=MessageResponse( id=assistant_message.id, content=agent_result["response"], role=MessageRole.ASSISTANT, conversation_id=conversation_id, created_at=assistant_message.created_at, metadata=assistant_message.metadata ) ) else: raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}") else: logger.info(f"Processing chat request for conversation {conversation_id} via LangChain") # Delegate to LangChain service return await self.langchain_chat_service.chat( conversation_id=conversation_id, message=message, stream=stream, temperature=temperature, max_tokens=max_tokens ) async def get_available_models(self) -> List[str]: """Get list of available models from LangChain.""" logger.info("Getting available models via LangChain") # Delegate to LangChain service return await self.langchain_chat_service.get_available_models() def update_model_config( self, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None ): """Update LLM configuration via LangChain.""" logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}") # Delegate to LangChain service self.langchain_chat_service.update_model_config( model=model, temperature=temperature, max_tokens=max_tokens ) # ------------------------------------------------------------------------- def __init__(self, session: Session): self.session = session async def initialize(self, conversation_id: int, streaming: bool = False): self.conversation_service = ConversationService(self.session) self.session.desc = "ChatService初始化 - ConversationService 实例化完毕" self.conversation = await self.conversation_service.get_conversation( conversation_id=conversation_id ) if not self.conversation: raise ChatServiceError(f"Conversation {conversation_id} not found") if not ChatService._checkpointer_initialized: from langgraph.checkpoint.postgres import PostgresSaver import psycopg2 CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres" ChatService._conn_string = CONN_STRING # 检查必要的表是否已存在 tables_need_setup = True try: # 连接到数据库并检查表是否存在 conn = psycopg2.connect(CONN_STRING) cursor = conn.cursor() # 检查langgraph需要的表是否存在 cursor.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs') """) existing_tables = [row[0] for row in cursor.fetchall()] # 检查是否所有必要的表都存在 required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs'] if all(table in existing_tables for table in required_tables): tables_need_setup = False self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup" cursor.close() conn.close() except Exception as e: self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)},将进行setup" tables_need_setup = True # 只有在需要时才进行setup if tables_need_setup: self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup" try: async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer: await checkpointer.setup() self.session.desc = "ChatService初始化 - PostgresSaver setup完成" logger.info("PostgresSaver setup完成") except Exception as e: self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}" logger.error(f"PostgresSaver setup失败: {e}") raise else: self.session.desc = "ChatService初始化 - 使用现有的langgraph表" # 存储连接字符串供后续使用 ChatService._checkpointer_initialized = True self.llm = await new_llm(session=self.session, streaming=streaming) self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}" def get_config(self): config = { "configurable": { "thread_id": str(self.conversation.id), "checkpoint_ns": "drgraph" } } return config async def chat_stream( self, message: str ) -> AsyncGenerator[str, None]: """Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base.""" self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}" await self.conversation_service.add_message( conversation_id=self.conversation.id, role=MessageRole.USER, content=message ) full_assistant_content = "" async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: from langchain.agents import create_agent agent = create_agent( model=self.llm, # await new_llm(session=self.session, streaming=self.streaming), checkpointer=checkpointer ) async for chunk in agent.astream( {"messages": [{"role": "user", "content": message}]}, config=self.get_config(), stream_mode="messages" ): full_assistant_content += chunk[0].content json_result = {"data": {"v": chunk[0].content }} yield json.dumps( json_result, ensure_ascii=True ) if len(full_assistant_content) > 0: await self.conversation_service.add_message( conversation_id=self.conversation.id, role=MessageRole.ASSISTANT, content=full_assistant_content ) def get_conversation_history_messages( self, conversation_id: int, skip: int = 0, limit: int = 100 ): """Get conversation history messages with pagination.""" result = [] with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: checkpoints = checkpointer.list(self.get_config()) for checkpoint in checkpoints: print(checkpoint) result.append(checkpoint.messages) return result