hxf/backend/th_agenter/core/config.py

468 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Configuration management for TH Agenter."""
import os
from requests import Session
import yaml
from pathlib import Path
from loguru import logger
from typing import Any, Dict, List, Optional, Union
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from functools import lru_cache
class DatabaseSettings(BaseSettings):
"""Database configuration."""
url: str = Field(..., alias="database_url") # Must be provided via environment variable
echo: bool = Field(default=False)
pool_size: int = Field(default=5)
max_overflow: int = Field(default=10)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class SecuritySettings(BaseSettings):
"""Security configuration."""
secret_key: str = Field(default="your-secret-key-here-change-in-production")
algorithm: str = Field(default="HS256")
access_token_expire_minutes: int = Field(default=300)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class ToolSetings(BaseSettings):
# Tavily搜索配置
tavily_api_key: Optional[str] = Field(default=None)
weather_api_key: Optional[str] = Field(default=None)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class LLMSettings(BaseSettings):
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
# OpenAI配置
openai_api_key: Optional[str] = Field(default=None)
openai_base_url: str = Field(default="https://api.openai.com/v1")
openai_model: str = Field(default="gpt-3.5-turbo")
# DeepSeek配置
deepseek_api_key: Optional[str] = Field(default=None)
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
deepseek_model: str = Field(default="deepseek-chat")
# 豆包配置
doubao_api_key: Optional[str] = Field(default=None)
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
doubao_model: str = Field(default="doubao-lite-4k")
# 智谱AI配置
zhipu_api_key: Optional[str] = Field(default=None)
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
zhipu_model: str = Field(default="glm-4")
zhipu_embedding_model: str = Field(default="embedding-3")
# 月之暗面配置
moonshot_api_key: Optional[str] = Field(default=None)
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
moonshot_model: str = Field(default="moonshot-v1-8k")
# 通用配置
max_tokens: int = Field(default=2048)
temperature: float = Field(default=0.7)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
async def get_current_config(self, session: Session) -> dict:
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
try:
from th_agenter.services.llm_config_service import LLMConfigService
# 尝试从数据库读取默认聊天模型配置
llm_service = LLMConfigService()
db_config = None
if session:
db_config = await llm_service.get_default_chat_config(session)
if db_config:
# 如果数据库中有默认配置,使用数据库配置
config = {
"api_key": db_config.api_key,
"base_url": db_config.base_url,
"model": db_config.model_name,
"max_tokens": self.max_tokens,
"temperature": self.temperature
}
if session:
session.desc = f"使用LLM配置(get_default_chat_config)> {config}"
else:
logger.info(f"使用LLM配置(get_default_chat_config) > {config}")
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
if session:
session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}"
else:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
# 回退到原有的环境变量配置
provider_configs = {
"openai": {
"api_key": self.openai_api_key,
"base_url": self.openai_base_url,
"model": self.openai_model
},
"deepseek": {
"api_key": self.deepseek_api_key,
"base_url": self.deepseek_base_url,
"model": self.deepseek_model
},
"doubao": {
"api_key": self.doubao_api_key,
"base_url": self.doubao_base_url,
"model": self.doubao_model
},
"zhipu": {
"api_key": self.zhipu_api_key,
"base_url": self.zhipu_base_url,
"model": self.zhipu_model
},
"moonshot": {
"api_key": self.moonshot_api_key,
"base_url": self.moonshot_base_url,
"model": self.moonshot_model
}
}
config = provider_configs.get(self.provider, provider_configs["openai"])
config.update({
"max_tokens": self.max_tokens,
"temperature": self.temperature
})
return config
class EmbeddingSettings(BaseSettings):
"""Embedding模型配置 - 支持多种提供商."""
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
# OpenAI配置
openai_api_key: Optional[str] = Field(default=None)
openai_base_url: str = Field(default="https://api.openai.com/v1")
openai_embedding_model: str = Field(default="text-embedding-ada-002")
# DeepSeek配置
deepseek_api_key: Optional[str] = Field(default=None)
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
deepseek_embedding_model: str = Field(default="deepseek-embedding")
# 豆包配置
doubao_api_key: Optional[str] = Field(default=None)
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
doubao_embedding_model: str = Field(default="doubao-embedding")
# 智谱AI配置
zhipu_api_key: Optional[str] = Field(default=None)
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
zhipu_embedding_model: str = Field(default="embedding-3")
# 月之暗面配置
moonshot_api_key: Optional[str] = Field(default=None)
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
moonshot_embedding_model: str = Field(default="moonshot-embedding")
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
async def get_current_config(self, session: Session) -> dict:
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
try:
if session:
session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config";
# 尝试从数据库读取默认嵌入模型配置
from th_agenter.services.llm_config_service import LLMConfigService
llm_service = LLMConfigService()
db_config = await llm_service.get_default_embedding_config(session)
if db_config:
# 如果数据库中有默认配置,使用数据库配置
config = {
"api_key": db_config.api_key,
"base_url": db_config.base_url,
"model": db_config.model_name
}
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
if session:
session.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
else:
logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
# 回退到原有的环境变量配置
provider_configs = {
"openai": {
"api_key": self.openai_api_key,
"base_url": self.openai_base_url,
"model": self.openai_embedding_model
},
"deepseek": {
"api_key": self.deepseek_api_key,
"base_url": self.deepseek_base_url,
"model": self.deepseek_embedding_model
},
"doubao": {
"api_key": self.doubao_api_key,
"base_url": self.doubao_base_url,
"model": self.doubao_embedding_model
},
"zhipu": {
"api_key": self.zhipu_api_key,
"base_url": self.zhipu_base_url,
"model": self.zhipu_embedding_model
},
"moonshot": {
"api_key": self.moonshot_api_key,
"base_url": self.moonshot_base_url,
"model": self.moonshot_embedding_model
}
}
return provider_configs.get(self.provider, provider_configs["zhipu"])
class VectorDBSettings(BaseSettings):
"""Vector database configuration."""
type: str = Field(default="pgvector", alias="vector_db_type")
persist_directory: str = Field(default="./data/chroma")
collection_name: str = Field(default="documents")
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
# PostgreSQL pgvector configuration
pgvector_host: str = Field(default="localhost")
pgvector_port: int = Field(default=5432)
pgvector_database: str = Field(default="vectordb")
pgvector_user: str = Field(default="postgres")
pgvector_password: str = Field(default="")
pgvector_table_name: str = Field(default="embeddings")
pgvector_vector_dimension: int = Field(default=1024)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class FileSettings(BaseSettings):
"""File processing configuration."""
upload_dir: str = Field(default="./data/uploads")
max_size: int = Field(default=10485760) # 10MB
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
chunk_size: int = Field(default=1000)
chunk_overlap: int = Field(default=200)
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
@field_validator('allowed_extensions', mode='before')
@classmethod
def parse_allowed_extensions(cls, v):
"""Parse comma-separated string to list of extensions."""
if isinstance(v, str):
# Split by comma and add dots if not present
extensions = [ext.strip() for ext in v.split(',')]
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
elif isinstance(v, list):
# Ensure all extensions start with dot
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
return v
def get_allowed_extensions_list(self) -> List[str]:
"""Get allowed extensions as a list."""
if isinstance(self.allowed_extensions, list):
return self.allowed_extensions
elif isinstance(self.allowed_extensions, str):
# Split by comma and add dots if not present
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
return []
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class StorageSettings(BaseSettings):
"""Storage configuration."""
storage_type: str = Field(default="local") # local or s3
upload_directory: str = Field(default="./data/uploads")
# S3 settings
s3_bucket_name: str = Field(default="chat-agent-files")
aws_access_key_id: Optional[str] = Field(default=None)
aws_secret_access_key: Optional[str] = Field(default=None)
aws_region: str = Field(default="us-east-1")
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class CORSSettings(BaseSettings):
"""CORS configuration."""
allowed_origins: List[str] = Field(default=["*"])
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
allowed_headers: List[str] = Field(default=["*"])
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
class ChatSettings(BaseSettings):
"""Chat configuration."""
max_history_length: int = Field(default=10)
system_prompt: str = Field(default="你是一个有用的AI助手请根据提供的上下文信息回答用户的问题。")
max_response_tokens: int = Field(default=1000)
class Settings(BaseSettings):
"""Main application settings."""
# App info
app_name: str = Field(default="TH Agenter")
app_version: str = Field(default="0.2.0")
debug: bool = Field(default=True)
environment: str = Field(default="development")
# Server
host: str = Field(default="0.0.0.0")
port: int = Field(default=8000)
# Configuration sections
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
security: SecuritySettings = Field(default_factory=SecuritySettings)
llm: LLMSettings = Field(default_factory=LLMSettings)
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
file: FileSettings = Field(default_factory=FileSettings)
storage: StorageSettings = Field(default_factory=StorageSettings)
cors: CORSSettings = Field(default_factory=CORSSettings)
chat: ChatSettings = Field(default_factory=ChatSettings)
tool: ToolSetings = Field(default_factory=ToolSetings)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore"
}
@classmethod
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
"""Load settings from YAML file."""
config_file = Path(config_path)
if not config_file.exists():
# 获取当前文件所在目录backend/open_agent/core
current_dir = Path(__file__).parent
# 向上两级到backend目录然后找configs/settings.yaml
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
if backend_config_path.exists():
config_file = backend_config_path
else:
return cls()
with open(config_file, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
# 处理环境变量替换
config_data = cls._resolve_env_vars_nested(config_data)
# 为每个子设置类创建实例,确保它们能正确加载环境变量
# 如果YAML中没有对应配置则使用默认的BaseSettings加载会自动读取.env文件
settings_kwargs = {}
# 显式处理各个子设置以解决debug等情况因为环境的变化没有自动加载.env配置的问题
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
# 添加顶级配置
for key, value in config_data.items():
if key not in settings_kwargs:
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
settings_kwargs[key] = value
return cls(**settings_kwargs)
@staticmethod
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
"""Flatten nested configuration dictionary."""
flat = {}
for key, value in config.items():
new_key = f"{prefix}_{key}" if prefix else key
if isinstance(value, dict):
flat.update(Settings._flatten_config(value, new_key))
else:
flat[new_key] = value
return flat
@staticmethod
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve environment variables in nested configuration."""
if isinstance(config, dict):
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
env_var = config[2:-1]
return os.getenv(env_var, config)
else:
return config
@staticmethod
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve environment variables in configuration values."""
resolved = {}
for key, value in config.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
env_var = value[2:-1]
resolved[key] = os.getenv(env_var, value)
else:
resolved[key] = value
return resolved
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
settings = Settings.load_from_yaml()
return settings
settings = get_settings()