"""User service for managing user operations.""" from typing import Optional, List, Tuple from sqlalchemy.orm import Session from sqlalchemy import select, or_, desc, text from ..models.user import User from utils.util_schemas import UserCreate, UserUpdate from utils.util_exceptions import DatabaseError, ValidationError from .auth import AuthService class UserService: """Service for user management operations.""" def __init__(self, session: Session): ### Async: OK self.session = session self.session.desc = "创建UserService;-1" async def get_user_by_email(self, email: str) -> Optional[User]: ### Async: OK """Get user by email.""" self.session.desc = f"通过邮箱 [{email}] 获取用户" stmt = select(User).where(User.email == email) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_user_by_username(self, username: str) -> Optional[User]: ### Async: OK """Get user by username.""" self.session.desc = f"通过用户名 [{username}] 获取用户" stmt = select(User).where(User.username == username) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def create_user(self, user_data: UserCreate) -> User: ### Async: OK """Create a new user.""" self.session.desc = f"创建用户 [{user_data.username}]" # Validate input if len(user_data.password) < 6: self.session.desc = "ERROR: 密码长度必须至少为6个字符" raise ValidationError("密码长度必须至少为6个字符") # Hash password hashed_password = self.get_password_hash(user_data.password) self.session.desc = f"对密码 [{user_data.password}] 进行哈希处理完毕" # Create user db_user = User( username=user_data.username, email=user_data.email, hashed_password=hashed_password, full_name=user_data.full_name, is_active=True ) self.session.desc = f"创建用户 [{user_data.username}] 到数据库" self.session.add(db_user) self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - add" await self.session.commit() self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - commit" await self.session.refresh(db_user) self.session.desc = f"创建用户 [{user_data.username}] 成功" return db_user def get_password_hash(self, password: str) -> str: ### Async: OK """Hash a password.""" self.session.desc = f"对密码 [{password}] 进行哈希处理" return AuthService.get_password_hash(password) async def get_user_by_id(self, user_id: int) -> Optional[User]: ### DrGraph: OK """Get user by ID.""" self.session.desc = f"通过ID{user_id}获取用户" from sqlalchemy.orm import noload stmt = select(User).where(User.id == user_id).options(noload(User.roles)) result = await self.session.execute(stmt) return result.scalar_one_or_none() def verify_password(self, plain_password: str, hashed_password: str) -> bool: ### DrGraph: OK """Verify a password against its hash.""" self.session.desc = f"验证密码 [{plain_password}] 与哈希 [{hashed_password}] 是否匹配" return AuthService.verify_password(plain_password, hashed_password) async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[User]: ### DrGraph: OK """Update user information.""" self.session.desc = f"更新用户ID为{user_id}的信息" user = await self.get_user_by_id(user_id) if not user: return None # Update fields update_data = user_update.model_dump(exclude_unset=True) if "password" in update_data: update_data["hashed_password"] = self.get_password_hash(update_data.pop("password")) # session.desc = f"更新用户ID为{user_id}的信息" for field, value in update_data.items(): setattr(user, field, value) # Audit fields are set automatically by SQLAlchemy event listener await self.session.commit() await self.session.refresh(user) return user async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: ### DrGraph: OK """Change user password.""" self.session.desc = f"更改用户ID为{user_id}的密码" user = await self.get_user_by_id(user_id) if not user: raise ValidationError("User not found") # Verify current password if not self.verify_password(current_password, user.hashed_password): raise ValidationError("Current password is incorrect") # Validate new password if len(new_password) < 6: raise ValidationError("New password must be at least 6 characters long") # Hash new password hashed_password = self.get_password_hash(new_password) # Update password user.hashed_password = hashed_password await self.session.commit() self.session.desc = f"用户ID为{user_id}的密码已成功更改" return True def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK """Reset user password (admin only, no current password required).""" self.session.desc = f"重置用户ID为{user_id}的密码" user = self.get_user_by_id(user_id) if not user: raise ValidationError("User not found") # Validate new password if len(new_password) < 6: raise ValidationError("New password must be at least 6 characters long") # Hash new password hashed_password = self.get_password_hash(new_password) # Update password user.hashed_password = hashed_password self.session.commit() self.session.desc = f"用户ID为{user_id}的密码已成功重置" return True async def get_users_with_filters( ### DrGraph: OK self, skip: int = 0, limit: int = 100, search: Optional[str] = None, role_id: Optional[int] = None, is_active: Optional[bool] = None ) -> Tuple[List[User], int]: """Get users with filters and return total count.""" # Build base query stmt = select(User).order_by(desc(User.created_at)) # Apply filters if search: search_term = f"%{search}%" stmt = stmt.where( or_( User.username.ilike(search_term), User.email.ilike(search_term), User.full_name.ilike(search_term) ) ) if role_id is not None: from ..models.permission import UserRole stmt = stmt.join(UserRole).where(UserRole.role_id == role_id) if is_active is not None: stmt = stmt.where(User.is_active == is_active) # Get total count count_stmt = select(text("COUNT(*)")).select_from(stmt.subquery()) total_result = await self.session.execute(count_stmt) total = total_result.scalar_one() self.session.desc = f"获取用户总数为{total}" # Apply pagination stmt = stmt.offset(skip).limit(limit) users_result = await self.session.execute(stmt) users = users_result.scalars().all() return users, total async def get_users(self, skip: int = 0, limit: int = 100) -> List[User]: """Get all users with pagination.""" # session.desc = f"分页获取用户列表,跳过{skip}条,限制{limit}条" stmt = select(User).offset(skip).limit(limit) result = await self.session.execute(stmt) return result.scalars().all() async def delete_user(self, user_id: int) -> bool: ### DrGraph: OK """删除一个用户.""" self.session.desc = f"删除ID为{user_id}用户" user = await self.get_user_by_id(user_id) if not user: return False # Manually delete related records to avoid cascade issues # Delete user_roles records await self.session.execute(text("DELETE FROM user_roles WHERE user_id = :user_id"), parameters={"user_id": user_id}) # Now delete the user await self.session.delete(user) await self.session.commit() self.session.desc = f"用户ID为{user_id}已成功删除" return True