hxf/backend/th_agenter/services/user.py

214 lines
8.8 KiB
Python
Raw Normal View History

2025-12-04 14:48:38 +08:00
"""User service for managing user operations."""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
2025-12-16 13:55:16 +08:00
from sqlalchemy import select, or_, desc, text
2025-12-04 14:48:38 +08:00
from ..models.user import User
2025-12-16 13:55:16 +08:00
from utils.util_schemas import UserCreate, UserUpdate
from utils.util_exceptions import DatabaseError, ValidationError
2025-12-04 14:48:38 +08:00
from .auth import AuthService
class UserService:
2025-12-16 13:55:16 +08:00
"""Service for user management operations."""
def __init__(self, session: Session): ### Async: OK
self.session = session
self.session.desc = "创建UserService;-1"
async def get_user_by_email(self, email: str) -> Optional[User]: ### Async: OK
"""Get user by email."""
self.session.desc = f"通过邮箱 [{email}] 获取用户"
stmt = select(User).where(User.email == email)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
async def get_user_by_username(self, username: str) -> Optional[User]: ### Async: OK
"""Get user by username."""
self.session.desc = f"通过用户名 [{username}] 获取用户"
stmt = select(User).where(User.username == username)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def create_user(self, user_data: UserCreate) -> User: ### Async: OK
"""Create a new user."""
self.session.desc = f"创建用户 [{user_data.username}]"
# Validate input
if len(user_data.password) < 6:
self.session.desc = "ERROR: 密码长度必须至少为6个字符"
raise ValidationError("密码长度必须至少为6个字符")
# Hash password
hashed_password = self.get_password_hash(user_data.password)
self.session.desc = f"对密码 [{user_data.password}] 进行哈希处理完毕"
# Create user
db_user = User(
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
full_name=user_data.full_name,
is_active=True
)
self.session.desc = f"创建用户 [{user_data.username}] 到数据库"
self.session.add(db_user)
self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - add"
await self.session.commit()
self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - commit"
await self.session.refresh(db_user)
self.session.desc = f"创建用户 [{user_data.username}] 成功"
return db_user
def get_password_hash(self, password: str) -> str: ### Async: OK
2025-12-04 14:48:38 +08:00
"""Hash a password."""
2025-12-16 13:55:16 +08:00
self.session.desc = f"对密码 [{password}] 进行哈希处理"
2025-12-04 14:48:38 +08:00
return AuthService.get_password_hash(password)
2025-12-16 13:55:16 +08:00
async def get_user_by_id(self, user_id: int) -> Optional[User]: ### DrGraph: OK
"""Get user by ID."""
self.session.desc = f"通过ID{user_id}获取用户"
from sqlalchemy.orm import noload
stmt = select(User).where(User.id == user_id).options(noload(User.roles))
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
def verify_password(self, plain_password: str, hashed_password: str) -> bool: ### DrGraph: OK
2025-12-04 14:48:38 +08:00
"""Verify a password against its hash."""
2025-12-16 13:55:16 +08:00
self.session.desc = f"验证密码 [{plain_password}] 与哈希 [{hashed_password}] 是否匹配"
2025-12-04 14:48:38 +08:00
return AuthService.verify_password(plain_password, hashed_password)
2025-12-16 13:55:16 +08:00
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[User]: ### DrGraph: OK
2025-12-04 14:48:38 +08:00
"""Update user information."""
2025-12-16 13:55:16 +08:00
self.session.desc = f"更新用户ID为{user_id}的信息"
user = await self.get_user_by_id(user_id)
if not user:
return None
# Update fields
update_data = user_update.model_dump(exclude_unset=True)
if "password" in update_data:
update_data["hashed_password"] = self.get_password_hash(update_data.pop("password"))
# session.desc = f"更新用户ID为{user_id}的信息"
for field, value in update_data.items():
setattr(user, field, value)
# Audit fields are set automatically by SQLAlchemy event listener
await self.session.commit()
await self.session.refresh(user)
return user
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: ### DrGraph: OK
"""Change user password."""
self.session.desc = f"更改用户ID为{user_id}的密码"
user = await self.get_user_by_id(user_id)
if not user:
raise ValidationError("User not found")
# Verify current password
if not self.verify_password(current_password, user.hashed_password):
raise ValidationError("Current password is incorrect")
# Validate new password
if len(new_password) < 6:
raise ValidationError("New password must be at least 6 characters long")
# Hash new password
hashed_password = self.get_password_hash(new_password)
# Update password
user.hashed_password = hashed_password
await self.session.commit()
self.session.desc = f"用户ID为{user_id}的密码已成功更改"
return True
def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
"""Reset user password (admin only, no current password required)."""
self.session.desc = f"重置用户ID为{user_id}的密码"
user = self.get_user_by_id(user_id)
if not user:
raise ValidationError("User not found")
# Validate new password
if len(new_password) < 6:
raise ValidationError("New password must be at least 6 characters long")
# Hash new password
hashed_password = self.get_password_hash(new_password)
# Update password
user.hashed_password = hashed_password
self.session.commit()
self.session.desc = f"用户ID为{user_id}的密码已成功重置"
return True
2025-12-04 14:48:38 +08:00
2025-12-16 13:55:16 +08:00
async def get_users_with_filters( ### DrGraph: OK
2025-12-04 14:48:38 +08:00
self,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
role_id: Optional[int] = None,
is_active: Optional[bool] = None
) -> Tuple[List[User], int]:
"""Get users with filters and return total count."""
2025-12-16 13:55:16 +08:00
# Build base query
stmt = select(User).order_by(desc(User.created_at))
# Apply filters
if search:
search_term = f"%{search}%"
stmt = stmt.where(
or_(
User.username.ilike(search_term),
User.email.ilike(search_term),
User.full_name.ilike(search_term)
2025-12-04 14:48:38 +08:00
)
2025-12-16 13:55:16 +08:00
)
if role_id is not None:
from ..models.permission import UserRole
stmt = stmt.join(UserRole).where(UserRole.role_id == role_id)
if is_active is not None:
stmt = stmt.where(User.is_active == is_active)
# Get total count
count_stmt = select(text("COUNT(*)")).select_from(stmt.subquery())
total_result = await self.session.execute(count_stmt)
total = total_result.scalar_one()
self.session.desc = f"获取用户总数为{total}"
# Apply pagination
stmt = stmt.offset(skip).limit(limit)
users_result = await self.session.execute(stmt)
users = users_result.scalars().all()
return users, total
async def get_users(self, skip: int = 0, limit: int = 100) -> List[User]:
"""Get all users with pagination."""
# session.desc = f"分页获取用户列表,跳过{skip}条,限制{limit}条"
stmt = select(User).offset(skip).limit(limit)
result = await self.session.execute(stmt)
return result.scalars().all()
async def delete_user(self, user_id: int) -> bool: ### DrGraph: OK
"""删除一个用户."""
self.session.desc = f"删除ID为{user_id}用户"
user = await self.get_user_by_id(user_id)
if not user:
return False
# Manually delete related records to avoid cascade issues
# Delete user_roles records
await self.session.execute(text("DELETE FROM user_roles WHERE user_id = :user_id"), parameters={"user_id": user_id})
# Now delete the user
await self.session.delete(user)
await self.session.commit()
self.session.desc = f"用户ID为{user_id}已成功删除"
return True