hyf-backend/th_agenter/services/knowledge_chat.py

351 lines
14 KiB
Python

"""Knowledge base chat service using LangChain RAG."""
import asyncio
from typing import List, Dict, Any, Optional, AsyncGenerator
from sqlalchemy.orm import Session
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_chroma import Chroma
from langchain_postgres import PGVector
from .embedding_factory import EmbeddingFactory
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, MessageResponse
from utils.util_exceptions import ChatServiceError
from .conversation import ConversationService
from .document_processor import get_document_processor
from loguru import logger
class KnowledgeChatService:
"""Knowledge base chat service using LangChain RAG."""
def __init__(self, session: Session):
self.session = session
self.conversation_service = ConversationService(session)
async def initialize(self):
# 获取当前LLM配置
llm_config = await settings.llm.get_current_config(self.session)
# Initialize LangChain ChatOpenAI
self.llm = ChatOpenAI(
model=llm_config["model"],
api_key=llm_config["api_key"],
base_url=llm_config["base_url"],
temperature=llm_config["temperature"],
max_tokens=llm_config["max_tokens"],
streaming=False
)
# Streaming LLM for stream responses
self.streaming_llm = ChatOpenAI(
model=llm_config["model"],
api_key=llm_config["api_key"],
base_url=llm_config["base_url"],
temperature=llm_config["temperature"],
max_tokens=llm_config["max_tokens"],
streaming=True
)
# Initialize embeddings based on provider
self.embeddings = await EmbeddingFactory.create_embeddings(self.session)
async def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
"""Get vector store for knowledge base."""
try:
import os
kb_vector_path = os.path.join((await get_document_processor(self.session)).vector_db_path, f"kb_{knowledge_base_id}")
if not os.path.exists(kb_vector_path):
logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}")
return None
vector_store = Chroma(
persist_directory=kb_vector_path,
embedding_function=self.embeddings
)
return vector_store
except Exception as e:
logger.error(f"Failed to load vector store for KB {knowledge_base_id}: {str(e)}")
return None
def _create_rag_chain(self, vector_store, conversation_history: List[Dict[str, str]]):
"""Create RAG chain with conversation history."""
# Create retriever
retriever = vector_store.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
# Create prompt template
system_prompt = """你是一个智能助手,基于提供的上下文信息回答用户问题。
上下文信息:
{context}
请根据上下文信息回答用户的问题。如果上下文信息不足以回答问题,请诚实地说明。
保持回答准确、有用且简洁。"""
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}")
])
# Create chain
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
"chat_history": lambda x: conversation_history
}
| prompt
| self.llm
| StrOutputParser()
)
return rag_chain, retriever
def _prepare_conversation_history(self, messages: List) -> List[Dict[str, str]]:
"""Prepare conversation history for RAG chain."""
history = []
for msg in messages[:-1]: # Exclude the last message (current user message)
if msg.role == MessageRole.USER:
history.append({"role": "human", "content": msg.content})
elif msg.role == MessageRole.ASSISTANT:
history.append({"role": "assistant", "content": msg.content})
return history
async def chat_with_knowledge_base(
self,
conversation_id: int,
message: str,
knowledge_base_id: int,
stream: bool = False,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> ChatResponse:
"""Chat with knowledge base using RAG."""
try:
# Get conversation and validate
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError("Conversation not found")
# Get vector store
vector_store = self._get_vector_store(knowledge_base_id)
if not vector_store:
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history
messages = await self.conversation_service.get_conversation_messages(conversation_id)
conversation_history = self._prepare_conversation_history(messages)
# Create RAG chain
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
# Get relevant documents for context
relevant_docs = retriever.get_relevant_documents(message)
context_documents = []
for doc in relevant_docs:
context_documents.append({
"content": doc.page_content[:500], # Limit content length
"metadata": doc.metadata,
"source": doc.metadata.get("filename", "unknown")
})
# Generate response
if stream:
# For streaming, we'll use a different approach
response_content = await self._generate_streaming_response(
rag_chain, message, conversation_id
)
else:
response_content = await asyncio.to_thread(rag_chain.invoke, message)
# Save assistant message with context
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=response_content,
role=MessageRole.ASSISTANT,
context_documents=context_documents
)
# Create response
return ChatResponse(
user_message=MessageResponse.from_orm(user_message),
assistant_message=MessageResponse.from_orm(assistant_message),
model_used=self.llm.model_name,
total_tokens=None # TODO: Calculate tokens if needed
)
except Exception as e:
logger.error(f"Knowledge base chat failed: {str(e)}")
raise ChatServiceError(f"Knowledge base chat failed: {str(e)}")
async def _generate_streaming_response(
self,
rag_chain,
message: str,
conversation_id: int
) -> str:
"""Generate streaming response (placeholder for now)."""
# For now, use non-streaming approach
# TODO: Implement proper streaming with RAG chain
return await asyncio.to_thread(rag_chain.invoke, message)
async def chat_stream_with_knowledge_base(
self,
conversation_id: int,
message: str,
knowledge_base_id: int,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""Chat with knowledge base using RAG with streaming response."""
try:
# Get vector store
vector_store = self._get_vector_store(knowledge_base_id)
if not vector_store:
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
# Get conversation history
messages = await self.conversation_service.get_conversation_messages(conversation_id)
conversation_history = self._prepare_conversation_history(messages)
# Create RAG chain
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
# Save user message
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get relevant documents
relevant_docs = retriever.get_relevant_documents(message)
context = "\n\n".join([doc.page_content for doc in relevant_docs])
# Create streaming LLM
llm_config = await settings.llm.get_current_config()
streaming_llm = ChatOpenAI(
model=llm_config["model"],
temperature=temperature or llm_config["temperature"],
max_tokens=max_tokens or llm_config["max_tokens"],
streaming=True,
api_key=llm_config["api_key"],
base_url=llm_config["base_url"]
)
# Create prompt for streaming
prompt = ChatPromptTemplate.from_messages([
("system", "你是一个智能助手。请基于以下上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地说明。\n\n上下文信息:\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}")
])
# Prepare chat history for prompt
chat_history_messages = []
for hist in conversation_history:
if hist["role"] == "human":
chat_history_messages.append(HumanMessage(content=hist["content"]))
elif hist["role"] == "assistant":
chat_history_messages.append(AIMessage(content=hist["content"]))
# Create streaming chain
streaming_chain = (
{
"context": lambda x: context,
"chat_history": lambda x: chat_history_messages,
"question": lambda x: x["question"]
}
| prompt
| streaming_llm
| StrOutputParser()
)
# Generate streaming response
full_response = ""
async for chunk in streaming_chain.astream({"question": message}):
if chunk:
full_response += chunk
yield chunk
# Save assistant response
if full_response:
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
message_metadata={
"knowledge_base_id": knowledge_base_id,
"relevant_docs_count": len(relevant_docs)
}
)
except Exception as e:
logger.error(f"Error in knowledge base streaming chat: {str(e)}")
error_message = f"知识库对话出错: {str(e)}"
yield error_message
# Save error message
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT
)
async def search_knowledge_base(
self,
knowledge_base_id: int,
query: str,
k: int = 5
) -> List[Dict[str, Any]]:
"""Search knowledge base for relevant documents."""
try:
vector_store = self._get_vector_store(knowledge_base_id)
if not vector_store:
return []
# Perform similarity search
results = vector_store.similarity_search_with_score(query, k=k)
formatted_results = []
for doc, score in results:
formatted_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"similarity_score": float(score),
"source": doc.metadata.get("filename", "unknown")
})
return formatted_results
except Exception as e:
logger.error(f"Knowledge base search failed: {str(e)}")
return []