hxf/backend/th_agenter/db/database.py

119 lines
3.8 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""Database connection and session management."""
2025-12-16 13:55:16 +08:00
import uuid, re
from loguru import logger
import traceback
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from typing import Optional
2025-12-04 14:48:38 +08:00
from ..core.config import get_settings
from .base import Base
2025-12-16 13:55:16 +08:00
from utils.util_exceptions import DatabaseError
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
# Custom Session class with desc property and unique ID
class DrSession(AsyncSession):
"""Custom Session class with desc property and unique ID."""
def __init__(self, **kwargs):
"""Initialize DrSession with unique ID."""
super().__init__(**kwargs)
# 确保info属性存在
if not hasattr(self, 'info'):
self.info = {}
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
self.stepIndex = 0
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@property
def desc(self) -> Optional[str]:
"""Get work brief from session info."""
return self.info.get('desc')
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
@desc.setter
def desc(self, value: str) -> None:
"""Set work brief in session info."""
self.stepIndex += 1
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def log_prefix(self) -> str:
"""Get log prefix with session ID and desc."""
return f"〖Session{self.info['session_id']}"
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def parse_source_pos(self, level: int):
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
if match:
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
pos = f"{file}:{match.group(2)} in {match.group(3)}"
return pos
def log_info(self, msg: str, level: int = -2):
"""Log info message with session ID."""
pos = self.parse_source_pos(level)
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
def log_success(self, msg: str, level: int = -2):
"""Log success message with session ID."""
pos = self.parse_source_pos(level)
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def log_warning(self, msg: str, level: int = -2):
"""Log warning message with session ID."""
pos = self.parse_source_pos(level)
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def log_error(self, msg: str, level: int = -2):
"""Log error message with session ID."""
pos = self.parse_source_pos(level)
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def log_exception(self, msg: str, level: int = -2):
"""Log exception message with session ID."""
pos = self.parse_source_pos(level)
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
engine_async = create_async_engine(
get_settings().database.url,
echo=True, # get_settings().database.echo,
future=True,
pool_size=get_settings().database.pool_size,
max_overflow=get_settings().database.max_overflow,
pool_pre_ping=True,
pool_recycle=3600,
)
from fastapi import Request
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
AsyncSessionFactory = sessionmaker(
bind=engine_async,
class_=DrSession,
expire_on_commit=False,
autoflush=True
)
async def get_session(request: Request = None):
url = "无request"
if request:
url = f"{request.method} {request.url.path}"# .split("://")[-1]
# session = AsyncSessionFactory()
session = DrSession(bind=engine_async)
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
session.desc = f"SUCCESS: 创建数据库 session >>> {url}"
# 设置request属性
if request:
session.request = request
2025-12-04 14:48:38 +08:00
try:
2025-12-16 13:55:16 +08:00
yield session
2025-12-04 14:48:38 +08:00
except Exception as e:
2025-12-16 13:55:16 +08:00
errMsg = f"数据库 session 异常 >>> {e}"
session.desc = f"EXCEPTION: {errMsg}"
await session.rollback()
raise e
# DatabaseError(e)
2025-12-04 14:48:38 +08:00
finally:
2025-12-16 13:55:16 +08:00
session.desc = f"数据库 session 关闭"
await session.close()