2025-12-04 14:48:38 +08:00
|
|
|
|
"""
|
|
|
|
|
|
中间件管理,如上下文中间件:校验Token等
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from fastapi import Request
|
2025-12-04 14:48:38 +08:00
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
|
from starlette.responses import Response
|
|
|
|
|
|
from typing import Callable
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from loguru import logger
|
|
|
|
|
|
from fastapi import status
|
|
|
|
|
|
from utils.util_exceptions import HxfErrorResponse
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from ..db.database import get_session, AsyncSessionFactory, engine_async
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2025-12-04 14:48:38 +08:00
|
|
|
|
from ..services.auth import AuthService
|
|
|
|
|
|
from .context import UserContext
|
2025-12-16 13:55:16 +08:00
|
|
|
|
|
2025-12-04 14:48:38 +08:00
|
|
|
|
class UserContextMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
"""Middleware to set user context for authenticated requests."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, app, exclude_paths: list = None):
|
|
|
|
|
|
super().__init__(app)
|
2026-01-07 11:30:54 +08:00
|
|
|
|
self.canLog = False
|
2025-12-04 14:48:38 +08:00
|
|
|
|
# Paths that don't require authentication
|
|
|
|
|
|
self.exclude_paths = exclude_paths or [
|
|
|
|
|
|
"/docs",
|
|
|
|
|
|
"/redoc",
|
|
|
|
|
|
"/openapi.json",
|
|
|
|
|
|
"/api/auth/login",
|
|
|
|
|
|
"/api/auth/register",
|
|
|
|
|
|
"/api/auth/login-oauth",
|
|
|
|
|
|
"/auth/login",
|
|
|
|
|
|
"/auth/register",
|
|
|
|
|
|
"/auth/login-oauth",
|
|
|
|
|
|
"/health",
|
2025-12-16 13:55:16 +08:00
|
|
|
|
"/static/"
|
2025-12-04 14:48:38 +08:00
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
|
|
|
|
"""Process request and set user context if authenticated."""
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# Skip authentication for excluded paths
|
|
|
|
|
|
path = request.url.path
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
should_skip = False
|
|
|
|
|
|
for exclude_path in self.exclude_paths:
|
|
|
|
|
|
# Exact match
|
|
|
|
|
|
if path == exclude_path:
|
|
|
|
|
|
should_skip = True
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
break
|
|
|
|
|
|
# For paths ending with '/', check if request path starts with it
|
|
|
|
|
|
elif exclude_path.endswith('/') and path.startswith(exclude_path):
|
|
|
|
|
|
should_skip = True
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
break
|
|
|
|
|
|
# For paths not ending with '/', check if request path starts with it + '/'
|
|
|
|
|
|
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
|
|
|
|
|
|
should_skip = True
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if should_skip:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
response = await call_next(request)
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# Always clear any existing user context to ensure fresh authentication
|
2026-01-07 11:30:54 +08:00
|
|
|
|
UserContext.clear_current_user(self.canLog)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# Initialize context token
|
|
|
|
|
|
user_token = None
|
|
|
|
|
|
|
|
|
|
|
|
# Try to extract and validate token
|
|
|
|
|
|
try:
|
|
|
|
|
|
# Get authorization header
|
|
|
|
|
|
authorization = request.headers.get("Authorization")
|
|
|
|
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
|
|
|
|
# No token provided, return 401 error
|
2025-12-16 13:55:16 +08:00
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="缺少或无效的授权头",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
2025-12-04 14:48:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Extract token
|
|
|
|
|
|
token = authorization.split(" ")[1]
|
|
|
|
|
|
|
2026-01-07 11:30:54 +08:00
|
|
|
|
|
2025-12-04 14:48:38 +08:00
|
|
|
|
# Verify token
|
|
|
|
|
|
payload = AuthService.verify_token(token)
|
|
|
|
|
|
if payload is None:
|
|
|
|
|
|
# Invalid token, return 401 error
|
2025-12-16 13:55:16 +08:00
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="无效或过期的令牌",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
2025-12-04 14:48:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Get username from token
|
|
|
|
|
|
username = payload.get("sub")
|
|
|
|
|
|
if not username:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="令牌负载无效",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
2025-12-04 14:48:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Get user from database
|
2025-12-16 13:55:16 +08:00
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from ..models.user import User
|
|
|
|
|
|
|
|
|
|
|
|
# 创建一个临时的异步会话获取用户信息
|
|
|
|
|
|
session = AsyncSession(bind=engine_async)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
try:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
stmt = select(User).where(User.username == username)
|
|
|
|
|
|
user = await session.execute(stmt)
|
|
|
|
|
|
user = user.scalar_one_or_none()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
if not user:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="用户不存在",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
2025-12-04 14:48:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not user.is_active:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="用户账户已停用",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
2025-12-04 14:48:38 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Set user in context using token mechanism
|
2026-01-07 11:30:54 +08:00
|
|
|
|
user_token = UserContext.set_current_user_with_token(user, self.canLog)
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# Verify context is set correctly
|
|
|
|
|
|
current_user_id = UserContext.get_current_user_id()
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文")
|
2025-12-04 14:48:38 +08:00
|
|
|
|
finally:
|
2025-12-16 13:55:16 +08:00
|
|
|
|
await session.close()
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Log error but don't fail the request
|
2025-12-16 13:55:16 +08:00
|
|
|
|
logger.error(f"[MIDDLEWARE] - 认证过程中设置用户上下文出错: {e}")
|
|
|
|
|
|
# Return 401 error
|
|
|
|
|
|
return HxfErrorResponse(
|
|
|
|
|
|
message="认证过程中出错",
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
|
|
|
|
|
|
# Continue with request
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = await call_next(request)
|
|
|
|
|
|
return response
|
2026-01-07 11:30:54 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Log error but don't fail the request
|
|
|
|
|
|
logger.error(f"[MIDDLEWARE] - 请求处理出错: {e}")
|
|
|
|
|
|
# Return 500 error
|
|
|
|
|
|
return HxfErrorResponse(e)
|
2025-12-04 14:48:38 +08:00
|
|
|
|
finally:
|
|
|
|
|
|
# Always clear user context after request processing
|
2026-01-07 11:30:54 +08:00
|
|
|
|
UserContext.clear_current_user(self.canLog)
|
2025-12-16 13:55:16 +08:00
|
|
|
|
if self.canLog:
|
|
|
|
|
|
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")
|