"""LLM Configuration Pydantic schemas.""" from typing import Optional, Dict, Any from pydantic import BaseModel, Field, field_validator, computed_field, model_validator from datetime import datetime class LLMConfigBase(BaseModel): """大模型配置基础模式.""" name: str = Field(..., min_length=1, max_length=100, description="配置名称") provider: str = Field(..., min_length=1, max_length=50, description="服务商") model_name: str = Field(..., min_length=1, max_length=100, description="模型名称") api_key: str = Field(..., min_length=1, description="API密钥") base_url: Optional[str] = Field(None, description="API基础URL") max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数") temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数") top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数") frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚") presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚") description: Optional[str] = Field(None, max_length=500, description="配置描述") is_active: bool = Field(True, description="是否激活") is_default: bool = Field(False, description="是否为默认配置") is_embedding: bool = Field(False, description="是否为嵌入模型") extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置") class LLMConfigCreate(LLMConfigBase): """创建大模型配置模式.""" @field_validator('provider') @classmethod def validate_provider(cls, v: str) -> str: allowed_providers = [ 'openai', 'azure', 'anthropic', 'google', 'baidu', 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', 'ollama', 'custom', "doubao", "ollama" ] if v.lower() not in allowed_providers: raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') return v.lower() @model_validator(mode='after') def validate_api_key_for_local_service(self): """验证 API 密钥:对于本地服务允许较短的密钥""" # 对于本地服务(如 Ollama),API 密钥可以为空或较短 # 检查 base_url 是否指向本地服务 base_url = self.base_url or '' is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.']) api_key = self.api_key.strip() if self.api_key else '' # 如果是本地服务,允许较短的 API 密钥(至少1个字符) if is_local_service: if len(api_key) < 1: raise ValueError('API密钥不能为空') else: # 对于在线服务,要求至少10个字符 if len(api_key) < 10: raise ValueError('API密钥长度不能少于10个字符') return self class LLMConfigUpdate(BaseModel): """更新大模型配置模式.""" name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称") provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商") model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称") api_key: Optional[str] = Field(None, min_length=1, description="API密钥") base_url: Optional[str] = Field(None, description="API基础URL") max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数") temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数") top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数") frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚") presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚") description: Optional[str] = Field(None, max_length=500, description="配置描述") is_active: Optional[bool] = Field(None, description="是否激活") is_default: Optional[bool] = Field(None, description="是否为默认配置") is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型") extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置") @field_validator('provider') @classmethod def validate_provider(cls, v: Optional[str]) -> Optional[str]: if v is not None: allowed_providers = [ 'openai', 'azure', 'anthropic', 'google', 'baidu', 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', 'ollama', 'custom',"doubao", "ollama" ] if v.lower() not in allowed_providers: raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') return v.lower() return v @model_validator(mode='after') def validate_api_key_for_local_service(self): """验证 API 密钥:对于本地服务允许较短的密钥""" # 如果 api_key 不为 None,进行验证 if self.api_key is not None: # 对于本地服务(如 Ollama),API 密钥可以为空或较短 # 检查 base_url 是否指向本地服务 base_url = self.base_url or '' is_local_service = base_url and any(local in base_url.lower() for local in ['localhost', '127.0.0.1', '192.168.', '10.', '172.']) api_key = self.api_key.strip() if self.api_key else '' # 如果是本地服务,允许较短的 API 密钥(至少1个字符) if is_local_service: if len(api_key) < 1: raise ValueError('API密钥不能为空') else: # 对于在线服务,要求至少10个字符 if len(api_key) < 10: raise ValueError('API密钥长度不能少于10个字符') # 更新 api_key 为去除空格后的值 self.api_key = api_key return self class LLMConfigResponse(BaseModel): """大模型配置响应模式.""" id: int name: str provider: str model_name: str api_key: Optional[str] = None # 完整的API密钥(仅在include_sensitive=True时返回) base_url: Optional[str] = None max_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None description: Optional[str] = None is_active: bool is_default: bool is_embedding: bool extra_config: Optional[Dict[str, Any]] = None created_at: datetime updated_at: Optional[datetime] = None created_by: Optional[int] = None updated_by: Optional[int] = None model_config = { 'from_attributes': True } @computed_field @property def api_key_masked(self) -> Optional[str]: # 在响应中隐藏API密钥,只显示前4位和后4位 if self.api_key: key = self.api_key if len(key) > 8: return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}" else: return '*' * len(key) return None class LLMConfigTest(BaseModel): """大模型配置测试模式.""" message: Optional[str] = Field( "Hello, this is a test message.", max_length=1000, description="测试消息" ) class LLMConfigClientResponse(BaseModel): """大模型配置客户端响应模式(用于前端).""" id: int name: str provider: str model_name: str max_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None is_active: bool description: Optional[str] = None model_config = { 'from_attributes': True }