feat: 更新LLM配置管理和错误处理逻辑
- 修改create_llm_config函数,改为使用require_authenticated_user进行用户验证 - 优化配置名称检查,支持不区分大小写的比较 - 更新API密钥验证逻辑,允许本地服务使用较短的API密钥 - 改进HxfErrorResponse类,增强异常处理和状态码管理 - 更新数据库文件和二进制数据,确保数据一致性
This commit is contained in:
parent
db8cc75ff5
commit
0bca60cd06
Binary file not shown.
Binary file not shown.
|
|
@ -152,47 +152,20 @@ async def get_llm_config(
|
||||||
async def create_llm_config(
|
async def create_llm_config(
|
||||||
config_data: LLMConfigCreate,
|
config_data: LLMConfigCreate,
|
||||||
session: Session = Depends(get_session),
|
session: Session = Depends(get_session),
|
||||||
current_user: User = Depends(require_super_admin)
|
current_user: User = Depends(require_authenticated_user)
|
||||||
):
|
):
|
||||||
"""创建大模型配置."""
|
"""创建大模型配置."""
|
||||||
# 检查配置名称是否已存在
|
# 检查配置名称是否已存在(不区分大小写)
|
||||||
# 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误
|
# 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误
|
||||||
username = current_user.username
|
username = current_user.username
|
||||||
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||||
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
stmt = select(LLMConfig).where(LLMConfig.name.ilike(config_data.name))
|
||||||
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
if existing_config:
|
if existing_config:
|
||||||
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
|
session.desc = f"ERROR: 配置名称已存在, name={config_data.name} (已存在ID: {existing_config.id})"
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="配置名称已存在"
|
detail=f"配置名称已存在(ID: {existing_config.id})。请使用不同的名称,或更新现有配置。"
|
||||||
)
|
|
||||||
|
|
||||||
# 创建配置对象
|
|
||||||
config = LLMConfig_DataClass(
|
|
||||||
name=config_data.name,
|
|
||||||
provider=config_data.provider,
|
|
||||||
model_name=config_data.model_name,
|
|
||||||
api_key=config_data.api_key,
|
|
||||||
base_url=config_data.base_url,
|
|
||||||
max_tokens=config_data.max_tokens,
|
|
||||||
temperature=config_data.temperature,
|
|
||||||
top_p=config_data.top_p,
|
|
||||||
frequency_penalty=config_data.frequency_penalty,
|
|
||||||
presence_penalty=config_data.presence_penalty,
|
|
||||||
description=config_data.description,
|
|
||||||
is_active=config_data.is_active,
|
|
||||||
is_default=config_data.is_default,
|
|
||||||
is_embedding=config_data.is_embedding,
|
|
||||||
extra_config=config_data.extra_config or {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证配置
|
|
||||||
validation_result = config.validate_config()
|
|
||||||
if not validation_result['valid']:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=validation_result['error']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果设为默认,取消同类型的其他默认配置
|
# 如果设为默认,取消同类型的其他默认配置
|
||||||
|
|
@ -203,8 +176,8 @@ async def create_llm_config(
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
session.desc = f"验证大模型配置, config_data"
|
session.desc = f"验证大模型配置, config_data"
|
||||||
# 创建配置
|
# 创建数据库模型对象
|
||||||
config = LLMConfig_DataClass(
|
config = LLMConfig(
|
||||||
name=config_data.name,
|
name=config_data.name,
|
||||||
provider=config_data.provider,
|
provider=config_data.provider,
|
||||||
model_name=config_data.model_name,
|
model_name=config_data.model_name,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""LLM Configuration Pydantic schemas."""
|
"""LLM Configuration Pydantic schemas."""
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
from pydantic import BaseModel, Field, field_validator, computed_field, model_validator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,12 +40,25 @@ class LLMConfigCreate(LLMConfigBase):
|
||||||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||||
return v.lower()
|
return v.lower()
|
||||||
|
|
||||||
@field_validator('api_key')
|
@model_validator(mode='after')
|
||||||
@classmethod
|
def validate_api_key_for_local_service(self):
|
||||||
def validate_api_key(cls, v: str) -> str:
|
"""验证 API 密钥:对于本地服务允许较短的密钥"""
|
||||||
if len(v.strip()) < 10:
|
# 对于本地服务(如 Ollama),API 密钥可以为空或较短
|
||||||
raise ValueError('API密钥长度不能少于10个字符')
|
# 检查 base_url 是否指向本地服务
|
||||||
return v.strip()
|
base_url = self.base_url or ''
|
||||||
|
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
|
||||||
|
|
||||||
|
api_key = self.api_key.strip() if self.api_key else ''
|
||||||
|
|
||||||
|
# 如果是本地服务,允许较短的 API 密钥(至少1个字符)
|
||||||
|
if is_local_service:
|
||||||
|
if len(api_key) < 1:
|
||||||
|
raise ValueError('API密钥不能为空')
|
||||||
|
else:
|
||||||
|
# 对于在线服务,要求至少10个字符
|
||||||
|
if len(api_key) < 10:
|
||||||
|
raise ValueError('API密钥长度不能少于10个字符')
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigUpdate(BaseModel):
|
class LLMConfigUpdate(BaseModel):
|
||||||
|
|
@ -81,12 +94,29 @@ class LLMConfigUpdate(BaseModel):
|
||||||
return v.lower()
|
return v.lower()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('api_key')
|
@model_validator(mode='after')
|
||||||
@classmethod
|
def validate_api_key_for_local_service(self):
|
||||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
"""验证 API 密钥:对于本地服务允许较短的密钥"""
|
||||||
if v is not None and len(v.strip()) < 10:
|
# 如果 api_key 不为 None,进行验证
|
||||||
raise ValueError('API密钥长度不能少于10个字符')
|
if self.api_key is not None:
|
||||||
return v.strip() if v else v
|
# 对于本地服务(如 Ollama),API 密钥可以为空或较短
|
||||||
|
# 检查 base_url 是否指向本地服务
|
||||||
|
base_url = self.base_url or ''
|
||||||
|
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
|
||||||
|
|
||||||
|
api_key = self.api_key.strip() if self.api_key else ''
|
||||||
|
|
||||||
|
# 如果是本地服务,允许较短的 API 密钥(至少1个字符)
|
||||||
|
if is_local_service:
|
||||||
|
if len(api_key) < 1:
|
||||||
|
raise ValueError('API密钥不能为空')
|
||||||
|
else:
|
||||||
|
# 对于在线服务,要求至少10个字符
|
||||||
|
if len(api_key) < 10:
|
||||||
|
raise ValueError('API密钥长度不能少于10个字符')
|
||||||
|
# 更新 api_key 为去除空格后的值
|
||||||
|
self.api_key = api_key
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigResponse(BaseModel):
|
class LLMConfigResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -71,24 +71,30 @@ class HxfErrorResponse(JSONResponse):
|
||||||
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
|
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
|
||||||
"""Return a JSON error response."""
|
"""Return a JSON error response."""
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
msg = message.message if 'message' in message.__dict__ else str(message)
|
msg = message.message if hasattr(message, 'message') and 'message' in message.__dict__ else str(message)
|
||||||
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
|
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
|
||||||
if isinstance(message, TypeError):
|
|
||||||
|
# 检查异常是否有 status_code 属性(如 HTTPException)
|
||||||
|
if hasattr(message, 'status_code'):
|
||||||
|
error_status_code = message.status_code
|
||||||
content = {
|
content = {
|
||||||
"code": -1,
|
"code": -1,
|
||||||
"status": 500,
|
"status": error_status_code,
|
||||||
"data": None,
|
|
||||||
"error": f"错误类型: {type(message)} 错误信息: {message}",
|
|
||||||
"message": msg
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
content = {
|
|
||||||
"code": -1,
|
|
||||||
"status": message.status_code,
|
|
||||||
"data": None,
|
"data": None,
|
||||||
"error": None,
|
"error": None,
|
||||||
"message": msg
|
"message": msg
|
||||||
}
|
}
|
||||||
|
status_code = error_status_code
|
||||||
|
else:
|
||||||
|
# 对于没有 status_code 的异常(如 AttributeError, ValueError 等),使用 500
|
||||||
|
content = {
|
||||||
|
"code": -1,
|
||||||
|
"status": 500,
|
||||||
|
"data": None,
|
||||||
|
"error": f"错误类型: {type(message).__name__} 错误信息: {str(message)}",
|
||||||
|
"message": msg
|
||||||
|
}
|
||||||
|
status_code = 500
|
||||||
else:
|
else:
|
||||||
content = {
|
content = {
|
||||||
"code": -1,
|
"code": -1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue