"""Database base model.""" from datetime import datetime from sqlalchemy import Integer, DateTime, event from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session from sqlalchemy.sql import func from typing import Optional from sqlalchemy import MetaData class Base(DeclarativeBase): metadata = MetaData( naming_convention={ # ix: index, 索引 "ix": "ix_%(column_0_label)s", # uq: unique, 唯一约束 "uq": "uq_%(table_name)s_%(column_0_name)s", # ck: check, 检查约束 "ck": "ck_%(table_name)s_%(constraint_name)s", # fk: foreign key, 外键约束 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", # pk: primary key, 主键约束 "pk": "pk_%(table_name)s" } ) class BaseModel(Base): """Base model with common fields.""" __abstract__ = True id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) def to_dict(self): """Convert model to dictionary.""" return { column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name) for column in self.__table__.columns } @classmethod def from_dict(cls, data: dict): """Create model instance from dictionary. Args: data: Dictionary containing model field values Returns: Model instance created from the dictionary """ # Filter out fields that don't exist in the model model_fields = {column.name for column in cls.__table__.columns} filtered_data = {key: value for key, value in data.items() if key in model_fields} # Create and return the instance return cls(**filtered_data) def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False): """对创建/更新操作设置created_by/updated_by字段。 Args: user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取) is_update: True 表示更新操作,False 表示创建操作 """ # 如果未提供user_id,则从上下文获取 if user_id is None: from ..core.context import UserContext try: user_id = UserContext.get_current_user_id() except Exception: # 如果上下文没有用户ID,则跳过设置审计字段 return # 如果仍未提供user_id,则跳过设置审计字段 if user_id is None: return if not is_update: # 对于创建操作,同时设置created_by和updated_by self.created_by = user_id self.updated_by = user_id else: # 对于更新操作,仅设置updated_by self.updated_by = user_id # @event.listens_for(Session, 'before_flush') # def set_audit_fields_before_flush(session, flush_context, instances): # """Automatically set audit fields before flush.""" # try: # from th_agenter.core.context import UserContext # user_id = UserContext.get_current_user_id() # except Exception: # user_id = None # # 处理新增对象 # for instance in session.new: # if isinstance(instance, BaseModel) and user_id: # instance.created_by = user_id # instance.updated_by = user_id # # 处理修改对象 # for instance in session.dirty: # if isinstance(instance, BaseModel) and user_id: # instance.updated_by = user_id # # def __init__(self, **kwargs): # # """Initialize model with automatic audit fields setting.""" # # super().__init__(**kwargs) # # # Set audit fields for new instances # # self.set_audit_fields() # # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False): # # """Set audit fields for create/update operations. # # Args: # # user_id: ID of the user performing the operation (optional, will use context if not provided) # # is_update: True for update operations, False for create operations # # """ # # # Get user_id from context if not provided # # if user_id is None: # # from ..core.context import UserContext # # try: # # user_id = UserContext.get_current_user_id() # # except Exception: # # # If no user in context, skip setting audit fields # # return # # # Skip if still no user_id # # if user_id is None: # # return # # if not is_update: # # # For create operations, set both create_by and update_by # # self.created_by = user_id # # self.updated_by = user_id # # else: # # # For update operations, only set update_by # # self.updated_by = user_id