feat: 更新LLM配置管理和错误处理逻辑
- 修改create_llm_config函数,改为使用require_authenticated_user进行用户验证 - 优化配置名称检查,支持不区分大小写的比较 - 更新API密钥验证逻辑,允许本地服务使用较短的API密钥 - 改进HxfErrorResponse类,增强异常处理和状态码管理 - 更新数据库文件和二进制数据,确保数据一致性
This commit is contained in:
parent
db8cc75ff5
commit
0bca60cd06
Binary file not shown.
Binary file not shown.
|
|
@ -152,47 +152,20 @@ async def get_llm_config(
|
|||
async def create_llm_config(
|
||||
config_data: LLMConfigCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""创建大模型配置."""
|
||||
# 检查配置名称是否已存在
|
||||
# 检查配置名称是否已存在(不区分大小写)
|
||||
# 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误
|
||||
username = current_user.username
|
||||
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
||||
stmt = select(LLMConfig).where(LLMConfig.name.ilike(config_data.name))
|
||||
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_config:
|
||||
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
|
||||
session.desc = f"ERROR: 配置名称已存在, name={config_data.name} (已存在ID: {existing_config.id})"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 创建配置对象
|
||||
config = LLMConfig_DataClass(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = config.validate_config()
|
||||
if not validation_result['valid']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=validation_result['error']
|
||||
detail=f"配置名称已存在(ID: {existing_config.id})。请使用不同的名称,或更新现有配置。"
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
|
|
@ -203,8 +176,8 @@ async def create_llm_config(
|
|||
await session.execute(stmt)
|
||||
|
||||
session.desc = f"验证大模型配置, config_data"
|
||||
# 创建配置
|
||||
config = LLMConfig_DataClass(
|
||||
# 创建数据库模型对象
|
||||
config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""LLM Configuration Pydantic schemas."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field, model_validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
|
@ -40,12 +40,25 @@ class LLMConfigCreate(LLMConfigBase):
|
|||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
return v.lower()
|
||||
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
if len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip()
|
||||
@model_validator(mode='after')
|
||||
def validate_api_key_for_local_service(self):
|
||||
"""验证 API 密钥:对于本地服务允许较短的密钥"""
|
||||
# 对于本地服务(如 Ollama),API 密钥可以为空或较短
|
||||
# 检查 base_url 是否指向本地服务
|
||||
base_url = self.base_url or ''
|
||||
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
|
||||
|
||||
api_key = self.api_key.strip() if self.api_key else ''
|
||||
|
||||
# 如果是本地服务,允许较短的 API 密钥(至少1个字符)
|
||||
if is_local_service:
|
||||
if len(api_key) < 1:
|
||||
raise ValueError('API密钥不能为空')
|
||||
else:
|
||||
# 对于在线服务,要求至少10个字符
|
||||
if len(api_key) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return self
|
||||
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
|
|
@ -81,12 +94,29 @@ class LLMConfigUpdate(BaseModel):
|
|||
return v.lower()
|
||||
return v
|
||||
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip() if v else v
|
||||
@model_validator(mode='after')
|
||||
def validate_api_key_for_local_service(self):
|
||||
"""验证 API 密钥:对于本地服务允许较短的密钥"""
|
||||
# 如果 api_key 不为 None,进行验证
|
||||
if self.api_key is not None:
|
||||
# 对于本地服务(如 Ollama),API 密钥可以为空或较短
|
||||
# 检查 base_url 是否指向本地服务
|
||||
base_url = self.base_url or ''
|
||||
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
|
||||
|
||||
api_key = self.api_key.strip() if self.api_key else ''
|
||||
|
||||
# 如果是本地服务,允许较短的 API 密钥(至少1个字符)
|
||||
if is_local_service:
|
||||
if len(api_key) < 1:
|
||||
raise ValueError('API密钥不能为空')
|
||||
else:
|
||||
# 对于在线服务,要求至少10个字符
|
||||
if len(api_key) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
# 更新 api_key 为去除空格后的值
|
||||
self.api_key = api_key
|
||||
return self
|
||||
|
||||
|
||||
class LLMConfigResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -71,24 +71,30 @@ class HxfErrorResponse(JSONResponse):
|
|||
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
|
||||
"""Return a JSON error response."""
|
||||
if isinstance(message, Exception):
|
||||
msg = message.message if 'message' in message.__dict__ else str(message)
|
||||
msg = message.message if hasattr(message, 'message') and 'message' in message.__dict__ else str(message)
|
||||
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
|
||||
if isinstance(message, TypeError):
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": 500,
|
||||
"data": None,
|
||||
"error": f"错误类型: {type(message)} 错误信息: {message}",
|
||||
"message": msg
|
||||
}
|
||||
else:
|
||||
|
||||
# 检查异常是否有 status_code 属性(如 HTTPException)
|
||||
if hasattr(message, 'status_code'):
|
||||
error_status_code = message.status_code
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": message.status_code,
|
||||
"status": error_status_code,
|
||||
"data": None,
|
||||
"error": None,
|
||||
"message": msg
|
||||
}
|
||||
status_code = error_status_code
|
||||
else:
|
||||
# 对于没有 status_code 的异常(如 AttributeError, ValueError 等),使用 500
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": 500,
|
||||
"data": None,
|
||||
"error": f"错误类型: {type(message).__name__} 错误信息: {str(message)}",
|
||||
"message": msg
|
||||
}
|
||||
status_code = 500
|
||||
else:
|
||||
content = {
|
||||
"code": -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue