375 lines
14 KiB
Python
375 lines
14 KiB
Python
"""数据库配置服务"""
|
||
from loguru import logger
|
||
from typing import List, Dict, Any, Optional
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import select
|
||
from cryptography.fernet import Fernet
|
||
import os
|
||
|
||
from ..models.database_config import DatabaseConfig
|
||
from utils.util_exceptions import ValidationError, NotFoundError
|
||
from .postgresql_tool_manager import get_postgresql_tool
|
||
from .mysql_tool_manager import get_mysql_tool
|
||
|
||
class DatabaseConfigService:
|
||
"""数据库配置管理服务"""
|
||
|
||
def __init__(self, db_session: Session):
|
||
self.session = db_session
|
||
self.postgresql_tool = get_postgresql_tool()
|
||
self.mysql_tool = get_mysql_tool()
|
||
# 初始化加密密钥
|
||
self.encryption_key = self._get_or_create_encryption_key()
|
||
self.cipher = Fernet(self.encryption_key)
|
||
def _get_or_create_encryption_key(self) -> bytes:
|
||
"""获取或创建加密密钥"""
|
||
key_file = "db/db_config_key.key"
|
||
if os.path.exists(key_file):
|
||
print('find db_config_key')
|
||
with open(key_file, 'rb') as f:
|
||
return f.read()
|
||
|
||
else:
|
||
print('not find db_config_key')
|
||
key = Fernet.generate_key()
|
||
with open(key_file, 'wb') as f:
|
||
f.write(key)
|
||
return key
|
||
|
||
def _encrypt_password(self, password: str) -> str:
|
||
"""加密密码"""
|
||
return self.cipher.encrypt(password.encode()).decode()
|
||
|
||
def _decrypt_password(self, encrypted_password: str) -> str:
|
||
"""解密密码"""
|
||
return self.cipher.decrypt(encrypted_password.encode()).decode()
|
||
|
||
async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
||
"""创建数据库配置"""
|
||
try:
|
||
# 验证配置
|
||
required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password']
|
||
for field in required_fields:
|
||
if field not in config_data:
|
||
raise ValidationError(f"缺少必需字段: {field}")
|
||
|
||
|
||
# 测试连接
|
||
test_config = {
|
||
'host': config_data['host'],
|
||
'port': config_data['port'],
|
||
'database': config_data['database'],
|
||
'username': config_data['username'],
|
||
'password': config_data['password']
|
||
}
|
||
if 'postgresql' == config_data['db_type']:
|
||
test_result = await self.postgresql_tool.execute(
|
||
operation="test_connection",
|
||
connection_config=test_config
|
||
)
|
||
if not test_result.success:
|
||
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
||
elif 'mysql' == config_data['db_type']:
|
||
test_result = await self.mysql_tool.execute(
|
||
operation="test_connection",
|
||
connection_config=test_config
|
||
)
|
||
if not test_result.success:
|
||
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
||
# 如果设置为默认配置,先取消其他默认配置
|
||
if config_data.get('is_default', False):
|
||
stmt = select(DatabaseConfig).where(
|
||
DatabaseConfig.created_by == user_id,
|
||
DatabaseConfig.is_default == True
|
||
)
|
||
result = self.session.execute(stmt)
|
||
for config in result.scalars():
|
||
config.is_default = False
|
||
|
||
# 创建配置
|
||
db_config = DatabaseConfig(
|
||
created_by=user_id,
|
||
name=config_data['name'],
|
||
db_type=config_data['db_type'],
|
||
host=config_data['host'],
|
||
port=config_data['port'],
|
||
database=config_data['database'],
|
||
username=config_data['username'],
|
||
password=self._encrypt_password(config_data['password']),
|
||
is_active=config_data.get('is_active', True),
|
||
is_default=config_data.get('is_default', False),
|
||
connection_params=config_data.get('connection_params')
|
||
)
|
||
|
||
self.session.add(db_config)
|
||
await self.session.commit()
|
||
await self.session.refresh(db_config)
|
||
|
||
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
|
||
return db_config
|
||
|
||
except Exception as e:
|
||
await self.session.rollback()
|
||
logger.error(f"创建数据库配置失败: {str(e)}")
|
||
raise
|
||
|
||
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
|
||
"""获取用户的数据库配置列表"""
|
||
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
|
||
if active_only:
|
||
stmt = stmt.where(DatabaseConfig.is_active == True)
|
||
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
|
||
return (await self.session.execute(stmt)).scalars().all()
|
||
|
||
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
|
||
"""根据ID获取配置"""
|
||
stmt = select(DatabaseConfig).where(
|
||
DatabaseConfig.id == config_id,
|
||
DatabaseConfig.created_by == user_id
|
||
)
|
||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||
|
||
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
|
||
"""获取用户的默认配置"""
|
||
stmt = select(DatabaseConfig).where(
|
||
DatabaseConfig.created_by == user_id,
|
||
# DatabaseConfig.is_default == True,
|
||
DatabaseConfig.is_active == True
|
||
)
|
||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||
|
||
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
||
"""测试数据库连接"""
|
||
config = self.get_config_by_id(config_id, user_id)
|
||
if not config:
|
||
raise NotFoundError("数据库配置不存在")
|
||
|
||
test_config = {
|
||
'host': config.host,
|
||
'port': config.port,
|
||
'database': config.database,
|
||
'username': config.username,
|
||
'password': self._decrypt_password(config.password)
|
||
}
|
||
|
||
result = await self.postgresql_tool.execute(
|
||
operation="test_connection",
|
||
connection_config=test_config
|
||
)
|
||
|
||
return {
|
||
'success': result.success,
|
||
'message': result.result.get('message') if result.success else result.error,
|
||
'details': result.result if result.success else None
|
||
}
|
||
|
||
async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
||
"""连接数据库并获取表列表"""
|
||
config = self.get_config_by_id(config_id, user_id)
|
||
if not config:
|
||
raise NotFoundError("数据库配置不存在")
|
||
|
||
connection_config = {
|
||
'host': config.host,
|
||
'port': config.port,
|
||
'database': config.database,
|
||
'username': config.username,
|
||
'password': self._decrypt_password(config.password)
|
||
}
|
||
|
||
if 'postgresql' == config.db_type:
|
||
# 连接数据库
|
||
connect_result = await self.postgresql_tool.execute(
|
||
operation="connect",
|
||
connection_config=connection_config,
|
||
user_id=str(user_id)
|
||
)
|
||
elif 'mysql' == config.db_type:
|
||
# 连接数据库
|
||
connect_result = await self.mysql_tool.execute(
|
||
operation="connect",
|
||
connection_config=connection_config,
|
||
user_id=str(user_id)
|
||
)
|
||
|
||
if not connect_result.success:
|
||
return {
|
||
'success': False,
|
||
'message': connect_result.error
|
||
}
|
||
# 连接信息已保存到PostgreSQLMCPTool的connections中
|
||
return {
|
||
'success': True,
|
||
'data': connect_result.result,
|
||
'config_name': config.name
|
||
}
|
||
|
||
async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]:
|
||
"""获取表数据预览(复用已建立的连接)"""
|
||
try:
|
||
user_id_str = str(user_id)
|
||
|
||
# 根据db_type选择相应的数据库工具
|
||
if db_type.lower() == 'postgresql':
|
||
db_tool = self.postgresql_tool
|
||
elif db_type.lower() == 'mysql':
|
||
db_tool = self.mysql_tool
|
||
else:
|
||
return {
|
||
'success': False,
|
||
'message': f'不支持的数据库类型: {db_type}'
|
||
}
|
||
|
||
# 检查是否已有连接
|
||
if user_id_str not in db_tool.connections:
|
||
return {
|
||
'success': False,
|
||
'message': '数据库连接已断开,请重新连接数据库'
|
||
}
|
||
|
||
# 直接使用已建立的连接执行查询
|
||
sql_query = f"SELECT * FROM {table_name}"
|
||
result = await db_tool.execute(
|
||
operation="execute_query",
|
||
user_id=user_id_str,
|
||
sql_query=sql_query,
|
||
limit=limit
|
||
)
|
||
|
||
if not result.success:
|
||
return {
|
||
'success': False,
|
||
'message': result.error
|
||
}
|
||
|
||
return {
|
||
'success': True,
|
||
'data': result.result,
|
||
'db_type': db_type
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取表数据失败: {str(e)}", exc_info=True)
|
||
return {
|
||
'success': False,
|
||
'message': f'获取表数据失败: {str(e)}'
|
||
}
|
||
|
||
def disconnect_database(self, user_id: int) -> Dict[str, Any]:
|
||
"""断开数据库连接"""
|
||
try:
|
||
# 从PostgreSQLMCPTool断开连接
|
||
self.postgresql_tool.execute(
|
||
operation="disconnect",
|
||
user_id=str(user_id)
|
||
)
|
||
|
||
# 从本地连接管理中移除
|
||
if user_id in self.user_connections:
|
||
del self.user_connections[user_id]
|
||
|
||
return {
|
||
'success': True,
|
||
'message': '数据库连接已断开'
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'message': f'断开连接失败: {str(e)}'
|
||
}
|
||
|
||
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
|
||
"""根据数据库类型获取用户配置"""
|
||
stmt = select(DatabaseConfig).where(
|
||
DatabaseConfig.created_by == user_id,
|
||
DatabaseConfig.db_type == db_type,
|
||
DatabaseConfig.is_active == True
|
||
)
|
||
return await self.session.scalar(stmt)
|
||
|
||
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
||
"""创建或更新数据库配置(保证db_type唯一性)"""
|
||
try:
|
||
# 检查是否已存在该类型的配置
|
||
existing_config = self.get_config_by_type(user_id, config_data['db_type'])
|
||
|
||
if existing_config:
|
||
# 更新现有配置
|
||
for key, value in config_data.items():
|
||
if key == 'password':
|
||
setattr(existing_config, key, self._encrypt_password(value))
|
||
elif hasattr(existing_config, key):
|
||
setattr(existing_config, key, value)
|
||
|
||
await self.session.commit()
|
||
await self.session.refresh(existing_config)
|
||
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
|
||
return existing_config
|
||
else:
|
||
# 创建新配置
|
||
return await self.create_config(user_id, config_data)
|
||
|
||
except Exception as e:
|
||
await self.session.rollback()
|
||
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]:
|
||
"""获取表结构信息(复用已建立的连接)"""
|
||
try:
|
||
logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}")
|
||
user_id_str = str(user_id)
|
||
|
||
# 获取用户默认数据库配置
|
||
default_config = self.get_default_config(user_id)
|
||
if not default_config:
|
||
return {
|
||
'success': False,
|
||
'message': '未找到默认数据库配置'
|
||
}
|
||
|
||
# 根据db_type选择相应的数据库工具
|
||
if default_config.db_type.lower() == 'postgresql':
|
||
db_tool = self.postgresql_tool
|
||
elif default_config.db_type.lower() == 'mysql':
|
||
db_tool = self.mysql_tool
|
||
else:
|
||
return {
|
||
'success': False,
|
||
'message': f'不支持的数据库类型: {default_config.db_type}'
|
||
}
|
||
|
||
# 检查是否已有连接
|
||
if user_id_str not in db_tool.connections:
|
||
return {
|
||
'success': False,
|
||
'message': '数据库连接已断开,请重新连接数据库'
|
||
}
|
||
|
||
# 使用已建立的连接执行describe_table操作
|
||
result = await db_tool.execute(
|
||
operation="describe_table",
|
||
user_id=user_id_str,
|
||
table_name=table_name
|
||
)
|
||
|
||
if not result.success:
|
||
return {
|
||
'success': False,
|
||
'message': result.error
|
||
}
|
||
|
||
return {
|
||
'success': True,
|
||
'data': result.result,
|
||
'db_type': default_config.db_type
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取表结构失败: {str(e)}", exc_info=True)
|
||
return {
|
||
'success': False,
|
||
'message': f'获取表结构失败: {str(e)}'
|
||
}
|
||
|