hyf-backend/th_agenter/services/table_metadata_service.py

455 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""表元数据管理服务"""
import json
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from datetime import datetime
from ..models.table_metadata import TableMetadata
from ..models.database_config import DatabaseConfig
from utils.util_exceptions import ValidationError, NotFoundError
from .postgresql_tool_manager import get_postgresql_tool
from .mysql_tool_manager import get_mysql_tool
from loguru import logger
class TableMetadataService:
"""表元数据管理服务"""
def __init__(self, db_session: Session):
self.session = db_session
self.postgresql_tool = get_postgresql_tool()
self.mysql_tool = get_mysql_tool()
async def collect_and_save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_names: List[str]
) -> Dict[str, Any]:
"""收集并保存表元数据"""
self.session.desc = f"为用户 {user_id} 收集数据库 {database_config_id} 的表元数据"
try:
# 获取数据库配置
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == database_config_id,
DatabaseConfig.created_by == user_id
)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = "ERROR: 数据库配置不存在"
raise NotFoundError("数据库配置不存在")
# 根据数据库类型选择相应的工具
if db_config.db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_config.db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
self.session.desc = f"ERROR: 不支持的数据库类型: {db_config.db_type}, 期望为postgresql或mysql"
raise Exception(f"不支持的数据库类型: {db_config.db_type}")
# 检查是否已有连接,如果没有则建立连接
user_id_str = str(user_id)
if user_id_str not in db_tool.connections:
connection_config = {
'host': db_config.host,
'port': db_config.port,
'database': db_config.database,
'username': db_config.username,
'password': self._decrypt_password(db_config.password)
}
# 连接数据库
connect_result = await db_tool.execute(
operation="connect",
connection_config=connection_config,
user_id=user_id_str
)
if not connect_result.success:
self.session.desc = f"ERROR: 数据库连接失败: {connect_result.error}"
raise Exception(f"数据库连接失败: {connect_result.error}")
self.session.desc = f"SUCCESS: 为用户 {user_id} 建立了新的{db_config.db_type}数据库连接"
else:
self.session.desc = f"SUCCESS: 复用用户 {user_id} 的现有{db_config.db_type}数据库连接"
collected_tables = []
failed_tables = []
for table_name in table_names:
try:
# 收集表元数据
metadata = await self._collect_single_table_metadata(
user_id, table_name, db_config.db_type
)
# 保存或更新元数据
table_metadata = await self._save_table_metadata(
user_id, database_config_id, table_name, metadata
)
collected_tables.append({
'table_name': table_name,
'metadata_id': table_metadata.id,
'columns_count': len(metadata['columns_info']),
'sample_rows': len(metadata['sample_data'])
})
except Exception as e:
self.session.desc = f"ERROR: 收集表 {table_name} 元数据失败: {str(e)}"
failed_tables.append({
'table_name': table_name,
'error': str(e)
})
return {
'success': True,
'collected_tables': collected_tables,
'failed_tables': failed_tables,
'total_collected': len(collected_tables),
'total_failed': len(failed_tables)
}
except Exception as e:
self.session.desc = f"ERROR: 收集表元数据失败: {str(e)}"
return {
'success': False,
'message': str(e)
}
async def _collect_single_table_metadata(
self,
user_id: int,
table_name: str,
db_type: str
) -> Dict[str, Any]:
"""收集单个表的元数据"""
self.session.desc = f"为用户 {user_id} 收集表 {table_name} 的元数据"
# 根据数据库类型选择相应的工具
if db_type.lower() == 'postgresql':
db_tool = self.postgresql_tool
elif db_type.lower() == 'mysql':
db_tool = self.mysql_tool
else:
self.session.desc = f"ERROR: 不支持的数据库类型: {db_type}, 期望为postgresql或mysql"
raise Exception(f"不支持的数据库类型: {db_type}")
# 获取表结构
schema_result = await db_tool.execute(
operation="describe_table",
user_id=str(user_id),
table_name=table_name
)
if not schema_result.success:
self.session.desc = f"ERROR: 获取表 {table_name} 结构失败: {schema_result.error}"
raise Exception(f"获取表结构失败: {schema_result.error}")
schema_data = schema_result.result
# 获取示例数据前5条
sample_result = await db_tool.execute(
operation="execute_query",
user_id=str(user_id),
sql_query=f"SELECT * FROM {table_name} LIMIT 5",
limit=5
)
sample_data = []
if sample_result.success:
sample_data = sample_result.result.get('data', [])
# 获取行数统计
count_result = await db_tool.execute(
operation="execute_query",
user_id=str(user_id),
sql_query=f"SELECT COUNT(*) as total_rows FROM {table_name}",
limit=1
)
row_count = 0
if count_result.success and count_result.result.get('data'):
row_count = count_result.result['data'][0].get('total_rows', 0)
self.session.desc = f"SUCCESS: 为用户 {user_id} 收集表 {table_name} 的元数据, 包含 {len(schema_data.get('columns', []))} 列, {row_count} 行数据"
return {
'columns_info': schema_data.get('columns', []),
'primary_keys': schema_data.get('primary_keys', []),
'foreign_keys': schema_data.get('foreign_keys', []),
'indexes': schema_data.get('indexes', []),
'sample_data': sample_data,
'row_count': row_count,
'table_comment': schema_data.get('table_comment', '')
}
async def _save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_name: str,
metadata: Dict[str, Any]
) -> TableMetadata:
"""保存表元数据"""
self.session.desc = f"为用户 {user_id} 保存表 {table_name} 的元数据"
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 更新现有记录
existing.columns_info = metadata['columns_info']
existing.primary_keys = metadata['primary_keys']
existing.foreign_keys = metadata['foreign_keys']
existing.indexes = metadata['indexes']
existing.sample_data = metadata['sample_data']
existing.row_count = metadata['row_count']
existing.table_comment = metadata['table_comment']
existing.last_synced_at = datetime.utcnow()
await self.session.commit()
await self.session.refresh(existing)
return existing
else:
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 创建新记录
table_metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public',
table_type='BASE TABLE',
table_comment=metadata['table_comment'],
columns_info=metadata['columns_info'],
primary_keys=metadata['primary_keys'],
foreign_keys=metadata['foreign_keys'],
indexes=metadata['indexes'],
sample_data=metadata['sample_data'],
row_count=metadata['row_count'],
is_enabled_for_qa=True,
last_synced_at=datetime.utcnow()
)
self.session.add(table_metadata)
await self.session.commit()
await self.session.refresh(table_metadata)
self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
return table_metadata
async def save_table_metadata_config(
self,
user_id: int,
database_config_id: int,
table_names: List[str]
) -> Dict[str, Any]:
"""保存表元数据配置(简化版,只保存基本信息)"""
self.session.desc = f"为用户 {user_id} 保存数据库配置 {database_config_id}{table_names} 的元数据配置"
# 获取数据库配置
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == database_config_id,
DatabaseConfig.user_id == user_id
)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
raise NotFoundError("数据库配置不存在")
saved_tables = []
failed_tables = []
for table_name in table_names:
try:
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.user_id == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置"
# 更新现有记录
existing.is_enabled_for_qa = True
existing.last_synced_at = datetime.utcnow()
saved_tables.append({
'table_name': table_name,
'action': 'updated'
})
else:
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置"
# 创建新记录
metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public', # 默认值
table_type='table', # 默认值
table_comment='',
columns_count=0, # 后续可通过collect接口更新
row_count=0, # 后续可通过collect接口更新
is_enabled_for_qa=True,
qa_description='',
business_context='',
sample_data='{}',
column_info='{}',
last_synced_at=datetime.utcnow()
)
self.session.add(metadata)
saved_tables.append({
'table_name': table_name,
'action': 'created'
})
except Exception as e:
self.session.desc = f"ERROR: 保存用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置失败: {str(e)}"
failed_tables.append({
'table_name': table_name,
'error': str(e)
})
# 提交事务
await self.session.commit()
self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id}{table_names} 的元数据配置"
return {
'saved_tables': saved_tables,
'failed_tables': failed_tables,
'total_saved': len(saved_tables),
'total_failed': len(failed_tables)
}
async def get_user_table_metadata(
self,
user_id: int,
database_config_id: Optional[int] = None
) -> List[TableMetadata]:
"""获取用户的表元数据列表"""
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表元数据列表"
stmt = select(TableMetadata).where(TableMetadata.created_by == user_id)
if database_config_id:
stmt = stmt.where(TableMetadata.database_config_id == database_config_id)
else:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
raise NotFoundError("数据库配置不存在")
stmt = stmt.where(TableMetadata.is_enabled_for_qa == True)
return (await self.session.scalars(stmt)).all()
async def get_table_metadata_by_name(
self,
user_id: int,
database_config_id: int,
table_name: str
) -> Optional[TableMetadata]:
"""根据表名获取表元数据"""
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def update_table_qa_settings(
self,
user_id: int,
metadata_id: int,
settings: Dict[str, Any]
) -> bool:
"""更新表的问答设置"""
self.session.desc = f"更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置"
try:
stmt = select(TableMetadata).where(
TableMetadata.id == metadata_id,
TableMetadata.created_by == user_id
)
metadata = (await self.session.execute(stmt)).scalar_one_or_none()
if not metadata:
self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在"
return False
if 'is_enabled_for_qa' in settings:
metadata.is_enabled_for_qa = settings['is_enabled_for_qa']
if 'qa_description' in settings:
metadata.qa_description = settings['qa_description']
if 'business_context' in settings:
metadata.business_context = settings['business_context']
await self.session.commit()
return True
except Exception as e:
self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}"
await self.session.rollback()
return False
async def save_table_metadata(
self,
user_id: int,
database_config_id: int,
table_name: str,
columns_info: List[Dict[str, Any]],
primary_keys: List[str],
row_count: int,
table_comment: str = ''
) -> TableMetadata:
"""保存单个表的元数据"""
self.session.desc = f"保存用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
# 检查是否已存在
stmt = select(TableMetadata).where(
TableMetadata.created_by == user_id,
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 已存在,更新其元数据"
# 更新现有记录
existing.columns_info = columns_info
existing.primary_keys = primary_keys
existing.row_count = row_count
existing.table_comment = table_comment
existing.last_synced_at = datetime.utcnow()
await self.session.commit()
return existing
else:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 不存在,创建新记录"
# 创建新记录
metadata = TableMetadata(
created_by=user_id,
database_config_id=database_config_id,
table_name=table_name,
table_schema='public',
table_type='BASE TABLE',
table_comment=table_comment,
columns_info=columns_info,
primary_keys=primary_keys,
row_count=row_count,
is_enabled_for_qa=True,
last_synced_at=datetime.utcnow()
)
self.session.add(metadata)
await self.session.commit()
await self.session.refresh(metadata)
return metadata
def _decrypt_password(self, encrypted_password: str) -> str:
"""解密密码(需要实现加密逻辑)"""
# 这里需要实现与DatabaseConfigService相同的解密逻辑
# 暂时返回原始密码
return encrypted_password