229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
"""Custom exceptions and error handlers for the chat agent application."""
|
||
from typing import Any, Dict, List, Optional
|
||
from fastapi import HTTPException, Request, status
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from typing import Union
|
||
from pydantic import BaseModel
|
||
from datetime import datetime
|
||
import json
|
||
from loguru import logger
|
||
|
||
from starlette.status import (
|
||
HTTP_400_BAD_REQUEST,
|
||
HTTP_401_UNAUTHORIZED,
|
||
HTTP_403_FORBIDDEN,
|
||
HTTP_404_NOT_FOUND,
|
||
HTTP_429_TOO_MANY_REQUESTS,
|
||
HTTP_500_INTERNAL_SERVER_ERROR,
|
||
)
|
||
|
||
# 创建一个完整的响应模型
|
||
class FullHxfResponseModel(BaseModel):
|
||
"""完整的响应模型,包含状态码、数据、错误信息等"""
|
||
code: int
|
||
status: int
|
||
data: Dict[str, Any]
|
||
error: Optional[Dict[str, Any]]
|
||
message: Optional[str]
|
||
|
||
class HxfResponse(JSONResponse):
|
||
def __new__(cls, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]], StreamingResponse]):
|
||
# 如果是StreamingResponse,直接返回,不进行JSON包装
|
||
if isinstance(response, StreamingResponse):
|
||
return response
|
||
# 否则创建HxfResponse实例
|
||
return super().__new__(cls)
|
||
|
||
def __init__(self, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]]]):
|
||
code = 0
|
||
if isinstance(response, list):
|
||
# 处理BaseModel对象列表
|
||
if all(isinstance(item, BaseModel) for item in response):
|
||
data_dict = [item.model_dump(mode='json') for item in response]
|
||
else:
|
||
data_dict = response
|
||
elif isinstance(response, BaseModel):
|
||
# 处理单个BaseModel对象
|
||
data_dict = response.model_dump(mode='json')
|
||
else:
|
||
# 处理字典或其他可JSON序列化的数据
|
||
data_dict = response
|
||
|
||
if 'success' in data_dict and data_dict['success'] == False:
|
||
code = -1
|
||
|
||
content = {
|
||
"code": code,
|
||
"status": status.HTTP_200_OK,
|
||
"data": data_dict,
|
||
"error": None,
|
||
"message": None
|
||
}
|
||
super().__init__(
|
||
content=content,
|
||
status_code=status.HTTP_200_OK,
|
||
media_type="application/json"
|
||
)
|
||
|
||
class HxfErrorResponse(JSONResponse):
|
||
"""Custom JSON response class with standard format."""
|
||
|
||
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
|
||
"""Return a JSON error response."""
|
||
if isinstance(message, Exception):
|
||
msg = message.message if 'message' in message.__dict__ else str(message)
|
||
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
|
||
if isinstance(message, TypeError):
|
||
content = {
|
||
"code": -1,
|
||
"status": 500,
|
||
"data": None,
|
||
"error": f"错误类型: {type(message)} 错误信息: {message}",
|
||
"message": msg
|
||
}
|
||
else:
|
||
content = {
|
||
"code": -1,
|
||
"status": message.status_code,
|
||
"data": None,
|
||
"error": None,
|
||
"message": msg
|
||
}
|
||
else:
|
||
content = {
|
||
"code": -1,
|
||
"status": status_code,
|
||
"data": None,
|
||
"error": None,
|
||
"message": message
|
||
}
|
||
super().__init__(content=content, status_code=status_code)
|
||
|
||
class ChatAgentException(Exception):
|
||
"""Base exception for chat agent application."""
|
||
|
||
def __init__(
|
||
self,
|
||
message: str,
|
||
status_code: int = HTTP_500_INTERNAL_SERVER_ERROR,
|
||
details: Optional[Dict[str, Any]] = None
|
||
):
|
||
self.message = message
|
||
self.status_code = status_code
|
||
self.details = details or {}
|
||
super().__init__(self.message)
|
||
|
||
|
||
class ValidationError(ChatAgentException):
|
||
"""Validation error exception."""
|
||
|
||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||
super().__init__(message, HTTP_422_UNPROCESSABLE_ENTITY, details)
|
||
|
||
|
||
class AuthenticationError(ChatAgentException):
|
||
"""Authentication error exception."""
|
||
|
||
def __init__(self, message: str = "Authentication failed"):
|
||
super().__init__(message, HTTP_401_UNAUTHORIZED)
|
||
|
||
|
||
class AuthorizationError(ChatAgentException):
|
||
"""Authorization error exception."""
|
||
|
||
def __init__(self, message: str = "Access denied"):
|
||
super().__init__(message, HTTP_403_FORBIDDEN)
|
||
|
||
|
||
class NotFoundError(ChatAgentException):
|
||
"""Resource not found exception."""
|
||
|
||
def __init__(self, message: str = "Resource not found"):
|
||
super().__init__(message, HTTP_404_NOT_FOUND)
|
||
|
||
|
||
class ConversationNotFoundError(NotFoundError):
|
||
"""Conversation not found exception."""
|
||
|
||
def __init__(self, conversation_id: str):
|
||
super().__init__(f"Conversation with ID {conversation_id} not found")
|
||
|
||
|
||
class UserNotFoundError(NotFoundError):
|
||
"""User not found exception."""
|
||
|
||
def __init__(self, user_id: str):
|
||
super().__init__(f"User with ID {user_id} not found")
|
||
|
||
|
||
class ChatServiceError(ChatAgentException):
|
||
"""Chat service error exception."""
|
||
|
||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||
super().__init__(message, HTTP_500_INTERNAL_SERVER_ERROR, details)
|
||
|
||
|
||
class OpenAIError(ChatServiceError):
|
||
"""OpenAI API error exception."""
|
||
|
||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||
super().__init__(f"OpenAI API error: {message}", details)
|
||
|
||
|
||
class RateLimitError(ChatAgentException):
|
||
"""Rate limit exceeded error."""
|
||
pass
|
||
|
||
|
||
class DatabaseError(ChatAgentException):
|
||
"""Database operation error exception."""
|
||
|
||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||
super().__init__(f"Database error: {message}", HTTP_500_INTERNAL_SERVER_ERROR, details)
|
||
|
||
|
||
# Error handlers
|
||
async def chat_agent_exception_handler(request: Request, exc: ChatAgentException) -> JSONResponse:
|
||
"""Handle ChatAgentException and its subclasses."""
|
||
from loguru import logger
|
||
logger.error(
|
||
f"ChatAgentException: {exc.message}",
|
||
extra={
|
||
"status_code": exc.status_code,
|
||
"details": exc.details,
|
||
"path": request.url.path,
|
||
"method": request.method
|
||
}
|
||
)
|
||
|
||
return HxfErrorResponse(exc)
|
||
|
||
|
||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||
"""Handle HTTPException."""
|
||
from loguru import logger
|
||
logger.warning(
|
||
f"HTTPException: {exc.detail}",
|
||
extra={
|
||
"status_code": exc.status_code,
|
||
"path": request.url.path,
|
||
"method": request.method
|
||
}
|
||
)
|
||
|
||
return HxfErrorResponse(exc)
|
||
|
||
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||
"""Handle general exceptions."""
|
||
from loguru import logger
|
||
logger.error(
|
||
f"Unhandled exception: {str(exc)}",
|
||
extra={
|
||
"exception_type": exc.__class__.__name__,
|
||
"path": request.url.path,
|
||
"method": request.method
|
||
},
|
||
exc_info=True
|
||
)
|
||
|
||
return HxfErrorResponse(exc) |