hxf/backend/utils/util_schemas.py

511 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pydantic schemas for API requests and responses."""
from typing import Optional, List, Any, Dict, TYPE_CHECKING
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
if TYPE_CHECKING:
from th_agenter.schemas.permission import RoleResponse
class MessageRole(str, Enum):
"""消息角色枚举"""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class MessageType(str, Enum):
"""消息类型枚举"""
TEXT = "text"
IMAGE = "image"
FILE = "file"
AUDIO = "audio"
# Base schemas
class BaseResponse(BaseModel):
"""基础响应模型"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# User schemas
class UserBase(BaseModel):
"""用户基础模型"""
username: str = Field(..., min_length=3, max_length=50)
email: str = Field(..., max_length=100)
full_name: Optional[str] = Field(None, max_length=100)
bio: Optional[str] = None
avatar_url: Optional[str] = None
class UserCreate(UserBase):
"""用户创建模型"""
password: str = Field(..., min_length=6)
class UserUpdate(BaseModel):
"""用户更新模型"""
username: Optional[str] = Field(None, min_length=3, max_length=50)
email: Optional[str] = Field(None, max_length=100)
full_name: Optional[str] = Field(None, max_length=100)
bio: Optional[str] = None
avatar_url: Optional[str] = None
password: Optional[str] = Field(None, min_length=6)
is_active: Optional[bool] = None
department_id: Optional[int] = None
class UserResponse(BaseResponse, UserBase):
"""用户响应模型"""
is_active: bool
department_id: Optional[int] = None
roles: Optional[List['RoleResponse']] = Field(default=[], description="用户角色列表")
permissions: Optional[List[Dict[str, Any]]] = Field(default=[], description="用户权限列表")
is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员")
@classmethod
def from_orm(cls, obj):
"""从ORM对象创建响应对象安全处理关系属性同步版本."""
# 获取基本字段
data = {
'id': obj.id,
'username': obj.username,
'email': obj.email,
'full_name': obj.full_name,
'is_active': obj.is_active,
'department_id': obj.department_id,
'created_at': obj.created_at,
'updated_at': obj.updated_at,
'created_by': obj.created_by,
'updated_by': obj.updated_by,
}
# 安全处理roles关系 - 仅使用已加载的关系,不尝试刷新
try:
if hasattr(obj, 'roles'):
try:
from th_agenter.schemas.permission import RoleResponse
# 仅访问已加载的角色,不触发新查询
data['roles'] = [RoleResponse.from_orm(role) for role in obj.roles if role.is_active]
except Exception:
# 如果访问roles失败DetachedInstanceError或延迟加载错误使用空列表
data['roles'] = []
else:
data['roles'] = []
except Exception:
data['roles'] = []
# 安全处理权限信息 - 仅使用已加载的关系,不尝试刷新
try:
permissions = set()
if hasattr(obj, 'roles'):
try:
for role in obj.roles:
if role.is_active:
try:
for perm in role.permissions:
if perm.is_active:
permissions.add((perm.code, perm.name))
except Exception:
# 权限加载失败,跳过
continue
except Exception:
# 角色加载失败,跳过
pass
data['permissions'] = [{'code': code, 'name': name} for code, name in permissions]
except Exception:
data['permissions'] = []
# 添加is_superuser字段
try:
# 检查是否有is_admin属性或is_superuser属性
if hasattr(obj, 'is_admin'):
data['is_superuser'] = obj.is_admin
elif hasattr(obj, 'is_superuser'):
if callable(obj.is_superuser):
try:
data['is_superuser'] = obj.is_superuser()
except Exception:
data['is_superuser'] = False
else:
data['is_superuser'] = obj.is_superuser
else:
data['is_superuser'] = False
except Exception:
data['is_superuser'] = False
return cls(**data)
@classmethod
async def from_orm_async(cls, obj):
"""从ORM对象创建响应对象安全处理关系属性异步版本."""
# 获取基本字段
data = {
'id': obj.id,
'username': obj.username,
'email': obj.email,
'full_name': obj.full_name,
'is_active': obj.is_active,
'department_id': obj.department_id,
'created_at': obj.created_at,
'updated_at': obj.updated_at,
'created_by': obj.created_by,
'updated_by': obj.updated_by,
}
# 安全处理roles关系
try:
from sqlalchemy.orm import object_session
from sqlalchemy.ext.asyncio import AsyncSession
session = object_session(obj)
roles_loaded = []
if hasattr(obj, 'roles'):
# 根据会话类型加载角色
if session and isinstance(session, AsyncSession):
# 异步会话使用await刷新
await session.refresh(obj, ['roles'])
roles_loaded = obj.roles if obj.roles is not None else []
else:
# 同步会话或无会话,直接访问
try:
roles_loaded = obj.roles if obj.roles is not None else []
except Exception:
roles_loaded = []
else:
roles_loaded = []
from th_agenter.schemas.permission import RoleResponse
data['roles'] = [RoleResponse.from_orm(role) for role in roles_loaded]
except Exception as e:
# 如果访问roles失败使用空列表
data['roles'] = []
# 添加权限信息
try:
# 获取数据库会话
from sqlalchemy.orm import object_session
session = object_session(obj)
is_super_admin = False
if hasattr(obj, 'has_role'):
if callable(obj.has_role):
# 检查has_role是否为异步方法
import inspect
if inspect.iscoroutinefunction(obj.has_role):
is_super_admin = await obj.has_role('SUPER_ADMIN')
else:
is_super_admin = obj.has_role('SUPER_ADMIN')
if is_super_admin:
# 超级管理员拥有所有权限
if session:
from th_agenter.models.permission import Permission
if isinstance(session, AsyncSession):
from sqlalchemy import select
all_permissions = await session.execute(select(Permission).filter(Permission.is_active == True))
all_permissions = all_permissions.scalars().all()
else:
all_permissions = session.query(Permission).filter(Permission.is_active == True).all()
data['permissions'] = [{'code': perm.code, 'name': perm.name} for perm in all_permissions]
else:
data['permissions'] = [{'code': '*', 'name': '所有权限'}]
else:
# 从角色获取权限
permissions = set()
# 使用已加载的角色,避免再次访问关系
for role in roles_loaded:
if role.is_active:
# 同样处理role.permissions关系
role_perms = []
if hasattr(role, 'permissions'):
try:
if session and isinstance(session, AsyncSession):
await session.refresh(role, ['permissions'])
role_perms = role.permissions if role.permissions is not None else []
else:
role_perms = role.permissions if role.permissions is not None else []
except Exception:
role_perms = []
for perm in role_perms:
if perm.is_active:
permissions.add((perm.code, perm.name))
data['permissions'] = [{'code': code, 'name': name} for code, name in permissions]
except Exception as e:
# 如果访问权限失败,使用空列表
data['permissions'] = []
# 添加is_superuser字段
try:
# 检查是否有is_admin属性或is_superuser属性
if hasattr(obj, 'is_admin'):
data['is_superuser'] = obj.is_admin
elif hasattr(obj, 'is_superuser'):
if callable(obj.is_superuser):
import inspect
if inspect.iscoroutinefunction(obj.is_superuser):
data['is_superuser'] = await obj.is_superuser()
else:
data['is_superuser'] = obj.is_superuser()
else:
data['is_superuser'] = obj.is_superuser
else:
data['is_superuser'] = False
except Exception:
data['is_superuser'] = False
return cls(**data)
# Authentication schemas
class LoginRequest(BaseModel):
"""登录请求模型"""
email: str = Field(..., max_length=100)
password: str = Field(..., min_length=6)
class Token(BaseModel):
"""访问令牌响应模型"""
access_token: str
token_type: str
expires_in: int
# Conversation schemas
class ConversationBase(BaseModel):
"""对话基础模型"""
title: str = Field(..., min_length=1, max_length=200)
system_prompt: Optional[str] = None
model_name: str = Field(default="gpt-3.5-turbo", max_length=100)
temperature: str = Field(default="0.7", max_length=10)
max_tokens: int = Field(default=2048, ge=1, le=8192)
knowledge_base_id: Optional[int] = None
class ConversationCreate(ConversationBase):
"""对话创建模型"""
pass
class ConversationUpdate(BaseModel):
"""对话更新模型"""
title: Optional[str] = Field(None, min_length=1, max_length=200)
system_prompt: Optional[str] = None
model_name: Optional[str] = Field(None, max_length=100)
temperature: Optional[str] = Field(None, max_length=10)
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
is_archived: Optional[bool] = None
class ConversationResponse(BaseResponse, ConversationBase):
"""对话响应模型"""
user_id: int
is_archived: bool
message_count: int = 0
last_message_at: Optional[datetime] = None
messages: Optional[List["MessageResponse"]] = None
# Message schemas
class MessageBase(BaseModel):
"""消息基础模型"""
content: str = Field(..., min_length=1)
role: MessageRole
message_type: MessageType = MessageType.TEXT
metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata")
class MessageCreate(MessageBase):
"""消息创建模型"""
conversation_id: int
class MessageResponse(BaseResponse, MessageBase):
"""消息响应模型"""
conversation_id: int
context_documents: Optional[List[Dict[str, Any]]] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
class Config:
from_attributes = True
populate_by_name = True
# Chat schemas
class ChatRequest(BaseModel):
"""聊天请求模型"""
message: str = Field(..., min_length=1, max_length=10000)
stream: bool = Field(default=False)
use_knowledge_base: bool = Field(default=False)
knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode")
use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities")
use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling")
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192)
class ChatResponse(BaseModel):
"""聊天响应模型"""
user_message: MessageResponse
assistant_message: MessageResponse
total_tokens: Optional[int] = None
model_used: str
class StreamChunk(BaseModel):
"""流式响应块模型"""
content: str
role: MessageRole = MessageRole.ASSISTANT
finish_reason: Optional[str] = None
tokens_used: Optional[int] = None
# Knowledge Base schemas
class KnowledgeBaseBase(BaseModel):
"""知识库基础模型"""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = None
embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2")
chunk_size: int = Field(default=1000, ge=100, le=5000)
chunk_overlap: int = Field(default=200, ge=0, le=1000)
class KnowledgeBaseCreate(KnowledgeBaseBase):
"""知识库创建模型"""
pass
class KnowledgeBaseUpdate(BaseModel):
"""知识库更新模型"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
embedding_model: Optional[str] = None
chunk_size: Optional[int] = Field(None, ge=100, le=5000)
chunk_overlap: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
class KnowledgeBaseResponse(BaseResponse, KnowledgeBaseBase):
"""知识库响应模型"""
is_active: bool
vector_db_type: str
collection_name: Optional[str]
document_count: int = 0
active_document_count: int = 0
# Document schemas
class DocumentBase(BaseModel):
"""文档基础模型"""
filename: str
original_filename: str
file_type: str
file_size: int
class DocumentUpload(BaseModel):
"""文档上传模型"""
knowledge_base_id: int
process_immediately: bool = Field(default=True)
class DocumentResponse(BaseResponse, DocumentBase):
"""文档响应模型"""
knowledge_base_id: int
file_path: str
mime_type: Optional[str]
is_processed: bool
processing_error: Optional[str]
chunk_count: int = 0
embedding_model: Optional[str]
file_size_mb: float
class DocumentListResponse(BaseModel):
"""文档列表响应模型"""
documents: List[DocumentResponse]
total: int
page: int
page_size: int
class DocumentProcessingStatus(BaseModel):
"""文档处理状态模型"""
document_id: int
status: str # 'pending', 'processing', 'completed', 'failed'
progress: float = Field(default=0.0, ge=0.0, le=100.0)
error_message: Optional[str] = None
chunks_created: int = 0
estimated_time_remaining: Optional[int] = None # seconds
# Error schemas
# Document chunk schemas
class DocumentChunk(BaseModel):
"""文档分块模型"""
id: str
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
page_number: Optional[int] = None
chunk_index: int
start_char: Optional[int] = None
end_char: Optional[int] = None
vector_id: Optional[str] = None
class DocumentChunksResponse(BaseModel):
"""文档分块响应模型"""
document_id: int
document_name: str
total_chunks: int
chunks: List[DocumentChunk]
class ErrorResponse(BaseModel):
"""错误响应模型"""
error: str
detail: Optional[str] = None
code: Optional[str] = None
# 通用返回结构
class NormalResponse(BaseModel):
"""通用返回模型"""
success: bool
message: str
data: Optional[Dict[str, Any]] = None
class ExcelPreviewRequest(BaseModel):
"""Excel预览请求模型"""
file_id: str
page: int = 1
page_size: int = 20
class FileListResponse(BaseModel):
"""文件列表响应模型"""
success: bool
message: str
data: Optional[Dict[str, Any]] = None
# 解决前向引用问题
def rebuild_models():
"""重建模型以解决前向引用问题."""
try:
from th_agenter.schemas.permission import RoleResponse
UserResponse.model_rebuild()
except ImportError:
# 如果无法导入RoleResponse跳过重建
pass
# 在模块加载时尝试重建模型
rebuild_models()