hyf-backend/th_agenter/schemas/llm_config.py

156 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM Configuration Pydantic schemas."""
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, field_validator, computed_field
from datetime import datetime
class LLMConfigBase(BaseModel):
"""大模型配置基础模式."""
name: str = Field(..., min_length=1, max_length=100, description="配置名称")
provider: str = Field(..., min_length=1, max_length=50, description="服务商")
model_name: str = Field(..., min_length=1, max_length=100, description="模型名称")
api_key: str = Field(..., min_length=1, description="API密钥")
base_url: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数")
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数")
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚")
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚")
description: Optional[str] = Field(None, max_length=500, description="配置描述")
is_active: bool = Field(True, description="是否激活")
is_default: bool = Field(False, description="是否为默认配置")
is_embedding: bool = Field(False, description="是否为嵌入模型")
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
class LLMConfigCreate(LLMConfigBase):
"""创建大模型配置模式."""
@field_validator('provider')
@classmethod
def validate_provider(cls, v: str) -> str:
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom', "doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
return v.lower()
@field_validator('api_key')
@classmethod
def validate_api_key(cls, v: str) -> str:
if len(v.strip()) < 10:
raise ValueError('API密钥长度不能少于10个字符')
return v.strip()
class LLMConfigUpdate(BaseModel):
"""更新大模型配置模式."""
name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称")
provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商")
model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称")
api_key: Optional[str] = Field(None, min_length=1, description="API密钥")
base_url: Optional[str] = Field(None, description="API基础URL")
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数")
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数")
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚")
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚")
description: Optional[str] = Field(None, max_length=500, description="配置描述")
is_active: Optional[bool] = Field(None, description="是否激活")
is_default: Optional[bool] = Field(None, description="是否为默认配置")
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
@field_validator('provider')
@classmethod
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom',"doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
return v.lower()
return v
@field_validator('api_key')
@classmethod
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
if v is not None and len(v.strip()) < 10:
raise ValueError('API密钥长度不能少于10个字符')
return v.strip() if v else v
class LLMConfigResponse(BaseModel):
"""大模型配置响应模式."""
id: int
name: str
provider: str
model_name: str
api_key: Optional[str] = None # 完整的API密钥仅在include_sensitive=True时返回
base_url: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
description: Optional[str] = None
is_active: bool
is_default: bool
is_embedding: bool
extra_config: Optional[Dict[str, Any]] = None
created_at: datetime
updated_at: Optional[datetime] = None
created_by: Optional[int] = None
updated_by: Optional[int] = None
model_config = {
'from_attributes': True
}
@computed_field
@property
def api_key_masked(self) -> Optional[str]:
# 在响应中隐藏API密钥只显示前4位和后4位
if self.api_key:
key = self.api_key
if len(key) > 8:
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
else:
return '*' * len(key)
return None
class LLMConfigTest(BaseModel):
"""大模型配置测试模式."""
message: Optional[str] = Field(
"Hello, this is a test message.",
max_length=1000,
description="测试消息"
)
class LLMConfigClientResponse(BaseModel):
"""大模型配置客户端响应模式(用于前端)."""
id: int
name: str
provider: str
model_name: str
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
is_active: bool
description: Optional[str] = None
model_config = {
'from_attributes': True
}