hxf/backend/th_agenter/utils/schemas.py

375 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pydantic schemas for API requests and responses."""
from typing import Optional, List, Any, Dict, TYPE_CHECKING
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
if TYPE_CHECKING:
from ..schemas.permission import RoleResponse
class MessageRole(str, Enum):
"""Message role enumeration."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class MessageType(str, Enum):
"""Message type enumeration."""
TEXT = "text"
IMAGE = "image"
FILE = "file"
AUDIO = "audio"
# Base schemas
class BaseResponse(BaseModel):
"""Base response schema."""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# User schemas
class UserBase(BaseModel):
"""User base schema."""
username: str = Field(..., min_length=3, max_length=50)
email: str = Field(..., max_length=100)
full_name: Optional[str] = Field(None, max_length=100)
bio: Optional[str] = None
avatar_url: Optional[str] = None
class UserCreate(UserBase):
"""User creation schema."""
password: str = Field(..., min_length=6)
class UserUpdate(BaseModel):
"""User update schema."""
username: Optional[str] = Field(None, min_length=3, max_length=50)
email: Optional[str] = Field(None, max_length=100)
full_name: Optional[str] = Field(None, max_length=100)
bio: Optional[str] = None
avatar_url: Optional[str] = None
password: Optional[str] = Field(None, min_length=6)
is_active: Optional[bool] = None
department_id: Optional[int] = None
class UserResponse(BaseResponse, UserBase):
"""User response schema."""
is_active: bool
department_id: Optional[int] = None
roles: Optional[List['RoleResponse']] = Field(default=[], description="用户角色列表")
permissions: Optional[List[Dict[str, Any]]] = Field(default=[], description="用户权限列表")
is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员")
@classmethod
def from_orm(cls, obj):
"""从ORM对象创建响应对象安全处理关系属性."""
# 获取基本字段
data = {
'id': obj.id,
'username': obj.username,
'email': obj.email,
'full_name': obj.full_name,
'is_active': obj.is_active,
'department_id': obj.department_id,
'created_at': obj.created_at,
'updated_at': obj.updated_at,
'created_by': obj.created_by,
'updated_by': obj.updated_by,
}
# 安全处理roles关系
try:
# 尝试访问roles如果成功则包含否则使用空列表
if hasattr(obj, 'roles') and obj.roles is not None:
from ..schemas.permission import RoleResponse
data['roles'] = [RoleResponse.from_orm(role) for role in obj.roles]
else:
data['roles'] = []
except Exception:
# 如果访问roles失败DetachedInstanceError使用空列表
data['roles'] = []
# 添加权限信息
try:
# 获取数据库会话
from sqlalchemy.orm import object_session
session = object_session(obj)
if obj.has_role('SUPER_ADMIN'):
# 超级管理员拥有所有权限
if session:
from ..models.permission import Permission
all_permissions = session.query(Permission).filter(Permission.is_active == True).all()
data['permissions'] = [{'code': perm.code, 'name': perm.name} for perm in all_permissions]
else:
data['permissions'] = [{'code': '*', 'name': '所有权限'}]
else:
# 从角色获取权限
permissions = set()
for role in obj.roles:
if role.is_active:
for perm in role.permissions:
if perm.is_active:
permissions.add((perm.code, perm.name))
data['permissions'] = [{'code': code, 'name': name} for code, name in permissions]
except Exception as e:
data['permissions'] = []
# 添加超级管理员状态
try:
data['is_superuser'] = obj.is_superuser()
except Exception:
data['is_superuser'] = False
return cls(**data)
# Authentication schemas
class LoginRequest(BaseModel):
"""Login request schema."""
email: str = Field(..., max_length=100)
password: str = Field(..., min_length=6)
class Token(BaseModel):
"""Token response schema."""
access_token: str
token_type: str
expires_in: int
# Conversation schemas
class ConversationBase(BaseModel):
"""Conversation base schema."""
title: str = Field(..., min_length=1, max_length=200)
system_prompt: Optional[str] = None
model_name: str = Field(default="gpt-3.5-turbo", max_length=100)
temperature: str = Field(default="0.7", max_length=10)
max_tokens: int = Field(default=2048, ge=1, le=8192)
knowledge_base_id: Optional[int] = None
class ConversationCreate(ConversationBase):
"""Conversation creation schema."""
pass
class ConversationUpdate(BaseModel):
"""Conversation update schema."""
title: Optional[str] = Field(None, min_length=1, max_length=200)
system_prompt: Optional[str] = None
model_name: Optional[str] = Field(None, max_length=100)
temperature: Optional[str] = Field(None, max_length=10)
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
is_archived: Optional[bool] = None
class ConversationResponse(BaseResponse, ConversationBase):
"""Conversation response schema."""
user_id: int
is_archived: bool
message_count: int = 0
last_message_at: Optional[datetime] = None
# Message schemas
class MessageBase(BaseModel):
"""Message base schema."""
content: str = Field(..., min_length=1)
role: MessageRole
message_type: MessageType = MessageType.TEXT
metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata")
class MessageCreate(MessageBase):
"""Message creation schema."""
conversation_id: int
class MessageResponse(BaseResponse, MessageBase):
"""Message response schema."""
conversation_id: int
context_documents: Optional[List[Dict[str, Any]]] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
class Config:
from_attributes = True
populate_by_name = True
# Chat schemas
class ChatRequest(BaseModel):
"""Chat request schema."""
message: str = Field(..., min_length=1, max_length=10000)
stream: bool = Field(default=False)
use_knowledge_base: bool = Field(default=False)
knowledge_base_id: Optional[int] = Field(None, description="Knowledge base ID for RAG mode")
use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities")
use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
class ChatResponse(BaseModel):
"""Chat response schema."""
user_message: MessageResponse
assistant_message: MessageResponse
total_tokens: Optional[int] = None
model_used: str
class StreamChunk(BaseModel):
"""Stream chunk schema."""
content: str
role: MessageRole = MessageRole.ASSISTANT
finish_reason: Optional[str] = None
tokens_used: Optional[int] = None
# Knowledge Base schemas
class KnowledgeBaseBase(BaseModel):
"""Knowledge base base schema."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = None
embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2")
chunk_size: int = Field(default=1000, ge=100, le=5000)
chunk_overlap: int = Field(default=200, ge=0, le=1000)
class KnowledgeBaseCreate(KnowledgeBaseBase):
"""Knowledge base creation schema."""
pass
class KnowledgeBaseUpdate(BaseModel):
"""Knowledge base update schema."""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
embedding_model: Optional[str] = None
chunk_size: Optional[int] = Field(None, ge=100, le=5000)
chunk_overlap: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
class KnowledgeBaseResponse(BaseResponse, KnowledgeBaseBase):
"""Knowledge base response schema."""
is_active: bool
vector_db_type: str
collection_name: Optional[str]
document_count: int = 0
active_document_count: int = 0
# Document schemas
class DocumentBase(BaseModel):
"""Document base schema."""
filename: str
original_filename: str
file_type: str
file_size: int
class DocumentUpload(BaseModel):
"""Document upload schema."""
knowledge_base_id: int
process_immediately: bool = Field(default=True)
class DocumentResponse(BaseResponse, DocumentBase):
"""Document response schema."""
knowledge_base_id: int
file_path: str
mime_type: Optional[str]
is_processed: bool
processing_error: Optional[str]
chunk_count: int = 0
embedding_model: Optional[str]
file_size_mb: float
class DocumentListResponse(BaseModel):
"""Document list response schema."""
documents: List[DocumentResponse]
total: int
page: int
page_size: int
class DocumentProcessingStatus(BaseModel):
"""Document processing status schema."""
document_id: int
status: str # 'pending', 'processing', 'completed', 'failed'
progress: float = Field(default=0.0, ge=0.0, le=100.0)
error_message: Optional[str] = None
chunks_created: int = 0
estimated_time_remaining: Optional[int] = None # seconds
# Error schemas
# Document chunk schemas
class DocumentChunk(BaseModel):
"""Document chunk schema."""
id: str
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
page_number: Optional[int] = None
chunk_index: int
start_char: Optional[int] = None
end_char: Optional[int] = None
class DocumentChunksResponse(BaseModel):
"""Document chunks response schema."""
document_id: int
document_name: str
total_chunks: int
chunks: List[DocumentChunk]
class ErrorResponse(BaseModel):
"""Error response schema."""
error: str
detail: Optional[str] = None
code: Optional[str] = None
# 通用返回结构
class NormalResponse(BaseModel):
success: bool
message: str
class ExcelPreviewRequest(BaseModel):
file_id: str
page: int = 1
page_size: int = 20
class FileListResponse(BaseModel):
success: bool
message: str
data: Optional[Dict[str, Any]] = None
# 解决前向引用问题
def rebuild_models():
"""重建模型以解决前向引用问题."""
try:
from ..schemas.permission import RoleResponse
UserResponse.model_rebuild()
except ImportError:
# 如果无法导入RoleResponse跳过重建
pass
# 在模块加载时尝试重建模型
rebuild_models()