"""Database connection and session management.""" import uuid, re from loguru import logger import traceback from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from typing import Optional from ..core.config import get_settings from .base import Base from utils.util_exceptions import DatabaseError # Custom Session class with desc property and unique ID class DrSession(AsyncSession): """Custom Session class with desc property and unique ID.""" def __init__(self, **kwargs): """Initialize DrSession with unique ID.""" super().__init__(**kwargs) self.title = "" self.descs = [] # 确保info属性存在 if not hasattr(self, 'info'): self.info = {} self.info['session_id'] = str(uuid.uuid4()).split('-')[0] self.stepIndex = 0 @property def title(self) -> Optional[str]: """Get work brief from session info.""" return self.info.get('title') @title.setter def title(self, value: str) -> None: """Set work brief in session info.""" if('title' not in self.info or self.info['title'].strip() == ""): self.info['title'] = value # 确保title属性存在 else: self.info['title'] = value + " >>> " + self.info['title'] @property def desc(self) -> Optional[str]: """Get work brief from session info.""" return self.info.get('desc') @desc.setter def desc(self, value: str) -> None: """Set work brief in session info.""" self.stepIndex += 1 match = re.search(r";(-\d+)", value); level = -3 if match: level = int(match.group(1)) value = value.replace(f";{level}", "") level = -3 + level if "警告" in value or value.startswith("WARNING"): self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level) elif "异常" in value or value.startswith("EXCEPTION"): self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level) elif "成功" in value or value.startswith("SUCCESS"): self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) elif "开始" in value or value.startswith("START"): self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) elif "失败" in value or value.startswith("ERROR"): self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level) else: self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level) def log_prefix(self) -> str: """Get log prefix with session ID and desc.""" return f"〖Session{self.info['session_id']}〗" def parse_source_pos(self, level: int): pos = (traceback.format_stack())[level - 1].strip().split('\n')[0] match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos); if match: file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "") pos = f"{file}:{match.group(2)} in {match.group(3)}" return pos def log_info(self, msg: str, level: int = -2): """Log info message with session ID.""" pos = self.parse_source_pos(level) logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}") def log_success(self, msg: str, level: int = -2): """Log success message with session ID.""" pos = self.parse_source_pos(level) logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}") def log_warning(self, msg: str, level: int = -2): """Log warning message with session ID.""" pos = self.parse_source_pos(level) logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}") def log_error(self, msg: str, level: int = -2): """Log error message with session ID.""" pos = self.parse_source_pos(level) logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}") def log_exception(self, msg: str, level: int = -2): """Log exception message with session ID.""" pos = self.parse_source_pos(level) logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}") engine_async = create_async_engine( get_settings().database.url, echo=False, # get_settings().database.echo, future=True, pool_size=get_settings().database.pool_size, max_overflow=get_settings().database.max_overflow, pool_pre_ping=True, pool_recycle=3600, ) from fastapi import HTTPException, Request AsyncSessionFactory = sessionmaker( bind=engine_async, class_=DrSession, expire_on_commit=False, autoflush=True ) async def get_session(request: Request = None): url = "无request" if request: url = f"{request.method} {request.url.path}"# .split("://")[-1] # session = AsyncSessionFactory() print(url) # 取得request的来源IP if request: client_host = request.client.host else: client_host = "无request" session = DrSession(bind=engine_async) session.title = f"{url} - {client_host}" # 设置request属性 if request: session.request = request try: yield session except Exception as e: errMsg = f"数据库 session 异常 >>> {e}" session.desc = f"EXCEPTION: {errMsg}" await session.rollback() # 重新抛出原始异常,不转换为 HTTPException raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常 finally: # session.desc = f"数据库 session 关闭" session.desc = "" await session.close()