2026-01-21 13:45:39 +08:00
|
|
|
|
"""文档处理服务,负责文档的分段、向量化和索引"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
|
from requests import Session
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
|
|
from langchain_community.document_loaders import (
|
|
|
|
|
|
TextLoader,
|
|
|
|
|
|
PyPDFLoader,
|
|
|
|
|
|
Docx2txtLoader,
|
|
|
|
|
|
UnstructuredMarkdownLoader
|
|
|
|
|
|
)
|
|
|
|
|
|
import pdfplumber
|
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
from langchain_postgres import PGVector
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..core.config import BaseSettings, get_settings
|
|
|
|
|
|
from ..models.knowledge_base import Document as DocumentModel
|
|
|
|
|
|
from ..db.database import get_session
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
class DocumentProcessor:
|
|
|
|
|
|
"""文档处理器,负责文档的加载、分段和向量化"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
# 初始化语义分割器配置
|
|
|
|
|
|
self.embeddings = None
|
|
|
|
|
|
self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled
|
|
|
|
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
|
|
chunk_size=settings.file.chunk_size,
|
|
|
|
|
|
chunk_overlap=settings.file.chunk_overlap,
|
|
|
|
|
|
length_function=len,
|
|
|
|
|
|
separators=["\n\n", "\n", " ", ""]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def initialize(self, session: Session = None):
|
2026-01-26 17:46:07 +08:00
|
|
|
|
# 先设置向量数据库路径(即使后续初始化失败,路径也应该被设置)
|
|
|
|
|
|
# 向量数据库存储路径(Chroma兼容)
|
|
|
|
|
|
vector_db_path = settings.vector_db.persist_directory
|
|
|
|
|
|
if not os.path.isabs(vector_db_path):
|
|
|
|
|
|
# 如果是相对路径,则基于项目根目录计算绝对路径
|
|
|
|
|
|
# 项目根目录是backend的父目录
|
|
|
|
|
|
backend_dir = Path(__file__).parent.parent.parent
|
|
|
|
|
|
vector_db_path = str(backend_dir / vector_db_path)
|
|
|
|
|
|
self.vector_db_path = vector_db_path
|
|
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"初始化向量数据库 - 路径 = {self.vector_db_path}"
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
# 初始化嵌入模型 - 根据配置选择提供商
|
|
|
|
|
|
await self._init_embeddings(session)
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化连接池(仅对PGVector)
|
|
|
|
|
|
self.pgvector_pool = None
|
|
|
|
|
|
|
|
|
|
|
|
# PostgreSQL pgvector连接配置
|
|
|
|
|
|
# if settings.vector_db.type == "pgvector":
|
|
|
|
|
|
# # 新版本PGVector使用psycopg3连接字符串
|
|
|
|
|
|
# # 对密码进行URL编码以处理特殊字符(如@符号)
|
|
|
|
|
|
# encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
|
|
|
|
|
# self.connection_string = (
|
|
|
|
|
|
# f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
|
|
|
|
|
|
# f"{encoded_password}@"
|
|
|
|
|
|
# f"{settings.vector_db.pgvector_host}:"
|
|
|
|
|
|
# f"{settings.vector_db.pgvector_port}/"
|
|
|
|
|
|
# f"{settings.vector_db.pgvector_database}"
|
|
|
|
|
|
# )
|
|
|
|
|
|
# # 初始化连接池
|
|
|
|
|
|
# self.pgvector_pool = PGVectorConnectionPool()
|
|
|
|
|
|
# logger.info("新版本PGVector使用psycopg3连接字符串: %s", self.connection_string)
|
|
|
|
|
|
|
|
|
|
|
|
async def _init_embeddings(self, session: Optional[Any] = None):
|
|
|
|
|
|
"""初始化嵌入模型。"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not self.embeddings:
|
|
|
|
|
|
# 使用llm_config_service获取嵌入配置
|
|
|
|
|
|
from .llm_config_service import LLMConfigService
|
|
|
|
|
|
llm_config_service = LLMConfigService()
|
|
|
|
|
|
|
|
|
|
|
|
# 获取嵌入配置
|
|
|
|
|
|
config = None
|
|
|
|
|
|
if session:
|
|
|
|
|
|
config = await llm_config_service.get_default_embedding_config(session)
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if config and session:
|
|
|
|
|
|
session.desc = f"获取默认嵌入模型配置: {config}"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
# # 转换配置格式
|
|
|
|
|
|
# config = {
|
|
|
|
|
|
# "provider": config.provider,
|
|
|
|
|
|
# "api_key": config.api_key,
|
|
|
|
|
|
# "model": config.model_name
|
|
|
|
|
|
# }
|
|
|
|
|
|
|
|
|
|
|
|
# 如果未找到配置,使用默认配置
|
|
|
|
|
|
if not config:
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"ERROR: 未找到嵌入模型配置"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"获取嵌入模型配置 > 结果:{config}"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 根据配置创建嵌入模型
|
|
|
|
|
|
if config.provider == "openai":
|
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
|
self.embeddings = OpenAIEmbeddings(
|
2026-01-22 16:18:29 +08:00
|
|
|
|
model=config.model_name or "text-embedding-3-small",
|
|
|
|
|
|
api_key=config.api_key
|
|
|
|
|
|
)
|
|
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.model_name or 'text-embedding-3-small'})"
|
|
|
|
|
|
elif config.provider == "zhipu":
|
|
|
|
|
|
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
|
|
|
|
|
self.embeddings = ZhipuOpenAIEmbeddings(
|
|
|
|
|
|
api_key=config.api_key,
|
|
|
|
|
|
base_url=config.base_url or "https://open.bigmodel.cn/api/paas/v4",
|
|
|
|
|
|
model=config.model_name or "embedding-3",
|
|
|
|
|
|
dimensions=settings.vector_db.embedding_dimension
|
2026-01-21 13:45:39 +08:00
|
|
|
|
)
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"创建嵌入模型 - ZhipuOpenAIEmbeddings(model={config.model_name or 'embedding-3'}, base_url={config.base_url})"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
elif config.provider == "ollama":
|
|
|
|
|
|
from langchain_ollama import OllamaEmbeddings
|
|
|
|
|
|
self.embeddings = OllamaEmbeddings(
|
|
|
|
|
|
model=config.model_name,
|
|
|
|
|
|
base_url=config.base_url
|
|
|
|
|
|
)
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
elif config.provider == "local":
|
|
|
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings(
|
2026-01-22 16:18:29 +08:00
|
|
|
|
model_name=config.model_name or "sentence-transformers/all-MiniLM-L6-v2"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
)
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.model_name or 'sentence-transformers/all-MiniLM-L6-v2'})"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
else:
|
|
|
|
|
|
# 默认使用OpenAI
|
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
|
|
self.embeddings = OpenAIEmbeddings(
|
2026-01-22 16:18:29 +08:00
|
|
|
|
model=config.model_name or "text-embedding-3-small",
|
|
|
|
|
|
api_key=config.api_key
|
2026-01-21 13:45:39 +08:00
|
|
|
|
)
|
2026-01-22 16:18:29 +08:00
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"ERROR: 未支持的嵌入提供者: {config.provider},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
return self.embeddings
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"初始化嵌入模型时出错: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def load_document(self, session: Session, file_path: str) -> List[Document]:
|
|
|
|
|
|
"""根据文件类型加载文档"""
|
|
|
|
|
|
file_extension = Path(file_path).suffix.lower()
|
|
|
|
|
|
try:
|
|
|
|
|
|
if file_extension == '.txt':
|
|
|
|
|
|
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: txt"
|
|
|
|
|
|
loader = TextLoader(file_path, encoding='utf-8')
|
|
|
|
|
|
documents = loader.load()
|
|
|
|
|
|
elif file_extension == '.pdf':
|
|
|
|
|
|
# 使用pdfplumber处理PDF文件,更稳定
|
|
|
|
|
|
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: pdf"
|
|
|
|
|
|
from langchain_community.document_loaders import PyPDFLoader
|
|
|
|
|
|
loader = PyPDFLoader(file_path)
|
|
|
|
|
|
documents = loader.load()
|
|
|
|
|
|
# documents = self._load_pdf_with_pdfplumber(file_path)
|
|
|
|
|
|
elif file_extension == '.docx':
|
|
|
|
|
|
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: docx"
|
|
|
|
|
|
loader = Docx2txtLoader(file_path)
|
|
|
|
|
|
documents = loader.load()
|
|
|
|
|
|
elif file_extension == '.md':
|
|
|
|
|
|
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: md"
|
|
|
|
|
|
loader = UnstructuredMarkdownLoader(file_path)
|
|
|
|
|
|
documents = loader.load()
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的文件类型: {file_extension}")
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"已载文档: {file_path}, 页数: {len(documents)}"
|
|
|
|
|
|
# if len(documents) > 0:
|
|
|
|
|
|
# session.desc = f"文档内容示例: {type(documents[0])} - {documents[0]}"
|
|
|
|
|
|
return documents
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
session.desc = f"ERROR: 加载文档失败 {file_path}: {str(e)}"
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]:
|
|
|
|
|
|
"""使用pdfplumber加载PDF文档"""
|
|
|
|
|
|
documents = []
|
|
|
|
|
|
try:
|
|
|
|
|
|
with pdfplumber.open(file_path) as pdf:
|
|
|
|
|
|
for page_num, page in enumerate(pdf.pages):
|
|
|
|
|
|
text = page.extract_text()
|
|
|
|
|
|
if text and text.strip(): # 只处理有文本内容的页面
|
|
|
|
|
|
doc = Document(
|
|
|
|
|
|
page_content=text,
|
|
|
|
|
|
metadata={
|
|
|
|
|
|
"source": file_path,
|
|
|
|
|
|
"page": page_num + 1
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
documents.append(doc)
|
|
|
|
|
|
return documents
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"使用pdfplumber加载PDF失败 {file_path}: {str(e)}")
|
|
|
|
|
|
# 如果pdfplumber失败,回退到PyPDFLoader
|
|
|
|
|
|
try:
|
|
|
|
|
|
loader = PyPDFLoader(file_path)
|
|
|
|
|
|
return loader.load()
|
|
|
|
|
|
except Exception as fallback_e:
|
|
|
|
|
|
logger.error(f"PyPDFLoader回退也失败 {file_path}: {str(fallback_e)}")
|
|
|
|
|
|
raise fallback_e
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_documents(self, documents: List[Document]) -> Document:
|
|
|
|
|
|
"""将多个文档合并成一个文档"""
|
|
|
|
|
|
merged_text = ""
|
|
|
|
|
|
merged_metadata = {}
|
|
|
|
|
|
|
|
|
|
|
|
for doc in documents:
|
|
|
|
|
|
if merged_text:
|
|
|
|
|
|
merged_text += "\n\n"
|
|
|
|
|
|
merged_text += doc.page_content
|
|
|
|
|
|
# 合并元数据
|
|
|
|
|
|
merged_metadata.update(doc.metadata)
|
|
|
|
|
|
|
|
|
|
|
|
return Document(page_content=merged_text, metadata=merged_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]:
|
|
|
|
|
|
"""根据语义分割点切分文本"""
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
current_pos = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 按顺序查找每个分割点并切分文本
|
|
|
|
|
|
for point in split_points:
|
|
|
|
|
|
pos = text.find(point, current_pos)
|
|
|
|
|
|
if pos != -1:
|
|
|
|
|
|
# 添加当前位置到分割点位置的文本块
|
|
|
|
|
|
if pos > current_pos:
|
|
|
|
|
|
chunk = text[current_pos:pos].strip()
|
|
|
|
|
|
if chunk:
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
current_pos = pos
|
|
|
|
|
|
|
|
|
|
|
|
# 添加最后一个文本块
|
|
|
|
|
|
if current_pos < len(text):
|
|
|
|
|
|
chunk = text[current_pos:].strip()
|
|
|
|
|
|
if chunk:
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
async def split_documents(self, session: Session, documents: List[Document]) -> List[Document]:
|
|
|
|
|
|
"""将文档分割成小块(含短段落合并和超长强制分割功能)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
chunks = self.text_splitter.split_documents(documents)
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"文档分割完成,共生成 {len(chunks)} 个文档块"
|
|
|
|
|
|
if len(chunks) > 0:
|
|
|
|
|
|
session.desc = f"文档块内容示例: {type(chunks[0])} - {chunks[0]}"
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
session.desc = f"ERROR: 文档分割失败: {str(e)}"
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
def _force_split_long_chunk(self, chunk: str) -> List[str]:
|
|
|
|
|
|
"""强制分割超长段落(超过1000字符)"""
|
|
|
|
|
|
max_length = 1000
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
# 先尝试按换行符分割
|
|
|
|
|
|
if '\n' in chunk:
|
|
|
|
|
|
lines = chunk.split('\n')
|
|
|
|
|
|
current_chunk = ""
|
|
|
|
|
|
for line in lines:
|
|
|
|
|
|
if len(current_chunk) + len(line) + 1 > max_length:
|
|
|
|
|
|
if current_chunk:
|
|
|
|
|
|
chunks.append(current_chunk)
|
|
|
|
|
|
current_chunk = line
|
|
|
|
|
|
else:
|
|
|
|
|
|
chunks.append(line[:max_length])
|
|
|
|
|
|
current_chunk = line[max_length:]
|
|
|
|
|
|
else:
|
|
|
|
|
|
if current_chunk:
|
|
|
|
|
|
current_chunk += "\n" + line
|
|
|
|
|
|
else:
|
|
|
|
|
|
current_chunk = line
|
|
|
|
|
|
if current_chunk:
|
|
|
|
|
|
chunks.append(current_chunk)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 没有换行符则直接按长度分割
|
|
|
|
|
|
chunks = [chunk[i:i + max_length] for i in range(0, len(chunk), max_length)]
|
|
|
|
|
|
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str:
|
|
|
|
|
|
"""为知识库创建向量存储"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# if settings.vector_db.type == "pgvector":
|
|
|
|
|
|
# # 添加元数据
|
|
|
|
|
|
# for i, doc in enumerate(documents):
|
|
|
|
|
|
# doc.metadata.update({
|
|
|
|
|
|
# "knowledge_base_id": knowledge_base_id,
|
|
|
|
|
|
# "document_id": str(document_id) if document_id else "unknown",
|
|
|
|
|
|
# "chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|
|
|
|
|
# "chunk_index": i
|
|
|
|
|
|
# })
|
|
|
|
|
|
|
|
|
|
|
|
# # 创建PostgreSQL pgvector存储
|
|
|
|
|
|
# collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|
|
|
|
|
|
|
|
|
|
|
# # 创建新版本PGVector实例
|
|
|
|
|
|
# vector_store = PGVector(
|
|
|
|
|
|
# connection=self.connection_string,
|
|
|
|
|
|
# embeddings=self.embeddings,
|
|
|
|
|
|
# collection_name=collection_name,
|
|
|
|
|
|
# use_jsonb=True # 使用JSONB存储元数据
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
# # 手动添加文档
|
|
|
|
|
|
# vector_store.add_documents(documents)
|
|
|
|
|
|
|
|
|
|
|
|
# logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
|
|
|
|
|
|
# return collection_name
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# Chroma兼容模式
|
2026-01-26 17:46:07 +08:00
|
|
|
|
# 检查 vector_db_path 和 embeddings 是否已初始化
|
|
|
|
|
|
if not hasattr(self, 'vector_db_path') or not self.vector_db_path:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:vector_db_path 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
if not hasattr(self, 'embeddings') or not self.embeddings:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:embeddings 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
from langchain_chroma import Chroma
|
|
|
|
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加元数据
|
|
|
|
|
|
for i, doc in enumerate(documents):
|
|
|
|
|
|
doc.metadata.update({
|
|
|
|
|
|
"knowledge_base_id": knowledge_base_id,
|
|
|
|
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|
|
|
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|
|
|
|
|
"chunk_index": i
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 创建向量存储
|
|
|
|
|
|
vector_store = Chroma.from_documents(
|
|
|
|
|
|
documents=documents,
|
|
|
|
|
|
embedding=self.embeddings,
|
|
|
|
|
|
persist_directory=kb_vector_path
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"向量存储创建成功: {kb_vector_path}")
|
|
|
|
|
|
return kb_vector_path
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"创建向量存储失败: {str(e)}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def add_documents_to_vector_store(self, session: Session, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
|
|
|
|
|
|
"""向现有向量存储添加文档"""
|
|
|
|
|
|
if len(documents) == 0:
|
|
|
|
|
|
session.desc = f"WARNING: 文档列表为空,不执行添加操作"
|
|
|
|
|
|
return
|
2026-01-26 17:46:07 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查 vector_db_path 是否已初始化
|
|
|
|
|
|
if not hasattr(self, 'vector_db_path') or not self.vector_db_path:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:vector_db_path 未设置"
|
|
|
|
|
|
session.desc = f"ERROR: {error_msg}"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
from langchain_chroma import Chroma
|
|
|
|
|
|
|
|
|
|
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|
|
|
|
|
session.desc = f"添加文档到向量存储: {kb_vector_path} - documents number: {len(documents)}"
|
|
|
|
|
|
# 检查向量存储是否存在
|
|
|
|
|
|
if not os.path.exists(kb_vector_path):
|
|
|
|
|
|
# 如果不存在,创建新的向量存储
|
|
|
|
|
|
session.desc = f"WARNING: 向量存储不存在,创建新的向量存储"
|
|
|
|
|
|
self.create_vector_store(knowledge_base_id, documents, document_id)
|
|
|
|
|
|
return
|
|
|
|
|
|
session.desc = f"添加文档到向量存储: exists"
|
|
|
|
|
|
# 添加元数据
|
|
|
|
|
|
for i, doc in enumerate(documents):
|
|
|
|
|
|
doc.metadata.update({
|
|
|
|
|
|
"knowledge_base_id": knowledge_base_id,
|
|
|
|
|
|
"document_id": str(document_id) if document_id else "unknown",
|
|
|
|
|
|
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
|
|
|
|
|
"chunk_index": i
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"添加文档到向量存储: enumerate"
|
|
|
|
|
|
# 加载现有向量存储
|
|
|
|
|
|
vector_store = Chroma(
|
|
|
|
|
|
persist_directory=kb_vector_path,
|
|
|
|
|
|
embedding_function=self.embeddings
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"添加文档到向量存储: Chroma"
|
|
|
|
|
|
# 添加新文档
|
|
|
|
|
|
ids = vector_store.add_documents(documents)
|
|
|
|
|
|
session.desc = f"文档已添加到向量存储: {kb_vector_path} -> {len(ids)} IDS - \n{ids}"
|
|
|
|
|
|
|
|
|
|
|
|
async def process_document(self, session: Session, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""处理单个文档:加载、分段、向量化"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
session.desc = f"处理文档 ID: {document_id} 文件路径: {file_path}"
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 加载文档
|
|
|
|
|
|
documents = self.load_document(session, file_path)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 分割文档
|
|
|
|
|
|
chunks = await self.split_documents(session, documents)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 添加到向量存储
|
|
|
|
|
|
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 更新文档状态
|
2026-01-22 16:18:29 +08:00
|
|
|
|
# Python 3.9 兼容:使用 async for 替代 anext
|
|
|
|
|
|
async for db_session in get_session():
|
|
|
|
|
|
try:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
|
|
|
|
|
|
|
|
|
|
|
if document:
|
|
|
|
|
|
document.is_processed = True
|
|
|
|
|
|
document.chunk_count = len(chunks)
|
|
|
|
|
|
await db_session.commit()
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await db_session.close()
|
|
|
|
|
|
break # 只取第一个 session
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
result = {
|
|
|
|
|
|
"document_id": document_id,
|
|
|
|
|
|
"status": "success",
|
|
|
|
|
|
"chunks_count": len(chunks),
|
|
|
|
|
|
"message": "文档处理完成"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"文档处理完成: {result}"
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
session.desc = f"ERROR: 文档处理失败 ID: {document_id}: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
# 更新文档状态为失败
|
|
|
|
|
|
try:
|
2026-01-22 16:18:29 +08:00
|
|
|
|
# Python 3.9 兼容:使用 async for 替代 anext
|
|
|
|
|
|
async for db_session in get_session():
|
|
|
|
|
|
try:
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
|
|
|
|
|
if document:
|
|
|
|
|
|
document.is_processed = False
|
|
|
|
|
|
document.processing_error = str(e)
|
|
|
|
|
|
await db_session.commit()
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await db_session.close()
|
|
|
|
|
|
break # 只取第一个 session
|
2026-01-21 13:45:39 +08:00
|
|
|
|
except Exception as db_error:
|
|
|
|
|
|
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"document_id": document_id,
|
|
|
|
|
|
"status": "failed",
|
|
|
|
|
|
"error": str(e),
|
|
|
|
|
|
"message": "文档处理失败"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None:
|
|
|
|
|
|
"""从向量存储中删除文档"""
|
|
|
|
|
|
try:
|
2026-01-26 17:46:07 +08:00
|
|
|
|
# 检查 vector_db_path 和 embeddings 是否已初始化
|
|
|
|
|
|
if not hasattr(self, 'vector_db_path') or not self.vector_db_path:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:vector_db_path 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
if not hasattr(self, 'embeddings') or not self.embeddings:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:embeddings 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
# Chroma兼容模式
|
|
|
|
|
|
from langchain_chroma import Chroma
|
|
|
|
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(kb_vector_path):
|
|
|
|
|
|
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
chunks = self.get_document_chunks(knowledge_base_id, document_id)
|
|
|
|
|
|
# 加载向量存储
|
|
|
|
|
|
vector_store = Chroma(
|
|
|
|
|
|
persist_directory=kb_vector_path,
|
|
|
|
|
|
embedding_function=self.embeddings
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
count_before = vector_store._collection.count()
|
|
|
|
|
|
count_after = count_before
|
|
|
|
|
|
|
|
|
|
|
|
if len(chunks) > 0:
|
|
|
|
|
|
where_filter = {"document_id": str(document_id)}
|
|
|
|
|
|
vector_store.delete(where=where_filter)
|
|
|
|
|
|
count_after = vector_store._collection.count()
|
|
|
|
|
|
|
|
|
|
|
|
# 注意:Chroma的删除功能可能需要特定的实现方式
|
|
|
|
|
|
logger.info(f"文档已从向量存储中删除: document_id={document_id},删除前有 {count_before} 个向量,删除后有 {count_after} 个向量")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"从向量存储删除文档失败: {str(e)}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def get_document_chunks(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""获取文档的所有分段内容
|
|
|
|
|
|
|
|
|
|
|
|
改进说明:
|
|
|
|
|
|
- 避免使用空查询进行相似性搜索,防止触发不必要的embedding API调用
|
|
|
|
|
|
- 优先使用直接SQL查询,提高性能
|
|
|
|
|
|
- 确保结果按chunk_index排序
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
return self._get_chunks_chroma(knowledge_base_id, document_id)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def _get_chunks_by_sql(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""使用SQLAlchemy连接池查询获取文档分段(推荐方法)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not self.pgvector_pool:
|
|
|
|
|
|
logger.error("PGVector连接池未初始化")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# 直接SQL查询,避免相似性搜索和embedding计算
|
|
|
|
|
|
query = """
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
id,
|
|
|
|
|
|
document,
|
|
|
|
|
|
cmetadata
|
|
|
|
|
|
FROM langchain_pg_embedding
|
|
|
|
|
|
WHERE cmetadata->>'document_id' = :document_id
|
|
|
|
|
|
AND cmetadata->>'knowledge_base_id' = :knowledge_base_id
|
|
|
|
|
|
ORDER BY
|
|
|
|
|
|
CAST(cmetadata->>'chunk_index' AS INTEGER) ASC;
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 使用连接池执行查询
|
|
|
|
|
|
session = self.pgvector_pool.get_session()
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = session.execute(
|
|
|
|
|
|
text(query),
|
|
|
|
|
|
{
|
|
|
|
|
|
'document_id': str(document_id),
|
|
|
|
|
|
'knowledge_base_id': str(knowledge_base_id)
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
results = result.fetchall()
|
|
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
for row in results:
|
|
|
|
|
|
# SQLAlchemy结果行访问
|
|
|
|
|
|
metadata = row.cmetadata
|
|
|
|
|
|
chunk = {
|
|
|
|
|
|
"id": f"chunk_{document_id}_{metadata.get('chunk_index', 0)}",
|
|
|
|
|
|
"content": row.document,
|
|
|
|
|
|
"metadata": metadata,
|
|
|
|
|
|
"page_number": metadata.get("page"),
|
|
|
|
|
|
"chunk_index": metadata.get("chunk_index", 0),
|
|
|
|
|
|
"start_char": metadata.get("start_char"),
|
|
|
|
|
|
"end_char": metadata.get("end_char")
|
|
|
|
|
|
}
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"通过SQLAlchemy连接池查询获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"SQLAlchemy连接池查询失败: {e}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def _get_chunks_by_langchain_improved(self, knowledge_base_id: int, document_id: int, collection_name: str) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""改进的LangChain查询方法(回退方案)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
vector_store = PGVector(
|
|
|
|
|
|
connection=self.connection_string,
|
|
|
|
|
|
embeddings=self.embeddings,
|
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
|
use_jsonb=True
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 使用有意义的查询而不是空查询,避免触发embedding API错误
|
|
|
|
|
|
# 先尝试获取少量结果来构造查询
|
|
|
|
|
|
try:
|
|
|
|
|
|
sample_results = vector_store.similarity_search(
|
|
|
|
|
|
query="文档内容", # 使用通用查询词而非空字符串
|
|
|
|
|
|
k=5,
|
|
|
|
|
|
filter={"document_id": {"$eq": str(document_id)}}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if sample_results:
|
|
|
|
|
|
# 使用第一个结果的内容片段作为查询
|
|
|
|
|
|
first_content = sample_results[0].page_content[:50]
|
|
|
|
|
|
results = vector_store.similarity_search(
|
|
|
|
|
|
query=first_content,
|
|
|
|
|
|
k=1000,
|
|
|
|
|
|
filter={"document_id": {"$eq": str(document_id)}}
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没有结果,尝试不使用filter的查询
|
|
|
|
|
|
results = vector_store.similarity_search(
|
|
|
|
|
|
query="文档",
|
|
|
|
|
|
k=1000
|
|
|
|
|
|
)
|
|
|
|
|
|
# 手动过滤结果
|
|
|
|
|
|
results = [doc for doc in results if doc.metadata.get("document_id") == str(document_id)]
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.warning(f"改进的相似性搜索失败: {e}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
for i, doc in enumerate(results):
|
|
|
|
|
|
chunk = {
|
|
|
|
|
|
"id": f"chunk_{document_id}_{i}",
|
|
|
|
|
|
"content": doc.page_content,
|
|
|
|
|
|
"metadata": doc.metadata,
|
|
|
|
|
|
"page_number": doc.metadata.get("page"),
|
|
|
|
|
|
"chunk_index": doc.metadata.get("chunk_index", i),
|
|
|
|
|
|
"start_char": doc.metadata.get("start_char"),
|
|
|
|
|
|
"end_char": doc.metadata.get("end_char")
|
|
|
|
|
|
}
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
# 按chunk_index排序
|
|
|
|
|
|
chunks.sort(key=lambda x: x.get("chunk_index", 0))
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"通过改进的LangChain方法获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"LangChain改进方法失败: {e}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""Chroma存储的处理逻辑"""
|
2026-01-26 17:46:07 +08:00
|
|
|
|
# 检查 vector_db_path 和 embeddings 是否已初始化
|
|
|
|
|
|
if not hasattr(self, 'vector_db_path') or not self.vector_db_path:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:vector_db_path 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
if not hasattr(self, 'embeddings') or not self.embeddings:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:embeddings 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
from langchain_chroma import Chroma
|
|
|
|
|
|
# 构建向量数据库路径
|
|
|
|
|
|
vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|
|
|
|
|
|
2026-01-26 17:46:07 +08:00
|
|
|
|
logger.info(f"获取文档块: kb_id={knowledge_base_id}, doc_id={document_id}, vector_db_path={vector_db_path}")
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
if not os.path.exists(vector_db_path):
|
|
|
|
|
|
logger.warning(f"向量数据库不存在: {vector_db_path}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# 加载向量数据库
|
2026-01-26 17:46:07 +08:00
|
|
|
|
try:
|
|
|
|
|
|
vectorstore = Chroma(
|
|
|
|
|
|
persist_directory=vector_db_path,
|
|
|
|
|
|
embedding_function=self.embeddings
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"加载向量数据库失败: {vector_db_path}, 错误: {str(e)}")
|
|
|
|
|
|
return []
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
# 获取所有文档的元数据,筛选出指定文档的分段
|
2026-01-26 17:46:07 +08:00
|
|
|
|
try:
|
|
|
|
|
|
collection = vectorstore._collection
|
|
|
|
|
|
total_count = collection.count()
|
|
|
|
|
|
logger.info(f"向量数据库中共有 {total_count} 个向量")
|
|
|
|
|
|
|
|
|
|
|
|
if total_count == 0:
|
|
|
|
|
|
logger.warning(f"向量数据库为空: {vector_db_path}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
all_docs = collection.get(include=["metadatas", "documents"])
|
|
|
|
|
|
all_ids_data = collection.get()
|
|
|
|
|
|
|
|
|
|
|
|
if not all_docs or "metadatas" not in all_docs or not all_docs["metadatas"]:
|
|
|
|
|
|
logger.warning(f"向量数据库中没有元数据: {vector_db_path}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"获取到 {len(all_docs['metadatas'])} 个元数据项")
|
|
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
chunk_index = 0
|
|
|
|
|
|
document_id_str = str(document_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 记录所有 document_id 以便调试
|
|
|
|
|
|
all_document_ids = set()
|
|
|
|
|
|
for metadata in all_docs["metadatas"]:
|
|
|
|
|
|
doc_id = metadata.get("document_id")
|
|
|
|
|
|
if doc_id:
|
|
|
|
|
|
all_document_ids.add(str(doc_id))
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"向量数据库中的所有 document_id: {sorted(all_document_ids)}")
|
|
|
|
|
|
logger.info(f"正在查找 document_id: {document_id_str} (类型: {type(document_id_str)})")
|
|
|
|
|
|
|
|
|
|
|
|
for i, metadata in enumerate(all_docs["metadatas"]):
|
|
|
|
|
|
metadata_doc_id = metadata.get("document_id")
|
|
|
|
|
|
if metadata_doc_id:
|
|
|
|
|
|
metadata_doc_id_str = str(metadata_doc_id)
|
|
|
|
|
|
if metadata_doc_id_str == document_id_str:
|
|
|
|
|
|
chunk_content = all_docs["documents"][i] if i < len(all_docs["documents"]) else ""
|
|
|
|
|
|
vector_id = all_ids_data["ids"][i] if i < len(all_ids_data["ids"]) else None
|
|
|
|
|
|
|
|
|
|
|
|
# 使用元数据中的 chunk_index,如果没有则使用递增索引
|
|
|
|
|
|
chunk_idx = metadata.get("chunk_index", chunk_index)
|
|
|
|
|
|
|
|
|
|
|
|
chunk = {
|
|
|
|
|
|
"id": f"chunk_{document_id}_{chunk_idx}",
|
|
|
|
|
|
"content": chunk_content,
|
|
|
|
|
|
"metadata": metadata,
|
|
|
|
|
|
"page_number": metadata.get("page"),
|
|
|
|
|
|
"chunk_index": chunk_idx,
|
|
|
|
|
|
"start_char": metadata.get("start_char"),
|
|
|
|
|
|
"end_char": metadata.get("end_char"),
|
|
|
|
|
|
"vector_id": vector_id
|
|
|
|
|
|
}
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
chunk_index += 1
|
|
|
|
|
|
|
|
|
|
|
|
# 按 chunk_index 排序
|
|
|
|
|
|
chunks.sort(key=lambda x: x["chunk_index"])
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"找到 {len(chunks)} 个文档块 (document_id={document_id})")
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取文档块时出错: kb_id={knowledge_base_id}, doc_id={document_id}, 错误: {str(e)}", exc_info=True)
|
|
|
|
|
|
return []
|
2026-01-21 13:45:39 +08:00
|
|
|
|
|
|
|
|
|
|
def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""在知识库中搜索相似文档"""
|
|
|
|
|
|
try:
|
2026-01-26 17:46:07 +08:00
|
|
|
|
# 检查 vector_db_path 和 embeddings 是否已初始化
|
|
|
|
|
|
if not hasattr(self, 'vector_db_path') or not self.vector_db_path:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:vector_db_path 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
if not hasattr(self, 'embeddings') or not self.embeddings:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:embeddings 未设置"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
# if settings.vector_db.type == "pgvector":
|
|
|
|
|
|
# # PostgreSQL pgvector存储
|
|
|
|
|
|
# collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
|
|
|
|
|
|
|
|
|
|
|
# try:
|
|
|
|
|
|
# vector_store = PGVector(
|
|
|
|
|
|
# connection=self.connection_string,
|
|
|
|
|
|
# embeddings=self.embeddings,
|
|
|
|
|
|
# collection_name=collection_name,
|
|
|
|
|
|
# use_jsonb=True
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
# # 执行相似性搜索
|
|
|
|
|
|
# results = vector_store.similarity_search_with_score(query, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
# # 格式化结果
|
|
|
|
|
|
# formatted_results = []
|
|
|
|
|
|
# for doc, distance_score in results:
|
|
|
|
|
|
# # pgvector使用余弦距离,距离越小相似度越高
|
|
|
|
|
|
# # 将距离转换为0-1之间的相似度分数
|
|
|
|
|
|
# similarity_score = 1.0 / (1.0 + distance_score)
|
|
|
|
|
|
|
|
|
|
|
|
# formatted_results.append({
|
|
|
|
|
|
# "content": doc.page_content,
|
|
|
|
|
|
# "metadata": doc.metadata,
|
|
|
|
|
|
# "similarity_score": distance_score, # 保留原始距离分数
|
|
|
|
|
|
# "normalized_score": similarity_score, # 归一化相似度分数
|
|
|
|
|
|
# "source": doc.metadata.get('filename', 'unknown'),
|
|
|
|
|
|
# "document_id": doc.metadata.get('document_id', 'unknown'),
|
|
|
|
|
|
# "chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
|
|
|
|
|
# })
|
|
|
|
|
|
|
|
|
|
|
|
# # 按相似度分数排序(距离越小越相似)
|
|
|
|
|
|
# formatted_results.sort(key=lambda x: x['similarity_score'])
|
|
|
|
|
|
|
|
|
|
|
|
# logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档")
|
|
|
|
|
|
# return formatted_results
|
|
|
|
|
|
|
|
|
|
|
|
# except Exception as e:
|
|
|
|
|
|
# logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
|
|
|
|
|
|
# return []
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# Chroma兼容模式
|
|
|
|
|
|
from langchain_chroma import Chroma
|
|
|
|
|
|
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(kb_vector_path):
|
|
|
|
|
|
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
# 加载向量存储
|
|
|
|
|
|
vector_store = Chroma(
|
|
|
|
|
|
persist_directory=kb_vector_path,
|
|
|
|
|
|
embedding_function=self.embeddings
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行相似性搜索
|
|
|
|
|
|
results = vector_store.similarity_search_with_score(query, k=k)
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化结果
|
|
|
|
|
|
formatted_results = []
|
|
|
|
|
|
for doc, distance_score in results:
|
|
|
|
|
|
# Chroma使用欧几里得距离,距离越小相似度越高
|
|
|
|
|
|
# 将距离转换为0-1之间的相似度分数
|
|
|
|
|
|
similarity_score = 1.0 / (1.0 + distance_score)
|
|
|
|
|
|
|
|
|
|
|
|
formatted_results.append({
|
|
|
|
|
|
"content": doc.page_content,
|
|
|
|
|
|
"metadata": doc.metadata,
|
|
|
|
|
|
"similarity_score": distance_score, # 保留原始距离分数
|
|
|
|
|
|
"normalized_score": similarity_score, # 归一化相似度分数
|
|
|
|
|
|
"source": doc.metadata.get('filename', 'unknown'),
|
|
|
|
|
|
"document_id": doc.metadata.get('document_id', 'unknown'),
|
|
|
|
|
|
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 按相似度分数排序(距离越小越相似)
|
|
|
|
|
|
formatted_results.sort(key=lambda x: x['similarity_score'])
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档")
|
|
|
|
|
|
return formatted_results
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"搜索文档失败: {str(e)}")
|
|
|
|
|
|
return [] # 返回空列表而不是抛出异常
|
|
|
|
|
|
|
|
|
|
|
|
# 全局文档处理器实例(延迟初始化)
|
|
|
|
|
|
document_processor = None
|
|
|
|
|
|
|
|
|
|
|
|
async def get_document_processor(session: Session = None):
|
|
|
|
|
|
"""获取文档处理器实例(延迟初始化)"""
|
|
|
|
|
|
global document_processor
|
|
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = "获取文档处理器实例"
|
|
|
|
|
|
if document_processor is None:
|
|
|
|
|
|
document_processor = DocumentProcessor()
|
2026-01-26 17:46:07 +08:00
|
|
|
|
try:
|
|
|
|
|
|
await document_processor.initialize(session)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 如果初始化失败,清除部分初始化的实例
|
|
|
|
|
|
document_processor = None
|
|
|
|
|
|
error_msg = f"DocumentProcessor 初始化失败: {str(e)}"
|
|
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"ERROR: {error_msg}"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise ValueError(error_msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
# 检查实例是否已正确初始化
|
|
|
|
|
|
if not hasattr(document_processor, 'embeddings') or not document_processor.embeddings:
|
|
|
|
|
|
error_msg = "DocumentProcessor 未正确初始化:embeddings 未设置"
|
|
|
|
|
|
if session:
|
|
|
|
|
|
session.desc = f"ERROR: {error_msg}"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
# 尝试重新初始化
|
|
|
|
|
|
try:
|
|
|
|
|
|
document_processor = DocumentProcessor()
|
|
|
|
|
|
await document_processor.initialize(session)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
document_processor = None
|
|
|
|
|
|
raise ValueError(f"DocumentProcessor 重新初始化失败: {str(e)}") from e
|
|
|
|
|
|
|
2026-01-21 13:45:39 +08:00
|
|
|
|
return document_processor
|