""" 中间件管理,如上下文中间件:校验Token等 """ from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from typing import Callable from loguru import logger from fastapi import status from utils.util_exceptions import HxfErrorResponse from ..db.database import get_session, AsyncSessionFactory, engine_async from sqlalchemy.ext.asyncio import AsyncSession from ..services.auth import AuthService from .context import UserContext class UserContextMiddleware(BaseHTTPMiddleware): """Middleware to set user context for authenticated requests.""" def __init__(self, app, exclude_paths: list = None): super().__init__(app) self.canLog = False # Paths that don't require authentication self.exclude_paths = exclude_paths or [ "/docs", "/redoc", "/openapi.json", "/api/auth/login", "/api/auth/register", "/api/auth/login-oauth", "/auth/login", "/auth/register", "/auth/login-oauth", "/health", "/static/" ] async def dispatch(self, request: Request, call_next: Callable) -> Response: """Process request and set user context if authenticated.""" if self.canLog: logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}") # Skip authentication for excluded paths path = request.url.path if self.canLog: logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}") should_skip = False for exclude_path in self.exclude_paths: # Exact match if path == exclude_path: should_skip = True if self.canLog: logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}") break # For paths ending with '/', check if request path starts with it elif exclude_path.endswith('/') and path.startswith(exclude_path): should_skip = True if self.canLog: logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头") break # For paths not ending with '/', check if request path starts with it + '/' elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'): should_skip = True if self.canLog: logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头") break if should_skip: if self.canLog: logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next") response = await call_next(request) return response if self.canLog: logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理") # Always clear any existing user context to ensure fresh authentication UserContext.clear_current_user(self.canLog) # Initialize context token user_token = None # Try to extract and validate token try: # Get authorization header authorization = request.headers.get("Authorization") if not authorization or not authorization.startswith("Bearer "): # No token provided, return 401 error return HxfErrorResponse( message="缺少或无效的授权头", status_code=status.HTTP_401_UNAUTHORIZED ) # Extract token token = authorization.split(" ")[1] # Verify token payload = AuthService.verify_token(token) if payload is None: # Invalid token, return 401 error return HxfErrorResponse( message="无效或过期的令牌", status_code=status.HTTP_401_UNAUTHORIZED ) # Get username from token username = payload.get("sub") if not username: return HxfErrorResponse( message="令牌负载无效", status_code=status.HTTP_401_UNAUTHORIZED ) # Get user from database from sqlalchemy import select from ..models.user import User # 创建一个临时的异步会话获取用户信息 session = AsyncSession(bind=engine_async) try: stmt = select(User).where(User.username == username) user = await session.execute(stmt) user = user.scalar_one_or_none() if not user: return HxfErrorResponse( message="用户不存在", status_code=status.HTTP_401_UNAUTHORIZED ) if not user.is_active: return HxfErrorResponse( message="用户账户已停用", status_code=status.HTTP_401_UNAUTHORIZED ) # Set user in context using token mechanism user_token = UserContext.set_current_user_with_token(user, self.canLog) if self.canLog: logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文") # Verify context is set correctly current_user_id = UserContext.get_current_user_id() if self.canLog: logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文") finally: await session.close() except Exception as e: # Log error but don't fail the request logger.error(f"[MIDDLEWARE] - 认证过程中设置用户上下文出错: {e}") # Return 401 error return HxfErrorResponse( message="认证过程中出错", status_code=status.HTTP_401_UNAUTHORIZED ) # Continue with request try: response = await call_next(request) return response except Exception as e: # Log error but don't fail the request logger.error(f"[MIDDLEWARE] - 请求处理出错: {e}") # Return 500 error return HxfErrorResponse(e) finally: # Always clear user context after request processing UserContext.clear_current_user(self.canLog) if self.canLog: logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")