375 lines
11 KiB
Python
375 lines
11 KiB
Python
"""Pydantic schemas for API requests and responses."""
|
||
|
||
from typing import Optional, List, Any, Dict, TYPE_CHECKING
|
||
from datetime import datetime
|
||
from pydantic import BaseModel, Field
|
||
from enum import Enum
|
||
|
||
if TYPE_CHECKING:
|
||
from ..schemas.permission import RoleResponse
|
||
|
||
|
||
class MessageRole(str, Enum):
|
||
"""Message role enumeration."""
|
||
USER = "user"
|
||
ASSISTANT = "assistant"
|
||
SYSTEM = "system"
|
||
|
||
|
||
class MessageType(str, Enum):
|
||
"""Message type enumeration."""
|
||
TEXT = "text"
|
||
IMAGE = "image"
|
||
FILE = "file"
|
||
AUDIO = "audio"
|
||
|
||
|
||
# Base schemas
|
||
class BaseResponse(BaseModel):
|
||
"""Base response schema."""
|
||
id: int
|
||
created_at: datetime
|
||
updated_at: datetime
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
# User schemas
|
||
class UserBase(BaseModel):
|
||
"""User base schema."""
|
||
username: str = Field(..., min_length=3, max_length=50)
|
||
email: str = Field(..., max_length=100)
|
||
full_name: Optional[str] = Field(None, max_length=100)
|
||
bio: Optional[str] = None
|
||
avatar_url: Optional[str] = None
|
||
|
||
|
||
class UserCreate(UserBase):
|
||
"""User creation schema."""
|
||
password: str = Field(..., min_length=6)
|
||
|
||
|
||
class UserUpdate(BaseModel):
|
||
"""User update schema."""
|
||
username: Optional[str] = Field(None, min_length=3, max_length=50)
|
||
email: Optional[str] = Field(None, max_length=100)
|
||
full_name: Optional[str] = Field(None, max_length=100)
|
||
bio: Optional[str] = None
|
||
avatar_url: Optional[str] = None
|
||
password: Optional[str] = Field(None, min_length=6)
|
||
is_active: Optional[bool] = None
|
||
department_id: Optional[int] = None
|
||
|
||
|
||
class UserResponse(BaseResponse, UserBase):
|
||
"""User response schema."""
|
||
is_active: bool
|
||
department_id: Optional[int] = None
|
||
roles: Optional[List['RoleResponse']] = Field(default=[], description="用户角色列表")
|
||
permissions: Optional[List[Dict[str, Any]]] = Field(default=[], description="用户权限列表")
|
||
is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员")
|
||
|
||
@classmethod
|
||
def from_orm(cls, obj):
|
||
"""从ORM对象创建响应对象,安全处理关系属性."""
|
||
# 获取基本字段
|
||
data = {
|
||
'id': obj.id,
|
||
'username': obj.username,
|
||
'email': obj.email,
|
||
'full_name': obj.full_name,
|
||
'is_active': obj.is_active,
|
||
'department_id': obj.department_id,
|
||
'created_at': obj.created_at,
|
||
'updated_at': obj.updated_at,
|
||
'created_by': obj.created_by,
|
||
'updated_by': obj.updated_by,
|
||
}
|
||
|
||
# 安全处理roles关系
|
||
try:
|
||
# 尝试访问roles,如果成功则包含,否则使用空列表
|
||
if hasattr(obj, 'roles') and obj.roles is not None:
|
||
from ..schemas.permission import RoleResponse
|
||
data['roles'] = [RoleResponse.from_orm(role) for role in obj.roles]
|
||
else:
|
||
data['roles'] = []
|
||
except Exception:
|
||
# 如果访问roles失败(DetachedInstanceError),使用空列表
|
||
data['roles'] = []
|
||
|
||
# 添加权限信息
|
||
try:
|
||
# 获取数据库会话
|
||
from sqlalchemy.orm import object_session
|
||
session = object_session(obj)
|
||
|
||
if obj.has_role('SUPER_ADMIN'):
|
||
# 超级管理员拥有所有权限
|
||
if session:
|
||
from ..models.permission import Permission
|
||
all_permissions = session.query(Permission).filter(Permission.is_active == True).all()
|
||
data['permissions'] = [{'code': perm.code, 'name': perm.name} for perm in all_permissions]
|
||
else:
|
||
data['permissions'] = [{'code': '*', 'name': '所有权限'}]
|
||
else:
|
||
# 从角色获取权限
|
||
permissions = set()
|
||
for role in obj.roles:
|
||
if role.is_active:
|
||
for perm in role.permissions:
|
||
if perm.is_active:
|
||
permissions.add((perm.code, perm.name))
|
||
data['permissions'] = [{'code': code, 'name': name} for code, name in permissions]
|
||
except Exception as e:
|
||
data['permissions'] = []
|
||
|
||
# 添加超级管理员状态
|
||
try:
|
||
data['is_superuser'] = obj.is_superuser()
|
||
except Exception:
|
||
data['is_superuser'] = False
|
||
|
||
return cls(**data)
|
||
|
||
|
||
# Authentication schemas
|
||
class LoginRequest(BaseModel):
|
||
"""Login request schema."""
|
||
email: str = Field(..., max_length=100)
|
||
password: str = Field(..., min_length=6)
|
||
|
||
|
||
class Token(BaseModel):
|
||
"""Token response schema."""
|
||
access_token: str
|
||
token_type: str
|
||
expires_in: int
|
||
|
||
|
||
# Conversation schemas
|
||
class ConversationBase(BaseModel):
|
||
"""Conversation base schema."""
|
||
title: str = Field(..., min_length=1, max_length=200)
|
||
system_prompt: Optional[str] = None
|
||
model_name: str = Field(default="gpt-3.5-turbo", max_length=100)
|
||
temperature: str = Field(default="0.7", max_length=10)
|
||
max_tokens: int = Field(default=2048, ge=1, le=8192)
|
||
knowledge_base_id: Optional[int] = None
|
||
|
||
|
||
class ConversationCreate(ConversationBase):
|
||
"""Conversation creation schema."""
|
||
pass
|
||
|
||
|
||
class ConversationUpdate(BaseModel):
|
||
"""Conversation update schema."""
|
||
title: Optional[str] = Field(None, min_length=1, max_length=200)
|
||
system_prompt: Optional[str] = None
|
||
model_name: Optional[str] = Field(None, max_length=100)
|
||
temperature: Optional[str] = Field(None, max_length=10)
|
||
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
|
||
is_archived: Optional[bool] = None
|
||
|
||
|
||
class ConversationResponse(BaseResponse, ConversationBase):
|
||
"""Conversation response schema."""
|
||
user_id: int
|
||
is_archived: bool
|
||
message_count: int = 0
|
||
last_message_at: Optional[datetime] = None
|
||
|
||
|
||
# Message schemas
|
||
class MessageBase(BaseModel):
|
||
"""Message base schema."""
|
||
content: str = Field(..., min_length=1)
|
||
role: MessageRole
|
||
message_type: MessageType = MessageType.TEXT
|
||
metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata")
|
||
|
||
|
||
class MessageCreate(MessageBase):
|
||
"""Message creation schema."""
|
||
conversation_id: int
|
||
|
||
|
||
class MessageResponse(BaseResponse, MessageBase):
|
||
"""Message response schema."""
|
||
conversation_id: int
|
||
context_documents: Optional[List[Dict[str, Any]]] = None
|
||
prompt_tokens: Optional[int] = None
|
||
completion_tokens: Optional[int] = None
|
||
total_tokens: Optional[int] = None
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
populate_by_name = True
|
||
|
||
|
||
# Chat schemas
|
||
class ChatRequest(BaseModel):
|
||
"""Chat request schema."""
|
||
message: str = Field(..., min_length=1, max_length=10000)
|
||
stream: bool = Field(default=False)
|
||
use_knowledge_base: bool = Field(default=False)
|
||
knowledge_base_id: Optional[int] = Field(None, description="Knowledge base ID for RAG mode")
|
||
use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities")
|
||
use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling")
|
||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
||
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""Chat response schema."""
|
||
user_message: MessageResponse
|
||
assistant_message: MessageResponse
|
||
total_tokens: Optional[int] = None
|
||
model_used: str
|
||
|
||
|
||
class StreamChunk(BaseModel):
|
||
"""Stream chunk schema."""
|
||
content: str
|
||
role: MessageRole = MessageRole.ASSISTANT
|
||
finish_reason: Optional[str] = None
|
||
tokens_used: Optional[int] = None
|
||
|
||
|
||
# Knowledge Base schemas
|
||
class KnowledgeBaseBase(BaseModel):
|
||
"""Knowledge base base schema."""
|
||
name: str = Field(..., min_length=1, max_length=100)
|
||
description: Optional[str] = None
|
||
embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2")
|
||
chunk_size: int = Field(default=1000, ge=100, le=5000)
|
||
chunk_overlap: int = Field(default=200, ge=0, le=1000)
|
||
|
||
|
||
class KnowledgeBaseCreate(KnowledgeBaseBase):
|
||
"""Knowledge base creation schema."""
|
||
pass
|
||
|
||
|
||
class KnowledgeBaseUpdate(BaseModel):
|
||
"""Knowledge base update schema."""
|
||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||
description: Optional[str] = None
|
||
embedding_model: Optional[str] = None
|
||
chunk_size: Optional[int] = Field(None, ge=100, le=5000)
|
||
chunk_overlap: Optional[int] = Field(None, ge=0, le=1000)
|
||
is_active: Optional[bool] = None
|
||
|
||
|
||
class KnowledgeBaseResponse(BaseResponse, KnowledgeBaseBase):
|
||
"""Knowledge base response schema."""
|
||
is_active: bool
|
||
vector_db_type: str
|
||
collection_name: Optional[str]
|
||
document_count: int = 0
|
||
active_document_count: int = 0
|
||
|
||
|
||
# Document schemas
|
||
class DocumentBase(BaseModel):
|
||
"""Document base schema."""
|
||
filename: str
|
||
original_filename: str
|
||
file_type: str
|
||
file_size: int
|
||
|
||
|
||
class DocumentUpload(BaseModel):
|
||
"""Document upload schema."""
|
||
knowledge_base_id: int
|
||
process_immediately: bool = Field(default=True)
|
||
|
||
|
||
class DocumentResponse(BaseResponse, DocumentBase):
|
||
"""Document response schema."""
|
||
knowledge_base_id: int
|
||
file_path: str
|
||
mime_type: Optional[str]
|
||
is_processed: bool
|
||
processing_error: Optional[str]
|
||
chunk_count: int = 0
|
||
embedding_model: Optional[str]
|
||
file_size_mb: float
|
||
|
||
|
||
class DocumentListResponse(BaseModel):
|
||
"""Document list response schema."""
|
||
documents: List[DocumentResponse]
|
||
total: int
|
||
page: int
|
||
page_size: int
|
||
|
||
|
||
class DocumentProcessingStatus(BaseModel):
|
||
"""Document processing status schema."""
|
||
document_id: int
|
||
status: str # 'pending', 'processing', 'completed', 'failed'
|
||
progress: float = Field(default=0.0, ge=0.0, le=100.0)
|
||
error_message: Optional[str] = None
|
||
chunks_created: int = 0
|
||
estimated_time_remaining: Optional[int] = None # seconds
|
||
|
||
|
||
# Error schemas
|
||
# Document chunk schemas
|
||
class DocumentChunk(BaseModel):
|
||
"""Document chunk schema."""
|
||
id: str
|
||
content: str
|
||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||
page_number: Optional[int] = None
|
||
chunk_index: int
|
||
start_char: Optional[int] = None
|
||
end_char: Optional[int] = None
|
||
|
||
|
||
class DocumentChunksResponse(BaseModel):
|
||
"""Document chunks response schema."""
|
||
document_id: int
|
||
document_name: str
|
||
total_chunks: int
|
||
chunks: List[DocumentChunk]
|
||
|
||
|
||
class ErrorResponse(BaseModel):
|
||
"""Error response schema."""
|
||
error: str
|
||
detail: Optional[str] = None
|
||
code: Optional[str] = None
|
||
|
||
# 通用返回结构
|
||
class NormalResponse(BaseModel):
|
||
success: bool
|
||
message: str
|
||
|
||
class ExcelPreviewRequest(BaseModel):
|
||
file_id: str
|
||
page: int = 1
|
||
page_size: int = 20
|
||
|
||
class FileListResponse(BaseModel):
|
||
success: bool
|
||
message: str
|
||
data: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
# 解决前向引用问题
|
||
def rebuild_models():
|
||
"""重建模型以解决前向引用问题."""
|
||
try:
|
||
from ..schemas.permission import RoleResponse
|
||
UserResponse.model_rebuild()
|
||
except ImportError:
|
||
# 如果无法导入RoleResponse,跳过重建
|
||
pass
|
||
|
||
|
||
# 在模块加载时尝试重建模型
|
||
rebuild_models() |