hyf-backend/th_agenter/services/smart_query.py

717 lines
25 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
import pandas as pd
import pymysql
import psycopg2
import tempfile
import os
from typing import Dict, Any, List
from datetime import datetime
import asyncio
from concurrent.futures import ThreadPoolExecutor
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from loguru import logger
# 在 SmartQueryService 类中添加方法
from .table_metadata_service import TableMetadataService
class SmartQueryService:
"""
智能问数服务基类
"""
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=4)
self.table_metadata_service = None
def set_db_session(self, db_session):
"""设置数据库会话"""
self.table_metadata_service = TableMetadataService(db_session)
async def _run_in_executor(self, func, *args):
"""在线程池中运行阻塞函数"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
class ExcelAnalysisService(SmartQueryService):
"""
Excel数据分析服务
"""
def __init__(self):
super().__init__()
self.user_dataframes = {} # 存储用户的DataFrame
def analyze_dataframe(self, df: pd.DataFrame, filename: str) -> Dict[str, Any]:
"""
分析DataFrame并返回基本信息
"""
try:
# 基本统计信息
rows, columns = df.shape
# 列信息
column_info = []
for col in df.columns:
col_info = {
'name': col,
'dtype': str(df[col].dtype),
'null_count': int(df[col].isnull().sum()),
'unique_count': int(df[col].nunique())
}
# 如果是数值列,添加统计信息
if pd.api.types.is_numeric_dtype(df[col]):
df.fillna({col:0}) #数值列将空值补0
col_info.update({
'mean': float(df[col].mean()) if not df[col].isnull().all() else None,
'std': float(df[col].std()) if not df[col].isnull().all() else None,
'min': float(df[col].min()) if not df[col].isnull().all() else None,
'max': float(df[col].max()) if not df[col].isnull().all() else None
})
column_info.append(col_info)
# 数据预览前5行
preview_data = df.head().fillna('').to_dict('records')
# 数据质量检查
quality_issues = []
# 检查缺失值
missing_cols = df.columns[df.isnull().any()].tolist()
if missing_cols:
quality_issues.append({
'type': 'missing_values',
'description': f'以下列存在缺失值: {", ".join(map(str, missing_cols))}',
'columns': missing_cols
})
# 检查重复行
duplicate_count = df.duplicated().sum()
if duplicate_count > 0:
quality_issues.append({
'type': 'duplicate_rows',
'description': f'发现 {duplicate_count} 行重复数据',
'count': int(duplicate_count)
})
return {
'filename': filename,
'rows': rows,
'columns': columns,
'column_names': [str(col) for col in df.columns.tolist()],
'column_info': column_info,
'preview': preview_data,
'quality_issues': quality_issues,
'memory_usage': f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB"
}
except Exception as e:
print(e)
raise Exception(f"DataFrame分析失败: {str(e)}")
def _create_pandas_agent(self, df: pd.DataFrame):
"""
创建pandas代理
"""
try:
# 使用智谱AI作为LLM
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.1
)
agent = None
logger.error('创建pandas代理失败 - 暂屏蔽处理')
# # 创建pandas代理
# agent = create_pandas_dataframe_agent(
# llm=llm,
# df=df,
# verbose=True,
# return_intermediate_steps=True,
# handle_parsing_errors=True,
# max_iterations=3,
# early_stopping_method="force",
# allow_dangerous_code=True # 允许执行代码以支持数据分析
# )
return agent
except Exception as e:
raise Exception(f"创建pandas代理失败: {str(e)}")
def _execute_pandas_query(self, agent, query: str) -> Dict[str, Any]:
"""
执行pandas查询
"""
try:
# 执行查询
# 使用invoke方法来处理有多个输出键的情况
agent_result = agent.invoke({"input": query})
# 提取主要结果
result = agent_result.get('output', agent_result)
# 解析结果
if isinstance(result, pd.DataFrame):
# 如果结果是DataFrame
data = result.fillna('').to_dict('records')
columns = result.columns.tolist()
total = len(result)
return {
'data': data,
'columns': columns,
'total': total,
'result_type': 'dataframe'
}
else:
# 如果结果是其他类型(字符串、数字等)
return {
'data': [{'result': str(result)}],
'columns': ['result'],
'total': 1,
'result_type': 'scalar'
}
except Exception as e:
raise Exception(f"pandas查询执行失败: {str(e)}")
async def execute_natural_language_query(
self,
query: str,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
执行自然语言查询
"""
try:
# 查找用户的临时文件
temp_dir = tempfile.gettempdir()
user_files = [f for f in os.listdir(temp_dir)
if f.startswith(f"excel_{user_id}_") and f.endswith('.pkl')]
if not user_files:
return {
'success': False,
'message': '未找到上传的Excel文件请先上传文件'
}
# 使用最新的文件
latest_file = sorted(user_files)[-1]
file_path = os.path.join(temp_dir, latest_file)
# 加载DataFrame
df = pd.read_pickle(file_path)
# 创建pandas代理
agent = self._create_pandas_agent(df)
# 执行查询
query_result = await self._run_in_executor(
self._execute_pandas_query, agent, query
)
# 分页处理
total = query_result['total']
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_data = query_result['data'][start_idx:end_idx]
# 生成AI总结
summary = await self._generate_summary(query, query_result, df)
return {
'success': True,
'data': {
'data': paginated_data,
'columns': query_result['columns'],
'total': total,
'page': page,
'page_size': page_size,
'generated_code': f"# 基于自然语言查询: {query}\n# 使用LangChain Pandas代理执行",
'summary': summary,
'result_type': query_result['result_type']
}
}
except Exception as e:
return {
'success': False,
'message': f"查询执行失败: {str(e)}"
}
async def _generate_summary(self, query: str, result: Dict[str, Any], df: pd.DataFrame) -> str:
"""
生成AI总结
"""
try:
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.3
)
# 构建总结提示
prompt = f"""
用户查询: {query}
数据集信息:
- 总行数: {len(df)}
- 总列数: {len(df.columns)}
- 列名: {', '.join(str(col) for col in df.columns.tolist())}
查询结果:
- 结果类型: {result['result_type']}
- 结果行数: {result['total']}
- 结果列数: {len(result['columns'])}
请基于以上信息用中文生成一个简洁的分析总结包括:
1. 查询的主要目的
2. 关键发现
3. 数据洞察
4. 建议的后续分析方向
总结应该专业准确易懂控制在200字以内
"""
response = await self._run_in_executor(
lambda: llm.invoke([HumanMessage(content=prompt)])
)
return response.content
except Exception as e:
return f"查询已完成,但生成总结时出现错误: {str(e)}"
class DatabaseQueryService(SmartQueryService):
"""
数据库查询服务
"""
def __init__(self):
super().__init__()
self.user_connections = {} # 存储用户的数据库连接信息
def _create_connection(self, config: Dict[str, str]):
"""
创建数据库连接
"""
db_type = config['type'].lower()
try:
if db_type == 'mysql':
connection = pymysql.connect(
host=config['host'],
port=int(config['port']),
user=config['username'],
password=config['password'],
database=config['database'],
charset='utf8mb4'
)
elif db_type == 'postgresql':
connection = psycopg2.connect(
host=config['host'],
port=int(config['port']),
user=config['username'],
password=config['password'],
database=config['database']
)
return connection
except Exception as e:
raise Exception(f"数据库连接失败: {str(e)}")
async def test_connection(self, config: Dict[str, str]) -> bool:
"""
测试数据库连接
"""
try:
connection = await self._run_in_executor(self._create_connection, config)
connection.close()
return True
except Exception:
return False
async def connect_database(self, config: Dict[str, str], user_id: int) -> Dict[str, Any]:
"""
连接数据库并获取表列表
"""
try:
connection = await self._run_in_executor(self._create_connection, config)
# 获取表列表
tables = await self._run_in_executor(self._get_tables, connection, config['type'])
# 存储连接信息
self.user_connections[user_id] = {
'config': config,
'connection': connection,
'connected_at': datetime.now()
}
return {
'success': True,
'data': {
'tables': tables,
'database_type': config['type'],
'database_name': config['database']
}
}
except Exception as e:
return {
'success': False,
'message': f"数据库连接失败: {str(e)}"
}
def _get_tables(self, connection, db_type: str) -> List[str]:
"""
获取数据库表列表
"""
cursor = connection.cursor()
try:
if db_type.lower() == 'mysql':
cursor.execute("SHOW TABLES")
tables = [row[0] for row in cursor.fetchall()]
elif db_type.lower() == 'postgresql':
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
""")
tables = [row[0] for row in cursor.fetchall()]
elif db_type.lower() == 'sqlserver':
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
""")
tables = [row[0] for row in cursor.fetchall()]
else:
tables = []
return tables
finally:
cursor.close()
async def get_table_schema(self, table_name: str, user_id: int) -> Dict[str, Any]:
"""
获取表结构
"""
try:
if user_id not in self.user_connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接'
}
connection = self.user_connections[user_id]['connection']
db_type = self.user_connections[user_id]['config']['type']
schema = await self._run_in_executor(
self._get_table_schema, connection, table_name, db_type
)
return {
'success': True,
'data': {
'schema': schema,
'table_name': table_name
}
}
except Exception as e:
return {
'success': False,
'message': f"获取表结构失败: {str(e)}"
}
def _get_table_schema(self, connection, table_name: str, db_type: str) -> List[Dict[str, Any]]:
"""
获取表结构信息
"""
cursor = connection.cursor()
try:
if db_type.lower() == 'mysql':
cursor.execute(f"DESCRIBE {table_name}")
columns = cursor.fetchall()
schema = [{
'column_name': col[0],
'data_type': col[1],
'is_nullable': 'YES' if col[2] == 'YES' else 'NO',
'column_key': col[3],
'column_default': col[4]
} for col in columns]
elif db_type.lower() == 'postgresql':
cursor.execute("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table_name,))
columns = cursor.fetchall()
schema = [{
'column_name': col[0],
'data_type': col[1],
'is_nullable': col[2],
'column_default': col[3]
} for col in columns]
else:
schema = []
return schema
finally:
cursor.close()
async def execute_natural_language_query(
self,
query: str,
table_name: str,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
执行自然语言数据库查询
"""
try:
if user_id not in self.user_connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接'
}
connection = self.user_connections[user_id]['connection']
# 这里应该集成MCP服务来将自然语言转换为SQL
# 目前先使用简单的实现
sql_query = await self._convert_to_sql(query, table_name, connection)
# 执行SQL查询
result = await self._run_in_executor(
self._execute_sql_query, connection, sql_query, page, page_size
)
# 生成AI总结
summary = await self._generate_db_summary(query, result, table_name)
result['generated_code'] = sql_query
result['summary'] = summary
return {
'success': True,
'data': result
}
except Exception as e:
return {
'success': False,
'message': f"数据库查询执行失败: {str(e)}"
}
async def _convert_to_sql(self, query: str, table_name: str, connection) -> str:
"""
将自然语言转换为SQL查询
TODO: 集成MCP服务
"""
# 这是一个简化的实现实际应该使用MCP服务
# 根据常见的查询模式生成SQL
query_lower = query.lower()
if '所有' in query or '全部' in query or 'all' in query_lower:
return f"SELECT * FROM {table_name} LIMIT 100"
elif '统计' in query or '总数' in query or 'count' in query_lower:
return f"SELECT COUNT(*) as total_count FROM {table_name}"
elif '最近' in query or 'recent' in query_lower:
return f"SELECT * FROM {table_name} ORDER BY id DESC LIMIT 10"
elif '分组' in query or 'group' in query_lower:
# 简单的分组查询,需要根据实际表结构调整
return f"SELECT COUNT(*) as count FROM {table_name} GROUP BY id LIMIT 10"
else:
# 默认查询
return f"SELECT * FROM {table_name} LIMIT 20"
def _execute_sql_query(self, connection, sql_query: str, page: int, page_size: int) -> Dict[str, Any]:
"""
执行SQL查询
"""
cursor = connection.cursor()
try:
# 执行查询
cursor.execute(sql_query)
# 获取列名
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# 获取所有结果
all_results = cursor.fetchall()
total = len(all_results)
# 分页
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = all_results[start_idx:end_idx]
# 转换为字典格式
data = []
for row in paginated_results:
row_dict = {}
for i, value in enumerate(row):
if i < len(columns):
row_dict[columns[i]] = value
data.append(row_dict)
return {
'data': data,
'columns': columns,
'total': total,
'page': page,
'page_size': page_size
}
finally:
cursor.close()
async def _generate_db_summary(self, query: str, result: Dict[str, Any], table_name: str) -> str:
"""
生成数据库查询总结
"""
try:
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.3
)
prompt = f"""
用户查询: {query}
目标表: {table_name}
查询结果:
- 结果行数: {result['total']}
- 结果列数: {len(result['columns'])}
- 列名: {', '.join(result['columns'])}
请基于以上信息用中文生成一个简洁的数据库查询分析总结包括:
1. 查询的主要目的
2. 关键数据发现
3. 数据特征分析
4. 建议的后续查询方向
总结应该专业准确易懂控制在200字以内
"""
response = await self._run_in_executor(
lambda: llm.invoke([HumanMessage(content=prompt)])
)
return response.content
except Exception as e:
return f"查询已完成,但生成总结时出现错误: {str(e)}"
# 在 SmartQueryService 类中添加方法
from .table_metadata_service import TableMetadataService
class SmartQueryService:
def __init__(self):
super().__init__()
self.table_metadata_service = None
def set_db_session(self, db_session):
"""设置数据库会话"""
self.table_metadata_service = TableMetadataService(db_session)
async def get_database_context(self, user_id: int, query: str) -> str:
"""获取数据库上下文信息用于问答"""
if not self.table_metadata_service:
return ""
try:
# 获取用户的表元数据
table_metadata_list = self.table_metadata_service.get_user_table_metadata(user_id)
if not table_metadata_list:
return ""
# 构建数据库上下文
context_parts = []
context_parts.append("=== 数据库表信息 ===")
for metadata in table_metadata_list:
table_info = []
table_info.append(f"表名: {metadata.table_name}")
if metadata.table_comment:
table_info.append(f"表描述: {metadata.table_comment}")
if metadata.qa_description:
table_info.append(f"业务说明: {metadata.qa_description}")
# 添加列信息
if metadata.columns_info:
columns = []
for col in metadata.columns_info:
col_desc = f"{col['column_name']} ({col['data_type']})"
if col.get('column_comment'):
col_desc += f" - {col['column_comment']}"
columns.append(col_desc)
table_info.append(f"字段: {', '.join(columns)}")
# 添加示例数据
if metadata.sample_data:
table_info.append(f"示例数据: {metadata.sample_data[:2]}")
table_info.append(f"总行数: {metadata.row_count}")
context_parts.append("\n".join(table_info))
context_parts.append("---")
return "\n".join(context_parts)
except Exception as e:
logger.error(f"获取数据库上下文失败: {str(e)}")
return ""
async def execute_smart_query(self, query: str, user_id: int, **kwargs) -> Dict[str, Any]:
"""执行智能查询(集成表元数据)"""
try:
# 获取数据库上下文
db_context = await self.get_database_context(user_id, query)
# 构建增强的提示词
enhanced_prompt = f"""
{db_context}
用户问题: {query}
请基于上述数据库表信息生成相应的SQL查询语句
注意
1. 使用准确的表名和字段名
2. 考虑数据类型和约束
3. 参考示例数据理解数据格式
4. 生成高效的查询语句
"""
# 调用原有的查询逻辑
return await super().execute_smart_query(enhanced_prompt, user_id, **kwargs)
except Exception as e:
logger.error(f"智能查询失败: {str(e)}")
return {
'success': False,
'message': f"查询失败: {str(e)}"
}