hxf/backend/th_agenter/models/llm_config.py

163 lines
7.0 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""LLM Configuration model for managing multiple AI models."""
2025-12-16 13:55:16 +08:00
from datetime import datetime
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
from sqlalchemy.orm import Mapped, mapped_column
2025-12-04 14:48:38 +08:00
from typing import Dict, Any, Optional
from ..db.base import BaseModel
class LLMConfig(BaseModel):
"""LLM Configuration model for managing AI model settings."""
__tablename__ = "llm_configs"
2025-12-16 13:55:16 +08:00
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商openai, deepseek, doubao, zhipu, moonshot, baidu
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥加密存储
base_url: Mapped[str | None] = mapped_column(String(200), nullable=True) # API基础URL
2025-12-04 14:48:38 +08:00
# 模型参数
2025-12-16 13:55:16 +08:00
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
2025-12-04 14:48:38 +08:00
# 配置信息
2025-12-16 13:55:16 +08:00
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 配置描述
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
2025-12-04 14:48:38 +08:00
# 扩展配置JSON格式
2025-12-16 13:55:16 +08:00
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
2025-12-04 14:48:38 +08:00
# 使用统计
2025-12-16 13:55:16 +08:00
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
2025-12-04 14:48:38 +08:00
def __repr__(self):
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
def to_dict(self, include_sensitive=False):
"""Convert to dictionary, optionally excluding sensitive data."""
data = super().to_dict()
data.update({
'name': self.name,
'provider': self.provider,
'model_name': self.model_name,
'base_url': self.base_url,
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'top_p': self.top_p,
'frequency_penalty': self.frequency_penalty,
'presence_penalty': self.presence_penalty,
'description': self.description,
'is_active': self.is_active,
'is_default': self.is_default,
'is_embedding': self.is_embedding,
'extra_config': self.extra_config,
'usage_count': self.usage_count,
'last_used_at': self.last_used_at
})
if include_sensitive:
data['api_key'] = self.api_key
else:
# 只显示API密钥的前几位和后几位
if self.api_key:
key_len = len(self.api_key)
if key_len > 8:
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
else:
data['api_key_masked'] = "***"
else:
data['api_key_masked'] = None
return data
def get_client_config(self) -> Dict[str, Any]:
"""获取用于创建客户端的配置."""
config = {
'api_key': self.api_key,
'base_url': self.base_url,
'model': self.model_name,
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'top_p': self.top_p,
'frequency_penalty': self.frequency_penalty,
'presence_penalty': self.presence_penalty
}
# 添加额外配置
if self.extra_config:
config.update(self.extra_config)
return config
def validate_config(self) -> Dict[str, Any]:
"""验证配置是否有效."""
if not self.name or not self.name.strip():
return {"valid": False, "error": "配置名称不能为空"}
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
return {"valid": False, "error": "不支持的服务商"}
if not self.model_name or not self.model_name.strip():
return {"valid": False, "error": "模型名称不能为空"}
if not self.api_key or not self.api_key.strip():
return {"valid": False, "error": "API密钥不能为空"}
if self.max_tokens <= 0 or self.max_tokens > 32000:
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
if self.temperature < 0 or self.temperature > 2:
return {"valid": False, "error": "温度参数必须在0-2之间"}
return {"valid": True, "error": None}
def increment_usage(self):
"""增加使用次数."""
self.usage_count += 1
2025-12-16 13:55:16 +08:00
self.last_used_at = datetime.now()
2025-12-04 14:48:38 +08:00
@classmethod
def get_default_config(cls, provider: str, is_embedding: bool = False):
"""获取服务商的默认配置模板."""
templates = {
'openai': {
'base_url': 'https://api.openai.com/v1',
2025-12-16 13:55:16 +08:00
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
2025-12-04 14:48:38 +08:00
'max_tokens': 2048,
'temperature': 0.7
},
'deepseek': {
'base_url': 'https://api.deepseek.com/v1',
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
'max_tokens': 2048,
'temperature': 0.7
},
'doubao': {
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
'max_tokens': 2048,
'temperature': 0.7
},
'zhipu': {
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
'max_tokens': 2048,
'temperature': 0.7
},
'moonshot': {
'base_url': 'https://api.moonshot.cn/v1',
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
'max_tokens': 2048,
'temperature': 0.7
}
}
return templates.get(provider, {})