hyf-backend/th_agenter/services/mcp/postgresql_mcp.py

385 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PostgreSQL MCP (Model Context Protocol) tool for database operations."""
import json
import psycopg2
from typing import List, Dict, Any, Optional
from datetime import datetime
from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult
class PostgreSQLMCPTool(BaseTool):
"""PostgreSQL MCP tool for database operations and intelligent querying."""
def __init__(self):
super().__init__()
self.connections = {} # 存储用户的数据库连接
def get_name(self) -> str:
return "postgresql_mcp"
def get_description(self) -> str:
return "PostgreSQL MCP服务工具提供数据库连接、表结构查询、SQL执行等功能支持智能数据问答。"
def get_parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ToolParameterType.STRING,
description="操作类型",
required=True,
enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"]
),
ToolParameter(
name="connection_config",
type=ToolParameterType.OBJECT,
description="数据库连接配置 {host, port, database, username, password}",
required=False
),
ToolParameter(
name="user_id",
type=ToolParameterType.STRING,
description="用户ID用于管理连接",
required=False
),
ToolParameter(
name="table_name",
type=ToolParameterType.STRING,
description="表名用于describe_table操作",
required=False
),
ToolParameter(
name="sql_query",
type=ToolParameterType.STRING,
description="SQL查询语句用于execute_query操作",
required=False
),
ToolParameter(
name="limit",
type=ToolParameterType.INTEGER,
description="查询结果限制数量默认100",
required=False,
default=100
)
]
def _create_connection(self, config: Dict[str, Any]) -> psycopg2.extensions.connection:
"""创建PostgreSQL数据库连接"""
try:
connection = psycopg2.connect(
host=config['host'],
port=int(config.get('port', 5432)),
user=config['username'],
password=config['password'],
database=config['database'],
connect_timeout=10
)
return connection
except Exception as e:
raise Exception(f"PostgreSQL连接失败: {str(e)}")
def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""测试数据库连接"""
try:
conn = self._create_connection(config)
cursor = conn.cursor()
# 获取数据库版本信息
cursor.execute("SELECT version();")
version = cursor.fetchone()[0]
# 检查pgvector扩展
cursor.execute("SELECT * FROM pg_extension WHERE extname = 'vector';")
has_vector = bool(cursor.fetchall())
cursor.close()
conn.close()
return {
"success": True,
"version": version,
"has_pgvector": has_vector,
"message": "连接测试成功"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"message": "连接测试失败"
}
def _get_tables(self, connection) -> List[Dict[str, Any]]:
"""获取数据库表列表"""
cursor = connection.cursor()
try:
cursor.execute("""
SELECT
table_name,
table_type,
table_schema
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name;
""")
tables = []
for row in cursor.fetchall():
tables.append({
"table_name": row[0],
"table_type": row[1],
"table_schema": row[2]
})
return tables
finally:
cursor.close()
def _describe_table(self, connection, table_name: str) -> Dict[str, Any]:
"""获取表结构信息"""
cursor = connection.cursor()
try:
# 获取列信息
cursor.execute("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
FROM information_schema.columns
WHERE table_name = %s AND table_schema = 'public'
ORDER BY ordinal_position;
""", (table_name,))
columns = []
for row in cursor.fetchall():
columns.append({
"column_name": row[0],
"data_type": row[1],
"is_nullable": row[2],
"column_default": row[3],
"character_maximum_length": row[4],
"numeric_precision": row[5],
"numeric_scale": row[6]
})
# 获取主键信息
cursor.execute("""
SELECT column_name
FROM information_schema.key_column_usage
WHERE table_name = %s AND table_schema = 'public'
AND constraint_name IN (
SELECT constraint_name
FROM information_schema.table_constraints
WHERE table_name = %s AND constraint_type = 'PRIMARY KEY'
);
""", (table_name, table_name))
primary_keys = [row[0] for row in cursor.fetchall()]
# 获取表行数
cursor.execute(f"SELECT COUNT(*) FROM {table_name};")
row_count = cursor.fetchone()[0]
return {
"table_name": table_name,
"columns": columns,
"primary_keys": primary_keys,
"row_count": row_count
}
finally:
cursor.close()
def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]:
"""执行SQL查询"""
cursor = connection.cursor()
try:
# 添加LIMIT限制如果查询中没有
if limit and "LIMIT" not in sql_query.upper():
sql_query = f"{sql_query.rstrip(';')} LIMIT {limit};"
cursor.execute(sql_query)
# 获取列名
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# 获取结果
if cursor.description: # SELECT查询
rows = cursor.fetchall()
data = []
for row in rows:
row_dict = {}
for i, value in enumerate(row):
if i < len(columns):
# 处理特殊数据类型
if isinstance(value, datetime):
row_dict[columns[i]] = value.isoformat()
else:
row_dict[columns[i]] = value
data.append(row_dict)
return {
"success": True,
"data": data,
"columns": columns,
"row_count": len(data),
"query": sql_query
}
else: # INSERT/UPDATE/DELETE查询
affected_rows = cursor.rowcount
return {
"success": True,
"affected_rows": affected_rows,
"query": sql_query,
"message": f"查询执行成功,影响 {affected_rows}"
}
finally:
cursor.close()
async def execute(self, operation: str, connection_config: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None, table_name: Optional[str] = None,
sql_query: Optional[str] = None, limit: int = 100) -> ToolResult:
"""执行PostgreSQL MCP操作"""
try:
logger.info(f"执行PostgreSQL MCP操作: {operation}")
if operation == "test_connection":
if not connection_config:
return ToolResult(
success=False,
error="缺少连接配置参数"
)
result = self._test_connection(connection_config)
return ToolResult(
success=result["success"],
result=result,
error=result.get("error")
)
elif operation == "connect":
if not connection_config or not user_id:
return ToolResult(
success=False,
error="缺少连接配置或用户ID参数"
)
try:
connection = self._create_connection(connection_config)
self.connections[user_id] = {
"connection": connection,
"config": connection_config,
"connected_at": datetime.now().isoformat()
}
# 获取表列表
tables = self._get_tables(connection)
return ToolResult(
success=True,
result={
"message": "数据库连接成功",
"database": connection_config["database"],
"tables": tables,
"table_count": len(tables)
}
)
except Exception as e:
return ToolResult(
success=False,
error=f"连接失败: {str(e)}"
)
elif operation == "list_tables":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
connection = self.connections[user_id]["connection"]
tables = self._get_tables(connection)
return ToolResult(
success=True,
result={
"tables": tables,
"table_count": len(tables)
}
)
elif operation == "describe_table":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
if not table_name:
return ToolResult(
success=False,
error="缺少table_name参数"
)
connection = self.connections[user_id]["connection"]
table_info = self._describe_table(connection, table_name)
return ToolResult(
success=True,
result=table_info
)
elif operation == "execute_query":
if not user_id or user_id not in self.connections:
return ToolResult(
success=False,
error="用户未连接数据库请先执行connect操作"
)
if not sql_query:
return ToolResult(
success=False,
error="缺少sql_query参数"
)
connection = self.connections[user_id]["connection"]
query_result = self._execute_query(connection, sql_query, limit)
return ToolResult(
success=True,
result=query_result
)
elif operation == "disconnect":
if user_id and user_id in self.connections:
try:
self.connections[user_id]["connection"].close()
del self.connections[user_id]
return ToolResult(
success=True,
result={"message": "数据库连接已断开"}
)
except Exception as e:
return ToolResult(
success=False,
error=f"断开连接失败: {str(e)}"
)
else:
return ToolResult(
success=True,
result={"message": "用户未连接数据库"}
)
else:
return ToolResult(
success=False,
error=f"不支持的操作类型: {operation}"
)
except Exception as e:
logger.error(f"PostgreSQL MCP工具执行失败: {str(e)}", exc_info=True)
return ToolResult(
success=False,
error=f"工具执行失败: {str(e)}"
)