2025-12-04 14:48:38 +08:00
|
|
|
|
"""数据库配置服务"""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from loguru import logger
|
2025-12-04 14:48:38 +08:00
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from sqlalchemy import select
|
2025-12-04 14:48:38 +08:00
|
|
|
|
from cryptography.fernet import Fernet
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
from ..models.database_config import DatabaseConfig
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from utils.util_exceptions import ValidationError, NotFoundError
|
2025-12-04 14:48:38 +08:00
|
|
|
|
from .postgresql_tool_manager import get_postgresql_tool
|
|
|
|
|
|
from .mysql_tool_manager import get_mysql_tool
|
|
|
|
|
|
|
|
|
|
|
|
class DatabaseConfigService:
|
|
|
|
|
|
"""数据库配置管理服务"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_session: Session):
|
2025-12-16 13:55:16 +08:00
|
|
|
|
self.session = db_session
|
2025-12-04 14:48:38 +08:00
|
|
|
|
self.postgresql_tool = get_postgresql_tool()
|
|
|
|
|
|
self.mysql_tool = get_mysql_tool()
|
|
|
|
|
|
# 初始化加密密钥
|
|
|
|
|
|
self.encryption_key = self._get_or_create_encryption_key()
|
|
|
|
|
|
self.cipher = Fernet(self.encryption_key)
|
|
|
|
|
|
def _get_or_create_encryption_key(self) -> bytes:
|
|
|
|
|
|
"""获取或创建加密密钥"""
|
|
|
|
|
|
key_file = "db/db_config_key.key"
|
|
|
|
|
|
if os.path.exists(key_file):
|
|
|
|
|
|
print('find db_config_key')
|
|
|
|
|
|
with open(key_file, 'rb') as f:
|
|
|
|
|
|
return f.read()
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print('not find db_config_key')
|
|
|
|
|
|
key = Fernet.generate_key()
|
|
|
|
|
|
with open(key_file, 'wb') as f:
|
|
|
|
|
|
f.write(key)
|
|
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
|
|
def _encrypt_password(self, password: str) -> str:
|
|
|
|
|
|
"""加密密码"""
|
|
|
|
|
|
return self.cipher.encrypt(password.encode()).decode()
|
|
|
|
|
|
|
|
|
|
|
|
def _decrypt_password(self, encrypted_password: str) -> str:
|
|
|
|
|
|
"""解密密码"""
|
|
|
|
|
|
return self.cipher.decrypt(encrypted_password.encode()).decode()
|
|
|
|
|
|
|
|
|
|
|
|
async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
|
|
|
|
|
"""创建数据库配置"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 验证配置
|
|
|
|
|
|
required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password']
|
|
|
|
|
|
for field in required_fields:
|
|
|
|
|
|
if field not in config_data:
|
|
|
|
|
|
raise ValidationError(f"缺少必需字段: {field}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 测试连接
|
|
|
|
|
|
test_config = {
|
|
|
|
|
|
'host': config_data['host'],
|
|
|
|
|
|
'port': config_data['port'],
|
|
|
|
|
|
'database': config_data['database'],
|
|
|
|
|
|
'username': config_data['username'],
|
|
|
|
|
|
'password': config_data['password']
|
|
|
|
|
|
}
|
|
|
|
|
|
if 'postgresql' == config_data['db_type']:
|
|
|
|
|
|
test_result = await self.postgresql_tool.execute(
|
|
|
|
|
|
operation="test_connection",
|
|
|
|
|
|
connection_config=test_config
|
|
|
|
|
|
)
|
|
|
|
|
|
if not test_result.success:
|
|
|
|
|
|
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
|
|
|
|
|
elif 'mysql' == config_data['db_type']:
|
|
|
|
|
|
test_result = await self.mysql_tool.execute(
|
|
|
|
|
|
operation="test_connection",
|
|
|
|
|
|
connection_config=test_config
|
|
|
|
|
|
)
|
|
|
|
|
|
if not test_result.success:
|
|
|
|
|
|
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
|
|
|
|
|
# 如果设置为默认配置,先取消其他默认配置
|
|
|
|
|
|
if config_data.get('is_default', False):
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(DatabaseConfig).where(
|
|
|
|
|
|
DatabaseConfig.created_by == user_id,
|
|
|
|
|
|
DatabaseConfig.is_default == True
|
|
|
|
|
|
)
|
|
|
|
|
|
result = self.session.execute(stmt)
|
|
|
|
|
|
for config in result.scalars():
|
|
|
|
|
|
config.is_default = False
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建配置
|
|
|
|
|
|
db_config = DatabaseConfig(
|
|
|
|
|
|
created_by=user_id,
|
|
|
|
|
|
name=config_data['name'],
|
|
|
|
|
|
db_type=config_data['db_type'],
|
|
|
|
|
|
host=config_data['host'],
|
|
|
|
|
|
port=config_data['port'],
|
|
|
|
|
|
database=config_data['database'],
|
|
|
|
|
|
username=config_data['username'],
|
|
|
|
|
|
password=self._encrypt_password(config_data['password']),
|
|
|
|
|
|
is_active=config_data.get('is_active', True),
|
|
|
|
|
|
is_default=config_data.get('is_default', False),
|
|
|
|
|
|
connection_params=config_data.get('connection_params')
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
|
self.session.add(db_config)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(db_config)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
|
|
|
|
|
|
return db_config
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-07 11:30:54 +08:00
|
|
|
|
await self.session.rollback()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
logger.error(f"创建数据库配置失败: {str(e)}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
|
2025-12-04 14:48:38 +08:00
|
|
|
|
"""获取用户的数据库配置列表"""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
if active_only:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = stmt.where(DatabaseConfig.is_active == True)
|
|
|
|
|
|
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
|
2026-01-07 11:30:54 +08:00
|
|
|
|
return (await self.session.execute(stmt)).scalars().all()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
|
2025-12-04 14:48:38 +08:00
|
|
|
|
"""根据ID获取配置"""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(DatabaseConfig).where(
|
|
|
|
|
|
DatabaseConfig.id == config_id,
|
|
|
|
|
|
DatabaseConfig.created_by == user_id
|
|
|
|
|
|
)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
return (await self.session.execute(stmt)).scalar_one_or_none()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
|
2025-12-04 14:48:38 +08:00
|
|
|
|
"""获取用户的默认配置"""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(DatabaseConfig).where(
|
|
|
|
|
|
DatabaseConfig.created_by == user_id,
|
|
|
|
|
|
# DatabaseConfig.is_default == True,
|
|
|
|
|
|
DatabaseConfig.is_active == True
|
|
|
|
|
|
)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
return (await self.session.execute(stmt)).scalar_one_or_none()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""测试数据库连接"""
|
|
|
|
|
|
config = self.get_config_by_id(config_id, user_id)
|
|
|
|
|
|
if not config:
|
|
|
|
|
|
raise NotFoundError("数据库配置不存在")
|
|
|
|
|
|
|
|
|
|
|
|
test_config = {
|
|
|
|
|
|
'host': config.host,
|
|
|
|
|
|
'port': config.port,
|
|
|
|
|
|
'database': config.database,
|
|
|
|
|
|
'username': config.username,
|
|
|
|
|
|
'password': self._decrypt_password(config.password)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
result = await self.postgresql_tool.execute(
|
|
|
|
|
|
operation="test_connection",
|
|
|
|
|
|
connection_config=test_config
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': result.success,
|
|
|
|
|
|
'message': result.result.get('message') if result.success else result.error,
|
|
|
|
|
|
'details': result.result if result.success else None
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""连接数据库并获取表列表"""
|
|
|
|
|
|
config = self.get_config_by_id(config_id, user_id)
|
|
|
|
|
|
if not config:
|
|
|
|
|
|
raise NotFoundError("数据库配置不存在")
|
|
|
|
|
|
|
|
|
|
|
|
connection_config = {
|
|
|
|
|
|
'host': config.host,
|
|
|
|
|
|
'port': config.port,
|
|
|
|
|
|
'database': config.database,
|
|
|
|
|
|
'username': config.username,
|
|
|
|
|
|
'password': self._decrypt_password(config.password)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if 'postgresql' == config.db_type:
|
|
|
|
|
|
# 连接数据库
|
|
|
|
|
|
connect_result = await self.postgresql_tool.execute(
|
|
|
|
|
|
operation="connect",
|
|
|
|
|
|
connection_config=connection_config,
|
|
|
|
|
|
user_id=str(user_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
elif 'mysql' == config.db_type:
|
|
|
|
|
|
# 连接数据库
|
|
|
|
|
|
connect_result = await self.mysql_tool.execute(
|
|
|
|
|
|
operation="connect",
|
|
|
|
|
|
connection_config=connection_config,
|
|
|
|
|
|
user_id=str(user_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not connect_result.success:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': connect_result.error
|
|
|
|
|
|
}
|
|
|
|
|
|
# 连接信息已保存到PostgreSQLMCPTool的connections中
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'data': connect_result.result,
|
|
|
|
|
|
'config_name': config.name
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]:
|
|
|
|
|
|
"""获取表数据预览(复用已建立的连接)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
user_id_str = str(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据db_type选择相应的数据库工具
|
|
|
|
|
|
if db_type.lower() == 'postgresql':
|
|
|
|
|
|
db_tool = self.postgresql_tool
|
|
|
|
|
|
elif db_type.lower() == 'mysql':
|
|
|
|
|
|
db_tool = self.mysql_tool
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': f'不支持的数据库类型: {db_type}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否已有连接
|
|
|
|
|
|
if user_id_str not in db_tool.connections:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': '数据库连接已断开,请重新连接数据库'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 直接使用已建立的连接执行查询
|
|
|
|
|
|
sql_query = f"SELECT * FROM {table_name}"
|
|
|
|
|
|
result = await db_tool.execute(
|
|
|
|
|
|
operation="execute_query",
|
|
|
|
|
|
user_id=user_id_str,
|
|
|
|
|
|
sql_query=sql_query,
|
|
|
|
|
|
limit=limit
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not result.success:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': result.error
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'data': result.result,
|
|
|
|
|
|
'db_type': db_type
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取表数据失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': f'获取表数据失败: {str(e)}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def disconnect_database(self, user_id: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""断开数据库连接"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 从PostgreSQLMCPTool断开连接
|
|
|
|
|
|
self.postgresql_tool.execute(
|
|
|
|
|
|
operation="disconnect",
|
|
|
|
|
|
user_id=str(user_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 从本地连接管理中移除
|
|
|
|
|
|
if user_id in self.user_connections:
|
|
|
|
|
|
del self.user_connections[user_id]
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'message': '数据库连接已断开'
|
|
|
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': f'断开连接失败: {str(e)}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
|
2025-12-04 14:48:38 +08:00
|
|
|
|
"""根据数据库类型获取用户配置"""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(DatabaseConfig).where(
|
|
|
|
|
|
DatabaseConfig.created_by == user_id,
|
|
|
|
|
|
DatabaseConfig.db_type == db_type,
|
|
|
|
|
|
DatabaseConfig.is_active == True
|
|
|
|
|
|
)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
return await self.session.scalar(stmt)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
|
|
|
|
|
"""创建或更新数据库配置(保证db_type唯一性)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 检查是否已存在该类型的配置
|
|
|
|
|
|
existing_config = self.get_config_by_type(user_id, config_data['db_type'])
|
|
|
|
|
|
|
|
|
|
|
|
if existing_config:
|
|
|
|
|
|
# 更新现有配置
|
|
|
|
|
|
for key, value in config_data.items():
|
|
|
|
|
|
if key == 'password':
|
|
|
|
|
|
setattr(existing_config, key, self._encrypt_password(value))
|
|
|
|
|
|
elif hasattr(existing_config, key):
|
|
|
|
|
|
setattr(existing_config, key, value)
|
|
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
await self.session.refresh(existing_config)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
|
|
|
|
|
|
return existing_config
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 创建新配置
|
|
|
|
|
|
return await self.create_config(user_id, config_data)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-07 11:30:54 +08:00
|
|
|
|
await self.session.rollback()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
2025-12-16 13:55:16 +08:00
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]:
|
|
|
|
|
|
"""获取表结构信息(复用已建立的连接)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}")
|
|
|
|
|
|
user_id_str = str(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取用户默认数据库配置
|
|
|
|
|
|
default_config = self.get_default_config(user_id)
|
|
|
|
|
|
if not default_config:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': '未找到默认数据库配置'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 根据db_type选择相应的数据库工具
|
|
|
|
|
|
if default_config.db_type.lower() == 'postgresql':
|
|
|
|
|
|
db_tool = self.postgresql_tool
|
|
|
|
|
|
elif default_config.db_type.lower() == 'mysql':
|
|
|
|
|
|
db_tool = self.mysql_tool
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': f'不支持的数据库类型: {default_config.db_type}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否已有连接
|
|
|
|
|
|
if user_id_str not in db_tool.connections:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': '数据库连接已断开,请重新连接数据库'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 使用已建立的连接执行describe_table操作
|
|
|
|
|
|
result = await db_tool.execute(
|
|
|
|
|
|
operation="describe_table",
|
|
|
|
|
|
user_id=user_id_str,
|
|
|
|
|
|
table_name=table_name
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not result.success:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': result.error
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': True,
|
|
|
|
|
|
'data': result.result,
|
|
|
|
|
|
'db_type': default_config.db_type
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"获取表结构失败: {str(e)}", exc_info=True)
|
|
|
|
|
|
return {
|
|
|
|
|
|
'success': False,
|
|
|
|
|
|
'message': f'获取表结构失败: {str(e)}'
|
|
|
|
|
|
}
|
|
|
|
|
|
|