717 lines
25 KiB
Python
717 lines
25 KiB
Python
import pandas as pd
|
||
import pymysql
|
||
import psycopg2
|
||
import tempfile
|
||
import os
|
||
from typing import Dict, Any, List
|
||
from datetime import datetime
|
||
import asyncio
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
from langchain_community.chat_models import ChatZhipuAI
|
||
from langchain_core.messages import HumanMessage
|
||
from loguru import logger
|
||
|
||
# 在 SmartQueryService 类中添加方法
|
||
|
||
from .table_metadata_service import TableMetadataService
|
||
|
||
class SmartQueryService:
|
||
"""
|
||
智能问数服务基类
|
||
"""
|
||
def __init__(self):
|
||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||
self.table_metadata_service = None
|
||
|
||
def set_db_session(self, db_session):
|
||
"""设置数据库会话"""
|
||
self.table_metadata_service = TableMetadataService(db_session)
|
||
|
||
async def _run_in_executor(self, func, *args):
|
||
"""在线程池中运行阻塞函数"""
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(self.executor, func, *args)
|
||
|
||
class ExcelAnalysisService(SmartQueryService):
|
||
"""
|
||
Excel数据分析服务
|
||
"""
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.user_dataframes = {} # 存储用户的DataFrame
|
||
|
||
def analyze_dataframe(self, df: pd.DataFrame, filename: str) -> Dict[str, Any]:
|
||
"""
|
||
分析DataFrame并返回基本信息
|
||
"""
|
||
try:
|
||
# 基本统计信息
|
||
rows, columns = df.shape
|
||
|
||
# 列信息
|
||
column_info = []
|
||
for col in df.columns:
|
||
col_info = {
|
||
'name': col,
|
||
'dtype': str(df[col].dtype),
|
||
'null_count': int(df[col].isnull().sum()),
|
||
'unique_count': int(df[col].nunique())
|
||
}
|
||
|
||
# 如果是数值列,添加统计信息
|
||
if pd.api.types.is_numeric_dtype(df[col]):
|
||
df.fillna({col:0}) #数值列,将空值补0
|
||
col_info.update({
|
||
'mean': float(df[col].mean()) if not df[col].isnull().all() else None,
|
||
'std': float(df[col].std()) if not df[col].isnull().all() else None,
|
||
'min': float(df[col].min()) if not df[col].isnull().all() else None,
|
||
'max': float(df[col].max()) if not df[col].isnull().all() else None
|
||
})
|
||
|
||
column_info.append(col_info)
|
||
|
||
# 数据预览(前5行)
|
||
preview_data = df.head().fillna('').to_dict('records')
|
||
|
||
# 数据质量检查
|
||
quality_issues = []
|
||
|
||
# 检查缺失值
|
||
missing_cols = df.columns[df.isnull().any()].tolist()
|
||
if missing_cols:
|
||
quality_issues.append({
|
||
'type': 'missing_values',
|
||
'description': f'以下列存在缺失值: {", ".join(map(str, missing_cols))}',
|
||
'columns': missing_cols
|
||
})
|
||
|
||
# 检查重复行
|
||
duplicate_count = df.duplicated().sum()
|
||
if duplicate_count > 0:
|
||
quality_issues.append({
|
||
'type': 'duplicate_rows',
|
||
'description': f'发现 {duplicate_count} 行重复数据',
|
||
'count': int(duplicate_count)
|
||
})
|
||
|
||
return {
|
||
'filename': filename,
|
||
'rows': rows,
|
||
'columns': columns,
|
||
'column_names': [str(col) for col in df.columns.tolist()],
|
||
'column_info': column_info,
|
||
'preview': preview_data,
|
||
'quality_issues': quality_issues,
|
||
'memory_usage': f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB"
|
||
}
|
||
|
||
except Exception as e:
|
||
print(e)
|
||
raise Exception(f"DataFrame分析失败: {str(e)}")
|
||
|
||
def _create_pandas_agent(self, df: pd.DataFrame):
|
||
"""
|
||
创建pandas代理
|
||
"""
|
||
try:
|
||
# 使用智谱AI作为LLM
|
||
llm = ChatZhipuAI(
|
||
model="glm-4",
|
||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||
temperature=0.1
|
||
)
|
||
agent = None
|
||
logger.error('创建pandas代理失败 - 暂屏蔽处理')
|
||
|
||
# # 创建pandas代理
|
||
# agent = create_pandas_dataframe_agent(
|
||
# llm=llm,
|
||
# df=df,
|
||
# verbose=True,
|
||
# return_intermediate_steps=True,
|
||
# handle_parsing_errors=True,
|
||
# max_iterations=3,
|
||
# early_stopping_method="force",
|
||
# allow_dangerous_code=True # 允许执行代码以支持数据分析
|
||
# )
|
||
|
||
return agent
|
||
|
||
except Exception as e:
|
||
raise Exception(f"创建pandas代理失败: {str(e)}")
|
||
|
||
def _execute_pandas_query(self, agent, query: str) -> Dict[str, Any]:
|
||
"""
|
||
执行pandas查询
|
||
"""
|
||
try:
|
||
# 执行查询
|
||
# 使用invoke方法来处理有多个输出键的情况
|
||
agent_result = agent.invoke({"input": query})
|
||
# 提取主要结果
|
||
result = agent_result.get('output', agent_result)
|
||
|
||
# 解析结果
|
||
if isinstance(result, pd.DataFrame):
|
||
# 如果结果是DataFrame
|
||
data = result.fillna('').to_dict('records')
|
||
columns = result.columns.tolist()
|
||
total = len(result)
|
||
|
||
return {
|
||
'data': data,
|
||
'columns': columns,
|
||
'total': total,
|
||
'result_type': 'dataframe'
|
||
}
|
||
else:
|
||
# 如果结果是其他类型(字符串、数字等)
|
||
return {
|
||
'data': [{'result': str(result)}],
|
||
'columns': ['result'],
|
||
'total': 1,
|
||
'result_type': 'scalar'
|
||
}
|
||
|
||
except Exception as e:
|
||
raise Exception(f"pandas查询执行失败: {str(e)}")
|
||
|
||
async def execute_natural_language_query(
|
||
self,
|
||
query: str,
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行自然语言查询
|
||
"""
|
||
try:
|
||
# 查找用户的临时文件
|
||
temp_dir = tempfile.gettempdir()
|
||
user_files = [f for f in os.listdir(temp_dir)
|
||
if f.startswith(f"excel_{user_id}_") and f.endswith('.pkl')]
|
||
|
||
if not user_files:
|
||
return {
|
||
'success': False,
|
||
'message': '未找到上传的Excel文件,请先上传文件'
|
||
}
|
||
|
||
# 使用最新的文件
|
||
latest_file = sorted(user_files)[-1]
|
||
file_path = os.path.join(temp_dir, latest_file)
|
||
|
||
# 加载DataFrame
|
||
df = pd.read_pickle(file_path)
|
||
|
||
# 创建pandas代理
|
||
agent = self._create_pandas_agent(df)
|
||
|
||
# 执行查询
|
||
query_result = await self._run_in_executor(
|
||
self._execute_pandas_query, agent, query
|
||
)
|
||
|
||
# 分页处理
|
||
total = query_result['total']
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = start_idx + page_size
|
||
|
||
paginated_data = query_result['data'][start_idx:end_idx]
|
||
|
||
# 生成AI总结
|
||
summary = await self._generate_summary(query, query_result, df)
|
||
|
||
return {
|
||
'success': True,
|
||
'data': {
|
||
'data': paginated_data,
|
||
'columns': query_result['columns'],
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'generated_code': f"# 基于自然语言查询: {query}\n# 使用LangChain Pandas代理执行",
|
||
'summary': summary,
|
||
'result_type': query_result['result_type']
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'message': f"查询执行失败: {str(e)}"
|
||
}
|
||
|
||
async def _generate_summary(self, query: str, result: Dict[str, Any], df: pd.DataFrame) -> str:
|
||
"""
|
||
生成AI总结
|
||
"""
|
||
try:
|
||
llm = ChatZhipuAI(
|
||
model="glm-4",
|
||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||
temperature=0.3
|
||
)
|
||
|
||
# 构建总结提示
|
||
prompt = f"""
|
||
用户查询: {query}
|
||
|
||
数据集信息:
|
||
- 总行数: {len(df)}
|
||
- 总列数: {len(df.columns)}
|
||
- 列名: {', '.join(str(col) for col in df.columns.tolist())}
|
||
|
||
查询结果:
|
||
- 结果类型: {result['result_type']}
|
||
- 结果行数: {result['total']}
|
||
- 结果列数: {len(result['columns'])}
|
||
|
||
请基于以上信息,用中文生成一个简洁的分析总结,包括:
|
||
1. 查询的主要目的
|
||
2. 关键发现
|
||
3. 数据洞察
|
||
4. 建议的后续分析方向
|
||
|
||
总结应该专业、准确、易懂,控制在200字以内。
|
||
"""
|
||
|
||
response = await self._run_in_executor(
|
||
lambda: llm.invoke([HumanMessage(content=prompt)])
|
||
)
|
||
|
||
return response.content
|
||
|
||
except Exception as e:
|
||
return f"查询已完成,但生成总结时出现错误: {str(e)}"
|
||
|
||
class DatabaseQueryService(SmartQueryService):
|
||
"""
|
||
数据库查询服务
|
||
"""
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.user_connections = {} # 存储用户的数据库连接信息
|
||
|
||
def _create_connection(self, config: Dict[str, str]):
|
||
"""
|
||
创建数据库连接
|
||
"""
|
||
db_type = config['type'].lower()
|
||
|
||
try:
|
||
if db_type == 'mysql':
|
||
connection = pymysql.connect(
|
||
host=config['host'],
|
||
port=int(config['port']),
|
||
user=config['username'],
|
||
password=config['password'],
|
||
database=config['database'],
|
||
charset='utf8mb4'
|
||
)
|
||
elif db_type == 'postgresql':
|
||
connection = psycopg2.connect(
|
||
host=config['host'],
|
||
port=int(config['port']),
|
||
user=config['username'],
|
||
password=config['password'],
|
||
database=config['database']
|
||
)
|
||
|
||
return connection
|
||
|
||
except Exception as e:
|
||
raise Exception(f"数据库连接失败: {str(e)}")
|
||
|
||
async def test_connection(self, config: Dict[str, str]) -> bool:
|
||
"""
|
||
测试数据库连接
|
||
"""
|
||
try:
|
||
connection = await self._run_in_executor(self._create_connection, config)
|
||
connection.close()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
async def connect_database(self, config: Dict[str, str], user_id: int) -> Dict[str, Any]:
|
||
"""
|
||
连接数据库并获取表列表
|
||
"""
|
||
try:
|
||
connection = await self._run_in_executor(self._create_connection, config)
|
||
|
||
# 获取表列表
|
||
tables = await self._run_in_executor(self._get_tables, connection, config['type'])
|
||
|
||
# 存储连接信息
|
||
self.user_connections[user_id] = {
|
||
'config': config,
|
||
'connection': connection,
|
||
'connected_at': datetime.now()
|
||
}
|
||
|
||
return {
|
||
'success': True,
|
||
'data': {
|
||
'tables': tables,
|
||
'database_type': config['type'],
|
||
'database_name': config['database']
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'message': f"数据库连接失败: {str(e)}"
|
||
}
|
||
|
||
def _get_tables(self, connection, db_type: str) -> List[str]:
|
||
"""
|
||
获取数据库表列表
|
||
"""
|
||
cursor = connection.cursor()
|
||
|
||
try:
|
||
if db_type.lower() == 'mysql':
|
||
cursor.execute("SHOW TABLES")
|
||
tables = [row[0] for row in cursor.fetchall()]
|
||
elif db_type.lower() == 'postgresql':
|
||
cursor.execute("""
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
""")
|
||
tables = [row[0] for row in cursor.fetchall()]
|
||
|
||
elif db_type.lower() == 'sqlserver':
|
||
cursor.execute("""
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_type = 'BASE TABLE'
|
||
""")
|
||
tables = [row[0] for row in cursor.fetchall()]
|
||
else:
|
||
tables = []
|
||
|
||
return tables
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
async def get_table_schema(self, table_name: str, user_id: int) -> Dict[str, Any]:
|
||
"""
|
||
获取表结构
|
||
"""
|
||
try:
|
||
if user_id not in self.user_connections:
|
||
return {
|
||
'success': False,
|
||
'message': '数据库连接已断开,请重新连接'
|
||
}
|
||
|
||
connection = self.user_connections[user_id]['connection']
|
||
db_type = self.user_connections[user_id]['config']['type']
|
||
|
||
schema = await self._run_in_executor(
|
||
self._get_table_schema, connection, table_name, db_type
|
||
)
|
||
|
||
return {
|
||
'success': True,
|
||
'data': {
|
||
'schema': schema,
|
||
'table_name': table_name
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'message': f"获取表结构失败: {str(e)}"
|
||
}
|
||
|
||
def _get_table_schema(self, connection, table_name: str, db_type: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取表结构信息
|
||
"""
|
||
cursor = connection.cursor()
|
||
|
||
try:
|
||
if db_type.lower() == 'mysql':
|
||
cursor.execute(f"DESCRIBE {table_name}")
|
||
columns = cursor.fetchall()
|
||
schema = [{
|
||
'column_name': col[0],
|
||
'data_type': col[1],
|
||
'is_nullable': 'YES' if col[2] == 'YES' else 'NO',
|
||
'column_key': col[3],
|
||
'column_default': col[4]
|
||
} for col in columns]
|
||
elif db_type.lower() == 'postgresql':
|
||
cursor.execute("""
|
||
SELECT column_name, data_type, is_nullable, column_default
|
||
FROM information_schema.columns
|
||
WHERE table_name = %s
|
||
ORDER BY ordinal_position
|
||
""", (table_name,))
|
||
columns = cursor.fetchall()
|
||
schema = [{
|
||
'column_name': col[0],
|
||
'data_type': col[1],
|
||
'is_nullable': col[2],
|
||
'column_default': col[3]
|
||
} for col in columns]
|
||
|
||
else:
|
||
schema = []
|
||
|
||
return schema
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
async def execute_natural_language_query(
|
||
self,
|
||
query: str,
|
||
table_name: str,
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行自然语言数据库查询
|
||
"""
|
||
try:
|
||
if user_id not in self.user_connections:
|
||
return {
|
||
'success': False,
|
||
'message': '数据库连接已断开,请重新连接'
|
||
}
|
||
|
||
connection = self.user_connections[user_id]['connection']
|
||
|
||
# 这里应该集成MCP服务来将自然语言转换为SQL
|
||
# 目前先使用简单的实现
|
||
sql_query = await self._convert_to_sql(query, table_name, connection)
|
||
|
||
# 执行SQL查询
|
||
result = await self._run_in_executor(
|
||
self._execute_sql_query, connection, sql_query, page, page_size
|
||
)
|
||
|
||
# 生成AI总结
|
||
summary = await self._generate_db_summary(query, result, table_name)
|
||
|
||
result['generated_code'] = sql_query
|
||
result['summary'] = summary
|
||
|
||
return {
|
||
'success': True,
|
||
'data': result
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
'success': False,
|
||
'message': f"数据库查询执行失败: {str(e)}"
|
||
}
|
||
|
||
async def _convert_to_sql(self, query: str, table_name: str, connection) -> str:
|
||
"""
|
||
将自然语言转换为SQL查询
|
||
TODO: 集成MCP服务
|
||
"""
|
||
# 这是一个简化的实现,实际应该使用MCP服务
|
||
# 根据常见的查询模式生成SQL
|
||
|
||
query_lower = query.lower()
|
||
|
||
if '所有' in query or '全部' in query or 'all' in query_lower:
|
||
return f"SELECT * FROM {table_name} LIMIT 100"
|
||
elif '统计' in query or '总数' in query or 'count' in query_lower:
|
||
return f"SELECT COUNT(*) as total_count FROM {table_name}"
|
||
elif '最近' in query or 'recent' in query_lower:
|
||
return f"SELECT * FROM {table_name} ORDER BY id DESC LIMIT 10"
|
||
elif '分组' in query or 'group' in query_lower:
|
||
# 简单的分组查询,需要根据实际表结构调整
|
||
return f"SELECT COUNT(*) as count FROM {table_name} GROUP BY id LIMIT 10"
|
||
else:
|
||
# 默认查询
|
||
return f"SELECT * FROM {table_name} LIMIT 20"
|
||
|
||
def _execute_sql_query(self, connection, sql_query: str, page: int, page_size: int) -> Dict[str, Any]:
|
||
"""
|
||
执行SQL查询
|
||
"""
|
||
cursor = connection.cursor()
|
||
|
||
try:
|
||
# 执行查询
|
||
cursor.execute(sql_query)
|
||
|
||
# 获取列名
|
||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||
|
||
# 获取所有结果
|
||
all_results = cursor.fetchall()
|
||
total = len(all_results)
|
||
|
||
# 分页
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = start_idx + page_size
|
||
paginated_results = all_results[start_idx:end_idx]
|
||
|
||
# 转换为字典格式
|
||
data = []
|
||
for row in paginated_results:
|
||
row_dict = {}
|
||
for i, value in enumerate(row):
|
||
if i < len(columns):
|
||
row_dict[columns[i]] = value
|
||
data.append(row_dict)
|
||
|
||
return {
|
||
'data': data,
|
||
'columns': columns,
|
||
'total': total,
|
||
'page': page,
|
||
'page_size': page_size
|
||
}
|
||
|
||
finally:
|
||
cursor.close()
|
||
|
||
async def _generate_db_summary(self, query: str, result: Dict[str, Any], table_name: str) -> str:
|
||
"""
|
||
生成数据库查询总结
|
||
"""
|
||
try:
|
||
llm = ChatZhipuAI(
|
||
model="glm-4",
|
||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||
temperature=0.3
|
||
)
|
||
|
||
prompt = f"""
|
||
用户查询: {query}
|
||
目标表: {table_name}
|
||
|
||
查询结果:
|
||
- 结果行数: {result['total']}
|
||
- 结果列数: {len(result['columns'])}
|
||
- 列名: {', '.join(result['columns'])}
|
||
|
||
请基于以上信息,用中文生成一个简洁的数据库查询分析总结,包括:
|
||
1. 查询的主要目的
|
||
2. 关键数据发现
|
||
3. 数据特征分析
|
||
4. 建议的后续查询方向
|
||
|
||
总结应该专业、准确、易懂,控制在200字以内。
|
||
"""
|
||
|
||
response = await self._run_in_executor(
|
||
lambda: llm.invoke([HumanMessage(content=prompt)])
|
||
)
|
||
|
||
return response.content
|
||
|
||
except Exception as e:
|
||
return f"查询已完成,但生成总结时出现错误: {str(e)}"
|
||
|
||
# 在 SmartQueryService 类中添加方法
|
||
|
||
from .table_metadata_service import TableMetadataService
|
||
|
||
class SmartQueryService:
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.table_metadata_service = None
|
||
|
||
def set_db_session(self, db_session):
|
||
"""设置数据库会话"""
|
||
self.table_metadata_service = TableMetadataService(db_session)
|
||
|
||
async def get_database_context(self, user_id: int, query: str) -> str:
|
||
"""获取数据库上下文信息用于问答"""
|
||
if not self.table_metadata_service:
|
||
return ""
|
||
|
||
try:
|
||
# 获取用户的表元数据
|
||
table_metadata_list = self.table_metadata_service.get_user_table_metadata(user_id)
|
||
|
||
if not table_metadata_list:
|
||
return ""
|
||
|
||
# 构建数据库上下文
|
||
context_parts = []
|
||
context_parts.append("=== 数据库表信息 ===")
|
||
|
||
for metadata in table_metadata_list:
|
||
table_info = []
|
||
table_info.append(f"表名: {metadata.table_name}")
|
||
|
||
if metadata.table_comment:
|
||
table_info.append(f"表描述: {metadata.table_comment}")
|
||
|
||
if metadata.qa_description:
|
||
table_info.append(f"业务说明: {metadata.qa_description}")
|
||
|
||
# 添加列信息
|
||
if metadata.columns_info:
|
||
columns = []
|
||
for col in metadata.columns_info:
|
||
col_desc = f"{col['column_name']} ({col['data_type']})"
|
||
if col.get('column_comment'):
|
||
col_desc += f" - {col['column_comment']}"
|
||
columns.append(col_desc)
|
||
table_info.append(f"字段: {', '.join(columns)}")
|
||
|
||
# 添加示例数据
|
||
if metadata.sample_data:
|
||
table_info.append(f"示例数据: {metadata.sample_data[:2]}")
|
||
|
||
table_info.append(f"总行数: {metadata.row_count}")
|
||
|
||
context_parts.append("\n".join(table_info))
|
||
context_parts.append("---")
|
||
|
||
return "\n".join(context_parts)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取数据库上下文失败: {str(e)}")
|
||
return ""
|
||
|
||
async def execute_smart_query(self, query: str, user_id: int, **kwargs) -> Dict[str, Any]:
|
||
"""执行智能查询(集成表元数据)"""
|
||
try:
|
||
# 获取数据库上下文
|
||
db_context = await self.get_database_context(user_id, query)
|
||
|
||
# 构建增强的提示词
|
||
enhanced_prompt = f"""
|
||
{db_context}
|
||
|
||
用户问题: {query}
|
||
|
||
请基于上述数据库表信息,生成相应的SQL查询语句。
|
||
注意:
|
||
1. 使用准确的表名和字段名
|
||
2. 考虑数据类型和约束
|
||
3. 参考示例数据理解数据格式
|
||
4. 生成高效的查询语句
|
||
"""
|
||
|
||
# 调用原有的查询逻辑
|
||
return await super().execute_smart_query(enhanced_prompt, user_id, **kwargs)
|
||
|
||
except Exception as e:
|
||
logger.error(f"智能查询失败: {str(e)}")
|
||
return {
|
||
'success': False,
|
||
'message': f"查询失败: {str(e)}"
|
||
} |