hxf/backend/tests/fastapi_test/main.py

153 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextvars import ContextVar
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import Optional
import uuid
import uvicorn
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 创建上下文变量存储当前用户和请求ID
current_user_ctx: ContextVar[dict] = ContextVar("current_user", default=None)
request_id_ctx: ContextVar[str] = ContextVar("request_id", default=None)
# 用户模型
class User(BaseModel):
id: int
username: str
email: Optional[str] = None
# 模拟用户服务
class UserService:
@staticmethod
def get_current_user_id() -> int:
"""在service中直接获取当前用户ID"""
user = current_user_ctx.get()
if not user:
raise RuntimeError("No current user available")
return user["id"]
@staticmethod
def get_current_user() -> dict:
"""获取完整的当前用户信息"""
user = current_user_ctx.get()
if not user:
raise RuntimeError("No current user available")
return user
# 业务服务示例
class TaskService:
def create_task(self, task_data: dict):
"""创建任务时自动添加当前用户ID"""
current_user_id = UserService.get_current_user_id()
# 这里模拟数据库操作
task = {
**task_data,
"created_by": current_user_id,
"created_at": "2023-10-01 12:00:00"
}
print(f"Task created by user {current_user_id}: {task}")
return task
def get_user_tasks(self):
"""获取当前用户的任务"""
user = current_user_ctx.get()
current_user_id = UserService.get_current_user_id()
# 模拟根据用户ID查询任务
return [{"id": 1, "title": "Sample task", "user_id": current_user_id}]
# 中间件:设置上下文
@app.middleware("http")
async def set_context_vars(request: Request, call_next):
# 为每个请求生成唯一ID
request_id = str(uuid.uuid4())
request_id_token = request_id_ctx.set(request_id)
# 尝试提取用户信息
user_token = None
try:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "")
user = await decode_token_and_get_user(token) # 您的认证逻辑
user_token = current_user_ctx.set(user)
response = await call_next(request)
return response
finally:
# 清理上下文
request_id_ctx.reset(request_id_token)
if user_token:
current_user_ctx.reset(user_token)
# 模拟认证函数
async def decode_token_and_get_user(token: str) -> dict:
# 这里应该是您的实际认证逻辑例如JWT解码或数据库查询
# 简单模拟根据token返回用户信息
if token == "valid_token_123":
return {"id": 123, "username": "john_doe", "email": "john@example.com"}
elif token == "valid_token_456":
return {"id": 456, "username": "jane_doe", "email": "jane@example.com"}
else:
return None
# 依赖项:用于路由层认证
async def get_current_user_route(token: str = Depends(oauth2_scheme)) -> dict:
"""路由层的用户认证"""
user = await decode_token_and_get_user(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
return user
# 路由处理函数
@app.post("/tasks")
async def create_task(
task_data: dict,
current_user: dict = Depends(get_current_user_route)
):
"""创建任务"""
# 不需要显式传递user_id到service
task_service = TaskService()
task = task_service.create_task(task_data)
return {"task": task, "message": "Task created successfully"}
@app.get("/tasks")
async def get_tasks(current_user: dict = Depends(get_current_user_route)):
"""获取当前用户的任务"""
task_service = TaskService()
tasks = task_service.get_user_tasks()
return {"tasks": tasks}
@app.get("/users/me")
async def read_users_me(current_user: dict = Depends(get_current_user_route)):
"""获取当前用户信息"""
return current_user
# 测试端点 - 直接在路由中获取上下文用户
@app.get("/test-context")
async def test_context():
"""测试直接通过上下文获取用户(不通过依赖注入)"""
try:
user = UserService.get_current_user()
return {"message": "Successfully got user from context", "user": user}
except RuntimeError as e:
return {"error": str(e)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)