hyf-backend/th_agenter/schemas/llm_config.py

186 lines
7.9 KiB
Python
Raw Permalink Normal View History

2026-01-21 13:45:39 +08:00
"""LLM Configuration Pydantic schemas."""
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, field_validator, computed_field, model_validator
2026-01-21 13:45:39 +08:00
from datetime import datetime
class LLMConfigBase(BaseModel):
"""大模型配置基础模式."""
name: str = Field(..., min_length=1, max_length=100, description="配置名称")
provider: str = Field(..., min_length=1, max_length=50, description="服务商")
model_name: str = Field(..., min_length=1, max_length=100, description="模型名称")
api_key: str = Field(..., min_length=1, description="API密钥")
base_url: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数")
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数")
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚")
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚")
description: Optional[str] = Field(None, max_length=500, description="配置描述")
is_active: bool = Field(True, description="是否激活")
is_default: bool = Field(False, description="是否为默认配置")
is_embedding: bool = Field(False, description="是否为嵌入模型")
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
class LLMConfigCreate(LLMConfigBase):
"""创建大模型配置模式."""
@field_validator('provider')
@classmethod
def validate_provider(cls, v: str) -> str:
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom', "doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
return v.lower()
@model_validator(mode='after')
def validate_api_key_for_local_service(self):
"""验证 API 密钥:对于本地服务允许较短的密钥"""
# 对于本地服务(如 OllamaAPI 密钥可以为空或较短
# 检查 base_url 是否指向本地服务
base_url = self.base_url or ''
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
api_key = self.api_key.strip() if self.api_key else ''
# 如果是本地服务,允许较短的 API 密钥至少1个字符
if is_local_service:
if len(api_key) < 1:
raise ValueError('API密钥不能为空')
else:
# 对于在线服务要求至少10个字符
if len(api_key) < 10:
raise ValueError('API密钥长度不能少于10个字符')
return self
2026-01-21 13:45:39 +08:00
class LLMConfigUpdate(BaseModel):
"""更新大模型配置模式."""
name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称")
provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商")
model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称")
api_key: Optional[str] = Field(None, min_length=1, description="API密钥")
base_url: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数")
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数")
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚")
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚")
description: Optional[str] = Field(None, max_length=500, description="配置描述")
is_active: Optional[bool] = Field(None, description="是否激活")
is_default: Optional[bool] = Field(None, description="是否为默认配置")
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
@field_validator('provider')
@classmethod
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom',"doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
return v.lower()
return v
@model_validator(mode='after')
def validate_api_key_for_local_service(self):
"""验证 API 密钥:对于本地服务允许较短的密钥"""
# 如果 api_key 不为 None进行验证
if self.api_key is not None:
# 对于本地服务(如 OllamaAPI 密钥可以为空或较短
# 检查 base_url 是否指向本地服务
base_url = self.base_url or ''
is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.'])
api_key = self.api_key.strip() if self.api_key else ''
# 如果是本地服务,允许较短的 API 密钥至少1个字符
if is_local_service:
if len(api_key) < 1:
raise ValueError('API密钥不能为空')
else:
# 对于在线服务要求至少10个字符
if len(api_key) < 10:
raise ValueError('API密钥长度不能少于10个字符')
# 更新 api_key 为去除空格后的值
self.api_key = api_key
return self
2026-01-21 13:45:39 +08:00
class LLMConfigResponse(BaseModel):
"""大模型配置响应模式."""
id: int
name: str
provider: str
model_name: str
api_key: Optional[str] = None # 完整的API密钥仅在include_sensitive=True时返回
base_url: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
description: Optional[str] = None
is_active: bool
is_default: bool
is_embedding: bool
extra_config: Optional[Dict[str, Any]] = None
created_at: datetime
updated_at: Optional[datetime] = None
created_by: Optional[int] = None
updated_by: Optional[int] = None
model_config = {
'from_attributes': True
}
@computed_field
@property
def api_key_masked(self) -> Optional[str]:
# 在响应中隐藏API密钥只显示前4位和后4位
if self.api_key:
key = self.api_key
if len(key) > 8:
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
else:
return '*' * len(key)
return None
class LLMConfigTest(BaseModel):
"""大模型配置测试模式."""
message: Optional[str] = Field(
"Hello, this is a test message.",
max_length=1000,
description="测试消息"
)
class LLMConfigClientResponse(BaseModel):
"""大模型配置客户端响应模式(用于前端)."""
id: int
name: str
provider: str
model_name: str
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
is_active: bool
description: Optional[str] = None
model_config = {
'from_attributes': True
}