"""Knowledge base service.""" # Standard library imports from typing import List, Optional, Dict, Any # Third-party imports from loguru import logger from sqlalchemy import select, and_, or_ from sqlalchemy.orm import Session # Local imports from ..core.config import get_settings from ..core.context import UserContext from ..models.knowledge_base import KnowledgeBase from .document_processor import get_document_processor from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate settings = get_settings() class KnowledgeBaseService: """知识库基础服务类,用于管理知识基础。 该服务类提供了创建、获取、更新、删除和搜索知识库基础的功能。 """ def __init__(self, session: Session): """初始化知识库基础服务类。 Args: session (Session): 数据库会话,用于执行ORM操作。 """ self.session = session def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase: """创建一个新的知识库实例。 Args: kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据。 Returns: KnowledgeBase: 创建的知识库实例。 Raises: Exception: 如果创建过程中发生错误。 """ try: # Generate collection name for vector database collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}" kb = KnowledgeBase( name=kb_data.name, description=kb_data.description, embedding_model=kb_data.embedding_model, chunk_size=kb_data.chunk_size, chunk_overlap=kb_data.chunk_overlap, vector_db_type=settings.vector_db.type, collection_name=collection_name ) # Set audit fields kb.set_audit_fields() self.session.add(kb) self.session.commit() self.session.refresh(kb) logger.info(f"Created knowledge base: {kb.name} (ID: {kb.id})") return kb except Exception as e: self.session.rollback() logger.error(f"Failed to create knowledge base: {str(e)}") raise def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]: """根据ID获取知识库实例。 Args: kb_id (int): 知识库实例的ID。 Returns: Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 """ stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) return self.session.execute(stmt).scalar_one_or_none() def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]: """根据名称获取当前用户的知识库实例。 Args: name (str): 知识库实例的名称。 Returns: Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 """ stmt = select(KnowledgeBase).where( KnowledgeBase.name == name, KnowledgeBase.created_by == UserContext.get_current_user().id ) return self.session.execute(stmt).scalar_one_or_none() async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]: """获取当前用户的所有知识库的列表。 Args: skip (int, optional): 跳过的记录数。默认值为0。 limit (int, optional): 返回的最大记录数。默认值为50。 active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。 Returns: List[KnowledgeBase]: 当前用户的知识库列表。 """ stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user().id) if active_only: stmt = stmt.where(KnowledgeBase.is_active == True) stmt = stmt.offset(skip).limit(limit) return (await self.session.execute(stmt)).scalars().all() def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]: """更新知识库实例。 Args: kb_id (int): 待更新的知识库实例ID。 kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。 Returns: Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。 Raises: Exception: 如果更新过程中发生错误。 """ try: kb = self.get_knowledge_base(kb_id) if not kb: return None # Update fields update_data = kb_update.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(kb, field, value) # Set audit fields kb.set_audit_fields(is_update=True) self.session.commit() self.session.refresh(kb) self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})" return kb except Exception as e: self.session.rollback() self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb_id} 失败: {str(e)}" raise def delete_knowledge_base(self, kb_id: int) -> bool: """删除知识库实例。 Args: kb_id (int): 待删除的知识库实例ID。 Returns: bool: 如果知识库实例被成功删除则返回True,否则返回False。 Raises: Exception: 如果删除过程中发生错误。 """ kb = self.get_knowledge_base(kb_id) if not kb: return False # TODO: Clean up vector database collection # This should be implemented when vector database service is ready self.session.delete(kb) self.session.commit() return True def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]: """Search knowledge bases by name or description for the current user. Args: query (str): Search query. skip (int, optional): Number of records to skip. Defaults to 0. limit (int, optional): Maximum number of records to return. Defaults to 50. Returns: List[KnowledgeBase]: List of matching knowledge bases. """ stmt = select(KnowledgeBase).where( KnowledgeBase.created_by == UserContext.get_current_user().id, KnowledgeBase.is_active == True, or_( KnowledgeBase.name.ilike(f"%{query}%"), KnowledgeBase.description.ilike(f"%{query}%") ) ) stmt = stmt.offset(skip).limit(limit) return self.session.execute(stmt).scalars().all() async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: """Search in knowledge base using vector similarity. Args: kb_id (int): ID of the knowledge base to search in. query (str): Search query. top_k (int, optional): Maximum number of results to return. Defaults to 5. similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7. Returns: List[Dict[str, Any]]: List of search results with content, source, score, and metadata. """ try: logger.info(f"Searching in knowledge base {kb_id} for: {query}") # Use document processor for vector search search_results = get_document_processor().search_similar_documents( knowledge_base_id=kb_id, query=query, k=top_k ) # Filter by similarity threshold filtered_results = [] for result in search_results: # Use already normalized similarity score normalized_score = result.get('normalized_score', 0) if normalized_score >= similarity_threshold: filtered_results.append({ "content": result.get('content', ''), "source": result.get('source', 'unknown'), "score": normalized_score, "metadata": result.get('metadata', {}), "document_id": result.get('document_id', 'unknown'), "chunk_id": result.get('chunk_id', 'unknown') }) logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})") return filtered_results except Exception as e: logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}") return []