hyf-backend/th_agenter/services/llm_config_service.py

157 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM配置服务 - 从数据库读取默认配置"""
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from sqlalchemy import and_, select
from ..models.llm_config import LLMConfig
from ..db.database import get_session
from loguru import logger
class LLMConfigService:
"""LLM配置管理服务"""
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认对话模型配置"""
# async for session in get_session():
try:
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == False,
LLMConfig.is_active == True
)
)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
logger.warning("未找到默认对话模型配置")
return None
return config
except Exception as e:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
return None
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认嵌入模型配置,如果没有默认配置则尝试使用任何激活的嵌入模型配置"""
if session is None:
logger.error("get_default_embedding_config: session 为 None无法查询配置")
return None
try:
# 首先尝试获取默认嵌入模型配置
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == True,
LLMConfig.is_active == True
)
)
config = (await session.execute(stmt)).scalar_one_or_none()
if config:
session.desc = f"找到默认嵌入模型配置: {config.name} (ID: {config.id})"
return config
# 如果没有默认配置,尝试获取任何激活的嵌入模型配置作为后备
session.desc = "未找到默认嵌入模型配置,尝试查找任何激活的嵌入模型配置"
logger.info("未找到默认嵌入模型配置,尝试查找任何激活的嵌入模型配置")
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_embedding == True,
LLMConfig.is_active == True
)
).order_by(LLMConfig.created_at) # 按创建时间排序,取第一个
config = (await session.execute(stmt)).scalar_one_or_none()
if config:
session.desc = f"使用激活的嵌入模型配置(非默认): {config.name} (ID: {config.id})"
logger.info(f"使用激活的嵌入模型配置(非默认): {config.name} (ID: {config.id})")
return config
# 如果还是没找到,记录详细信息
session.desc = "ERROR: 未找到任何激活的嵌入模型配置"
logger.error("未找到任何激活的嵌入模型配置")
# 尝试查询所有嵌入模型配置(包括未激活的),用于调试
all_embedding_stmt = select(LLMConfig).where(LLMConfig.is_embedding == True)
all_embedding = (await session.execute(all_embedding_stmt)).scalars().all()
if all_embedding:
logger.warning(f"找到 {len(all_embedding)} 个嵌入模型配置,但都不是激活状态:")
for cfg in all_embedding:
logger.warning(f" - {cfg.name} (ID: {cfg.id}, is_active={cfg.is_active}, is_default={cfg.is_default})")
else:
logger.warning("数据库中没有任何嵌入模型配置")
return None
except Exception as e:
session.desc = f"ERROR: 获取嵌入模型配置失败: {str(e)}"
logger.error(f"获取嵌入模型配置失败: {str(e)}", exc_info=True)
return None
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
"""根据ID获取配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
return (await self.db.execute(stmt)).scalar_one_or_none()
except Exception as e:
logger.error(f"获取配置失败: {str(e)}")
return None
def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]:
"""获取所有激活的配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
stmt = stmt.order_by(LLMConfig.created_at)
return self.db.execute(stmt).scalars().all()
except Exception as e:
logger.error(f"获取激活配置失败: {str(e)}")
return []
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
"""获取fallback对话模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.llm.get_current_config()
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
"""获取fallback嵌入模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.embedding.get_current_config()
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
"""测试配置连接"""
try:
config = self.get_config_by_id(config_id)
if not config:
return {"success": False, "error": "配置不存在"}
# 这里可以添加实际的连接测试逻辑
# 例如发送一个简单的请求来验证配置是否有效
return {"success": True, "message": "配置测试成功"}
except Exception as e:
logger.error(f"测试配置失败: {str(e)}")
return {"success": False, "error": str(e)}
# # 全局实例
# _llm_config_service = None
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
# """获取LLM配置服务实例"""
# global _llm_config_service
# if _llm_config_service is None or db_session is not None:
# _llm_config_service = LLMConfigService(db_session)
# return _llm_config_service