"""表元数据管理服务""" import json from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session from sqlalchemy import select, func from datetime import datetime from ..models.table_metadata import TableMetadata from ..models.database_config import DatabaseConfig from utils.util_exceptions import ValidationError, NotFoundError from .postgresql_tool_manager import get_postgresql_tool from .mysql_tool_manager import get_mysql_tool from loguru import logger class TableMetadataService: """表元数据管理服务""" def __init__(self, db_session: Session): self.session = db_session self.postgresql_tool = get_postgresql_tool() self.mysql_tool = get_mysql_tool() async def collect_and_save_table_metadata( self, user_id: int, database_config_id: int, table_names: List[str] ) -> Dict[str, Any]: """收集并保存表元数据""" self.session.desc = f"为用户 {user_id} 收集数据库 {database_config_id} 的表元数据" try: # 获取数据库配置 stmt = select(DatabaseConfig).where( DatabaseConfig.id == database_config_id, DatabaseConfig.created_by == user_id ) db_config = (await self.session.execute(stmt)).scalar_one_or_none() if not db_config: self.session.desc = "ERROR: 数据库配置不存在" raise NotFoundError("数据库配置不存在") # 根据数据库类型选择相应的工具 if db_config.db_type.lower() == 'postgresql': db_tool = self.postgresql_tool elif db_config.db_type.lower() == 'mysql': db_tool = self.mysql_tool else: self.session.desc = f"ERROR: 不支持的数据库类型: {db_config.db_type}, 期望为postgresql或mysql" raise Exception(f"不支持的数据库类型: {db_config.db_type}") # 检查是否已有连接,如果没有则建立连接 user_id_str = str(user_id) if user_id_str not in db_tool.connections: connection_config = { 'host': db_config.host, 'port': db_config.port, 'database': db_config.database, 'username': db_config.username, 'password': self._decrypt_password(db_config.password) } # 连接数据库 connect_result = await db_tool.execute( operation="connect", connection_config=connection_config, user_id=user_id_str ) if not connect_result.success: self.session.desc = f"ERROR: 数据库连接失败: {connect_result.error}" raise Exception(f"数据库连接失败: {connect_result.error}") self.session.desc = f"SUCCESS: 为用户 {user_id} 建立了新的{db_config.db_type}数据库连接" else: self.session.desc = f"SUCCESS: 复用用户 {user_id} 的现有{db_config.db_type}数据库连接" collected_tables = [] failed_tables = [] for table_name in table_names: try: # 收集表元数据 metadata = await self._collect_single_table_metadata( user_id, table_name, db_config.db_type ) # 保存或更新元数据 table_metadata = await self._save_table_metadata( user_id, database_config_id, table_name, metadata ) collected_tables.append({ 'table_name': table_name, 'metadata_id': table_metadata.id, 'columns_count': len(metadata['columns_info']), 'sample_rows': len(metadata['sample_data']) }) except Exception as e: self.session.desc = f"ERROR: 收集表 {table_name} 元数据失败: {str(e)}" failed_tables.append({ 'table_name': table_name, 'error': str(e) }) return { 'success': True, 'collected_tables': collected_tables, 'failed_tables': failed_tables, 'total_collected': len(collected_tables), 'total_failed': len(failed_tables) } except Exception as e: self.session.desc = f"ERROR: 收集表元数据失败: {str(e)}" return { 'success': False, 'message': str(e) } async def _collect_single_table_metadata( self, user_id: int, table_name: str, db_type: str ) -> Dict[str, Any]: """收集单个表的元数据""" self.session.desc = f"为用户 {user_id} 收集表 {table_name} 的元数据" # 根据数据库类型选择相应的工具 if db_type.lower() == 'postgresql': db_tool = self.postgresql_tool elif db_type.lower() == 'mysql': db_tool = self.mysql_tool else: self.session.desc = f"ERROR: 不支持的数据库类型: {db_type}, 期望为postgresql或mysql" raise Exception(f"不支持的数据库类型: {db_type}") # 获取表结构 schema_result = await db_tool.execute( operation="describe_table", user_id=str(user_id), table_name=table_name ) if not schema_result.success: self.session.desc = f"ERROR: 获取表 {table_name} 结构失败: {schema_result.error}" raise Exception(f"获取表结构失败: {schema_result.error}") schema_data = schema_result.result # 获取示例数据(前5条) sample_result = await db_tool.execute( operation="execute_query", user_id=str(user_id), sql_query=f"SELECT * FROM {table_name} LIMIT 5", limit=5 ) sample_data = [] if sample_result.success: sample_data = sample_result.result.get('data', []) # 获取行数统计 count_result = await db_tool.execute( operation="execute_query", user_id=str(user_id), sql_query=f"SELECT COUNT(*) as total_rows FROM {table_name}", limit=1 ) row_count = 0 if count_result.success and count_result.result.get('data'): row_count = count_result.result['data'][0].get('total_rows', 0) self.session.desc = f"SUCCESS: 为用户 {user_id} 收集表 {table_name} 的元数据, 包含 {len(schema_data.get('columns', []))} 列, {row_count} 行数据" return { 'columns_info': schema_data.get('columns', []), 'primary_keys': schema_data.get('primary_keys', []), 'foreign_keys': schema_data.get('foreign_keys', []), 'indexes': schema_data.get('indexes', []), 'sample_data': sample_data, 'row_count': row_count, 'table_comment': schema_data.get('table_comment', '') } async def _save_table_metadata( self, user_id: int, database_config_id: int, table_name: str, metadata: Dict[str, Any] ) -> TableMetadata: """保存表元数据""" self.session.desc = f"为用户 {user_id} 保存表 {table_name} 的元数据" # 检查是否已存在 stmt = select(TableMetadata).where( TableMetadata.created_by == user_id, TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" # 更新现有记录 existing.columns_info = metadata['columns_info'] existing.primary_keys = metadata['primary_keys'] existing.foreign_keys = metadata['foreign_keys'] existing.indexes = metadata['indexes'] existing.sample_data = metadata['sample_data'] existing.row_count = metadata['row_count'] existing.table_comment = metadata['table_comment'] existing.last_synced_at = datetime.utcnow() await self.session.commit() await self.session.refresh(existing) return existing else: self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" # 创建新记录 table_metadata = TableMetadata( created_by=user_id, database_config_id=database_config_id, table_name=table_name, table_schema='public', table_type='BASE TABLE', table_comment=metadata['table_comment'], columns_info=metadata['columns_info'], primary_keys=metadata['primary_keys'], foreign_keys=metadata['foreign_keys'], indexes=metadata['indexes'], sample_data=metadata['sample_data'], row_count=metadata['row_count'], is_enabled_for_qa=True, last_synced_at=datetime.utcnow() ) self.session.add(table_metadata) await self.session.commit() await self.session.refresh(table_metadata) self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" return table_metadata async def save_table_metadata_config( self, user_id: int, database_config_id: int, table_names: List[str] ) -> Dict[str, Any]: """保存表元数据配置(简化版,只保存基本信息)""" self.session.desc = f"为用户 {user_id} 保存数据库配置 {database_config_id} 表 {table_names} 的元数据配置" # 获取数据库配置 stmt = select(DatabaseConfig).where( DatabaseConfig.id == database_config_id, DatabaseConfig.user_id == user_id ) db_config = (await self.session.execute(stmt)).scalar_one_or_none() if not db_config: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" raise NotFoundError("数据库配置不存在") saved_tables = [] failed_tables = [] for table_name in table_names: try: # 检查是否已存在 stmt = select(TableMetadata).where( TableMetadata.user_id == user_id, TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置" # 更新现有记录 existing.is_enabled_for_qa = True existing.last_synced_at = datetime.utcnow() saved_tables.append({ 'table_name': table_name, 'action': 'updated' }) else: self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置" # 创建新记录 metadata = TableMetadata( created_by=user_id, database_config_id=database_config_id, table_name=table_name, table_schema='public', # 默认值 table_type='table', # 默认值 table_comment='', columns_count=0, # 后续可通过collect接口更新 row_count=0, # 后续可通过collect接口更新 is_enabled_for_qa=True, qa_description='', business_context='', sample_data='{}', column_info='{}', last_synced_at=datetime.utcnow() ) self.session.add(metadata) saved_tables.append({ 'table_name': table_name, 'action': 'created' }) except Exception as e: self.session.desc = f"ERROR: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置失败: {str(e)}" failed_tables.append({ 'table_name': table_name, 'error': str(e) }) # 提交事务 await self.session.commit() self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_names} 的元数据配置" return { 'saved_tables': saved_tables, 'failed_tables': failed_tables, 'total_saved': len(saved_tables), 'total_failed': len(failed_tables) } async def get_user_table_metadata( self, user_id: int, database_config_id: Optional[int] = None ) -> List[TableMetadata]: """获取用户的表元数据列表""" self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表元数据列表" stmt = select(TableMetadata).where(TableMetadata.created_by == user_id) if database_config_id: stmt = stmt.where(TableMetadata.database_config_id == database_config_id) else: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" raise NotFoundError("数据库配置不存在") stmt = stmt.where(TableMetadata.is_enabled_for_qa == True) return (await self.session.scalars(stmt)).all() async def get_table_metadata_by_name( self, user_id: int, database_config_id: int, table_name: str ) -> Optional[TableMetadata]: """根据表名获取表元数据""" self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" stmt = select(TableMetadata).where( TableMetadata.created_by == user_id, TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) return (await self.session.execute(stmt)).scalar_one_or_none() async def update_table_qa_settings( self, user_id: int, metadata_id: int, settings: Dict[str, Any] ) -> bool: """更新表的问答设置""" self.session.desc = f"更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置" try: stmt = select(TableMetadata).where( TableMetadata.id == metadata_id, TableMetadata.created_by == user_id ) metadata = (await self.session.execute(stmt)).scalar_one_or_none() if not metadata: self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在" return False if 'is_enabled_for_qa' in settings: metadata.is_enabled_for_qa = settings['is_enabled_for_qa'] if 'qa_description' in settings: metadata.qa_description = settings['qa_description'] if 'business_context' in settings: metadata.business_context = settings['business_context'] await self.session.commit() return True except Exception as e: self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}" await self.session.rollback() return False async def save_table_metadata( self, user_id: int, database_config_id: int, table_name: str, columns_info: List[Dict[str, Any]], primary_keys: List[str], row_count: int, table_comment: str = '' ) -> TableMetadata: """保存单个表的元数据""" self.session.desc = f"保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" # 检查是否已存在 stmt = select(TableMetadata).where( TableMetadata.created_by == user_id, TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 已存在,更新其元数据" # 更新现有记录 existing.columns_info = columns_info existing.primary_keys = primary_keys existing.row_count = row_count existing.table_comment = table_comment existing.last_synced_at = datetime.utcnow() await self.session.commit() return existing else: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 不存在,创建新记录" # 创建新记录 metadata = TableMetadata( created_by=user_id, database_config_id=database_config_id, table_name=table_name, table_schema='public', table_type='BASE TABLE', table_comment=table_comment, columns_info=columns_info, primary_keys=primary_keys, row_count=row_count, is_enabled_for_qa=True, last_synced_at=datetime.utcnow() ) self.session.add(metadata) await self.session.commit() await self.session.refresh(metadata) return metadata def _decrypt_password(self, encrypted_password: str) -> str: """解密密码(需要实现加密逻辑)""" # 这里需要实现与DatabaseConfigService相同的解密逻辑 # 暂时返回原始密码 return encrypted_password