hyf-backend/th_agenter/services/knowledge_base.py

245 lines
9.5 KiB
Python
Raw Permalink Normal View History

2026-01-21 13:45:39 +08:00
"""Knowledge base service."""
# Standard library imports
from typing import List, Optional, Dict, Any
# Third-party imports
from loguru import logger
from sqlalchemy import select, and_, or_
from sqlalchemy.orm import Session
# Local imports
from ..core.config import get_settings
from ..core.context import UserContext
from ..models.knowledge_base import KnowledgeBase
from .document_processor import get_document_processor
from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate
settings = get_settings()
class KnowledgeBaseService:
"""知识库基础服务类,用于管理知识基础。
该服务类提供了创建获取更新删除和搜索知识库基础的功能
"""
def __init__(self, session: Session):
"""初始化知识库基础服务类。
Args:
session (Session): 数据库会话用于执行ORM操作
"""
if session is None:
logger.error("session为空session must be an instance of Session")
self.session = session
async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
"""创建一个新的知识库实例。
Args:
kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据
Returns:
KnowledgeBase: 创建的知识库实例
Raises:
Exception: 如果创建过程中发生错误
"""
try:
# Generate collection name for vector database
collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}"
kb = KnowledgeBase(
name=kb_data.name,
description=kb_data.description,
embedding_model=kb_data.embedding_model,
chunk_size=kb_data.chunk_size,
chunk_overlap=kb_data.chunk_overlap,
vector_db_type=settings.vector_db.type,
collection_name=collection_name
)
# 自动更新created_by和updated_by字段
kb.set_audit_fields()
self.session.add(kb)
await self.session.commit()
await self.session.refresh(kb)
self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}"
return kb
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to create knowledge base: {str(e)}")
raise
async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
"""Search knowledge bases by name or description for the current user.
Args:
query (str): Search query.
skip (int, optional): Number of records to skip. Defaults to 0.
limit (int, optional): Maximum number of records to return. Defaults to 50.
Returns:
List[KnowledgeBase]: List of matching knowledge bases.
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.created_by == UserContext.get_current_user()['id'],
KnowledgeBase.is_active == True,
or_(
KnowledgeBase.name.ilike(f"%{query}%"),
KnowledgeBase.description.ilike(f"%{query}%")
)
)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
"""Search in knowledge base using vector similarity.
Args:
kb_id (int): ID of the knowledge base to search in.
query (str): Search query.
top_k (int, optional): Maximum number of results to return. Defaults to 5.
similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7.
Returns:
List[Dict[str, Any]]: List of search results with content, source, score, and metadata.
"""
try:
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
# Use document processor for vector search
search_results = (await get_document_processor(self.session)).search_similar_documents(
knowledge_base_id=kb_id,
query=query,
k=top_k
)
# Filter by similarity threshold
filtered_results = []
for result in search_results:
# Use already normalized similarity score
normalized_score = result.get('normalized_score', 0)
if normalized_score >= similarity_threshold:
filtered_results.append({
"content": result.get('content', ''),
"source": result.get('source', 'unknown'),
"score": normalized_score,
"metadata": result.get('metadata', {}),
"document_id": result.get('document_id', 'unknown'),
"chunk_id": result.get('chunk_id', 'unknown')
})
logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})")
return filtered_results
except Exception as e:
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
return []
# ----------------------------------------------------------------------------------
async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
"""根据名称获取当前用户的知识库实例。
Args:
name (str): 知识库实例的名称
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.name == name,
KnowledgeBase.created_by == UserContext.get_current_user()['id']
)
result = (await self.session.execute(stmt)).scalar_one_or_none()
return result
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
"""获取当前用户的所有知识库的列表。
Args:
skip (int, optional): 跳过的记录数默认值为0
limit (int, optional): 返回的最大记录数默认值为50
active_only (bool, optional): 是否仅返回活动的知识库默认值为True
Returns:
List[KnowledgeBase]: 当前用户的知识库列表
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID
if active_only:
stmt = stmt.where(KnowledgeBase.is_active == True)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
"""根据ID获取知识库实例。
Args:
kb_id (int): 知识库实例的ID
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
"""更新知识库实例。
Args:
kb_id (int): 待更新的知识库实例ID
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据
Returns:
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例否则返回None
Raises:
Exception: 如果更新过程中发生错误
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return None
# Update fields
update_data = kb_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(kb, field, value)
# Set audit fields
kb.set_audit_fields(is_update=True)
await self.session.commit()
await self.session.refresh(kb)
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
return kb
async def delete_knowledge_base(self, kb_id: int) -> bool:
"""删除知识库实例。
Args:
kb_id (int): 待删除的知识库实例ID
Returns:
bool: 如果知识库实例被成功删除则返回True否则返回False
Raises:
Exception: 如果删除过程中发生错误
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return False
# TODO: Clean up vector database collection
# This should be implemented when vector database service is ready
await self.session.delete(kb)
await self.session.commit()
return True