142 lines
5.1 KiB
Python
142 lines
5.1 KiB
Python
"""
|
||
HTTP请求上下文管理,如:获取当前登录用户信息及Token信息
|
||
"""
|
||
|
||
from contextvars import ContextVar
|
||
from typing import Optional
|
||
import threading
|
||
from ..models.user import User
|
||
from loguru import logger
|
||
|
||
# Context variable to store current user
|
||
current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None)
|
||
|
||
# Thread-local storage as backup
|
||
_thread_local = threading.local()
|
||
|
||
|
||
class UserContext:
|
||
"""User context manager for accessing current user globally."""
|
||
|
||
@staticmethod
|
||
def set_current_user(user: User, canLog: bool = False) -> None:
|
||
"""Set current user in context."""
|
||
if canLog:
|
||
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||
|
||
# Store user information as a dictionary instead of the SQLAlchemy model
|
||
user_dict = {
|
||
'id': user.id,
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'full_name': user.full_name,
|
||
'is_active': user.is_active
|
||
}
|
||
|
||
# Set in ContextVar
|
||
current_user_context.set(user_dict)
|
||
|
||
# Also set in thread-local as backup
|
||
_thread_local.current_user = user_dict
|
||
|
||
# Verify it was set
|
||
verify_user = current_user_context.get()
|
||
if canLog:
|
||
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||
|
||
@staticmethod
|
||
def set_current_user_with_token(user: User, canLog: bool = False):
|
||
"""Set current user in context and return token for cleanup."""
|
||
if canLog:
|
||
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||
|
||
# Store user information as a dictionary instead of the SQLAlchemy model
|
||
user_dict = {
|
||
'id': user.id,
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'full_name': user.full_name,
|
||
'is_active': user.is_active
|
||
}
|
||
|
||
# Set in ContextVar and get token
|
||
token = current_user_context.set(user_dict)
|
||
|
||
# Also set in thread-local as backup
|
||
_thread_local.current_user = user_dict
|
||
|
||
# Verify it was set
|
||
verify_user = current_user_context.get()
|
||
if canLog:
|
||
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||
|
||
return token
|
||
|
||
@staticmethod
|
||
def reset_current_user_token(token):
|
||
"""Reset current user context using token."""
|
||
logger.info("[UserContext] - Resetting user context using token")
|
||
|
||
# Reset ContextVar using token
|
||
current_user_context.reset(token)
|
||
|
||
# Clear thread-local as well
|
||
if hasattr(_thread_local, 'current_user'):
|
||
delattr(_thread_local, 'current_user')
|
||
|
||
@staticmethod
|
||
def get_current_user() -> Optional[dict]:
|
||
"""Get current user from context."""
|
||
# Try ContextVar first
|
||
user = current_user_context.get()
|
||
if user:
|
||
# logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}")
|
||
return user
|
||
|
||
# Fallback to thread-local
|
||
user = getattr(_thread_local, 'current_user', None)
|
||
if user:
|
||
# logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}")
|
||
return user
|
||
|
||
logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)")
|
||
return None
|
||
|
||
@staticmethod
|
||
def get_current_user_id() -> Optional[int]:
|
||
"""Get current user ID from context."""
|
||
try:
|
||
user = UserContext.get_current_user()
|
||
return user.get('id') if user else None
|
||
except Exception as e:
|
||
logger.error(f"[UserContext] - Error getting current user ID: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def clear_current_user(canLog: bool = False) -> None:
|
||
"""Clear current user from context."""
|
||
if canLog:
|
||
logger.info("[UserContext] - 清除当前用户上下文")
|
||
|
||
current_user_context.set(None)
|
||
if hasattr(_thread_local, 'current_user'):
|
||
delattr(_thread_local, 'current_user')
|
||
|
||
@staticmethod
|
||
def require_current_user() -> dict:
|
||
"""Get current user from context, raise exception if not found."""
|
||
# Use the same logic as get_current_user to check both ContextVar and thread-local
|
||
user = UserContext.get_current_user()
|
||
if user is None:
|
||
from fastapi import HTTPException, status
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="No authenticated user in context"
|
||
)
|
||
return user
|
||
|
||
@staticmethod
|
||
def require_current_user_id() -> int:
|
||
"""Get current user ID from context, raise exception if not found."""
|
||
user = UserContext.require_current_user()
|
||
return user.get('id') |