159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
"""Database connection and session management."""
|
|
|
|
import uuid, re
|
|
from loguru import logger
|
|
import traceback
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from typing import Optional
|
|
|
|
from ..core.config import get_settings
|
|
from .base import Base
|
|
from utils.util_exceptions import DatabaseError
|
|
|
|
# Custom Session class with desc property and unique ID
|
|
class DrSession(AsyncSession):
|
|
"""Custom Session class with desc property and unique ID."""
|
|
|
|
def __init__(self, **kwargs):
|
|
"""Initialize DrSession with unique ID."""
|
|
super().__init__(**kwargs)
|
|
self.title = ""
|
|
self.descs = []
|
|
# 确保info属性存在
|
|
if not hasattr(self, 'info'):
|
|
self.info = {}
|
|
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
|
self.stepIndex = 0
|
|
|
|
@property
|
|
def title(self) -> Optional[str]:
|
|
"""Get work brief from session info."""
|
|
return self.info.get('title')
|
|
|
|
@title.setter
|
|
def title(self, value: str) -> None:
|
|
"""Set work brief in session info."""
|
|
if('title' not in self.info or self.info['title'].strip() == ""):
|
|
self.info['title'] = value # 确保title属性存在
|
|
else:
|
|
self.info['title'] = value + " >>> " + self.info['title']
|
|
|
|
@property
|
|
def desc(self) -> Optional[str]:
|
|
"""Get work brief from session info."""
|
|
return self.info.get('desc')
|
|
|
|
@desc.setter
|
|
def desc(self, value: str) -> None:
|
|
"""Set work brief in session info."""
|
|
self.stepIndex += 1
|
|
|
|
match = re.search(r";(-\d+)", value);
|
|
level = -3
|
|
if match:
|
|
level = int(match.group(1))
|
|
value = value.replace(f";{level}", "")
|
|
level = -3 + level
|
|
|
|
if "警告" in value or value.startswith("WARNING"):
|
|
self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
elif "异常" in value or value.startswith("EXCEPTION"):
|
|
self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
elif "成功" in value or value.startswith("SUCCESS"):
|
|
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
elif "开始" in value or value.startswith("START"):
|
|
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
elif "失败" in value or value.startswith("ERROR"):
|
|
self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
else:
|
|
self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level)
|
|
|
|
def log_prefix(self) -> str:
|
|
"""Get log prefix with session ID and desc."""
|
|
return f"〖Session{self.info['session_id']}〗"
|
|
|
|
def parse_source_pos(self, level: int):
|
|
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
|
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
|
if match:
|
|
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
|
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
|
return pos
|
|
|
|
def log_info(self, msg: str, level: int = -2):
|
|
"""Log info message with session ID."""
|
|
pos = self.parse_source_pos(level)
|
|
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
|
|
|
def log_success(self, msg: str, level: int = -2):
|
|
"""Log success message with session ID."""
|
|
pos = self.parse_source_pos(level)
|
|
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
|
|
|
def log_warning(self, msg: str, level: int = -2):
|
|
"""Log warning message with session ID."""
|
|
pos = self.parse_source_pos(level)
|
|
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
|
|
|
def log_error(self, msg: str, level: int = -2):
|
|
"""Log error message with session ID."""
|
|
pos = self.parse_source_pos(level)
|
|
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
|
|
|
def log_exception(self, msg: str, level: int = -2):
|
|
"""Log exception message with session ID."""
|
|
pos = self.parse_source_pos(level)
|
|
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
|
|
|
engine_async = create_async_engine(
|
|
get_settings().database.url,
|
|
echo=False, # get_settings().database.echo,
|
|
future=True,
|
|
pool_size=get_settings().database.pool_size,
|
|
max_overflow=get_settings().database.max_overflow,
|
|
pool_pre_ping=True,
|
|
pool_recycle=3600,
|
|
)
|
|
from fastapi import HTTPException, Request
|
|
|
|
AsyncSessionFactory = sessionmaker(
|
|
bind=engine_async,
|
|
class_=DrSession,
|
|
expire_on_commit=False,
|
|
autoflush=True
|
|
)
|
|
|
|
async def get_session(request: Request = None):
|
|
url = "无request"
|
|
if request:
|
|
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
|
# session = AsyncSessionFactory()
|
|
print(url)
|
|
# 取得request的来源IP
|
|
if request:
|
|
client_host = request.client.host
|
|
else:
|
|
client_host = "无request"
|
|
session = DrSession(bind=engine_async)
|
|
|
|
session.title = f"{url} - {client_host}"
|
|
|
|
# 设置request属性
|
|
if request:
|
|
session.request = request
|
|
|
|
try:
|
|
yield session
|
|
|
|
except Exception as e:
|
|
errMsg = f"数据库 session 异常 >>> {e}"
|
|
session.desc = f"EXCEPTION: {errMsg}"
|
|
await session.rollback()
|
|
# 重新抛出原始异常,不转换为 HTTPException
|
|
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
|
|
finally:
|
|
# session.desc = f"数据库 session 关闭"
|
|
session.desc = ""
|
|
await session.close()
|