163 lines
7.0 KiB
Python
163 lines
7.0 KiB
Python
"""LLM Configuration model for managing multiple AI models."""
|
||
|
||
from datetime import datetime
|
||
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
|
||
from sqlalchemy.orm import Mapped, mapped_column
|
||
from typing import Dict, Any, Optional
|
||
|
||
from ..db.base import BaseModel
|
||
|
||
class LLMConfig(BaseModel):
|
||
"""LLM Configuration model for managing AI model settings."""
|
||
|
||
__tablename__ = "llm_configs"
|
||
|
||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
|
||
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
|
||
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥(加密存储)
|
||
base_url: Mapped[str | None] = mapped_column(String(200), nullable=True) # API基础URL
|
||
|
||
# 模型参数
|
||
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
|
||
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
|
||
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||
|
||
# 配置信息
|
||
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 配置描述
|
||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
|
||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||
|
||
# 扩展配置(JSON格式)
|
||
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
|
||
|
||
# 使用统计
|
||
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
|
||
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
|
||
|
||
def __repr__(self):
|
||
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
|
||
|
||
def to_dict(self, include_sensitive=False):
|
||
"""Convert to dictionary, optionally excluding sensitive data."""
|
||
data = super().to_dict()
|
||
data.update({
|
||
'name': self.name,
|
||
'provider': self.provider,
|
||
'model_name': self.model_name,
|
||
'base_url': self.base_url,
|
||
'max_tokens': self.max_tokens,
|
||
'temperature': self.temperature,
|
||
'top_p': self.top_p,
|
||
'frequency_penalty': self.frequency_penalty,
|
||
'presence_penalty': self.presence_penalty,
|
||
'description': self.description,
|
||
'is_active': self.is_active,
|
||
'is_default': self.is_default,
|
||
'is_embedding': self.is_embedding,
|
||
'extra_config': self.extra_config,
|
||
'usage_count': self.usage_count,
|
||
'last_used_at': self.last_used_at
|
||
})
|
||
|
||
if include_sensitive:
|
||
data['api_key'] = self.api_key
|
||
else:
|
||
# 只显示API密钥的前几位和后几位
|
||
if self.api_key:
|
||
key_len = len(self.api_key)
|
||
if key_len > 8:
|
||
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
|
||
else:
|
||
data['api_key_masked'] = "***"
|
||
else:
|
||
data['api_key_masked'] = None
|
||
|
||
return data
|
||
|
||
def get_client_config(self) -> Dict[str, Any]:
|
||
"""获取用于创建客户端的配置."""
|
||
config = {
|
||
'api_key': self.api_key,
|
||
'base_url': self.base_url,
|
||
'model': self.model_name,
|
||
'max_tokens': self.max_tokens,
|
||
'temperature': self.temperature,
|
||
'top_p': self.top_p,
|
||
'frequency_penalty': self.frequency_penalty,
|
||
'presence_penalty': self.presence_penalty
|
||
}
|
||
|
||
# 添加额外配置
|
||
if self.extra_config:
|
||
config.update(self.extra_config)
|
||
|
||
return config
|
||
|
||
def validate_config(self) -> Dict[str, Any]:
|
||
"""验证配置是否有效."""
|
||
if not self.name or not self.name.strip():
|
||
return {"valid": False, "error": "配置名称不能为空"}
|
||
|
||
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
|
||
return {"valid": False, "error": "不支持的服务商"}
|
||
|
||
if not self.model_name or not self.model_name.strip():
|
||
return {"valid": False, "error": "模型名称不能为空"}
|
||
|
||
if not self.api_key or not self.api_key.strip():
|
||
return {"valid": False, "error": "API密钥不能为空"}
|
||
|
||
if self.max_tokens <= 0 or self.max_tokens > 32000:
|
||
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
|
||
|
||
if self.temperature < 0 or self.temperature > 2:
|
||
return {"valid": False, "error": "温度参数必须在0-2之间"}
|
||
|
||
return {"valid": True, "error": None}
|
||
|
||
def increment_usage(self):
|
||
"""增加使用次数."""
|
||
self.usage_count += 1
|
||
self.last_used_at = datetime.now()
|
||
|
||
@classmethod
|
||
def get_default_config(cls, provider: str, is_embedding: bool = False):
|
||
"""获取服务商的默认配置模板."""
|
||
templates = {
|
||
'openai': {
|
||
'base_url': 'https://api.openai.com/v1',
|
||
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
|
||
'max_tokens': 2048,
|
||
'temperature': 0.7
|
||
},
|
||
'deepseek': {
|
||
'base_url': 'https://api.deepseek.com/v1',
|
||
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
|
||
'max_tokens': 2048,
|
||
'temperature': 0.7
|
||
},
|
||
'doubao': {
|
||
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
|
||
'max_tokens': 2048,
|
||
'temperature': 0.7
|
||
},
|
||
'zhipu': {
|
||
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
|
||
'max_tokens': 2048,
|
||
'temperature': 0.7
|
||
},
|
||
'moonshot': {
|
||
'base_url': 'https://api.moonshot.cn/v1',
|
||
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
|
||
'max_tokens': 2048,
|
||
'temperature': 0.7
|
||
}
|
||
}
|
||
|
||
return templates.get(provider, {}) |