144 lines
5.4 KiB
Python
144 lines
5.4 KiB
Python
"""Database base model."""
|
||
|
||
from datetime import datetime
|
||
from sqlalchemy import Integer, DateTime, event
|
||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
|
||
from sqlalchemy.sql import func
|
||
from typing import Optional
|
||
from sqlalchemy import MetaData
|
||
|
||
class Base(DeclarativeBase):
|
||
metadata = MetaData(
|
||
naming_convention={
|
||
# ix: index, 索引
|
||
"ix": "ix_%(column_0_label)s",
|
||
# uq: unique, 唯一约束
|
||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||
# ck: check, 检查约束
|
||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||
# fk: foreign key, 外键约束
|
||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||
# pk: primary key, 主键约束
|
||
"pk": "pk_%(table_name)s"
|
||
}
|
||
)
|
||
|
||
|
||
class BaseModel(Base):
|
||
"""Base model with common fields."""
|
||
|
||
__abstract__ = True
|
||
|
||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False)
|
||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||
created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||
updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||
|
||
def to_dict(self):
|
||
"""Convert model to dictionary."""
|
||
return {
|
||
column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name)
|
||
for column in self.__table__.columns
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict):
|
||
"""Create model instance from dictionary.
|
||
|
||
Args:
|
||
data: Dictionary containing model field values
|
||
|
||
Returns:
|
||
Model instance created from the dictionary
|
||
"""
|
||
# Filter out fields that don't exist in the model
|
||
model_fields = {column.name for column in cls.__table__.columns}
|
||
filtered_data = {key: value for key, value in data.items() if key in model_fields}
|
||
|
||
# Create and return the instance
|
||
return cls(**filtered_data)
|
||
|
||
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||
"""对创建/更新操作设置created_by/updated_by字段。
|
||
|
||
Args:
|
||
user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取)
|
||
is_update: True 表示更新操作,False 表示创建操作
|
||
"""
|
||
# 如果未提供user_id,则从上下文获取
|
||
if user_id is None:
|
||
from ..core.context import UserContext
|
||
try:
|
||
user_id = UserContext.get_current_user_id()
|
||
except Exception:
|
||
# 如果上下文没有用户ID,则跳过设置审计字段
|
||
return
|
||
|
||
# 如果仍未提供user_id,则跳过设置审计字段
|
||
if user_id is None:
|
||
return
|
||
|
||
if not is_update:
|
||
# 对于创建操作,同时设置created_by和updated_by
|
||
self.created_by = user_id
|
||
self.updated_by = user_id
|
||
else:
|
||
# 对于更新操作,仅设置updated_by
|
||
self.updated_by = user_id
|
||
|
||
# @event.listens_for(Session, 'before_flush')
|
||
# def set_audit_fields_before_flush(session, flush_context, instances):
|
||
# """Automatically set audit fields before flush."""
|
||
# try:
|
||
# from th_agenter.core.context import UserContext
|
||
# user_id = UserContext.get_current_user_id()
|
||
# except Exception:
|
||
# user_id = None
|
||
|
||
# # 处理新增对象
|
||
# for instance in session.new:
|
||
# if isinstance(instance, BaseModel) and user_id:
|
||
# instance.created_by = user_id
|
||
# instance.updated_by = user_id
|
||
|
||
# # 处理修改对象
|
||
# for instance in session.dirty:
|
||
# if isinstance(instance, BaseModel) and user_id:
|
||
# instance.updated_by = user_id
|
||
|
||
# # def __init__(self, **kwargs):
|
||
# # """Initialize model with automatic audit fields setting."""
|
||
# # super().__init__(**kwargs)
|
||
# # # Set audit fields for new instances
|
||
# # self.set_audit_fields()
|
||
|
||
# # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||
# # """Set audit fields for create/update operations.
|
||
|
||
# # Args:
|
||
# # user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||
# # is_update: True for update operations, False for create operations
|
||
# # """
|
||
# # # Get user_id from context if not provided
|
||
# # if user_id is None:
|
||
# # from ..core.context import UserContext
|
||
# # try:
|
||
# # user_id = UserContext.get_current_user_id()
|
||
# # except Exception:
|
||
# # # If no user in context, skip setting audit fields
|
||
# # return
|
||
|
||
# # # Skip if still no user_id
|
||
# # if user_id is None:
|
||
# # return
|
||
|
||
# # if not is_update:
|
||
# # # For create operations, set both create_by and update_by
|
||
# # self.created_by = user_id
|
||
# # self.updated_by = user_id
|
||
# # else:
|
||
# # # For update operations, only set update_by
|
||
# # self.updated_by = user_id
|
||
|