hxf/backend/th_agenter/api/endpoints/knowledge_base.py

599 lines
22 KiB
Python

"""Knowledge base API endpoints."""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import JSONResponse
from sqlalchemy import select, func
from sqlalchemy.orm import Session
from ...db.database import get_session
from ...models.user import User
from ...models.knowledge_base import KnowledgeBase, Document
from ...services.knowledge_base import KnowledgeBaseService
from ...services.document import DocumentService
from ...services.auth import AuthService
from utils.util_schemas import (
KnowledgeBaseCreate,
KnowledgeBaseResponse,
DocumentResponse,
DocumentListResponse,
DocumentUpload,
DocumentProcessingStatus,
DocumentChunksResponse,
ErrorResponse
)
from utils.util_file import FileUtils
from ...core.config import settings
router = APIRouter(tags=["knowledge-bases"])
@router.post("/", response_model=KnowledgeBaseResponse, summary="创建新的知识库")
async def create_knowledge_base(
kb_data: KnowledgeBaseCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建新的知识库"""
# Check if knowledge base with same name already exists for this user
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data.name}"
service = KnowledgeBaseService(session)
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
existing_kb = service.get_knowledge_base_by_name(kb_data.name)
if existing_kb:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"知识库名称 {kb_data.name} 已存在"
)
# Create knowledge base
session.desc = f"知识库 {kb_data.name}不存在,创建之"
kb = service.create_knowledge_base(kb_data)
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
return KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=0,
active_document_count=0
)
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
async def list_knowledge_bases(
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取当前用户的所有知识库"""
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库"
service = KnowledgeBaseService(session)
session.desc = f"获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
knowledge_bases = await service.get_knowledge_bases(skip=skip, limit=limit)
result = []
for kb in knowledge_bases:
# Count documents
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
result.append(KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
))
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
return result
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
async def get_knowledge_base(
kb_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""根据知识库ID获取知识库详情"""
session.desc = f"START: 获取知识库 {kb_id} 的详情"
service = KnowledgeBaseService(session)
session.desc = f"检查知识库 {kb_id} 是否存在"
kb = service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Count documents
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
session.desc = f"获取知识库 {kb_id}{total_docs} 个文档"
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
return KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
)
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
async def update_knowledge_base(
kb_id: int,
kb_data: KnowledgeBaseCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""更新知识库"""
session.desc = f"START: 更新知识库 {kb_id}"
service = KnowledgeBaseService(session)
kb = service.update_knowledge_base(kb_id, kb_data)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Count documents
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
Document.is_processed == True
)
)
session.desc = f"SUCCESS: 更新知识库 {kb_id},共 {total_docs} 个文档,其中 {active_docs} 个已处理"
return KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
name=kb.name,
description=kb.description,
embedding_model=kb.embedding_model,
chunk_size=kb.chunk_size,
chunk_overlap=kb.chunk_overlap,
is_active=kb.is_active,
vector_db_type=kb.vector_db_type,
collection_name=kb.collection_name,
document_count=total_docs,
active_document_count=active_docs
)
@router.delete("/{kb_id}", summary="删除知识库")
async def delete_knowledge_base(
kb_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库"""
session.desc = f"START: 删除知识库 {kb_id}"
service = KnowledgeBaseService(session)
success = service.delete_knowledge_base(kb_id)
if not success:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
return {"message": "Knowledge base deleted successfully"}
# Document management endpoints
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
async def upload_document(
kb_id: int,
file: UploadFile = File(...),
process_immediately: bool = Form(True),
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""上传文档到知识库"""
session.desc = f"START: 上传文档到知识库 {kb_id}"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Validate file
if not FileUtils.validate_file_extension(file.filename):
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
)
# Check file size (50MB limit)
max_size = 50 * 1024 * 1024 # 50MB
if file.size and file.size > max_size:
session.desc = f"ERROR: 文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
)
# Upload document
doc_service = DocumentService(session)
document = await doc_service.upload_document(
file, kb_id
)
# Process document immediately if requested
if process_immediately:
try:
await doc_service.process_document(document.id, kb_id)
# Refresh document to get updated status
await session.refresh(document)
except Exception as e:
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
return DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
async def list_documents(
kb_id: int,
skip: int = 0,
limit: int = 50,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档列表。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
documents, total = doc_service.list_documents(kb_id, skip, limit)
doc_responses = []
for doc in documents:
doc_responses.append(DocumentResponse(
id=doc.id,
created_at=doc.created_at,
updated_at=doc.updated_at,
knowledge_base_id=doc.knowledge_base_id,
filename=doc.filename,
original_filename=doc.original_filename,
file_path=doc.file_path,
file_type=doc.file_type,
file_size=doc.file_size,
mime_type=doc.mime_type,
is_processed=doc.is_processed,
processing_error=doc.processing_error,
chunk_count=doc.chunk_count or 0,
embedding_model=doc.embedding_model,
file_size_mb=round(doc.file_size / (1024 * 1024), 2)
))
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total}"
return DocumentListResponse(
documents=doc_responses,
total=total,
page=skip // limit + 1,
page_size=limit
)
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
async def get_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档详情。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
return DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
async def delete_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库中的文档。"""
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
success = doc_service.delete_document(doc_id, kb_id)
if not success:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
return {"message": "Document deleted successfully"}
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
async def process_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""处理知识库中的文档,用于向量搜索。"""
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Check if document exists
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Process the document
result = await doc_service.process_document(doc_id, kb_id)
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
return DocumentProcessingStatus(
document_id=doc_id,
status=result["status"],
progress=result.get("progress", 0.0),
error_message=result.get("error_message"),
chunks_created=result.get("chunks_created", 0)
)
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
async def get_document_processing_status(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档处理状态。"""
# Verify knowledge base exists and user has access
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Determine status
if document.processing_error:
status_str = "failed"
progress = 0.0
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
elif document.is_processed:
status_str = "completed"
progress = 100.0
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
else:
status_str = "pending"
progress = 0.0
session.desc = f"文档 {doc_id} 处理pending中"
return DocumentProcessingStatus(
document_id=document.id,
status=status_str,
progress=progress,
error_message=document.processing_error,
chunks_created=document.chunk_count or 0
)
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
async def search_knowledge_base(
kb_id: int,
query: str,
limit: int = 5,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""在知识库中搜索文档。"""
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Perform search
doc_service = DocumentService(session)
results = doc_service.search_documents(kb_id, query, limit)
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
return {
"knowledge_base_id": kb_id,
"query": query,
"results": results,
"total_results": len(results)
}
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
async def get_document_chunks(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""
获取知识库中特定文档的所有文档块(片段)。
Args:
kb_id: 知识库ID
doc_id: 文档ID
session: 数据库会话
current_user: 当前认证用户
Returns:
DocumentChunksResponse: 文档块(片段)响应模型
"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
kb_service = KnowledgeBaseService(session)
knowledge_base = kb_service.get_knowledge_base(kb_id)
if not knowledge_base:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="知识库不存在"
)
# Verify document exists in the knowledge base
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文档不存在"
)
# Get document chunks
chunks = doc_service.get_document_chunks(doc_id)
session.desc = f"SUCCESS: 获取文档 {doc_id}{len(chunks)} 个文档块(片段)"
return DocumentChunksResponse(
document_id=doc_id,
document_name=document.filename,
total_chunks=len(chunks),
chunks=chunks
)