hyf-backend/th_agenter/api/endpoints/knowledge_base.py

617 lines
23 KiB
Python

"""Knowledge base API endpoints."""
from utils.util_exceptions import HxfResponse
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import JSONResponse
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from ...db.database import get_session
from ...models.user import User
from ...models.knowledge_base import KnowledgeBase, Document
from ...services.knowledge_base import KnowledgeBaseService
from ...services.document import DocumentService
from ...services.auth import AuthService
from utils.util_schemas import (
KnowledgeBaseCreate,
KnowledgeBaseResponse,
DocumentResponse,
DocumentListResponse,
DocumentUpload,
DocumentProcessingStatus,
DocumentChunksResponse,
ErrorResponse
)
from utils.util_file import FileUtils
from ...core.config import settings
router = APIRouter(tags=["knowledge-bases"])
@router.post("/", response_model=KnowledgeBaseResponse, summary="创建新的知识库")
async def create_knowledge_base(
kb_data: KnowledgeBaseCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建新的知识库"""
# Check if knowledge base with same name already exists for this user
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}"
kb_service = KnowledgeBaseService(session)
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name)
if existing_kb:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"知识库名称 {kb_data.name} 已存在"
)
# Create knowledge base
session.desc = f"知识库 {kb_data.name}不存在,创建之"
kb = await kb_service.create_knowledge_base(kb_data)
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=0,
active_document_count=0
)
return HxfResponse(response)
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
async def list_knowledge_bases(
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取当前用户的所有知识库"""
session.title = f"获取用户 {current_user.username} 的所有知识库"
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
kb_service = KnowledgeBaseService(session)
knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit)
result = []
for kb in knowledge_bases:
# 本知识库的文档数量
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
# 本知识库的已处理文档数量
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
result.append(KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
))
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
return HxfResponse(result)
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
async def get_knowledge_base(
kb_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""根据知识库ID获取知识库详情"""
session.desc = f"START: 获取知识库 {kb_id} 的详情"
service = KnowledgeBaseService(session)
session.desc = f"检查知识库 {kb_id} 是否存在"
kb = await service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Count documents
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
session.desc = f"获取知识库 {kb_id}{total_docs} 个文档"
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
)
return HxfResponse(response)
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
async def update_knowledge_base(
kb_id: int,
kb_data: KnowledgeBaseCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""更新知识库"""
session.desc = f"START: 更新知识库 {kb_id}"
service = KnowledgeBaseService(session)
kb = await service.update_knowledge_base(kb_id, kb_data)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Count documents
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
session.desc = f"SUCCESS: 更新知识库 {kb_id},结果 - 共 {total_docs} 个文档,其中 {active_docs} 个已处理"
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
)
return HxfResponse(response)
@router.delete("/{kb_id}", summary="删除知识库")
async def delete_knowledge_base(
kb_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库"""
session.desc = f"START: 删除知识库 {kb_id}"
service = KnowledgeBaseService(session)
success = await service.delete_knowledge_base(kb_id)
if not success:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
return HxfResponse({"message": "Knowledge base deleted successfully"})
# Document management endpoints
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
async def upload_document(
kb_id: int,
file: UploadFile = File(...),
process_immediately: bool = Form(True),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""上传文档到知识库"""
session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}"
# Validate file
if not FileUtils.validate_file_extension(file.filename):
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
)
# Check file size (50MB limit)
max_size = 50 * 1024 * 1024 # 50MB
if file.size and file.size > max_size:
session.desc = f"ERROR: 文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
)
session.desc = f"文件为期望类型,处理文件 {file.filename} - "
# Upload document
doc_service = DocumentService(session)
document = await doc_service.upload_document(
file, kb_id
)
# Process document immediately if requested
if process_immediately:
try:
await doc_service.process_document(document.id, kb_id)
# Refresh document to get updated status
await session.refresh(document)
except Exception as e:
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
response = DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
return HxfResponse(response)
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
async def list_documents(
kb_id: int,
skip: int = 0,
limit: int = 50,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档列表。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
documents, total = await doc_service.list_documents(kb_id, skip, limit)
doc_responses = []
for doc in documents:
doc_responses.append(DocumentResponse(
id=doc.id,
created_at=doc.created_at,
updated_at=doc.updated_at,
knowledge_base_id=doc.knowledge_base_id,
filename=doc.filename,
original_filename=doc.original_filename,
file_path=doc.file_path,
file_type=doc.file_type,
file_size=doc.file_size,
mime_type=doc.mime_type,
is_processed=doc.is_processed,
processing_error=doc.processing_error,
chunk_count=doc.chunk_count or 0,
embedding_model=doc.embedding_model,
file_size_mb=round(doc.file_size / (1024 * 1024), 2)
))
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total}"
response = DocumentListResponse(
documents=doc_responses,
total=total,
page=skip // limit + 1,
page_size=limit
)
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
async def get_document_chunks(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""
获取知识库中特定文档的所有文档块(片段)。
Args:
kb_id: 知识库ID
doc_id: 文档ID
session: 数据库会话
current_user: 当前认证用户
Returns:
DocumentChunksResponse: 文档块(片段)响应模型
"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
kb_service = KnowledgeBaseService(session)
knowledge_base = await kb_service.get_knowledge_base(kb_id)
if not knowledge_base:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="知识库不存在"
)
# Verify document exists in the knowledge base
doc_service = DocumentService(session)
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService"
document = await doc_service.get_document(doc_id, kb_id)
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document"
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文档不存在"
)
# Get document chunks
chunks = await doc_service.get_document_chunks(doc_id)
session.desc = f"SUCCESS: 获取文档 {doc_id}{len(chunks)} 个文档块(片段)"
response = DocumentChunksResponse(
document_id=doc_id,
document_name=document.filename,
total_chunks=len(chunks),
chunks=chunks
)
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
async def get_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档详情。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
response = DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
return HxfResponse(response)
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
async def delete_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库中的文档。"""
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
success = await doc_service.delete_document(doc_id, kb_id)
if not success:
session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
response = {"message": "Document deleted successfully"}
return HxfResponse(response)
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
async def process_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""处理知识库中的文档,用于向量搜索。"""
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Check if document exists
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Process the document
result = await doc_service.process_document(doc_id, kb_id)
await session.refresh(document)
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
response = DocumentProcessingStatus(
document_id=doc_id,
status=result["status"],
progress=result.get("progress", 0.0),
error_message=result.get("error_message"),
chunks_created=result.get("chunks_created", 0)
)
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
async def get_document_processing_status(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档处理状态。"""
# Verify knowledge base exists and user has access
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Determine status
if document.processing_error:
status_str = "failed"
progress = 0.0
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
elif document.is_processed:
status_str = "completed"
progress = 100.0
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
else:
status_str = "pending"
progress = 0.0
session.desc = f"文档 {doc_id} 处理pending中"
response = DocumentProcessingStatus(
document_id=document.id,
status=status_str,
progress=progress,
error_message=document.processing_error,
chunks_created=document.chunk_count or 0
)
return HxfResponse(response)
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
async def search_knowledge_base(
kb_id: int,
query: str,
limit: int = 5,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""在知识库中搜索文档。"""
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Perform search
doc_service = DocumentService(session)
results = await doc_service.search_documents(kb_id, query, limit)
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
response = {
"knowledge_base_id": kb_id,
"query": query,
"results": results,
"total_results": len(results)
}
return HxfResponse(response)