hxf/backend/th_agenter/services/document.py

302 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Document service."""
import os
from pathlib import Path
from typing import List, Optional, Dict, Any
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from fastapi import UploadFile
from ..models.knowledge_base import Document, KnowledgeBase
from ..core.config import get_settings
from utils.util_file import FileUtils
from .storage import storage_service
from .document_processor import get_document_processor
from utils.util_schemas import DocumentChunk
from loguru import logger
settings = get_settings()
class DocumentService:
"""Document service for managing documents in knowledge bases."""
def __init__(self, session: Session):
self.session = session
self.file_utils = FileUtils()
async def upload_document(self, file: UploadFile, kb_id: int) -> Document:
"""Upload a document to knowledge base."""
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
# Validate knowledge base exists
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = self.session.scalar(stmt)
if not kb:
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise ValueError(f"知识库 {kb_id} 不存在")
# Validate file
if not file.filename:
self.session.desc = f"ERROR: 上传文件时未提供文件名"
raise ValueError("No filename provided")
# Validate file extension
file_extension = Path(file.filename).suffix.lower()
if file_extension not in settings.file.allowed_extensions:
self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}"
raise ValueError(f"非期望的文件类型 {file_extension}")
# Upload file using storage service
storage_info = await storage_service.upload_file(file, kb_id)
# Create document record
document = Document(
knowledge_base_id=kb_id,
filename=os.path.basename(storage_info["file_path"]),
original_filename=file.filename,
file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available
file_size=storage_info["size"],
file_type=file_extension,
mime_type=storage_info["mime_type"],
is_processed=False
)
# Set audit fields
document.set_audit_fields()
self.session.add(document)
self.session.commit()
self.session.refresh(document)
self.session.desc = f"SUCCESS: 成功上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
return document
def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
"""根据文档ID查询文档可选地根据知识库ID过滤。"""
self.session.desc = f"查询文档 {doc_id}"
stmt = select(Document).where(Document.id == doc_id)
if kb_id is not None:
stmt = stmt.where(Document.knowledge_base_id == kb_id)
return self.session.scalar(stmt)
def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
"""根据知识库ID查询文档支持分页。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
stmt = (
select(Document)
.where(Document.knowledge_base_id == kb_id)
.offset(skip)
.limit(limit)
)
return self.session.scalars(stmt).all()
def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
"""根据知识库ID查询文档支持分页并返回总文档数。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
# Get total count
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
total = self.session.scalar(count_stmt)
# Get documents with pagination
documents_stmt = (
select(Document)
.where(Document.knowledge_base_id == kb_id)
.offset(skip)
.limit(limit)
)
documents = self.session.scalars(documents_stmt).all()
return documents, total
def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
"""根据文档ID删除文档可选地根据知识库ID过滤。"""
self.session.desc = f"删除文档 {doc_id}"
document = self.get_document(doc_id, kb_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
# Delete file from storage
try:
storage_service.delete_file(document.file_path)
logger.info(f"Deleted file: {document.file_path}")
except Exception as e:
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
# TODO: Remove from vector database
# This should be implemented when vector database service is ready
get_document_processor().delete_document_from_vector_store(kb_id,doc_id)
# Delete database record
self.session.delete(document)
self.session.commit()
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
return True
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
"""处理文档,提取文本并创建嵌入向量。"""
try:
self.session.desc = f"处理文档 {doc_id}"
document = self.get_document(doc_id, kb_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise ValueError(f"Document {doc_id} not found")
if document.is_processed:
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
return {
"document_id": doc_id,
"status": "already_processed",
"message": "文档已处理"
}
# 更新文档状态为处理中
document.processing_error = None
self.session.commit()
# 调用文档处理器进行处理
result = get_document_processor().process_document(
document_id=doc_id,
file_path=document.file_path,
knowledge_base_id=document.knowledge_base_id
)
self.session.desc = f"SUCCESS: 成功处理文档 {doc_id}"
# 如果处理成功,更新文档状态
if result["status"] == "success":
document.is_processed = True
document.chunk_count = result.get("chunks_count", 0)
self.session.commit()
self.session.refresh(document)
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
return result
except Exception as e:
self.session.rollback()
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
# Update document with error
try:
document = self.get_document(doc_id)
if document:
document.processing_error = str(e)
self.session.commit()
except Exception as db_error:
logger.error(f"Failed to update document error status: {db_error}")
return {
"document_id": doc_id,
"status": "failed",
"error": str(e),
"message": "文档处理失败"
}
async def _extract_text(self, document: Document) -> str:
"""从文档文件中提取文本内容。"""
try:
if document.is_text_file:
# Read text files directly
with open(document.file_path, 'r', encoding='utf-8') as f:
return f.read()
elif document.is_pdf_file:
# TODO: Implement PDF text extraction using PyPDF2 or similar
# For now, return placeholder
return f"PDF content from {document.original_filename}"
elif document.is_office_file:
# TODO: Implement Office file text extraction using python-docx, openpyxl, etc.
# For now, return placeholder
return f"Office document content from {document.original_filename}"
else:
self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}"
raise ValueError(f"不支持的文件类型: {document.file_type}")
except Exception as e:
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
raise
def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
"""更新文档处理状态。"""
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
document = self.get_document(doc_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
document.is_processed = is_processed
document.processing_error = error
self.session.commit()
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
return True
def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""在知识库中搜索文档使用向量相似度。"""
try:
# 使用文档处理器进行相似性搜索
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query}"
results = get_document_processor().search_similar_documents(kb_id, query, limit)
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
return results
except Exception as e:
self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}"
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
return []
def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
"""获取知识库中的文档统计信息。"""
documents = self.get_documents(kb_id, limit=1000) # Get all documents
total_count = len(documents)
processed_count = len([doc for doc in documents if doc.is_processed])
total_size = sum(doc.file_size for doc in documents)
file_types = {}
for doc in documents:
file_type = doc.file_type
file_types[file_type] = file_types.get(file_type, 0) + 1
return {
"total_documents": total_count,
"processed_documents": processed_count,
"pending_documents": total_count - processed_count,
"total_size_bytes": total_size,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"file_types": file_types
}
def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
"""获取特定文档的文档块。"""
try:
self.session.desc = f"获取文档 {doc_id} 的文档块"
stmt = select(Document).where(Document.id == doc_id)
document = self.session.scalar(stmt)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return []
# Get chunks from document processor
chunks_data = get_document_processor().get_document_chunks(document.knowledge_base_id, doc_id)
# Convert to DocumentChunk objects
chunks = []
for chunk_data in chunks_data:
chunk = DocumentChunk(
id=chunk_data["id"],
content=chunk_data["content"],
metadata=chunk_data["metadata"],
page_number=chunk_data.get("page_number"),
chunk_index=chunk_data["chunk_index"],
start_char=chunk_data.get("start_char"),
end_char=chunk_data.get("end_char")
)
chunks.append(chunk)
self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)}"
return chunks
except Exception as e:
self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}"
return []