153 lines
4.7 KiB
Python
153 lines
4.7 KiB
Python
from contextvars import ContextVar
|
||
from fastapi import FastAPI, Request, Depends, HTTPException
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import uuid
|
||
import uvicorn
|
||
|
||
app = FastAPI()
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||
|
||
# 创建上下文变量存储当前用户和请求ID
|
||
current_user_ctx: ContextVar[dict] = ContextVar("current_user", default=None)
|
||
request_id_ctx: ContextVar[str] = ContextVar("request_id", default=None)
|
||
|
||
|
||
# 用户模型
|
||
class User(BaseModel):
|
||
id: int
|
||
username: str
|
||
email: Optional[str] = None
|
||
|
||
|
||
# 模拟用户服务
|
||
class UserService:
|
||
@staticmethod
|
||
def get_current_user_id() -> int:
|
||
"""在service中直接获取当前用户ID"""
|
||
user = current_user_ctx.get()
|
||
if not user:
|
||
raise RuntimeError("No current user available")
|
||
return user["id"]
|
||
|
||
@staticmethod
|
||
def get_current_user() -> dict:
|
||
"""获取完整的当前用户信息"""
|
||
user = current_user_ctx.get()
|
||
if not user:
|
||
raise RuntimeError("No current user available")
|
||
return user
|
||
|
||
|
||
# 业务服务示例
|
||
class TaskService:
|
||
def create_task(self, task_data: dict):
|
||
"""创建任务时自动添加当前用户ID"""
|
||
current_user_id = UserService.get_current_user_id()
|
||
|
||
# 这里模拟数据库操作
|
||
task = {
|
||
**task_data,
|
||
"created_by": current_user_id,
|
||
"created_at": "2023-10-01 12:00:00"
|
||
}
|
||
|
||
print(f"Task created by user {current_user_id}: {task}")
|
||
return task
|
||
|
||
def get_user_tasks(self):
|
||
"""获取当前用户的任务"""
|
||
user = current_user_ctx.get()
|
||
current_user_id = UserService.get_current_user_id()
|
||
|
||
# 模拟根据用户ID查询任务
|
||
return [{"id": 1, "title": "Sample task", "user_id": current_user_id}]
|
||
|
||
|
||
# 中间件:设置上下文
|
||
@app.middleware("http")
|
||
async def set_context_vars(request: Request, call_next):
|
||
# 为每个请求生成唯一ID
|
||
request_id = str(uuid.uuid4())
|
||
request_id_token = request_id_ctx.set(request_id)
|
||
|
||
# 尝试提取用户信息
|
||
user_token = None
|
||
try:
|
||
auth_header = request.headers.get("Authorization")
|
||
if auth_header and auth_header.startswith("Bearer "):
|
||
token = auth_header.replace("Bearer ", "")
|
||
user = await decode_token_and_get_user(token) # 您的认证逻辑
|
||
user_token = current_user_ctx.set(user)
|
||
|
||
response = await call_next(request)
|
||
return response
|
||
finally:
|
||
# 清理上下文
|
||
request_id_ctx.reset(request_id_token)
|
||
if user_token:
|
||
current_user_ctx.reset(user_token)
|
||
|
||
|
||
# 模拟认证函数
|
||
async def decode_token_and_get_user(token: str) -> dict:
|
||
# 这里应该是您的实际认证逻辑,例如JWT解码或数据库查询
|
||
# 简单模拟:根据token返回用户信息
|
||
if token == "valid_token_123":
|
||
return {"id": 123, "username": "john_doe", "email": "john@example.com"}
|
||
elif token == "valid_token_456":
|
||
return {"id": 456, "username": "jane_doe", "email": "jane@example.com"}
|
||
else:
|
||
return None
|
||
|
||
|
||
# 依赖项:用于路由层认证
|
||
async def get_current_user_route(token: str = Depends(oauth2_scheme)) -> dict:
|
||
"""路由层的用户认证"""
|
||
user = await decode_token_and_get_user(token)
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||
return user
|
||
|
||
|
||
# 路由处理函数
|
||
@app.post("/tasks")
|
||
async def create_task(
|
||
task_data: dict,
|
||
current_user: dict = Depends(get_current_user_route)
|
||
):
|
||
"""创建任务"""
|
||
# 不需要显式传递user_id到service!
|
||
task_service = TaskService()
|
||
task = task_service.create_task(task_data)
|
||
return {"task": task, "message": "Task created successfully"}
|
||
|
||
|
||
@app.get("/tasks")
|
||
async def get_tasks(current_user: dict = Depends(get_current_user_route)):
|
||
"""获取当前用户的任务"""
|
||
task_service = TaskService()
|
||
tasks = task_service.get_user_tasks()
|
||
return {"tasks": tasks}
|
||
|
||
|
||
@app.get("/users/me")
|
||
async def read_users_me(current_user: dict = Depends(get_current_user_route)):
|
||
"""获取当前用户信息"""
|
||
return current_user
|
||
|
||
|
||
# 测试端点 - 直接在路由中获取上下文用户
|
||
@app.get("/test-context")
|
||
async def test_context():
|
||
"""测试直接通过上下文获取用户(不通过依赖注入)"""
|
||
try:
|
||
user = UserService.get_current_user()
|
||
return {"message": "Successfully got user from context", "user": user}
|
||
except RuntimeError as e:
|
||
return {"error": str(e)}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |