453 lines
18 KiB
Python
453 lines
18 KiB
Python
"""Configuration management for TH Agenter."""
|
||
|
||
import os
|
||
import yaml
|
||
from pathlib import Path
|
||
from loguru import logger
|
||
from typing import Any, Dict, List, Optional, Union
|
||
from pydantic import Field, field_validator
|
||
from pydantic_settings import BaseSettings
|
||
from functools import lru_cache
|
||
|
||
class DatabaseSettings(BaseSettings):
|
||
"""Database configuration."""
|
||
url: str = Field(..., alias="database_url") # Must be provided via environment variable
|
||
echo: bool = Field(default=False)
|
||
pool_size: int = Field(default=5)
|
||
max_overflow: int = Field(default=10)
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class SecuritySettings(BaseSettings):
|
||
"""Security configuration."""
|
||
secret_key: str = Field(default="your-secret-key-here-change-in-production")
|
||
algorithm: str = Field(default="HS256")
|
||
access_token_expire_minutes: int = Field(default=300)
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class ToolSetings(BaseSettings):
|
||
# Tavily搜索配置
|
||
tavily_api_key: Optional[str] = Field(default=None)
|
||
weather_api_key: Optional[str] = Field(default=None)
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class LLMSettings(BaseSettings):
|
||
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
|
||
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||
|
||
# OpenAI配置
|
||
openai_api_key: Optional[str] = Field(default=None)
|
||
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||
openai_model: str = Field(default="gpt-3.5-turbo")
|
||
|
||
# DeepSeek配置
|
||
deepseek_api_key: Optional[str] = Field(default=None)
|
||
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||
deepseek_model: str = Field(default="deepseek-chat")
|
||
|
||
# 豆包配置
|
||
doubao_api_key: Optional[str] = Field(default=None)
|
||
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||
doubao_model: str = Field(default="doubao-lite-4k")
|
||
|
||
# 智谱AI配置
|
||
zhipu_api_key: Optional[str] = Field(default=None)
|
||
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||
zhipu_model: str = Field(default="glm-4")
|
||
zhipu_embedding_model: str = Field(default="embedding-3")
|
||
|
||
# 月之暗面配置
|
||
moonshot_api_key: Optional[str] = Field(default=None)
|
||
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||
moonshot_model: str = Field(default="moonshot-v1-8k")
|
||
|
||
# 通用配置
|
||
max_tokens: int = Field(default=2048)
|
||
temperature: float = Field(default=0.7)
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
def get_current_config(self) -> dict:
|
||
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
|
||
try:
|
||
# 尝试从数据库读取默认聊天模型配置
|
||
from th_agenter.services.llm_config_service import LLMConfigService
|
||
llm_service = LLMConfigService()
|
||
db_config = llm_service.get_default_chat_config()
|
||
|
||
if db_config:
|
||
# 如果数据库中有默认配置,使用数据库配置
|
||
config = {
|
||
"api_key": db_config.api_key,
|
||
"base_url": db_config.base_url,
|
||
"model": db_config.model_name,
|
||
"max_tokens": self.max_tokens,
|
||
"temperature": self.temperature
|
||
}
|
||
return config
|
||
except Exception as e:
|
||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||
logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||
|
||
# 回退到原有的环境变量配置
|
||
provider_configs = {
|
||
"openai": {
|
||
"api_key": self.openai_api_key,
|
||
"base_url": self.openai_base_url,
|
||
"model": self.openai_model
|
||
},
|
||
"deepseek": {
|
||
"api_key": self.deepseek_api_key,
|
||
"base_url": self.deepseek_base_url,
|
||
"model": self.deepseek_model
|
||
},
|
||
"doubao": {
|
||
"api_key": self.doubao_api_key,
|
||
"base_url": self.doubao_base_url,
|
||
"model": self.doubao_model
|
||
},
|
||
"zhipu": {
|
||
"api_key": self.zhipu_api_key,
|
||
"base_url": self.zhipu_base_url,
|
||
"model": self.zhipu_model
|
||
},
|
||
"moonshot": {
|
||
"api_key": self.moonshot_api_key,
|
||
"base_url": self.moonshot_base_url,
|
||
"model": self.moonshot_model
|
||
}
|
||
}
|
||
|
||
config = provider_configs.get(self.provider, provider_configs["openai"])
|
||
config.update({
|
||
"max_tokens": self.max_tokens,
|
||
"temperature": self.temperature
|
||
})
|
||
return config
|
||
|
||
class EmbeddingSettings(BaseSettings):
|
||
"""Embedding模型配置 - 支持多种提供商."""
|
||
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||
|
||
# OpenAI配置
|
||
openai_api_key: Optional[str] = Field(default=None)
|
||
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||
openai_embedding_model: str = Field(default="text-embedding-ada-002")
|
||
|
||
# DeepSeek配置
|
||
deepseek_api_key: Optional[str] = Field(default=None)
|
||
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||
deepseek_embedding_model: str = Field(default="deepseek-embedding")
|
||
|
||
# 豆包配置
|
||
doubao_api_key: Optional[str] = Field(default=None)
|
||
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||
doubao_embedding_model: str = Field(default="doubao-embedding")
|
||
|
||
# 智谱AI配置
|
||
zhipu_api_key: Optional[str] = Field(default=None)
|
||
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||
zhipu_embedding_model: str = Field(default="embedding-3")
|
||
|
||
# 月之暗面配置
|
||
moonshot_api_key: Optional[str] = Field(default=None)
|
||
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||
moonshot_embedding_model: str = Field(default="moonshot-embedding")
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
def get_current_config(self) -> dict:
|
||
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
|
||
try:
|
||
# 尝试从数据库读取默认嵌入模型配置
|
||
from th_agenter.services.llm_config_service import LLMConfigService
|
||
llm_service = LLMConfigService()
|
||
db_config = llm_service.get_default_embedding_config()
|
||
|
||
if db_config:
|
||
# 如果数据库中有默认配置,使用数据库配置
|
||
config = {
|
||
"api_key": db_config.api_key,
|
||
"base_url": db_config.base_url,
|
||
"model": db_config.model_name
|
||
}
|
||
return config
|
||
except Exception as e:
|
||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||
logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||
|
||
# 回退到原有的环境变量配置
|
||
provider_configs = {
|
||
"openai": {
|
||
"api_key": self.openai_api_key,
|
||
"base_url": self.openai_base_url,
|
||
"model": self.openai_embedding_model
|
||
},
|
||
"deepseek": {
|
||
"api_key": self.deepseek_api_key,
|
||
"base_url": self.deepseek_base_url,
|
||
"model": self.deepseek_embedding_model
|
||
},
|
||
"doubao": {
|
||
"api_key": self.doubao_api_key,
|
||
"base_url": self.doubao_base_url,
|
||
"model": self.doubao_embedding_model
|
||
},
|
||
"zhipu": {
|
||
"api_key": self.zhipu_api_key,
|
||
"base_url": self.zhipu_base_url,
|
||
"model": self.zhipu_embedding_model
|
||
},
|
||
"moonshot": {
|
||
"api_key": self.moonshot_api_key,
|
||
"base_url": self.moonshot_base_url,
|
||
"model": self.moonshot_embedding_model
|
||
}
|
||
}
|
||
|
||
return provider_configs.get(self.provider, provider_configs["zhipu"])
|
||
|
||
class VectorDBSettings(BaseSettings):
|
||
"""Vector database configuration."""
|
||
type: str = Field(default="pgvector", alias="vector_db_type")
|
||
persist_directory: str = Field(default="./data/chroma")
|
||
collection_name: str = Field(default="documents")
|
||
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
|
||
|
||
# PostgreSQL pgvector configuration
|
||
pgvector_host: str = Field(default="localhost")
|
||
pgvector_port: int = Field(default=5432)
|
||
pgvector_database: str = Field(default="vectordb")
|
||
pgvector_user: str = Field(default="postgres")
|
||
pgvector_password: str = Field(default="")
|
||
pgvector_table_name: str = Field(default="embeddings")
|
||
pgvector_vector_dimension: int = Field(default=1024)
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class FileSettings(BaseSettings):
|
||
"""File processing configuration."""
|
||
upload_dir: str = Field(default="./data/uploads")
|
||
max_size: int = Field(default=10485760) # 10MB
|
||
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
|
||
chunk_size: int = Field(default=1000)
|
||
chunk_overlap: int = Field(default=200)
|
||
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
|
||
|
||
@field_validator('allowed_extensions', mode='before')
|
||
@classmethod
|
||
def parse_allowed_extensions(cls, v):
|
||
"""Parse comma-separated string to list of extensions."""
|
||
if isinstance(v, str):
|
||
# Split by comma and add dots if not present
|
||
extensions = [ext.strip() for ext in v.split(',')]
|
||
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||
elif isinstance(v, list):
|
||
# Ensure all extensions start with dot
|
||
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
|
||
return v
|
||
|
||
def get_allowed_extensions_list(self) -> List[str]:
|
||
"""Get allowed extensions as a list."""
|
||
if isinstance(self.allowed_extensions, list):
|
||
return self.allowed_extensions
|
||
elif isinstance(self.allowed_extensions, str):
|
||
# Split by comma and add dots if not present
|
||
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
|
||
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||
return []
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class StorageSettings(BaseSettings):
|
||
"""Storage configuration."""
|
||
storage_type: str = Field(default="local") # local or s3
|
||
upload_directory: str = Field(default="./data/uploads")
|
||
|
||
# S3 settings
|
||
s3_bucket_name: str = Field(default="chat-agent-files")
|
||
aws_access_key_id: Optional[str] = Field(default=None)
|
||
aws_secret_access_key: Optional[str] = Field(default=None)
|
||
aws_region: str = Field(default="us-east-1")
|
||
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class CORSSettings(BaseSettings):
|
||
"""CORS configuration."""
|
||
allowed_origins: List[str] = Field(default=["*"])
|
||
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||
allowed_headers: List[str] = Field(default=["*"])
|
||
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
class ChatSettings(BaseSettings):
|
||
"""Chat configuration."""
|
||
max_history_length: int = Field(default=10)
|
||
system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。")
|
||
max_response_tokens: int = Field(default=1000)
|
||
|
||
class Settings(BaseSettings):
|
||
"""Main application settings."""
|
||
|
||
# App info
|
||
app_name: str = Field(default="TH Agenter")
|
||
app_version: str = Field(default="0.2.0")
|
||
debug: bool = Field(default=True)
|
||
environment: str = Field(default="development")
|
||
|
||
# Server
|
||
host: str = Field(default="0.0.0.0")
|
||
port: int = Field(default=8000)
|
||
|
||
# Configuration sections
|
||
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
||
llm: LLMSettings = Field(default_factory=LLMSettings)
|
||
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
|
||
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
|
||
file: FileSettings = Field(default_factory=FileSettings)
|
||
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||
cors: CORSSettings = Field(default_factory=CORSSettings)
|
||
chat: ChatSettings = Field(default_factory=ChatSettings)
|
||
tool: ToolSetings = Field(default_factory=ToolSetings)
|
||
model_config = {
|
||
"env_file": ".env",
|
||
"env_file_encoding": "utf-8",
|
||
"case_sensitive": False,
|
||
"extra": "ignore"
|
||
}
|
||
|
||
@classmethod
|
||
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
|
||
"""Load settings from YAML file."""
|
||
config_file = Path(config_path)
|
||
|
||
if not config_file.exists():
|
||
# 获取当前文件所在目录(backend/open_agent/core)
|
||
current_dir = Path(__file__).parent
|
||
# 向上两级到backend目录,然后找configs/settings.yaml
|
||
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
|
||
if backend_config_path.exists():
|
||
config_file = backend_config_path
|
||
else:
|
||
return cls()
|
||
|
||
with open(config_file, "r", encoding="utf-8") as f:
|
||
config_data = yaml.safe_load(f) or {}
|
||
|
||
# 处理环境变量替换
|
||
config_data = cls._resolve_env_vars_nested(config_data)
|
||
|
||
# 为每个子设置类创建实例,确保它们能正确加载环境变量
|
||
# 如果YAML中没有对应配置,则使用默认的BaseSettings加载(会自动读取.env文件)
|
||
settings_kwargs = {}
|
||
|
||
# 显式处理各个子设置,以解决debug等情况因为环境的变化没有自动加载.env配置的问题
|
||
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
|
||
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
|
||
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
|
||
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
|
||
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
|
||
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
|
||
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
|
||
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
|
||
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
|
||
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||
|
||
# 添加顶级配置
|
||
for key, value in config_data.items():
|
||
if key not in settings_kwargs:
|
||
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
|
||
settings_kwargs[key] = value
|
||
|
||
return cls(**settings_kwargs)
|
||
|
||
@staticmethod
|
||
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
||
"""Flatten nested configuration dictionary."""
|
||
flat = {}
|
||
for key, value in config.items():
|
||
new_key = f"{prefix}_{key}" if prefix else key
|
||
if isinstance(value, dict):
|
||
flat.update(Settings._flatten_config(value, new_key))
|
||
else:
|
||
flat[new_key] = value
|
||
return flat
|
||
|
||
@staticmethod
|
||
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Resolve environment variables in nested configuration."""
|
||
if isinstance(config, dict):
|
||
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
|
||
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
|
||
env_var = config[2:-1]
|
||
return os.getenv(env_var, config)
|
||
else:
|
||
return config
|
||
|
||
@staticmethod
|
||
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Resolve environment variables in configuration values."""
|
||
resolved = {}
|
||
for key, value in config.items():
|
||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||
env_var = value[2:-1]
|
||
resolved[key] = os.getenv(env_var, value)
|
||
else:
|
||
resolved[key] = value
|
||
return resolved
|
||
|
||
@lru_cache()
|
||
def get_settings() -> Settings:
|
||
"""Get cached settings instance."""
|
||
settings = Settings.load_from_yaml()
|
||
return settings
|
||
|
||
settings = get_settings() |