1005 lines
43 KiB
Python
1005 lines
43 KiB
Python
|
|
"""文档处理服务,负责文档的分段、向量化和索引"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import logging
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
from pathlib import Path
|
|||
|
|
from urllib.parse import quote
|
|||
|
|
from sqlalchemy import create_engine, text
|
|||
|
|
from sqlalchemy.orm import sessionmaker
|
|||
|
|
from sqlalchemy.pool import QueuePool
|
|||
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# Try to import pdfplumber with exception handling
|
|||
|
|
try:
|
|||
|
|
import pdfplumber
|
|||
|
|
PDFPLUMBER_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
logger.warning("pdfplumber not available. PDF processing features will be disabled.")
|
|||
|
|
pdfplumber = None
|
|||
|
|
PDFPLUMBER_AVAILABLE = False
|
|||
|
|
|
|||
|
|
# Try to import langchain_core and langchain_postgres with exception handling
|
|||
|
|
try:
|
|||
|
|
from langchain_core.documents import Document
|
|||
|
|
from langchain_postgres import PGVector
|
|||
|
|
LANGCHAIN_CORE_AVAILABLE = True
|
|||
|
|
LANGCHAIN_POSTGRES_AVAILABLE = True
|
|||
|
|
except ImportError as e:
|
|||
|
|
logger.warning(f"Some langchain modules not available: {e}. Document processing features may be limited.")
|
|||
|
|
Document = None
|
|||
|
|
PGVector = None
|
|||
|
|
LANGCHAIN_CORE_AVAILABLE = False
|
|||
|
|
LANGCHAIN_POSTGRES_AVAILABLE = False
|
|||
|
|
from typing import List
|
|||
|
|
# 旧的ZhipuEmbeddings类已移除,现在统一使用EmbeddingFactory创建embedding实例
|
|||
|
|
|
|||
|
|
from ..core.config import settings
|
|||
|
|
from ..utils.file_utils import FileUtils
|
|||
|
|
from ..models.knowledge_base import Document as DocumentModel
|
|||
|
|
from ..db.database import get_db
|
|||
|
|
|
|||
|
|
# Try to import document loaders with exception handling
|
|||
|
|
try:
|
|||
|
|
from langchain_community.document_loaders import (
|
|||
|
|
TextLoader,
|
|||
|
|
PyPDFLoader,
|
|||
|
|
Docx2txtLoader,
|
|||
|
|
UnstructuredMarkdownLoader
|
|||
|
|
)
|
|||
|
|
DOCUMENT_LOADERS_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
logger.warning("langchain_community.document_loaders not available. Document processing features will be disabled.")
|
|||
|
|
TextLoader = None
|
|||
|
|
PyPDFLoader = None
|
|||
|
|
Docx2txtLoader = None
|
|||
|
|
UnstructuredMarkdownLoader = None
|
|||
|
|
DOCUMENT_LOADERS_AVAILABLE = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PGVectorConnectionPool:
|
|||
|
|
"""PGVector连接池管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.engine = None
|
|||
|
|
self.SessionLocal = None
|
|||
|
|
self._init_connection_pool()
|
|||
|
|
|
|||
|
|
def _init_connection_pool(self):
|
|||
|
|
"""初始化连接池"""
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# 构建连接字符串,对密码进行URL编码以处理特殊字符(如@符号)
|
|||
|
|
encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
|||
|
|
connection_string = (
|
|||
|
|
f"postgresql://{settings.vector_db.pgvector_user}:"
|
|||
|
|
f"{encoded_password}@"
|
|||
|
|
f"{settings.vector_db.pgvector_host}:"
|
|||
|
|
f"{settings.vector_db.pgvector_port}/"
|
|||
|
|
f"{settings.vector_db.pgvector_database}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建SQLAlchemy引擎,配置连接池
|
|||
|
|
self.engine = create_engine(
|
|||
|
|
connection_string,
|
|||
|
|
poolclass=QueuePool,
|
|||
|
|
pool_size=5, # 连接池大小
|
|||
|
|
max_overflow=10, # 最大溢出连接数
|
|||
|
|
pool_pre_ping=True, # 连接前ping检查
|
|||
|
|
pool_recycle=3600, # 连接回收时间(秒)
|
|||
|
|
echo=False # 是否打印SQL语句
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建会话工厂
|
|||
|
|
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
|||
|
|
logger.info(f"PGVector连接池已初始化: {settings.vector_db.pgvector_host}:{settings.vector_db.pgvector_port}")
|
|||
|
|
|
|||
|
|
def get_session(self):
|
|||
|
|
"""获取数据库会话"""
|
|||
|
|
if self.SessionLocal is None:
|
|||
|
|
raise RuntimeError("连接池未初始化")
|
|||
|
|
return self.SessionLocal()
|
|||
|
|
|
|||
|
|
def execute_query(self, query: str, params: tuple = None):
|
|||
|
|
"""执行查询并返回结果"""
|
|||
|
|
session = self.get_session()
|
|||
|
|
try:
|
|||
|
|
result = session.execute(text(query), params or {})
|
|||
|
|
return result.fetchall()
|
|||
|
|
finally:
|
|||
|
|
session.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DocumentProcessor:
|
|||
|
|
"""文档处理器,负责文档的加载、分段和向量化"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
# 初始化语义分割器配置
|
|||
|
|
self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled
|
|||
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|||
|
|
chunk_size=settings.file.chunk_size,
|
|||
|
|
chunk_overlap=settings.file.chunk_overlap,
|
|||
|
|
length_function=len,
|
|||
|
|
separators=["\n\n", "\n", " ", ""]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 初始化嵌入模型 - 根据配置选择提供商
|
|||
|
|
self._init_embeddings()
|
|||
|
|
|
|||
|
|
# 初始化连接池(仅对PGVector)
|
|||
|
|
self.pgvector_pool = None
|
|||
|
|
|
|||
|
|
# PostgreSQL pgvector连接配置
|
|||
|
|
print('settings.vector_db.type=============', settings.vector_db.type)
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# 新版本PGVector使用psycopg3连接字符串
|
|||
|
|
# 对密码进行URL编码以处理特殊字符(如@符号)
|
|||
|
|
encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
|||
|
|
self.connection_string = (
|
|||
|
|
f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
|
|||
|
|
f"{encoded_password}@"
|
|||
|
|
f"{settings.vector_db.pgvector_host}:"
|
|||
|
|
f"{settings.vector_db.pgvector_port}/"
|
|||
|
|
f"{settings.vector_db.pgvector_database}"
|
|||
|
|
)
|
|||
|
|
# 初始化连接池
|
|||
|
|
self.pgvector_pool = PGVectorConnectionPool()
|
|||
|
|
else:
|
|||
|
|
# 向量数据库存储路径(Chroma兼容)
|
|||
|
|
vector_db_path = settings.vector_db.persist_directory
|
|||
|
|
if not os.path.isabs(vector_db_path):
|
|||
|
|
# 如果是相对路径,则基于项目根目录计算绝对路径
|
|||
|
|
# 项目根目录是backend的父目录
|
|||
|
|
backend_dir = Path(__file__).parent.parent.parent
|
|||
|
|
vector_db_path = str(backend_dir / vector_db_path)
|
|||
|
|
self.vector_db_path = vector_db_path
|
|||
|
|
|
|||
|
|
def _init_embeddings(self):
|
|||
|
|
"""根据配置初始化embedding模型"""
|
|||
|
|
from .embedding_factory import EmbeddingFactory
|
|||
|
|
self.embeddings = EmbeddingFactory.create_embeddings()
|
|||
|
|
|
|||
|
|
def load_document(self, file_path: str) -> List[Document]:
|
|||
|
|
"""根据文件类型加载文档"""
|
|||
|
|
file_extension = Path(file_path).suffix.lower()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if file_extension == '.txt':
|
|||
|
|
loader = TextLoader(file_path, encoding='utf-8')
|
|||
|
|
documents = loader.load()
|
|||
|
|
elif file_extension == '.pdf':
|
|||
|
|
# 使用pdfplumber处理PDF文件,更稳定
|
|||
|
|
documents = self._load_pdf_with_pdfplumber(file_path)
|
|||
|
|
elif file_extension == '.docx':
|
|||
|
|
loader = Docx2txtLoader(file_path)
|
|||
|
|
documents = loader.load()
|
|||
|
|
elif file_extension == '.md':
|
|||
|
|
loader = UnstructuredMarkdownLoader(file_path)
|
|||
|
|
documents = loader.load()
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的文件类型: {file_extension}")
|
|||
|
|
|
|||
|
|
logger.info(f"成功加载文档: {file_path}, 页数: {len(documents)}")
|
|||
|
|
return documents
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载文档失败 {file_path}: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]:
|
|||
|
|
"""使用pdfplumber加载PDF文档"""
|
|||
|
|
documents = []
|
|||
|
|
try:
|
|||
|
|
with pdfplumber.open(file_path) as pdf:
|
|||
|
|
for page_num, page in enumerate(pdf.pages):
|
|||
|
|
text = page.extract_text()
|
|||
|
|
if text and text.strip(): # 只处理有文本内容的页面
|
|||
|
|
doc = Document(
|
|||
|
|
page_content=text,
|
|||
|
|
metadata={
|
|||
|
|
"source": file_path,
|
|||
|
|
"page": page_num + 1
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
documents.append(doc)
|
|||
|
|
return documents
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"使用pdfplumber加载PDF失败 {file_path}: {str(e)}")
|
|||
|
|
# 如果pdfplumber失败,回退到PyPDFLoader
|
|||
|
|
try:
|
|||
|
|
loader = PyPDFLoader(file_path)
|
|||
|
|
return loader.load()
|
|||
|
|
except Exception as fallback_e:
|
|||
|
|
logger.error(f"PyPDFLoader回退也失败 {file_path}: {str(fallback_e)}")
|
|||
|
|
raise fallback_e
|
|||
|
|
|
|||
|
|
def _merge_documents(self, documents: List[Document]) -> Document:
|
|||
|
|
"""将多个文档合并成一个文档"""
|
|||
|
|
merged_text = ""
|
|||
|
|
merged_metadata = {}
|
|||
|
|
|
|||
|
|
for doc in documents:
|
|||
|
|
if merged_text:
|
|||
|
|
merged_text += "\n\n"
|
|||
|
|
merged_text += doc.page_content
|
|||
|
|
# 合并元数据
|
|||
|
|
merged_metadata.update(doc.metadata)
|
|||
|
|
|
|||
|
|
return Document(page_content=merged_text, metadata=merged_metadata)
|
|||
|
|
|
|||
|
|
def _get_semantic_split_points(self, text: str) -> List[str]:
|
|||
|
|
"""使用大模型分析文档内容,返回合适的分割点列表"""
|
|||
|
|
try:
|
|||
|
|
from langchain.chat_models import ChatOpenAI
|
|||
|
|
from ..core.config import get_settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
prompt = f"""
|
|||
|
|
# 任务说明
|
|||
|
|
请分析文档内容,识别出适合作为分割点的关键位置。分割点应该是能够将文档划分为有意义段落的文本片段。
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 分割规则
|
|||
|
|
请严格按照以下规则识别分割点:
|
|||
|
|
|
|||
|
|
## 基本要求
|
|||
|
|
1. 分割点必须是完整的句子开头或段落开头
|
|||
|
|
2. 每个分割后的部分应包含相对完整的语义内容
|
|||
|
|
3. 每个分割部分的理想长度控制在500字以内,严禁超过1000字。如果超过了1000字,要强制分段。
|
|||
|
|
|
|||
|
|
## 短段落处理
|
|||
|
|
4. 如果某部分长度可能小于50字,应将其与后续内容合并,避免产生过短片段
|
|||
|
|
|
|||
|
|
## 唯一性保证(重要)
|
|||
|
|
5. 确保每个分割点在文档中具有唯一性:
|
|||
|
|
- 检查文内是否存在相同的文本片段
|
|||
|
|
- 如果存在重复,需要扩展分割点字符串,直到获得唯一标识
|
|||
|
|
- 扩展方法:在当前分割点后追加几个字符,形成更长的唯一字符串
|
|||
|
|
|
|||
|
|
## 示例说明
|
|||
|
|
原始文档:
|
|||
|
|
"目录:
|
|||
|
|
第一章 标题一
|
|||
|
|
第二章 标题二
|
|||
|
|
正文
|
|||
|
|
第一章 标题一
|
|||
|
|
这是第一章的内容
|
|||
|
|
|
|||
|
|
第二章 标题二
|
|||
|
|
这是第二章的内容"
|
|||
|
|
|
|||
|
|
错误分割点:"第一章 标题一"(在目录和正文中重复出现)
|
|||
|
|
|
|||
|
|
正确分割点:"第一章 标题一\n这是第"(通过追加内容确保唯一性)
|
|||
|
|
|
|||
|
|
# 输出格式
|
|||
|
|
- 只返回分割点文本字符串
|
|||
|
|
- 每个分割点用~~分隔
|
|||
|
|
- 不要包含任何其他内容或解释
|
|||
|
|
|
|||
|
|
示例输出:分割点1~~分割点2~~分割点3
|
|||
|
|
|
|||
|
|
|
|||
|
|
文档内容:
|
|||
|
|
{text[:10000]} # 限制输入长度
|
|||
|
|
"""
|
|||
|
|
from ..core.llm import create_llm
|
|||
|
|
llm = create_llm(temperature=0.2)
|
|||
|
|
|
|||
|
|
response = llm.invoke(prompt)
|
|||
|
|
|
|||
|
|
# 解析响应获取分割点列表
|
|||
|
|
split_points = [point.strip() for point in response.content.split('~~') if point.strip()]
|
|||
|
|
logger.info(f"语义分析得到 {len(split_points)} 个分割点")
|
|||
|
|
return split_points
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取语义分割点失败: {str(e)}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]:
|
|||
|
|
"""根据语义分割点切分文本"""
|
|||
|
|
chunks = []
|
|||
|
|
current_pos = 0
|
|||
|
|
|
|||
|
|
# 按顺序查找每个分割点并切分文本
|
|||
|
|
for point in split_points:
|
|||
|
|
pos = text.find(point, current_pos)
|
|||
|
|
if pos != -1:
|
|||
|
|
# 添加当前位置到分割点位置的文本块
|
|||
|
|
if pos > current_pos:
|
|||
|
|
chunk = text[current_pos:pos].strip()
|
|||
|
|
if chunk:
|
|||
|
|
chunks.append(chunk)
|
|||
|
|
current_pos = pos
|
|||
|
|
|
|||
|
|
# 添加最后一个文本块
|
|||
|
|
if current_pos < len(text):
|
|||
|
|
chunk = text[current_pos:].strip()
|
|||
|
|
if chunk:
|
|||
|
|
chunks.append(chunk)
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
def split_documents(self, documents: List[Document]) -> List[Document]:
|
|||
|
|
"""将文档分割成小块(含短段落合并和超长强制分割功能)"""
|
|||
|
|
try:
|
|||
|
|
if self.semantic_splitter_enabled and documents:
|
|||
|
|
# 1. 合并文档
|
|||
|
|
merged_doc = self._merge_documents(documents)
|
|||
|
|
|
|||
|
|
# 2. 获取语义分割点
|
|||
|
|
split_points = self._get_semantic_split_points(merged_doc.page_content)
|
|||
|
|
|
|||
|
|
if split_points:
|
|||
|
|
# 3. 根据语义分割点切分文本
|
|||
|
|
text_chunks = self._split_by_semantic_points(merged_doc.page_content, split_points)
|
|||
|
|
|
|||
|
|
# 4. 处理短段落合并和超长强制分割(新增逻辑)
|
|||
|
|
processed_chunks = []
|
|||
|
|
buffer = ""
|
|||
|
|
for chunk in text_chunks:
|
|||
|
|
# 先检查当前chunk是否超长(超过1000字符)
|
|||
|
|
if len(chunk) > 1000:
|
|||
|
|
# 如果有缓冲内容,先处理缓冲
|
|||
|
|
if buffer:
|
|||
|
|
processed_chunks.append(buffer)
|
|||
|
|
buffer = ""
|
|||
|
|
|
|||
|
|
# 对超长chunk进行强制分割
|
|||
|
|
forced_splits = self._force_split_long_chunk(chunk)
|
|||
|
|
processed_chunks.extend(forced_splits)
|
|||
|
|
else:
|
|||
|
|
# 正常处理短段落合并
|
|||
|
|
if not buffer:
|
|||
|
|
buffer = chunk
|
|||
|
|
else:
|
|||
|
|
if len(buffer) < 100:
|
|||
|
|
buffer = f"{buffer}\n{chunk}"
|
|||
|
|
else:
|
|||
|
|
processed_chunks.append(buffer)
|
|||
|
|
buffer = chunk
|
|||
|
|
|
|||
|
|
# 添加最后剩余的缓冲内容
|
|||
|
|
if buffer:
|
|||
|
|
processed_chunks.append(buffer)
|
|||
|
|
|
|||
|
|
# 5. 创建Document对象
|
|||
|
|
chunks = []
|
|||
|
|
for i, chunk in enumerate(processed_chunks):
|
|||
|
|
doc = Document(
|
|||
|
|
page_content=chunk,
|
|||
|
|
metadata={
|
|||
|
|
**merged_doc.metadata,
|
|||
|
|
'chunk_index': i,
|
|||
|
|
'merged': len(chunk) > 100, # 标记是否经过合并
|
|||
|
|
'forced_split': len(chunk) > 1000 # 标记是否经过强制分割
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
chunks.append(doc)
|
|||
|
|
else:
|
|||
|
|
# 如果获取分割点失败,回退到默认分割器
|
|||
|
|
logger.warning("语义分割失败,使用默认分割器")
|
|||
|
|
chunks = self.text_splitter.split_documents(documents)
|
|||
|
|
else:
|
|||
|
|
# 使用默认分割器
|
|||
|
|
chunks = self.text_splitter.split_documents(documents)
|
|||
|
|
|
|||
|
|
logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块")
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文档分割失败: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def _force_split_long_chunk(self, chunk: str) -> List[str]:
|
|||
|
|
"""强制分割超长段落(超过1000字符)"""
|
|||
|
|
max_length = 1000
|
|||
|
|
chunks = []
|
|||
|
|
|
|||
|
|
# 先尝试按换行符分割
|
|||
|
|
if '\n' in chunk:
|
|||
|
|
lines = chunk.split('\n')
|
|||
|
|
current_chunk = ""
|
|||
|
|
for line in lines:
|
|||
|
|
if len(current_chunk) + len(line) + 1 > max_length:
|
|||
|
|
if current_chunk:
|
|||
|
|
chunks.append(current_chunk)
|
|||
|
|
current_chunk = line
|
|||
|
|
else:
|
|||
|
|
chunks.append(line[:max_length])
|
|||
|
|
current_chunk = line[max_length:]
|
|||
|
|
else:
|
|||
|
|
if current_chunk:
|
|||
|
|
current_chunk += "\n" + line
|
|||
|
|
else:
|
|||
|
|
current_chunk = line
|
|||
|
|
if current_chunk:
|
|||
|
|
chunks.append(current_chunk)
|
|||
|
|
else:
|
|||
|
|
# 没有换行符则直接按长度分割
|
|||
|
|
chunks = [chunk[i:i + max_length] for i in range(0, len(chunk), max_length)]
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str:
|
|||
|
|
"""为知识库创建向量存储"""
|
|||
|
|
try:
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# 添加元数据
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
doc.metadata.update({
|
|||
|
|
"knowledge_base_id": knowledge_base_id,
|
|||
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|||
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|||
|
|
"chunk_index": i
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 创建PostgreSQL pgvector存储
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
|
|||
|
|
# 创建新版本PGVector实例
|
|||
|
|
vector_store = PGVector(
|
|||
|
|
connection=self.connection_string,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
use_jsonb=True # 使用JSONB存储元数据
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 手动添加文档
|
|||
|
|
vector_store.add_documents(documents)
|
|||
|
|
|
|||
|
|
logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
|
|||
|
|
return collection_name
|
|||
|
|
else:
|
|||
|
|
# Chroma兼容模式
|
|||
|
|
from langchain_community.vectorstores import Chroma
|
|||
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|||
|
|
|
|||
|
|
# 添加元数据
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
doc.metadata.update({
|
|||
|
|
"knowledge_base_id": knowledge_base_id,
|
|||
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|||
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|||
|
|
"chunk_index": i
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 创建向量存储
|
|||
|
|
vector_store = Chroma.from_documents(
|
|||
|
|
documents=documents,
|
|||
|
|
embedding=self.embeddings,
|
|||
|
|
persist_directory=kb_vector_path
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 持久化向量存储
|
|||
|
|
vector_store.persist()
|
|||
|
|
|
|||
|
|
logger.info(f"向量存储创建成功: {kb_vector_path}")
|
|||
|
|
return kb_vector_path
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"创建向量存储失败: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def add_documents_to_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
|
|||
|
|
"""向现有向量存储添加文档"""
|
|||
|
|
try:
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# 添加元数据
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
doc.metadata.update({
|
|||
|
|
"knowledge_base_id": knowledge_base_id,
|
|||
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|||
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|||
|
|
"chunk_index": i
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# PostgreSQL pgvector存储
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
try:
|
|||
|
|
# 连接到现有集合
|
|||
|
|
vector_store = PGVector(
|
|||
|
|
connection=self.connection_string,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
use_jsonb=True
|
|||
|
|
)
|
|||
|
|
# 添加新文档
|
|||
|
|
vector_store.add_documents(documents)
|
|||
|
|
except Exception as e:
|
|||
|
|
# 如果集合不存在,创建新的向量存储
|
|||
|
|
logger.warning(f"连接现有向量存储失败,创建新的向量存储: {e}")
|
|||
|
|
self.create_vector_store(knowledge_base_id, documents, document_id)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.info(f"文档已添加到PostgreSQL pgvector存储: {collection_name}")
|
|||
|
|
else:
|
|||
|
|
# Chroma兼容模式
|
|||
|
|
from langchain_community.vectorstores import Chroma
|
|||
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|||
|
|
|
|||
|
|
# 检查向量存储是否存在
|
|||
|
|
if not os.path.exists(kb_vector_path):
|
|||
|
|
# 如果不存在,创建新的向量存储
|
|||
|
|
self.create_vector_store(knowledge_base_id, documents, document_id)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 添加元数据
|
|||
|
|
for i, doc in enumerate(documents):
|
|||
|
|
doc.metadata.update({
|
|||
|
|
"knowledge_base_id": knowledge_base_id,
|
|||
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|||
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|||
|
|
"chunk_index": i
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 加载现有向量存储
|
|||
|
|
vector_store = Chroma(
|
|||
|
|
persist_directory=kb_vector_path,
|
|||
|
|
embedding_function=self.embeddings
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 添加新文档
|
|||
|
|
vector_store.add_documents(documents)
|
|||
|
|
vector_store.persist()
|
|||
|
|
|
|||
|
|
logger.info(f"文档已添加到向量存储: {kb_vector_path}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"添加文档到向量存储失败: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def process_document(self, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
|
|||
|
|
"""处理单个文档:加载、分段、向量化"""
|
|||
|
|
try:
|
|||
|
|
logger.info(f"开始处理文档 ID: {document_id}, 路径: {file_path}")
|
|||
|
|
|
|||
|
|
# 1. 加载文档
|
|||
|
|
documents = self.load_document(file_path)
|
|||
|
|
|
|||
|
|
# 2. 分割文档
|
|||
|
|
chunks = self.split_documents(documents)
|
|||
|
|
|
|||
|
|
# 3. 添加到向量存储
|
|||
|
|
self.add_documents_to_vector_store(knowledge_base_id, chunks, document_id)
|
|||
|
|
|
|||
|
|
# 4. 更新文档状态
|
|||
|
|
with next(get_db()) as db:
|
|||
|
|
document = db.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
|||
|
|
if document:
|
|||
|
|
document.status = "processed"
|
|||
|
|
document.chunk_count = len(chunks)
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
"document_id": document_id,
|
|||
|
|
"status": "success",
|
|||
|
|
"chunks_count": len(chunks),
|
|||
|
|
"message": "文档处理完成"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"文档处理完成: {result}")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文档处理失败 ID: {document_id}: {str(e)}")
|
|||
|
|
|
|||
|
|
# 更新文档状态为失败
|
|||
|
|
try:
|
|||
|
|
with next(get_db()) as db:
|
|||
|
|
document = db.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
|||
|
|
if document:
|
|||
|
|
document.status = "failed"
|
|||
|
|
document.error_message = str(e)
|
|||
|
|
db.commit()
|
|||
|
|
except Exception as db_error:
|
|||
|
|
logger.error(f"更新文档状态失败: {str(db_error)}")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"document_id": document_id,
|
|||
|
|
"status": "failed",
|
|||
|
|
"error": str(e),
|
|||
|
|
"message": "文档处理失败"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _get_document_ids_from_vector_store(self, knowledge_base_id: int, document_id: int) -> List[str]:
|
|||
|
|
"""查询指定document_id的所有向量记录的uuid"""
|
|||
|
|
try:
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
|
|||
|
|
# 使用连接池执行查询
|
|||
|
|
if self.pgvector_pool:
|
|||
|
|
query = f"""
|
|||
|
|
SELECT uuid FROM langchain_pg_embedding
|
|||
|
|
WHERE collection_id = (
|
|||
|
|
SELECT uuid FROM langchain_pg_collection
|
|||
|
|
WHERE name = %s
|
|||
|
|
) AND cmetadata->>'document_id' = %s
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
result = self.pgvector_pool.execute_query(query, (collection_name, str(document_id)))
|
|||
|
|
return [row[0] for row in result] if result else []
|
|||
|
|
else:
|
|||
|
|
logger.warning("PGVector连接池未初始化")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"查询文档向量记录失败: {str(e)}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None:
|
|||
|
|
"""从向量存储中删除文档"""
|
|||
|
|
try:
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# PostgreSQL pgvector存储
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 创建新版本PGVector实例
|
|||
|
|
vector_store = PGVector(
|
|||
|
|
connection=self.connection_string,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
use_jsonb=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 直接从数据库查询要删除的文档UUID
|
|||
|
|
try:
|
|||
|
|
from sqlalchemy import text
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
# 获取数据库引擎
|
|||
|
|
engine = vector_store._engine
|
|||
|
|
|
|||
|
|
with Session(engine) as session:
|
|||
|
|
# 查询匹配document_id的所有记录的ID
|
|||
|
|
query_sql = text(
|
|||
|
|
f"SELECT id FROM langchain_pg_embedding "
|
|||
|
|
f"WHERE cmetadata->>'document_id' = :doc_id"
|
|||
|
|
)
|
|||
|
|
result = session.execute(query_sql, {"doc_id": str(document_id)})
|
|||
|
|
ids_to_delete = [row[0] for row in result.fetchall()]
|
|||
|
|
|
|||
|
|
if ids_to_delete:
|
|||
|
|
# 使用ID删除文档
|
|||
|
|
vector_store.delete(ids=ids_to_delete)
|
|||
|
|
logger.info(f"成功删除 {len(ids_to_delete)} 个文档块: document_id={document_id}")
|
|||
|
|
else:
|
|||
|
|
logger.warning(f"未找到要删除的文档ID: document_id={document_id}")
|
|||
|
|
|
|||
|
|
except Exception as query_error:
|
|||
|
|
logger.error(f"查询要删除的文档时出错: {query_error}")
|
|||
|
|
# 如果查询失败,说明文档可能不存在
|
|||
|
|
logger.warning(f"无法查询到要删除的文档: document_id={document_id}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.info(f"文档已从PostgreSQL pgvector存储中删除: document_id={document_id}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"PostgreSQL pgvector存储不存在或删除失败: {collection_name}, {str(e)}")
|
|||
|
|
else:
|
|||
|
|
# Chroma兼容模式
|
|||
|
|
from langchain_community.vectorstores import Chroma
|
|||
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|||
|
|
|
|||
|
|
if not os.path.exists(kb_vector_path):
|
|||
|
|
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 加载向量存储
|
|||
|
|
vector_store = Chroma(
|
|||
|
|
persist_directory=kb_vector_path,
|
|||
|
|
embedding_function=self.embeddings
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 删除相关文档块(这里需要根据实际的Chroma API来实现)
|
|||
|
|
# 注意:Chroma的删除功能可能需要特定的实现方式
|
|||
|
|
logger.info(f"文档已从向量存储中删除: document_id={document_id}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"从向量存储删除文档失败: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_document_chunks(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|||
|
|
"""获取文档的所有分段内容
|
|||
|
|
|
|||
|
|
改进说明:
|
|||
|
|
- 避免使用空查询进行相似性搜索,防止触发不必要的embedding API调用
|
|||
|
|
- 优先使用直接SQL查询,提高性能
|
|||
|
|
- 确保结果按chunk_index排序
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# PostgreSQL pgvector存储 - 使用直接SQL查询避免相似性搜索
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 尝试直接SQL查询(推荐方法)
|
|||
|
|
chunks = self._get_chunks_by_sql(knowledge_base_id, document_id)
|
|||
|
|
if chunks:
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
# 如果SQL查询失败,回退到改进的LangChain方法
|
|||
|
|
logger.info("SQL查询失败,使用LangChain回退方案")
|
|||
|
|
return self._get_chunks_by_langchain_improved(knowledge_base_id, document_id, collection_name)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"PostgreSQL pgvector存储访问失败: {collection_name}, {str(e)}")
|
|||
|
|
return []
|
|||
|
|
else:
|
|||
|
|
# Chroma兼容模式
|
|||
|
|
return self._get_chunks_chroma(knowledge_base_id, document_id)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _get_chunks_by_sql(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|||
|
|
"""使用SQLAlchemy连接池查询获取文档分段(推荐方法)"""
|
|||
|
|
try:
|
|||
|
|
if not self.pgvector_pool:
|
|||
|
|
logger.error("PGVector连接池未初始化")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 直接SQL查询,避免相似性搜索和embedding计算
|
|||
|
|
query = """
|
|||
|
|
SELECT
|
|||
|
|
id,
|
|||
|
|
document,
|
|||
|
|
cmetadata
|
|||
|
|
FROM langchain_pg_embedding
|
|||
|
|
WHERE cmetadata->>'document_id' = :document_id
|
|||
|
|
AND cmetadata->>'knowledge_base_id' = :knowledge_base_id
|
|||
|
|
ORDER BY
|
|||
|
|
CAST(cmetadata->>'chunk_index' AS INTEGER) ASC;
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 使用连接池执行查询
|
|||
|
|
session = self.pgvector_pool.get_session()
|
|||
|
|
try:
|
|||
|
|
result = session.execute(
|
|||
|
|
text(query),
|
|||
|
|
{
|
|||
|
|
'document_id': str(document_id),
|
|||
|
|
'knowledge_base_id': str(knowledge_base_id)
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
results = result.fetchall()
|
|||
|
|
|
|||
|
|
chunks = []
|
|||
|
|
for row in results:
|
|||
|
|
# SQLAlchemy结果行访问
|
|||
|
|
metadata = row.cmetadata
|
|||
|
|
chunk = {
|
|||
|
|
"id": f"chunk_{document_id}_{metadata.get('chunk_index', 0)}",
|
|||
|
|
"content": row.document,
|
|||
|
|
"metadata": metadata,
|
|||
|
|
"page_number": metadata.get("page"),
|
|||
|
|
"chunk_index": metadata.get("chunk_index", 0),
|
|||
|
|
"start_char": metadata.get("start_char"),
|
|||
|
|
"end_char": metadata.get("end_char")
|
|||
|
|
}
|
|||
|
|
chunks.append(chunk)
|
|||
|
|
|
|||
|
|
logger.info(f"通过SQLAlchemy连接池查询获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
session.close()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"SQLAlchemy连接池查询失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _get_chunks_by_langchain_improved(self, knowledge_base_id: int, document_id: int, collection_name: str) -> List[Dict[str, Any]]:
|
|||
|
|
"""改进的LangChain查询方法(回退方案)"""
|
|||
|
|
try:
|
|||
|
|
vector_store = PGVector(
|
|||
|
|
connection=self.connection_string,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
use_jsonb=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 使用有意义的查询而不是空查询,避免触发embedding API错误
|
|||
|
|
# 先尝试获取少量结果来构造查询
|
|||
|
|
try:
|
|||
|
|
sample_results = vector_store.similarity_search(
|
|||
|
|
query="文档内容", # 使用通用查询词而非空字符串
|
|||
|
|
k=5,
|
|||
|
|
filter={"document_id": {"$eq": str(document_id)}}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if sample_results:
|
|||
|
|
# 使用第一个结果的内容片段作为查询
|
|||
|
|
first_content = sample_results[0].page_content[:50]
|
|||
|
|
results = vector_store.similarity_search(
|
|||
|
|
query=first_content,
|
|||
|
|
k=1000,
|
|||
|
|
filter={"document_id": {"$eq": str(document_id)}}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 如果没有结果,尝试不使用filter的查询
|
|||
|
|
results = vector_store.similarity_search(
|
|||
|
|
query="文档",
|
|||
|
|
k=1000
|
|||
|
|
)
|
|||
|
|
# 手动过滤结果
|
|||
|
|
results = [doc for doc in results if doc.metadata.get("document_id") == str(document_id)]
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"改进的相似性搜索失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
chunks = []
|
|||
|
|
for i, doc in enumerate(results):
|
|||
|
|
chunk = {
|
|||
|
|
"id": f"chunk_{document_id}_{i}",
|
|||
|
|
"content": doc.page_content,
|
|||
|
|
"metadata": doc.metadata,
|
|||
|
|
"page_number": doc.metadata.get("page"),
|
|||
|
|
"chunk_index": doc.metadata.get("chunk_index", i),
|
|||
|
|
"start_char": doc.metadata.get("start_char"),
|
|||
|
|
"end_char": doc.metadata.get("end_char")
|
|||
|
|
}
|
|||
|
|
chunks.append(chunk)
|
|||
|
|
|
|||
|
|
# 按chunk_index排序
|
|||
|
|
chunks.sort(key=lambda x: x.get("chunk_index", 0))
|
|||
|
|
|
|||
|
|
logger.info(f"通过改进的LangChain方法获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"LangChain改进方法失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|||
|
|
"""Chroma存储的处理逻辑"""
|
|||
|
|
try:
|
|||
|
|
from langchain_community.vectorstores import Chroma
|
|||
|
|
|
|||
|
|
# 构建向量数据库路径
|
|||
|
|
vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|||
|
|
|
|||
|
|
if not os.path.exists(vector_db_path):
|
|||
|
|
logger.warning(f"向量数据库不存在: {vector_db_path}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 加载向量数据库
|
|||
|
|
vectorstore = Chroma(
|
|||
|
|
persist_directory=vector_db_path,
|
|||
|
|
embedding_function=self.embeddings
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取所有文档的元数据,筛选出指定文档的分段
|
|||
|
|
collection = vectorstore._collection
|
|||
|
|
all_docs = collection.get(include=["metadatas", "documents"])
|
|||
|
|
|
|||
|
|
chunks = []
|
|||
|
|
chunk_index = 0
|
|||
|
|
|
|||
|
|
for i, metadata in enumerate(all_docs["metadatas"]):
|
|||
|
|
if metadata.get("document_id") == str(document_id):
|
|||
|
|
chunk_content = all_docs["documents"][i]
|
|||
|
|
|
|||
|
|
chunk = {
|
|||
|
|
"id": f"chunk_{document_id}_{chunk_index}",
|
|||
|
|
"content": chunk_content,
|
|||
|
|
"metadata": metadata,
|
|||
|
|
"page_number": metadata.get("page"),
|
|||
|
|
"chunk_index": chunk_index,
|
|||
|
|
"start_char": metadata.get("start_char"),
|
|||
|
|
"end_char": metadata.get("end_char")
|
|||
|
|
}
|
|||
|
|
chunks.append(chunk)
|
|||
|
|
chunk_index += 1
|
|||
|
|
|
|||
|
|
logger.info(f"获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Chroma存储处理失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
|||
|
|
"""在知识库中搜索相似文档"""
|
|||
|
|
try:
|
|||
|
|
if settings.vector_db.type == "pgvector":
|
|||
|
|
# PostgreSQL pgvector存储
|
|||
|
|
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
vector_store = PGVector(
|
|||
|
|
connection=self.connection_string,
|
|||
|
|
embeddings=self.embeddings,
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
use_jsonb=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 执行相似性搜索
|
|||
|
|
results = vector_store.similarity_search_with_score(query, k=k)
|
|||
|
|
|
|||
|
|
# 格式化结果
|
|||
|
|
formatted_results = []
|
|||
|
|
for doc, distance_score in results:
|
|||
|
|
# pgvector使用余弦距离,距离越小相似度越高
|
|||
|
|
# 将距离转换为0-1之间的相似度分数
|
|||
|
|
similarity_score = 1.0 / (1.0 + distance_score)
|
|||
|
|
|
|||
|
|
formatted_results.append({
|
|||
|
|
"content": doc.page_content,
|
|||
|
|
"metadata": doc.metadata,
|
|||
|
|
"similarity_score": distance_score, # 保留原始距离分数
|
|||
|
|
"normalized_score": similarity_score, # 归一化相似度分数
|
|||
|
|
"source": doc.metadata.get('filename', 'unknown'),
|
|||
|
|
"document_id": doc.metadata.get('document_id', 'unknown'),
|
|||
|
|
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 按相似度分数排序(距离越小越相似)
|
|||
|
|
formatted_results.sort(key=lambda x: x['similarity_score'])
|
|||
|
|
|
|||
|
|
logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档")
|
|||
|
|
return formatted_results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
|
|||
|
|
return []
|
|||
|
|
else:
|
|||
|
|
# Chroma兼容模式
|
|||
|
|
from langchain_community.vectorstores import Chroma
|
|||
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|||
|
|
|
|||
|
|
if not os.path.exists(kb_vector_path):
|
|||
|
|
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 加载向量存储
|
|||
|
|
vector_store = Chroma(
|
|||
|
|
persist_directory=kb_vector_path,
|
|||
|
|
embedding_function=self.embeddings
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 执行相似性搜索
|
|||
|
|
results = vector_store.similarity_search_with_score(query, k=k)
|
|||
|
|
|
|||
|
|
# 格式化结果
|
|||
|
|
formatted_results = []
|
|||
|
|
for doc, distance_score in results:
|
|||
|
|
# Chroma使用欧几里得距离,距离越小相似度越高
|
|||
|
|
# 将距离转换为0-1之间的相似度分数
|
|||
|
|
similarity_score = 1.0 / (1.0 + distance_score)
|
|||
|
|
|
|||
|
|
formatted_results.append({
|
|||
|
|
"content": doc.page_content,
|
|||
|
|
"metadata": doc.metadata,
|
|||
|
|
"similarity_score": distance_score, # 保留原始距离分数
|
|||
|
|
"normalized_score": similarity_score, # 归一化相似度分数
|
|||
|
|
"source": doc.metadata.get('filename', 'unknown'),
|
|||
|
|
"document_id": doc.metadata.get('document_id', 'unknown'),
|
|||
|
|
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 按相似度分数排序(距离越小越相似)
|
|||
|
|
formatted_results.sort(key=lambda x: x['similarity_score'])
|
|||
|
|
|
|||
|
|
logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档")
|
|||
|
|
return formatted_results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"搜索文档失败: {str(e)}")
|
|||
|
|
return [] # 返回空列表而不是抛出异常
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局文档处理器实例(延迟初始化)
|
|||
|
|
document_processor = None
|
|||
|
|
|
|||
|
|
def get_document_processor():
|
|||
|
|
"""获取文档处理器实例(延迟初始化)"""
|
|||
|
|
global document_processor
|
|||
|
|
if document_processor is None:
|
|||
|
|
document_processor = DocumentProcessor()
|
|||
|
|
return document_processor
|