hxf/backend/th_agenter/core/config.py

468 lines
19 KiB
Python
Raw Normal View History

2025-12-16 13:55:16 +08:00
"""Configuration management for TH Agenter."""
2025-12-04 14:48:38 +08:00
import os
2026-01-07 11:30:54 +08:00
from requests import Session
2025-12-04 14:48:38 +08:00
import yaml
from pathlib import Path
2025-12-16 13:55:16 +08:00
from loguru import logger
2025-12-04 14:48:38 +08:00
from typing import Any, Dict, List, Optional, Union
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from functools import lru_cache
class DatabaseSettings(BaseSettings):
"""Database configuration."""
url: str = Field(..., alias="database_url") # Must be provided via environment variable
echo: bool = Field(default=False)
pool_size: int = Field(default=5)
max_overflow: int = Field(default=10)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class SecuritySettings(BaseSettings):
"""Security configuration."""
secret_key: str = Field(default="your-secret-key-here-change-in-production")
algorithm: str = Field(default="HS256")
access_token_expire_minutes: int = Field(default=300)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class ToolSetings(BaseSettings):
# Tavily搜索配置
tavily_api_key: Optional[str] = Field(default=None)
weather_api_key: Optional[str] = Field(default=None)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
2025-12-16 13:55:16 +08:00
2025-12-04 14:48:38 +08:00
class LLMSettings(BaseSettings):
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
# OpenAI配置
openai_api_key: Optional[str] = Field(default=None)
openai_base_url: str = Field(default="https://api.openai.com/v1")
openai_model: str = Field(default="gpt-3.5-turbo")
# DeepSeek配置
deepseek_api_key: Optional[str] = Field(default=None)
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
deepseek_model: str = Field(default="deepseek-chat")
# 豆包配置
doubao_api_key: Optional[str] = Field(default=None)
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
doubao_model: str = Field(default="doubao-lite-4k")
# 智谱AI配置
zhipu_api_key: Optional[str] = Field(default=None)
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
zhipu_model: str = Field(default="glm-4")
zhipu_embedding_model: str = Field(default="embedding-3")
# 月之暗面配置
moonshot_api_key: Optional[str] = Field(default=None)
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
moonshot_model: str = Field(default="moonshot-v1-8k")
# 通用配置
max_tokens: int = Field(default=2048)
temperature: float = Field(default=0.7)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
2026-01-07 11:30:54 +08:00
async def get_current_config(self, session: Session) -> dict:
2025-12-04 14:48:38 +08:00
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
try:
from th_agenter.services.llm_config_service import LLMConfigService
2026-01-07 11:30:54 +08:00
# 尝试从数据库读取默认聊天模型配置
2025-12-04 14:48:38 +08:00
llm_service = LLMConfigService()
2026-01-07 11:30:54 +08:00
db_config = None
if session:
db_config = await llm_service.get_default_chat_config(session)
2025-12-04 14:48:38 +08:00
if db_config:
# 如果数据库中有默认配置,使用数据库配置
config = {
"api_key": db_config.api_key,
"base_url": db_config.base_url,
"model": db_config.model_name,
"max_tokens": self.max_tokens,
"temperature": self.temperature
}
2026-01-07 11:30:54 +08:00
if session:
session.desc = f"使用LLM配置(get_default_chat_config)> {config}"
else:
logger.info(f"使用LLM配置(get_default_chat_config) > {config}")
2025-12-04 14:48:38 +08:00
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
2026-01-07 11:30:54 +08:00
if session:
session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}"
else:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
2025-12-04 14:48:38 +08:00
# 回退到原有的环境变量配置
provider_configs = {
"openai": {
"api_key": self.openai_api_key,
"base_url": self.openai_base_url,
"model": self.openai_model
},
"deepseek": {
"api_key": self.deepseek_api_key,
"base_url": self.deepseek_base_url,
"model": self.deepseek_model
},
"doubao": {
"api_key": self.doubao_api_key,
"base_url": self.doubao_base_url,
"model": self.doubao_model
},
"zhipu": {
"api_key": self.zhipu_api_key,
"base_url": self.zhipu_base_url,
"model": self.zhipu_model
},
"moonshot": {
"api_key": self.moonshot_api_key,
"base_url": self.moonshot_base_url,
"model": self.moonshot_model
}
}
config = provider_configs.get(self.provider, provider_configs["openai"])
config.update({
"max_tokens": self.max_tokens,
"temperature": self.temperature
})
return config
class EmbeddingSettings(BaseSettings):
"""Embedding模型配置 - 支持多种提供商."""
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
# OpenAI配置
openai_api_key: Optional[str] = Field(default=None)
openai_base_url: str = Field(default="https://api.openai.com/v1")
openai_embedding_model: str = Field(default="text-embedding-ada-002")
# DeepSeek配置
deepseek_api_key: Optional[str] = Field(default=None)
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
deepseek_embedding_model: str = Field(default="deepseek-embedding")
# 豆包配置
doubao_api_key: Optional[str] = Field(default=None)
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
doubao_embedding_model: str = Field(default="doubao-embedding")
# 智谱AI配置
zhipu_api_key: Optional[str] = Field(default=None)
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
zhipu_embedding_model: str = Field(default="embedding-3")
# 月之暗面配置
moonshot_api_key: Optional[str] = Field(default=None)
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
moonshot_embedding_model: str = Field(default="moonshot-embedding")
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
2026-01-07 11:30:54 +08:00
async def get_current_config(self, session: Session) -> dict:
2025-12-04 14:48:38 +08:00
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
try:
2026-01-07 11:30:54 +08:00
if session:
session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config";
2025-12-04 14:48:38 +08:00
# 尝试从数据库读取默认嵌入模型配置
from th_agenter.services.llm_config_service import LLMConfigService
llm_service = LLMConfigService()
2026-01-07 11:30:54 +08:00
db_config = await llm_service.get_default_embedding_config(session)
2025-12-04 14:48:38 +08:00
if db_config:
# 如果数据库中有默认配置,使用数据库配置
config = {
"api_key": db_config.api_key,
"base_url": db_config.base_url,
"model": db_config.model_name
}
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
2026-01-07 11:30:54 +08:00
if session:
session.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
else:
logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
2025-12-04 14:48:38 +08:00
# 回退到原有的环境变量配置
provider_configs = {
"openai": {
"api_key": self.openai_api_key,
"base_url": self.openai_base_url,
"model": self.openai_embedding_model
},
"deepseek": {
"api_key": self.deepseek_api_key,
"base_url": self.deepseek_base_url,
"model": self.deepseek_embedding_model
},
"doubao": {
"api_key": self.doubao_api_key,
"base_url": self.doubao_base_url,
"model": self.doubao_embedding_model
},
"zhipu": {
"api_key": self.zhipu_api_key,
"base_url": self.zhipu_base_url,
"model": self.zhipu_embedding_model
},
"moonshot": {
"api_key": self.moonshot_api_key,
"base_url": self.moonshot_base_url,
"model": self.moonshot_embedding_model
}
}
return provider_configs.get(self.provider, provider_configs["zhipu"])
class VectorDBSettings(BaseSettings):
"""Vector database configuration."""
type: str = Field(default="pgvector", alias="vector_db_type")
persist_directory: str = Field(default="./data/chroma")
collection_name: str = Field(default="documents")
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
# PostgreSQL pgvector configuration
pgvector_host: str = Field(default="localhost")
pgvector_port: int = Field(default=5432)
pgvector_database: str = Field(default="vectordb")
pgvector_user: str = Field(default="postgres")
pgvector_password: str = Field(default="")
pgvector_table_name: str = Field(default="embeddings")
pgvector_vector_dimension: int = Field(default=1024)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class FileSettings(BaseSettings):
"""File processing configuration."""
upload_dir: str = Field(default="./data/uploads")
max_size: int = Field(default=10485760) # 10MB
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
chunk_size: int = Field(default=1000)
chunk_overlap: int = Field(default=200)
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
@field_validator('allowed_extensions', mode='before')
@classmethod
def parse_allowed_extensions(cls, v):
"""Parse comma-separated string to list of extensions."""
if isinstance(v, str):
# Split by comma and add dots if not present
extensions = [ext.strip() for ext in v.split(',')]
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
elif isinstance(v, list):
# Ensure all extensions start with dot
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
return v
def get_allowed_extensions_list(self) -> List[str]:
"""Get allowed extensions as a list."""
if isinstance(self.allowed_extensions, list):
return self.allowed_extensions
elif isinstance(self.allowed_extensions, str):
# Split by comma and add dots if not present
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
return []
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class StorageSettings(BaseSettings):
"""Storage configuration."""
storage_type: str = Field(default="local") # local or s3
upload_directory: str = Field(default="./data/uploads")
# S3 settings
s3_bucket_name: str = Field(default="chat-agent-files")
aws_access_key_id: Optional[str] = Field(default=None)
aws_secret_access_key: Optional[str] = Field(default=None)
aws_region: str = Field(default="us-east-1")
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class CORSSettings(BaseSettings):
"""CORS configuration."""
allowed_origins: List[str] = Field(default=["*"])
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
allowed_headers: List[str] = Field(default=["*"])
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class ChatSettings(BaseSettings):
"""Chat configuration."""
max_history_length: int = Field(default=10)
system_prompt: str = Field(default="你是一个有用的AI助手请根据提供的上下文信息回答用户的问题。")
max_response_tokens: int = Field(default=1000)
class Settings(BaseSettings):
"""Main application settings."""
# App info
2025-12-16 13:55:16 +08:00
app_name: str = Field(default="TH Agenter")
app_version: str = Field(default="0.2.0")
2025-12-04 14:48:38 +08:00
debug: bool = Field(default=True)
environment: str = Field(default="development")
# Server
host: str = Field(default="0.0.0.0")
port: int = Field(default=8000)
# Configuration sections
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
security: SecuritySettings = Field(default_factory=SecuritySettings)
llm: LLMSettings = Field(default_factory=LLMSettings)
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
file: FileSettings = Field(default_factory=FileSettings)
storage: StorageSettings = Field(default_factory=StorageSettings)
cors: CORSSettings = Field(default_factory=CORSSettings)
chat: ChatSettings = Field(default_factory=ChatSettings)
tool: ToolSetings = Field(default_factory=ToolSetings)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
@classmethod
2025-12-16 13:55:16 +08:00
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
2025-12-04 14:48:38 +08:00
"""Load settings from YAML file."""
config_file = Path(config_path)
if not config_file.exists():
2025-12-16 13:55:16 +08:00
# 获取当前文件所在目录backend/open_agent/core
2025-12-04 14:48:38 +08:00
current_dir = Path(__file__).parent
# 向上两级到backend目录然后找configs/settings.yaml
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
if backend_config_path.exists():
config_file = backend_config_path
else:
return cls()
with open(config_file, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
# 处理环境变量替换
config_data = cls._resolve_env_vars_nested(config_data)
# 为每个子设置类创建实例,确保它们能正确加载环境变量
# 如果YAML中没有对应配置则使用默认的BaseSettings加载会自动读取.env文件
settings_kwargs = {}
# 显式处理各个子设置以解决debug等情况因为环境的变化没有自动加载.env配置的问题
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
2025-12-16 13:55:16 +08:00
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
2025-12-04 14:48:38 +08:00
# 添加顶级配置
for key, value in config_data.items():
if key not in settings_kwargs:
2025-12-16 13:55:16 +08:00
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
2025-12-04 14:48:38 +08:00
settings_kwargs[key] = value
return cls(**settings_kwargs)
@staticmethod
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
"""Flatten nested configuration dictionary."""
flat = {}
for key, value in config.items():
new_key = f"{prefix}_{key}" if prefix else key
if isinstance(value, dict):
flat.update(Settings._flatten_config(value, new_key))
else:
flat[new_key] = value
return flat
@staticmethod
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve environment variables in nested configuration."""
if isinstance(config, dict):
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
env_var = config[2:-1]
return os.getenv(env_var, config)
else:
return config
@staticmethod
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve environment variables in configuration values."""
resolved = {}
for key, value in config.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
env_var = value[2:-1]
resolved[key] = os.getenv(env_var, value)
else:
resolved[key] = value
return resolved
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
2025-12-16 13:55:16 +08:00
settings = Settings.load_from_yaml()
return settings
2025-12-04 14:48:38 +08:00
settings = get_settings()