973 lines
42 KiB
Python
973 lines
42 KiB
Python
"""文档处理服务,负责文档的分段、向量化和索引"""
|
||
|
||
import os
|
||
from typing import List, Dict, Any, Optional
|
||
from pathlib import Path
|
||
from urllib.parse import quote
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import QueuePool
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from langchain_community.document_loaders import (
|
||
TextLoader,
|
||
PyPDFLoader,
|
||
Docx2txtLoader,
|
||
UnstructuredMarkdownLoader
|
||
)
|
||
import pdfplumber
|
||
from langchain_core.documents import Document
|
||
from langchain_postgres import PGVector
|
||
from typing import List
|
||
# 旧的ZhipuEmbeddings类已移除,现在统一使用EmbeddingFactory创建embedding实例
|
||
|
||
from ..core.config import BaseSettings, get_settings
|
||
from ..models.knowledge_base import Document as DocumentModel
|
||
from ..db.database import get_session
|
||
from loguru import logger
|
||
|
||
settings = get_settings()
|
||
class PGVectorConnectionPool:
|
||
"""PGVector连接池管理器"""
|
||
|
||
def __init__(self):
|
||
logger.error("PGVector连接池管理器 -==== 待异步方式实现")
|
||
self.engine = None
|
||
self.SessionLocal = None
|
||
# self._init_connection_pool()
|
||
|
||
# def _init_connection_pool(self):
|
||
# """初始化连接池"""
|
||
# if settings.vector_db.type == "pgvector":
|
||
# # 构建连接字符串,对密码进行URL编码以处理特殊字符(如@符号)
|
||
# encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
||
# connection_string = (
|
||
# f"postgresql://{settings.vector_db.pgvector_user}:"
|
||
# f"{encoded_password}@"
|
||
# f"{settings.vector_db.pgvector_host}:"
|
||
# f"{settings.vector_db.pgvector_port}/"
|
||
# f"{settings.vector_db.pgvector_database}"
|
||
# )
|
||
|
||
# # 创建SQLAlchemy引擎,配置连接池
|
||
# self.engine = create_engine(
|
||
# connection_string,
|
||
# poolclass=QueuePool,
|
||
# pool_size=5, # 连接池大小
|
||
# max_overflow=10, # 最大溢出连接数
|
||
# pool_pre_ping=True, # 连接前ping检查
|
||
# pool_recycle=3600, # 连接回收时间(秒)
|
||
# echo=False # 是否打印SQL语句
|
||
# )
|
||
|
||
# # 创建会话工厂
|
||
# self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||
# logger.info(f"PGVector连接池已初始化: {settings.vector_db.pgvector_host}:{settings.vector_db.pgvector_port}")
|
||
|
||
# def get_session(self):
|
||
# """获取数据库会话"""
|
||
# if self.SessionLocal is None:
|
||
# raise RuntimeError("连接池未初始化")
|
||
# return self.SessionLocal()
|
||
|
||
# def execute_query(self, query: str, params: tuple = None):
|
||
# """执行查询并返回结果"""
|
||
# session = self.get_session()
|
||
# try:
|
||
# result = session.execute(text(query), params or {})
|
||
# return result.fetchall()
|
||
# finally:
|
||
# session.close()
|
||
|
||
|
||
class DocumentProcessor:
|
||
"""文档处理器,负责文档的加载、分段和向量化"""
|
||
|
||
def __init__(self):
|
||
# 初始化语义分割器配置
|
||
self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=settings.file.chunk_size,
|
||
chunk_overlap=settings.file.chunk_overlap,
|
||
length_function=len,
|
||
separators=["\n\n", "\n", " ", ""]
|
||
)
|
||
|
||
# 初始化嵌入模型 - 根据配置选择提供商
|
||
self._init_embeddings()
|
||
|
||
# 初始化连接池(仅对PGVector)
|
||
self.pgvector_pool = None
|
||
|
||
# PostgreSQL pgvector连接配置
|
||
print('settings.vector_db.type=============', settings.vector_db.type)
|
||
if settings.vector_db.type == "pgvector":
|
||
# 新版本PGVector使用psycopg3连接字符串
|
||
# 对密码进行URL编码以处理特殊字符(如@符号)
|
||
encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
||
self.connection_string = (
|
||
f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
|
||
f"{encoded_password}@"
|
||
f"{settings.vector_db.pgvector_host}:"
|
||
f"{settings.vector_db.pgvector_port}/"
|
||
f"{settings.vector_db.pgvector_database}"
|
||
)
|
||
# 初始化连接池
|
||
self.pgvector_pool = PGVectorConnectionPool()
|
||
else:
|
||
# 向量数据库存储路径(Chroma兼容)
|
||
vector_db_path = settings.vector_db.persist_directory
|
||
if not os.path.isabs(vector_db_path):
|
||
# 如果是相对路径,则基于项目根目录计算绝对路径
|
||
# 项目根目录是backend的父目录
|
||
backend_dir = Path(__file__).parent.parent.parent
|
||
vector_db_path = str(backend_dir / vector_db_path)
|
||
self.vector_db_path = vector_db_path
|
||
|
||
def _init_embeddings(self):
|
||
"""根据配置初始化embedding模型"""
|
||
from .embedding_factory import EmbeddingFactory
|
||
self.embeddings = EmbeddingFactory.create_embeddings()
|
||
|
||
def load_document(self, file_path: str) -> List[Document]:
|
||
"""根据文件类型加载文档"""
|
||
file_extension = Path(file_path).suffix.lower()
|
||
|
||
try:
|
||
if file_extension == '.txt':
|
||
loader = TextLoader(file_path, encoding='utf-8')
|
||
documents = loader.load()
|
||
elif file_extension == '.pdf':
|
||
# 使用pdfplumber处理PDF文件,更稳定
|
||
documents = self._load_pdf_with_pdfplumber(file_path)
|
||
elif file_extension == '.docx':
|
||
loader = Docx2txtLoader(file_path)
|
||
documents = loader.load()
|
||
elif file_extension == '.md':
|
||
loader = UnstructuredMarkdownLoader(file_path)
|
||
documents = loader.load()
|
||
else:
|
||
raise ValueError(f"不支持的文件类型: {file_extension}")
|
||
|
||
logger.info(f"成功加载文档: {file_path}, 页数: {len(documents)}")
|
||
return documents
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载文档失败 {file_path}: {str(e)}")
|
||
raise
|
||
|
||
def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]:
|
||
"""使用pdfplumber加载PDF文档"""
|
||
documents = []
|
||
try:
|
||
with pdfplumber.open(file_path) as pdf:
|
||
for page_num, page in enumerate(pdf.pages):
|
||
text = page.extract_text()
|
||
if text and text.strip(): # 只处理有文本内容的页面
|
||
doc = Document(
|
||
page_content=text,
|
||
metadata={
|
||
"source": file_path,
|
||
"page": page_num + 1
|
||
}
|
||
)
|
||
documents.append(doc)
|
||
return documents
|
||
except Exception as e:
|
||
logger.error(f"使用pdfplumber加载PDF失败 {file_path}: {str(e)}")
|
||
# 如果pdfplumber失败,回退到PyPDFLoader
|
||
try:
|
||
loader = PyPDFLoader(file_path)
|
||
return loader.load()
|
||
except Exception as fallback_e:
|
||
logger.error(f"PyPDFLoader回退也失败 {file_path}: {str(fallback_e)}")
|
||
raise fallback_e
|
||
|
||
def _merge_documents(self, documents: List[Document]) -> Document:
|
||
"""将多个文档合并成一个文档"""
|
||
merged_text = ""
|
||
merged_metadata = {}
|
||
|
||
for doc in documents:
|
||
if merged_text:
|
||
merged_text += "\n\n"
|
||
merged_text += doc.page_content
|
||
# 合并元数据
|
||
merged_metadata.update(doc.metadata)
|
||
|
||
return Document(page_content=merged_text, metadata=merged_metadata)
|
||
|
||
def _get_semantic_split_points(self, text: str) -> List[str]:
|
||
"""使用大模型分析文档内容,返回合适的分割点列表"""
|
||
try:
|
||
from langchain.chat_models import ChatOpenAI
|
||
from ..core.config import get_settings
|
||
|
||
|
||
|
||
prompt = f"""
|
||
# 任务说明
|
||
请分析文档内容,识别出适合作为分割点的关键位置。分割点应该是能够将文档划分为有意义段落的文本片段。
|
||
|
||
|
||
# 分割规则
|
||
请严格按照以下规则识别分割点:
|
||
|
||
## 基本要求
|
||
1. 分割点必须是完整的句子开头或段落开头
|
||
2. 每个分割后的部分应包含相对完整的语义内容
|
||
3. 每个分割部分的理想长度控制在500字以内,严禁超过1000字。如果超过了1000字,要强制分段。
|
||
|
||
## 短段落处理
|
||
4. 如果某部分长度可能小于50字,应将其与后续内容合并,避免产生过短片段
|
||
|
||
## 唯一性保证(重要)
|
||
5. 确保每个分割点在文档中具有唯一性:
|
||
- 检查文内是否存在相同的文本片段
|
||
- 如果存在重复,需要扩展分割点字符串,直到获得唯一标识
|
||
- 扩展方法:在当前分割点后追加几个字符,形成更长的唯一字符串
|
||
|
||
## 示例说明
|
||
原始文档:
|
||
"目录:
|
||
第一章 标题一
|
||
第二章 标题二
|
||
正文
|
||
第一章 标题一
|
||
这是第一章的内容
|
||
|
||
第二章 标题二
|
||
这是第二章的内容"
|
||
|
||
错误分割点:"第一章 标题一"(在目录和正文中重复出现)
|
||
|
||
正确分割点:"第一章 标题一\n这是第"(通过追加内容确保唯一性)
|
||
|
||
# 输出格式
|
||
- 只返回分割点文本字符串
|
||
- 每个分割点用~~分隔
|
||
- 不要包含任何其他内容或解释
|
||
|
||
示例输出:分割点1~~分割点2~~分割点3
|
||
|
||
|
||
文档内容:
|
||
{text[:10000]} # 限制输入长度
|
||
"""
|
||
from ..core.llm import create_llm
|
||
llm = create_llm(temperature=0.2)
|
||
|
||
response = llm.invoke(prompt)
|
||
|
||
# 解析响应获取分割点列表
|
||
split_points = [point.strip() for point in response.content.split('~~') if point.strip()]
|
||
logger.info(f"语义分析得到 {len(split_points)} 个分割点")
|
||
return split_points
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取语义分割点失败: {str(e)}")
|
||
return []
|
||
|
||
def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]:
|
||
"""根据语义分割点切分文本"""
|
||
chunks = []
|
||
current_pos = 0
|
||
|
||
# 按顺序查找每个分割点并切分文本
|
||
for point in split_points:
|
||
pos = text.find(point, current_pos)
|
||
if pos != -1:
|
||
# 添加当前位置到分割点位置的文本块
|
||
if pos > current_pos:
|
||
chunk = text[current_pos:pos].strip()
|
||
if chunk:
|
||
chunks.append(chunk)
|
||
current_pos = pos
|
||
|
||
# 添加最后一个文本块
|
||
if current_pos < len(text):
|
||
chunk = text[current_pos:].strip()
|
||
if chunk:
|
||
chunks.append(chunk)
|
||
|
||
return chunks
|
||
|
||
def split_documents(self, documents: List[Document]) -> List[Document]:
|
||
"""将文档分割成小块(含短段落合并和超长强制分割功能)"""
|
||
try:
|
||
if self.semantic_splitter_enabled and documents:
|
||
# 1. 合并文档
|
||
merged_doc = self._merge_documents(documents)
|
||
|
||
# 2. 获取语义分割点
|
||
split_points = self._get_semantic_split_points(merged_doc.page_content)
|
||
|
||
if split_points:
|
||
# 3. 根据语义分割点切分文本
|
||
text_chunks = self._split_by_semantic_points(merged_doc.page_content, split_points)
|
||
|
||
# 4. 处理短段落合并和超长强制分割(新增逻辑)
|
||
processed_chunks = []
|
||
buffer = ""
|
||
for chunk in text_chunks:
|
||
# 先检查当前chunk是否超长(超过1000字符)
|
||
if len(chunk) > 1000:
|
||
# 如果有缓冲内容,先处理缓冲
|
||
if buffer:
|
||
processed_chunks.append(buffer)
|
||
buffer = ""
|
||
|
||
# 对超长chunk进行强制分割
|
||
forced_splits = self._force_split_long_chunk(chunk)
|
||
processed_chunks.extend(forced_splits)
|
||
else:
|
||
# 正常处理短段落合并
|
||
if not buffer:
|
||
buffer = chunk
|
||
else:
|
||
if len(buffer) < 100:
|
||
buffer = f"{buffer}\n{chunk}"
|
||
else:
|
||
processed_chunks.append(buffer)
|
||
buffer = chunk
|
||
|
||
# 添加最后剩余的缓冲内容
|
||
if buffer:
|
||
processed_chunks.append(buffer)
|
||
|
||
# 5. 创建Document对象
|
||
chunks = []
|
||
for i, chunk in enumerate(processed_chunks):
|
||
doc = Document(
|
||
page_content=chunk,
|
||
metadata={
|
||
**merged_doc.metadata,
|
||
'chunk_index': i,
|
||
'merged': len(chunk) > 100, # 标记是否经过合并
|
||
'forced_split': len(chunk) > 1000 # 标记是否经过强制分割
|
||
}
|
||
)
|
||
chunks.append(doc)
|
||
else:
|
||
# 如果获取分割点失败,回退到默认分割器
|
||
logger.warning("语义分割失败,使用默认分割器")
|
||
chunks = self.text_splitter.split_documents(documents)
|
||
else:
|
||
# 使用默认分割器
|
||
chunks = self.text_splitter.split_documents(documents)
|
||
|
||
logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档分割失败: {str(e)}")
|
||
raise
|
||
|
||
def _force_split_long_chunk(self, chunk: str) -> List[str]:
|
||
"""强制分割超长段落(超过1000字符)"""
|
||
max_length = 1000
|
||
chunks = []
|
||
|
||
# 先尝试按换行符分割
|
||
if '\n' in chunk:
|
||
lines = chunk.split('\n')
|
||
current_chunk = ""
|
||
for line in lines:
|
||
if len(current_chunk) + len(line) + 1 > max_length:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = line
|
||
else:
|
||
chunks.append(line[:max_length])
|
||
current_chunk = line[max_length:]
|
||
else:
|
||
if current_chunk:
|
||
current_chunk += "\n" + line
|
||
else:
|
||
current_chunk = line
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
else:
|
||
# 没有换行符则直接按长度分割
|
||
chunks = [chunk[i:i + max_length] for i in range(0, len(chunk), max_length)]
|
||
|
||
return chunks
|
||
|
||
def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str:
|
||
"""为知识库创建向量存储"""
|
||
try:
|
||
if settings.vector_db.type == "pgvector":
|
||
# 添加元数据
|
||
for i, doc in enumerate(documents):
|
||
doc.metadata.update({
|
||
"knowledge_base_id": knowledge_base_id,
|
||
"document_id": str(document_id) if document_id else "unknown",
|
||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||
"chunk_index": i
|
||
})
|
||
|
||
# 创建PostgreSQL pgvector存储
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
|
||
# 创建新版本PGVector实例
|
||
vector_store = PGVector(
|
||
connection=self.connection_string,
|
||
embeddings=self.embeddings,
|
||
collection_name=collection_name,
|
||
use_jsonb=True # 使用JSONB存储元数据
|
||
)
|
||
|
||
# 手动添加文档
|
||
vector_store.add_documents(documents)
|
||
|
||
logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
|
||
return collection_name
|
||
else:
|
||
# Chroma兼容模式
|
||
from langchain_community.vectorstores import Chroma
|
||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||
|
||
# 添加元数据
|
||
for i, doc in enumerate(documents):
|
||
doc.metadata.update({
|
||
"knowledge_base_id": knowledge_base_id,
|
||
"document_id": str(document_id) if document_id else "unknown",
|
||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||
"chunk_index": i
|
||
})
|
||
|
||
# 创建向量存储
|
||
vector_store = Chroma.from_documents(
|
||
documents=documents,
|
||
embedding=self.embeddings,
|
||
persist_directory=kb_vector_path
|
||
)
|
||
|
||
# 持久化向量存储
|
||
vector_store.persist()
|
||
|
||
logger.info(f"向量存储创建成功: {kb_vector_path}")
|
||
return kb_vector_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建向量存储失败: {str(e)}")
|
||
raise
|
||
|
||
def add_documents_to_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
|
||
"""向现有向量存储添加文档"""
|
||
try:
|
||
if settings.vector_db.type == "pgvector":
|
||
# 添加元数据
|
||
for i, doc in enumerate(documents):
|
||
doc.metadata.update({
|
||
"knowledge_base_id": knowledge_base_id,
|
||
"document_id": str(document_id) if document_id else "unknown",
|
||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||
"chunk_index": i
|
||
})
|
||
|
||
# PostgreSQL pgvector存储
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
try:
|
||
# 连接到现有集合
|
||
vector_store = PGVector(
|
||
connection=self.connection_string,
|
||
embeddings=self.embeddings,
|
||
collection_name=collection_name,
|
||
use_jsonb=True
|
||
)
|
||
# 添加新文档
|
||
vector_store.add_documents(documents)
|
||
except Exception as e:
|
||
# 如果集合不存在,创建新的向量存储
|
||
logger.warning(f"连接现有向量存储失败,创建新的向量存储: {e}")
|
||
self.create_vector_store(knowledge_base_id, documents, document_id)
|
||
return
|
||
|
||
logger.info(f"文档已添加到PostgreSQL pgvector存储: {collection_name}")
|
||
else:
|
||
# Chroma兼容模式
|
||
from langchain_community.vectorstores import Chroma
|
||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||
|
||
# 检查向量存储是否存在
|
||
if not os.path.exists(kb_vector_path):
|
||
# 如果不存在,创建新的向量存储
|
||
self.create_vector_store(knowledge_base_id, documents, document_id)
|
||
return
|
||
|
||
# 添加元数据
|
||
for i, doc in enumerate(documents):
|
||
doc.metadata.update({
|
||
"knowledge_base_id": knowledge_base_id,
|
||
"document_id": str(document_id) if document_id else "unknown",
|
||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||
"chunk_index": i
|
||
})
|
||
|
||
# 加载现有向量存储
|
||
vector_store = Chroma(
|
||
persist_directory=kb_vector_path,
|
||
embedding_function=self.embeddings
|
||
)
|
||
|
||
# 添加新文档
|
||
vector_store.add_documents(documents)
|
||
vector_store.persist()
|
||
|
||
logger.info(f"文档已添加到向量存储: {kb_vector_path}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"添加文档到向量存储失败: {str(e)}")
|
||
raise
|
||
|
||
def process_document(self, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
|
||
"""处理单个文档:加载、分段、向量化"""
|
||
try:
|
||
logger.info(f"开始处理文档 ID: {document_id}, 路径: {file_path}")
|
||
|
||
# 1. 加载文档
|
||
documents = self.load_document(file_path)
|
||
|
||
# 2. 分割文档
|
||
chunks = self.split_documents(documents)
|
||
|
||
# 3. 添加到向量存储
|
||
self.add_documents_to_vector_store(knowledge_base_id, chunks, document_id)
|
||
|
||
# 4. 更新文档状态
|
||
with next(get_session()) as session:
|
||
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
||
if document:
|
||
document.status = "processed"
|
||
document.chunk_count = len(chunks)
|
||
session.commit()
|
||
|
||
result = {
|
||
"document_id": document_id,
|
||
"status": "success",
|
||
"chunks_count": len(chunks),
|
||
"message": "文档处理完成"
|
||
}
|
||
|
||
logger.info(f"文档处理完成: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档处理失败 ID: {document_id}: {str(e)}")
|
||
|
||
# 更新文档状态为失败
|
||
try:
|
||
with next(get_session()) as session:
|
||
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
||
if document:
|
||
document.status = "failed"
|
||
document.error_message = str(e)
|
||
session.commit()
|
||
except Exception as db_error:
|
||
logger.error(f"更新文档状态失败: {str(db_error)}")
|
||
|
||
return {
|
||
"document_id": document_id,
|
||
"status": "failed",
|
||
"error": str(e),
|
||
"message": "文档处理失败"
|
||
}
|
||
|
||
def _get_document_ids_from_vector_store(self, knowledge_base_id: int, document_id: int) -> List[str]:
|
||
"""查询指定document_id的所有向量记录的uuid"""
|
||
try:
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
|
||
# 使用连接池执行查询
|
||
if self.pgvector_pool:
|
||
query = f"""
|
||
SELECT uuid FROM langchain_pg_embedding
|
||
WHERE collection_id = (
|
||
SELECT uuid FROM langchain_pg_collection
|
||
WHERE name = %s
|
||
) AND cmetadata->>'document_id' = %s
|
||
"""
|
||
|
||
result = self.pgvector_pool.execute_query(query, (collection_name, str(document_id)))
|
||
return [row[0] for row in result] if result else []
|
||
else:
|
||
logger.warning("PGVector连接池未初始化")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询文档向量记录失败: {str(e)}")
|
||
return []
|
||
|
||
def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None:
|
||
"""从向量存储中删除文档"""
|
||
try:
|
||
if settings.vector_db.type == "pgvector":
|
||
# PostgreSQL pgvector存储
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
|
||
try:
|
||
# 创建新版本PGVector实例
|
||
vector_store = PGVector(
|
||
connection=self.connection_string,
|
||
embeddings=self.embeddings,
|
||
collection_name=collection_name,
|
||
use_jsonb=True
|
||
)
|
||
|
||
# 直接从数据库查询要删除的文档UUID
|
||
try:
|
||
from sqlalchemy import text
|
||
from sqlalchemy.orm import Session
|
||
|
||
# 获取数据库引擎
|
||
engine = vector_store._engine
|
||
|
||
with Session(engine) as session:
|
||
# 查询匹配document_id的所有记录的ID
|
||
query_sql = text(
|
||
f"SELECT id FROM langchain_pg_embedding "
|
||
f"WHERE cmetadata->>'document_id' = :doc_id"
|
||
)
|
||
result = session.execute(query_sql, {"doc_id": str(document_id)})
|
||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||
|
||
if ids_to_delete:
|
||
# 使用ID删除文档
|
||
vector_store.delete(ids=ids_to_delete)
|
||
logger.info(f"成功删除 {len(ids_to_delete)} 个文档块: document_id={document_id}")
|
||
else:
|
||
logger.warning(f"未找到要删除的文档ID: document_id={document_id}")
|
||
|
||
except Exception as query_error:
|
||
logger.error(f"查询要删除的文档时出错: {query_error}")
|
||
# 如果查询失败,说明文档可能不存在
|
||
logger.warning(f"无法查询到要删除的文档: document_id={document_id}")
|
||
return
|
||
|
||
logger.info(f"文档已从PostgreSQL pgvector存储中删除: document_id={document_id}")
|
||
except Exception as e:
|
||
logger.warning(f"PostgreSQL pgvector存储不存在或删除失败: {collection_name}, {str(e)}")
|
||
else:
|
||
# Chroma兼容模式
|
||
from langchain_community.vectorstores import Chroma
|
||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||
|
||
if not os.path.exists(kb_vector_path):
|
||
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
||
return
|
||
|
||
# 加载向量存储
|
||
vector_store = Chroma(
|
||
persist_directory=kb_vector_path,
|
||
embedding_function=self.embeddings
|
||
)
|
||
|
||
# 删除相关文档块(这里需要根据实际的Chroma API来实现)
|
||
# 注意:Chroma的删除功能可能需要特定的实现方式
|
||
logger.info(f"文档已从向量存储中删除: document_id={document_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"从向量存储删除文档失败: {str(e)}")
|
||
raise
|
||
|
||
def get_document_chunks(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||
"""获取文档的所有分段内容
|
||
|
||
改进说明:
|
||
- 避免使用空查询进行相似性搜索,防止触发不必要的embedding API调用
|
||
- 优先使用直接SQL查询,提高性能
|
||
- 确保结果按chunk_index排序
|
||
"""
|
||
try:
|
||
if settings.vector_db.type == "pgvector":
|
||
# PostgreSQL pgvector存储 - 使用直接SQL查询避免相似性搜索
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
|
||
try:
|
||
# 尝试直接SQL查询(推荐方法)
|
||
chunks = self._get_chunks_by_sql(knowledge_base_id, document_id)
|
||
if chunks:
|
||
return chunks
|
||
|
||
# 如果SQL查询失败,回退到改进的LangChain方法
|
||
logger.info("SQL查询失败,使用LangChain回退方案")
|
||
return self._get_chunks_by_langchain_improved(knowledge_base_id, document_id, collection_name)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"PostgreSQL pgvector存储访问失败: {collection_name}, {str(e)}")
|
||
return []
|
||
else:
|
||
# Chroma兼容模式
|
||
return self._get_chunks_chroma(knowledge_base_id, document_id)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}")
|
||
return []
|
||
|
||
def _get_chunks_by_sql(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||
"""使用SQLAlchemy连接池查询获取文档分段(推荐方法)"""
|
||
try:
|
||
if not self.pgvector_pool:
|
||
logger.error("PGVector连接池未初始化")
|
||
return []
|
||
|
||
# 直接SQL查询,避免相似性搜索和embedding计算
|
||
query = """
|
||
SELECT
|
||
id,
|
||
document,
|
||
cmetadata
|
||
FROM langchain_pg_embedding
|
||
WHERE cmetadata->>'document_id' = :document_id
|
||
AND cmetadata->>'knowledge_base_id' = :knowledge_base_id
|
||
ORDER BY
|
||
CAST(cmetadata->>'chunk_index' AS INTEGER) ASC;
|
||
"""
|
||
|
||
# 使用连接池执行查询
|
||
session = self.pgvector_pool.get_session()
|
||
try:
|
||
result = session.execute(
|
||
text(query),
|
||
{
|
||
'document_id': str(document_id),
|
||
'knowledge_base_id': str(knowledge_base_id)
|
||
}
|
||
)
|
||
results = result.fetchall()
|
||
|
||
chunks = []
|
||
for row in results:
|
||
# SQLAlchemy结果行访问
|
||
metadata = row.cmetadata
|
||
chunk = {
|
||
"id": f"chunk_{document_id}_{metadata.get('chunk_index', 0)}",
|
||
"content": row.document,
|
||
"metadata": metadata,
|
||
"page_number": metadata.get("page"),
|
||
"chunk_index": metadata.get("chunk_index", 0),
|
||
"start_char": metadata.get("start_char"),
|
||
"end_char": metadata.get("end_char")
|
||
}
|
||
chunks.append(chunk)
|
||
|
||
logger.info(f"通过SQLAlchemy连接池查询获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||
return chunks
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"SQLAlchemy连接池查询失败: {e}")
|
||
return []
|
||
|
||
def _get_chunks_by_langchain_improved(self, knowledge_base_id: int, document_id: int, collection_name: str) -> List[Dict[str, Any]]:
|
||
"""改进的LangChain查询方法(回退方案)"""
|
||
try:
|
||
vector_store = PGVector(
|
||
connection=self.connection_string,
|
||
embeddings=self.embeddings,
|
||
collection_name=collection_name,
|
||
use_jsonb=True
|
||
)
|
||
|
||
# 使用有意义的查询而不是空查询,避免触发embedding API错误
|
||
# 先尝试获取少量结果来构造查询
|
||
try:
|
||
sample_results = vector_store.similarity_search(
|
||
query="文档内容", # 使用通用查询词而非空字符串
|
||
k=5,
|
||
filter={"document_id": {"$eq": str(document_id)}}
|
||
)
|
||
|
||
if sample_results:
|
||
# 使用第一个结果的内容片段作为查询
|
||
first_content = sample_results[0].page_content[:50]
|
||
results = vector_store.similarity_search(
|
||
query=first_content,
|
||
k=1000,
|
||
filter={"document_id": {"$eq": str(document_id)}}
|
||
)
|
||
else:
|
||
# 如果没有结果,尝试不使用filter的查询
|
||
results = vector_store.similarity_search(
|
||
query="文档",
|
||
k=1000
|
||
)
|
||
# 手动过滤结果
|
||
results = [doc for doc in results if doc.metadata.get("document_id") == str(document_id)]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"改进的相似性搜索失败: {e}")
|
||
return []
|
||
|
||
chunks = []
|
||
for i, doc in enumerate(results):
|
||
chunk = {
|
||
"id": f"chunk_{document_id}_{i}",
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"page_number": doc.metadata.get("page"),
|
||
"chunk_index": doc.metadata.get("chunk_index", i),
|
||
"start_char": doc.metadata.get("start_char"),
|
||
"end_char": doc.metadata.get("end_char")
|
||
}
|
||
chunks.append(chunk)
|
||
|
||
# 按chunk_index排序
|
||
chunks.sort(key=lambda x: x.get("chunk_index", 0))
|
||
|
||
logger.info(f"通过改进的LangChain方法获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"LangChain改进方法失败: {e}")
|
||
return []
|
||
|
||
def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||
"""Chroma存储的处理逻辑"""
|
||
try:
|
||
from langchain_community.vectorstores import Chroma
|
||
|
||
# 构建向量数据库路径
|
||
vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||
|
||
if not os.path.exists(vector_db_path):
|
||
logger.warning(f"向量数据库不存在: {vector_db_path}")
|
||
return []
|
||
|
||
# 加载向量数据库
|
||
vectorstore = Chroma(
|
||
persist_directory=vector_db_path,
|
||
embedding_function=self.embeddings
|
||
)
|
||
|
||
# 获取所有文档的元数据,筛选出指定文档的分段
|
||
collection = vectorstore._collection
|
||
all_docs = collection.get(include=["metadatas", "documents"])
|
||
|
||
chunks = []
|
||
chunk_index = 0
|
||
|
||
for i, metadata in enumerate(all_docs["metadatas"]):
|
||
if metadata.get("document_id") == str(document_id):
|
||
chunk_content = all_docs["documents"][i]
|
||
|
||
chunk = {
|
||
"id": f"chunk_{document_id}_{chunk_index}",
|
||
"content": chunk_content,
|
||
"metadata": metadata,
|
||
"page_number": metadata.get("page"),
|
||
"chunk_index": chunk_index,
|
||
"start_char": metadata.get("start_char"),
|
||
"end_char": metadata.get("end_char")
|
||
}
|
||
chunks.append(chunk)
|
||
chunk_index += 1
|
||
|
||
logger.info(f"获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||
return chunks
|
||
|
||
except Exception as e:
|
||
logger.error(f"Chroma存储处理失败: {e}")
|
||
return []
|
||
|
||
def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||
"""在知识库中搜索相似文档"""
|
||
try:
|
||
if settings.vector_db.type == "pgvector":
|
||
# PostgreSQL pgvector存储
|
||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||
|
||
try:
|
||
vector_store = PGVector(
|
||
connection=self.connection_string,
|
||
embeddings=self.embeddings,
|
||
collection_name=collection_name,
|
||
use_jsonb=True
|
||
)
|
||
|
||
# 执行相似性搜索
|
||
results = vector_store.similarity_search_with_score(query, k=k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc, distance_score in results:
|
||
# pgvector使用余弦距离,距离越小相似度越高
|
||
# 将距离转换为0-1之间的相似度分数
|
||
similarity_score = 1.0 / (1.0 + distance_score)
|
||
|
||
formatted_results.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"similarity_score": distance_score, # 保留原始距离分数
|
||
"normalized_score": similarity_score, # 归一化相似度分数
|
||
"source": doc.metadata.get('filename', 'unknown'),
|
||
"document_id": doc.metadata.get('document_id', 'unknown'),
|
||
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
||
})
|
||
|
||
# 按相似度分数排序(距离越小越相似)
|
||
formatted_results.sort(key=lambda x: x['similarity_score'])
|
||
|
||
logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档")
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
|
||
return []
|
||
else:
|
||
# Chroma兼容模式
|
||
from langchain_community.vectorstores import Chroma
|
||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||
|
||
if not os.path.exists(kb_vector_path):
|
||
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
||
return []
|
||
|
||
# 加载向量存储
|
||
vector_store = Chroma(
|
||
persist_directory=kb_vector_path,
|
||
embedding_function=self.embeddings
|
||
)
|
||
|
||
# 执行相似性搜索
|
||
results = vector_store.similarity_search_with_score(query, k=k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for doc, distance_score in results:
|
||
# Chroma使用欧几里得距离,距离越小相似度越高
|
||
# 将距离转换为0-1之间的相似度分数
|
||
similarity_score = 1.0 / (1.0 + distance_score)
|
||
|
||
formatted_results.append({
|
||
"content": doc.page_content,
|
||
"metadata": doc.metadata,
|
||
"similarity_score": distance_score, # 保留原始距离分数
|
||
"normalized_score": similarity_score, # 归一化相似度分数
|
||
"source": doc.metadata.get('filename', 'unknown'),
|
||
"document_id": doc.metadata.get('document_id', 'unknown'),
|
||
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
||
})
|
||
|
||
# 按相似度分数排序(距离越小越相似)
|
||
formatted_results.sort(key=lambda x: x['similarity_score'])
|
||
|
||
logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档")
|
||
return formatted_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索文档失败: {str(e)}")
|
||
return [] # 返回空列表而不是抛出异常
|
||
|
||
|
||
# 全局文档处理器实例(延迟初始化)
|
||
document_processor = None
|
||
|
||
def get_document_processor():
|
||
"""获取文档处理器实例(延迟初始化)"""
|
||
global document_processor
|
||
if document_processor is None:
|
||
document_processor = DocumentProcessor()
|
||
return document_processor |