hyf-backend/th_agenter/services/llm_config_service.py

123 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM配置服务 - 从数据库读取默认配置"""
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from sqlalchemy import and_, select
from ..models.llm_config import LLMConfig
from ..db.database import get_session
from loguru import logger
class LLMConfigService:
"""LLM配置管理服务"""
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认对话模型配置"""
# async for session in get_session():
try:
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == False,
LLMConfig.is_active == True
)
)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
logger.warning("未找到默认对话模型配置")
return None
return config
except Exception as e:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
return None
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认嵌入模型配置"""
try:
stmt = select(LLMConfig).where(
and_(
LLMConfig.is_default == True,
LLMConfig.is_embedding == True,
LLMConfig.is_active == True
)
)
config = None
if session != None:
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
if session != None:
session.desc = "ERROR: 未找到默认嵌入模型配置"
return None
session.desc = f"获取默认嵌入模型配置 > 结果:{config}"
return config
except Exception as e:
if session != None:
session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}"
return None
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
"""根据ID获取配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
return (await self.db.execute(stmt)).scalar_one_or_none()
except Exception as e:
logger.error(f"获取配置失败: {str(e)}")
return None
def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]:
"""获取所有激活的配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
stmt = stmt.order_by(LLMConfig.created_at)
return self.db.execute(stmt).scalars().all()
except Exception as e:
logger.error(f"获取激活配置失败: {str(e)}")
return []
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
"""获取fallback对话模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.llm.get_current_config()
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
"""获取fallback嵌入模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return await settings.embedding.get_current_config()
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
"""测试配置连接"""
try:
config = self.get_config_by_id(config_id)
if not config:
return {"success": False, "error": "配置不存在"}
# 这里可以添加实际的连接测试逻辑
# 例如发送一个简单的请求来验证配置是否有效
return {"success": True, "message": "配置测试成功"}
except Exception as e:
logger.error(f"测试配置失败: {str(e)}")
return {"success": False, "error": str(e)}
# # 全局实例
# _llm_config_service = None
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
# """获取LLM配置服务实例"""
# global _llm_config_service
# if _llm_config_service is None or db_session is not None:
# _llm_config_service = LLMConfigService(db_session)
# return _llm_config_service