245 lines
9.5 KiB
Python
245 lines
9.5 KiB
Python
"""Knowledge base service."""
|
||
|
||
# Standard library imports
|
||
from typing import List, Optional, Dict, Any
|
||
|
||
# Third-party imports
|
||
from loguru import logger
|
||
from sqlalchemy import select, and_, or_
|
||
from sqlalchemy.orm import Session
|
||
|
||
# Local imports
|
||
from ..core.config import get_settings
|
||
from ..core.context import UserContext
|
||
from ..models.knowledge_base import KnowledgeBase
|
||
from .document_processor import get_document_processor
|
||
from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate
|
||
|
||
settings = get_settings()
|
||
|
||
class KnowledgeBaseService:
|
||
"""知识库基础服务类,用于管理知识基础。
|
||
|
||
该服务类提供了创建、获取、更新、删除和搜索知识库基础的功能。
|
||
"""
|
||
|
||
def __init__(self, session: Session):
|
||
"""初始化知识库基础服务类。
|
||
|
||
Args:
|
||
session (Session): 数据库会话,用于执行ORM操作。
|
||
"""
|
||
if session is None:
|
||
logger.error("session为空,session must be an instance of Session")
|
||
self.session = session
|
||
|
||
async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
|
||
"""创建一个新的知识库实例。
|
||
|
||
Args:
|
||
kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据。
|
||
|
||
Returns:
|
||
KnowledgeBase: 创建的知识库实例。
|
||
|
||
Raises:
|
||
Exception: 如果创建过程中发生错误。
|
||
"""
|
||
try:
|
||
# Generate collection name for vector database
|
||
collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}"
|
||
|
||
kb = KnowledgeBase(
|
||
name=kb_data.name,
|
||
description=kb_data.description,
|
||
embedding_model=kb_data.embedding_model,
|
||
chunk_size=kb_data.chunk_size,
|
||
chunk_overlap=kb_data.chunk_overlap,
|
||
vector_db_type=settings.vector_db.type,
|
||
collection_name=collection_name
|
||
)
|
||
|
||
# 自动更新created_by和updated_by字段
|
||
kb.set_audit_fields()
|
||
|
||
self.session.add(kb)
|
||
await self.session.commit()
|
||
await self.session.refresh(kb)
|
||
|
||
self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}"
|
||
return kb
|
||
|
||
except Exception as e:
|
||
await self.session.rollback()
|
||
logger.error(f"Failed to create knowledge base: {str(e)}")
|
||
raise
|
||
|
||
async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
|
||
"""Search knowledge bases by name or description for the current user.
|
||
|
||
Args:
|
||
query (str): Search query.
|
||
skip (int, optional): Number of records to skip. Defaults to 0.
|
||
limit (int, optional): Maximum number of records to return. Defaults to 50.
|
||
|
||
Returns:
|
||
List[KnowledgeBase]: List of matching knowledge bases.
|
||
"""
|
||
stmt = select(KnowledgeBase).where(
|
||
KnowledgeBase.created_by == UserContext.get_current_user()['id'],
|
||
KnowledgeBase.is_active == True,
|
||
or_(
|
||
KnowledgeBase.name.ilike(f"%{query}%"),
|
||
KnowledgeBase.description.ilike(f"%{query}%")
|
||
)
|
||
)
|
||
|
||
stmt = stmt.offset(skip).limit(limit)
|
||
return (await self.session.execute(stmt)).scalars().all()
|
||
|
||
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
|
||
"""Search in knowledge base using vector similarity.
|
||
|
||
Args:
|
||
kb_id (int): ID of the knowledge base to search in.
|
||
query (str): Search query.
|
||
top_k (int, optional): Maximum number of results to return. Defaults to 5.
|
||
similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7.
|
||
|
||
Returns:
|
||
List[Dict[str, Any]]: List of search results with content, source, score, and metadata.
|
||
"""
|
||
try:
|
||
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
|
||
|
||
# Use document processor for vector search
|
||
search_results = (await get_document_processor(self.session)).search_similar_documents(
|
||
knowledge_base_id=kb_id,
|
||
query=query,
|
||
k=top_k
|
||
)
|
||
|
||
# Filter by similarity threshold
|
||
filtered_results = []
|
||
for result in search_results:
|
||
# Use already normalized similarity score
|
||
normalized_score = result.get('normalized_score', 0)
|
||
|
||
if normalized_score >= similarity_threshold:
|
||
filtered_results.append({
|
||
"content": result.get('content', ''),
|
||
"source": result.get('source', 'unknown'),
|
||
"score": normalized_score,
|
||
"metadata": result.get('metadata', {}),
|
||
"document_id": result.get('document_id', 'unknown'),
|
||
"chunk_id": result.get('chunk_id', 'unknown')
|
||
})
|
||
|
||
logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})")
|
||
return filtered_results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
|
||
return []
|
||
|
||
# ----------------------------------------------------------------------------------
|
||
async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
|
||
"""根据名称获取当前用户的知识库实例。
|
||
|
||
Args:
|
||
name (str): 知识库实例的名称。
|
||
|
||
Returns:
|
||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||
"""
|
||
stmt = select(KnowledgeBase).where(
|
||
KnowledgeBase.name == name,
|
||
KnowledgeBase.created_by == UserContext.get_current_user()['id']
|
||
)
|
||
result = (await self.session.execute(stmt)).scalar_one_or_none()
|
||
return result
|
||
|
||
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
|
||
"""获取当前用户的所有知识库的列表。
|
||
|
||
Args:
|
||
skip (int, optional): 跳过的记录数。默认值为0。
|
||
limit (int, optional): 返回的最大记录数。默认值为50。
|
||
active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。
|
||
|
||
Returns:
|
||
List[KnowledgeBase]: 当前用户的知识库列表。
|
||
"""
|
||
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID
|
||
if active_only:
|
||
stmt = stmt.where(KnowledgeBase.is_active == True)
|
||
stmt = stmt.offset(skip).limit(limit)
|
||
return (await self.session.execute(stmt)).scalars().all()
|
||
|
||
async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
|
||
"""根据ID获取知识库实例。
|
||
|
||
Args:
|
||
kb_id (int): 知识库实例的ID。
|
||
|
||
Returns:
|
||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||
"""
|
||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||
|
||
async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
|
||
"""更新知识库实例。
|
||
|
||
Args:
|
||
kb_id (int): 待更新的知识库实例ID。
|
||
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。
|
||
|
||
Returns:
|
||
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。
|
||
|
||
Raises:
|
||
Exception: 如果更新过程中发生错误。
|
||
"""
|
||
kb = await self.get_knowledge_base(kb_id)
|
||
if not kb:
|
||
return None
|
||
|
||
# Update fields
|
||
update_data = kb_update.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(kb, field, value)
|
||
|
||
# Set audit fields
|
||
kb.set_audit_fields(is_update=True)
|
||
|
||
await self.session.commit()
|
||
await self.session.refresh(kb)
|
||
|
||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
|
||
return kb
|
||
|
||
async def delete_knowledge_base(self, kb_id: int) -> bool:
|
||
"""删除知识库实例。
|
||
|
||
Args:
|
||
kb_id (int): 待删除的知识库实例ID。
|
||
|
||
Returns:
|
||
bool: 如果知识库实例被成功删除则返回True,否则返回False。
|
||
|
||
Raises:
|
||
Exception: 如果删除过程中发生错误。
|
||
"""
|
||
kb = await self.get_knowledge_base(kb_id)
|
||
if not kb:
|
||
return False
|
||
|
||
# TODO: Clean up vector database collection
|
||
# This should be implemented when vector database service is ready
|
||
|
||
await self.session.delete(kb)
|
||
await self.session.commit()
|
||
|
||
return True
|