2025-12-04 14:48:38 +08:00
|
|
|
"""User service for managing user operations."""
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Tuple
|
|
|
|
|
from sqlalchemy.orm import Session
|
2025-12-16 13:55:16 +08:00
|
|
|
from sqlalchemy import select, or_, desc, text
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
from ..models.user import User
|
2025-12-16 13:55:16 +08:00
|
|
|
from utils.util_schemas import UserCreate, UserUpdate
|
|
|
|
|
from utils.util_exceptions import DatabaseError, ValidationError
|
2025-12-04 14:48:38 +08:00
|
|
|
from .auth import AuthService
|
|
|
|
|
|
|
|
|
|
class UserService:
|
2025-12-16 13:55:16 +08:00
|
|
|
"""Service for user management operations."""
|
|
|
|
|
def __init__(self, session: Session): ### Async: OK
|
|
|
|
|
self.session = session
|
|
|
|
|
self.session.desc = "创建UserService;-1"
|
|
|
|
|
|
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[User]: ### Async: OK
|
|
|
|
|
"""Get user by email."""
|
|
|
|
|
self.session.desc = f"通过邮箱 [{email}] 获取用户"
|
|
|
|
|
stmt = select(User).where(User.email == email)
|
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
|
return result.scalar_one_or_none()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
async def get_user_by_username(self, username: str) -> Optional[User]: ### Async: OK
|
|
|
|
|
"""Get user by username."""
|
|
|
|
|
self.session.desc = f"通过用户名 [{username}] 获取用户"
|
|
|
|
|
stmt = select(User).where(User.username == username)
|
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def create_user(self, user_data: UserCreate) -> User: ### Async: OK
|
|
|
|
|
"""Create a new user."""
|
|
|
|
|
self.session.desc = f"创建用户 [{user_data.username}]"
|
|
|
|
|
# Validate input
|
|
|
|
|
if len(user_data.password) < 6:
|
|
|
|
|
self.session.desc = "ERROR: 密码长度必须至少为6个字符"
|
|
|
|
|
raise ValidationError("密码长度必须至少为6个字符")
|
|
|
|
|
|
|
|
|
|
# Hash password
|
|
|
|
|
hashed_password = self.get_password_hash(user_data.password)
|
|
|
|
|
|
|
|
|
|
self.session.desc = f"对密码 [{user_data.password}] 进行哈希处理完毕"
|
|
|
|
|
# Create user
|
|
|
|
|
db_user = User(
|
|
|
|
|
username=user_data.username,
|
|
|
|
|
email=user_data.email,
|
|
|
|
|
hashed_password=hashed_password,
|
|
|
|
|
full_name=user_data.full_name,
|
|
|
|
|
is_active=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.session.desc = f"创建用户 [{user_data.username}] 到数据库"
|
|
|
|
|
self.session.add(db_user)
|
|
|
|
|
self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - add"
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - commit"
|
|
|
|
|
await self.session.refresh(db_user)
|
|
|
|
|
|
|
|
|
|
self.session.desc = f"创建用户 [{user_data.username}] 成功"
|
|
|
|
|
return db_user
|
|
|
|
|
|
|
|
|
|
def get_password_hash(self, password: str) -> str: ### Async: OK
|
2025-12-04 14:48:38 +08:00
|
|
|
"""Hash a password."""
|
2025-12-16 13:55:16 +08:00
|
|
|
self.session.desc = f"对密码 [{password}] 进行哈希处理"
|
2025-12-04 14:48:38 +08:00
|
|
|
return AuthService.get_password_hash(password)
|
2025-12-16 13:55:16 +08:00
|
|
|
|
|
|
|
|
async def get_user_by_id(self, user_id: int) -> Optional[User]: ### DrGraph: OK
|
|
|
|
|
"""Get user by ID."""
|
|
|
|
|
self.session.desc = f"通过ID{user_id}获取用户"
|
|
|
|
|
from sqlalchemy.orm import noload
|
|
|
|
|
stmt = select(User).where(User.id == user_id).options(noload(User.roles))
|
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
|
return result.scalar_one_or_none()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
def verify_password(self, plain_password: str, hashed_password: str) -> bool: ### DrGraph: OK
|
2025-12-04 14:48:38 +08:00
|
|
|
"""Verify a password against its hash."""
|
2025-12-16 13:55:16 +08:00
|
|
|
self.session.desc = f"验证密码 [{plain_password}] 与哈希 [{hashed_password}] 是否匹配"
|
2025-12-04 14:48:38 +08:00
|
|
|
return AuthService.verify_password(plain_password, hashed_password)
|
2025-12-16 13:55:16 +08:00
|
|
|
|
|
|
|
|
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[User]: ### DrGraph: OK
|
2025-12-04 14:48:38 +08:00
|
|
|
"""Update user information."""
|
2025-12-16 13:55:16 +08:00
|
|
|
self.session.desc = f"更新用户ID为{user_id}的信息"
|
|
|
|
|
user = await self.get_user_by_id(user_id)
|
|
|
|
|
if not user:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Update fields
|
|
|
|
|
update_data = user_update.model_dump(exclude_unset=True)
|
|
|
|
|
|
|
|
|
|
if "password" in update_data:
|
|
|
|
|
update_data["hashed_password"] = self.get_password_hash(update_data.pop("password"))
|
|
|
|
|
|
|
|
|
|
# session.desc = f"更新用户ID为{user_id}的信息"
|
|
|
|
|
for field, value in update_data.items():
|
|
|
|
|
setattr(user, field, value)
|
|
|
|
|
|
|
|
|
|
# Audit fields are set automatically by SQLAlchemy event listener
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
await self.session.refresh(user)
|
|
|
|
|
|
|
|
|
|
return user
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: ### DrGraph: OK
|
|
|
|
|
"""Change user password."""
|
|
|
|
|
self.session.desc = f"更改用户ID为{user_id}的密码"
|
|
|
|
|
user = await self.get_user_by_id(user_id)
|
|
|
|
|
if not user:
|
|
|
|
|
raise ValidationError("User not found")
|
|
|
|
|
|
|
|
|
|
# Verify current password
|
|
|
|
|
if not self.verify_password(current_password, user.hashed_password):
|
|
|
|
|
raise ValidationError("Current password is incorrect")
|
|
|
|
|
|
|
|
|
|
# Validate new password
|
|
|
|
|
if len(new_password) < 6:
|
|
|
|
|
raise ValidationError("New password must be at least 6 characters long")
|
|
|
|
|
|
|
|
|
|
# Hash new password
|
|
|
|
|
hashed_password = self.get_password_hash(new_password)
|
|
|
|
|
|
|
|
|
|
# Update password
|
|
|
|
|
user.hashed_password = hashed_password
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
|
|
|
|
self.session.desc = f"用户ID为{user_id}的密码已成功更改"
|
|
|
|
|
return True
|
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
async def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
|
2025-12-16 13:55:16 +08:00
|
|
|
"""Reset user password (admin only, no current password required)."""
|
|
|
|
|
self.session.desc = f"重置用户ID为{user_id}的密码"
|
2026-01-07 11:30:54 +08:00
|
|
|
user = await self.get_user_by_id(user_id)
|
2025-12-16 13:55:16 +08:00
|
|
|
if not user:
|
|
|
|
|
raise ValidationError("User not found")
|
|
|
|
|
|
|
|
|
|
# Validate new password
|
|
|
|
|
if len(new_password) < 6:
|
|
|
|
|
raise ValidationError("New password must be at least 6 characters long")
|
|
|
|
|
|
|
|
|
|
# Hash new password
|
|
|
|
|
hashed_password = self.get_password_hash(new_password)
|
|
|
|
|
# Update password
|
|
|
|
|
user.hashed_password = hashed_password
|
2026-01-07 11:30:54 +08:00
|
|
|
await self.session.commit()
|
2025-12-16 13:55:16 +08:00
|
|
|
|
|
|
|
|
self.session.desc = f"用户ID为{user_id}的密码已成功重置"
|
|
|
|
|
return True
|
2025-12-04 14:48:38 +08:00
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
async def get_users_with_filters( ### DrGraph: OK
|
2025-12-04 14:48:38 +08:00
|
|
|
self,
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 100,
|
|
|
|
|
search: Optional[str] = None,
|
|
|
|
|
role_id: Optional[int] = None,
|
|
|
|
|
is_active: Optional[bool] = None
|
|
|
|
|
) -> Tuple[List[User], int]:
|
|
|
|
|
"""Get users with filters and return total count."""
|
2025-12-16 13:55:16 +08:00
|
|
|
# Build base query
|
|
|
|
|
stmt = select(User).order_by(desc(User.created_at))
|
|
|
|
|
|
|
|
|
|
# Apply filters
|
|
|
|
|
if search:
|
|
|
|
|
search_term = f"%{search}%"
|
|
|
|
|
stmt = stmt.where(
|
|
|
|
|
or_(
|
|
|
|
|
User.username.ilike(search_term),
|
|
|
|
|
User.email.ilike(search_term),
|
|
|
|
|
User.full_name.ilike(search_term)
|
2025-12-04 14:48:38 +08:00
|
|
|
)
|
2025-12-16 13:55:16 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if role_id is not None:
|
|
|
|
|
from ..models.permission import UserRole
|
|
|
|
|
stmt = stmt.join(UserRole).where(UserRole.role_id == role_id)
|
|
|
|
|
|
|
|
|
|
if is_active is not None:
|
|
|
|
|
stmt = stmt.where(User.is_active == is_active)
|
|
|
|
|
|
|
|
|
|
# Get total count
|
|
|
|
|
count_stmt = select(text("COUNT(*)")).select_from(stmt.subquery())
|
|
|
|
|
total_result = await self.session.execute(count_stmt)
|
|
|
|
|
total = total_result.scalar_one()
|
|
|
|
|
self.session.desc = f"获取用户总数为{total}"
|
|
|
|
|
|
|
|
|
|
# Apply pagination
|
|
|
|
|
stmt = stmt.offset(skip).limit(limit)
|
|
|
|
|
users_result = await self.session.execute(stmt)
|
|
|
|
|
users = users_result.scalars().all()
|
|
|
|
|
|
|
|
|
|
return users, total
|
|
|
|
|
|
|
|
|
|
async def get_users(self, skip: int = 0, limit: int = 100) -> List[User]:
|
|
|
|
|
"""Get all users with pagination."""
|
|
|
|
|
# session.desc = f"分页获取用户列表,跳过{skip}条,限制{limit}条"
|
|
|
|
|
stmt = select(User).offset(skip).limit(limit)
|
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
|
return result.scalars().all()
|
|
|
|
|
|
|
|
|
|
async def delete_user(self, user_id: int) -> bool: ### DrGraph: OK
|
|
|
|
|
"""删除一个用户."""
|
|
|
|
|
self.session.desc = f"删除ID为{user_id}用户"
|
|
|
|
|
user = await self.get_user_by_id(user_id)
|
|
|
|
|
if not user:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Manually delete related records to avoid cascade issues
|
|
|
|
|
# Delete user_roles records
|
|
|
|
|
await self.session.execute(text("DELETE FROM user_roles WHERE user_id = :user_id"), parameters={"user_id": user_id})
|
|
|
|
|
# Now delete the user
|
|
|
|
|
await self.session.delete(user)
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
|
|
|
|
|
self.session.desc = f"用户ID为{user_id}已成功删除"
|
|
|
|
|
return True
|