hyf-backend/th_agenter/core/middleware.py

174 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
中间件管理如上下文中间件校验Token等
"""
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Callable
from loguru import logger
from fastapi import status
from utils.util_exceptions import HxfErrorResponse
from ..db.database import get_session, AsyncSessionFactory, engine_async
from sqlalchemy.ext.asyncio import AsyncSession
from ..services.auth import AuthService
from .context import UserContext
class UserContextMiddleware(BaseHTTPMiddleware):
"""Middleware to set user context for authenticated requests."""
def __init__(self, app, exclude_paths: list = None):
super().__init__(app)
self.canLog = False
# Paths that don't require authentication
self.exclude_paths = exclude_paths or [
"/docs",
"/redoc",
"/openapi.json",
"/api/auth/login",
"/api/auth/register",
"/api/auth/login-oauth",
"/auth/login",
"/auth/register",
"/auth/login-oauth",
"/health",
"/static/"
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request and set user context if authenticated."""
if self.canLog:
logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}")
# Skip authentication for excluded paths
path = request.url.path
if self.canLog:
logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}")
should_skip = False
for exclude_path in self.exclude_paths:
# Exact match
if path == exclude_path:
should_skip = True
if self.canLog:
logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}")
break
# For paths ending with '/', check if request path starts with it
elif exclude_path.endswith('/') and path.startswith(exclude_path):
should_skip = True
if self.canLog:
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头")
break
# For paths not ending with '/', check if request path starts with it + '/'
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
should_skip = True
if self.canLog:
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头")
break
if should_skip:
if self.canLog:
logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next")
response = await call_next(request)
return response
if self.canLog:
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
# Always clear any existing user context to ensure fresh authentication
UserContext.clear_current_user(self.canLog)
# Initialize context token
user_token = None
# Try to extract and validate token
try:
# Get authorization header
authorization = request.headers.get("Authorization")
if not authorization or not authorization.startswith("Bearer "):
# No token provided, return 401 error
return HxfErrorResponse(
message="缺少或无效的授权头",
status_code=status.HTTP_401_UNAUTHORIZED
)
# Extract token
token = authorization.split(" ")[1]
# Verify token
payload = AuthService.verify_token(token)
if payload is None:
# Invalid token, return 401 error
return HxfErrorResponse(
message="无效或过期的令牌",
status_code=status.HTTP_401_UNAUTHORIZED
)
# Get username from token
username = payload.get("sub")
if not username:
return HxfErrorResponse(
message="令牌负载无效",
status_code=status.HTTP_401_UNAUTHORIZED
)
# Get user from database
from sqlalchemy import select
from ..models.user import User
# 创建一个临时的异步会话获取用户信息
session = AsyncSession(bind=engine_async)
try:
stmt = select(User).where(User.username == username)
user = await session.execute(stmt)
user = user.scalar_one_or_none()
if not user:
return HxfErrorResponse(
message="用户不存在",
status_code=status.HTTP_401_UNAUTHORIZED
)
if not user.is_active:
return HxfErrorResponse(
message="用户账户已停用",
status_code=status.HTTP_401_UNAUTHORIZED
)
# Set user in context using token mechanism
user_token = UserContext.set_current_user_with_token(user, self.canLog)
if self.canLog:
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
# Verify context is set correctly
current_user_id = UserContext.get_current_user_id()
if self.canLog:
logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文")
finally:
await session.close()
except Exception as e:
# Log error but don't fail the request
logger.error(f"[MIDDLEWARE] - 认证过程 [{request.method} {request.url.path}] 中设置用户上下文出错: {e}")
# Return 401 error
return HxfErrorResponse(
message="认证过程中出错",
status_code=status.HTTP_401_UNAUTHORIZED
)
# Continue with request
try:
response = await call_next(request)
return response
except Exception as e:
# Log error but don't fail the request
logger.error(f"[MIDDLEWARE] - 请求处理 [{request.method} {request.url.path}] 出错: {e}")
# Return 500 error
return HxfErrorResponse(e)
finally:
# Always clear user context after request processing
UserContext.clear_current_user(self.canLog)
if self.canLog:
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")