hyf-backend/th_agenter/db/database.py

155 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Database connection and session management."""
import uuid, re
from loguru import logger
import traceback
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from typing import Optional
from utils.general import gradient_text
from ..core.config import get_settings
from .base import Base
from utils.util_exceptions import DatabaseError
# Custom Session class with desc property and unique ID
class DrSession(AsyncSession):
"""Custom Session class with desc property and unique ID."""
def __init__(self, **kwargs):
"""Initialize DrSession with unique ID."""
super().__init__(**kwargs)
self.title = ""
self.descs = []
# 确保info属性存在
if not hasattr(self, 'info'):
self.info = {}
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
self.stepIndex = 0
@property
def title(self) -> Optional[str]:
"""Get work brief from session info."""
return self.info.get('title')
@title.setter
def title(self, value: str) -> None:
"""Set work brief in session info."""
if('title' not in self.info or self.info['title'].strip() == ""):
self.info['title'] = value # 确保title属性存在
else:
self.info['title'] = value + " >>> " + self.info['title']
@property
def desc(self) -> Optional[str]:
"""Get work brief from session info."""
return self.info.get('desc')
@desc.setter
def desc(self, value: str) -> None:
"""Set work brief in session info."""
self.stepIndex += 1
# 统一在这里打印更详细的 session 日志,方便排查问题
try:
# level 取 -3可以拿到触发 desc 设置的上层业务代码位置
pos = self.parse_source_pos(-3)
except Exception:
pos = "unknown"
logger.info(f"{self.log_prefix()} STEP[{self.stepIndex}] {value} >>> @ {pos}")
def log_prefix(self) -> str:
"""Get log prefix with session ID and desc."""
return f"〖Session{self.info['session_id']}"
def parse_source_pos(self, level: int):
pos = (traceback.format_stack())[level].strip().split('\n')[0]
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
if match:
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
pos = f"{file}:{match.group(2)} in {match.group(3)}"
return pos
def log_info(self, msg: str, level: int = -2):
"""Log info message with session ID."""
pos = self.parse_source_pos(level)
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
def log_success(self, msg: str, level: int = -2):
"""Log success message with session ID."""
pos = self.parse_source_pos(level)
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
def log_warning(self, msg: str, level: int = -2):
"""Log warning message with session ID."""
pos = self.parse_source_pos(level)
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
def log_error(self, msg: str, level: int = -2):
"""Log error message with session ID."""
pos = self.parse_source_pos(level)
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
def log_exception(self, msg: str, level: int = -2):
"""Log exception message with session ID."""
pos = self.parse_source_pos(level)
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
engine_async = create_async_engine(
get_settings().database.url,
echo=False, # get_settings().database.echo,
future=True,
pool_size=get_settings().database.pool_size,
max_overflow=get_settings().database.max_overflow,
pool_pre_ping=True,
pool_recycle=3600,
)
from fastapi import HTTPException, Request
AsyncSessionFactory = sessionmaker(
bind=engine_async,
class_=DrSession,
expire_on_commit=False,
autoflush=True
)
async def get_session(request: Request = None):
url = "无request"
if request:
url = f"{request.method} {request.url.path}"# .split("://")[-1]
# session = AsyncSessionFactory()
print(url)
# 取得request的来源IP
if request:
client_host = request.client.host
else:
client_host = "无request"
# 使用 AsyncSessionFactory 创建会话,确保 async/greenlet 配置正确
#(包括 expire_on_commit=False避免在属性访问时触发隐式 IO导致 MissingGreenlet / pk_1 参数异常)
session: DrSession = AsyncSessionFactory()
session.title = f"{url} - {client_host}"
# 设置request属性
if request:
session.request = request
try:
yield session
except Exception as e:
errMsg = f"数据库 session 异常 >>> {e}"
# 先打带堆栈的异常日志
session.log_exception(errMsg)
# 再通过 desc 打一条结构化的 info 日志(含步骤、调用位置)
session.desc = f"EXCEPTION: {errMsg}"
await session.rollback()
# 重新抛出原始异常,不转换为 HTTPException
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
finally:
# session.desc = f"数据库 session 关闭"
session.desc = ""
await session.close()