hxf/backend/th_agenter/db/database.py

89 lines
2.5 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""Database connection and session management."""
2025-12-17 19:26:36 +08:00
import logging
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
2025-12-04 14:48:38 +08:00
from ..core.config import get_settings
from .base import Base
2025-12-17 19:26:36 +08:00
# Global variables
engine = None
SessionLocal = None
def create_database_engine():
"""Create database engine."""
global engine, SessionLocal
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
settings = get_settings()
database_url = settings.database.url
2025-12-16 13:55:16 +08:00
2025-12-17 19:26:36 +08:00
# Determine database type and configure engine
engine_kwargs = {
"echo": settings.database.echo,
}
2025-12-16 13:55:16 +08:00
2025-12-17 19:26:36 +08:00
if database_url.startswith("sqlite"):
# SQLite configuration
engine = create_engine(database_url, **engine_kwargs)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
logging.info(f"SQLite database engine created: {database_url}")
elif database_url.startswith("postgresql"):
# PostgreSQL configuration
engine_kwargs.update({
"pool_size": settings.database.pool_size,
"max_overflow": settings.database.max_overflow,
"pool_pre_ping": True,
"pool_recycle": 3600,
})
engine = create_engine(database_url, **engine_kwargs)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
logging.info(f"PostgreSQL database engine created: {database_url}")
else:
raise ValueError(f"Unsupported database type. Please use PostgreSQL or SQLite. URL: {database_url}")
async def init_db():
"""Initialize database."""
global engine
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
if engine is None:
create_database_engine()
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
# Import all models to ensure they are registered
from ..models import user, conversation, message, knowledge_base, permission, workflow
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
# Create all tables
Base.metadata.create_all(bind=engine)
logging.info("Database tables created")
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
def get_db() -> Generator[Session, None, None]:
"""Get database session."""
global SessionLocal
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
if SessionLocal is None:
create_database_engine()
2025-12-04 14:48:38 +08:00
2025-12-17 19:26:36 +08:00
db = SessionLocal()
2025-12-04 14:48:38 +08:00
try:
2025-12-17 19:26:36 +08:00
yield db
2025-12-04 14:48:38 +08:00
except Exception as e:
2025-12-17 19:26:36 +08:00
db.rollback()
logging.error(f"Database session error: {e}")
raise
2025-12-04 14:48:38 +08:00
finally:
2025-12-17 19:26:36 +08:00
db.close()
def get_db_session() -> Session:
"""Get database session (synchronous)."""
global SessionLocal
if SessionLocal is None:
create_database_engine()
return SessionLocal()