"""Document service.""" import os from pathlib import Path from typing import List, Optional, Dict, Any from sqlalchemy import select, func from sqlalchemy.orm import Session from fastapi import UploadFile from ..models.knowledge_base import Document, KnowledgeBase from ..core.config import get_settings from utils.util_file import FileUtils from .storage import storage_service from .document_processor import get_document_processor from utils.util_schemas import DocumentChunk from loguru import logger settings = get_settings() class DocumentService: """Document service for managing documents in knowledge bases.""" def __init__(self, session: Session): self.session = session self.file_utils = FileUtils() async def upload_document(self, file: UploadFile, kb_id: int) -> Document: """Upload a document to knowledge base.""" self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}" # Validate knowledge base exists stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) kb = await self.session.scalar(stmt) if not kb: self.session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise ValueError(f"知识库 {kb_id} 不存在") # Validate file if not file.filename: self.session.desc = f"ERROR: 上传文件时未提供文件名" raise ValueError("No filename provided") # Validate file extension file_extension = Path(file.filename).suffix.lower() if file_extension not in settings.file.allowed_extensions: self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}" raise ValueError(f"非期望的文件类型 {file_extension}") # Upload file using storage service storage_info = await storage_service.upload_file(file, kb_id) self.session.desc = f"文档 {file.filename} 上传到 {storage_info}" # Create document record document = Document( knowledge_base_id=kb_id, filename=os.path.basename(storage_info["file_path"]), original_filename=file.filename, file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available file_size=storage_info["size"], file_type=file_extension, mime_type=storage_info["mime_type"], is_processed=False ) # Set audit fields document.set_audit_fields() self.session.add(document) await self.session.commit() await self.session.refresh(document) self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})" return document async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]: """根据文档ID查询文档,可选地根据知识库ID过滤。""" self.session.desc = f"根据文档ID查询文档 {doc_id}" stmt = select(Document).where(Document.id == doc_id) if kb_id is not None: stmt = stmt.where(Document.knowledge_base_id == kb_id) return await self.session.scalar(stmt) async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]: """根据知识库ID查询文档,支持分页。""" self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" stmt = ( select(Document) .where(Document.knowledge_base_id == kb_id) .offset(skip) .limit(limit) ) return (await self.session.scalars(stmt)).all() async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]: """根据知识库ID查询文档,支持分页,并返回总文档数。""" self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" # Get total count count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id) total = await self.session.scalar(count_stmt) # Get documents with pagination documents_stmt = ( select(Document) .where(Document.knowledge_base_id == kb_id) .offset(skip) .limit(limit) ) documents = (await self.session.scalars(documents_stmt)).all() return documents, total async def delete_document(self, doc_id: int, kb_id: int = None) -> bool: """根据文档ID删除文档,可选地根据知识库ID过滤。""" self.session.desc = f"删除文档 {doc_id}" document = await self.get_document(doc_id, kb_id) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return False # Delete file from storage try: await storage_service.delete_file(document.file_path) self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}" except Exception as e: self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}" # TODO: Remove from vector database # This should be implemented when vector database service is ready self.session.desc = f"从向量数据库删除文档 {doc_id}" (await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id) # Delete database record self.session.desc = f"删除数据库记录 {doc_id}" await self.session.delete(document) await self.session.commit() self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}" return True async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]: """处理文档,提取文本并创建嵌入向量。""" try: self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量" document = await self.get_document(doc_id, kb_id) self.session.desc = f"获取文档 {doc_id} >>> {document}" if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" raise ValueError(f"Document {doc_id} not found") # document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素 file_path = document.file_path knowledge_base_id=document.knowledge_base_id is_processed=document.is_processed if is_processed: self.session.desc = f"INFO: 文档 {doc_id} 已处理" return { "document_id": doc_id, "status": "already_processed", "message": "文档已处理" } self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}" # 更新文档状态为处理中 document.processing_error = None await self.session.commit() self.session.desc = f"更新文档状态为处理中 {doc_id}" # 调用文档处理器进行处理 document_processor = await get_document_processor(self.session) self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}" result = await document_processor.process_document( session=self.session, document_id=doc_id, file_path=file_path, knowledge_base_id=knowledge_base_id ) self.session.desc = f"处理文档完毕 {doc_id}" # 如果处理成功,更新文档状态 if result["status"] == "success": document.is_processed = True document.chunk_count = result.get("chunks_count", 0) await self.session.commit() await self.session.refresh(document) logger.info(f"Processed document: {document.filename} (ID: {doc_id})") return result except Exception as e: await self.session.rollback() self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}" # Update document with error try: document = await self.get_document(doc_id, kb_id) if document: document.processing_error = str(e) await self.session.commit() except Exception as db_error: logger.error(f"Failed to update document error status: {db_error}") return { "document_id": doc_id, "status": "failed", "error": str(e), "message": "文档处理失败" } async def _extract_text(self, document: Document) -> str: """从文档文件中提取文本内容。""" try: if document.is_text_file: # Read text files directly with open(document.file_path, 'r', encoding='utf-8') as f: return f.read() elif document.is_pdf_file: # TODO: Implement PDF text extraction using PyPDF2 or similar # For now, return placeholder return f"PDF content from {document.original_filename}" elif document.is_office_file: # TODO: Implement Office file text extraction using python-docx, openpyxl, etc. # For now, return placeholder return f"Office document content from {document.original_filename}" else: self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}" raise ValueError(f"不支持的文件类型: {document.file_type}") except Exception as e: self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}" raise async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool: """更新文档处理状态。""" self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}" document = await self.get_document(doc_id) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return False document.is_processed = is_processed document.processing_error = error await self.session.commit() self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}" return True async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]: """在知识库中搜索文档使用向量相似度。""" try: # 使用文档处理器进行相似性搜索 self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}条" results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit) self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果" return results except Exception as e: self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}" logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}") return [] async def get_document_stats(self, kb_id: int) -> Dict[str, Any]: """获取知识库中的文档统计信息。""" documents = await self.get_documents(kb_id, limit=1000) # Get all documents total_count = len(documents) processed_count = len([doc for doc in documents if doc.is_processed]) total_size = sum(doc.file_size for doc in documents) file_types = {} for doc in documents: file_type = doc.file_type file_types[file_type] = file_types.get(file_type, 0) + 1 return { "total_documents": total_count, "processed_documents": processed_count, "pending_documents": total_count - processed_count, "total_size_bytes": total_size, "total_size_mb": round(total_size / (1024 * 1024), 2), "file_types": file_types } async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]: """获取特定文档的文档块。""" try: self.session.desc = f"获取文档 {doc_id} 的文档块" stmt = select(Document).where(Document.id == doc_id) document = await self.session.scalar(stmt) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return [] self.session.desc = f"获取文档 {doc_id} 的文档块 > document" # Get chunks from document processor chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id) self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data" # Convert to DocumentChunk objects chunks = [] for chunk_data in chunks_data: chunk = DocumentChunk( id=chunk_data["id"], content=chunk_data["content"], metadata=chunk_data["metadata"], page_number=chunk_data.get("page_number"), chunk_index=chunk_data["chunk_index"], start_char=chunk_data.get("start_char"), end_char=chunk_data.get("end_char"), vector_id=chunk_data.get("vector_id") ) chunks.append(chunk) self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)} 个" return chunks except Exception as e: self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}" return []