hyf-backend/th_agenter/services/smart_query.py

717 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import pymysql
import psycopg2
import tempfile
import os
from typing import Dict, Any, List
from datetime import datetime
import asyncio
from concurrent.futures import ThreadPoolExecutor
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from loguru import logger
# 在 SmartQueryService 类中添加方法
from .table_metadata_service import TableMetadataService
class SmartQueryService:
"""
智能问数服务基类
"""
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=4)
self.table_metadata_service = None
def set_db_session(self, db_session):
"""设置数据库会话"""
self.table_metadata_service = TableMetadataService(db_session)
async def _run_in_executor(self, func, *args):
"""在线程池中运行阻塞函数"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
class ExcelAnalysisService(SmartQueryService):
"""
Excel数据分析服务
"""
def __init__(self):
super().__init__()
self.user_dataframes = {} # 存储用户的DataFrame
def analyze_dataframe(self, df: pd.DataFrame, filename: str) -> Dict[str, Any]:
"""
分析DataFrame并返回基本信息
"""
try:
# 基本统计信息
rows, columns = df.shape
# 列信息
column_info = []
for col in df.columns:
col_info = {
'name': col,
'dtype': str(df[col].dtype),
'null_count': int(df[col].isnull().sum()),
'unique_count': int(df[col].nunique())
}
# 如果是数值列,添加统计信息
if pd.api.types.is_numeric_dtype(df[col]):
df.fillna({col:0}) #数值列将空值补0
col_info.update({
'mean': float(df[col].mean()) if not df[col].isnull().all() else None,
'std': float(df[col].std()) if not df[col].isnull().all() else None,
'min': float(df[col].min()) if not df[col].isnull().all() else None,
'max': float(df[col].max()) if not df[col].isnull().all() else None
})
column_info.append(col_info)
# 数据预览前5行
preview_data = df.head().fillna('').to_dict('records')
# 数据质量检查
quality_issues = []
# 检查缺失值
missing_cols = df.columns[df.isnull().any()].tolist()
if missing_cols:
quality_issues.append({
'type': 'missing_values',
'description': f'以下列存在缺失值: {", ".join(map(str, missing_cols))}',
'columns': missing_cols
})
# 检查重复行
duplicate_count = df.duplicated().sum()
if duplicate_count > 0:
quality_issues.append({
'type': 'duplicate_rows',
'description': f'发现 {duplicate_count} 行重复数据',
'count': int(duplicate_count)
})
return {
'filename': filename,
'rows': rows,
'columns': columns,
'column_names': [str(col) for col in df.columns.tolist()],
'column_info': column_info,
'preview': preview_data,
'quality_issues': quality_issues,
'memory_usage': f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB"
}
except Exception as e:
print(e)
raise Exception(f"DataFrame分析失败: {str(e)}")
def _create_pandas_agent(self, df: pd.DataFrame):
"""
创建pandas代理
"""
try:
# 使用智谱AI作为LLM
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.1
)
agent = None
logger.error('创建pandas代理失败 - 暂屏蔽处理')
# # 创建pandas代理
# agent = create_pandas_dataframe_agent(
# llm=llm,
# df=df,
# verbose=True,
# return_intermediate_steps=True,
# handle_parsing_errors=True,
# max_iterations=3,
# early_stopping_method="force",
# allow_dangerous_code=True # 允许执行代码以支持数据分析
# )
return agent
except Exception as e:
raise Exception(f"创建pandas代理失败: {str(e)}")
def _execute_pandas_query(self, agent, query: str) -> Dict[str, Any]:
"""
执行pandas查询
"""
try:
# 执行查询
# 使用invoke方法来处理有多个输出键的情况
agent_result = agent.invoke({"input": query})
# 提取主要结果
result = agent_result.get('output', agent_result)
# 解析结果
if isinstance(result, pd.DataFrame):
# 如果结果是DataFrame
data = result.fillna('').to_dict('records')
columns = result.columns.tolist()
total = len(result)
return {
'data': data,
'columns': columns,
'total': total,
'result_type': 'dataframe'
}
else:
# 如果结果是其他类型(字符串、数字等)
return {
'data': [{'result': str(result)}],
'columns': ['result'],
'total': 1,
'result_type': 'scalar'
}
except Exception as e:
raise Exception(f"pandas查询执行失败: {str(e)}")
async def execute_natural_language_query(
self,
query: str,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
执行自然语言查询
"""
try:
# 查找用户的临时文件
temp_dir = tempfile.gettempdir()
user_files = [f for f in os.listdir(temp_dir)
if f.startswith(f"excel_{user_id}_") and f.endswith('.pkl')]
if not user_files:
return {
'success': False,
'message': '未找到上传的Excel文件请先上传文件'
}
# 使用最新的文件
latest_file = sorted(user_files)[-1]
file_path = os.path.join(temp_dir, latest_file)
# 加载DataFrame
df = pd.read_pickle(file_path)
# 创建pandas代理
agent = self._create_pandas_agent(df)
# 执行查询
query_result = await self._run_in_executor(
self._execute_pandas_query, agent, query
)
# 分页处理
total = query_result['total']
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_data = query_result['data'][start_idx:end_idx]
# 生成AI总结
summary = await self._generate_summary(query, query_result, df)
return {
'success': True,
'data': {
'data': paginated_data,
'columns': query_result['columns'],
'total': total,
'page': page,
'page_size': page_size,
'generated_code': f"# 基于自然语言查询: {query}\n# 使用LangChain Pandas代理执行",
'summary': summary,
'result_type': query_result['result_type']
}
}
except Exception as e:
return {
'success': False,
'message': f"查询执行失败: {str(e)}"
}
async def _generate_summary(self, query: str, result: Dict[str, Any], df: pd.DataFrame) -> str:
"""
生成AI总结
"""
try:
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.3
)
# 构建总结提示
prompt = f"""
用户查询: {query}
数据集信息:
- 总行数: {len(df)}
- 总列数: {len(df.columns)}
- 列名: {', '.join(str(col) for col in df.columns.tolist())}
查询结果:
- 结果类型: {result['result_type']}
- 结果行数: {result['total']}
- 结果列数: {len(result['columns'])}
请基于以上信息,用中文生成一个简洁的分析总结,包括:
1. 查询的主要目的
2. 关键发现
3. 数据洞察
4. 建议的后续分析方向
总结应该专业、准确、易懂控制在200字以内。
"""
response = await self._run_in_executor(
lambda: llm.invoke([HumanMessage(content=prompt)])
)
return response.content
except Exception as e:
return f"查询已完成,但生成总结时出现错误: {str(e)}"
class DatabaseQueryService(SmartQueryService):
"""
数据库查询服务
"""
def __init__(self):
super().__init__()
self.user_connections = {} # 存储用户的数据库连接信息
def _create_connection(self, config: Dict[str, str]):
"""
创建数据库连接
"""
db_type = config['type'].lower()
try:
if db_type == 'mysql':
connection = pymysql.connect(
host=config['host'],
port=int(config['port']),
user=config['username'],
password=config['password'],
database=config['database'],
charset='utf8mb4'
)
elif db_type == 'postgresql':
connection = psycopg2.connect(
host=config['host'],
port=int(config['port']),
user=config['username'],
password=config['password'],
database=config['database']
)
return connection
except Exception as e:
raise Exception(f"数据库连接失败: {str(e)}")
async def test_connection(self, config: Dict[str, str]) -> bool:
"""
测试数据库连接
"""
try:
connection = await self._run_in_executor(self._create_connection, config)
connection.close()
return True
except Exception:
return False
async def connect_database(self, config: Dict[str, str], user_id: int) -> Dict[str, Any]:
"""
连接数据库并获取表列表
"""
try:
connection = await self._run_in_executor(self._create_connection, config)
# 获取表列表
tables = await self._run_in_executor(self._get_tables, connection, config['type'])
# 存储连接信息
self.user_connections[user_id] = {
'config': config,
'connection': connection,
'connected_at': datetime.now()
}
return {
'success': True,
'data': {
'tables': tables,
'database_type': config['type'],
'database_name': config['database']
}
}
except Exception as e:
return {
'success': False,
'message': f"数据库连接失败: {str(e)}"
}
def _get_tables(self, connection, db_type: str) -> List[str]:
"""
获取数据库表列表
"""
cursor = connection.cursor()
try:
if db_type.lower() == 'mysql':
cursor.execute("SHOW TABLES")
tables = [row[0] for row in cursor.fetchall()]
elif db_type.lower() == 'postgresql':
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
""")
tables = [row[0] for row in cursor.fetchall()]
elif db_type.lower() == 'sqlserver':
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
""")
tables = [row[0] for row in cursor.fetchall()]
else:
tables = []
return tables
finally:
cursor.close()
async def get_table_schema(self, table_name: str, user_id: int) -> Dict[str, Any]:
"""
获取表结构
"""
try:
if user_id not in self.user_connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接'
}
connection = self.user_connections[user_id]['connection']
db_type = self.user_connections[user_id]['config']['type']
schema = await self._run_in_executor(
self._get_table_schema, connection, table_name, db_type
)
return {
'success': True,
'data': {
'schema': schema,
'table_name': table_name
}
}
except Exception as e:
return {
'success': False,
'message': f"获取表结构失败: {str(e)}"
}
def _get_table_schema(self, connection, table_name: str, db_type: str) -> List[Dict[str, Any]]:
"""
获取表结构信息
"""
cursor = connection.cursor()
try:
if db_type.lower() == 'mysql':
cursor.execute(f"DESCRIBE {table_name}")
columns = cursor.fetchall()
schema = [{
'column_name': col[0],
'data_type': col[1],
'is_nullable': 'YES' if col[2] == 'YES' else 'NO',
'column_key': col[3],
'column_default': col[4]
} for col in columns]
elif db_type.lower() == 'postgresql':
cursor.execute("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
""", (table_name,))
columns = cursor.fetchall()
schema = [{
'column_name': col[0],
'data_type': col[1],
'is_nullable': col[2],
'column_default': col[3]
} for col in columns]
else:
schema = []
return schema
finally:
cursor.close()
async def execute_natural_language_query(
self,
query: str,
table_name: str,
user_id: int,
page: int = 1,
page_size: int = 20
) -> Dict[str, Any]:
"""
执行自然语言数据库查询
"""
try:
if user_id not in self.user_connections:
return {
'success': False,
'message': '数据库连接已断开,请重新连接'
}
connection = self.user_connections[user_id]['connection']
# 这里应该集成MCP服务来将自然语言转换为SQL
# 目前先使用简单的实现
sql_query = await self._convert_to_sql(query, table_name, connection)
# 执行SQL查询
result = await self._run_in_executor(
self._execute_sql_query, connection, sql_query, page, page_size
)
# 生成AI总结
summary = await self._generate_db_summary(query, result, table_name)
result['generated_code'] = sql_query
result['summary'] = summary
return {
'success': True,
'data': result
}
except Exception as e:
return {
'success': False,
'message': f"数据库查询执行失败: {str(e)}"
}
async def _convert_to_sql(self, query: str, table_name: str, connection) -> str:
"""
将自然语言转换为SQL查询
TODO: 集成MCP服务
"""
# 这是一个简化的实现实际应该使用MCP服务
# 根据常见的查询模式生成SQL
query_lower = query.lower()
if '所有' in query or '全部' in query or 'all' in query_lower:
return f"SELECT * FROM {table_name} LIMIT 100"
elif '统计' in query or '总数' in query or 'count' in query_lower:
return f"SELECT COUNT(*) as total_count FROM {table_name}"
elif '最近' in query or 'recent' in query_lower:
return f"SELECT * FROM {table_name} ORDER BY id DESC LIMIT 10"
elif '分组' in query or 'group' in query_lower:
# 简单的分组查询,需要根据实际表结构调整
return f"SELECT COUNT(*) as count FROM {table_name} GROUP BY id LIMIT 10"
else:
# 默认查询
return f"SELECT * FROM {table_name} LIMIT 20"
def _execute_sql_query(self, connection, sql_query: str, page: int, page_size: int) -> Dict[str, Any]:
"""
执行SQL查询
"""
cursor = connection.cursor()
try:
# 执行查询
cursor.execute(sql_query)
# 获取列名
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# 获取所有结果
all_results = cursor.fetchall()
total = len(all_results)
# 分页
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = all_results[start_idx:end_idx]
# 转换为字典格式
data = []
for row in paginated_results:
row_dict = {}
for i, value in enumerate(row):
if i < len(columns):
row_dict[columns[i]] = value
data.append(row_dict)
return {
'data': data,
'columns': columns,
'total': total,
'page': page,
'page_size': page_size
}
finally:
cursor.close()
async def _generate_db_summary(self, query: str, result: Dict[str, Any], table_name: str) -> str:
"""
生成数据库查询总结
"""
try:
llm = ChatZhipuAI(
model="glm-4",
api_key=os.getenv("ZHIPUAI_API_KEY"),
temperature=0.3
)
prompt = f"""
用户查询: {query}
目标表: {table_name}
查询结果:
- 结果行数: {result['total']}
- 结果列数: {len(result['columns'])}
- 列名: {', '.join(result['columns'])}
请基于以上信息,用中文生成一个简洁的数据库查询分析总结,包括:
1. 查询的主要目的
2. 关键数据发现
3. 数据特征分析
4. 建议的后续查询方向
总结应该专业、准确、易懂控制在200字以内。
"""
response = await self._run_in_executor(
lambda: llm.invoke([HumanMessage(content=prompt)])
)
return response.content
except Exception as e:
return f"查询已完成,但生成总结时出现错误: {str(e)}"
# 在 SmartQueryService 类中添加方法
from .table_metadata_service import TableMetadataService
class SmartQueryService:
def __init__(self):
super().__init__()
self.table_metadata_service = None
def set_db_session(self, db_session):
"""设置数据库会话"""
self.table_metadata_service = TableMetadataService(db_session)
async def get_database_context(self, user_id: int, query: str) -> str:
"""获取数据库上下文信息用于问答"""
if not self.table_metadata_service:
return ""
try:
# 获取用户的表元数据
table_metadata_list = self.table_metadata_service.get_user_table_metadata(user_id)
if not table_metadata_list:
return ""
# 构建数据库上下文
context_parts = []
context_parts.append("=== 数据库表信息 ===")
for metadata in table_metadata_list:
table_info = []
table_info.append(f"表名: {metadata.table_name}")
if metadata.table_comment:
table_info.append(f"表描述: {metadata.table_comment}")
if metadata.qa_description:
table_info.append(f"业务说明: {metadata.qa_description}")
# 添加列信息
if metadata.columns_info:
columns = []
for col in metadata.columns_info:
col_desc = f"{col['column_name']} ({col['data_type']})"
if col.get('column_comment'):
col_desc += f" - {col['column_comment']}"
columns.append(col_desc)
table_info.append(f"字段: {', '.join(columns)}")
# 添加示例数据
if metadata.sample_data:
table_info.append(f"示例数据: {metadata.sample_data[:2]}")
table_info.append(f"总行数: {metadata.row_count}")
context_parts.append("\n".join(table_info))
context_parts.append("---")
return "\n".join(context_parts)
except Exception as e:
logger.error(f"获取数据库上下文失败: {str(e)}")
return ""
async def execute_smart_query(self, query: str, user_id: int, **kwargs) -> Dict[str, Any]:
"""执行智能查询(集成表元数据)"""
try:
# 获取数据库上下文
db_context = await self.get_database_context(user_id, query)
# 构建增强的提示词
enhanced_prompt = f"""
{db_context}
用户问题: {query}
请基于上述数据库表信息生成相应的SQL查询语句。
注意:
1. 使用准确的表名和字段名
2. 考虑数据类型和约束
3. 参考示例数据理解数据格式
4. 生成高效的查询语句
"""
# 调用原有的查询逻辑
return await super().execute_smart_query(enhanced_prompt, user_id, **kwargs)
except Exception as e:
logger.error(f"智能查询失败: {str(e)}")
return {
'success': False,
'message': f"查询失败: {str(e)}"
}