hyf-backend/th_agenter/services/smart_db_workflow.py

880 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, Any, List, Optional
import logging
from datetime import datetime
import asyncio
from concurrent.futures import ThreadPoolExecutor
from langchain_openai import ChatOpenAI
from th_agenter.core.context import UserContext
from .smart_query import DatabaseQueryService
from .postgresql_tool_manager import get_postgresql_tool
from .mysql_tool_manager import get_mysql_tool
from .table_metadata_service import TableMetadataService
from ..core.config import get_settings
# 配置日志
logger = logging.getLogger(__name__)
class SmartWorkflowError(Exception):
"""智能工作流自定义异常"""
pass
class DatabaseConnectionError(SmartWorkflowError):
"""数据库连接异常"""
pass
class TableSchemaError(SmartWorkflowError):
"""表结构获取异常"""
pass
class SQLGenerationError(SmartWorkflowError):
"""SQL生成异常"""
pass
class QueryExecutionError(SmartWorkflowError):
"""查询执行异常"""
pass
class SmartDatabaseWorkflowManager:
"""
智能数据库工作流管理器
负责协调数据库连接、表元数据获取、SQL生成、查询执行和AI总结的完整流程
"""
def __init__(self, db=None):
self.executor = ThreadPoolExecutor(max_workers=4)
self.database_service = DatabaseQueryService()
self.postgresql_tool = get_postgresql_tool()
self.mysql_tool = get_mysql_tool()
self.db = db
self.table_metadata_service = TableMetadataService(db) if db else None
async def initialize(self):
from ..core.new_agent import new_agent
self.llm = await new_agent()
def _get_database_tool(self, db_type: str):
"""根据数据库类型获取对应的数据库工具"""
if db_type.lower() == 'postgresql':
return self.postgresql_tool
elif db_type.lower() == 'mysql':
return self.mysql_tool
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
async def _run_in_executor(self, func, *args):
"""在线程池中运行阻塞函数"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
def _convert_query_result_to_table_data(self, query_result: Dict[str, Any]) -> Dict[str, Any]:
"""
将数据库查询结果转换为前端表格数据格式
参考Excel处理方式以表格形式返回结果
"""
try:
data = query_result.get('data', [])
columns = query_result.get('columns', [])
row_count = query_result.get('row_count', 0)
if not data or not columns:
return {
'result_type': 'table',
'columns': [],
'data': [],
'total': 0,
'message': '查询未返回数据'
}
# 构建列定义
table_columns = []
for i, col_name in enumerate(columns):
table_columns.append({
'prop': f'col_{i}',
'label': str(col_name),
'width': 'auto'
})
# 转换数据行
table_data = []
for row_index, row in enumerate(data):
row_data = {'_index': str(row_index)}
# 处理字典格式的行数据
if isinstance(row, dict):
for i, col_name in enumerate(columns):
col_prop = f'col_{i}'
value = row.get(col_name)
# 处理None值和特殊值
if value is None:
row_data[col_prop] = ''
elif isinstance(value, (int, float, str, bool)):
row_data[col_prop] = str(value)
else:
row_data[col_prop] = str(value)
else:
# 处理列表格式的行数据(兼容性处理)
for i, value in enumerate(row):
col_prop = f'col_{i}'
# 处理None值和特殊值
if value is None:
row_data[col_prop] = ''
elif isinstance(value, (int, float, str, bool)):
row_data[col_prop] = str(value)
else:
row_data[col_prop] = str(value)
table_data.append(row_data)
return {
'result_type': 'table_data',
'columns': table_columns,
'data': table_data,
'total': row_count,
'message': f'查询成功,共返回 {row_count} 条记录'
}
except Exception as e:
logger.error(f"转换查询结果异常: {str(e)}")
return {
'result_type': 'error',
'columns': [],
'data': [],
'total': 0,
'message': f'结果转换失败: {str(e)}'
}
async def process_database_query_stream(
self,
user_query: str,
user_id: int,
database_config_id: int
):
"""
流式处理数据库智能问数查询的主要工作流(基于保存的表元数据)
实时推送每个工作流步骤
新流程:
1. 根据database_config_id获取数据库配置并创建连接
2. 从系统数据库读取表元数据(只包含启用问答的表)
3. 根据表元数据生成SQL
4. 执行SQL查询
5. 查询数据后处理成表格形式
6. 生成数据总结
7. 返回结果
Args:
user_query: 用户问题
user_id: 用户ID
database_config_id: 数据库配置ID
Yields:
包含工作流步骤或最终结果的字典
"""
workflow_steps = []
try:
logger.info(f"开始执行流式数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...")
# 步骤1: 根据database_config_id获取数据库配置并创建连接
try:
step_data = {
'type': 'workflow_step',
'step': 'database_connection',
'status': 'running',
'message': '正在建立数据库连接...',
'timestamp': datetime.now().isoformat()
}
yield step_data
# 获取数据库配置并建立连接
connection_result = await self._connect_database(user_id, database_config_id)
if not connection_result['success']:
raise DatabaseConnectionError(connection_result['message'])
step_data.update({
'status': 'completed',
'message': '数据库连接成功',
'details': {'database': connection_result.get('database_name', 'Unknown')}
})
yield step_data
workflow_steps.append({
'step': 'database_connection',
'status': 'completed',
'message': '数据库连接成功'
})
except Exception as e:
error_msg = f'数据库连接失败: {str(e)}'
step_data = {
'type': 'workflow_step',
'step': 'database_connection',
'status': 'failed',
'message': error_msg,
'timestamp': datetime.now().isoformat()
}
yield step_data
yield {
'type': 'error',
'message': error_msg,
'workflow_steps': workflow_steps
}
return
# 步骤2: 从系统数据库读取表元数据(只包含启用问答的表)
try:
step_data = {
'type': 'workflow_step',
'step': 'table_metadata',
'status': 'running',
'message': '正在从系统数据库读取表元数据...',
'timestamp': datetime.now().isoformat()
}
yield step_data
# 从系统数据库读取已保存的表元数据(只包含启用问答的表)
tables_info = await self._get_saved_tables_metadata(user_id, database_config_id)
step_data.update({
'status': 'completed',
'message': f'成功读取 {len(tables_info)} 个启用问答的表元数据',
'details': {'table_count': len(tables_info), 'tables': list(tables_info.keys())}
})
yield step_data
workflow_steps.append({
'step': 'table_metadata',
'status': 'completed',
'message': f'成功读取表元数据'
})
except Exception as e:
error_msg = f'获取表元数据失败: {str(e)}'
step_data = {
'type': 'workflow_step',
'step': 'table_metadata',
'status': 'failed',
'message': error_msg,
'timestamp': datetime.now().isoformat()
}
yield step_data
yield {
'type': 'error',
'message': error_msg,
'workflow_steps': workflow_steps
}
return
# 步骤3: 根据表元数据生成SQL
try:
step_data = {
'type': 'workflow_step',
'step': 'sql_generation',
'status': 'running',
'message': '正在根据表元数据生成SQL查询...',
'timestamp': datetime.now().isoformat()
}
yield step_data
# 根据表元数据选择相关表并生成SQL
target_tables, target_schemas = await self._select_target_table(user_query, tables_info)
step_data = {
'type': 'workflow_step',
'step': 'table_selected',
'status': 'completed',
'message': f'已经智能选择了相关表: {", ".join(target_tables)}',
'timestamp': datetime.now().isoformat()
}
yield step_data
workflow_steps.append({
'step': 'table_metadata',
'status': 'completed',
'message': f'已经智能选择了相关表: {", ".join(target_tables)}',
})
sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas)
step_data.update({
'status': 'completed',
'message': 'SQL查询生成成功',
'details': {
'target_tables': target_tables,
'generated_sql': sql_query[:100] + '...' if len(sql_query) > 100 else sql_query
}
})
yield step_data
workflow_steps.append({
'step': 'sql_generation',
'status': 'completed',
'message': 'SQL语句生成成功'
})
except Exception as e:
error_msg = f'SQL生成失败: {str(e)}'
step_data = {
'type': 'workflow_step',
'step': 'sql_generation',
'status': 'failed',
'message': error_msg,
'timestamp': datetime.now().isoformat()
}
yield step_data
yield {
'type': 'error',
'message': error_msg,
'workflow_steps': workflow_steps
}
return
# 步骤4: 执行SQL查询
try:
step_data = {
'type': 'workflow_step',
'step': 'query_execution',
'status': 'running',
'message': '正在执行SQL查询...',
'timestamp': datetime.now().isoformat()
}
yield step_data
query_result = await self._execute_database_query(user_id, sql_query, database_config_id)
step_data.update({
'status': 'completed',
'message': f'查询执行成功,返回 {query_result.get("row_count", 0)} 条记录',
'details': {'row_count': query_result.get('row_count', 0)}
})
yield step_data
workflow_steps.append({
'step': 'query_execution',
'status': 'completed',
'message': '查询执行成功'
})
except Exception as e:
error_msg = f'查询执行失败: {str(e)}'
step_data = {
'type': 'workflow_step',
'step': 'query_execution',
'status': 'failed',
'message': error_msg,
'timestamp': datetime.now().isoformat()
}
yield step_data
yield {
'type': 'error',
'message': error_msg,
'workflow_steps': workflow_steps
}
return
# 步骤5: 查询数据后处理成表格形式在步骤6中完成
# 步骤6: 生成数据总结
try:
step_data = {
'type': 'workflow_step',
'step': 'ai_summary',
'status': 'running',
'message': '正在生成查询结果总结...',
'timestamp': datetime.now().isoformat()
}
yield step_data
summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables))
step_data.update({
'status': 'completed',
'message': '总结生成完成',
'details': {
'tables_analyzed': target_tables,
'summary_length': len(summary)
}
})
yield step_data
workflow_steps.append({
'step': 'ai_summary',
'status': 'completed',
'message': '总结生成完成'
})
except Exception as e:
logger.warning(f'生成总结失败: {str(e)}')
summary = '查询执行完成,但生成总结时出现问题。'
workflow_steps.append({
'step': 'ai_summary',
'status': 'warning',
'message': '总结生成失败,但查询成功'
})
# 步骤7: 返回最终结果且结果参考excel的处理方式尽量以表格形式返回
try:
step_data = {
'type': 'workflow_step',
'step': 'result_formatting',
'status': 'running',
'message': '正在格式化查询结果...',
'timestamp': datetime.now().isoformat()
}
yield step_data
# 转换为表格格式
table_data = self._convert_query_result_to_table_data(query_result)
step_data.update({
'status': 'completed',
'message': '结果格式化完成'
})
yield step_data
workflow_steps.append({
'step': 'result_formatting',
'status': 'completed',
'message': '结果格式化完成'
})
# 返回最终结果
final_result = {
'type': 'final_result',
'success': True,
'data': {
**table_data,
'generated_sql': sql_query,
'summary': summary,
'table_name': target_tables,
'query_result': query_result,
'metadata_source': 'saved_database' # 标记元数据来源
},
'workflow_steps': workflow_steps,
'timestamp': datetime.now().isoformat()
}
yield final_result
logger.info(f"数据库查询工作流完成 - 用户ID: {user_id}")
except Exception as e:
error_msg = f'结果格式化失败: {str(e)}'
yield {
'type': 'error',
'message': error_msg,
'workflow_steps': workflow_steps
}
return
except Exception as e:
logger.error(f"数据库查询工作流异常: {str(e)}", exc_info=True)
yield {
'type': 'error',
'message': f'系统异常: {str(e)}',
'workflow_steps': workflow_steps
}
async def _connect_database(self, user_id: int, database_config_id: int) -> Dict[str, Any]:
"""连接数据库(判断用户现有连接)"""
try:
# 获取数据库配置
from ..services.database_config_service import DatabaseConfigService
config_service = DatabaseConfigService(self.db)
config = config_service.get_config_by_id(database_config_id, user_id)
if not config:
return {'success': False, 'message': '数据库配置不存在'}
# 根据数据库类型选择对应的工具
try:
db_tool = self._get_database_tool(config.db_type)
except ValueError as e:
return {'success': False, 'message': str(e)}
# 测试连接(如果已经有连接则直接复用)
connection_config = {
'host': config.host,
'port': config.port,
'database': config.database,
'username': config.username,
'password': config_service._decrypt_password(config.password)
}
try:
connection = db_tool._test_connection(connection_config)
if connection['success'] == True:
return {
'success': True,
'database_name': config.database,
'db_type': config.db_type,
'message': '连接成功'
}
else:
return {
'success': False,
'database_name': config.database,
'db_type': config.db_type,
'message': '连接失败'
}
except Exception as e:
return {
'success': False,
'message': f'连接失败: {str(e)}'
}
except Exception as e:
logger.error(f"数据库连接异常: {str(e)}")
return {'success': False, 'message': f'连接异常: {str(e)}'}
async def _get_saved_tables_metadata(self, user_id: int, database_config_id: int) -> Dict[str, Dict[str, Any]]:
"""从系统数据库中读取已保存的表元数据"""
try:
if not self.table_metadata_service:
raise TableSchemaError("表元数据服务未初始化")
# 从数据库中获取表元数据
saved_metadata = self.table_metadata_service.get_user_table_metadata(
user_id, database_config_id
)
if not saved_metadata:
raise TableSchemaError(f"未找到数据库配置ID {database_config_id} 的表元数据,请先在数据库管理页面收集表元数据")
# 转换为所需格式
tables_metadata = {}
for meta in saved_metadata:
# 只处理启用问答的表
if meta.is_enabled_for_qa:
tables_metadata[meta.table_name] = {
'table_name': meta.table_name,
'columns': meta.columns_info or [],
'primary_keys': meta.primary_keys or [],
'row_count': meta.row_count or 0,
'table_comment': meta.table_comment or '',
'qa_description': meta.qa_description or '',
'business_context': meta.business_context or '',
'from_saved_metadata': True # 标记来源
}
if not tables_metadata:
raise TableSchemaError("没有启用问答的表,请在数据库管理页面启用相关表的问答功能")
logger.info(f"从系统数据库读取表元数据成功,共 {len(tables_metadata)} 个启用问答的表")
return tables_metadata
except Exception as e:
logger.error(f"读取保存的表元数据异常: {str(e)}")
raise TableSchemaError(f'读取表元数据失败: {str(e)}')
async def _get_table_schema(self, user_id: int, table_name: str, database_config_id: int) -> Dict[str, Any]:
"""获取指定表结构"""
try:
# 获取数据库配置
from ..services.database_config_service import DatabaseConfigService
config_service = DatabaseConfigService(self.db)
config = config_service.get_config_by_id(database_config_id, user_id)
if not config:
raise TableSchemaError('数据库配置不存在')
# 根据数据库类型选择对应的工具
try:
db_tool = self._get_database_tool(config.db_type)
except ValueError as e:
raise TableSchemaError(str(e))
# 使用对应的数据库工具获取表结构
schema_result = await db_tool.describe_table(table_name)
if schema_result.get('success'):
return schema_result.get('schema', {})
else:
raise TableSchemaError(schema_result.get('error', '获取表结构失败'))
except Exception as e:
logger.error(f"获取表结构异常: {str(e)}")
raise TableSchemaError(f'获取表结构失败: {str(e)}')
async def _select_target_table(self, user_query: str, tables_info: Dict[str, Dict]) -> tuple[List[str], List[Dict]]:
"""根据用户查询选择相关的表,支持返回多个表"""
try:
if len(tables_info) == 1:
# 只有一个表,直接返回
table_name = list(tables_info.keys())[0]
return [table_name], [tables_info[table_name]]
# 多个表时使用LLM选择相关的表
tables_summary = []
for table_name, schema in tables_info.items():
columns = schema.get('columns', [])
column_names = [col.get('column_name', col.get('name', '')) for col in columns]
qa_desc = schema.get('qa_description', '')
business_ctx = schema.get('business_context', '')
tables_summary.append(f"表名: {table_name}\n字段: {', '.join(column_names[:10])}\n表描述: {qa_desc}\n业务上下文: {business_ctx}")
prompt = f"""
用户查询: {user_query}
可用的表:
{chr(10).join(tables_summary)}
请根据用户查询选择相关的表,可以选择多个表。分析表之间可能的关联关系,返回所有相关的表名,用逗号分隔。
可以通过qa_description表描述business_context(表的业务上下文以及column_names几个字段判断要使用哪些表。
注意:只返回表名列表,后面不要跟其他的内容。
例如直接输出: table1,table2,table3
"""
response = await self.llm.ainvoke(prompt)
selected_tables = [t.strip() for t in response.content.strip().split(',')]
# 验证选择的表是否存在
valid_tables = []
valid_schemas = []
for table in selected_tables:
if table in tables_info:
valid_tables.append(table)
valid_schemas.append(tables_info[table])
else:
logger.warning(f"LLM选择的表 {table} 不存在")
if valid_tables:
return valid_tables, valid_schemas
else:
# 如果没有有效的表,选择第一个表
table_name = list(tables_info.keys())[0]
logger.warning(f"没有找到有效的表,使用默认表 {table_name}")
return [table_name], [tables_info[table_name]]
except Exception as e:
logger.error(f"选择目标表异常: {str(e)}")
# 出现异常时选择第一个表
table_name = list(tables_info.keys())[0]
return [table_name], [tables_info[table_name]]
async def _generate_sql_query(self, user_query: str, table_names: List[str], table_schemas: List[Dict]) -> str:
"""生成SQL语句支持多表关联查询"""
try:
# 构建所有表的结构信息
tables_info = []
for table_name, schema in zip(table_names, table_schemas):
columns_info = []
for col in schema.get('columns', []):
col_info = f"{col['column_name']} ({col['data_type']})"
columns_info.append(col_info)
table_info = f"表名: {table_name}\n"
table_info += f"表描述: {schema.get('qa_description', '')}\n"
table_info += f"业务上下文: {schema.get('business_context', '')}\n"
table_info += "字段信息:\n" + "\n".join(columns_info)
tables_info.append(table_info)
schema_text = "\n\n".join(tables_info)
prompt = f"""
基于以下表结构将自然语言查询转换为SQL语句。如果需要关联多个表请分析表之间的关系使用合适的JOIN语法
{schema_text}
用户查询: {user_query}
请生成对应的SQL查询语句要求
1. 只返回SQL语句不要包含其他解释
2. 如果查询涉及多个表,需要正确处理表之间的关联关系
3. 使用合适的JOIN类型INNER JOIN、LEFT JOIN等
4. 确保SELECT的字段来源明确必要时使用表名前缀
"""
# 使用LLM生成SQL
response = await self.llm.ainvoke(prompt)
sql_query = response.content.strip()
# 清理SQL语句
if sql_query.startswith('```sql'):
sql_query = sql_query[6:]
if sql_query.endswith('```'):
sql_query = sql_query[:-3]
sql_query = sql_query.strip()
logger.info(f"生成的SQL查询: {sql_query}")
return sql_query
except Exception as e:
logger.error(f"SQL生成异常: {str(e)}")
raise SQLGenerationError(f'SQL生成失败: {str(e)}')
async def _execute_database_query(self, user_id: int, sql_query: str, database_config_id: int) -> Dict[str, Any]:
"""执行SQL语句"""
try:
# 获取数据库配置
from ..services.database_config_service import DatabaseConfigService
config_service = DatabaseConfigService(self.db)
config = config_service.get_config_by_id(database_config_id, user_id)
if not config:
raise QueryExecutionError('数据库配置不存在')
# 根据数据库类型选择对应的工具
try:
db_tool = self._get_database_tool(config.db_type)
except ValueError as e:
raise QueryExecutionError(str(e))
# 使用对应的数据库工具执行查询
if str(user_id) in db_tool.connections:
query_result = db_tool._execute_query(db_tool.connections[str(user_id)]['connection'], sql_query)
else:
raise QueryExecutionError('请重新进行数据库连接')
if query_result.get('success'):
data = query_result.get('data', [])
return {
'success': True,
'data': data,
'row_count': len(data),
'columns': query_result.get('columns', []),
'sql_query': sql_query
}
else:
raise QueryExecutionError(query_result.get('error', '查询执行失败'))
except Exception as e:
logger.error(f"查询执行异常: {str(e)}")
raise QueryExecutionError(f'查询执行失败: {str(e)}')
async def _generate_database_summary(self, user_query: str, query_result: Dict, tables_str: str) -> str:
"""生成AI总结支持多表查询结果"""
try:
data = query_result.get('data', [])
row_count = query_result.get('row_count', 0)
columns = query_result.get('columns', [])
sql_query = query_result.get('sql_query', '')
# 构建总结提示词
prompt = f"""
用户查询: {user_query}
涉及的表: {tables_str}
查询结果: 共 {row_count} 条记录
查询的字段: {', '.join(columns)}
执行的SQL: {sql_query}
前几条数据示例:
{str(data[:3]) if data else '无数据'}
请基于以上信息,用中文生成一个简洁的查询结果总结,包括:
1. 查询涉及的表及其关系
2. 查询的主要发现和数据特征
3. 如果有关联查询,说明关联的结果特点
4. 最后对用户的问题进行回答
总结要求:
1. 语言简洁明了
2. 重点突出查询结果
3. 如果是多表查询,需要说明表之间的关系
4. 总结不超过300字
"""
# 使用LLM生成总结
response = await self.llm.ainvoke(prompt)
summary = response.content.strip()
logger.info(f"生成的总结: {summary[:100]}...")
return summary
except Exception as e:
logger.error(f"总结生成异常: {str(e)}")
return f"查询完成,共返回 {query_result.get('row_count', 0)} 条记录。涉及的表: {tables_str}"
async def process_database_query(
self,
user_query: str,
user_id: int,
database_config_id: int,
table_name: Optional[str] = None,
conversation_id: Optional[int] = None,
is_new_conversation: bool = False
) -> Dict[str, Any]:
"""
处理数据库智能问数查询的主要工作流(基于保存的表元数据)
新流程:
1. 根据database_config_id获取数据库配置
2. 创建数据库连接
3. 从系统数据库读取表元数据(只包含启用问答的表)
4. 根据表元数据生成SQL
5. 执行SQL查询
6. 查询数据后处理成表格形式
7. 生成数据总结
8. 返回结果
Args:
user_query: 用户问题
user_id: 用户ID
database_config_id: 数据库配置ID
table_name: 表名(可选)
conversation_id: 对话ID
is_new_conversation: 是否为新对话
Returns:
包含查询结果的字典
"""
try:
logger.info(f"开始执行数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...")
# 步骤1: 根据database_config_id获取数据库配置并创建连接
connection_result = await self._connect_database(user_id, database_config_id)
if not connection_result['success']:
raise DatabaseConnectionError(connection_result['message'])
logger.info("数据库连接成功")
# 步骤2: 从系统数据库读取表元数据(只包含启用问答的表)
tables_info = await self._get_saved_tables_metadata(user_id, database_config_id)
logger.info(f"表元数据读取完成 - 共{len(tables_info)}个启用问答的表")
# 步骤3: 根据表元数据选择相关表并生成SQL
target_tables, target_schemas = await self._select_target_table(user_query, tables_info)
sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas)
logger.info(f"SQL生成完成 - 目标表: {', '.join(target_tables)}")
# 步骤4: 执行SQL查询
query_result = await self._execute_database_query(user_id, sql_query, database_config_id)
logger.info("查询执行完成")
# 步骤5: 查询数据后处理成表格形式
table_data = self._convert_query_result_to_table_data(query_result)
# 步骤6: 生成数据总结
summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables))
# 步骤7: 返回结果
return {
'success': True,
'data': {
**table_data,
'generated_sql': sql_query,
'summary': summary,
'table_names': target_tables,
'query_result': query_result,
'metadata_source': 'saved_database' # 标记元数据来源
}
}
except SmartWorkflowError as e:
logger.error(f"数据库工作流异常: {str(e)}")
return {
'success': False,
'error': str(e),
'error_type': type(e).__name__
}
except Exception as e:
logger.error(f"数据库工作流未知异常: {str(e)}", exc_info=True)
return {
'success': False,
'error': f'系统异常: {str(e)}',
'error_type': 'SystemError'
}