hxf/backend/th_agenter/services/database_config_service.py

375 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库配置服务"""
from loguru import logger
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select
from cryptography.fernet import Fernet
import os
from ..models.database_config import DatabaseConfig
from utils.util_exceptions import ValidationError, NotFoundError
from .postgresql_tool_manager import get_postgresql_tool
from .mysql_tool_manager import get_mysql_tool
class DatabaseConfigService:
"""数据库配置管理服务"""
def __init__(self, db_session: Session):
self.session = db_session
self.postgresql_tool = get_postgresql_tool()
self.mysql_tool = get_mysql_tool()
# 初始化加密密钥
self.encryption_key = self._get_or_create_encryption_key()
self.cipher = Fernet(self.encryption_key)
def _get_or_create_encryption_key(self) -> bytes:
"""获取或创建加密密钥"""
key_file = "db/db_config_key.key"
if os.path.exists(key_file):
print('find db_config_key')
with open(key_file, 'rb') as f:
return f.read()
else:
print('not find db_config_key')
key = Fernet.generate_key()
with open(key_file, 'wb') as f:
f.write(key)
return key
def _encrypt_password(self, password: str) -> str:
"""加密密码"""
return self.cipher.encrypt(password.encode()).decode()
def _decrypt_password(self, encrypted_password: str) -> str:
"""解密密码"""
return self.cipher.decrypt(encrypted_password.encode()).decode()
async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
"""创建数据库配置"""
try:
# 验证配置
required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password']
for field in required_fields:
if field not in config_data:
raise ValidationError(f"缺少必需字段: {field}")
# 测试连接
test_config = {
'host': config_data['host'],
'port': config_data['port'],
'database': config_data['database'],
'username': config_data['username'],
'password': config_data['password']
}
if 'postgresql' == config_data['db_type']:
test_result = await self.postgresql_tool.execute(
operation="test_connection",
connection_config=test_config
)
if not test_result.success:
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
elif 'mysql' == config_data['db_type']:
test_result = await self.mysql_tool.execute(
operation="test_connection",
connection_config=test_config
)
if not test_result.success:
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
# 如果设置为默认配置,先取消其他默认配置
if config_data.get('is_default', False):
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
DatabaseConfig.is_default == True
)
result = self.session.execute(stmt)
for config in result.scalars():
config.is_default = False
# 创建配置
db_config = DatabaseConfig(
created_by=user_id,
name=config_data['name'],
db_type=config_data['db_type'],
host=config_data['host'],
port=config_data['port'],
database=config_data['database'],
username=config_data['username'],
password=self._encrypt_password(config_data['password']),
is_active=config_data.get('is_active', True),
is_default=config_data.get('is_default', False),
connection_params=config_data.get('connection_params')
)
self.session.add(db_config)
await self.session.commit()
await self.session.refresh(db_config)
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
return db_config
except Exception as e:
await self.session.rollback()
logger.error(f"创建数据库配置失败: {str(e)}")
raise
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
"""获取用户的数据库配置列表"""
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
if active_only:
stmt = stmt.where(DatabaseConfig.is_active == True)
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
return (await self.session.execute(stmt)).scalars().all()
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
"""根据ID获取配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == config_id,
DatabaseConfig.created_by == user_id
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
"""获取用户的默认配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
# DatabaseConfig.is_default == True,
DatabaseConfig.is_active == True
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
"""测试数据库连接"""
config = self.get_config_by_id(config_id, user_id)
if not config:
raise NotFoundError("数据库配置不存在")
test_config = {
'host': config.host,
'port': config.port,
'database': config.database,
'username': config.username,
'password': self._decrypt_password(config.password)
}
result = await self.postgresql_tool.execute(
operation="test_connection",
connection_config=test_config
)
return {
'success': result.success,
'message': result.result.get('message') if result.success else result.error,
'details': result.result if result.success else None
}
async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]:
"""连接数据库并获取表列表"""
config = self.get_config_by_id(config_id, user_id)
if not config:
raise NotFoundError("数据库配置不存在")
connection_config = {
'host': config.host,
'port': config.port,
'database': config.database,
'username': config.username,
'password': self._decrypt_password(config.password)
}
if 'postgresql' == config.db_type:
# 连接数据库
connect_result = await self.postgresql_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=str(user_id)
)
elif 'mysql' == config.db_type:
# 连接数据库
connect_result = await self.mysql_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=str(user_id)
)
if not connect_result.success:
return {
'success': False,
'message': connect_result.error
}
# 连接信息已保存到PostgreSQLMCPTool的connections中
return {
'success': True,
'data': connect_result.result,
'config_name': config.name
}
async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]:
"""获取表数据预览(复用已建立的连接)"""
try:
user_id_str = str(user_id)
# 根据db_type选择相应的数据库工具
if db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
return {
'success': False,
'message': f'不支持的数据库类型: {db_type}'
}
# 检查是否已有连接
if user_id_str not in db_tool.connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接数据库'
}
# 直接使用已建立的连接执行查询
sql_query = f"SELECT * FROM {table_name}"
result = await db_tool.execute(
operation="execute_query",
user_id=user_id_str,
sql_query=sql_query,
limit=limit
)
if not result.success:
return {
'success': False,
'message': result.error
}
return {
'success': True,
'data': result.result,
'db_type': db_type
}
except Exception as e:
logger.error(f"获取表数据失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'获取表数据失败: {str(e)}'
}
def disconnect_database(self, user_id: int) -> Dict[str, Any]:
"""断开数据库连接"""
try:
# 从PostgreSQLMCPTool断开连接
self.postgresql_tool.execute(
operation="disconnect",
user_id=str(user_id)
)
# 从本地连接管理中移除
if user_id in self.user_connections:
del self.user_connections[user_id]
return {
'success': True,
'message': '数据库连接已断开'
}
except Exception as e:
return {
'success': False,
'message': f'断开连接失败: {str(e)}'
}
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
"""根据数据库类型获取用户配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
DatabaseConfig.db_type == db_type,
DatabaseConfig.is_active == True
)
return await self.session.scalar(stmt)
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
"""创建或更新数据库配置保证db_type唯一性"""
try:
# 检查是否已存在该类型的配置
existing_config = self.get_config_by_type(user_id, config_data['db_type'])
if existing_config:
# 更新现有配置
for key, value in config_data.items():
if key == 'password':
setattr(existing_config, key, self._encrypt_password(value))
elif hasattr(existing_config, key):
setattr(existing_config, key, value)
await self.session.commit()
await self.session.refresh(existing_config)
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
return existing_config
else:
# 创建新配置
return await self.create_config(user_id, config_data)
except Exception as e:
await self.session.rollback()
logger.error(f"创建或更新数据库配置失败: {str(e)}")
raise
async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]:
"""获取表结构信息(复用已建立的连接)"""
try:
logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}")
user_id_str = str(user_id)
# 获取用户默认数据库配置
default_config = self.get_default_config(user_id)
if not default_config:
return {
'success': False,
'message': '未找到默认数据库配置'
}
# 根据db_type选择相应的数据库工具
if default_config.db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif default_config.db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
return {
'success': False,
'message': f'不支持的数据库类型: {default_config.db_type}'
}
# 检查是否已有连接
if user_id_str not in db_tool.connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接数据库'
}
# 使用已建立的连接执行describe_table操作
result = await db_tool.execute(
operation="describe_table",
user_id=user_id_str,
table_name=table_name
)
if not result.success:
return {
'success': False,
'message': result.error
}
return {
'success': True,
'data': result.result,
'db_type': default_config.db_type
}
except Exception as e:
logger.error(f"获取表结构失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'获取表结构失败: {str(e)}'
}