hxf/backend/th_agenter/services/user.py

282 lines
10 KiB
Python

"""User service for managing user operations."""
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import or_, and_, desc
from ..models.user import User
from ..utils.schemas import UserCreate, UserUpdate
from ..utils.exceptions import DatabaseError, ValidationError
from ..utils.logger import get_logger
from .auth import AuthService
logger = get_logger(__name__)
class UserService:
"""Service for user management operations."""
def __init__(self, db: Session):
self.db = db
def get_password_hash(self, password: str) -> str:
"""Hash a password."""
return AuthService.get_password_hash(password)
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return AuthService.verify_password(plain_password, hashed_password)
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""Get user by ID."""
try:
# Use options to avoid loading problematic relationships
from sqlalchemy.orm import noload
return self.db.query(User).options(
noload(User.roles)
).filter(User.id == user_id).first()
except Exception as e:
logger.error(f"Error getting user by ID {user_id}: {e}")
raise DatabaseError(f"Failed to get user: {str(e)}")
def get_user(self, user_id: int) -> Optional[User]:
"""Get user by ID (alias for get_user_by_id)."""
return self.get_user_by_id(user_id)
def get_user_by_email(self, email: str) -> Optional[User]:
"""Get user by email."""
try:
return self.db.query(User).filter(User.email == email).first()
except Exception as e:
logger.error(f"Error getting user by email {email}: {e}")
raise DatabaseError(f"Failed to get user: {str(e)}")
def get_user_by_username(self, username: str) -> Optional[User]:
"""Get user by username."""
try:
return self.db.query(User).filter(User.username == username).first()
except Exception as e:
logger.error(f"Error getting user by username {username}: {e}")
raise DatabaseError(f"Failed to get user: {str(e)}")
def create_user(self, user_data: UserCreate) -> User:
"""Create a new user."""
try:
# Validate input
if len(user_data.password) < 6:
raise ValidationError("Password must be at least 6 characters long")
if len(user_data.password) > 72:
raise ValidationError("Password cannot be longer than 72 characters")
# Hash password
hashed_password = self.get_password_hash(user_data.password)
# Create user
db_user = User(
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
full_name=user_data.full_name,
is_active=True
)
# Set audit fields
db_user.set_audit_fields()
self.db.add(db_user)
self.db.commit()
self.db.refresh(db_user)
logger.info(f"User created successfully: {user_data.username}")
return db_user
except ValidationError:
raise
except Exception as e:
self.db.rollback()
logger.error(f"Error creating user {user_data.username}: {e}")
raise DatabaseError(f"Failed to create user: {str(e)}")
def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[User]:
"""Update user information."""
try:
user = self.get_user_by_id(user_id)
if not user:
return None
# Update fields
update_data = user_update.dict(exclude_unset=True)
if "password" in update_data:
update_data["hashed_password"] = self.get_password_hash(update_data.pop("password"))
for field, value in update_data.items():
setattr(user, field, value)
# Skip audit fields for now due to database schema mismatch
# user.set_audit_fields(is_update=True)
self.db.commit()
self.db.refresh(user)
logger.info(f"User updated successfully: {user.username}")
return user
except Exception as e:
self.db.rollback()
logger.error(f"Error updating user {user_id}: {e}")
raise DatabaseError(f"Failed to update user: {str(e)}")
def get_users(self, skip: int = 0, limit: int = 100) -> List[User]:
"""Get all users with pagination."""
try:
return self.db.query(User).offset(skip).limit(limit).all()
except Exception as e:
logger.error(f"Error getting users: {e}")
raise DatabaseError(f"Failed to get users: {str(e)}")
def get_users_with_filters(
self,
skip: int = 0,
limit: int = 100,
search: Optional[str] = None,
role_id: Optional[int] = None,
is_active: Optional[bool] = None
) -> Tuple[List[User], int]:
"""Get users with filters and return total count."""
try:
query = self.db.query(User).order_by(desc('created_at'))
# Apply filters
if search:
search_term = f"%{search}%"
query = query.filter(
or_(
User.username.ilike(search_term),
User.email.ilike(search_term),
User.full_name.ilike(search_term)
)
)
if role_id is not None:
from ..models.permission import UserRole
query = query.join(UserRole).filter(UserRole.role_id == role_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
# Get total count
total = query.count()
# Apply pagination
users = query.offset(skip).limit(limit).all()
return users, total
except Exception as e:
logger.error(f"Error getting users with filters: {e}")
raise DatabaseError(f"Failed to get users: {str(e)}")
def delete_user(self, user_id: int) -> bool:
"""Delete a user."""
try:
user = self.get_user_by_id(user_id)
if not user:
return False
# Manually delete related records to avoid cascade issues
from sqlalchemy import text
# Delete user_roles records
self.db.execute(text("DELETE FROM user_roles WHERE user_id = :user_id"), {"user_id": user_id})
# Now delete the user
self.db.delete(user)
self.db.commit()
logger.info(f"User deleted successfully: {user.username}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"Error deleting user {user_id}: {e}")
raise DatabaseError(f"Failed to delete user: {str(e)}")
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""Authenticate user with username and password."""
try:
user = self.get_user_by_username(username)
if not user:
return None
if not self.verify_password(password, user.hashed_password):
return None
if not user.is_active:
return None
return user
except Exception as e:
logger.error(f"Error authenticating user {username}: {e}")
return None
def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
"""Change user password."""
try:
user = self.get_user_by_id(user_id)
if not user:
raise ValidationError("User not found")
# Verify current password
if not self.verify_password(current_password, user.hashed_password):
raise ValidationError("Current password is incorrect")
# Validate new password
if len(new_password) < 6:
raise ValidationError("New password must be at least 6 characters long")
# Hash new password
hashed_password = self.get_password_hash(new_password)
# Update password
user.hashed_password = hashed_password
self.db.commit()
logger.info(f"Password changed successfully for user: {user.username}")
return True
except ValidationError:
raise
except Exception as e:
self.db.rollback()
logger.error(f"Error changing password for user {user_id}: {e}")
raise DatabaseError(f"Failed to change password: {str(e)}")
def reset_password(self, user_id: int, new_password: str) -> bool:
"""Reset user password (admin only, no current password required)."""
try:
user = self.get_user_by_id(user_id)
if not user:
raise ValidationError("User not found")
# Validate new password
if len(new_password) < 6:
raise ValidationError("New password must be at least 6 characters long")
# Hash new password
hashed_password = self.get_password_hash(new_password)
# Update password
user.hashed_password = hashed_password
self.db.commit()
logger.info(f"Password reset successfully for user: {user.username}")
return True
except ValidationError:
raise
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting password for user {user_id}: {e}")
raise DatabaseError(f"Failed to reset password: {str(e)}")