hyf-backend/th_agenter/services/table_metadata_service.py

455 lines
19 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""表元数据管理服务"""
import json
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from datetime import datetime
from ..models.table_metadata import TableMetadata
from ..models.database_config import DatabaseConfig
from utils.util_exceptions import ValidationError, NotFoundError
from .postgresql_tool_manager import get_postgresql_tool
from .mysql_tool_manager import get_mysql_tool
from loguru import logger
class TableMetadataService:
"""表元数据管理服务"""
def __init__(self, db_session: Session):
self.session = db_session
self.postgresql_tool = get_postgresql_tool()
self.mysql_tool = get_mysql_tool()
async def collect_and_save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_names: List[str]
) -> Dict[str, Any]:
"""收集并保存表元数据"""
self.session.desc = f"为用户 {user_id} 收集数据库 {database_config_id} 的表元数据"
try:
# 获取数据库配置
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == database_config_id,
DatabaseConfig.created_by == user_id
)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = "ERROR: 数据库配置不存在"
raise NotFoundError("数据库配置不存在")
# 根据数据库类型选择相应的工具
if db_config.db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_config.db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
self.session.desc = f"ERROR: 不支持的数据库类型: {db_config.db_type}, 期望为postgresql或mysql"
raise Exception(f"不支持的数据库类型: {db_config.db_type}")
# 检查是否已有连接,如果没有则建立连接
user_id_str = str(user_id)
if user_id_str not in db_tool.connections:
connection_config = {
'host': db_config.host,
'port': db_config.port,
'database': db_config.database,
'username': db_config.username,
'password': self._decrypt_password(db_config.password)
}
# 连接数据库
connect_result = await db_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=user_id_str
)
if not connect_result.success:
self.session.desc = f"ERROR: 数据库连接失败: {connect_result.error}"
raise Exception(f"数据库连接失败: {connect_result.error}")
self.session.desc = f"SUCCESS: 为用户 {user_id} 建立了新的{db_config.db_type}数据库连接"
else:
self.session.desc = f"SUCCESS: 复用用户 {user_id} 的现有{db_config.db_type}数据库连接"
collected_tables = []
failed_tables = []
for table_name in table_names:
try:
# 收集表元数据
metadata = await self._collect_single_table_metadata(
user_id, table_name, db_config.db_type
)
# 保存或更新元数据
table_metadata = await self._save_table_metadata(
user_id, database_config_id, table_name, metadata
)
collected_tables.append({
'table_name': table_name,
'metadata_id': table_metadata.id,
'columns_count': len(metadata['columns_info']),
'sample_rows': len(metadata['sample_data'])
})
except Exception as e:
self.session.desc = f"ERROR: 收集表 {table_name} 元数据失败: {str(e)}"
failed_tables.append({
'table_name': table_name,
'error': str(e)
})
return {
'success': True,
'collected_tables': collected_tables,
'failed_tables': failed_tables,
'total_collected': len(collected_tables),
'total_failed': len(failed_tables)
}
except Exception as e:
self.session.desc = f"ERROR: 收集表元数据失败: {str(e)}"
return {
'success': False,
'message': str(e)
}
async def _collect_single_table_metadata(
self,
user_id: int,
table_name: str,
db_type: str
) -> Dict[str, Any]:
"""收集单个表的元数据"""
self.session.desc = f"为用户 {user_id} 收集表 {table_name} 的元数据"
# 根据数据库类型选择相应的工具
if db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
self.session.desc = f"ERROR: 不支持的数据库类型: {db_type}, 期望为postgresql或mysql"
raise Exception(f"不支持的数据库类型: {db_type}")
# 获取表结构
schema_result = await db_tool.execute(
operation="describe_table",
user_id=str(user_id),
table_name=table_name
)
if not schema_result.success:
self.session.desc = f"ERROR: 获取表 {table_name} 结构失败: {schema_result.error}"
raise Exception(f"获取表结构失败: {schema_result.error}")
schema_data = schema_result.result
# 获取示例数据前5条
sample_result = await db_tool.execute(
operation="execute_query",
user_id=str(user_id),
sql_query=f"SELECT * FROM {table_name} LIMIT 5",
limit=5
)
sample_data = []
if sample_result.success:
sample_data = sample_result.result.get('data', [])
# 获取行数统计
count_result = await db_tool.execute(
operation="execute_query",
user_id=str(user_id),
sql_query=f"SELECT COUNT(*) as total_rows FROM {table_name}",
limit=1
)
row_count = 0
if count_result.success and count_result.result.get('data'):
row_count = count_result.result['data'][0].get('total_rows', 0)
self.session.desc = f"SUCCESS: 为用户 {user_id} 收集表 {table_name} 的元数据, 包含 {len(schema_data.get('columns', []))} 列, {row_count} 行数据"
return {
'columns_info': schema_data.get('columns', []),
'primary_keys': schema_data.get('primary_keys', []),
'foreign_keys': schema_data.get('foreign_keys', []),
'indexes': schema_data.get('indexes', []),
'sample_data': sample_data,
'row_count': row_count,
'table_comment': schema_data.get('table_comment', '')
}
async def _save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_name: str,
metadata: Dict[str, Any]
) -> TableMetadata:
"""保存表元数据"""
self.session.desc = f"为用户 {user_id} 保存表 {table_name} 的元数据"
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 更新现有记录
existing.columns_info = metadata['columns_info']
existing.primary_keys = metadata['primary_keys']
existing.foreign_keys = metadata['foreign_keys']
existing.indexes = metadata['indexes']
existing.sample_data = metadata['sample_data']
existing.row_count = metadata['row_count']
existing.table_comment = metadata['table_comment']
existing.last_synced_at = datetime.utcnow()
await self.session.commit()
await self.session.refresh(existing)
return existing
else:
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 创建新记录
table_metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public',
table_type='BASE TABLE',
table_comment=metadata['table_comment'],
columns_info=metadata['columns_info'],
primary_keys=metadata['primary_keys'],
foreign_keys=metadata['foreign_keys'],
indexes=metadata['indexes'],
sample_data=metadata['sample_data'],
row_count=metadata['row_count'],
is_enabled_for_qa=True,
last_synced_at=datetime.utcnow()
)
self.session.add(table_metadata)
await self.session.commit()
await self.session.refresh(table_metadata)
self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
return table_metadata
async def save_table_metadata_config(
self,
user_id: int,
database_config_id: int,
table_names: List[str]
) -> Dict[str, Any]:
"""保存表元数据配置(简化版,只保存基本信息)"""
self.session.desc = f"为用户 {user_id} 保存数据库配置 {database_config_id}{table_names} 的元数据配置"
# 获取数据库配置
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == database_config_id,
DatabaseConfig.user_id == user_id
)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
raise NotFoundError("数据库配置不存在")
saved_tables = []
failed_tables = []
for table_name in table_names:
try:
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.user_id == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置"
# 更新现有记录
existing.is_enabled_for_qa = True
existing.last_synced_at = datetime.utcnow()
saved_tables.append({
'table_name': table_name,
'action': 'updated'
})
else:
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置"
# 创建新记录
metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public', # 默认值
table_type='table', # 默认值
table_comment='',
columns_count=0, # 后续可通过collect接口更新
row_count=0, # 后续可通过collect接口更新
is_enabled_for_qa=True,
qa_description='',
business_context='',
sample_data='{}',
column_info='{}',
last_synced_at=datetime.utcnow()
)
self.session.add(metadata)
saved_tables.append({
'table_name': table_name,
'action': 'created'
})
except Exception as e:
self.session.desc = f"ERROR: 保存用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置失败: {str(e)}"
failed_tables.append({
'table_name': table_name,
'error': str(e)
})
# 提交事务
await self.session.commit()
self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id}{table_names} 的元数据配置"
return {
'saved_tables': saved_tables,
'failed_tables': failed_tables,
'total_saved': len(saved_tables),
'total_failed': len(failed_tables)
}
async def get_user_table_metadata(
self,
user_id: int,
database_config_id: Optional[int] = None
) -> List[TableMetadata]:
"""获取用户的表元数据列表"""
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表元数据列表"
stmt = select(TableMetadata).where(TableMetadata.created_by == user_id)
if database_config_id:
stmt = stmt.where(TableMetadata.database_config_id == database_config_id)
else:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
raise NotFoundError("数据库配置不存在")
stmt = stmt.where(TableMetadata.is_enabled_for_qa == True)
return (await self.session.scalars(stmt)).all()
async def get_table_metadata_by_name(
self,
user_id: int,
database_config_id: int,
table_name: str
) -> Optional[TableMetadata]:
"""根据表名获取表元数据"""
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def update_table_qa_settings(
self,
user_id: int,
metadata_id: int,
settings: Dict[str, Any]
) -> bool:
"""更新表的问答设置"""
self.session.desc = f"更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置"
try:
stmt = select(TableMetadata).where(
TableMetadata.id == metadata_id,
TableMetadata.created_by == user_id
)
metadata = (await self.session.execute(stmt)).scalar_one_or_none()
if not metadata:
self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在"
return False
if 'is_enabled_for_qa' in settings:
metadata.is_enabled_for_qa = settings['is_enabled_for_qa']
if 'qa_description' in settings:
metadata.qa_description = settings['qa_description']
if 'business_context' in settings:
metadata.business_context = settings['business_context']
await self.session.commit()
return True
except Exception as e:
self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}"
await self.session.rollback()
return False
async def save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_name: str,
columns_info: List[Dict[str, Any]],
primary_keys: List[str],
row_count: int,
table_comment: str = ''
) -> TableMetadata:
"""保存单个表的元数据"""
self.session.desc = f"保存用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 已存在,更新其元数据"
# 更新现有记录
existing.columns_info = columns_info
existing.primary_keys = primary_keys
existing.row_count = row_count
existing.table_comment = table_comment
existing.last_synced_at = datetime.utcnow()
await self.session.commit()
return existing
else:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 不存在,创建新记录"
# 创建新记录
metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public',
table_type='BASE TABLE',
table_comment=table_comment,
columns_info=columns_info,
primary_keys=primary_keys,
row_count=row_count,
is_enabled_for_qa=True,
last_synced_at=datetime.utcnow()
)
self.session.add(metadata)
await self.session.commit()
await self.session.refresh(metadata)
return metadata
def _decrypt_password(self, encrypted_password: str) -> str:
"""解密密码(需要实现加密逻辑)"""
# 这里需要实现与DatabaseConfigService相同的解密逻辑
# 暂时返回原始密码
return encrypted_password