hxf/backend/th_agenter/services/document.py

319 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Document service."""
import os
from pathlib import Path
from typing import List, Optional, Dict, Any
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from fastapi import UploadFile
from ..models.knowledge_base import Document, KnowledgeBase
from ..core.config import get_settings
from utils.util_file import FileUtils
from .storage import storage_service
from .document_processor import get_document_processor
from utils.util_schemas import DocumentChunk
from loguru import logger
settings = get_settings()
class DocumentService:
"""Document service for managing documents in knowledge bases."""
def __init__(self, session: Session):
self.session = session
self.file_utils = FileUtils()
async def upload_document(self, file: UploadFile, kb_id: int) -> Document:
"""Upload a document to knowledge base."""
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
# Validate knowledge base exists
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = await self.session.scalar(stmt)
if not kb:
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise ValueError(f"知识库 {kb_id} 不存在")
# Validate file
if not file.filename:
self.session.desc = f"ERROR: 上传文件时未提供文件名"
raise ValueError("No filename provided")
# Validate file extension
file_extension = Path(file.filename).suffix.lower()
if file_extension not in settings.file.allowed_extensions:
self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}"
raise ValueError(f"非期望的文件类型 {file_extension}")
# Upload file using storage service
storage_info = await storage_service.upload_file(file, kb_id)
self.session.desc = f"文档 {file.filename} 上传到 {storage_info}"
# Create document record
document = Document(
knowledge_base_id=kb_id,
filename=os.path.basename(storage_info["file_path"]),
original_filename=file.filename,
file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available
file_size=storage_info["size"],
file_type=file_extension,
mime_type=storage_info["mime_type"],
is_processed=False
)
# Set audit fields
document.set_audit_fields()
self.session.add(document)
await self.session.commit()
await self.session.refresh(document)
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
return document
async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
"""根据文档ID查询文档可选地根据知识库ID过滤。"""
self.session.desc = f"根据文档ID查询文档 {doc_id}"
stmt = select(Document).where(Document.id == doc_id)
if kb_id is not None:
stmt = stmt.where(Document.knowledge_base_id == kb_id)
return await self.session.scalar(stmt)
async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
"""根据知识库ID查询文档支持分页。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
stmt = (
select(Document)
.where(Document.knowledge_base_id == kb_id)
.offset(skip)
.limit(limit)
)
return (await self.session.scalars(stmt)).all()
async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
"""根据知识库ID查询文档支持分页并返回总文档数。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
# Get total count
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
total = await self.session.scalar(count_stmt)
# Get documents with pagination
documents_stmt = (
select(Document)
.where(Document.knowledge_base_id == kb_id)
.offset(skip)
.limit(limit)
)
documents = (await self.session.scalars(documents_stmt)).all()
return documents, total
async def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
"""根据文档ID删除文档可选地根据知识库ID过滤。"""
self.session.desc = f"删除文档 {doc_id}"
document = await self.get_document(doc_id, kb_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
# Delete file from storage
try:
await storage_service.delete_file(document.file_path)
self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}"
except Exception as e:
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
# TODO: Remove from vector database
# This should be implemented when vector database service is ready
self.session.desc = f"从向量数据库删除文档 {doc_id}"
(await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id)
# Delete database record
self.session.desc = f"删除数据库记录 {doc_id}"
await self.session.delete(document)
await self.session.commit()
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
return True
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
"""处理文档,提取文本并创建嵌入向量。"""
try:
self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量"
document = await self.get_document(doc_id, kb_id)
self.session.desc = f"获取文档 {doc_id} >>> {document}"
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise ValueError(f"Document {doc_id} not found")
# document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素
file_path = document.file_path
knowledge_base_id=document.knowledge_base_id
is_processed=document.is_processed
if is_processed:
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
return {
"document_id": doc_id,
"status": "already_processed",
"message": "文档已处理"
}
self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}"
# 更新文档状态为处理中
document.processing_error = None
await self.session.commit()
self.session.desc = f"更新文档状态为处理中 {doc_id}"
# 调用文档处理器进行处理
document_processor = await get_document_processor(self.session)
self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}"
result = await document_processor.process_document(
session=self.session,
document_id=doc_id,
file_path=file_path,
knowledge_base_id=knowledge_base_id
)
self.session.desc = f"处理文档完毕 {doc_id}"
# 如果处理成功,更新文档状态
if result["status"] == "success":
document.is_processed = True
document.chunk_count = result.get("chunks_count", 0)
await self.session.commit()
await self.session.refresh(document)
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
return result
except Exception as e:
await self.session.rollback()
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
# Update document with error
try:
document = await self.get_document(doc_id, kb_id)
if document:
document.processing_error = str(e)
await self.session.commit()
except Exception as db_error:
logger.error(f"Failed to update document error status: {db_error}")
return {
"document_id": doc_id,
"status": "failed",
"error": str(e),
"message": "文档处理失败"
}
async def _extract_text(self, document: Document) -> str:
"""从文档文件中提取文本内容。"""
try:
if document.is_text_file:
# Read text files directly
with open(document.file_path, 'r', encoding='utf-8') as f:
return f.read()
elif document.is_pdf_file:
# TODO: Implement PDF text extraction using PyPDF2 or similar
# For now, return placeholder
return f"PDF content from {document.original_filename}"
elif document.is_office_file:
# TODO: Implement Office file text extraction using python-docx, openpyxl, etc.
# For now, return placeholder
return f"Office document content from {document.original_filename}"
else:
self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}"
raise ValueError(f"不支持的文件类型: {document.file_type}")
except Exception as e:
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
raise
async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
"""更新文档处理状态。"""
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
document = await self.get_document(doc_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
document.is_processed = is_processed
document.processing_error = error
await self.session.commit()
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
return True
async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""在知识库中搜索文档使用向量相似度。"""
try:
# 使用文档处理器进行相似性搜索
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}"
results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit)
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
return results
except Exception as e:
self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}"
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
return []
async def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
"""获取知识库中的文档统计信息。"""
documents = await self.get_documents(kb_id, limit=1000) # Get all documents
total_count = len(documents)
processed_count = len([doc for doc in documents if doc.is_processed])
total_size = sum(doc.file_size for doc in documents)
file_types = {}
for doc in documents:
file_type = doc.file_type
file_types[file_type] = file_types.get(file_type, 0) + 1
return {
"total_documents": total_count,
"processed_documents": processed_count,
"pending_documents": total_count - processed_count,
"total_size_bytes": total_size,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"file_types": file_types
}
async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
"""获取特定文档的文档块。"""
try:
self.session.desc = f"获取文档 {doc_id} 的文档块"
stmt = select(Document).where(Document.id == doc_id)
document = await self.session.scalar(stmt)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return []
self.session.desc = f"获取文档 {doc_id} 的文档块 > document"
# Get chunks from document processor
chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id)
self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data"
# Convert to DocumentChunk objects
chunks = []
for chunk_data in chunks_data:
chunk = DocumentChunk(
id=chunk_data["id"],
content=chunk_data["content"],
metadata=chunk_data["metadata"],
page_number=chunk_data.get("page_number"),
chunk_index=chunk_data["chunk_index"],
start_char=chunk_data.get("start_char"),
end_char=chunk_data.get("end_char"),
vector_id=chunk_data.get("vector_id")
)
chunks.append(chunk)
self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)}"
return chunks
except Exception as e:
self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}"
return []