hxf/backend/th_agenter/services/auth.py

138 lines
5.2 KiB
Python

"""Authentication service."""
from loguru import logger
from typing import Optional
from datetime import datetime, timedelta, timezone
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from sqlalchemy import select
import bcrypt
import jwt
from ..core.config import settings
from ..db.database import get_session
from ..models.user import User
security = HTTPBearer()
class AuthService:
"""Authentication service."""
@staticmethod
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
session: Session = Depends(get_session)
) -> User:
"""Get current authenticated user."""
from ..core.context import UserContext
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = AuthService.verify_token(token)
if payload is None:
session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..."
raise credentials_exception
username: str = payload.get("sub")
if username is None:
session.desc = "ERROR: 令牌中没有用户名"
raise credentials_exception
stmt = select(User).where(User.username == username)
user = (await session.execute(stmt)).scalar_one_or_none()
if user is None:
session.desc = f"ERROR: 数据库中未找到用户 {username}"
raise credentials_exception
# Set user in context for global access
UserContext.set_current_user(user, canLog=True)
# session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
return user
@staticmethod
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""Get current active user."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return current_user
@staticmethod
async def authenticate_user_by_email(session: Session, email: str, password: str) -> Optional[User]:
"""Authenticate user with email and password."""
session.desc = f"根据邮箱 {email} 验证用户密码"
stmt = select(User).where(User.email == email)
user = (await session.execute(stmt)).scalar_one_or_none()
if not user:
return None
if not AuthService.verify_password(password, user.hashed_password):
return None
return user
@staticmethod
async def authenticate_user(session: Session, username: str, password: str) -> Optional[User]:
"""Authenticate user with username and password."""
session.desc = f"根据用户名 {username} 验证用户密码"
stmt = select(User).where(User.username == username)
user = (await session.execute(stmt)).scalar_one_or_none()
if not user:
return None
if not AuthService.verify_password(password, user.hashed_password):
return None
return user
@staticmethod
async def create_access_token(session: Session, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建 JWT 访问 token"""
session.desc = f"创建 JWT 访问 token - 数据: {data}"
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.security.access_token_expire_minutes)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode,
settings.security.secret_key,
algorithm=settings.security.algorithm
)
return encoded_jwt
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
# 直接使用bcrypt库进行密码验证
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
@staticmethod
def get_password_hash(password: str) -> str:
"""Generate password hash."""
# 直接使用bcrypt库进行哈希
salt = bcrypt.gensalt()
hashed_bytes = bcrypt.hashpw(password.encode('utf-8'), salt)
hashed_password = hashed_bytes.decode('utf-8')
return hashed_password
@staticmethod
def verify_token(token: str) -> Optional[dict]:
"""Verify JWT token."""
try:
payload = jwt.decode(
token,
settings.security.secret_key,
algorithms=[settings.security.algorithm]
)
return payload
except jwt.PyJWTError as e:
logger.error(f"Token verification failed: {e}")
logger.error(f"Token: {token[:50]}...")
return None