Compare commits
2 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
c6b7bb32e5 | |
|
|
b6d5489459 |
|
|
@ -1,4 +1 @@
|
|||
Generic single-database configuration with an async dbapi.
|
||||
|
||||
alembic revision --autogenerate -m "init"
|
||||
alembic upgrade head
|
||||
Generic single-database configuration with an async dbapi.
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
"""Initial migration
|
||||
|
||||
Revision ID: 424646027786
|
||||
Revises:
|
||||
Create Date: 2025-12-16 09:56:45.172954
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '424646027786'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('agent_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('enabled_tools', sa.JSON(), nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('system_message', sa.Text(), nullable=True),
|
||||
sa.Column('verbose', sa.Boolean(), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False)
|
||||
op.create_table('conversations',
|
||||
sa.Column('title', sa.String(length=200), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=True),
|
||||
sa.Column('system_prompt', sa.Text(), nullable=True),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_archived', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations'))
|
||||
)
|
||||
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
|
||||
op.create_table('database_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('db_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('host', sa.String(length=255), nullable=False),
|
||||
sa.Column('port', sa.Integer(), nullable=False),
|
||||
sa.Column('database', sa.String(length=100), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('password', sa.Text(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('connection_params', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')),
|
||||
sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type'))
|
||||
)
|
||||
op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False)
|
||||
op.create_table('documents',
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('doc_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('chunk_count', sa.Integer(), nullable=False),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=True),
|
||||
sa.Column('vector_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_documents'))
|
||||
)
|
||||
op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False)
|
||||
op.create_table('excel_files',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('sheet_names', sa.JSON(), nullable=False),
|
||||
sa.Column('default_sheet', sa.String(length=100), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('preview_data', sa.JSON(), nullable=False),
|
||||
sa.Column('data_types', sa.JSON(), nullable=True),
|
||||
sa.Column('total_rows', sa.JSON(), nullable=True),
|
||||
sa.Column('total_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('last_accessed', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files'))
|
||||
)
|
||||
op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False)
|
||||
op.create_table('knowledge_bases',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=False),
|
||||
sa.Column('chunk_size', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_overlap', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('vector_db_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('collection_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases'))
|
||||
)
|
||||
op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False)
|
||||
op.create_table('llm_configs',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('api_key', sa.String(length=500), nullable=False),
|
||||
sa.Column('base_url', sa.String(length=200), nullable=True),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.Float(), nullable=False),
|
||||
sa.Column('top_p', sa.Float(), nullable=False),
|
||||
sa.Column('frequency_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('presence_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_embedding', sa.Boolean(), nullable=False),
|
||||
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||
op.create_table('messages',
|
||||
sa.Column('conversation_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False),
|
||||
sa.Column('message_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('context_documents', sa.JSON(), nullable=True),
|
||||
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('completion_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
|
||||
)
|
||||
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
|
||||
op.create_table('roles',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_system', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_roles'))
|
||||
)
|
||||
op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True)
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||
op.create_table('table_metadata',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('table_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('table_schema', sa.String(length=50), nullable=False),
|
||||
sa.Column('table_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('table_comment', sa.Text(), nullable=True),
|
||||
sa.Column('database_config_id', sa.Integer(), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('primary_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('foreign_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('indexes', sa.JSON(), nullable=True),
|
||||
sa.Column('sample_data', sa.JSON(), nullable=True),
|
||||
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||
sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False),
|
||||
sa.Column('qa_description', sa.Text(), nullable=True),
|
||||
sa.Column('business_context', sa.Text(), nullable=True),
|
||||
sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata'))
|
||||
)
|
||||
op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('username', sa.String(length=50), nullable=False),
|
||||
sa.Column('email', sa.String(length=100), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.Column('bio', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_users'))
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('user_roles',
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')),
|
||||
sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles'))
|
||||
)
|
||||
op.create_table('workflows',
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'),
|
||||
sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
|
||||
sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'),
|
||||
sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False)
|
||||
op.create_table('workflow_executions',
|
||||
sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')),
|
||||
sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False)
|
||||
op.create_table('node_executions',
|
||||
sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'),
|
||||
sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'),
|
||||
sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'),
|
||||
sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions')
|
||||
op.drop_table('node_executions')
|
||||
op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions')
|
||||
op.drop_table('workflow_executions')
|
||||
op.drop_index(op.f('ix_workflows_id'), table_name='workflows')
|
||||
op.drop_table('workflows')
|
||||
op.drop_table('user_roles')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata')
|
||||
op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata')
|
||||
op.drop_table('table_metadata')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_code'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
op.drop_index(op.f('ix_messages_id'), table_name='messages')
|
||||
op.drop_table('messages')
|
||||
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases')
|
||||
op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases')
|
||||
op.drop_table('knowledge_bases')
|
||||
op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files')
|
||||
op.drop_table('excel_files')
|
||||
op.drop_index(op.f('ix_documents_id'), table_name='documents')
|
||||
op.drop_table('documents')
|
||||
op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs')
|
||||
op.drop_table('database_configs')
|
||||
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
|
||||
op.drop_table('conversations')
|
||||
op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs')
|
||||
op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs')
|
||||
op.drop_table('agent_configs')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
"""Add message_count and last_message_at to conversations
|
||||
|
||||
Revision ID: 8da391c6e2b7
|
||||
Revises: 424646027786
|
||||
Create Date: 2025-12-19 16:16:29.943314
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '8da391c6e2b7'
|
||||
down_revision: Union[str, Sequence[str], None] = '424646027786'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('conversations', sa.Column('message_count', sa.Integer(), nullable=False))
|
||||
op.add_column('conversations', sa.Column('last_message_at', sa.DateTime(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('conversations', 'last_message_at')
|
||||
op.drop_column('conversations', 'message_count')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
async def check_table_constraints():
|
||||
try:
|
||||
# 获取数据库连接字符串
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+asyncmy://root:123456@localhost:3306/th_agenter")
|
||||
|
||||
# 创建异步引擎
|
||||
engine = create_async_engine(DATABASE_URL, echo=True)
|
||||
|
||||
# 创建会话
|
||||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with async_session() as session:
|
||||
# 获取数据库连接
|
||||
async with session.begin():
|
||||
# 使用inspect查看表结构
|
||||
inspector = inspect(engine)
|
||||
|
||||
# 获取messages表的所有约束
|
||||
constraints = await engine.run_sync(inspector.get_table_constraints, 'messages')
|
||||
print("Messages表的所有约束:")
|
||||
for constraint in constraints:
|
||||
print(f" 约束名称: {constraint['name']}, 类型: {constraint['type']}")
|
||||
if constraint['type'] == 'PRIMARY KEY':
|
||||
print(f" 主键约束列: {constraint['constrained_columns']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检查约束时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_table_constraints())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -38,7 +38,7 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
|||
async def http_exception_handler(request, exc):
|
||||
from utils.util_exceptions import HxfErrorResponse
|
||||
logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}")
|
||||
return HxfErrorResponse(exc.status_code, exc.detail)
|
||||
return HxfErrorResponse(exc)
|
||||
|
||||
def make_json_serializable(obj):
|
||||
"""递归地将对象转换为JSON可序列化的格式"""
|
||||
|
|
@ -127,31 +127,5 @@ def add_router(app: FastAPI) -> None:
|
|||
# Include routers
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
|
||||
# app.include_router(table_metadata.router)
|
||||
# # 在现有导入中添加
|
||||
# from ..api.endpoints import database_config
|
||||
|
||||
# # 在路由注册部分添加
|
||||
# app.include_router(database_config.router)
|
||||
# # Health check endpoint
|
||||
# @app.get("/health")
|
||||
# async def health_check():
|
||||
# return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# # Root endpoint
|
||||
# @app.get("/")
|
||||
# async def root():
|
||||
# return {"message": "Chat Agent API is running"}
|
||||
|
||||
# # Test endpoint
|
||||
# @app.get("/test")
|
||||
# async def test_endpoint():
|
||||
# return {"message": "API is working"}
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
# from utils.util_test import test_db
|
||||
# test_db()
|
||||
# from test.example import internet_search_tool
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# Test package for PostgreSQL agent functionality
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from deepagents import create_deep_agent
|
||||
from openai import OpenAI
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具
|
||||
from deepagents.backends import StoreBackend
|
||||
from loguru import logger
|
||||
def internet_search_tool(query: str):
|
||||
"""Run a web search"""
|
||||
logger.info(f"Running internet search for query: {query}")
|
||||
client = OpenAI(
|
||||
api_key=os.getenv('DASHSCOPE_API_KEY'),
|
||||
base_url=os.getenv('DASHSCOPE_BASE_URL'),
|
||||
)
|
||||
logger.info(f"create OpenAI")
|
||||
completion = client.chat.completions.create(
|
||||
model="qwen-plus",
|
||||
messages=[
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||
{'role': 'user', 'content': query}
|
||||
],
|
||||
extra_body={
|
||||
"enable_search": True
|
||||
}
|
||||
)
|
||||
logger.info(f"create completions")
|
||||
logger.info(f"OpenAI response: {completion.choices[0].message.content}")
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
|
||||
# System prompt to steer the agent to be an expert researcher
|
||||
today = datetime.now().strftime("%Y年%m月%d日")
|
||||
research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。
|
||||
|
||||
你可以使用互联网搜索工具来获取信息。
|
||||
## `internet_search`
|
||||
使用此工具对给定查询进行互联网搜索。你可以指定返回结果的最大数量、主题以及是否包含原始内容。
|
||||
|
||||
今天的日期是:{today}
|
||||
"""
|
||||
|
||||
# Create the deep agent with memory
|
||||
model = init_chat_model(
|
||||
model="gpt-4.1-mini",
|
||||
model_provider='openai',
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
base_url=os.getenv('OPENAI_BASE_URL'),
|
||||
)
|
||||
checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史
|
||||
|
||||
agent = create_deep_agent( # state:thread会话级的状态
|
||||
tools=[internet_search_tool],
|
||||
system_prompt=research_instructions,
|
||||
model=model,
|
||||
checkpointer=checkpointer, # 添加检查点,启用自动记忆
|
||||
interrupt_on={'internet_search_tool':True}
|
||||
)
|
||||
|
||||
# 多轮对话循环(使用 Checkpointer 自动记忆)
|
||||
printed_msg_ids = set() # 跟踪已打印的消息ID
|
||||
thread_id = "user_session_001" # 会话 ID,区分不同用户/会话
|
||||
config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话
|
||||
|
||||
print("开始对话(输入 'exit' 退出):")
|
||||
while True:
|
||||
user_input = input("\nHUMAN: ").strip()
|
||||
if user_input.lower() == 'exit':
|
||||
break
|
||||
|
||||
# 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印
|
||||
pending_resume = None
|
||||
while True:
|
||||
if pending_resume is None:
|
||||
request = {"messages": [{"role": "user", "content": user_input}]}
|
||||
else:
|
||||
from langgraph.types import Command as _Command
|
||||
|
||||
request = _Command(resume=pending_resume)
|
||||
pending_resume = None
|
||||
|
||||
for item in agent.stream(
|
||||
request,
|
||||
config=config,
|
||||
stream_mode="values",
|
||||
):
|
||||
state = item[0] if isinstance(item, tuple) and len(item) == 2 else item
|
||||
|
||||
# 先检查是否触发了 Human-In-The-Loop 中断
|
||||
if isinstance(state, dict) and "__interrupt__" in state:
|
||||
interrupts = state["__interrupt__"] or []
|
||||
if interrupts:
|
||||
hitl_payload = interrupts[0].value
|
||||
action_requests = hitl_payload.get("action_requests", [])
|
||||
|
||||
print("\n=== 需要人工审批的工具调用 ===")
|
||||
decisions: list[dict[str, str]] = []
|
||||
for idx, ar in enumerate(action_requests):
|
||||
name = ar.get("name")
|
||||
args = ar.get("args")
|
||||
print(f"[{idx}] 工具 {name} 参数: {args}")
|
||||
while True:
|
||||
choice = input(" 决策 (a=approve, r=reject): ").strip().lower()
|
||||
if choice in ("a", "r"):
|
||||
break
|
||||
decisions.append({"type": "approve" if choice == "a" else "reject"})
|
||||
|
||||
# 下一轮调用改为 resume,同一轮用户回合继续往下跑
|
||||
pending_resume = {"decisions": decisions}
|
||||
break
|
||||
|
||||
# 兼容 dict state 和 AgentState dataclass
|
||||
messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", [])
|
||||
for msg in messages:
|
||||
msg_id = getattr(msg, "id", None)
|
||||
if msg_id is not None and msg_id in printed_msg_ids:
|
||||
continue
|
||||
if msg_id is not None:
|
||||
printed_msg_ids.add(msg_id)
|
||||
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
# 用户输入已经在命令行里,不再重复打印
|
||||
continue
|
||||
|
||||
if msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
if tool_calls:
|
||||
# 这是发起工具调用的 AI 消息(TOOL CALL)
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("name")
|
||||
args = tc.get("args")
|
||||
print(f"TOOL CALL [{tool_name}]: {args}")
|
||||
# 如果 AI 同时带有自然语言内容,也一起打印
|
||||
if getattr(msg, "content", None):
|
||||
print(f"AI: {msg.content}")
|
||||
continue
|
||||
|
||||
if msg_type == "tool":
|
||||
# 工具执行结果(TOOL RESPONSE)
|
||||
tool_name = getattr(msg, "name", None) or "tool"
|
||||
print(f"TOOL RESPONSE [{tool_name}]: {msg.content}")
|
||||
continue
|
||||
|
||||
# 兜底:其它类型直接打印出来便于调试
|
||||
print(f"[{msg_type}]: {getattr(msg, 'content', None)}")
|
||||
|
||||
# 如果没有新的中断需要 resume,则整轮结束,等待下一轮用户输入
|
||||
if pending_resume is None:
|
||||
break
|
||||
|
|
@ -95,11 +95,13 @@ async def login_oauth(
|
|||
)
|
||||
session.desc = f"用户 {user.username} OAuth2 登录成功"
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
return HxfResponse(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/refresh", response_model=Token, summary="刷新访问token")
|
||||
async def refresh_token(
|
||||
|
|
@ -113,15 +115,17 @@ async def refresh_token(
|
|||
session, data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return Token(
|
||||
response = Token(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.security.access_token_expire_minutes * 60
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
|
||||
async def get_current_user_info(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
return UserResponse.model_validate(current_user, from_attributes=True)
|
||||
response = UserResponse.model_validate(current_user, from_attributes=True)
|
||||
return HxfResponse(response)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""Chat endpoints for TH Agenter."""
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -11,6 +12,8 @@ from ...models.user import User
|
|||
from ...services.auth import AuthService
|
||||
from ...services.chat import ChatService
|
||||
from ...services.conversation import ConversationService
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
from utils.util_schemas import (
|
||||
ConversationCreate,
|
||||
ConversationResponse,
|
||||
|
|
@ -23,83 +26,6 @@ from utils.util_schemas import (
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# Conversation management
|
||||
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
||||
async def create_conversation(
|
||||
conversation_data: ConversationCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""创建新对话"""
|
||||
session.desc = "START: 创建新对话"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
conversation_data=conversation_data
|
||||
)
|
||||
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
||||
async def list_conversations(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
order_by: str = "updated_at",
|
||||
order_desc: bool = True,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话列表"""
|
||||
session.desc = "START: 获取用户对话列表"
|
||||
conversation_service = ConversationService(session)
|
||||
conversations = await conversation_service.get_user_conversations(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search_query=search,
|
||||
include_archived=include_archived,
|
||||
order_by=order_by,
|
||||
order_desc=order_desc
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话"
|
||||
return [ConversationResponse.model_validate(conv) for conv in conversations]
|
||||
|
||||
@router.get("/conversations/count", summary="获取用户对话总数")
|
||||
async def get_conversations_count(
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话总数"""
|
||||
session.desc = "START: 获取用户对话总数"
|
||||
conversation_service = ConversationService(session)
|
||||
count = await conversation_service.get_user_conversations_count(
|
||||
search_query=search,
|
||||
include_archived=include_archived
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
||||
return {"count": count}
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取指定对话"""
|
||||
session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.get_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Conversation not found"
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
|
||||
async def update_conversation(
|
||||
conversation_id: int,
|
||||
|
|
@ -113,7 +39,8 @@ async def update_conversation(
|
|||
conversation_id, conversation_update
|
||||
)
|
||||
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(updated_conversation)
|
||||
response = ConversationResponse.model_validate(updated_conversation)
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
|
||||
|
|
@ -126,7 +53,8 @@ async def delete_conversation(
|
|||
conversation_service = ConversationService(session)
|
||||
await conversation_service.delete_conversation(conversation_id)
|
||||
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation deleted successfully"}
|
||||
response = {"message": "Conversation deleted successfully"}
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
|
||||
|
|
@ -145,7 +73,8 @@ async def archive_conversation(
|
|||
)
|
||||
|
||||
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation archived successfully"}
|
||||
response = {"message": "Conversation archived successfully"}
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
|
||||
|
|
@ -165,7 +94,8 @@ async def unarchive_conversation(
|
|||
)
|
||||
|
||||
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation unarchived successfully"}
|
||||
response = {"message": "Conversation unarchived successfully"}
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
# Message management
|
||||
|
|
@ -183,7 +113,8 @@ async def get_conversation_messages(
|
|||
conversation_id, skip=skip, limit=limit
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||
return [MessageResponse.model_validate(msg) for msg in messages]
|
||||
response = [MessageResponse.model_validate(msg) for msg in messages]
|
||||
return HxfResponse(response)
|
||||
|
||||
# Chat functionality
|
||||
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
|
||||
|
|
@ -195,48 +126,158 @@ async def chat(
|
|||
"""发送消息并获取AI响应"""
|
||||
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
|
||||
chat_service = ChatService(session)
|
||||
response = await chat_service.chat(
|
||||
conversation_id=conversation_id,
|
||||
message=chat_request.message,
|
||||
stream=False,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
use_agent=chat_request.use_agent,
|
||||
use_langgraph=chat_request.use_langgraph,
|
||||
use_knowledge_base=chat_request.use_knowledge_base,
|
||||
knowledge_base_id=chat_request.knowledge_base_id
|
||||
)
|
||||
await chat_service.initialize(conversation_id)
|
||||
|
||||
# response = await chat_service.chat(
|
||||
# conversation_id=conversation_id,
|
||||
# message=chat_request.message,
|
||||
# stream=False,
|
||||
# temperature=chat_request.temperature,
|
||||
# max_tokens=chat_request.max_tokens,
|
||||
# use_agent=chat_request.use_agent, # 可以简化掉
|
||||
# use_langgraph=chat_request.use_langgraph, # 可以简化掉
|
||||
# use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉
|
||||
# knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉
|
||||
# )
|
||||
response = "oooooooooooooooooooK"
|
||||
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
|
||||
|
||||
return response
|
||||
|
||||
return HxfResponse(response)
|
||||
# ------------------------------------------------------------------------
|
||||
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
|
||||
async def chat_stream(
|
||||
conversation_id: int,
|
||||
chat_request: ChatRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""发送消息并获取流式AI响应."""
|
||||
session.title = f"对话{conversation_id} 发送消息并获取流式AI响应"
|
||||
session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> "
|
||||
chat_service = ChatService(session)
|
||||
await chat_service.initialize(conversation_id, streaming=True)
|
||||
|
||||
async def generate_response():
|
||||
async for chunk in chat_service.chat_stream(
|
||||
conversation_id=conversation_id,
|
||||
message=chat_request.message,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
use_agent=chat_request.use_agent,
|
||||
use_langgraph=chat_request.use_langgraph,
|
||||
use_knowledge_base=chat_request.use_knowledge_base,
|
||||
knowledge_base_id=chat_request.knowledge_base_id
|
||||
):
|
||||
yield f"data: {chunk}\n\n"
|
||||
async def generate_response(chat_service):
|
||||
try:
|
||||
async for chunk in chat_service.chat_stream(
|
||||
message=chat_request.message
|
||||
):
|
||||
yield chunk + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}")
|
||||
yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"}
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/plain",
|
||||
response = StreamingResponse(
|
||||
generate_response(chat_service),
|
||||
media_type="text/stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Conversation management
|
||||
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
||||
async def create_conversation(
|
||||
conversation_data: ConversationCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""创建新对话"""
|
||||
id = current_user.id
|
||||
session.title = f"用户{current_user.username} - 创建新对话"
|
||||
session.desc = "START: 创建新对话"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.create_conversation(
|
||||
user_id=id,
|
||||
conversation_data=conversation_data
|
||||
)
|
||||
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}"
|
||||
response = ConversationResponse.model_validate(conversation)
|
||||
return HxfResponse(response)
|
||||
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
||||
async def list_conversations(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
order_by: str = "updated_at",
|
||||
order_desc: bool = True,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话列表"""
|
||||
session.title = "获取用户对话列表"
|
||||
session.desc = "START: 获取用户对话列表"
|
||||
conversation_service = ConversationService(session)
|
||||
conversations = await conversation_service.get_user_conversations(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search_query=search,
|
||||
include_archived=include_archived,
|
||||
order_by=order_by,
|
||||
order_desc=order_desc
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..."
|
||||
response = [ConversationResponse.model_validate(conv) for conv in conversations]
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/conversations/count", summary="获取用户对话总数")
|
||||
async def get_conversations_count(
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话总数"""
|
||||
from th_agenter.core.context import UserContext
|
||||
user_id = UserContext.get_current_user_id()
|
||||
session.title = f"获取用户对话总数[用户id = {user_id}]"
|
||||
session.desc = "START: 获取用户对话总数"
|
||||
conversation_service = ConversationService(session)
|
||||
count = await conversation_service.get_user_conversations_count(
|
||||
search_query=search,
|
||||
include_archived=include_archived
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
||||
response = {"count": count}
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取指定对话"""
|
||||
session.title = f"获取指定对话[对话id = {conversation_id}]"
|
||||
session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}"
|
||||
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.get_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Conversation not found"
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}"
|
||||
|
||||
response = ConversationResponse.model_validate(conversation)
|
||||
|
||||
|
||||
# chat_service = ChatService(session)
|
||||
# await chat_service.initialize(conversation_id, streaming=False)
|
||||
# messages = await chat_service.get_conversation_history_messages(
|
||||
# conversation_id
|
||||
# )
|
||||
# response.messages = messages
|
||||
|
||||
messages = await conversation_service.get_conversation_messages(
|
||||
conversation_id, skip=0, limit=100
|
||||
)
|
||||
response.messages = [MessageResponse.model_validate(msg) for msg in messages]
|
||||
|
||||
response.message_count = len(response.messages)
|
||||
return HxfResponse(response)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from th_agenter.db.database import get_session
|
|||
from th_agenter.services.database_config_service import DatabaseConfigService
|
||||
from th_agenter.services.auth import AuthService
|
||||
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||
|
||||
from utils.util_exceptions import HxfResponse
|
||||
# 在文件顶部添加
|
||||
from functools import lru_cache
|
||||
|
||||
|
|
@ -68,11 +68,12 @@ async def create_database_config(
|
|||
):
|
||||
"""创建或更新数据库配置"""
|
||||
config = await service.create_or_update_config(current_user.id, config_data.model_dump())
|
||||
return NormalResponse(
|
||||
response = NormalResponse(
|
||||
success=True,
|
||||
message="保存数据库配置成功",
|
||||
data=config
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表")
|
||||
async def get_database_configs(
|
||||
|
|
@ -83,7 +84,7 @@ async def get_database_configs(
|
|||
configs = service.get_user_configs(current_user.id)
|
||||
|
||||
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||
return config_list
|
||||
return HxfResponse(config_list)
|
||||
|
||||
@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
|
|
@ -93,7 +94,7 @@ async def test_database_connection(
|
|||
):
|
||||
"""测试数据库连接"""
|
||||
result = await service.test_connection(config_id, current_user.id)
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
|
||||
@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表")
|
||||
async def connect_database(
|
||||
|
|
@ -103,7 +104,7 @@ async def connect_database(
|
|||
):
|
||||
"""连接数据库并获取表列表"""
|
||||
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
|
||||
|
||||
@router.get("/tables/{table_name}/data", summary="获取表数据预览")
|
||||
|
|
@ -117,7 +118,7 @@ async def get_table_data(
|
|||
"""获取表数据预览"""
|
||||
try:
|
||||
result = await service.get_table_data(table_name, current_user.id, db_type, limit)
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
except Exception as e:
|
||||
logger.error(f"获取表数据失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -133,7 +134,7 @@ async def get_table_schema(
|
|||
):
|
||||
"""获取表结构信息"""
|
||||
result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的?
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
|
||||
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置")
|
||||
async def get_config_by_type(
|
||||
|
|
@ -149,4 +150,4 @@ async def get_config_by_type(
|
|||
detail=f"未找到类型为 {db_type} 的配置"
|
||||
)
|
||||
# 返回包含解密密码的配置
|
||||
return config.to_dict(include_password=True, decrypt_service=service)
|
||||
return HxfResponse(config.to_dict(include_password=True, decrypt_service=service))
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""Knowledge base API endpoints."""
|
||||
|
||||
from utils.util_exceptions import HxfResponse
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -35,10 +36,10 @@ async def create_knowledge_base(
|
|||
):
|
||||
"""创建新的知识库"""
|
||||
# Check if knowledge base with same name already exists for this user
|
||||
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data.name}"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
|
||||
existing_kb = service.get_knowledge_base_by_name(kb_data.name)
|
||||
existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name)
|
||||
if existing_kb:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -47,10 +48,10 @@ async def create_knowledge_base(
|
|||
|
||||
# Create knowledge base
|
||||
session.desc = f"知识库 {kb_data.name}不存在,创建之"
|
||||
kb = service.create_knowledge_base(kb_data)
|
||||
|
||||
kb = await kb_service.create_knowledge_base(kb_data)
|
||||
|
||||
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
|
||||
return KnowledgeBaseResponse(
|
||||
response = KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
|
|
@ -65,7 +66,7 @@ async def create_knowledge_base(
|
|||
document_count=0,
|
||||
active_document_count=0
|
||||
)
|
||||
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
|
||||
async def list_knowledge_bases(
|
||||
|
|
@ -76,18 +77,17 @@ async def list_knowledge_bases(
|
|||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户的所有知识库"""
|
||||
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
|
||||
knowledge_bases = await service.get_knowledge_bases(skip=skip, limit=limit)
|
||||
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit)
|
||||
|
||||
result = []
|
||||
for kb in knowledge_bases:
|
||||
# Count documents
|
||||
# 本知识库的文档数量
|
||||
total_docs = await session.scalar(
|
||||
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||
)
|
||||
|
||||
# 本知识库的已处理文档数量
|
||||
active_docs = await session.scalar(
|
||||
select(func.count()).where(
|
||||
Document.knowledge_base_id == kb.id,
|
||||
|
|
@ -112,7 +112,7 @@ async def list_knowledge_bases(
|
|||
))
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
|
||||
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
|
||||
async def get_knowledge_base(
|
||||
|
|
@ -124,7 +124,7 @@ async def get_knowledge_base(
|
|||
session.desc = f"START: 获取知识库 {kb_id} 的详情"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"检查知识库 {kb_id} 是否存在"
|
||||
kb = service.get_knowledge_base(kb_id)
|
||||
kb = await service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -146,7 +146,7 @@ async def get_knowledge_base(
|
|||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||
return KnowledgeBaseResponse(
|
||||
response = KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
|
|
@ -161,6 +161,7 @@ async def get_knowledge_base(
|
|||
document_count=total_docs,
|
||||
active_document_count=active_docs
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
|
||||
async def update_knowledge_base(
|
||||
|
|
@ -172,7 +173,7 @@ async def update_knowledge_base(
|
|||
"""更新知识库"""
|
||||
session.desc = f"START: 更新知识库 {kb_id}"
|
||||
service = KnowledgeBaseService(session)
|
||||
kb = service.update_knowledge_base(kb_id, kb_data)
|
||||
kb = await service.update_knowledge_base(kb_id, kb_data)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -192,8 +193,8 @@ async def update_knowledge_base(
|
|||
)
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 更新知识库 {kb_id},共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||
return KnowledgeBaseResponse(
|
||||
session.desc = f"SUCCESS: 更新知识库 {kb_id},结果 - 共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||
response = KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
|
|
@ -208,6 +209,7 @@ async def update_knowledge_base(
|
|||
document_count=total_docs,
|
||||
active_document_count=active_docs
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/{kb_id}", summary="删除知识库")
|
||||
async def delete_knowledge_base(
|
||||
|
|
@ -218,7 +220,7 @@ async def delete_knowledge_base(
|
|||
"""删除知识库"""
|
||||
session.desc = f"START: 删除知识库 {kb_id}"
|
||||
service = KnowledgeBaseService(session)
|
||||
success = service.delete_knowledge_base(kb_id)
|
||||
success = await service.delete_knowledge_base(kb_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -227,7 +229,7 @@ async def delete_knowledge_base(
|
|||
)
|
||||
|
||||
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
|
||||
return {"message": "Knowledge base deleted successfully"}
|
||||
return HxfResponse({"message": "Knowledge base deleted successfully"})
|
||||
|
||||
# Document management endpoints
|
||||
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
|
||||
|
|
@ -239,18 +241,18 @@ async def upload_document(
|
|||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""上传文档到知识库"""
|
||||
session.desc = f"START: 上传文档到知识库 {kb_id}"
|
||||
session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})"
|
||||
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}"
|
||||
# Validate file
|
||||
if not FileUtils.validate_file_extension(file.filename):
|
||||
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||
|
|
@ -258,7 +260,6 @@ async def upload_document(
|
|||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# Check file size (50MB limit)
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if file.size and file.size > max_size:
|
||||
|
|
@ -268,6 +269,7 @@ async def upload_document(
|
|||
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
|
||||
)
|
||||
|
||||
session.desc = f"文件为期望类型,处理文件 {file.filename} - "
|
||||
# Upload document
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.upload_document(
|
||||
|
|
@ -284,7 +286,7 @@ async def upload_document(
|
|||
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
|
||||
|
||||
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
|
||||
return DocumentResponse(
|
||||
response = DocumentResponse(
|
||||
id=document.id,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
|
|
@ -301,6 +303,7 @@ async def upload_document(
|
|||
embedding_model=document.embedding_model,
|
||||
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
|
||||
async def list_documents(
|
||||
|
|
@ -314,7 +317,8 @@ async def list_documents(
|
|||
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -323,7 +327,7 @@ async def list_documents(
|
|||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
documents, total = doc_service.list_documents(kb_id, skip, limit)
|
||||
documents, total = await doc_service.list_documents(kb_id, skip, limit)
|
||||
|
||||
doc_responses = []
|
||||
for doc in documents:
|
||||
|
|
@ -346,208 +350,13 @@ async def list_documents(
|
|||
))
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total} 条"
|
||||
return DocumentListResponse(
|
||||
response = DocumentListResponse(
|
||||
documents=doc_responses,
|
||||
total=total,
|
||||
page=skip // limit + 1,
|
||||
page_size=limit
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
|
||||
async def get_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档详情。"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
return DocumentResponse(
|
||||
id=document.id,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
knowledge_base_id=document.knowledge_base_id,
|
||||
filename=document.filename,
|
||||
original_filename=document.original_filename,
|
||||
file_path=document.file_path,
|
||||
file_type=document.file_type,
|
||||
file_size=document.file_size,
|
||||
mime_type=document.mime_type,
|
||||
is_processed=document.is_processed,
|
||||
processing_error=document.processing_error,
|
||||
chunk_count=document.chunk_count or 0,
|
||||
embedding_model=document.embedding_model,
|
||||
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||
)
|
||||
|
||||
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
|
||||
async def delete_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除知识库中的文档。"""
|
||||
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
success = doc_service.delete_document(doc_id, kb_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
return {"message": "Document deleted successfully"}
|
||||
|
||||
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
|
||||
async def process_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""处理知识库中的文档,用于向量搜索。"""
|
||||
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Check if document exists
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Process the document
|
||||
result = await doc_service.process_document(doc_id, kb_id)
|
||||
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
return DocumentProcessingStatus(
|
||||
document_id=doc_id,
|
||||
status=result["status"],
|
||||
progress=result.get("progress", 0.0),
|
||||
error_message=result.get("error_message"),
|
||||
chunks_created=result.get("chunks_created", 0)
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
|
||||
async def get_document_processing_status(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档处理状态。"""
|
||||
# Verify knowledge base exists and user has access
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Determine status
|
||||
if document.processing_error:
|
||||
status_str = "failed"
|
||||
progress = 0.0
|
||||
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
|
||||
elif document.is_processed:
|
||||
status_str = "completed"
|
||||
progress = 100.0
|
||||
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
|
||||
else:
|
||||
status_str = "pending"
|
||||
progress = 0.0
|
||||
session.desc = f"文档 {doc_id} 处理pending中"
|
||||
|
||||
return DocumentProcessingStatus(
|
||||
document_id=document.id,
|
||||
status=status_str,
|
||||
progress=progress,
|
||||
error_message=document.processing_error,
|
||||
chunks_created=document.chunk_count or 0
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
|
||||
async def search_knowledge_base(
|
||||
kb_id: int,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""在知识库中搜索文档。"""
|
||||
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Perform search
|
||||
doc_service = DocumentService(session)
|
||||
results = doc_service.search_documents(kb_id, query, limit)
|
||||
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
|
||||
return {
|
||||
"knowledge_base_id": kb_id,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"total_results": len(results)
|
||||
}
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
|
||||
async def get_document_chunks(
|
||||
|
|
@ -570,7 +379,8 @@ async def get_document_chunks(
|
|||
"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
knowledge_base = kb_service.get_knowledge_base(kb_id)
|
||||
knowledge_base = await kb_service.get_knowledge_base(kb_id)
|
||||
|
||||
if not knowledge_base:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -580,7 +390,9 @@ async def get_document_chunks(
|
|||
|
||||
# Verify document exists in the knowledge base
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService"
|
||||
document = await doc_service.get_document(doc_id, kb_id)
|
||||
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document"
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
|
|
@ -589,11 +401,215 @@ async def get_document_chunks(
|
|||
)
|
||||
|
||||
# Get document chunks
|
||||
chunks = doc_service.get_document_chunks(doc_id)
|
||||
chunks = await doc_service.get_document_chunks(doc_id)
|
||||
|
||||
session.desc = f"SUCCESS: 获取文档 {doc_id} 共 {len(chunks)} 个文档块(片段)"
|
||||
return DocumentChunksResponse(
|
||||
response = DocumentChunksResponse(
|
||||
document_id=doc_id,
|
||||
document_name=document.filename,
|
||||
total_chunks=len(chunks),
|
||||
chunks=chunks
|
||||
)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
|
||||
async def get_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档详情。"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
response = DocumentResponse(
|
||||
id=document.id,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
knowledge_base_id=document.knowledge_base_id,
|
||||
filename=document.filename,
|
||||
original_filename=document.original_filename,
|
||||
file_path=document.file_path,
|
||||
file_type=document.file_type,
|
||||
file_size=document.file_size,
|
||||
mime_type=document.mime_type,
|
||||
is_processed=document.is_processed,
|
||||
processing_error=document.processing_error,
|
||||
chunk_count=document.chunk_count or 0,
|
||||
embedding_model=document.embedding_model,
|
||||
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
|
||||
async def delete_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除知识库中的文档。"""
|
||||
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
success = await doc_service.delete_document(doc_id, kb_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
response = {"message": "Document deleted successfully"}
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
|
||||
async def process_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""处理知识库中的文档,用于向量搜索。"""
|
||||
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Check if document exists
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Process the document
|
||||
result = await doc_service.process_document(doc_id, kb_id)
|
||||
await session.refresh(document)
|
||||
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
response = DocumentProcessingStatus(
|
||||
document_id=doc_id,
|
||||
status=result["status"],
|
||||
progress=result.get("progress", 0.0),
|
||||
error_message=result.get("error_message"),
|
||||
chunks_created=result.get("chunks_created", 0)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
|
||||
async def get_document_processing_status(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档处理状态。"""
|
||||
# Verify knowledge base exists and user has access
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Determine status
|
||||
if document.processing_error:
|
||||
status_str = "failed"
|
||||
progress = 0.0
|
||||
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
|
||||
elif document.is_processed:
|
||||
status_str = "completed"
|
||||
progress = 100.0
|
||||
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
|
||||
else:
|
||||
status_str = "pending"
|
||||
progress = 0.0
|
||||
session.desc = f"文档 {doc_id} 处理pending中"
|
||||
|
||||
response = DocumentProcessingStatus(
|
||||
document_id=document.id,
|
||||
status=status_str,
|
||||
progress=progress,
|
||||
error_message=document.processing_error,
|
||||
chunks_created=document.chunk_count or 0
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
|
||||
async def search_knowledge_base(
|
||||
kb_id: int,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""在知识库中搜索文档。"""
|
||||
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = await kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Perform search
|
||||
doc_service = DocumentService(session)
|
||||
results = await doc_service.search_documents(kb_id, query, limit)
|
||||
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
|
||||
response = {
|
||||
"knowledge_base_id": kb_id,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"total_results": len(results)
|
||||
}
|
||||
return HxfResponse(response)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,21 @@
|
|||
"""LLM configuration management API endpoints."""
|
||||
|
||||
from turtle import textinput
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import AIMessage
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_, select, delete, update
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM
|
||||
from th_agenter.llm.online.online_llm import OnlineLLM
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.llm_config import LLMConfig
|
||||
from th_agenter.llm.base_llm import LLMConfig_DataClass
|
||||
from ...core.simple_permissions import require_super_admin, require_authenticated_user
|
||||
from ...schemas.llm_config import (
|
||||
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
|
||||
|
|
@ -31,6 +38,7 @@ async def get_llm_configs(
|
|||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取大模型配置列表."""
|
||||
session.title = "获取大模型配置列表"
|
||||
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig)
|
||||
|
||||
|
|
@ -62,7 +70,7 @@ async def get_llm_configs(
|
|||
# 分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
configs = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置"
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..."
|
||||
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||
|
||||
|
||||
|
|
@ -131,7 +139,7 @@ async def get_llm_config(
|
|||
):
|
||||
"""获取大模型配置详情."""
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -149,17 +157,20 @@ async def create_llm_config(
|
|||
):
|
||||
"""创建大模型配置."""
|
||||
# 检查配置名称是否已存在
|
||||
# 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误
|
||||
username = current_user.username
|
||||
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_config:
|
||||
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 创建临时配置对象进行验证
|
||||
temp_config = LLMConfig(
|
||||
# 创建配置对象
|
||||
config = LLMConfig_DataClass(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
|
|
@ -178,7 +189,7 @@ async def create_llm_config(
|
|||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = temp_config.validate_config()
|
||||
validation_result = config.validate_config()
|
||||
if not validation_result['valid']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -190,10 +201,11 @@ async def create_llm_config(
|
|||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == config_data.is_embedding
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
|
||||
session.desc = f"验证大模型配置, config_data"
|
||||
# 创建配置
|
||||
config = LLMConfig(
|
||||
config = LLMConfig_DataClass(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
|
|
@ -213,10 +225,9 @@ async def create_llm_config(
|
|||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.add(config)
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}"
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
|
||||
|
|
@ -228,9 +239,10 @@ async def update_llm_config(
|
|||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新大模型配置."""
|
||||
username = current_user.username
|
||||
session.desc = f"START: 更新大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -243,7 +255,7 @@ async def update_llm_config(
|
|||
LLMConfig.name == config_data.name,
|
||||
LLMConfig.id != config_id
|
||||
)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -258,17 +270,17 @@ async def update_llm_config(
|
|||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
|
||||
# 更新字段
|
||||
update_data = config_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}"
|
||||
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
|
||||
|
|
@ -278,22 +290,24 @@ async def delete_llm_config(
|
|||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""删除大模型配置."""
|
||||
username = current_user.username
|
||||
session.desc = f"START: 删除大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
session.desc = f"待删除大模型记录 {config.to_dict()}"
|
||||
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||
# 这里可以添加相关的检查逻辑
|
||||
|
||||
session.delete(config)
|
||||
session.commit()
|
||||
# 删除配置
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}"
|
||||
session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}"
|
||||
return HxfResponse({"message": "LLM config deleted successfully"})
|
||||
|
||||
@router.post("/{config_id}/test", summary="测试连接大模型配置")
|
||||
|
|
@ -304,17 +318,21 @@ async def test_llm_config(
|
|||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""测试连接大模型配置."""
|
||||
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}"
|
||||
username = current_user.username
|
||||
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}")
|
||||
config_name = config.name
|
||||
# 验证配置
|
||||
validation_result = config.validate_config()
|
||||
logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}")
|
||||
if not validation_result["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
|
|
@ -322,41 +340,54 @@ async def test_llm_config(
|
|||
"details": validation_result
|
||||
}
|
||||
|
||||
session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}"
|
||||
# 尝试创建客户端并发送测试请求
|
||||
try:
|
||||
# 这里应该根据不同的服务商创建相应的客户端
|
||||
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||
# # 这里应该根据不同的服务商创建相应的客户端
|
||||
# # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||
|
||||
test_message = test_data.message or "Hello, this is a test message."
|
||||
session.desc = f"准备测试LLM功能 > test_message = {test_message}"
|
||||
|
||||
# TODO: 实现具体的测试逻辑
|
||||
# 例如:
|
||||
# client = config.get_client()
|
||||
# response = client.chat.completions.create(
|
||||
# model=config.model_name,
|
||||
# messages=[{"role": "user", "content": test_message}],
|
||||
# max_tokens=100
|
||||
# )
|
||||
|
||||
# 模拟测试成功
|
||||
session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}"
|
||||
if config.is_embedding:
|
||||
config.provider = "ollama"
|
||||
streaming_llm = BGEEmbedLLM(config)
|
||||
else:
|
||||
streaming_llm = OnlineLLM(config)
|
||||
session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}"
|
||||
streaming_llm.load_model() # 加载模型
|
||||
session.desc = f"加载模型完毕,模型名称:{config.model_name},base_url: {config.base_url},准备测试对话..."
|
||||
|
||||
if config.is_embedding:
|
||||
# 测试嵌入模型,使用嵌入API而非聊天API
|
||||
test_text = test_message or "Hello, this is a test message for embedding"
|
||||
response = streaming_llm.embed_query(test_text)
|
||||
else:
|
||||
# 测试聊天模型
|
||||
from langchain.messages import SystemMessage, HumanMessage
|
||||
messages = [
|
||||
SystemMessage(content="你是一个简洁的助手,回答控制在50字以内"),
|
||||
HumanMessage(content=test_message)
|
||||
]
|
||||
response = streaming_llm.model.invoke(messages)
|
||||
session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}"
|
||||
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"message": "配置测试成功",
|
||||
"test_message": test_message,
|
||||
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
|
||||
"message": "LLM测试成功",
|
||||
"request": test_message,
|
||||
"response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典
|
||||
"latency_ms": 150, # 模拟延迟
|
||||
"config_info": config.get_client_config()
|
||||
"config_info": config.to_dict()
|
||||
})
|
||||
|
||||
except Exception as test_error:
|
||||
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": f"配置测试失败: {str(test_error)}",
|
||||
"message": f"LLM测试失败: {str(test_error)}",
|
||||
"test_message": test_message,
|
||||
"config_info": config.get_client_config()
|
||||
"config_info": config.to_dict()
|
||||
})
|
||||
|
||||
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
|
||||
|
|
@ -366,10 +397,11 @@ async def toggle_llm_config_status(
|
|||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""切换大模型配置状态."""
|
||||
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}"
|
||||
username = current_user.username
|
||||
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -380,11 +412,11 @@ async def toggle_llm_config_status(
|
|||
config.is_active = not config.is_active
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
status_text = "激活" if config.is_active else "禁用"
|
||||
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}"
|
||||
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}"
|
||||
|
||||
return HxfResponse({
|
||||
"message": f"配置已{status_text}",
|
||||
|
|
@ -399,10 +431,11 @@ async def set_default_llm_config(
|
|||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""设置默认大模型配置."""
|
||||
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}"
|
||||
username = current_user.username
|
||||
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -421,19 +454,19 @@ async def set_default_llm_config(
|
|||
LLMConfig.is_embedding == config.is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
|
||||
# 设置当前配置为默认
|
||||
config.is_default = True
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||
# 更新文档处理器默认embedding
|
||||
get_document_processor()._init_embeddings()
|
||||
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}"
|
||||
await get_document_processor(session)._init_embeddings()
|
||||
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}"
|
||||
return HxfResponse({
|
||||
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||
"is_default": config.is_default
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Role management API endpoints."""
|
||||
|
||||
from utils.util_exceptions import HxfResponse
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
|
|
@ -49,7 +50,8 @@ async def get_roles(
|
|||
stmt = stmt.offset(skip).limit(limit)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色"
|
||||
return [role.to_dict() for role in roles]
|
||||
response = [role.to_dict() for role in roles]
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情")
|
||||
async def get_role(
|
||||
|
|
@ -67,7 +69,8 @@ async def get_role(
|
|||
detail="角色不存在"
|
||||
)
|
||||
|
||||
return role.to_dict()
|
||||
response = role.to_dict()
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色")
|
||||
async def create_role(
|
||||
|
|
@ -95,12 +98,13 @@ async def create_role(
|
|||
)
|
||||
role.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(role)
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
response = role.to_dict()
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色")
|
||||
async def update_role(
|
||||
|
|
@ -152,7 +156,8 @@ async def update_role(
|
|||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
response = role.to_dict()
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色")
|
||||
async def delete_role(
|
||||
|
|
@ -190,7 +195,8 @@ async def delete_role(
|
|||
await session.commit()
|
||||
|
||||
session.desc = f"角色删除成功: {role.name} by user {current_user.username}"
|
||||
return {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
|
||||
response = {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
|
||||
return HxfResponse(response)
|
||||
|
||||
# 用户角色管理路由
|
||||
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
|
@ -230,13 +236,14 @@ async def assign_user_roles(
|
|||
user_id=assignment_data.user_id,
|
||||
role_id=role_id
|
||||
)
|
||||
await session.add(user_role)
|
||||
session.add(user_role)
|
||||
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}"
|
||||
|
||||
return {"message": "角色分配成功"}
|
||||
response = {"message": "角色分配成功"}
|
||||
return HxfResponse(response)
|
||||
|
||||
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表")
|
||||
async def get_user_roles(
|
||||
|
|
@ -267,7 +274,8 @@ async def get_user_roles(
|
|||
)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
|
||||
return [role.to_dict() for role in roles]
|
||||
response = [role.to_dict() for role in roles]
|
||||
return HxfResponse(response)
|
||||
|
||||
# 将子路由添加到主路由
|
||||
router.include_router(user_role_router)
|
||||
|
|
@ -12,6 +12,7 @@ from th_agenter.services.conversation_context import conversation_context_servic
|
|||
from utils.util_schemas import BaseResponse
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
|
||||
security = HTTPBearer()
|
||||
|
|
@ -65,6 +66,8 @@ async def smart_query(
|
|||
|
||||
# 初始化工作流管理器
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
await workflow_manager.initialize()
|
||||
|
||||
conversation_service = ConversationService(session)
|
||||
|
||||
# 处理对话上下文
|
||||
|
|
@ -126,7 +129,7 @@ async def smart_query(
|
|||
except Exception as e:
|
||||
session.desc = f"ERROR: 智能查询执行失败: {e}"
|
||||
# 返回结构化的错误响应
|
||||
return SmartQueryResponse(
|
||||
response = SmartQueryResponse(
|
||||
success=False,
|
||||
message=f"查询执行失败: {str(e)}",
|
||||
data={'error_type': 'query_execution_error'},
|
||||
|
|
@ -137,6 +140,7 @@ async def smart_query(
|
|||
}],
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
# 如果查询成功,保存助手回复和更新上下文
|
||||
if result['success'] and conversation_id:
|
||||
|
|
@ -169,21 +173,22 @@ async def smart_query(
|
|||
if conversation_id:
|
||||
response_data['conversation_id'] = conversation_id
|
||||
session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}"
|
||||
return SmartQueryResponse(
|
||||
response = SmartQueryResponse(
|
||||
success=result['success'],
|
||||
message=result.get('message', '查询完成'),
|
||||
data=response_data,
|
||||
workflow_steps=result.get('workflow_steps', []),
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except HTTPException:
|
||||
except HTTPException as e:
|
||||
session.desc = f"EXCEPTION: HTTP异常: {e}"
|
||||
raise
|
||||
raise e
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 智能查询接口异常: {e}"
|
||||
# 返回通用错误响应
|
||||
return SmartQueryResponse(
|
||||
response = SmartQueryResponse(
|
||||
success=False,
|
||||
message="服务器内部错误,请稍后重试",
|
||||
data={'error_type': 'internal_server_error'},
|
||||
|
|
@ -194,6 +199,7 @@ async def smart_query(
|
|||
}],
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
|
||||
async def get_conversation_context(
|
||||
|
|
@ -227,11 +233,12 @@ async def get_conversation_context(
|
|||
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||
context['message_history'] = history
|
||||
session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}"
|
||||
return ConversationContextResponse(
|
||||
response = ConversationContextResponse(
|
||||
success=True,
|
||||
message="获取对话上下文成功",
|
||||
data=context
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
|
||||
|
|
@ -244,6 +251,7 @@ async def get_files_status(
|
|||
"""
|
||||
session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}"
|
||||
workflow_manager = SmartWorkflowManager()
|
||||
await workflow_manager.initialize()
|
||||
|
||||
# 获取用户文件列表
|
||||
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||
|
|
@ -277,11 +285,12 @@ async def get_files_status(
|
|||
}
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}"
|
||||
return ConversationContextResponse(
|
||||
response = ConversationContextResponse(
|
||||
success=True,
|
||||
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||
data=status_data
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
|
||||
async def reset_conversation_context(
|
||||
|
|
@ -315,10 +324,11 @@ async def reset_conversation_context(
|
|||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}"
|
||||
return {
|
||||
"success": True,
|
||||
"message": "对话上下文已重置,可以开始新的数据分析会话"
|
||||
}
|
||||
response = ConversationContextResponse(
|
||||
success=True,
|
||||
message="对话上下文已重置,可以开始新的数据分析会话"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
else:
|
||||
session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from th_agenter.services.smart_workflow import SmartWorkflowManager
|
|||
from th_agenter.services.conversation_context import ConversationContextService
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
|
||||
security = HTTPBearer()
|
||||
|
|
@ -163,12 +164,13 @@ async def upload_excel(
|
|||
})
|
||||
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}"
|
||||
return ExcelUploadResponse(
|
||||
response = ExcelUploadResponse(
|
||||
file_id=excel_file.id,
|
||||
success=True,
|
||||
message="Excel文件上传成功",
|
||||
data=analysis_result
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据")
|
||||
async def preview_excel(
|
||||
|
|
@ -233,7 +235,7 @@ async def preview_excel(
|
|||
data = paginated_df.fillna('').to_dict('records')
|
||||
columns = df.columns.tolist()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据"
|
||||
return QueryResponse(
|
||||
response = QueryResponse(
|
||||
success=True,
|
||||
message="Excel文件预览加载成功",
|
||||
data={
|
||||
|
|
@ -245,6 +247,7 @@ async def preview_excel(
|
|||
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||
}
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
|
|
@ -264,10 +267,11 @@ async def test_database_connection(
|
|||
message="数据库连接测试成功"
|
||||
)
|
||||
else:
|
||||
return NormalResponse(
|
||||
response = NormalResponse(
|
||||
success=False,
|
||||
message="数据库连接测试失败"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except Exception as e:
|
||||
return NormalResponse(
|
||||
|
|
@ -298,22 +302,25 @@ async def get_table_schema(
|
|||
schema_result = await db_service.get_table_schema(request.table_name, current_user.id)
|
||||
|
||||
if schema_result['success']:
|
||||
return QueryResponse(
|
||||
response = QueryResponse(
|
||||
success=True,
|
||||
message="获取表结构成功",
|
||||
data=schema_result['data']
|
||||
)
|
||||
return HxfResponse(response)
|
||||
else:
|
||||
return QueryResponse(
|
||||
response = QueryResponse(
|
||||
success=False,
|
||||
message=schema_result['message']
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except Exception as e:
|
||||
return QueryResponse(
|
||||
response = QueryResponse(
|
||||
success=False,
|
||||
message=f"获取表结构失败: {str(e)}"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
class StreamQueryRequest(BaseModel):
|
||||
query: str
|
||||
|
|
@ -355,6 +362,8 @@ async def stream_smart_query(
|
|||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
await workflow_manager.initialize()
|
||||
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
|
|
@ -435,7 +444,7 @@ async def stream_smart_query(
|
|||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
response = StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
|
|
@ -447,6 +456,7 @@ async def stream_smart_query(
|
|||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/execute-db-query", summary="流式数据库查询")
|
||||
async def execute_database_query(
|
||||
|
|
@ -477,6 +487,7 @@ async def execute_database_query(
|
|||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
await workflow_manager.initialize()
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
|
|
@ -556,7 +567,7 @@ async def execute_database_query(
|
|||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
response = StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
|
|
@ -568,6 +579,7 @@ async def execute_database_query(
|
|||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/cleanup-temp-files", summary="清理临时文件")
|
||||
async def cleanup_temp_files(
|
||||
|
|
@ -590,16 +602,18 @@ async def cleanup_temp_files(
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
return BaseResponse(
|
||||
response = BaseResponse(
|
||||
success=True,
|
||||
message=f"已清理 {cleaned_count} 个临时文件"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
response = BaseResponse(
|
||||
success=False,
|
||||
message=f"清理临时文件失败: {str(e)}"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表")
|
||||
async def get_file_list(
|
||||
|
|
@ -634,7 +648,7 @@ async def get_file_list(
|
|||
file_list.append(file_info)
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件"
|
||||
return FileListResponse(
|
||||
response = FileListResponse(
|
||||
success=True,
|
||||
message="获取文件列表成功",
|
||||
data={
|
||||
|
|
@ -645,12 +659,14 @@ async def get_file_list(
|
|||
'total_pages': (total + page_size - 1) // page_size
|
||||
}
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except Exception as e:
|
||||
return FileListResponse(
|
||||
response = FileListResponse(
|
||||
success=False,
|
||||
message=f"获取文件列表失败: {str(e)}"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件")
|
||||
async def delete_file(
|
||||
|
|
@ -668,22 +684,26 @@ async def delete_file(
|
|||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||
return NormalResponse(
|
||||
response = NormalResponse(
|
||||
success=True,
|
||||
message="文件删除成功"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
else:
|
||||
session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败"
|
||||
return NormalResponse(
|
||||
response = NormalResponse(
|
||||
success=False,
|
||||
message="文件不存在或删除失败"
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
except Exception as e:
|
||||
return NormalResponse(
|
||||
response = NormalResponse(
|
||||
success=True,
|
||||
message=str(e)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息")
|
||||
async def get_file_info(
|
||||
|
|
@ -725,9 +745,11 @@ async def get_file_info(
|
|||
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||
}
|
||||
|
||||
return QueryResponse(
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件 {file_id} 信息"
|
||||
response = QueryResponse(
|
||||
success=True,
|
||||
message="获取文件信息成功",
|
||||
data=file_info
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ from th_agenter.models.user import User
|
|||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.table_metadata_service import TableMetadataService
|
||||
from th_agenter.services.auth import AuthService
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
|
||||
|
||||
|
|
@ -54,8 +55,7 @@ async def collect_table_metadata(
|
|||
request.table_names
|
||||
)
|
||||
session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据"
|
||||
return result
|
||||
|
||||
return HxfResponse(result)
|
||||
|
||||
@router.get("/", summary="获取用户表元数据列表")
|
||||
async def get_table_metadata(
|
||||
|
|
@ -66,7 +66,7 @@ async def get_table_metadata(
|
|||
"""获取表元数据列表"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
metadata_list = service.get_user_table_metadata(
|
||||
metadata_list = await service.get_user_table_metadata(
|
||||
current_user.id,
|
||||
database_config_id
|
||||
)
|
||||
|
|
@ -96,17 +96,17 @@ async def get_table_metadata(
|
|||
for meta in metadata_list
|
||||
]
|
||||
|
||||
return {
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"data": data
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@router.post("/by-table", summary="根据表名获取表元数据")
|
||||
|
|
@ -118,7 +118,7 @@ async def get_table_metadata_by_name(
|
|||
"""根据表名获取表元数据"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
metadata = service.get_table_metadata_by_name(
|
||||
metadata = await service.get_table_metadata_by_name(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
request.table_name
|
||||
|
|
@ -146,27 +146,30 @@ async def get_table_metadata_by_name(
|
|||
"business_context": metadata.business_context or ""
|
||||
}
|
||||
}
|
||||
return {"success": True, "data": data}
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"data": data
|
||||
})
|
||||
else:
|
||||
return {"success": False, "data": None, "message": "表元数据不存在"}
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"data": None,
|
||||
"message": "表元数据不存在"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置")
|
||||
|
|
@ -179,14 +182,17 @@ async def update_qa_settings(
|
|||
"""更新表的问答设置"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
success = service.update_table_qa_settings(
|
||||
success = await service.update_table_qa_settings(
|
||||
current_user.id,
|
||||
metadata_id,
|
||||
settings.dict()
|
||||
settings.model_dump()
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"success": True, "message": "设置更新成功"}
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"message": "设置更新成功"
|
||||
})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -221,9 +227,9 @@ async def save_table_metadata(
|
|||
|
||||
session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置"
|
||||
|
||||
return {
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||
"saved_tables": result['saved_tables'],
|
||||
"failed_tables": result.get('failed_tables', [])
|
||||
}
|
||||
})
|
||||
|
|
@ -9,6 +9,7 @@ from ...core.simple_permissions import require_super_admin
|
|||
from ...services.auth import AuthService
|
||||
from ...services.user import UserService
|
||||
from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -17,7 +18,8 @@ async def get_user_profile(
|
|||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户的个人信息."""
|
||||
return UserResponse.model_validate(current_user)
|
||||
response = UserResponse.model_validate(current_user)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息")
|
||||
async def update_user_profile(
|
||||
|
|
@ -39,7 +41,8 @@ async def update_user_profile(
|
|||
|
||||
# Update user
|
||||
updated_user = await user_service.update_user(current_user.id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
response = UserResponse.model_validate(updated_user)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/profile", summary="删除当前用户的账户")
|
||||
async def delete_user_account(
|
||||
|
|
@ -51,7 +54,8 @@ async def delete_user_account(
|
|||
user_service = UserService(session)
|
||||
await user_service.delete_user(current_user.id)
|
||||
session.desc = f"删除用户 [{username}] 成功"
|
||||
return {"message": f"删除用户 {username} 成功"}
|
||||
response = {"message": f"删除用户 {username} 成功"}
|
||||
return HxfResponse(response)
|
||||
|
||||
# Admin endpoints
|
||||
@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)")
|
||||
|
|
@ -83,7 +87,8 @@ async def create_user(
|
|||
|
||||
# Create user
|
||||
new_user = await user_service.create_user(user_create)
|
||||
return UserResponse.model_validate(new_user)
|
||||
response = UserResponse.model_validate(new_user)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)")
|
||||
async def list_users(
|
||||
|
|
@ -111,7 +116,7 @@ async def list_users(
|
|||
"page": page,
|
||||
"page_size": size
|
||||
}
|
||||
return result
|
||||
return HxfResponse(result)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)")
|
||||
|
|
@ -128,7 +133,8 @@ async def get_user(
|
|||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return UserResponse.model_validate(user)
|
||||
response = UserResponse.model_validate(user)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.put("/change-password", summary="修改当前用户的密码")
|
||||
async def change_password(
|
||||
|
|
@ -145,7 +151,8 @@ async def change_password(
|
|||
current_password=request.current_password,
|
||||
new_password=request.new_password
|
||||
)
|
||||
return {"message": "Password changed successfully"}
|
||||
response = {"message": "Password changed successfully"}
|
||||
return HxfResponse(response)
|
||||
except Exception as e:
|
||||
if "Current password is incorrect" in str(e):
|
||||
raise HTTPException(
|
||||
|
|
@ -178,7 +185,8 @@ async def reset_user_password(
|
|||
user_id=user_id,
|
||||
new_password=request.new_password
|
||||
)
|
||||
return {"message": "Password reset successfully"}
|
||||
response = {"message": "Password reset successfully"}
|
||||
return HxfResponse(response)
|
||||
except Exception as e:
|
||||
if "User not found" in str(e):
|
||||
raise HTTPException(
|
||||
|
|
@ -215,7 +223,8 @@ async def update_user(
|
|||
)
|
||||
|
||||
updated_user = await user_service.update_user(user_id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
response = UserResponse.model_validate(updated_user)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)")
|
||||
async def delete_user(
|
||||
|
|
@ -234,4 +243,5 @@ async def delete_user(
|
|||
)
|
||||
|
||||
await user_service.delete_user(user_id)
|
||||
return {"message": "User deleted successfully"}
|
||||
response = {"message": "User deleted successfully"}
|
||||
return HxfResponse(response)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from ...services.workflow_engine import get_workflow_engine
|
|||
from ...services.auth import AuthService
|
||||
from ...models.user import User
|
||||
from loguru import logger
|
||||
|
||||
from utils.util_exceptions import HxfResponse
|
||||
router = APIRouter()
|
||||
|
||||
def convert_workflow_for_response(workflow_dict):
|
||||
|
|
@ -30,38 +30,7 @@ def convert_workflow_for_response(workflow_dict):
|
|||
if 'to_node' in conn:
|
||||
conn['to'] = conn.pop('to_node')
|
||||
return workflow_dict
|
||||
|
||||
@router.post("/", response_model=WorkflowResponse)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""创建工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 创建工作流
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
definition=workflow_data.definition.model_dump(),
|
||||
version="1.0.0",
|
||||
status=workflow_data.status,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
workflow.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
|
||||
# 转换definition中的字段映射
|
||||
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||
|
||||
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
|
||||
|
||||
@router.get("/", response_model=WorkflowListResponse)
|
||||
async def list_workflows(
|
||||
skip: Optional[int] = Query(None, ge=0),
|
||||
|
|
@ -73,6 +42,7 @@ async def list_workflows(
|
|||
):
|
||||
"""获取工作流列表"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})"
|
||||
|
||||
# 构建查询
|
||||
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
|
||||
|
|
@ -90,17 +60,22 @@ async def list_workflows(
|
|||
count_query = count_query.where(Workflow.status == workflow_status)
|
||||
if search:
|
||||
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
|
||||
total = session.scalar(count_query)
|
||||
|
||||
session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}"
|
||||
total = await session.scalar(count_query)
|
||||
session.desc = f"查询结果: 共 {total} 条"
|
||||
|
||||
# 如果没有传分页参数,返回所有数据
|
||||
if skip is None and limit is None:
|
||||
workflows = session.scalars(stmt).all()
|
||||
return WorkflowListResponse(
|
||||
workflows = (await session.scalars(stmt)).all()
|
||||
session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)} 条"
|
||||
response = WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=1,
|
||||
size=total
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
# 使用默认分页参数
|
||||
if skip is None:
|
||||
|
|
@ -109,14 +84,16 @@ async def list_workflows(
|
|||
limit = 10
|
||||
|
||||
# 分页查询
|
||||
workflows = session.scalars(stmt.offset(skip).limit(limit)).all()
|
||||
workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all()
|
||||
session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)} 条"
|
||||
|
||||
return WorkflowListResponse(
|
||||
response = WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=skip // limit + 1, # 计算页码
|
||||
size=limit
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def get_workflow(
|
||||
|
|
@ -126,8 +103,9 @@ async def get_workflow(
|
|||
):
|
||||
"""获取工作流详情"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 获取工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
|
|
@ -135,12 +113,15 @@ async def get_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}"
|
||||
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def update_workflow(
|
||||
|
|
@ -151,8 +132,9 @@ async def update_workflow(
|
|||
):
|
||||
"""更新工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 更新工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -162,13 +144,15 @@ async def update_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = workflow_data.dict(exclude_unset=True)
|
||||
session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}"
|
||||
update_data = workflow_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "definition" and value:
|
||||
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||
|
|
@ -181,13 +165,12 @@ async def update_workflow(
|
|||
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}"
|
||||
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
return HxfResponse(response)
|
||||
|
||||
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}")
|
||||
async def delete_workflow(
|
||||
workflow_id: int,
|
||||
|
|
@ -196,8 +179,9 @@ async def delete_workflow(
|
|||
):
|
||||
"""删除工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 删除工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -207,16 +191,18 @@ async def delete_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
session.desc = f"删除工作流: {workflow.name}"
|
||||
await session.delete(workflow)
|
||||
await session.commit()
|
||||
|
||||
session.delete(workflow)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流删除成功"}
|
||||
session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}"
|
||||
response = {"message": "工作流删除成功"}
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/activate")
|
||||
|
|
@ -227,8 +213,9 @@ async def activate_workflow(
|
|||
):
|
||||
"""激活工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 激活工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -238,6 +225,7 @@ async def activate_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
|
|
@ -245,11 +233,11 @@ async def activate_workflow(
|
|||
|
||||
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
await session.commit()
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流激活成功"}
|
||||
session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}"
|
||||
response = {"message": "工作流激活成功"}
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/{workflow_id}/deactivate")
|
||||
async def deactivate_workflow(
|
||||
|
|
@ -259,8 +247,9 @@ async def deactivate_workflow(
|
|||
):
|
||||
"""停用工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 停用工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -270,6 +259,7 @@ async def deactivate_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
|
|
@ -278,10 +268,11 @@ async def deactivate_workflow(
|
|||
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流停用成功"}
|
||||
session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}"
|
||||
response = {"message": "工作流停用成功"}
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||
async def execute_workflow(
|
||||
|
|
@ -292,8 +283,9 @@ async def execute_workflow(
|
|||
):
|
||||
"""执行工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 执行工作流 {workflow_id}"
|
||||
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -303,6 +295,7 @@ async def execute_workflow(
|
|||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
|
|
@ -323,8 +316,8 @@ async def execute_workflow(
|
|||
session=session
|
||||
)
|
||||
|
||||
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
|
||||
return execution_result
|
||||
session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}"
|
||||
return HxfResponse(execution_result)
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
|
||||
|
|
@ -336,38 +329,40 @@ async def list_workflow_executions(
|
|||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行历史"""
|
||||
session.desc = f"START: 获取工作流执行历史 {workflow_id}"
|
||||
try:
|
||||
from ...models.workflow import Workflow, WorkflowExecution
|
||||
|
||||
# 验证工作流所有权
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 获取执行历史
|
||||
executions = session.scalars(
|
||||
select(WorkflowExecution).where(
|
||||
WorkflowExecution.workflow_id == workflow_id
|
||||
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
|
||||
).all()
|
||||
executions = (await session.scalars(
|
||||
select(WorkflowExecution).where(
|
||||
WorkflowExecution.workflow_id == workflow_id
|
||||
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
|
||||
)).all()
|
||||
|
||||
return [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
|
||||
session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}"
|
||||
response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
|
||||
return HxfResponse(response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}")
|
||||
session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取执行历史失败"
|
||||
|
|
@ -380,10 +375,11 @@ async def get_workflow_execution(
|
|||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行详情"""
|
||||
session.desc = f"START: 获取工作流执行详情 {execution_id}"
|
||||
try:
|
||||
from ...models.workflow import WorkflowExecution, Workflow
|
||||
|
||||
execution = session.scalar(
|
||||
execution = await session.scalar(
|
||||
select(WorkflowExecution).join(
|
||||
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||
).where(
|
||||
|
|
@ -393,17 +389,18 @@ async def get_workflow_execution(
|
|||
)
|
||||
|
||||
if not execution:
|
||||
session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行记录不存在"
|
||||
)
|
||||
|
||||
return WorkflowExecutionResponse.model_validate(execution)
|
||||
response = WorkflowExecutionResponse.model_validate(execution)
|
||||
session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}"
|
||||
return HxfResponse(response)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow execution {execution_id}: {str(e)}")
|
||||
session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取执行详情失败"
|
||||
|
|
@ -418,7 +415,7 @@ async def execute_workflow_stream(
|
|||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""流式执行工作流,实时推送节点执行状态"""
|
||||
|
||||
session.desc = f"START: 流式执行工作流 {workflow_id}"
|
||||
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||
workflow_engine = None
|
||||
|
||||
|
|
@ -426,7 +423,7 @@ async def execute_workflow_stream(
|
|||
from ...models.workflow import Workflow
|
||||
|
||||
# 验证工作流
|
||||
workflow = session.scalar(
|
||||
workflow = await session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
|
|
@ -466,7 +463,7 @@ async def execute_workflow_stream(
|
|||
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
response = StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
|
|
@ -477,4 +474,43 @@ async def execute_workflow_stream(
|
|||
"Access-Control-Allow-Headers": "*",
|
||||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
)
|
||||
session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕"
|
||||
return HxfResponse(response)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
@router.post("/", response_model=WorkflowResponse)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""创建工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
session.desc = f"START: 创建工作流 {workflow_data.name}"
|
||||
# 创建工作流
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
definition=workflow_data.definition.model_dump(),
|
||||
version="1.0.0",
|
||||
status=workflow_data.status,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
session.desc = f"创建工作流实例 - Workflow(), {workflow_data.name}"
|
||||
workflow.set_audit_fields(current_user.id)
|
||||
session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}"
|
||||
|
||||
session.add(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
session.desc = f"保存工作流 - commit & refresh {workflow_data.name}"
|
||||
# 转换definition中的字段映射
|
||||
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||
session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}"
|
||||
|
||||
response = WorkflowResponse(**workflow_dict)
|
||||
session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}"
|
||||
return HxfResponse(response)
|
||||
|
||||
|
|
@ -69,13 +69,6 @@ router.include_router(
|
|||
tags=["smart-chat"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
router.include_router(
|
||||
workflow.router,
|
||||
prefix="/workflows",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Configuration management for TH Agenter."""
|
||||
|
||||
import os
|
||||
from requests import Session
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
|
@ -88,13 +89,15 @@ class LLMSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
def get_current_config(self) -> dict:
|
||||
async def get_current_config(self, session: Session) -> dict:
|
||||
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
|
||||
try:
|
||||
# 尝试从数据库读取默认聊天模型配置
|
||||
from th_agenter.services.llm_config_service import LLMConfigService
|
||||
# 尝试从数据库读取默认聊天模型配置
|
||||
llm_service = LLMConfigService()
|
||||
db_config = llm_service.get_default_chat_config()
|
||||
db_config = None
|
||||
if session:
|
||||
db_config = await llm_service.get_default_chat_config(session)
|
||||
|
||||
if db_config:
|
||||
# 如果数据库中有默认配置,使用数据库配置
|
||||
|
|
@ -105,10 +108,17 @@ class LLMSettings(BaseSettings):
|
|||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
if session:
|
||||
session.desc = f"使用LLM配置(get_default_chat_config)> {config}"
|
||||
else:
|
||||
logger.info(f"使用LLM配置(get_default_chat_config) > {config}")
|
||||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||||
if session:
|
||||
session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}"
|
||||
else:
|
||||
logger.error(f"获取默认对话模型配置失败: {str(e)}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
|
|
@ -182,13 +192,15 @@ class EmbeddingSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
def get_current_config(self) -> dict:
|
||||
async def get_current_config(self, session: Session) -> dict:
|
||||
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
|
||||
try:
|
||||
if session:
|
||||
session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config";
|
||||
# 尝试从数据库读取默认嵌入模型配置
|
||||
from th_agenter.services.llm_config_service import LLMConfigService
|
||||
llm_service = LLMConfigService()
|
||||
db_config = llm_service.get_default_embedding_config()
|
||||
db_config = await llm_service.get_default_embedding_config(session)
|
||||
|
||||
if db_config:
|
||||
# 如果数据库中有默认配置,使用数据库配置
|
||||
|
|
@ -200,7 +212,10 @@ class EmbeddingSettings(BaseSettings):
|
|||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
if session:
|
||||
session.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
else:
|
||||
logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from ..models.user import User
|
|||
from loguru import logger
|
||||
|
||||
# Context variable to store current user
|
||||
current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None)
|
||||
current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None)
|
||||
|
||||
# Thread-local storage as backup
|
||||
_thread_local = threading.local()
|
||||
|
|
@ -19,34 +19,56 @@ class UserContext:
|
|||
"""User context manager for accessing current user globally."""
|
||||
|
||||
@staticmethod
|
||||
def set_current_user(user: User) -> None:
|
||||
def set_current_user(user: User, canLog: bool = False) -> None:
|
||||
"""Set current user in context."""
|
||||
logger.info(f"[UserContext] - Setting user in context: {user.username} (ID: {user.id})")
|
||||
if canLog:
|
||||
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||||
|
||||
# Store user information as a dictionary instead of the SQLAlchemy model
|
||||
user_dict = {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'full_name': user.full_name,
|
||||
'is_active': user.is_active
|
||||
}
|
||||
|
||||
# Set in ContextVar
|
||||
current_user_context.set(user)
|
||||
current_user_context.set(user_dict)
|
||||
|
||||
# Also set in thread-local as backup
|
||||
_thread_local.current_user = user
|
||||
_thread_local.current_user = user_dict
|
||||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
if canLog:
|
||||
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||||
|
||||
@staticmethod
|
||||
def set_current_user_with_token(user: User):
|
||||
def set_current_user_with_token(user: User, canLog: bool = False):
|
||||
"""Set current user in context and return token for cleanup."""
|
||||
logger.info(f"[UserContext] - Setting user in context with token: {user.username} (ID: {user.id})")
|
||||
if canLog:
|
||||
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||||
|
||||
# Store user information as a dictionary instead of the SQLAlchemy model
|
||||
user_dict = {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'full_name': user.full_name,
|
||||
'is_active': user.is_active
|
||||
}
|
||||
|
||||
# Set in ContextVar and get token
|
||||
token = current_user_context.set(user)
|
||||
token = current_user_context.set(user_dict)
|
||||
|
||||
# Also set in thread-local as backup
|
||||
_thread_local.current_user = user
|
||||
_thread_local.current_user = user_dict
|
||||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
if canLog:
|
||||
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||||
|
||||
return token
|
||||
|
||||
|
|
@ -63,42 +85,45 @@ class UserContext:
|
|||
delattr(_thread_local, 'current_user')
|
||||
|
||||
@staticmethod
|
||||
def get_current_user() -> Optional[User]:
|
||||
"""Get current user from context."""
|
||||
logger.debug("[UserContext] - Attempting to get user from context")
|
||||
|
||||
def get_current_user() -> Optional[dict]:
|
||||
"""Get current user from context."""
|
||||
# Try ContextVar first
|
||||
user = current_user_context.get()
|
||||
user = current_user_context.get()
|
||||
if user:
|
||||
logger.debug(f"[UserContext] - Got user from ContextVar: {user.username} (ID: {user.id})")
|
||||
# logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}")
|
||||
return user
|
||||
|
||||
# Fallback to thread-local
|
||||
user = getattr(_thread_local, 'current_user', None)
|
||||
user = getattr(_thread_local, 'current_user', None)
|
||||
if user:
|
||||
logger.debug(f"[UserContext] - Got user from thread-local: {user.username} (ID: {user.id})")
|
||||
# logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}")
|
||||
return user
|
||||
|
||||
logger.debug("[UserContext] - No user found in context (neither ContextVar nor thread-local)")
|
||||
logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_current_user_id() -> Optional[int]:
|
||||
"""Get current user ID from context."""
|
||||
user = UserContext.get_current_user()
|
||||
return user.id if user else None
|
||||
try:
|
||||
user = UserContext.get_current_user()
|
||||
return user.get('id') if user else None
|
||||
except Exception as e:
|
||||
logger.error(f"[UserContext] - Error getting current user ID: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def clear_current_user() -> None:
|
||||
def clear_current_user(canLog: bool = False) -> None:
|
||||
"""Clear current user from context."""
|
||||
logger.info("[UserContext] - 清除当前用户上下文")
|
||||
if canLog:
|
||||
logger.info("[UserContext] - 清除当前用户上下文")
|
||||
|
||||
current_user_context.set(None)
|
||||
if hasattr(_thread_local, 'current_user'):
|
||||
delattr(_thread_local, 'current_user')
|
||||
|
||||
@staticmethod
|
||||
def require_current_user() -> User:
|
||||
def require_current_user() -> dict:
|
||||
"""Get current user from context, raise exception if not found."""
|
||||
# Use the same logic as get_current_user to check both ContextVar and thread-local
|
||||
user = UserContext.get_current_user()
|
||||
|
|
@ -114,4 +139,4 @@ class UserContext:
|
|||
def require_current_user_id() -> int:
|
||||
"""Get current user ID from context, raise exception if not found."""
|
||||
user = UserContext.require_current_user()
|
||||
return user.id
|
||||
return user.get('id')
|
||||
|
|
@ -20,7 +20,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
def __init__(self, app, exclude_paths: list = None):
|
||||
super().__init__(app)
|
||||
self.canLog = True
|
||||
self.canLog = False
|
||||
# Paths that don't require authentication
|
||||
self.exclude_paths = exclude_paths or [
|
||||
"/docs",
|
||||
|
|
@ -77,7 +77,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
|
||||
|
||||
# Always clear any existing user context to ensure fresh authentication
|
||||
UserContext.clear_current_user()
|
||||
UserContext.clear_current_user(self.canLog)
|
||||
|
||||
# Initialize context token
|
||||
user_token = None
|
||||
|
|
@ -96,6 +96,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
# Extract token
|
||||
token = authorization.split(" ")[1]
|
||||
|
||||
|
||||
# Verify token
|
||||
payload = AuthService.verify_token(token)
|
||||
if payload is None:
|
||||
|
|
@ -136,7 +137,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
)
|
||||
|
||||
# Set user in context using token mechanism
|
||||
user_token = UserContext.set_current_user_with_token(user)
|
||||
user_token = UserContext.set_current_user_with_token(user, self.canLog)
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
|
||||
|
||||
|
|
@ -160,8 +161,13 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
except Exception as e:
|
||||
# Log error but don't fail the request
|
||||
logger.error(f"[MIDDLEWARE] - 请求处理出错: {e}")
|
||||
# Return 500 error
|
||||
return HxfErrorResponse(e)
|
||||
finally:
|
||||
# Always clear user context after request processing
|
||||
UserContext.clear_current_user()
|
||||
UserContext.clear_current_user(self.canLog)
|
||||
if self.canLog:
|
||||
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
"""LLM工厂类,用于创建和管理LLM实例"""
|
||||
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.agents import create_agent
|
||||
from loguru import logger
|
||||
from requests import Session
|
||||
from .config import get_settings
|
||||
|
||||
async def new_llm(session: Session = None, model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
streaming: bool = False) -> ChatOpenAI:
|
||||
"""创建LLM实例
|
||||
|
||||
Args:
|
||||
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||
temperature: 可选,模型温度参数
|
||||
streaming: 是否启用流式响应,默认False
|
||||
|
||||
Returns:
|
||||
ChatOpenAI实例
|
||||
"""
|
||||
settings = get_settings()
|
||||
llm_config = await settings.llm.get_current_config(session)
|
||||
|
||||
if model:
|
||||
# 根据指定的模型获取对应配置
|
||||
if model.startswith('deepseek'):
|
||||
llm_config['model'] = settings.llm.deepseek_model
|
||||
llm_config['api_key'] = settings.llm.deepseek_api_key
|
||||
llm_config['base_url'] = settings.llm.deepseek_base_url
|
||||
elif model.startswith('doubao'):
|
||||
llm_config['model'] = settings.llm.doubao_model
|
||||
llm_config['api_key'] = settings.llm.doubao_api_key
|
||||
llm_config['base_url'] = settings.llm.doubao_base_url
|
||||
elif model.startswith('glm'):
|
||||
llm_config['model'] = settings.llm.zhipu_model
|
||||
llm_config['api_key'] = settings.llm.zhipu_api_key
|
||||
llm_config['base_url'] = settings.llm.zhipu_base_url
|
||||
elif model.startswith('moonshot'):
|
||||
llm_config['model'] = settings.llm.moonshot_model
|
||||
llm_config['api_key'] = settings.llm.moonshot_api_key
|
||||
llm_config['base_url'] = settings.llm.moonshot_base_url
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model=llm_config['model'],
|
||||
api_key=llm_config['api_key'],
|
||||
base_url=llm_config['base_url'],
|
||||
temperature=temperature if temperature is not None else llm_config['temperature'],
|
||||
max_tokens=llm_config['max_tokens'],
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
return llm
|
||||
|
||||
async def new_agent(session: Session = None, model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
streaming: bool = False) -> ChatOpenAI:
|
||||
"""创建LLM实例
|
||||
|
||||
Args:
|
||||
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||
temperature: 可选,模型温度参数
|
||||
streaming: 是否启用流式响应,默认False
|
||||
|
||||
Returns:
|
||||
ChatOpenAI实例
|
||||
"""
|
||||
llm = await new_llm(session, model, temperature, streaming)
|
||||
result = create_agent(
|
||||
model=llm,
|
||||
)
|
||||
return result
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
from functools import wraps
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.database import get_session
|
||||
|
|
@ -11,45 +12,40 @@ from ..models.permission import Role
|
|||
from ..services.auth import AuthService
|
||||
|
||||
|
||||
def is_super_admin(user: User, db: Session) -> bool:
|
||||
async def is_super_admin(user: User, session: Session) -> bool:
|
||||
"""检查用户是否为超级管理员."""
|
||||
session.desc = f"检查用户 {user.id} 是否为超级管理员"
|
||||
if not user or not user.is_active:
|
||||
session.desc = f"用户 {user.id} 不是活跃状态"
|
||||
return False
|
||||
|
||||
# 检查用户是否有超级管理员角色
|
||||
try:
|
||||
# 尝试访问已加载的角色
|
||||
for role in user.roles:
|
||||
if role.code == "SUPER_ADMIN":
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
# 如果角色未加载或访问失败,直接从数据库查询
|
||||
from sqlalchemy import select, and_
|
||||
from ..models.permission import Role, UserRole
|
||||
# 直接使用提供的session查询,避免MissingGreenlet错误
|
||||
from sqlalchemy import select
|
||||
from ..models.permission import UserRole, Role
|
||||
|
||||
try:
|
||||
# 直接查询用户角色
|
||||
stmt = select(Role).join(UserRole).filter(
|
||||
and_(
|
||||
UserRole.user_id == user.id,
|
||||
Role.code == "SUPER_ADMIN",
|
||||
Role.is_active == True
|
||||
)
|
||||
)
|
||||
super_admin_role = db.execute(stmt).scalar_one_or_none()
|
||||
return super_admin_role is not None
|
||||
except Exception:
|
||||
# 如果查询失败,返回False
|
||||
return False
|
||||
stmt = select(UserRole).join(Role).filter(
|
||||
UserRole.user_id == user.id,
|
||||
Role.code == 'SUPER_ADMIN',
|
||||
Role.is_active == True
|
||||
)
|
||||
user_role = await session.execute(stmt)
|
||||
result = user_role.scalar_one_or_none() is not None
|
||||
session.desc = f"用户 {user.id} 超级管理员角色查询结果: {result}"
|
||||
return result
|
||||
except Exception as e:
|
||||
# 如果调用失败,记录错误并返回False
|
||||
session.desc = f"EXCEPTION: 用户 {user.id} 超级管理员角色查询失败: {str(e)}"
|
||||
logger.error(f"检查用户 {user.id} 超级管理员角色失败: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def require_super_admin(
|
||||
async def require_super_admin(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
) -> User:
|
||||
"""要求超级管理员权限的依赖项."""
|
||||
if not is_super_admin(current_user, session):
|
||||
if not await is_super_admin(current_user, session):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="需要超级管理员权限"
|
||||
|
|
@ -75,17 +71,17 @@ class SimplePermissionChecker:
|
|||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_super_admin(self, user: User) -> bool:
|
||||
async def check_super_admin(self, user: User) -> bool:
|
||||
"""检查是否为超级管理员."""
|
||||
return is_super_admin(user, self.db)
|
||||
return await is_super_admin(user, self.db)
|
||||
|
||||
def check_user_access(self, user: User, target_user_id: int) -> bool:
|
||||
async def check_user_access(self, user: User, target_user_id: int) -> bool:
|
||||
"""检查用户访问权限(自己或超级管理员)."""
|
||||
if not user or not user.is_active:
|
||||
return False
|
||||
|
||||
# 超级管理员可以访问所有用户
|
||||
if self.check_super_admin(user):
|
||||
if await self.check_super_admin(user):
|
||||
return True
|
||||
|
||||
# 用户只能访问自己的信息
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class BaseModel(Base):
|
|||
def to_dict(self):
|
||||
"""Convert model to dictionary."""
|
||||
return {
|
||||
column.name: getattr(self, column.name)
|
||||
column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name)
|
||||
for column in self.__table__.columns
|
||||
}
|
||||
|
||||
|
|
@ -60,31 +60,31 @@ class BaseModel(Base):
|
|||
return cls(**filtered_data)
|
||||
|
||||
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||
"""Set audit fields for create/update operations.
|
||||
"""对创建/更新操作设置created_by/updated_by字段。
|
||||
|
||||
Args:
|
||||
user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||
is_update: True for update operations, False for create operations
|
||||
user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取)
|
||||
is_update: True 表示更新操作,False 表示创建操作
|
||||
"""
|
||||
# Get user_id from context if not provided
|
||||
# 如果未提供user_id,则从上下文获取
|
||||
if user_id is None:
|
||||
from ..core.context import UserContext
|
||||
try:
|
||||
user_id = UserContext.get_current_user_id()
|
||||
except Exception:
|
||||
# If no user in context, skip setting audit fields
|
||||
# 如果上下文没有用户ID,则跳过设置审计字段
|
||||
return
|
||||
|
||||
# Skip if still no user_id
|
||||
# 如果仍未提供user_id,则跳过设置审计字段
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
if not is_update:
|
||||
# For create operations, set both create_by and update_by
|
||||
# 对于创建操作,同时设置created_by和updated_by
|
||||
self.created_by = user_id
|
||||
self.updated_by = user_id
|
||||
else:
|
||||
# For update operations, only set update_by
|
||||
# 对于更新操作,仅设置updated_by
|
||||
self.updated_by = user_id
|
||||
|
||||
# @event.listens_for(Session, 'before_flush')
|
||||
|
|
|
|||
|
|
@ -18,12 +18,27 @@ class DrSession(AsyncSession):
|
|||
def __init__(self, **kwargs):
|
||||
"""Initialize DrSession with unique ID."""
|
||||
super().__init__(**kwargs)
|
||||
self.title = ""
|
||||
self.descs = []
|
||||
# 确保info属性存在
|
||||
if not hasattr(self, 'info'):
|
||||
self.info = {}
|
||||
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
||||
self.stepIndex = 0
|
||||
|
||||
@property
|
||||
def title(self) -> Optional[str]:
|
||||
"""Get work brief from session info."""
|
||||
return self.info.get('title')
|
||||
|
||||
@title.setter
|
||||
def title(self, value: str) -> None:
|
||||
"""Set work brief in session info."""
|
||||
if('title' not in self.info or self.info['title'].strip() == ""):
|
||||
self.info['title'] = value # 确保title属性存在
|
||||
else:
|
||||
self.info['title'] = value + " >>> " + self.info['title']
|
||||
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Get work brief from session info."""
|
||||
|
|
@ -34,6 +49,25 @@ class DrSession(AsyncSession):
|
|||
"""Set work brief in session info."""
|
||||
self.stepIndex += 1
|
||||
|
||||
match = re.search(r";(-\d+)", value);
|
||||
level = -3
|
||||
if match:
|
||||
level = int(match.group(1))
|
||||
value = value.replace(f";{level}", "")
|
||||
level = -3 + level
|
||||
|
||||
if "警告" in value or value.startswith("WARNING"):
|
||||
self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "异常" in value or value.startswith("EXCEPTION"):
|
||||
self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "成功" in value or value.startswith("SUCCESS"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "开始" in value or value.startswith("START"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "失败" in value or value.startswith("ERROR"):
|
||||
self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
else:
|
||||
self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
|
||||
def log_prefix(self) -> str:
|
||||
"""Get log prefix with session ID and desc."""
|
||||
|
|
@ -74,14 +108,14 @@ class DrSession(AsyncSession):
|
|||
|
||||
engine_async = create_async_engine(
|
||||
get_settings().database.url,
|
||||
echo=True, # get_settings().database.echo,
|
||||
echo=False, # get_settings().database.echo,
|
||||
future=True,
|
||||
pool_size=get_settings().database.pool_size,
|
||||
max_overflow=get_settings().database.max_overflow,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
from fastapi import Request
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
AsyncSessionFactory = sessionmaker(
|
||||
bind=engine_async,
|
||||
|
|
@ -95,10 +129,15 @@ async def get_session(request: Request = None):
|
|||
if request:
|
||||
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
||||
# session = AsyncSessionFactory()
|
||||
|
||||
print(url)
|
||||
# 取得request的来源IP
|
||||
if request:
|
||||
client_host = request.client.host
|
||||
else:
|
||||
client_host = "无request"
|
||||
session = DrSession(bind=engine_async)
|
||||
|
||||
session.desc = f"SUCCESS: 创建数据库 session >>> {url}"
|
||||
session.title = f"{url} - {client_host}"
|
||||
|
||||
# 设置request属性
|
||||
if request:
|
||||
|
|
@ -111,8 +150,9 @@ async def get_session(request: Request = None):
|
|||
errMsg = f"数据库 session 异常 >>> {e}"
|
||||
session.desc = f"EXCEPTION: {errMsg}"
|
||||
await session.rollback()
|
||||
raise e
|
||||
# DatabaseError(e)
|
||||
# 重新抛出原始异常,不转换为 HTTPException
|
||||
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
|
||||
finally:
|
||||
session.desc = f"数据库 session 关闭"
|
||||
# session.desc = f"数据库 session 关闭"
|
||||
session.desc = ""
|
||||
await session.close()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ sys.path.insert(0, str(backend_dir))
|
|||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from th_agenter.core.config import settings
|
||||
from th_agenter.db.database import Base, get_session
|
||||
from th_agenter.core.config import get_settings
|
||||
from th_agenter.db.database import Base
|
||||
from th_agenter.models import * # Import all models to ensure they're registered
|
||||
from th_agenter.utils.logger import get_logger
|
||||
from th_agenter.models.resource import Resource
|
||||
|
|
@ -25,8 +25,25 @@ def migrate_hardcoded_resources():
|
|||
"""Migrate hardcoded resources from init_resource_data.py to database."""
|
||||
db = None
|
||||
try:
|
||||
# Get database settings
|
||||
settings = get_settings()
|
||||
|
||||
# Create synchronous engine (remove asyncpg from URL)
|
||||
sync_db_url = settings.database.url.replace('postgresql+asyncpg://', 'postgresql://')
|
||||
sync_engine = create_engine(
|
||||
sync_db_url,
|
||||
echo=False,
|
||||
pool_size=settings.database.pool_size,
|
||||
max_overflow=settings.database.max_overflow,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# Create synchronous session factory
|
||||
SyncSessionFactory = sessionmaker(bind=sync_engine)
|
||||
|
||||
# Get database session
|
||||
db = get_session() # xxxx
|
||||
db = SyncSessionFactory()
|
||||
|
||||
if db is None:
|
||||
logger.error("Failed to create database session")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ branch_labels = None
|
|||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
async def upgrade():
|
||||
"""删除权限相关表."""
|
||||
|
||||
# 获取数据库连接
|
||||
|
|
@ -44,7 +44,7 @@ def upgrade():
|
|||
WHERE table_name = '{table_name}'
|
||||
);
|
||||
"""))
|
||||
table_exists = result.scalar()
|
||||
table_exists = await result.scalar()
|
||||
|
||||
if table_exists:
|
||||
print(f"删除表: {table_name}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,198 @@
|
|||
from loguru import logger
|
||||
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
|
||||
|
||||
# 核心:导入 LangChain 的基础语言模型抽象类
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
|
||||
@dataclass
|
||||
class LLMConfig_DataClass:
|
||||
"""
|
||||
统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段
|
||||
通过 provider + is_embedding 区分模型类型:
|
||||
- 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
|
||||
- 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False
|
||||
- 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True
|
||||
"""
|
||||
# ====================== 数据库核心公共字段(必选/可选) ======================
|
||||
# 基础标识字段
|
||||
name: str # 模型自定义名称(如 "gpt-5")
|
||||
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5")
|
||||
provider: str # 提供商(openai/llama/bge/zhipu 等)
|
||||
id: Optional[int] = None # 数据库主键ID
|
||||
description: Optional[str] = None # 模型描述
|
||||
is_active: bool = True # 是否启用
|
||||
is_default: bool = False # 是否默认模型
|
||||
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
|
||||
|
||||
# ====================== 通用生成参数(所有推理模型共用) ======================
|
||||
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
|
||||
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
|
||||
top_p: float = 0.6 # 采样Top-P
|
||||
frequency_penalty: float = 0.0 # 频率惩罚
|
||||
presence_penalty: float = 0.0 # 存在惩罚
|
||||
|
||||
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
|
||||
api_key: Optional[str] = None # API密钥(在线模型必填)
|
||||
base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1)
|
||||
# timeout: int = 30 # 请求超时时间(秒)
|
||||
max_retries: int = 3 # 最大重试次数
|
||||
api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview)
|
||||
|
||||
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
|
||||
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
|
||||
device: str = "cpu" # 运行设备(cpu/cuda/mps)
|
||||
n_ctx: int = 2048 # 上下文窗口大小
|
||||
n_threads: int = 8 # 推理线程数
|
||||
quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16)
|
||||
load_in_8bit: bool = False # 是否8bit加载
|
||||
load_in_4bit: bool = False # 是否4bit加载
|
||||
prompt_template: Optional[str] = None # 自定义Prompt模板
|
||||
|
||||
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
|
||||
normalize_embeddings: bool = True # 是否归一化向量
|
||||
batch_size: int = 32 # 批量编码大小
|
||||
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
|
||||
dimension: Optional[int] = None # 向量维度(如 768)
|
||||
|
||||
# ====================== 元数据字段(数据库自动维护) ======================
|
||||
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
|
||||
usage_count: int = 0 # 使用次数
|
||||
last_used_at: Optional[datetime] = None # 最后使用时间
|
||||
created_at: Optional[datetime] = None # 创建时间
|
||||
updated_at: Optional[datetime] = None # 更新时间
|
||||
created_by: Optional[int] = None # 创建人ID
|
||||
updated_by: Optional[int] = None # 更新人ID
|
||||
|
||||
api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储)
|
||||
|
||||
# ====================== 核心工具方法 ======================
|
||||
def __post_init__(self):
|
||||
"""后置初始化:自动校验和修正配置"""
|
||||
# 1. 嵌入式模型强制清空推理参数(避免误用)
|
||||
if self.is_embedding:
|
||||
self.max_tokens = 0
|
||||
self.temperature = 0.0
|
||||
self.top_p = 0.0
|
||||
|
||||
# 2. 校验必填参数(按模型类型)
|
||||
self._validate_required_fields()
|
||||
|
||||
def _validate_required_fields(self):
|
||||
"""按模型类型校验必填参数"""
|
||||
# 在线模型校验
|
||||
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||||
if not self.api_key:
|
||||
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
|
||||
|
||||
# 本地模型校验
|
||||
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||||
if not self.model_path:
|
||||
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于存入/更新数据库)"""
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items()
|
||||
if not key.startswith('_') # 排除私有属性
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
|
||||
"""从数据库字典初始化配置(核心方法)"""
|
||||
# 1. 时间字段转换:字符串 → datetime
|
||||
time_fields = ['last_used_at', 'created_at', 'updated_at']
|
||||
for field_name in time_fields:
|
||||
val = db_dict.get(field_name)
|
||||
if val and isinstance(val, str):
|
||||
try:
|
||||
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
|
||||
except (ValueError, TypeError):
|
||||
db_dict[field_name] = None
|
||||
|
||||
# 2. 过滤数据库中无关字段(如 api_key_masked)
|
||||
valid_fields = cls.__dataclass_fields__.keys()
|
||||
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
|
||||
|
||||
# 3. 初始化并返回配置实例
|
||||
return cls(**filtered_dict)
|
||||
|
||||
def get_model_type(self) -> str:
|
||||
"""快速判断模型类型(返回:online/local/embedding)"""
|
||||
if self.is_embedding:
|
||||
return "embedding"
|
||||
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||||
return "online"
|
||||
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||||
return "local"
|
||||
return "unknown"
|
||||
|
||||
|
||||
class BaseLLM(BaseChatModel):
|
||||
"""
|
||||
继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类)
|
||||
使其能直接用于 create_agent
|
||||
"""
|
||||
# 配置参数(通过 __init__ 初始化)
|
||||
config: Any = None
|
||||
model: Any = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__() # 必须调用父类构造函数
|
||||
self.config = config
|
||||
self.model = None
|
||||
self._validate_config()
|
||||
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
|
||||
|
||||
# ---------------------- 必须实现的核心抽象方法(LangChain 协议) ----------------------
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
核心同步生成方法(LangChain 要求必须实现)
|
||||
messages: 消息列表(如 [HumanMessage(content="你好")])
|
||||
返回 ChatResult 类型(LangChain 标准输出)
|
||||
"""
|
||||
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
** kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""异步生成方法(LangChain 异步协议)"""
|
||||
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回模型类型标识(如 "openai"、"llama"、"bge")"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""加载模型(自定义逻辑)"""
|
||||
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
|
||||
|
||||
def close(self) -> None:
|
||||
"""释放资源(自定义逻辑)"""
|
||||
if self.model:
|
||||
logger.info(f"释放 {self.__class__.__name__} 模型资源")
|
||||
self.model = None
|
||||
|
||||
def __enter__(self):
|
||||
self.load_model()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from typing import List
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from loguru import logger
|
||||
from th_agenter.llm.base_llm import BaseLLM
|
||||
|
||||
class EmbedLLM(BaseLLM, Embeddings):
|
||||
"""嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel"""
|
||||
def __init__(self, config):
|
||||
logger.info(f"初始化 EmbedLLM 模型: {config.model_name}")
|
||||
super().__init__(config)
|
||||
logger.info(f"已加载 EmbedLLM 模型: {config.model_name}")
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""LangChain 要求的核心方法:批量文档向量化"""
|
||||
pass
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步批量向量化"""
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""单查询文本向量化"""
|
||||
pass
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""异步单查询向量化"""
|
||||
pass
|
||||
|
||||
# 具体实现 BGE 嵌入式模型
|
||||
class BGEEmbedLLM(EmbedLLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _validate_config(self):
|
||||
if not self.config.model_name:
|
||||
raise ValueError("必须配置 model_name")
|
||||
|
||||
def load_model(self):
|
||||
logger.info(f"正在加载 嵌入 模型: {self.config.model_name}")
|
||||
if hasattr(self.config, 'provider') and self.config.provider == 'ollama':
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
self.model = OllamaEmbeddings(
|
||||
model=self.config.model_name,
|
||||
base_url=self.config.base_url if hasattr(self.config, 'base_url') else None
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
self.model = HuggingFaceEmbeddings(
|
||||
model_name=self.config.model_name,
|
||||
model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"},
|
||||
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True}
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to load HuggingFaceEmbeddings: {e}")
|
||||
logger.error("Please install sentence-transformers: pip install sentence-transformers")
|
||||
raise
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
return self.model.embed_documents(texts)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
return await self.model.aembed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
return self.model.embed_query(text)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
return await self.model.aembed_query(text)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import os, dotenv
|
||||
from loguru import logger
|
||||
from utils.Constant import Constant
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.messages import HumanMessage
|
||||
# 加载环境变量
|
||||
dotenv.load_dotenv()
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
||||
os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
class LLM_Model_Base(object):
|
||||
'''
|
||||
语言模型基类
|
||||
所有语言模型类的基类,定义了语言模型的基本属性和方法。
|
||||
- 语言模型名称, 缺省为"gpt-4o-mini"
|
||||
- 温度,缺省为0.7
|
||||
- 语言模型实例, 由子类实现
|
||||
- 语言模型模式, 由子类实现
|
||||
- 语言模型名称, 用于描述语言模型, 在人机界面中显示
|
||||
|
||||
author: DrGraph
|
||||
date: 2025-11-20
|
||||
'''
|
||||
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||
self.model_name = model_name # 0.15 0.6
|
||||
self.temperature = temperature
|
||||
self.llmModel = None
|
||||
self.mode = Constant.LLM_MODE_NONE
|
||||
self.name = '未知模型'
|
||||
|
||||
def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str):
|
||||
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
|
||||
prompt_template = PromptTemplate.from_template(
|
||||
template="请回答以下问题: {question}",
|
||||
)
|
||||
prompt_template_value = None
|
||||
if methodType == "format":
|
||||
# 方式1 - 使用format方法,取得字符串
|
||||
prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串
|
||||
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法,取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}")
|
||||
|
||||
if valueType == "str":
|
||||
# 1.1 直接用字符串进行调用LLM的invoke
|
||||
prompt_template_value = prompt_str
|
||||
logger.info(f"{self.name} >>> 1.2.1 直接使用字符串")
|
||||
|
||||
elif valueType == "messages":
|
||||
# 1.2 由字符串,创建HumanMessage对象列表
|
||||
prompt_template_value = [HumanMessage(content=prompt)]
|
||||
logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表")
|
||||
|
||||
elif methodType == "invoke":
|
||||
# 方式2 - 使用invoke方法,取得PromptValue
|
||||
prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue
|
||||
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法,取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}")
|
||||
if valueType == "str":
|
||||
# 2.1 再倒回字符串方式
|
||||
prompt_template_value = prompt_value.to_string()
|
||||
logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串")
|
||||
elif valueType == "promptValue":
|
||||
# 2.2 直接使用 prompt_value 作为 prompt_template_value
|
||||
prompt_template_value = prompt_value
|
||||
logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value")
|
||||
elif valueType == "messages":
|
||||
# 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表
|
||||
prompt_template_value = prompt_value.to_messages()
|
||||
logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表")
|
||||
|
||||
logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}")
|
||||
return prompt_template_value
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from loguru import logger
|
||||
|
||||
from DrGraph.utils.Constant import Constant
|
||||
from LLM.llm_model_base import LLM_Model_Base
|
||||
|
||||
class Chat_LLM(LLM_Model_Base):
|
||||
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||
super().__init__(model_name, temperature)
|
||||
self.name = '聊天模型'
|
||||
self.mode = Constant.LLM_MODE_CHAT
|
||||
self.llmModel = ChatOpenAI(
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
# 返回消息格式,以便在chatbot中显示
|
||||
def invoke(self, prompt: str):
|
||||
prompt_template_value = self.buildPromptTemplateValue(
|
||||
prompt=prompt,
|
||||
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
|
||||
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
|
||||
try:
|
||||
response = self.llmModel.invoke(prompt_template_value)
|
||||
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
|
||||
# response = {"role": "assistant", "content": response.content}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return response
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
'''
|
||||
非聊天模型类,继承自 LLM_Model_Base
|
||||
|
||||
author: DrGraph
|
||||
date: 2025-11-20
|
||||
'''
|
||||
from loguru import logger
|
||||
from langchain_openai import OpenAI
|
||||
from langchain_core.messages import AIMessage
|
||||
from DrGraph.utils.Constant import Constant
|
||||
from LLM.llm_model_base import LLM_Model_Base
|
||||
|
||||
|
||||
class NonChat_LLM(LLM_Model_Base):
|
||||
'''
|
||||
非聊天模型类,继承自 LLM_Model_Base,调用这个非聊天模型OpenAI
|
||||
- 语言模型名称, 缺省为"gpt-4o-mini"
|
||||
- 温度,缺省为0.7
|
||||
- 语言模型名称 = "非聊天模型", 在人机界面中显示
|
||||
'''
|
||||
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||
super().__init__(model_name, temperature)
|
||||
self.name = '非聊天模型'
|
||||
self.mode = Constant.LLM_MODE_NONCHAT
|
||||
self.llmModel = OpenAI(
|
||||
model_name=self.model_name,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
# 返回消息格式,以便在chatbot中显示
|
||||
def invoke(self, prompt: str):
|
||||
'''
|
||||
调用非聊天模型,返回消息格式,以便在chatbot中显示
|
||||
prompt: 用户输入,为字符串类型
|
||||
return: 助手回复,为字符串类型
|
||||
'''
|
||||
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
|
||||
try:
|
||||
response = self.llmModel.invoke(prompt)
|
||||
logger.info(f"{self.name} >>> 1.2 助手回复: {type(response)}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return response
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from loguru import logger
|
||||
|
||||
from DrGraph.utils.Constant import Constant
|
||||
from LLM.llm_model_base import LLM_Model_Base
|
||||
from langchain_ollama import ChatOllama
|
||||
class Chat_Ollama(LLM_Model_Base):
|
||||
def __init__(self, base_url="http://127.0.0.1:11434", model_name: str = "OxW/Qwen3-0.6B-GGUF:latest", temperature: float = 0.7):
|
||||
super().__init__(model_name, temperature)
|
||||
self.name = '私有化Ollama模型'
|
||||
self.base_url = base_url
|
||||
self.llmModel = ChatOllama(
|
||||
base_url = self.base_url,
|
||||
model=model_name,
|
||||
temperature=temperature
|
||||
)
|
||||
self.mode = Constant.LLM_MODE_LOCAL_OLLAMA
|
||||
|
||||
def invoke(self, prompt: str):
|
||||
prompt_template_value = self.buildPromptTemplateValue(
|
||||
prompt=prompt,
|
||||
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
|
||||
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
|
||||
try:
|
||||
response = self.llmModel.invoke(prompt_template_value)
|
||||
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return response
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import List, Optional
|
||||
from th_agenter.llm.base_llm import BaseLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
|
||||
|
||||
class LocalLLM(BaseLLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.local_config = config
|
||||
|
||||
def _validate_config(self):
|
||||
if not self.local_config.model_path:
|
||||
raise ValueError("LocalLLM 必须配置 model_path")
|
||||
|
||||
def load_model(self):
|
||||
from langchain_community.llms import LlamaCpp
|
||||
self.model = LlamaCpp(
|
||||
model_path=self.local_config.model_path,
|
||||
temperature=self.local_config.temperature,
|
||||
max_tokens=self.local_config.max_tokens,
|
||||
n_ctx=self.local_config.n_ctx,
|
||||
n_threads=self.local_config.n_threads,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "llama"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
# 适配 LlamaCpp(非 Chat 模型)的调用方式
|
||||
prompt = self._format_messages(messages)
|
||||
text = self.model.invoke(prompt, stop=stop, **kwargs)
|
||||
# 构造 ChatResult(LangChain 标准格式)
|
||||
generation = ChatGeneration(message=AIMessage(content=text))
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
prompt = self._format_messages(messages)
|
||||
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
|
||||
generation = ChatGeneration(message=AIMessage(content=text))
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _format_messages(self, messages: List[BaseMessage]) -> str:
|
||||
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
|
||||
prompt_parts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
|
||||
elif isinstance(msg, AIMessage):
|
||||
prompt_parts.append(msg.content)
|
||||
return "".join(prompt_parts)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, BaseMessage
|
||||
from typing import List, Optional, Any, Union
|
||||
from langchain_core.outputs import ChatResult
|
||||
from th_agenter.llm.base_llm import BaseLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
|
||||
class OnlineLLM(BaseLLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _validate_config(self):
|
||||
if not self.config.api_key:
|
||||
raise ValueError("OnlineLLM 必须配置 api_key")
|
||||
|
||||
def load_model(self):
|
||||
# from langchain.chat_models import init_chat_model
|
||||
# self.model = init_chat_model(
|
||||
# self.config.model_name,
|
||||
# self.config.api_key)
|
||||
from langchain_openai import ChatOpenAI
|
||||
self.model = ChatOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
model_name=self.config.model_name,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
base_url=self.config.base_url,
|
||||
)
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openai" # 标识模型类型
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""委托给底层 LangChain 模型的 _generate 方法"""
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
# 复用底层模型的实现
|
||||
return self.model._generate(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,** kwargs
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if not self.model:
|
||||
self.load_model()
|
||||
return await self.model._agenerate(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,** kwargs
|
||||
)
|
||||
|
||||
# ---------------------- 保留自定义的便捷方法 ----------------------
|
||||
def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
||||
"""自定义便捷方法:直接传入字符串 prompt 或消息列表"""
|
||||
if isinstance(prompt, str):
|
||||
messages = [HumanMessage(content=prompt)]
|
||||
else:
|
||||
messages = prompt
|
||||
result = self._generate(messages, **kwargs)
|
||||
return result.generations[0].text
|
||||
|
||||
async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
||||
"""自定义便捷异步方法:直接传入字符串 prompt 或消息列表"""
|
||||
if isinstance(prompt, str):
|
||||
messages = [HumanMessage(content=prompt)]
|
||||
else:
|
||||
messages = prompt
|
||||
result = await self._agenerate(messages, **kwargs)
|
||||
return result.generations[0].text
|
||||
|
|
@ -24,15 +24,20 @@ class Conversation(BaseModel):
|
|||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert conversation to a dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"title": self.title,
|
||||
"user_id": self.user_id,
|
||||
"knowledge_base_id": self.knowledge_base_id,
|
||||
"system_prompt": self.system_prompt,
|
||||
"model_name": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"is_archived": self.is_archived,
|
||||
"message_count": self.message_count,
|
||||
"last_message_at": self.last_message_at,
|
||||
}
|
||||
def __repr__(self):
|
||||
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
"""Get the number of messages in this conversation."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def last_message_at(self):
|
||||
"""Get the timestamp of the last message."""
|
||||
return self.messages[-1].created_at or self.created_at
|
||||
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id}, system_prompt={self.system_prompt}, model_name='{self.model_name}', temperature='{self.temperature}', message_count={self.message_count})>"
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class LLMConfig(BaseModel):
|
|||
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
|
||||
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model_name='{self.model_name}', base_url='{self.base_url}')>"
|
||||
|
||||
def to_dict(self, include_sensitive=False):
|
||||
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||
|
|
@ -60,7 +60,7 @@ class LLMConfig(BaseModel):
|
|||
'is_embedding': self.is_embedding,
|
||||
'extra_config': self.extra_config,
|
||||
'usage_count': self.usage_count,
|
||||
'last_used_at': self.last_used_at
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None
|
||||
})
|
||||
|
||||
if include_sensitive:
|
||||
|
|
@ -102,8 +102,8 @@ class LLMConfig(BaseModel):
|
|||
if not self.name or not self.name.strip():
|
||||
return {"valid": False, "error": "配置名称不能为空"}
|
||||
|
||||
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
|
||||
return {"valid": False, "error": "不支持的服务商"}
|
||||
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']:
|
||||
return {"valid": False, "error": f"不支持的服务商 {self.provider}"}
|
||||
|
||||
if not self.model_name or not self.model_name.strip():
|
||||
return {"valid": False, "error": "模型名称不能为空"}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class LLMConfigCreate(LLMConfigBase):
|
|||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||
'ollama', 'custom', "doubao"
|
||||
'ollama', 'custom', "doubao", "ollama"
|
||||
]
|
||||
if v.lower() not in allowed_providers:
|
||||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
|
|
@ -74,7 +74,7 @@ class LLMConfigUpdate(BaseModel):
|
|||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||
'ollama', 'custom',"doubao"
|
||||
'ollama', 'custom',"doubao", "ollama"
|
||||
]
|
||||
if v.lower() not in allowed_providers:
|
||||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
|
|
|
|||
|
|
@ -33,14 +33,16 @@ class AgentConfig(BaseModel):
|
|||
class AgentService:
|
||||
"""LangChain Agent service with tool calling capabilities."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
async def initialize(self, session=None):
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = AgentConfig()
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self.session = session
|
||||
self.config_service = AgentConfigService(session) if session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
await self._load_config()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize and register all available tools."""
|
||||
|
|
@ -56,18 +58,17 @@ class AgentService:
|
|||
self.tool_registry.register(tool)
|
||||
logger.info(f"Registered tool: {tool.get_name()}")
|
||||
|
||||
def _load_config(self):
|
||||
async def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
config_dict = self.config_service.get_config_dict()
|
||||
config_dict = await self.config_service.get_config_dict()
|
||||
# Update config with database values
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded agent configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database, using defaults: {str(e)}")
|
||||
logger.error(f"Failed to load config from database, using defaults: {str(e)}")
|
||||
|
||||
def _get_enabled_tools(self) -> List[Any]:
|
||||
"""Get list of enabled LangChain tools."""
|
||||
|
|
@ -83,11 +84,11 @@ class AgentService:
|
|||
|
||||
return enabled_tools
|
||||
|
||||
def _create_agent_executor(self) -> Any:
|
||||
async def _create_agent_executor(self) -> Any:
|
||||
"""Create LangChain agent executor."""
|
||||
# Get LLM configuration
|
||||
from ...core.llm import create_llm
|
||||
llm = create_llm()
|
||||
from ...core.new_agent import new_agent
|
||||
llm = await new_agent()
|
||||
|
||||
# Get enabled tools
|
||||
tools = self._get_enabled_tools()
|
||||
|
|
@ -114,7 +115,7 @@ class AgentService:
|
|||
logger.info(f"Processing agent chat message: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
agent = await self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
|
|
@ -155,7 +156,7 @@ class AgentService:
|
|||
logger.info(f"Processing agent chat stream: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
agent = await self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
|
|
@ -263,17 +264,18 @@ class AgentService:
|
|||
|
||||
|
||||
# Global agent service instance
|
||||
_agent_service: Optional[AgentService] = None
|
||||
_global_agent_service: Optional[AgentService] = None
|
||||
|
||||
|
||||
def get_agent_service(db_session=None) -> AgentService:
|
||||
async def get_agent_service(session=None) -> AgentService:
|
||||
"""Get global agent service instance."""
|
||||
global _agent_service
|
||||
if _agent_service is None:
|
||||
_agent_service = AgentService(db_session)
|
||||
elif db_session and not _agent_service.db_session:
|
||||
global _global_agent_service
|
||||
if _global_agent_service is None:
|
||||
_global_agent_service = AgentService()
|
||||
await _global_agent_service.initialize(session)
|
||||
elif session and session != _global_agent_service.session:
|
||||
# Update with database session if not already set
|
||||
_agent_service.db_session = db_session
|
||||
_agent_service.config_service = AgentConfigService(db_session)
|
||||
_agent_service._load_config()
|
||||
return _agent_service
|
||||
_global_agent_service.session = session
|
||||
_global_agent_service.config_service = AgentConfigService(session)
|
||||
_global_agent_service._load_config()
|
||||
return _global_agent_service
|
||||
|
|
|
|||
|
|
@ -41,16 +41,18 @@ class LangGraphAgentConfig(BaseModel):
|
|||
class LangGraphAgentService:
|
||||
"""LangGraph Agent service using low-level LangGraph graph (React pattern)."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
|
||||
async def initialize(self, session=None):
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = LangGraphAgentConfig()
|
||||
self.tools = []
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self.session = session
|
||||
self.config_service = AgentConfigService(session) if session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
self._create_react_agent()
|
||||
await self._load_config()
|
||||
await self._create_react_agent()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize available tools."""
|
||||
|
|
@ -76,28 +78,29 @@ class LangGraphAgentService:
|
|||
|
||||
|
||||
|
||||
def _load_config(self):
|
||||
async def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
db_config = self.config_service.get_active_config()
|
||||
if db_config:
|
||||
# Update config with database values
|
||||
config_dict = db_config.config_data
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database: {e}")
|
||||
pass
|
||||
# if self.config_service:
|
||||
# try:
|
||||
# db_config = self.config_service.get_active_config()
|
||||
# if db_config:
|
||||
# # Update config with database values
|
||||
# config_dict = db_config.config_data
|
||||
# for key, value in config_dict.items():
|
||||
# if hasattr(self.config, key):
|
||||
# setattr(self.config, key, value)
|
||||
# logger.info("Loaded configuration from database")
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Failed to load config from database: {e}")
|
||||
|
||||
|
||||
|
||||
def _create_react_agent(self):
|
||||
async def _create_react_agent(self):
|
||||
"""Create LangGraph agent using low-level StateGraph with explicit nodes/edges."""
|
||||
try:
|
||||
# Initialize the model
|
||||
llm_config = get_settings().llm.get_current_config()
|
||||
llm_config = await get_settings().llm.get_current_config(self.db_session)
|
||||
self.model = init_chat_model(
|
||||
model=llm_config['model'],
|
||||
model_provider='openai',
|
||||
|
|
@ -183,7 +186,7 @@ class LangGraphAgentService:
|
|||
# Compile graph and store as self.agent for compatibility with existing code
|
||||
self.react_agent = graph.compile()
|
||||
|
||||
logger.info("LangGraph low-level React agent created successfully")
|
||||
logger.info("LangGraph 底层 React 智能体创建成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {str(e)}")
|
||||
|
|
@ -723,15 +726,14 @@ class LangGraphAgentService:
|
|||
|
||||
|
||||
# Global instance
|
||||
_langgraph_agent_service: Optional[LangGraphAgentService] = None
|
||||
_global_langgraph_agent_service: Optional[LangGraphAgentService] = None
|
||||
|
||||
|
||||
def get_langgraph_agent_service(db_session=None) -> LangGraphAgentService:
|
||||
async def get_langgraph_agent_service(session=None) -> LangGraphAgentService:
|
||||
"""Get or create LangGraph agent service instance."""
|
||||
global _langgraph_agent_service
|
||||
global _global_langgraph_agent_service
|
||||
|
||||
if _langgraph_agent_service is None:
|
||||
_langgraph_agent_service = LangGraphAgentService(db_session)
|
||||
logger.info("LangGraph Agent service initialized")
|
||||
if _global_langgraph_agent_service is None:
|
||||
_global_langgraph_agent_service = LangGraphAgentService()
|
||||
await _global_langgraph_agent_service.initialize(session)
|
||||
|
||||
return _langgraph_agent_service
|
||||
return _global_langgraph_agent_service
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import and_, select, update
|
||||
|
||||
from ..models.agent_config import AgentConfig
|
||||
from utils.util_exceptions import ValidationError, NotFoundError
|
||||
|
|
@ -15,7 +15,7 @@ class AgentConfigService:
|
|||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_config(self, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
async def create_config(self, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
"""Create a new agent configuration."""
|
||||
try:
|
||||
# Validate required fields
|
||||
|
|
@ -23,9 +23,8 @@ class AgentConfigService:
|
|||
raise ValidationError("Configuration name is required")
|
||||
|
||||
# Check if name already exists
|
||||
existing = self.db.query(AgentConfig).filter(
|
||||
AgentConfig.name == config_data["name"]
|
||||
).first()
|
||||
stmt = select(AgentConfig).where(AgentConfig.name == config_data["name"])
|
||||
existing = (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
if existing:
|
||||
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
|
||||
|
||||
|
|
@ -44,62 +43,60 @@ class AgentConfigService:
|
|||
|
||||
# If this is set as default, unset other defaults
|
||||
if config.is_default:
|
||||
self.db.query(AgentConfig).filter(
|
||||
AgentConfig.is_default == True
|
||||
).update({"is_default": False})
|
||||
stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False})
|
||||
await self.db.execute(stmt)
|
||||
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(config)
|
||||
|
||||
logger.info(f"Created agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
await self.db.rollback()
|
||||
logger.error(f"Error creating agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_config(self, config_id: int) -> Optional[AgentConfig]:
|
||||
async def get_config(self, config_id: int) -> Optional[AgentConfig]:
|
||||
"""Get agent configuration by ID."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
AgentConfig.id == config_id
|
||||
).first()
|
||||
stmt = select(AgentConfig).where(AgentConfig.id == config_id)
|
||||
return (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
def get_config_by_name(self, name: str) -> Optional[AgentConfig]:
|
||||
async def get_config_by_name(self, name: str) -> Optional[AgentConfig]:
|
||||
"""Get agent configuration by name."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
AgentConfig.name == name
|
||||
).first()
|
||||
stmt = select(AgentConfig).where(AgentConfig.name == name)
|
||||
return (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
def get_default_config(self) -> Optional[AgentConfig]:
|
||||
async def get_default_config(self) -> Optional[AgentConfig]:
|
||||
"""Get default agent configuration."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
and_(AgentConfig.is_default == True, AgentConfig.is_active == True)
|
||||
).first()
|
||||
stmt = select(AgentConfig).where(and_(AgentConfig.is_default == True, AgentConfig.is_active == True))
|
||||
return (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
def list_configs(self, active_only: bool = True) -> List[AgentConfig]:
|
||||
"""List all agent configurations."""
|
||||
query = self.db.query(AgentConfig)
|
||||
stmt = select(AgentConfig)
|
||||
if active_only:
|
||||
query = query.filter(AgentConfig.is_active == True)
|
||||
return query.order_by(AgentConfig.created_at.desc()).all()
|
||||
stmt = stmt.where(AgentConfig.is_active == True)
|
||||
stmt = stmt.order_by(AgentConfig.created_at.desc())
|
||||
return self.db.execute(stmt).scalars().all()
|
||||
|
||||
def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
async def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
"""Update agent configuration."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
config = await self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
# Check if name change conflicts with existing
|
||||
if "name" in config_data and config_data["name"] != config.name:
|
||||
existing = self.db.query(AgentConfig).filter(
|
||||
stmt = select(AgentConfig).where(
|
||||
and_(
|
||||
AgentConfig.name == config_data["name"],
|
||||
AgentConfig.id != config_id
|
||||
)
|
||||
).first()
|
||||
)
|
||||
existing = (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
if existing:
|
||||
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
|
||||
|
||||
|
|
@ -110,28 +107,29 @@ class AgentConfigService:
|
|||
|
||||
# If this is set as default, unset other defaults
|
||||
if config_data.get("is_default", False):
|
||||
self.db.query(AgentConfig).filter(
|
||||
stmt = update(AgentConfig).where(
|
||||
and_(
|
||||
AgentConfig.is_default == True,
|
||||
AgentConfig.id != config_id
|
||||
)
|
||||
).update({"is_default": False})
|
||||
).values({"is_default": False})
|
||||
await self.db.execute(stmt)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(config)
|
||||
|
||||
logger.info(f"Updated agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
await self.db.rollback()
|
||||
logger.error(f"Error updating agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_config(self, config_id: int) -> bool:
|
||||
async def delete_config(self, config_id: int) -> bool:
|
||||
"""Delete agent configuration (soft delete by setting is_active=False)."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
config = await self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
|
|
@ -146,14 +144,14 @@ class AgentConfigService:
|
|||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
await self.db.rollback()
|
||||
logger.error(f"Error deleting agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def set_default_config(self, config_id: int) -> AgentConfig:
|
||||
async def set_default_config(self, config_id: int) -> AgentConfig:
|
||||
"""Set a configuration as default."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
config = await self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
|
|
@ -161,29 +159,28 @@ class AgentConfigService:
|
|||
raise ValidationError("Cannot set inactive configuration as default")
|
||||
|
||||
# Unset other defaults
|
||||
self.db.query(AgentConfig).filter(
|
||||
AgentConfig.is_default == True
|
||||
).update({"is_default": False})
|
||||
stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False})
|
||||
await self.db.execute(stmt)
|
||||
|
||||
# Set this as default
|
||||
config.is_default = True
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(config)
|
||||
|
||||
logger.info(f"Set default agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
await self.db.rollback()
|
||||
logger.error(f"Error setting default agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
async def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get configuration as dictionary. If no ID provided, returns default config."""
|
||||
if config_id:
|
||||
config = self.get_config(config_id)
|
||||
config = await self.get_config(config_id)
|
||||
else:
|
||||
config = self.get_default_config()
|
||||
config = await self.get_default_config()
|
||||
|
||||
if not config:
|
||||
# Return default values if no configuration found
|
||||
|
|
|
|||
|
|
@ -33,19 +33,16 @@ class AuthService:
|
|||
)
|
||||
|
||||
token = credentials.credentials
|
||||
session.desc = f"[AuthService] 取得token: {token[:50]}..."
|
||||
payload = AuthService.verify_token(token)
|
||||
if payload is None:
|
||||
session.desc = "ERROR: 令牌验证失败"
|
||||
session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..."
|
||||
raise credentials_exception
|
||||
|
||||
session.desc = f"[AuthService] 令牌有效 - 解析得到有效载荷: {payload}"
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
session.desc = "ERROR: 令牌中没有用户名"
|
||||
raise credentials_exception
|
||||
|
||||
session.desc = f"[AuthService] 获取当前用户 - 查找名为 {username} 的用户"
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if user is None:
|
||||
|
|
@ -53,8 +50,8 @@ class AuthService:
|
|||
raise credentials_exception
|
||||
|
||||
# Set user in context for global access
|
||||
UserContext.set_current_user(user)
|
||||
session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
|
||||
UserContext.set_current_user(user, canLog=True)
|
||||
# session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
|
||||
|
||||
return user
|
||||
|
||||
|
|
@ -138,6 +135,4 @@ class AuthService:
|
|||
except jwt.PyJWTError as e:
|
||||
logger.error(f"Token verification failed: {e}")
|
||||
logger.error(f"Token: {token[:50]}...")
|
||||
logger.error(f"Secret key: {settings.security.secret_key[:20]}...")
|
||||
logger.error(f"Algorithm: {settings.security.algorithm}")
|
||||
return None
|
||||
return None
|
||||
|
|
@ -1,45 +1,37 @@
|
|||
"""Chat service for AI model integration using LangChain."""
|
||||
|
||||
from th_agenter import db
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional, List, Dict, Any
|
||||
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from th_agenter.core.new_agent import new_agent, new_llm
|
||||
from ..core.config import settings
|
||||
from ..models.message import MessageRole
|
||||
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
||||
from utils.util_exceptions import ChatServiceError, OpenAIError
|
||||
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
|
||||
from .conversation import ConversationService
|
||||
from .langchain_chat import LangChainChatService
|
||||
from .knowledge_chat import KnowledgeChatService
|
||||
from .agent.agent_service import get_agent_service
|
||||
from .agent.langgraph_agent_service import get_langgraph_agent_service
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
class AgentState(TypedDict):
|
||||
messages: List[dict] # 存储对话消息(核心记忆)
|
||||
|
||||
class ChatService:
|
||||
"""Service for handling AI chat functionality using LangChain."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
# Initialize LangChain chat service
|
||||
self.langchain_service = LangChainChatService(db)
|
||||
|
||||
# Initialize Knowledge chat service
|
||||
self.knowledge_service = KnowledgeChatService(db)
|
||||
|
||||
# Initialize Agent service with database session
|
||||
self.agent_service = get_agent_service(db)
|
||||
|
||||
# Initialize LangGraph Agent service with database session
|
||||
self.langgraph_service = get_langgraph_agent_service(db)
|
||||
|
||||
logger.info("ChatService initialized with LangChain backend and Agent support")
|
||||
|
||||
_checkpointer_initialized = False
|
||||
_conn_string = None
|
||||
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
conversation_id: int,
|
||||
|
|
@ -57,7 +49,7 @@ class ChatService:
|
|||
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
||||
|
||||
# Use knowledge base chat service
|
||||
return await self.knowledge_service.chat_with_knowledge_base(
|
||||
return await self.knowledge_chat_service.chat_with_knowledge_base(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
|
|
@ -69,29 +61,29 @@ class ChatService:
|
|||
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
|
||||
|
||||
# Get conversation history for LangGraph agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Use LangGraph agent service
|
||||
agent_result = await self.langgraph_service.chat(message, chat_history)
|
||||
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
|
||||
|
||||
if agent_result["success"]:
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Save assistant response
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -114,11 +106,11 @@ class ChatService:
|
|||
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
|
||||
|
||||
# Get conversation history for agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
|
|
@ -129,14 +121,14 @@ class ChatService:
|
|||
|
||||
if agent_result["success"]:
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Save assistant response
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -159,7 +151,7 @@ class ChatService:
|
|||
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
|
||||
|
||||
# Delegate to LangChain service
|
||||
return await self.langchain_service.chat(
|
||||
return await self.langchain_chat_service.chat(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
stream=stream,
|
||||
|
|
@ -167,156 +159,12 @@ class ChatService:
|
|||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_agent: bool = False,
|
||||
use_langgraph: bool = False,
|
||||
use_knowledge_base: bool = False,
|
||||
knowledge_base_id: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
|
||||
if use_knowledge_base and knowledge_base_id:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
||||
|
||||
# Use knowledge base chat service streaming
|
||||
async for content in self.knowledge_service.chat_stream_with_knowledge_base(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
):
|
||||
# Create stream chunk for compatibility with existing API
|
||||
stream_chunk = StreamChunk(
|
||||
content=content,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
|
||||
elif use_langgraph:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangGraph Agent")
|
||||
|
||||
# Get conversation history for LangGraph agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Save user message first
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Use LangGraph agent service streaming
|
||||
full_response = ""
|
||||
intermediate_steps = []
|
||||
|
||||
async for chunk in self.langgraph_service.chat_stream(message, chat_history):
|
||||
if chunk["type"] == "response":
|
||||
full_response = chunk["content"]
|
||||
intermediate_steps = chunk.get("intermediate_steps", [])
|
||||
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
elif chunk["type"] == "error":
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
return
|
||||
else:
|
||||
# For other types (status, step, etc.), pass through
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"intermediate_steps": intermediate_steps}
|
||||
)
|
||||
elif use_agent:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Agent")
|
||||
|
||||
# Get conversation history for agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Save user message first
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Use agent service streaming
|
||||
full_response = ""
|
||||
tool_calls = []
|
||||
|
||||
async for chunk in self.agent_service.chat_stream(message, chat_history):
|
||||
if chunk["type"] == "response":
|
||||
full_response = chunk["content"]
|
||||
tool_calls = chunk.get("tool_calls", [])
|
||||
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
elif chunk["type"] == "error":
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
return
|
||||
else:
|
||||
# For other types (status, tool_start, etc.), pass through
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"tool_calls": tool_calls}
|
||||
)
|
||||
else:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangChain")
|
||||
|
||||
# Delegate to LangChain service and wrap response in JSON format
|
||||
async for content in self.langchain_service.chat_stream(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
):
|
||||
# Create stream chunk for compatibility with existing API
|
||||
stream_chunk = StreamChunk(
|
||||
content=content,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
|
||||
|
||||
async def get_available_models(self) -> List[str]:
|
||||
"""Get list of available models from LangChain."""
|
||||
logger.info("Getting available models via LangChain")
|
||||
|
||||
# Delegate to LangChain service
|
||||
return await self.langchain_service.get_available_models()
|
||||
return await self.langchain_chat_service.get_available_models()
|
||||
|
||||
def update_model_config(
|
||||
self,
|
||||
|
|
@ -328,8 +176,135 @@ class ChatService:
|
|||
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
||||
|
||||
# Delegate to LangChain service
|
||||
self.langchain_service.update_model_config(
|
||||
self.langchain_chat_service.update_model_config(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
# -------------------------------------------------------------------------
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
async def initialize(self, conversation_id: int, streaming: bool = False):
|
||||
self.conversation_service = ConversationService(self.session)
|
||||
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
|
||||
self.conversation = await self.conversation_service.get_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
if not self.conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
if not ChatService._checkpointer_initialized:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
import psycopg2
|
||||
CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres"
|
||||
ChatService._conn_string = CONN_STRING
|
||||
|
||||
# 检查必要的表是否已存在
|
||||
tables_need_setup = True
|
||||
try:
|
||||
# 连接到数据库并检查表是否存在
|
||||
conn = psycopg2.connect(CONN_STRING)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 检查langgraph需要的表是否存在
|
||||
cursor.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
|
||||
""")
|
||||
existing_tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# 检查是否所有必要的表都存在
|
||||
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
|
||||
if all(table in existing_tables for table in required_tables):
|
||||
tables_need_setup = False
|
||||
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup"
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)},将进行setup"
|
||||
tables_need_setup = True
|
||||
|
||||
# 只有在需要时才进行setup
|
||||
if tables_need_setup:
|
||||
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
|
||||
try:
|
||||
async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
|
||||
logger.info("PostgresSaver setup完成")
|
||||
except Exception as e:
|
||||
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
|
||||
logger.error(f"PostgresSaver setup失败: {e}")
|
||||
raise
|
||||
else:
|
||||
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
|
||||
|
||||
# 存储连接字符串供后续使用
|
||||
ChatService._checkpointer_initialized = True
|
||||
|
||||
self.llm = await new_llm(session=self.session, streaming=streaming)
|
||||
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": str(self.conversation.id),
|
||||
"checkpoint_ns": "drgraph"
|
||||
}
|
||||
}
|
||||
return config
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
|
||||
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=self.conversation.id,
|
||||
role=MessageRole.USER,
|
||||
content=message
|
||||
)
|
||||
full_assistant_content = ""
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
||||
from langchain.agents import create_agent
|
||||
agent = create_agent(
|
||||
model=self.llm, # await new_llm(session=self.session, streaming=self.streaming),
|
||||
checkpointer=checkpointer
|
||||
)
|
||||
async for chunk in agent.astream(
|
||||
{"messages": [{"role": "user", "content": message}]},
|
||||
config=self.get_config(),
|
||||
stream_mode="messages"
|
||||
):
|
||||
full_assistant_content += chunk[0].content
|
||||
json_result = {"data": {"v": chunk[0].content }}
|
||||
yield json.dumps(
|
||||
json_result,
|
||||
ensure_ascii=True
|
||||
)
|
||||
|
||||
if len(full_assistant_content) > 0:
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=self.conversation.id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=full_assistant_content
|
||||
)
|
||||
|
||||
def get_conversation_history_messages(
|
||||
self, conversation_id: int, skip: int = 0, limit: int = 100
|
||||
):
|
||||
"""Get conversation history messages with pagination."""
|
||||
result = []
|
||||
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
|
||||
checkpoints = checkpointer.list(self.get_config())
|
||||
for checkpoint in checkpoints:
|
||||
print(checkpoint)
|
||||
result.append(checkpoint.messages)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, desc, func, or_
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
|
||||
from th_agenter.db.database import AsyncSessionFactory
|
||||
|
||||
from ..models.conversation import Conversation
|
||||
from ..models.message import Message, MessageRole
|
||||
|
|
@ -24,7 +27,7 @@ class ConversationService:
|
|||
conversation_data: ConversationCreate
|
||||
) -> Conversation:
|
||||
"""Create a new conversation."""
|
||||
logger.info(f"Creating new conversation for user {user_id}: {conversation_data}")
|
||||
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
|
||||
|
||||
try:
|
||||
conversation = Conversation(
|
||||
|
|
@ -39,18 +42,23 @@ class ConversationService:
|
|||
await self.session.commit()
|
||||
await self.session.refresh(conversation)
|
||||
|
||||
logger.info(f"Successfully created conversation {conversation.id} for user {user_id}")
|
||||
self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}"
|
||||
return conversation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {str(e)}", exc_info=True)
|
||||
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Failed to create conversation: {str(e)}")
|
||||
raise DatabaseError(f"创建会话失败: {str(e)}")
|
||||
|
||||
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
|
||||
"""Get a conversation by ID."""
|
||||
try:
|
||||
user_id = UserContext.get_current_user_id()
|
||||
self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}"
|
||||
if user_id is None:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
|
||||
return None
|
||||
|
||||
conversation = await self.session.scalar(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
|
|
@ -59,12 +67,11 @@ class ConversationService:
|
|||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(f"Conversation {conversation_id} not found")
|
||||
|
||||
self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}"
|
||||
return conversation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True)
|
||||
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}"
|
||||
raise DatabaseError(f"Failed to get conversation: {str(e)}")
|
||||
|
||||
async def get_user_conversations(
|
||||
|
|
@ -78,6 +85,10 @@ class ConversationService:
|
|||
) -> List[Conversation]:
|
||||
"""Get user's conversations with search and filtering."""
|
||||
user_id = UserContext.get_current_user_id()
|
||||
if user_id is None:
|
||||
logger.error("Failed to get user conversations: No user context available")
|
||||
return []
|
||||
|
||||
query = select(Conversation).where(
|
||||
Conversation.user_id == user_id
|
||||
)
|
||||
|
|
@ -137,7 +148,7 @@ class ConversationService:
|
|||
if not conversation:
|
||||
return False
|
||||
|
||||
self.session.delete(conversation)
|
||||
await self.session.delete(conversation)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
|
|
@ -180,19 +191,42 @@ class ConversationService:
|
|||
# Set audit fields
|
||||
message.set_audit_fields()
|
||||
|
||||
self.session.add(message)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(message)
|
||||
session = AsyncSessionFactory()
|
||||
session.begin()
|
||||
try:
|
||||
session.add(message)
|
||||
await session.commit()
|
||||
await session.refresh(message)
|
||||
|
||||
# Update conversation's updated_at timestamp
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.now(timezone.utc)
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await self.session.commit()
|
||||
# Update conversation's updated_at timestamp
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.now(timezone.utc)
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
|
||||
await session.rollback()
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return message
|
||||
|
||||
async def get_conversation_history_messages(
|
||||
self,
|
||||
conversation_id: int,
|
||||
limit: int = 20
|
||||
) -> List[Message]:
|
||||
"""Get recent conversation history messages."""
|
||||
history = await self.get_conversation_history(conversation_id, limit)
|
||||
history_messages = []
|
||||
for message in history:
|
||||
if message.role == MessageRole.USER:
|
||||
history_messages.append(HumanMessage(content=message.content))
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
history_messages.append(AIMessage(content=message.content))
|
||||
return history_messages
|
||||
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: int,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class ConversationContextService:
|
|||
新创建的对话ID
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
session = await anext(get_session())
|
||||
|
||||
conversation = Conversation(
|
||||
user_id=user_id,
|
||||
|
|
@ -37,8 +37,8 @@ class ConversationContextService:
|
|||
)
|
||||
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
session.refresh(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
|
||||
# 初始化对话上下文
|
||||
self.context_cache[conversation.id] = {
|
||||
|
|
@ -74,7 +74,7 @@ class ConversationContextService:
|
|||
|
||||
# 从数据库加载
|
||||
try:
|
||||
session = next(get_session())
|
||||
session = await anext(get_session())
|
||||
|
||||
conversation = session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
|
|
@ -194,7 +194,7 @@ class ConversationContextService:
|
|||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
session = await anext(get_session())
|
||||
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
|
|
@ -205,7 +205,7 @@ class ConversationContextService:
|
|||
)
|
||||
|
||||
session.add(message)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# 更新对话的最后更新时间
|
||||
conversation = session.query(Conversation).filter(
|
||||
|
|
@ -214,7 +214,7 @@ class ConversationContextService:
|
|||
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -262,7 +262,7 @@ class ConversationContextService:
|
|||
消息历史列表
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
session = await anext(get_session())
|
||||
|
||||
messages = session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
|
|
|
|||
|
|
@ -102,41 +102,41 @@ class DatabaseConfigService:
|
|||
)
|
||||
|
||||
self.session.add(db_config)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_config)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(db_config)
|
||||
|
||||
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
logger.error(f"创建数据库配置失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
|
||||
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
|
||||
"""获取用户的数据库配置列表"""
|
||||
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
|
||||
if active_only:
|
||||
stmt = stmt.where(DatabaseConfig.is_active == True)
|
||||
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
|
||||
return self.session.scalars(stmt).all()
|
||||
return (await self.session.execute(stmt)).scalars().all()
|
||||
|
||||
def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
|
||||
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
|
||||
"""根据ID获取配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.id == config_id,
|
||||
DatabaseConfig.created_by == user_id
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
|
||||
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
|
||||
"""获取用户的默认配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.created_by == user_id,
|
||||
# DatabaseConfig.is_default == True,
|
||||
DatabaseConfig.is_active == True
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
||||
"""测试数据库连接"""
|
||||
|
|
@ -278,14 +278,14 @@ class DatabaseConfigService:
|
|||
'message': f'断开连接失败: {str(e)}'
|
||||
}
|
||||
|
||||
def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
|
||||
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
|
||||
"""根据数据库类型获取用户配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.created_by == user_id,
|
||||
DatabaseConfig.db_type == db_type,
|
||||
DatabaseConfig.is_active == True
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
||||
"""创建或更新数据库配置(保证db_type唯一性)"""
|
||||
|
|
@ -301,8 +301,8 @@ class DatabaseConfigService:
|
|||
elif hasattr(existing_config, key):
|
||||
setattr(existing_config, key, value)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(existing_config)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(existing_config)
|
||||
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
|
||||
return existing_config
|
||||
else:
|
||||
|
|
@ -310,7 +310,7 @@ class DatabaseConfigService:
|
|||
return await self.create_config(user_id, config_data)
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class DocumentService:
|
|||
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
|
||||
# Validate knowledge base exists
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
kb = self.session.scalar(stmt)
|
||||
kb = await self.session.scalar(stmt)
|
||||
if not kb:
|
||||
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise ValueError(f"知识库 {kb_id} 不存在")
|
||||
|
|
@ -48,6 +48,7 @@ class DocumentService:
|
|||
|
||||
# Upload file using storage service
|
||||
storage_info = await storage_service.upload_file(file, kb_id)
|
||||
self.session.desc = f"文档 {file.filename} 上传到 {storage_info}"
|
||||
|
||||
# Create document record
|
||||
document = Document(
|
||||
|
|
@ -65,21 +66,21 @@ class DocumentService:
|
|||
document.set_audit_fields()
|
||||
|
||||
self.session.add(document)
|
||||
self.session.commit()
|
||||
self.session.refresh(document)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(document)
|
||||
|
||||
self.session.desc = f"SUCCESS: 成功上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
|
||||
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
|
||||
return document
|
||||
|
||||
def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
|
||||
async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
|
||||
"""根据文档ID查询文档,可选地根据知识库ID过滤。"""
|
||||
self.session.desc = f"查询文档 {doc_id}"
|
||||
self.session.desc = f"根据文档ID查询文档 {doc_id}"
|
||||
stmt = select(Document).where(Document.id == doc_id)
|
||||
if kb_id is not None:
|
||||
stmt = stmt.where(Document.knowledge_base_id == kb_id)
|
||||
return self.session.scalar(stmt)
|
||||
return await self.session.scalar(stmt)
|
||||
|
||||
def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
|
||||
async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
|
||||
"""根据知识库ID查询文档,支持分页。"""
|
||||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||||
stmt = (
|
||||
|
|
@ -88,14 +89,14 @@ class DocumentService:
|
|||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return self.session.scalars(stmt).all()
|
||||
return (await self.session.scalars(stmt)).all()
|
||||
|
||||
def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
|
||||
async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
|
||||
"""根据知识库ID查询文档,支持分页,并返回总文档数。"""
|
||||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||||
# Get total count
|
||||
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
|
||||
total = self.session.scalar(count_stmt)
|
||||
total = await self.session.scalar(count_stmt)
|
||||
|
||||
# Get documents with pagination
|
||||
documents_stmt = (
|
||||
|
|
@ -104,44 +105,52 @@ class DocumentService:
|
|||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
documents = self.session.scalars(documents_stmt).all()
|
||||
documents = (await self.session.scalars(documents_stmt)).all()
|
||||
|
||||
return documents, total
|
||||
|
||||
def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
|
||||
async def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
|
||||
"""根据文档ID删除文档,可选地根据知识库ID过滤。"""
|
||||
self.session.desc = f"删除文档 {doc_id}"
|
||||
document = self.get_document(doc_id, kb_id)
|
||||
document = await self.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return False
|
||||
|
||||
# Delete file from storage
|
||||
try:
|
||||
storage_service.delete_file(document.file_path)
|
||||
logger.info(f"Deleted file: {document.file_path}")
|
||||
await storage_service.delete_file(document.file_path)
|
||||
self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}"
|
||||
except Exception as e:
|
||||
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
|
||||
|
||||
# TODO: Remove from vector database
|
||||
# This should be implemented when vector database service is ready
|
||||
get_document_processor().delete_document_from_vector_store(kb_id,doc_id)
|
||||
self.session.desc = f"从向量数据库删除文档 {doc_id}"
|
||||
(await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id)
|
||||
# Delete database record
|
||||
self.session.delete(document)
|
||||
self.session.commit()
|
||||
self.session.desc = f"删除数据库记录 {doc_id}"
|
||||
await self.session.delete(document)
|
||||
await self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
|
||||
return True
|
||||
|
||||
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
|
||||
"""处理文档,提取文本并创建嵌入向量。"""
|
||||
try:
|
||||
self.session.desc = f"处理文档 {doc_id}"
|
||||
document = self.get_document(doc_id, kb_id)
|
||||
self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量"
|
||||
document = await self.get_document(doc_id, kb_id)
|
||||
self.session.desc = f"获取文档 {doc_id} >>> {document}"
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
if document.is_processed:
|
||||
# document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素
|
||||
file_path = document.file_path
|
||||
knowledge_base_id=document.knowledge_base_id
|
||||
is_processed=document.is_processed
|
||||
|
||||
if is_processed:
|
||||
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
|
||||
return {
|
||||
"document_id": doc_id,
|
||||
|
|
@ -149,38 +158,43 @@ class DocumentService:
|
|||
"message": "文档已处理"
|
||||
}
|
||||
|
||||
self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}"
|
||||
# 更新文档状态为处理中
|
||||
document.processing_error = None
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
self.session.desc = f"更新文档状态为处理中 {doc_id}"
|
||||
|
||||
# 调用文档处理器进行处理
|
||||
result = get_document_processor().process_document(
|
||||
document_processor = await get_document_processor(self.session)
|
||||
self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}"
|
||||
result = await document_processor.process_document(
|
||||
session=self.session,
|
||||
document_id=doc_id,
|
||||
file_path=document.file_path,
|
||||
knowledge_base_id=document.knowledge_base_id
|
||||
file_path=file_path,
|
||||
knowledge_base_id=knowledge_base_id
|
||||
)
|
||||
self.session.desc = f"SUCCESS: 成功处理文档 {doc_id}"
|
||||
self.session.desc = f"处理文档完毕 {doc_id}"
|
||||
|
||||
# 如果处理成功,更新文档状态
|
||||
if result["status"] == "success":
|
||||
document.is_processed = True
|
||||
document.chunk_count = result.get("chunks_count", 0)
|
||||
self.session.commit()
|
||||
self.session.refresh(document)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(document)
|
||||
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
|
||||
|
||||
# Update document with error
|
||||
try:
|
||||
document = self.get_document(doc_id)
|
||||
document = await self.get_document(doc_id, kb_id)
|
||||
if document:
|
||||
document.processing_error = str(e)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"Failed to update document error status: {db_error}")
|
||||
|
||||
|
|
@ -217,10 +231,10 @@ class DocumentService:
|
|||
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
|
||||
raise
|
||||
|
||||
def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
|
||||
async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
|
||||
"""更新文档处理状态。"""
|
||||
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
|
||||
document = self.get_document(doc_id)
|
||||
document = await self.get_document(doc_id)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return False
|
||||
|
|
@ -228,16 +242,16 @@ class DocumentService:
|
|||
document.is_processed = is_processed
|
||||
document.processing_error = error
|
||||
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
|
||||
return True
|
||||
|
||||
def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""在知识库中搜索文档使用向量相似度。"""
|
||||
try:
|
||||
# 使用文档处理器进行相似性搜索
|
||||
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query}"
|
||||
results = get_document_processor().search_similar_documents(kb_id, query, limit)
|
||||
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}条"
|
||||
results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit)
|
||||
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
|
||||
return results
|
||||
except Exception as e:
|
||||
|
|
@ -245,9 +259,9 @@ class DocumentService:
|
|||
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
|
||||
return []
|
||||
|
||||
def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
|
||||
async def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
|
||||
"""获取知识库中的文档统计信息。"""
|
||||
documents = self.get_documents(kb_id, limit=1000) # Get all documents
|
||||
documents = await self.get_documents(kb_id, limit=1000) # Get all documents
|
||||
|
||||
total_count = len(documents)
|
||||
processed_count = len([doc for doc in documents if doc.is_processed])
|
||||
|
|
@ -267,19 +281,21 @@ class DocumentService:
|
|||
"file_types": file_types
|
||||
}
|
||||
|
||||
def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
|
||||
async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
|
||||
"""获取特定文档的文档块。"""
|
||||
try:
|
||||
self.session.desc = f"获取文档 {doc_id} 的文档块"
|
||||
stmt = select(Document).where(Document.id == doc_id)
|
||||
document = self.session.scalar(stmt)
|
||||
document = await self.session.scalar(stmt)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return []
|
||||
|
||||
self.session.desc = f"获取文档 {doc_id} 的文档块 > document"
|
||||
# Get chunks from document processor
|
||||
chunks_data = get_document_processor().get_document_chunks(document.knowledge_base_id, doc_id)
|
||||
chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id)
|
||||
|
||||
self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data"
|
||||
# Convert to DocumentChunk objects
|
||||
chunks = []
|
||||
for chunk_data in chunks_data:
|
||||
|
|
@ -290,7 +306,8 @@ class DocumentService:
|
|||
page_number=chunk_data.get("page_number"),
|
||||
chunk_index=chunk_data["chunk_index"],
|
||||
start_char=chunk_data.get("start_char"),
|
||||
end_char=chunk_data.get("end_char")
|
||||
end_char=chunk_data.get("end_char"),
|
||||
vector_id=chunk_data.get("vector_id")
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from requests import Session
|
||||
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
||||
from ..core.config import settings
|
||||
from loguru import logger
|
||||
|
|
@ -12,7 +13,8 @@ class EmbeddingFactory:
|
|||
"""Factory class for creating embedding instances based on provider."""
|
||||
|
||||
@staticmethod
|
||||
def create_embeddings(
|
||||
async def create_embeddings(
|
||||
session: Session = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
dimensions: Optional[int] = None
|
||||
|
|
@ -28,12 +30,12 @@ class EmbeddingFactory:
|
|||
Embeddings instance
|
||||
"""
|
||||
# 使用新的embedding配置
|
||||
embedding_config = settings.embedding.get_current_config()
|
||||
embedding_config = await settings.embedding.get_current_config(session)
|
||||
provider = provider or settings.embedding.provider
|
||||
model = model or embedding_config.get("model")
|
||||
dimensions = dimensions or settings.vector_db.embedding_dimension
|
||||
|
||||
logger.info(f"Creating embeddings with provider: {provider}, model: {model}")
|
||||
session.desc = f"创建嵌入模型: {provider}, {model}"
|
||||
|
||||
if provider == "openai":
|
||||
return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import os
|
|||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, func
|
||||
from ..models.excel_file import ExcelFile
|
||||
from ..db.database import get_session
|
||||
from loguru import logger
|
||||
|
|
@ -102,7 +103,7 @@ class ExcelMetadataService:
|
|||
'processing_error': str(e)
|
||||
}
|
||||
|
||||
def save_file_metadata(self, file_path: str, original_filename: str,
|
||||
async def save_file_metadata(self, file_path: str, original_filename: str,
|
||||
user_id: int, file_size: int) -> ExcelFile:
|
||||
"""Extract and save Excel file metadata to database."""
|
||||
try:
|
||||
|
|
@ -131,31 +132,30 @@ class ExcelMetadataService:
|
|||
|
||||
|
||||
# Save to database
|
||||
self.db.add(excel_file)
|
||||
self.db.commit()
|
||||
self.db.refresh(excel_file)
|
||||
self.session.add(excel_file)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(excel_file)
|
||||
|
||||
logger.info(f"Saved metadata for file {original_filename} with ID {excel_file.id}")
|
||||
return excel_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata for {original_filename}: {str(e)}")
|
||||
self.db.rollback()
|
||||
await self.session.rollback()
|
||||
raise
|
||||
|
||||
def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]:
|
||||
async def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]:
|
||||
"""Get Excel files for a user with pagination."""
|
||||
try:
|
||||
# Get total count
|
||||
total = self.db.query(ExcelFile).filter(ExcelFile.created_by == user_id).count()
|
||||
stmt = select(func.count(ExcelFile.id)).where(ExcelFile.created_by == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
total = result.scalar_one()
|
||||
|
||||
# Get files with pagination
|
||||
files = (self.db.query(ExcelFile)
|
||||
.filter(ExcelFile.created_by == user_id)
|
||||
.order_by(ExcelFile.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all())
|
||||
stmt = select(ExcelFile).where(ExcelFile.created_by == user_id).order_by(ExcelFile.created_at.desc()).offset(skip).limit(limit)
|
||||
result = await self.session.execute(stmt)
|
||||
files = result.scalars().all()
|
||||
|
||||
return files, total
|
||||
|
||||
|
|
@ -163,21 +163,21 @@ class ExcelMetadataService:
|
|||
logger.error(f"Error getting user files for user {user_id}: {str(e)}")
|
||||
return [], 0
|
||||
|
||||
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]:
|
||||
async def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]:
|
||||
"""Get Excel file by ID and user ID."""
|
||||
try:
|
||||
return (self.db.query(ExcelFile)
|
||||
.filter(ExcelFile.id == file_id, ExcelFile.created_by == user_id)
|
||||
.first())
|
||||
stmt = select(ExcelFile).where(ExcelFile.id == file_id, ExcelFile.created_by == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file {file_id} for user {user_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: int, user_id: int) -> bool:
|
||||
async def delete_file(self, file_id: int, user_id: int) -> bool:
|
||||
"""Delete Excel file record and physical file."""
|
||||
try:
|
||||
# Get file record
|
||||
excel_file = self.get_file_by_id(file_id, user_id)
|
||||
excel_file = await self.get_file_by_id(file_id, user_id)
|
||||
if not excel_file:
|
||||
return False
|
||||
|
||||
|
|
@ -187,39 +187,40 @@ class ExcelMetadataService:
|
|||
logger.info(f"Deleted physical file: {excel_file.file_path}")
|
||||
|
||||
# Delete database record
|
||||
self.db.delete(excel_file)
|
||||
self.db.commit()
|
||||
await self.session.delete(excel_file)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info(f"Deleted Excel file record with ID {file_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_id}: {str(e)}")
|
||||
self.db.rollback()
|
||||
await self.session.rollback()
|
||||
return False
|
||||
|
||||
def update_last_accessed(self, file_id: int, user_id: int) -> bool:
|
||||
async def update_last_accessed(self, file_id: int, user_id: int) -> bool:
|
||||
"""Update last accessed time for a file."""
|
||||
try:
|
||||
excel_file = self.get_file_by_id(file_id, user_id)
|
||||
excel_file = await self.get_file_by_id(file_id, user_id)
|
||||
if not excel_file:
|
||||
return False
|
||||
|
||||
from sqlalchemy.sql import func
|
||||
excel_file.last_accessed = func.now()
|
||||
self.db.commit()
|
||||
await self.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating last accessed for file {file_id}: {str(e)}")
|
||||
self.db.rollback()
|
||||
await self.session.rollback()
|
||||
return False
|
||||
|
||||
def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
async def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get file summary information for LLM context."""
|
||||
try:
|
||||
files = self.db.query(ExcelFile).filter(ExcelFile.user_id == user_id).all()
|
||||
stmt = select(ExcelFile).where(ExcelFile.user_id == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
files = result.scalars().all()
|
||||
|
||||
summary = []
|
||||
for file in files:
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ class KnowledgeBaseService:
|
|||
Args:
|
||||
session (Session): 数据库会话,用于执行ORM操作。
|
||||
"""
|
||||
if session is None:
|
||||
logger.error("session为空,session must be an instance of Session")
|
||||
self.session = session
|
||||
|
||||
def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
|
||||
async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
|
||||
"""创建一个新的知识库实例。
|
||||
|
||||
Args:
|
||||
|
|
@ -57,130 +59,22 @@ class KnowledgeBaseService:
|
|||
collection_name=collection_name
|
||||
)
|
||||
|
||||
# Set audit fields
|
||||
# 自动更新created_by和updated_by字段
|
||||
kb.set_audit_fields()
|
||||
|
||||
self.session.add(kb)
|
||||
self.session.commit()
|
||||
self.session.refresh(kb)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(kb)
|
||||
|
||||
logger.info(f"Created knowledge base: {kb.name} (ID: {kb.id})")
|
||||
self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}"
|
||||
return kb
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to create knowledge base: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
|
||||
"""根据ID获取知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 知识库实例的ID。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取当前用户的知识库实例。
|
||||
|
||||
Args:
|
||||
name (str): 知识库实例的名称。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.name == name,
|
||||
KnowledgeBase.created_by == UserContext.get_current_user().id
|
||||
)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
|
||||
"""获取当前用户的所有知识库的列表。
|
||||
|
||||
Args:
|
||||
skip (int, optional): 跳过的记录数。默认值为0。
|
||||
limit (int, optional): 返回的最大记录数。默认值为50。
|
||||
active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。
|
||||
|
||||
Returns:
|
||||
List[KnowledgeBase]: 当前用户的知识库列表。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user().id)
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(KnowledgeBase.is_active == True)
|
||||
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
return (await self.session.execute(stmt)).scalars().all()
|
||||
|
||||
|
||||
def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待更新的知识库实例ID。
|
||||
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。
|
||||
|
||||
Raises:
|
||||
Exception: 如果更新过程中发生错误。
|
||||
"""
|
||||
try:
|
||||
kb = self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
# Update fields
|
||||
update_data = kb_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(kb, field, value)
|
||||
|
||||
# Set audit fields
|
||||
kb.set_audit_fields(is_update=True)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(kb)
|
||||
|
||||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
|
||||
return kb
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb_id} 失败: {str(e)}"
|
||||
raise
|
||||
|
||||
def delete_knowledge_base(self, kb_id: int) -> bool:
|
||||
"""删除知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待删除的知识库实例ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果知识库实例被成功删除则返回True,否则返回False。
|
||||
|
||||
Raises:
|
||||
Exception: 如果删除过程中发生错误。
|
||||
"""
|
||||
kb = self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
# TODO: Clean up vector database collection
|
||||
# This should be implemented when vector database service is ready
|
||||
|
||||
self.session.delete(kb)
|
||||
self.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
|
||||
|
||||
async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
|
||||
"""Search knowledge bases by name or description for the current user.
|
||||
|
||||
Args:
|
||||
|
|
@ -192,7 +86,7 @@ class KnowledgeBaseService:
|
|||
List[KnowledgeBase]: List of matching knowledge bases.
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.created_by == UserContext.get_current_user().id,
|
||||
KnowledgeBase.created_by == UserContext.get_current_user()['id'],
|
||||
KnowledgeBase.is_active == True,
|
||||
or_(
|
||||
KnowledgeBase.name.ilike(f"%{query}%"),
|
||||
|
|
@ -201,7 +95,7 @@ class KnowledgeBaseService:
|
|||
)
|
||||
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
return (await self.session.execute(stmt)).scalars().all()
|
||||
|
||||
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
|
||||
"""Search in knowledge base using vector similarity.
|
||||
|
|
@ -219,7 +113,7 @@ class KnowledgeBaseService:
|
|||
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
|
||||
|
||||
# Use document processor for vector search
|
||||
search_results = get_document_processor().search_similar_documents(
|
||||
search_results = (await get_document_processor(self.session)).search_similar_documents(
|
||||
knowledge_base_id=kb_id,
|
||||
query=query,
|
||||
k=top_k
|
||||
|
|
@ -246,4 +140,105 @@ class KnowledgeBaseService:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
|
||||
return []
|
||||
return []
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取当前用户的知识库实例。
|
||||
|
||||
Args:
|
||||
name (str): 知识库实例的名称。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.name == name,
|
||||
KnowledgeBase.created_by == UserContext.get_current_user()['id']
|
||||
)
|
||||
result = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
|
||||
"""获取当前用户的所有知识库的列表。
|
||||
|
||||
Args:
|
||||
skip (int, optional): 跳过的记录数。默认值为0。
|
||||
limit (int, optional): 返回的最大记录数。默认值为50。
|
||||
active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。
|
||||
|
||||
Returns:
|
||||
List[KnowledgeBase]: 当前用户的知识库列表。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID
|
||||
if active_only:
|
||||
stmt = stmt.where(KnowledgeBase.is_active == True)
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
return (await self.session.execute(stmt)).scalars().all()
|
||||
|
||||
async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
|
||||
"""根据ID获取知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 知识库实例的ID。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待更新的知识库实例ID。
|
||||
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。
|
||||
|
||||
Raises:
|
||||
Exception: 如果更新过程中发生错误。
|
||||
"""
|
||||
kb = await self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
# Update fields
|
||||
update_data = kb_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(kb, field, value)
|
||||
|
||||
# Set audit fields
|
||||
kb.set_audit_fields(is_update=True)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(kb)
|
||||
|
||||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
|
||||
return kb
|
||||
|
||||
async def delete_knowledge_base(self, kb_id: int) -> bool:
|
||||
"""删除知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待删除的知识库实例ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果知识库实例被成功删除则返回True,否则返回False。
|
||||
|
||||
Raises:
|
||||
Exception: 如果删除过程中发生错误。
|
||||
"""
|
||||
kb = await self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
# TODO: Clean up vector database collection
|
||||
# This should be implemented when vector database service is ready
|
||||
|
||||
await self.session.delete(kb)
|
||||
await self.session.commit()
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_postgres import PGVector
|
||||
from .embedding_factory import EmbeddingFactory
|
||||
|
||||
|
|
@ -21,16 +21,16 @@ from .conversation import ConversationService
|
|||
from .document_processor import get_document_processor
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class KnowledgeChatService:
|
||||
"""Knowledge base chat service using LangChain RAG."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.conversation_service = ConversationService(session)
|
||||
|
||||
async def initialize(self):
|
||||
# 获取当前LLM配置
|
||||
llm_config = settings.llm.get_current_config()
|
||||
llm_config = await settings.llm.get_current_config(self.session)
|
||||
|
||||
# Initialize LangChain ChatOpenAI
|
||||
self.llm = ChatOpenAI(
|
||||
|
|
@ -53,41 +53,23 @@ class KnowledgeChatService:
|
|||
)
|
||||
|
||||
# Initialize embeddings based on provider
|
||||
self.embeddings = EmbeddingFactory.create_embeddings()
|
||||
|
||||
logger.info(f"Knowledge Chat Service initialized with model: {self.llm.model_name}")
|
||||
|
||||
def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
|
||||
self.embeddings = await EmbeddingFactory.create_embeddings(self.session)
|
||||
async def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
|
||||
"""Get vector store for knowledge base."""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# 使用PGVector
|
||||
doc_processor = get_document_processor()
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
vector_store = PGVector(
|
||||
connection=doc_processor.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
|
||||
return vector_store
|
||||
else:
|
||||
# 兼容Chroma模式
|
||||
import os
|
||||
kb_vector_path = os.path.join(get_document_processor().vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(kb_vector_path):
|
||||
logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}")
|
||||
return None
|
||||
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
return vector_store
|
||||
import os
|
||||
kb_vector_path = os.path.join((await get_document_processor(self.session)).vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(kb_vector_path):
|
||||
logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}")
|
||||
return None
|
||||
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
return vector_store
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load vector store for KB {knowledge_base_id}: {str(e)}")
|
||||
|
|
@ -159,7 +141,7 @@ class KnowledgeChatService:
|
|||
|
||||
try:
|
||||
# Get conversation and validate
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
|
|
@ -169,14 +151,14 @@ class KnowledgeChatService:
|
|||
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
|
||||
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||||
conversation_history = self._prepare_conversation_history(messages)
|
||||
|
||||
# Create RAG chain
|
||||
|
|
@ -203,7 +185,7 @@ class KnowledgeChatService:
|
|||
response_content = await asyncio.to_thread(rag_chain.invoke, message)
|
||||
|
||||
# Save assistant message with context
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=response_content,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -251,14 +233,14 @@ class KnowledgeChatService:
|
|||
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
|
||||
|
||||
# Get conversation history
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
messages = await self.conversation_service.get_conversation_messages(conversation_id)
|
||||
conversation_history = self._prepare_conversation_history(messages)
|
||||
|
||||
# Create RAG chain
|
||||
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
|
||||
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
|
|
@ -269,7 +251,7 @@ class KnowledgeChatService:
|
|||
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
||||
|
||||
# Create streaming LLM
|
||||
llm_config = settings.llm.get_current_config()
|
||||
llm_config = await settings.llm.get_current_config()
|
||||
streaming_llm = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
temperature=temperature or llm_config["temperature"],
|
||||
|
|
@ -315,7 +297,7 @@ class KnowledgeChatService:
|
|||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -331,7 +313,7 @@ class KnowledgeChatService:
|
|||
yield error_message
|
||||
|
||||
# Save error message
|
||||
self.conversation_service.add_message(
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT
|
||||
|
|
|
|||
|
|
@ -41,28 +41,27 @@ class StreamingCallbackHandler(BaseCallbackHandler):
|
|||
class LangChainChatService:
|
||||
"""LangChain-based chat service for AI model integration."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.conversation_service = ConversationService(session)
|
||||
|
||||
from ..core.llm import create_llm
|
||||
async def initialize(self):
|
||||
from ..core.new_agent import new_agent
|
||||
|
||||
# 添加调试日志
|
||||
logger.info(f"LLM Provider: {settings.llm.provider}")
|
||||
# Initialize LangChain ChatOpenAI
|
||||
self.llm = await new_agent(self.session, streaming=False)
|
||||
self.session.desc = "LangChainChatService初始化 - llm 实例化完毕"
|
||||
|
||||
# Initialize LangChain ChatOpenAI
|
||||
self.llm = create_llm(streaming=False)
|
||||
|
||||
# Streaming LLM for stream responses
|
||||
self.streaming_llm = create_llm(streaming=True)
|
||||
# Streaming LLM for stream responses
|
||||
self.streaming_llm = await new_agent(self.session, streaming=True)
|
||||
self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕"
|
||||
|
||||
self.streaming_handler = StreamingCallbackHandler()
|
||||
|
||||
logger.info(f"LangChain ChatService initialized with model: {self.llm.model_name}")
|
||||
self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕"
|
||||
|
||||
def _prepare_langchain_messages(self, conversation, history: List) -> List:
|
||||
"""Prepare messages for LangChain format."""
|
||||
messages = []
|
||||
messages = []
|
||||
|
||||
# Add system message if conversation has system prompt
|
||||
if hasattr(conversation, 'system_prompt') and conversation.system_prompt:
|
||||
|
|
@ -101,19 +100,19 @@ class LangChainChatService:
|
|||
|
||||
try:
|
||||
# Get conversation details
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
# Add user message to database
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = self.conversation_service.get_conversation_history(
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id, limit=20
|
||||
)
|
||||
|
||||
|
|
@ -123,7 +122,7 @@ class LangChainChatService:
|
|||
# Update LLM parameters if provided
|
||||
llm_to_use = self.llm
|
||||
if temperature is not None or max_tokens is not None:
|
||||
llm_config = settings.llm.get_current_config()
|
||||
llm_config = await settings.llm.get_current_config()
|
||||
llm_to_use = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
openai_api_key=llm_config["api_key"],
|
||||
|
|
@ -140,7 +139,7 @@ class LangChainChatService:
|
|||
assistant_content = response.content
|
||||
|
||||
# Add assistant message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=assistant_content,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -152,7 +151,7 @@ class LangChainChatService:
|
|||
)
|
||||
|
||||
# Update conversation timestamp
|
||||
self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
await self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
|
||||
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
|
||||
|
||||
|
|
@ -171,7 +170,7 @@ class LangChainChatService:
|
|||
error_message = self._format_error_message(e)
|
||||
|
||||
# Add error message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -206,33 +205,33 @@ class LangChainChatService:
|
|||
max_tokens: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Send a message and get streaming AI response using LangChain."""
|
||||
logger.info(f"Processing LangChain streaming chat request for conversation {conversation_id}")
|
||||
logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}")
|
||||
|
||||
try:
|
||||
# Get conversation details
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
conversation = await self.conversation_service.get_conversation(conversation_id)
|
||||
conv = conversation.to_dict()
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
# Add user message to database
|
||||
user_message = self.conversation_service.add_message(
|
||||
user_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = self.conversation_service.get_conversation_history(
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id, limit=20
|
||||
)
|
||||
|
||||
# Prepare messages for LangChain
|
||||
langchain_messages = self._prepare_langchain_messages(conversation, history)
|
||||
|
||||
langchain_messages = self._prepare_langchain_messages(conv, history)
|
||||
# Update streaming LLM parameters if provided
|
||||
streaming_llm_to_use = self.streaming_llm
|
||||
if temperature is not None or max_tokens is not None:
|
||||
llm_config = settings.llm.get_current_config()
|
||||
llm_config = await settings.llm.get_current_config()
|
||||
streaming_llm_to_use = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
openai_api_key=llm_config["api_key"],
|
||||
|
|
@ -241,33 +240,35 @@ class LangChainChatService:
|
|||
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Clear previous streaming handler state
|
||||
self.streaming_handler.clear()
|
||||
|
||||
# Stream response
|
||||
# Stream response
|
||||
full_response = ""
|
||||
async for chunk in streaming_llm_to_use.astream(langchain_messages):
|
||||
# Handle different chunk types to avoid KeyError
|
||||
chunk_content = None
|
||||
if hasattr(chunk, 'content'):
|
||||
# For object-like chunks with content attribute
|
||||
chunk_content = chunk.content
|
||||
elif isinstance(chunk, dict) and 'content' in chunk:
|
||||
# For dict-like chunks with content key
|
||||
chunk_content = chunk['content']
|
||||
elif isinstance(chunk, dict) and 'error' in chunk:
|
||||
# Handle error chunks explicitly
|
||||
logger.error(f"Error in LLM response: {chunk['error']}")
|
||||
yield self._format_error_message(Exception(chunk['error']))
|
||||
continue
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
yield chunk_content
|
||||
|
||||
try:
|
||||
async for chunk in streaming_llm_to_use._astream(langchain_messages):
|
||||
# Handle different chunk types to avoid KeyError
|
||||
chunk_content = None
|
||||
if hasattr(chunk, 'content'):
|
||||
# For object-like chunks with content attribute
|
||||
chunk_content = chunk.content
|
||||
elif isinstance(chunk, dict) and 'content' in chunk:
|
||||
# For dict-like chunks with content key
|
||||
chunk_content = chunk['content']
|
||||
elif isinstance(chunk, dict) and 'error' in chunk:
|
||||
# Handle error chunks explicitly
|
||||
logger.error(f"Error in LLM response: {chunk['error']}")
|
||||
yield self._format_error_message(Exception(chunk['error']))
|
||||
continue
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
yield chunk_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM streaming: {e}")
|
||||
yield f"{self._format_error_message(e)} >>> {e}"
|
||||
# Add complete assistant message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
assistant_message = await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -280,13 +281,12 @@ class LangChainChatService:
|
|||
)
|
||||
|
||||
# Update conversation timestamp
|
||||
self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
|
||||
logger.info(f"Successfully processed LangChain streaming chat request for conversation {conversation_id}")
|
||||
await self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
# 安全地格式化异常信息,避免再次引发KeyError
|
||||
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id}"
|
||||
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}"
|
||||
logger.error(error_info, exc_info=True)
|
||||
|
||||
# Format error message for user
|
||||
|
|
@ -294,7 +294,7 @@ class LangChainChatService:
|
|||
yield error_message
|
||||
|
||||
# Add error message to database
|
||||
self.conversation_service.add_message(
|
||||
await self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -324,23 +324,23 @@ class LangChainChatService:
|
|||
logger.error(f"Failed to get available models: {str(e)}")
|
||||
return ["gpt-3.5-turbo"]
|
||||
|
||||
def update_model_config(
|
||||
async def update_model_config(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
):
|
||||
"""Update LLM configuration."""
|
||||
from ..core.llm import create_llm
|
||||
from ..core.new_agent import new_agent
|
||||
|
||||
# 重新创建LLM实例
|
||||
self.llm = create_llm(
|
||||
self.llm = await new_agent(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
self.streaming_llm = create_llm(
|
||||
self.streaming_llm = await new_agent(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
streaming=True
|
||||
|
|
@ -357,7 +357,7 @@ class LangChainChatService:
|
|||
if "rate limit" in error_str.lower():
|
||||
return "服务器繁忙,请稍后再试。"
|
||||
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
|
||||
return "API认证失败,请检查配置。"
|
||||
return f"API认证失败,请检查配置文件。"
|
||||
elif "timeout" in error_str.lower():
|
||||
return "请求超时,请重试。"
|
||||
elif "connection" in error_str.lower():
|
||||
|
|
|
|||
|
|
@ -11,11 +11,9 @@ from loguru import logger
|
|||
class LLMConfigService:
|
||||
"""LLM配置管理服务"""
|
||||
|
||||
def __init__(self, db_session: Optional[Session] = None):
|
||||
self.db = db_session or get_session() # TODO DrGraph:检查异步
|
||||
|
||||
def get_default_chat_config(self) -> Optional[LLMConfig]:
|
||||
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
|
||||
"""获取默认对话模型配置"""
|
||||
# async for session in get_session():
|
||||
try:
|
||||
stmt = select(LLMConfig).where(
|
||||
and_(
|
||||
|
|
@ -24,7 +22,7 @@ class LLMConfigService:
|
|||
LLMConfig.is_active == True
|
||||
)
|
||||
)
|
||||
config = self.db.execute(stmt).scalar_one_or_none()
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
logger.warning("未找到默认对话模型配置")
|
||||
|
|
@ -36,7 +34,7 @@ class LLMConfigService:
|
|||
logger.error(f"获取默认对话模型配置失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_default_embedding_config(self) -> Optional[LLMConfig]:
|
||||
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
|
||||
"""获取默认嵌入模型配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(
|
||||
|
|
@ -46,23 +44,27 @@ class LLMConfigService:
|
|||
LLMConfig.is_active == True
|
||||
)
|
||||
)
|
||||
config = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
config = None
|
||||
if session != None:
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not config:
|
||||
logger.warning("未找到默认嵌入模型配置")
|
||||
if session != None:
|
||||
session.desc = "ERROR: 未找到默认嵌入模型配置"
|
||||
return None
|
||||
|
||||
session.desc = f"获取默认嵌入模型配置 > 结果:{config}"
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取默认嵌入模型配置失败: {str(e)}")
|
||||
if session != None:
|
||||
session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}"
|
||||
return None
|
||||
|
||||
def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
|
||||
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
|
||||
"""根据ID获取配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
return (await self.db.execute(stmt)).scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置失败: {str(e)}")
|
||||
return None
|
||||
|
|
@ -82,17 +84,17 @@ class LLMConfigService:
|
|||
logger.error(f"获取激活配置失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_fallback_chat_config(self) -> Dict[str, Any]:
|
||||
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
|
||||
"""获取fallback对话模型配置(从环境变量)"""
|
||||
from ..core.config import get_settings
|
||||
settings = get_settings()
|
||||
return settings.llm.get_current_config()
|
||||
return await settings.llm.get_current_config()
|
||||
|
||||
def _get_fallback_embedding_config(self) -> Dict[str, Any]:
|
||||
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
|
||||
"""获取fallback嵌入模型配置(从环境变量)"""
|
||||
from ..core.config import get_settings
|
||||
settings = get_settings()
|
||||
return settings.embedding.get_current_config()
|
||||
return await settings.embedding.get_current_config()
|
||||
|
||||
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
|
||||
"""测试配置连接"""
|
||||
|
|
@ -110,12 +112,12 @@ class LLMConfigService:
|
|||
logger.error(f"测试配置失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
# 全局实例
|
||||
_llm_config_service = None
|
||||
# # 全局实例
|
||||
# _llm_config_service = None
|
||||
|
||||
def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
|
||||
"""获取LLM配置服务实例"""
|
||||
global _llm_config_service
|
||||
if _llm_config_service is None or db_session is not None:
|
||||
_llm_config_service = LLMConfigService(db_session)
|
||||
return _llm_config_service
|
||||
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
|
||||
# """获取LLM配置服务实例"""
|
||||
# global _llm_config_service
|
||||
# if _llm_config_service is None or db_session is not None:
|
||||
# _llm_config_service = LLMConfigService(db_session)
|
||||
# return _llm_config_service
|
||||
|
|
@ -49,8 +49,9 @@ class SmartDatabaseWorkflowManager:
|
|||
self.db = db
|
||||
self.table_metadata_service = TableMetadataService(db) if db else None
|
||||
|
||||
from ..core.llm import create_llm
|
||||
self.llm = create_llm()
|
||||
async def initialize(self):
|
||||
from ..core.new_agent import new_agent
|
||||
self.llm = await new_agent()
|
||||
|
||||
def _get_database_tool(self, db_type: str):
|
||||
"""根据数据库类型获取对应的数据库工具"""
|
||||
|
|
|
|||
|
|
@ -54,9 +54,10 @@ class SmartExcelWorkflowManager:
|
|||
else:
|
||||
self.metadata_service = None
|
||||
|
||||
from ..core.llm import create_llm
|
||||
async def initialize(self):
|
||||
from ..core.new_agent import new_agent
|
||||
# 禁用流式响应,避免pandas代理兼容性问题
|
||||
self.llm = create_llm(streaming=False)
|
||||
self.llm = await new_agent(streaming=False)
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行阻塞函数"""
|
||||
|
|
|
|||
|
|
@ -15,8 +15,12 @@ class SmartWorkflowManager:
|
|||
|
||||
def __init__(self, db=None):
|
||||
self.db = db
|
||||
self.excel_workflow = SmartExcelWorkflowManager(db)
|
||||
self.database_workflow = SmartDatabaseWorkflowManager(db)
|
||||
|
||||
async def initialize(self):
|
||||
self.excel_workflow = SmartExcelWorkflowManager(self.db)
|
||||
await self.excel_workflow.initialize()
|
||||
self.database_workflow = SmartDatabaseWorkflowManager(self.db)
|
||||
await self.database_workflow.initialize()
|
||||
|
||||
async def process_excel_query_stream(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class TableMetadataService:
|
|||
DatabaseConfig.id == database_config_id,
|
||||
DatabaseConfig.created_by == user_id
|
||||
)
|
||||
db_config = self.session.scalar_one_or_none(stmt)
|
||||
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not db_config:
|
||||
self.session.desc = "ERROR: 数据库配置不存在"
|
||||
|
|
@ -202,7 +202,7 @@ class TableMetadataService:
|
|||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
existing = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
|
|
@ -216,8 +216,8 @@ class TableMetadataService:
|
|||
existing.table_comment = metadata['table_comment']
|
||||
existing.last_synced_at = datetime.utcnow()
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(existing)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
|
|
@ -240,8 +240,8 @@ class TableMetadataService:
|
|||
)
|
||||
|
||||
self.session.add(table_metadata)
|
||||
self.session.commit()
|
||||
self.session.refresh(table_metadata)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(table_metadata)
|
||||
self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
return table_metadata
|
||||
|
||||
|
|
@ -258,7 +258,7 @@ class TableMetadataService:
|
|||
DatabaseConfig.id == database_config_id,
|
||||
DatabaseConfig.user_id == user_id
|
||||
)
|
||||
db_config = self.session.scalar_one_or_none(stmt)
|
||||
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not db_config:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
|
||||
|
|
@ -275,7 +275,7 @@ class TableMetadataService:
|
|||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
existing = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置"
|
||||
|
|
@ -320,7 +320,7 @@ class TableMetadataService:
|
|||
})
|
||||
|
||||
# 提交事务
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_names} 的元数据配置"
|
||||
return {
|
||||
'saved_tables': saved_tables,
|
||||
|
|
@ -330,7 +330,7 @@ class TableMetadataService:
|
|||
}
|
||||
|
||||
|
||||
def get_user_table_metadata(
|
||||
async def get_user_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: Optional[int] = None
|
||||
|
|
@ -345,9 +345,9 @@ class TableMetadataService:
|
|||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
stmt = stmt.where(TableMetadata.is_enabled_for_qa == True)
|
||||
return self.session.scalars(stmt).all()
|
||||
return (await self.session.scalars(stmt)).all()
|
||||
|
||||
def get_table_metadata_by_name(
|
||||
async def get_table_metadata_by_name(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
|
|
@ -360,9 +360,9 @@ class TableMetadataService:
|
|||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
return self.session.scalar_one_or_none(stmt)
|
||||
return (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
def update_table_qa_settings(
|
||||
async def update_table_qa_settings(
|
||||
self,
|
||||
user_id: int,
|
||||
metadata_id: int,
|
||||
|
|
@ -375,7 +375,7 @@ class TableMetadataService:
|
|||
TableMetadata.id == metadata_id,
|
||||
TableMetadata.created_by == user_id
|
||||
)
|
||||
metadata = self.session.scalar_one_or_none(stmt)
|
||||
metadata = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not metadata:
|
||||
self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在"
|
||||
|
|
@ -388,15 +388,15 @@ class TableMetadataService:
|
|||
if 'business_context' in settings:
|
||||
metadata.business_context = settings['business_context']
|
||||
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}"
|
||||
self.session.rollback()
|
||||
await self.session.rollback()
|
||||
return False
|
||||
|
||||
def save_table_metadata(
|
||||
async def save_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
|
|
@ -414,7 +414,7 @@ class TableMetadataService:
|
|||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
existing = (await self.session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 已存在,更新其元数据"
|
||||
|
|
@ -424,7 +424,7 @@ class TableMetadataService:
|
|||
existing.row_count = row_count
|
||||
existing.table_comment = table_comment
|
||||
existing.last_synced_at = datetime.utcnow()
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
return existing
|
||||
else:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 不存在,创建新记录"
|
||||
|
|
@ -444,8 +444,8 @@ class TableMetadataService:
|
|||
)
|
||||
|
||||
self.session.add(metadata)
|
||||
self.session.commit()
|
||||
self.session.refresh(metadata)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(metadata)
|
||||
return metadata
|
||||
|
||||
def _decrypt_password(self, encrypted_password: str) -> str:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langchain.tools import BaseTool
|
|||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Optional, Type, ClassVar
|
||||
from langchain_tavily import TavilySearch
|
||||
|
||||
# 定义输入参数模型(替代原get_parameters())
|
||||
class SearchInput(BaseModel):
|
||||
|
|
@ -35,7 +36,7 @@ class TavilySearchTool(BaseTool):
|
|||
raise ValueError("Tavily API key not found in settings")
|
||||
|
||||
# 初始化Tavily客户端
|
||||
self._search_client = TavilySearchResults(
|
||||
self._search_client = TavilySearch(
|
||||
tavily_api_key=self._tavily_api_key
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -126,10 +126,10 @@ class UserService:
|
|||
self.session.desc = f"用户ID为{user_id}的密码已成功更改"
|
||||
return True
|
||||
|
||||
def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
|
||||
async def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
|
||||
"""Reset user password (admin only, no current password required)."""
|
||||
self.session.desc = f"重置用户ID为{user_id}的密码"
|
||||
user = self.get_user_by_id(user_id)
|
||||
user = await self.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise ValidationError("User not found")
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ class UserService:
|
|||
hashed_password = self.get_password_hash(new_password)
|
||||
# Update password
|
||||
user.hashed_password = hashed_password
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
self.session.desc = f"用户ID为{user_id}的密码已成功重置"
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ class WorkflowEngine:
|
|||
execution.set_audit_fields(user_id)
|
||||
|
||||
self.session.add(execution)
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(execution)
|
||||
|
||||
|
||||
|
||||
|
|
@ -75,8 +75,8 @@ class WorkflowEngine:
|
|||
|
||||
|
||||
execution.set_audit_fields(user_id, is_update=True)
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(execution)
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
|
||||
|
|
@ -100,8 +100,8 @@ class WorkflowEngine:
|
|||
execution.set_audit_fields(user_id)
|
||||
|
||||
self.session.add(execution)
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(execution)
|
||||
|
||||
# 发送工作流开始执行的消息
|
||||
yield {
|
||||
|
|
@ -170,11 +170,8 @@ class WorkflowEngine:
|
|||
}
|
||||
|
||||
execution.set_audit_fields(user_id, is_update=True)
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(execution)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(execution)
|
||||
|
||||
def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""构建节点依赖图"""
|
||||
|
|
@ -385,8 +382,8 @@ class WorkflowEngine:
|
|||
started_at=datetime.now().isoformat()
|
||||
)
|
||||
self.session.add(node_execution)
|
||||
self.session.commit()
|
||||
self.session.refresh(node_execution)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(node_execution)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -419,7 +416,7 @@ class WorkflowEngine:
|
|||
}
|
||||
|
||||
node_execution.input_data = display_input_data
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
|
||||
|
|
@ -446,7 +443,7 @@ class WorkflowEngine:
|
|||
node_execution.completed_at = datetime.now().isoformat()
|
||||
node_execution.duration_ms = int((end_time - start_time) * 1000)
|
||||
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
|
||||
|
|
@ -459,7 +456,7 @@ class WorkflowEngine:
|
|||
node_execution.error_message = str(e)
|
||||
node_execution.completed_at = datetime.now().isoformat()
|
||||
node_execution.duration_ms = int((end_time - start_time) * 1000)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
|
||||
|
|
@ -918,8 +915,8 @@ class WorkflowEngine:
|
|||
|
||||
|
||||
# 工作流引擎实例
|
||||
def get_workflow_engine(session: Session = None) -> WorkflowEngine:
|
||||
async def get_workflow_engine(session: Session = None) -> WorkflowEngine:
|
||||
"""获取工作流引擎实例"""
|
||||
if session is None:
|
||||
session = next(get_session())
|
||||
session = await anext(get_session())
|
||||
return WorkflowEngine(session)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,490 @@
|
|||
from PySide6.QtCore import QPointF, QRectF, QSizeF
|
||||
# -*- coding: utf-8 -*-
|
||||
class Constant:
|
||||
invalid_point = QPointF(-123.4, -234.71) # 无效点
|
||||
invalid_rect = QRectF(invalid_point, QSizeF(0, 0)) # 无效矩形
|
||||
service_yml_path = 'config/service/dsp_%s_service.yml' # 服务配置路径
|
||||
kafka_yml_path = 'config/kafka/dsp_%s_kafka.yml' # kafka配置路径
|
||||
aliyun_yml_path = "config/aliyun/dsp_%s_aliyun.yml" # 阿里云配置路径
|
||||
baidu_yml_path = 'config/baidu/dsp_%s_baidu.yml' # 百度配置路径
|
||||
pull_frame_width = 1400 # 拉流帧宽度
|
||||
|
||||
LLM_MODE_NONE = 0
|
||||
LLM_MODE_NONCHAT = 1
|
||||
LLM_MODE_CHAT = 2
|
||||
LLM_MODE_EMBEDDING = 3
|
||||
LLM_MODE_LOCAL_OLLAMA = 4
|
||||
|
||||
LLM_PROMPT_TEMPLATE_METHOD_FORMAT = "format"
|
||||
LLM_PROMPT_TEMPLATE_METHOD_INVOKE = "invoke"
|
||||
|
||||
LLM_PROMPT_VALUE_STR = "str"
|
||||
LLM_PROMPT_VALUE_MESSAGES = "messages"
|
||||
LLM_PROMPT_VALUE_VALUE = "promptValue"
|
||||
|
||||
INPUT_NONE = 0
|
||||
INPUT_PULL_STREAM = 1
|
||||
INPUT_NET_CAMERA = 2
|
||||
INPUT_USB_CAMERA = 3
|
||||
INPUT_LOCAL_VIDEO = 4
|
||||
INPUT_LOCAL_DIR = 5
|
||||
INPUT_LOCAL_FILE = 6
|
||||
INPUT_LOCAL_LABELLED_DIR = 7
|
||||
INPUT_LOCAL_LABELLED_ZIP = 8
|
||||
|
||||
SYSTEM_BUTTON_SNAP = 1
|
||||
SYSTEM_BUTTON_CLEAR_LOG = 2
|
||||
|
||||
MODE_OCR = 7
|
||||
|
||||
WEB_OWNER_NONE = 0
|
||||
WEB_OWNER_ALG_TEST = 1
|
||||
WEB_OWNER_LLM = 2
|
||||
|
||||
PAGE_MODE_NONE = 0
|
||||
PAGE_MODE_VIDEO = 1
|
||||
PAGE_MODE_LABELLING = 2
|
||||
PAGE_MODE_ALG_TEST = 3
|
||||
PAGE_MODE_WEB = 4
|
||||
PAGE_MODE_LLM = 5
|
||||
PAGE_MODE_CONFIG = 6
|
||||
|
||||
ALG_TEST_LOAD_DATA = 1
|
||||
ALG_TEST_TRAIN = 2
|
||||
ALG_TEST_INFER = 3
|
||||
|
||||
ALG_STATUS_NOTHING = 0 # 啥也没干
|
||||
ALG_STATUS_MODEL_READY = 1 # 模型已就绪
|
||||
ALG_STATUS_DATA_LOADED = 2 # 加载已数据/加载中
|
||||
ALG_STATUS_TRAINED = 4 # 训练完成/训练中
|
||||
ALG_STATUS_INFER = 8 # 推理完成/推理中
|
||||
|
||||
UTF_8 = "utf-8" # 编码格式
|
||||
|
||||
COLOR = (
|
||||
[0, 0, 255],
|
||||
[255, 0, 0],
|
||||
[211, 0, 148],
|
||||
[0, 127, 0],
|
||||
[0, 69, 255],
|
||||
[0, 255, 0],
|
||||
[255, 0, 255],
|
||||
[0, 0, 127],
|
||||
[127, 0, 255],
|
||||
[255, 129, 0],
|
||||
[139, 139, 0],
|
||||
[255, 255, 0],
|
||||
[127, 255, 0],
|
||||
[0, 127, 255],
|
||||
[0, 255, 127],
|
||||
[255, 127, 255],
|
||||
[8, 101, 139],
|
||||
[171, 130, 255],
|
||||
[139, 112, 74],
|
||||
[205, 205, 180])
|
||||
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
PHOTO = "photo"
|
||||
RECORDING = "recording"
|
||||
|
||||
ONLINE_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'nullable': False,
|
||||
'minlength': 1,
|
||||
'maxlength': 3,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ONLINE_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
OFFLINE_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'maxlength': 3,
|
||||
'minlength': 1,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OFFLINE_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
IMAGE_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"image_urls": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'minlength': 1,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 5000
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"code": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "categories",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"is_video": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"is_image": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "code",
|
||||
'allowed': ["0", "1"]
|
||||
},
|
||||
"categories": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'dependencies': "code",
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{0,255}$'},
|
||||
"config": {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'dependencies': "id",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RECORDING_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'maxlength': 255
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'empty': True,
|
||||
'maxlength': 255
|
||||
},
|
||||
"logo_url": {
|
||||
'type': 'string',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'maxlength': 255
|
||||
}
|
||||
}
|
||||
|
||||
RECORDING_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["stop"]
|
||||
}
|
||||
}
|
||||
|
||||
PULL2PUSH_START_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start"]
|
||||
},
|
||||
"video_urls": {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'nullable': False,
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'schema': {
|
||||
"id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "pull_url",
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
},
|
||||
"pull_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "push_url",
|
||||
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
|
||||
},
|
||||
"push_url": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'dependencies': "id",
|
||||
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PULL2PUSH_STOP_SCHEMA = {
|
||||
"request_id": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,36}$'
|
||||
},
|
||||
"command": {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'allowed': ["start", "stop"]
|
||||
},
|
||||
"video_ids": {
|
||||
'type': 'list',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'schema': {
|
||||
'type': 'string',
|
||||
'required': True,
|
||||
'empty': False,
|
||||
'regex': r'^[a-zA-Z0-9]{1,255}$'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from loguru import logger
|
||||
from enum import Enum, unique
|
||||
|
||||
class ServiceException(Exception):
|
||||
def __init__(self, code, msg, desc=None):
|
||||
self.code = code
|
||||
if desc is None:
|
||||
self.msg = msg
|
||||
else:
|
||||
self.msg = msg % desc
|
||||
|
||||
def __str__(self):
|
||||
logger.error("异常编码:{}, 异常描述:{}", self.code, self.msg)
|
||||
|
||||
# 异常枚举
|
||||
@unique
|
||||
class ExceptionType(Enum):
|
||||
OR_VIDEO_ADDRESS_EXCEPTION = ("SP000", "未拉取到视频流, 请检查拉流地址是否有视频流!")
|
||||
ANALYSE_TIMEOUT_EXCEPTION = ("SP001", "AI分析超时!")
|
||||
PULLSTREAM_TIMEOUT_EXCEPTION = ("SP002", "原视频拉流超时!")
|
||||
READSTREAM_TIMEOUT_EXCEPTION = ("SP003", "原视频读取视频流超时!")
|
||||
GET_VIDEO_URL_EXCEPTION = ("SP004", "获取视频播放地址失败!")
|
||||
GET_VIDEO_URL_TIMEOUT_EXCEPTION = ("SP005", "获取原视频播放地址超时!")
|
||||
PULL_STREAM_URL_EXCEPTION = ("SP006", "拉流地址不能为空!")
|
||||
PUSH_STREAM_URL_EXCEPTION = ("SP007", "推流地址不能为空!")
|
||||
PUSH_STREAM_TIME_EXCEPTION = ("SP008", "未生成本地视频地址!")
|
||||
AI_MODEL_MATCH_EXCEPTION = ("SP009", "未匹配到对应的AI模型!")
|
||||
ILLEGAL_PARAMETER_FORMAT = ("SP010", "非法参数格式!")
|
||||
PUSH_STREAMING_CHANNEL_IS_OCCUPIED = ("SP011", "推流通道可能被占用, 请稍后再试!")
|
||||
VIDEO_RESOLUTION_EXCEPTION = ("SP012", "不支持该分辨率类型的视频,请切换分辨率再试!")
|
||||
READ_IAMGE_URL_EXCEPTION = ("SP013", "未能解析图片地址!")
|
||||
DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED = ("SP014", "不支持该类型的检测目标!")
|
||||
WRITE_STREAM_EXCEPTION = ("SP015", "写流异常!")
|
||||
OR_VIDEO_DO_NOT_EXEIST_EXCEPTION = ("SP016", "原视频不存在!")
|
||||
MODEL_LOADING_EXCEPTION = ("SP017", "模型加载异常!")
|
||||
MODEL_ANALYSE_EXCEPTION = ("SP018", "算法模型分析异常!")
|
||||
AI_MODEL_CONFIG_EXCEPTION = ("SP019", "模型配置不能为空!")
|
||||
AI_MODEL_GET_CONFIG_EXCEPTION = ("SP020", "获取模型配置异常, 请检查模型配置是否正确!")
|
||||
MODEL_GROUP_LIMIT_EXCEPTION = ("SP021", "模型组合个数超过限制!")
|
||||
MODEL_NOT_SUPPORT_VIDEO_EXCEPTION = ("SP022", "%s不支持视频识别!")
|
||||
MODEL_NOT_SUPPORT_IMAGE_EXCEPTION = ("SP023", "%s不支持图片识别!")
|
||||
THE_DETECTION_TARGET_CANNOT_BE_EMPTY = ("SP024", "检测目标不能为空!")
|
||||
URL_ADDRESS_ACCESS_FAILED = ("SP025", "URL地址访问失败, 请检测URL地址是否正确!")
|
||||
UNIVERSAL_TEXT_RECOGNITION_FAILED = ("SP026", "识别失败!")
|
||||
COORDINATE_ACQUISITION_FAILED = ("SP027", "飞行坐标识别异常!")
|
||||
PUSH_STREAM_EXCEPTION = ("SP028", "推流异常!")
|
||||
MODEL_DUPLICATE_EXCEPTION = ("SP029", "存在重复模型配置!")
|
||||
DETECTION_TARGET_NOT_SUPPORT = ("SP031", "存在不支持的检测目标!")
|
||||
TASK_EXCUTE_TIMEOUT = ("SP032", "任务执行超时!")
|
||||
PUSH_STREAM_URL_IS_NULL = ("SP033", "拉流、推流地址不能为空!")
|
||||
PULL_STREAM_NUM_LIMIT_EXCEPTION = ("SP034", "转推流数量超过限制!")
|
||||
NOT_REQUESTID_TASK_EXCEPTION = ("SP993", "未查询到该任务,无法停止任务!")
|
||||
NO_RESOURCES = ("SP995", "服务器暂无资源可以使用,请稍后30秒后再试!")
|
||||
NO_CPU_RESOURCES = ("SP996", "暂无CPU资源可以使用,请稍后再试!")
|
||||
SERVICE_COMMON_EXCEPTION = ("SP997", "公共服务异常!")
|
||||
NO_GPU_RESOURCES = ("SP998", "暂无GPU资源可以使用,请稍后再试!")
|
||||
SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常!")
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
class Flag:
|
||||
Unique = True
|
||||
Append = False
|
||||
|
||||
Debug = True
|
||||
|
||||
class Option:
|
||||
NoOption = 0x00
|
||||
|
||||
AddObject_AutoName = 0x01
|
||||
AddObject_Select = 0x02
|
||||
|
|
@ -0,0 +1,668 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import sys, os, cv2
|
||||
from os import makedirs
|
||||
from os.path import join, exists
|
||||
from loguru import logger
|
||||
from json import loads
|
||||
from ruamel.yaml import safe_load, YAML
|
||||
import random, sys, math, inspect, psutil
|
||||
from pathlib import Path
|
||||
from PySide6.QtGui import QIcon, QColor
|
||||
from PySide6.QtCore import QObject, QRectF, QEventLoop, QIODevice, QTextStream, QFile
|
||||
from PySide6.QtWidgets import QApplication
|
||||
|
||||
import DrGraph.utils.vclEnums as enums
|
||||
#region Property
|
||||
class Property:
|
||||
def __init__(self, read_func=None, write_func=None, default=None, hasMember=True):
|
||||
self.read_func = read_func
|
||||
self.write_func = write_func
|
||||
self.default = default
|
||||
self.owner_class = None
|
||||
self.private_name = None
|
||||
self.hasMember = hasMember
|
||||
def __set_name__(self, owner, name):
|
||||
if self.hasMember:
|
||||
self.private_name = f"_{name}"
|
||||
self.owner_class = owner
|
||||
def callDirectGet(self, instance, owner):
|
||||
if instance is None or not self.hasMember:
|
||||
return self.default
|
||||
if not hasattr(instance, self.private_name):
|
||||
return self.default
|
||||
return getattr(instance, self.private_name)
|
||||
def callCustomGet(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
if self.read_func is None:
|
||||
if hasattr(instance, self.private_name):
|
||||
return getattr(instance, self.private_name)
|
||||
return self.default
|
||||
|
||||
try:
|
||||
if isinstance(self.read_func, str):
|
||||
if hasattr(instance, self.read_func):
|
||||
method = getattr(instance, self.read_func)
|
||||
return method()
|
||||
elif hasattr(self.read_func, '__name__'):
|
||||
method_name = self.read_func.__name__
|
||||
if hasattr(instance, method_name):
|
||||
method = getattr(instance, method_name)
|
||||
return method()
|
||||
elif callable(self.read_func):
|
||||
try:
|
||||
return self.read_func(instance)
|
||||
except (TypeError, AttributeError):
|
||||
return self.read_func()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
if hasattr(instance, self.private_name):
|
||||
return getattr(instance, self.private_name)
|
||||
return self.default
|
||||
|
||||
def callDirectSet(self, instance, value):
|
||||
if instance is None or not self.hasMember:
|
||||
return
|
||||
setattr(instance, self.private_name, value)
|
||||
|
||||
def callCustomSet(self, instance, value):
|
||||
try:
|
||||
if isinstance(self.write_func, str):
|
||||
if hasattr(instance, self.write_func):
|
||||
method = getattr(instance, self.write_func)
|
||||
method(value)
|
||||
elif hasattr(self.write_func, '__name__'):
|
||||
method_name = self.write_func.__name__
|
||||
if hasattr(instance, method_name):
|
||||
method = getattr(instance, method_name)
|
||||
method(value)
|
||||
elif callable(self.write_func):
|
||||
try:
|
||||
self.write_func(instance, value)
|
||||
except (TypeError, AttributeError):
|
||||
self.write_func(value)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
class Property_rw(Property):
|
||||
def __init__(self, default=None, hasMember=True):
|
||||
super().__init__(None, None, default, hasMember)
|
||||
def __get__(self, instance, owner):
|
||||
return self.callDirectGet(instance, owner)
|
||||
def __set__(self, instance, value):
|
||||
self.callDirectSet(instance, value)
|
||||
|
||||
class Property_Rw(Property):
|
||||
def __init__(self, read_func=None, default=None, hasMember=True):
|
||||
super().__init__(read_func, None, default, hasMember)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.callCustomGet(instance, owner)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
setattr(instance, self.private_name, value)
|
||||
|
||||
class Property_rW(Property):
|
||||
def __init__(self, write_func=None, default=None, hasMember=True):
|
||||
super().__init__(None, write_func, default, hasMember)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.callDirectGet(instance, owner)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self.callCustomSet(instance, value)
|
||||
|
||||
class Property_RW(Property):
|
||||
def __init__(self, read_func=None, write_func=None, default=None, hasMember=True):
|
||||
super().__init__(read_func, write_func, default, hasMember)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.callCustomGet(instance, owner)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self.callCustomSet(instance, value)
|
||||
#endregion Property
|
||||
|
||||
class AppHelper(QObject):
|
||||
app = Property_rw(None)
|
||||
def setBriefStatusText(self, text):
|
||||
if self.briefStatusControl:
|
||||
self.briefStatusControl.setText(text)
|
||||
else:
|
||||
print(text)
|
||||
briefStatusText = Property_rW(setBriefStatusText, '')
|
||||
|
||||
def setProgress(self, value):
|
||||
if self.progressBarControl:
|
||||
self.progressBarControl.setValue(value)
|
||||
progress = Property_rW(setProgress, 0)
|
||||
|
||||
def setProgressMax(self, value):
|
||||
if self.progressBarControl:
|
||||
self.progressBarControl.setMaximum(value)
|
||||
progressMax = Property_rW(setProgressMax, 100)
|
||||
|
||||
def setProgressMin(self, value):
|
||||
if self.progressBarControl:
|
||||
self.progressBarControl.setMinimum(value)
|
||||
progressMin = Property_rW(setProgressMin, 0)
|
||||
|
||||
def __init__(self):
|
||||
self.briefStatusControl = None
|
||||
self.progressBarControl = None
|
||||
self._briefStatusText = ''
|
||||
pass
|
||||
|
||||
class Helper:
|
||||
OnLogMsg = None
|
||||
AppFlag_SaveAnalysisResult = True
|
||||
AppFlag_SaveLog = False
|
||||
App = None
|
||||
|
||||
@staticmethod
|
||||
def castRange(value, minValue, maxValue):
|
||||
return max(minValue, min(maxValue, value))
|
||||
# 取得程序目录
|
||||
@staticmethod
|
||||
def getPath_App():
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 如果程序是打包的exe文件
|
||||
return os.path.dirname(sys.executable)
|
||||
else:
|
||||
# 如果是Python脚本 - 获取上两级目录
|
||||
current_file = os.path.abspath(__file__) # f:\PySide6\AiBase\DrGraph\utils\Helper.py
|
||||
current_dir = os.path.dirname(current_file) # f:\PySide6\AiBase\DrGraph\utils
|
||||
parent_dir = os.path.dirname(current_dir) # f:\PySide6\AiBase\DrGraph
|
||||
root_dir = os.path.dirname(parent_dir) # f:\PySide6\AiBase
|
||||
return root_dir
|
||||
|
||||
@staticmethod
|
||||
def fitOS(file_name):
|
||||
if sys.platform.startswith('win'):
|
||||
file_name = file_name.replace('/','\\')
|
||||
else:
|
||||
file_name = file_name.replace('\\', '/')
|
||||
return file_name
|
||||
|
||||
def generateDistinctColors(n, s=0.8, v=0.7):
|
||||
import colorsys
|
||||
colors = []
|
||||
for i in range(n):
|
||||
hue = i * 1.0 / n # 均匀分布在 [0, 1)
|
||||
r, g, b = colorsys.hsv_to_rgb(hue, s, v)
|
||||
colors.append(QColor(r * 255, g * 255, b * 255))
|
||||
return colors
|
||||
|
||||
def setBriefStatusText(self, text):
|
||||
Helper.App.setBriefStatusText(text)
|
||||
briefStatusText = Property_rW(setBriefStatusText, '')
|
||||
@staticmethod
|
||||
def Sleep(msec):
|
||||
QApplication.processEvents(QEventLoop.AllEvents, msec)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def getAbsoluteFileName(file_name):
|
||||
if os.path.isabs(file_name):
|
||||
return Helper.fitOS(file_name)
|
||||
else:
|
||||
return Helper.fitOS(os.path.join(Helper.getPath_App(), file_name))
|
||||
@staticmethod
|
||||
def getConfigs(path, read_type='yml'):
|
||||
"""
|
||||
读取配置文件并返回解析后的配置信息
|
||||
|
||||
:param path: 配置文件路径
|
||||
:param read_type: 配置文件类型,默认为'yml',可选'json'或'yml'
|
||||
:return: 解析后的配置信息,JSON格式返回字典,YML格式返回对应的数据结构
|
||||
:raises Exception: 当无法获取配置信息时抛出异常
|
||||
"""
|
||||
yaml = YAML(typ='safe', pure=True)
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
return yaml.load(f)
|
||||
# with open(path, 'r', encoding="utf-8") as f:
|
||||
# # 根据文件类型选择相应的解析方式
|
||||
# if read_type == 'json':
|
||||
# return loads(f.read())
|
||||
# if read_type == 'yml':
|
||||
# return safe_load(f)
|
||||
# 如果未成功读取配置信息,则抛出异常
|
||||
raise Exception('路径: %s未获取配置信息' % path)
|
||||
|
||||
@staticmethod
|
||||
def getTooltipText(content):
|
||||
# 增加一个小喇叭图标
|
||||
# content = f'<img src="appIOs/res/images/icons/info.png" width="16" height="16"> {content}'
|
||||
return f"""
|
||||
<html>
|
||||
<head/><body>
|
||||
<p><span style=" font-weight:600; color:#ffffff;">DrGraph <img src="appIOs/res/images/icons/Notice.png" width="16" height="16"></span></p>
|
||||
<p>{content}</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@staticmethod
|
||||
def log_init(app, base_dir, env):
|
||||
"""
|
||||
初始化日志配置
|
||||
|
||||
:param base_dir: 基础目录路径,用于定位配置文件和日志文件存储位置
|
||||
:param env: 环境标识,用于加载对应环境的日志配置文件
|
||||
:return: 无返回值
|
||||
"""
|
||||
Helper.App = AppHelper()
|
||||
Helper.App.app = app
|
||||
# QToolTip样式 - 自定义样式 - 增加Header
|
||||
app.setStyleSheet("""
|
||||
QToolTip {
|
||||
background-color: #dd2222;
|
||||
color: #f0f0f0;
|
||||
border: 1px solid #555;
|
||||
border-radius: 4px;
|
||||
padding: 6px;
|
||||
font: 10pt "Segoe UI";
|
||||
opacity: 220;
|
||||
}
|
||||
""")
|
||||
|
||||
log_config = Helper.getConfigs(join(base_dir, 'appIOs/configs/logger/drgraph_%s_logger.yml' % env))
|
||||
# 判断日志文件是否存在,不存在创建
|
||||
base_path = join(base_dir, log_config.get("base_path"))
|
||||
if not exists(base_path):
|
||||
makedirs(base_path)
|
||||
# 移除日志设置
|
||||
logger.remove(handler_id=None)
|
||||
# 打印日志到文件
|
||||
if bool(log_config.get("enable_file_log")):
|
||||
logger.add(join(base_path, log_config.get("log_name")),
|
||||
rotation=log_config.get("rotation"),
|
||||
retention=log_config.get("retention"),
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True,
|
||||
encoding=log_config.get("encoding"))
|
||||
# 控制台输出
|
||||
if bool(log_config.get("enable_stderr")):
|
||||
logger.add(sys.stderr,
|
||||
format=log_config.get("log_fmt"),
|
||||
level=log_config.get("level"),
|
||||
enqueue=True)
|
||||
logger.info("\n\n\n----=========== 日志配置初始化完成, 开始新的日志记录 ==========----")
|
||||
|
||||
@staticmethod
|
||||
def log_info(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'INFO: {msg}', 'black')
|
||||
caller = inspect.stack()[1]
|
||||
logger.info(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "INFO", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_error(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'ERROR: {msg}', 'red')
|
||||
caller = inspect.stack()[1]
|
||||
logger.error(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "ERROR", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_warning(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'WARNING: {msg}', (255, 128, 0))
|
||||
caller = inspect.stack()[1]
|
||||
logger.warning(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "WARNING", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_debug(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'DEBUG: {msg}', (0, 128, 128))
|
||||
caller = inspect.stack()[1]
|
||||
logger.debug(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "DEBUG", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_critical(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'CRITICAL: {msg}', (128, 0, 128))
|
||||
caller = inspect.stack()[1]
|
||||
logger.critical(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "CRITICAL", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_exception(msg, toWss = False):
|
||||
if Helper.OnLogMsg:
|
||||
Helper.OnLogMsg(f'EXCEPTION: {msg}', (255, 140, 0))
|
||||
caller = inspect.stack()[1]
|
||||
logger.exception(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "EXCEPTION", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log(msg, toWss = False):
|
||||
caller = inspect.stack()[1]
|
||||
logger.log(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
|
||||
if toWss:
|
||||
Helper.log_wss({"type": "log", "kind": "LOG", "msg" : msg} )
|
||||
@staticmethod
|
||||
def log_wss(msg):
|
||||
if Helper.wss:
|
||||
Helper.wss.send(msg)
|
||||
|
||||
@staticmethod
|
||||
def getTextSize(font, text):
|
||||
import pygame as pg
|
||||
surface = font.render(text, True, (0, 0, 0))
|
||||
return (surface.get_width(), surface.get_height(), surface)
|
||||
|
||||
@staticmethod
|
||||
def buildSurfaces(font, text, width, color, wordWrap):
|
||||
text = text.strip()
|
||||
w = Helper.getTextSize(font, text)[0]
|
||||
result = []
|
||||
if w > width and wordWrap:
|
||||
segLen = math.floor(width / w * len(text))
|
||||
while len(text):
|
||||
if len(text) < segLen:
|
||||
t = text
|
||||
text = ''
|
||||
else:
|
||||
t = text[:segLen]
|
||||
text = text[segLen:]
|
||||
result.append(font.render(t, True, color))
|
||||
else:
|
||||
result.append(font.render(text, True, color))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def randomColor():
|
||||
'''随机颜色'''
|
||||
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||
|
||||
@staticmethod
|
||||
def reverseColor(color: QColor):
|
||||
'''反转颜色'''
|
||||
return (255 - color.red(), 255 - color.green(), 255 - color.blue())
|
||||
@staticmethod
|
||||
def getRGB(color_value):
|
||||
# 如果是元组或列表形式的RGB值
|
||||
if isinstance(color_value, (tuple, list)):
|
||||
if len(color_value) >= 3:
|
||||
# 取前三个值作为RGB
|
||||
r, g, b = color_value[0], color_value[1], color_value[2]
|
||||
# 确保值在0-255范围内
|
||||
return (max(0, min(255, int(r))),
|
||||
max(0, min(255, int(g))),
|
||||
max(0, min(255, int(b))))
|
||||
|
||||
# 如果是整数形式的颜色值
|
||||
elif isinstance(color_value, int):
|
||||
# 将整数转换为RGB分量
|
||||
# 假设格式为0xRRGGBB
|
||||
r = (color_value >> 16) & 0xFF
|
||||
g = (color_value >> 8) & 0xFF
|
||||
b = color_value & 0xFF
|
||||
return (r, g, b)
|
||||
|
||||
# 如果是字符串形式
|
||||
elif isinstance(color_value, str):
|
||||
# 处理十六进制颜色值
|
||||
if color_value.startswith('#'):
|
||||
hex_value = color_value[1:]
|
||||
if len(hex_value) == 3: # 简写形式 #RGB
|
||||
hex_value = ''.join([c*2 for c in hex_value])
|
||||
if len(hex_value) in (6, 8): # #RRGGBB 或 #RRGGBBAA
|
||||
r = int(hex_value[0:2], 16)
|
||||
g = int(hex_value[2:4], 16)
|
||||
b = int(hex_value[4:6], 16)
|
||||
return (r, g, b)
|
||||
# 处理颜色名称(需要额外的颜色名称映射表)
|
||||
# 这里只列举几种常见颜色
|
||||
color_names = {
|
||||
'black': (0, 0, 0),
|
||||
'white': (255, 255, 255),
|
||||
'red': (255, 0, 0),
|
||||
'green': (0, 255, 0),
|
||||
'blue': (0, 0, 255),
|
||||
'yellow': (255, 255, 0),
|
||||
'magenta': (255, 0, 255),
|
||||
'cyan': (0, 255, 255),
|
||||
'orange': (255, 128, 0), # 根据项目规范
|
||||
'teal': (0, 128, 128) # 根据项目规范
|
||||
}
|
||||
if color_value.lower() in color_names:
|
||||
return color_names[color_value.lower()]
|
||||
|
||||
# 如果是Color对象(如pygame.Color)
|
||||
elif hasattr(color_value, 'r') and hasattr(color_value, 'g') and hasattr(color_value, 'b'):
|
||||
return (color_value.r, color_value.g, color_value.b)
|
||||
|
||||
# 默认返回黑色
|
||||
return (0, 0, 0)
|
||||
|
||||
@staticmethod
|
||||
def check_system_resources():
|
||||
"""检查系统资源使用情况"""
|
||||
logger.info("检查系统资源使用情况...")
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
network = psutil.net_io_counters()
|
||||
|
||||
logger.info("检查系统资源使用情况完毕")
|
||||
|
||||
return {
|
||||
'cpu_percent': cpu_percent,
|
||||
'memory_percent': memory.percent,
|
||||
'memory_available': memory.available / (1024**3), # GB
|
||||
'network_bytes_sent': int(network.bytes_sent / 1024),
|
||||
'network_bytes_recv': int(network.bytes_recv / 1024)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_response(type, status : enums.Response, msg):
|
||||
status_code, status_msg = status.value
|
||||
result = {
|
||||
"type": "response",
|
||||
"request_type": type,
|
||||
"status_code": status_code,
|
||||
"status_msg": status_msg,
|
||||
"detail_msg": msg
|
||||
}
|
||||
if status_code != 0:
|
||||
Helper.error(result);
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_surrounding_rect(points):
|
||||
if len(points) == 0:
|
||||
return Constant.invalid_rect
|
||||
min_x = min(p.x() for p in points)
|
||||
min_y = min(p.y() for p in points)
|
||||
max_x = max(p.x() for p in points)
|
||||
max_y = max(p.y() for p in points)
|
||||
return QRectF(min_x, min_y, max_x - min_x, max_y - min_y)
|
||||
|
||||
@staticmethod
|
||||
def getYoloLabellingInfo(dir_path, file_names, desc):
|
||||
if len(dir_path) > 0:
|
||||
imageNumber, labelNumber = 0, 0
|
||||
imagePath = dir_path + 'images/'
|
||||
labelPath = dir_path + 'labels/'
|
||||
for file_name in file_names:
|
||||
if file_name.startswith(imagePath):
|
||||
imageNumber += 1
|
||||
elif file_name.startswith(labelPath):
|
||||
labelNumber += 1
|
||||
return f'{desc} {imageNumber - 1} 张图片,{labelNumber - 1} 张标签;', imageNumber - 1
|
||||
return f'无{desc};', 0
|
||||
|
||||
@staticmethod
|
||||
def getMarkdownRenderText(mdContent):
|
||||
# 使用Python库直接将Markdown转换为HTML,避免JavaScript依赖
|
||||
try:
|
||||
# 尝试导入markdown库
|
||||
import markdown
|
||||
html_content = markdown.markdown(mdContent)
|
||||
# 添加基本样式使其美观
|
||||
styled_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
max-width: 900px; margin: 20px auto; padding: 0 20px; line-height: 1.6; }}
|
||||
code {{ background: #f5f5f5; padding: 2px 4px; border-radius: 3px; }}
|
||||
pre {{ background: #f5f5f5; padding: 10px; border-radius: 5px; overflow: auto; }}
|
||||
pre code {{ background: none; padding: 0; }}
|
||||
h1, h2, h3 {{ color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{html_content}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return styled_html
|
||||
except ImportError:
|
||||
# 回退到JavaScript的marked.js方法
|
||||
logger.warning("未找到markdown库,使用JavaScript渲染方式")
|
||||
# 从appIOs/configs加载marked.js
|
||||
file_js = QFile('appIOs/configs/marked.min.js')
|
||||
markedJs = ''
|
||||
if file_js.open(QIODevice.ReadOnly | QIODevice.Text):
|
||||
markedJs = file_js.readAll().data().decode('utf-8')
|
||||
file_js.close()
|
||||
|
||||
# 转义markdown内容
|
||||
escapedMd = mdContent.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''')
|
||||
|
||||
# 创建HTML模板
|
||||
htmlTemplate = '''
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
max-width: 900px; margin: 20px auto; padding: 0 20px; line-height: 1.6; }
|
||||
code { background: #f5f5f5; padding: 2px 4px; border-radius: 3px; }
|
||||
pre { background: #f5f5f5; padding: 10px; border-radius: 5px; overflow: auto; }
|
||||
pre code { background: none; padding: 0; }
|
||||
h1, h2, h3 { color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px; }
|
||||
</style>
|
||||
<script>
|
||||
// 加载marked库
|
||||
%1
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="content"></div>
|
||||
<script>
|
||||
const md = `%2`;
|
||||
document.getElementById('content').innerHTML = marked.parse(md);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
# 生成HTML内容
|
||||
htmlContent = htmlTemplate.replace('%1', markedJs).replace('%2', escapedMd)
|
||||
return htmlContent
|
||||
@staticmethod
|
||||
def getMarkdownRender(mdFileName):
|
||||
file = QFile(mdFileName)
|
||||
mdContent = ''
|
||||
if file.open(QIODevice.ReadOnly | QIODevice.Text):
|
||||
stream = QTextStream(file)
|
||||
stream.setAutoDetectUnicode(True)
|
||||
mdContent = stream.readAll()
|
||||
file.close()
|
||||
return Helper.getMarkdownRenderText(mdContent)
|
||||
else:
|
||||
logger.error(f'打开文件 {mdFileName} 失败')
|
||||
return f"<p style='color: red;'>无法打开文件: {mdFileName}</p>"
|
||||
|
||||
class RTTI:
|
||||
@staticmethod
|
||||
def _do_set_attr(obj, property_name, property_value):
|
||||
if obj is None:
|
||||
logger.error(f"RTTI.set: obj is None")
|
||||
return
|
||||
|
||||
class_name = type(obj).__name__
|
||||
object_name = obj.objectName()
|
||||
if property_name not in dir(obj):
|
||||
logger.error(f"RTTI.set: {class_name} {object_name}.{property_name} not in dir(obj)")
|
||||
return
|
||||
if property_name.endswith('icon') and isinstance(property_value, str):
|
||||
original_property_value = property_value
|
||||
if not os.path.exists(property_value):
|
||||
property_value = os.path.join('appIOs/res/images/icons',property_value)
|
||||
# logger.info(f"RTTI.set: {class_name} {object_name}.{property_name} = {property_value}(自动匹配)")
|
||||
if not os.path.exists(property_value):
|
||||
logger.error(f"{original_property_value}文件不存在 > RTTI.set: {class_name} {object_name}.{property_name} = '{original_property_value}'")
|
||||
return
|
||||
property_value = QIcon(property_value)
|
||||
setter_method = getattr(obj, f'set{property_name[0].upper() + property_name[1:]}')
|
||||
setter_method(property_value)
|
||||
@staticmethod
|
||||
def set(obj, property_name, property_value):
|
||||
property_list = property_name.split('.')
|
||||
if len(property_list) == 1:
|
||||
RTTI._do_set_attr(obj, property_name, property_value)
|
||||
else:
|
||||
dest_obj = obj
|
||||
for i in range(len(property_list) - 1):
|
||||
if not dest_obj:
|
||||
logger.error(f"RTTI.set: {property_list.join('.')} not found")
|
||||
return
|
||||
dest_obj = getattr(dest_obj, property_list[i])
|
||||
RTTI._do_set_attr(dest_obj, property_list[-1], property_value)
|
||||
|
||||
@staticmethod
|
||||
def _do_get_attr(obj, property_name):
|
||||
if obj is None:
|
||||
logger.error(f"RTTI.get: obj is None")
|
||||
return None, None
|
||||
if property_name not in dir(obj):
|
||||
logger.error(f"RTTI.get: {type(obj).__name__} {obj.objectName()}.{property_name} not in dir(obj)")
|
||||
return None, None
|
||||
# 返回属性类型与属性值
|
||||
type_name = type(getattr(obj, property_name)).__name__
|
||||
value = getattr(obj, property_name)
|
||||
return type_name, value
|
||||
# 取得属性类型与属性值 type, value = RTTI.get(obj, property_name)
|
||||
@staticmethod
|
||||
def get(obj, property_name):
|
||||
property_list = property_name.split('.')
|
||||
if len(property_list) == 1:
|
||||
return RTTI._do_get_attr(obj, property_name)
|
||||
else:
|
||||
dest_obj = obj
|
||||
for i in range(len(property_list) - 1):
|
||||
if not dest_obj:
|
||||
logger.error(f"RTTI.get: {property_list.join('.')} not found")
|
||||
return None
|
||||
dest_obj = getattr(dest_obj, property_list[i])
|
||||
return RTTI._do_get_attr(dest_obj, property_list[-1])
|
||||
|
||||
class DrawHelper:
|
||||
@staticmethod
|
||||
def draw_dashed_line(mat, pt1, pt2, color, thickness=1, dash_length=10):
|
||||
dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5
|
||||
dashes = int(dist / dash_length)
|
||||
for i in range(dashes):
|
||||
start = (int(pt1[0] + (pt2[0] - pt1[0]) * i / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * i / dashes))
|
||||
end = (int(pt1[0] + (pt2[0] - pt1[0]) * (i + 0.5) / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * (i + 0.5) / dashes))
|
||||
cv2.line(mat, start, end, color, thickness)
|
||||
|
||||
@staticmethod
|
||||
def draw_dashed_rect(painter, rect, color, thickness=1, dash_length=10):
|
||||
x1, y1 = rect.left(), rect.top()
|
||||
x2, y2 = rect.right(), rect.bottom()
|
||||
DrawHelper.draw_dashed_line(painter, (x1, y1), (x2, y1), color, thickness, dash_length)
|
||||
DrawHelper.draw_dashed_line(painter, (x1, y2), (x2, y2), color, thickness, dash_length)
|
||||
DrawHelper.draw_dashed_line(painter, (x1, y1), (x1, y2), color, thickness, dash_length)
|
||||
DrawHelper.draw_dashed_line(painter, (x2, y1), (x2, y2), color, thickness, dash_length)
|
||||
|
|
@ -0,0 +1,466 @@
|
|||
from loguru import logger
|
||||
import subprocess as sp
|
||||
from ultralytics import YOLO
|
||||
import time, cv2, numpy as np, math
|
||||
from traceback import format_exc
|
||||
from DrGraph.utils.pull_push import NetStream
|
||||
from DrGraph.utils.Helper import *
|
||||
from DrGraph.utils.Constant import Constant
|
||||
from zipfile import ZipFile
|
||||
|
||||
class YOLOTracker:
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
初始化YOLOv11追踪器
|
||||
"""
|
||||
self.model = YOLO(model_path)
|
||||
self.tracking_config = {
|
||||
"tracker": "appIOs/configs/yolo11/bytetrack.yaml", # "/home/thsw/jcq/projects/yolov11/ultralytics-main/ultralytics/cfg/trackers/bytetrack.yaml",
|
||||
"conf": 0.25,
|
||||
"iou": 0.45,
|
||||
"persist": True,
|
||||
"verbose": False
|
||||
}
|
||||
self.frame_count = 0
|
||||
self.processing_time = 0
|
||||
|
||||
def process_frame(self, frame):
|
||||
"""
|
||||
处理单帧图像,进行目标检测和追踪
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 执行YOLOv11目标检测和追踪
|
||||
results = self.model.track(
|
||||
source=frame,
|
||||
**self.tracking_config
|
||||
)
|
||||
|
||||
# 获取第一个结果(因为只处理单张图片)
|
||||
result = results[0]
|
||||
|
||||
# 绘制检测结果
|
||||
processed_frame = result.plot()
|
||||
|
||||
# 计算处理时间
|
||||
self.processing_time = (time.time() - start_time) * 1000 # 转换为毫秒
|
||||
self.frame_count += 1
|
||||
|
||||
# 打印检测信息(可选)
|
||||
if self.frame_count % 100 == 0:
|
||||
self._print_detection_info(result)
|
||||
|
||||
return processed_frame, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("YOLO处理异常: {}", format_exc())
|
||||
return frame, None
|
||||
|
||||
def _print_detection_info(self, result):
|
||||
"""
|
||||
打印检测信息
|
||||
"""
|
||||
boxes = result.boxes
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
detection_count = len(boxes)
|
||||
unique_ids = set()
|
||||
for box in boxes:
|
||||
if box.id is not None:
|
||||
unique_ids.add(int(box.id[0]))
|
||||
|
||||
logger.info(f"帧 {self.frame_count}: 检测到 {detection_count} 个目标, 追踪ID数: {len(unique_ids)}, 处理时间: {self.processing_time:.2f}ms")
|
||||
else:
|
||||
logger.info(f"帧 {self.frame_count}: 未检测到目标, 处理时间: {self.processing_time:.2f}ms")
|
||||
|
||||
class YOLOTrackerManager:
|
||||
def __init__(self, model_path, pull_url, push_url, request_id):
|
||||
self.pull_url = pull_url
|
||||
self.push_url = push_url
|
||||
self.request_id = request_id
|
||||
self.tracker = YOLOTracker(model_path)
|
||||
self.stream = None
|
||||
self.videoStream = None
|
||||
self.videoType = Constant.INPUT_NONE
|
||||
self.localFile = ''
|
||||
self.localPath = ''
|
||||
self.localFiles = []
|
||||
self._currentFrame = None
|
||||
self.totalFrames = 0
|
||||
self.frameChanged = False
|
||||
|
||||
def _stop(self):
|
||||
if self.videoStream is not None:
|
||||
self.videoStream.release()
|
||||
self.videoStream = None
|
||||
if self.stream is not None:
|
||||
self.stream.clear_pull_p(self.stream.pull_p, self.request_id)
|
||||
self.stream = None
|
||||
self.localFile = ''
|
||||
self.localPath = ''
|
||||
self.localFiles = []
|
||||
self._currentFrame = None
|
||||
self.totalFrames = 0
|
||||
self._frameIndex = -1
|
||||
self.videoType = Constant.INPUT_NONE
|
||||
self.frameChanged = True
|
||||
|
||||
def startLocalFile(self, fileName):
|
||||
self._stop()
|
||||
self.localFile = fileName
|
||||
self._frameIndex = -1
|
||||
|
||||
def startLocalDir(self, dirName):
|
||||
self._stop()
|
||||
self.localPath = dirName
|
||||
self.localFiles = [os.path.join(dirName, f) for f in os.listdir(dirName) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
||||
self.totalFrames = len(self.localFiles)
|
||||
Helper.App.progressMax = self.totalFrames
|
||||
self.localFiles.sort()
|
||||
logger.info("本地目录打开: {}, 总帧数: {}", dirName, self.totalFrames)
|
||||
self._frameIndex = 0
|
||||
|
||||
def startLabelledZip(self, labelledPath, categoryPath):
|
||||
self._stop()
|
||||
self.localPath = labelledPath
|
||||
localFiles = ZipFile(labelledPath).namelist()
|
||||
_, self.totalFrames = Helper.getYoloLabellingInfo(categoryPath, localFiles, '')
|
||||
imagePath = categoryPath + 'images/'
|
||||
self.localFiles = [file for file in localFiles if imagePath in file]
|
||||
logger.info(f"标注压缩文件{labelledPath}的{categoryPath}集共有{self.totalFrames}帧, 有效帧数: {len(self.localFiles)}")
|
||||
self._frameIndex = 0
|
||||
Helper.App.progressMax = self.totalFrames
|
||||
|
||||
def startUsbCamera(self, index = 0):
|
||||
self._stop()
|
||||
self.videoStream = cv2.VideoCapture(index)
|
||||
self.videoType = Constant.INPUT_USB_CAMERA
|
||||
Helper.Sleep(200)
|
||||
if not self.videoStream.isOpened():
|
||||
logger.error("无法打开USB摄像头: {}", index)
|
||||
self.videoType = Constant.INPUT_NONE
|
||||
return
|
||||
self.totalFrames = 0x7FFFFFFF
|
||||
|
||||
def startLocalVideo(self, fileName):
|
||||
self._stop()
|
||||
self.videoStream = cv2.VideoCapture(fileName)
|
||||
self.videoType = Constant.INPUT_LOCAL_VIDEO
|
||||
Helper.Sleep(200)
|
||||
if not self.videoStream.isOpened():
|
||||
logger.error("无法打开本地视频流: {}", fileName)
|
||||
self.videoType = Constant.INPUT_NONE
|
||||
return
|
||||
try:
|
||||
total = int(self.videoStream.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
except Exception:
|
||||
total = 0
|
||||
self.totalFrames = total if total is not None else 0
|
||||
Helper.App.progressMax = self.totalFrames
|
||||
logger.info("本地视频打开: {}, 总帧数: {}", fileName, self.totalFrames)
|
||||
|
||||
def startPull(self, url = ''):
|
||||
self._stop()
|
||||
if len(url) > 0:
|
||||
self.pull_url = url
|
||||
logger.info("拉流地址: {}", self.pull_url)
|
||||
self.stream = NetStream(self.pull_url, self.push_url, self.request_id)
|
||||
self.stream.prepare_pull()
|
||||
|
||||
def getCurrentFrame(self):
|
||||
if self._currentFrame is None:
|
||||
self._currentFrame = self.nextFrame()
|
||||
if self._currentFrame is not None:
|
||||
return self._currentFrame.copy()
|
||||
return None
|
||||
currentFrame = Property_Rw(getCurrentFrame, None)
|
||||
|
||||
def setFrameIndex(self, index):
|
||||
if self.videoStream is None and len(self.localFiles) == 0:
|
||||
return
|
||||
if self.videoStream is not None and self.videoType != Constant.INPUT_LOCAL_VIDEO:
|
||||
return
|
||||
if index < 0:
|
||||
index = 0
|
||||
if index >= self.totalFrames:
|
||||
index = self.totalFrames - 1
|
||||
if self.videoStream:
|
||||
self.videoStream.set(cv2.CAP_PROP_POS_FRAMES, index)
|
||||
self._frameIndex = index - 1
|
||||
self._currentFrame = self.nextFrame()
|
||||
self.frameChanged = True
|
||||
frameIndex = Property_rW(setFrameIndex, 0)
|
||||
|
||||
def getLabels(self):
|
||||
with ZipFile(self.localPath, 'r') as zip_ref:
|
||||
content = zip_ref.read(self.localFile)
|
||||
content = content.decode('utf-8')
|
||||
return content
|
||||
return ''
|
||||
# 取得待分析的图像帧
|
||||
def getAnalysisFrame(self, nextFlag):
|
||||
frameChanged = self.frameChanged
|
||||
self.frameChanged = False
|
||||
if nextFlag: # 流式媒体
|
||||
self._currentFrame = self.nextFrame()
|
||||
self.frameChanged = True
|
||||
frame = self.currentFrame
|
||||
return frame.copy() if frame is not None else None, frameChanged
|
||||
|
||||
def nextFrame(self):
|
||||
frame = None
|
||||
if self.stream:
|
||||
frame = self.stream.next_pull_frame()
|
||||
elif self.videoStream:
|
||||
ret, frame = self.videoStream.read()
|
||||
self._frameIndex += 1
|
||||
if not ret:
|
||||
self._frameIndex -= 1
|
||||
frame = None
|
||||
elif len(self.localFiles) > 0:
|
||||
if self.localPath.endswith('.zip'):
|
||||
index = -1
|
||||
for img_file in self.localFiles:
|
||||
if '/images/' in img_file:
|
||||
if index == self._frameIndex:
|
||||
# logger.warning(f'Loading image from zip file: {img_file}')
|
||||
try:
|
||||
with ZipFile(self.localPath, 'r') as zip_ref:
|
||||
image_data = zip_ref.read(img_file)
|
||||
nparr = np.frombuffer(image_data, np.uint8)
|
||||
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
self._frameIndex += 1
|
||||
lable_file = img_file.replace('/images/', '/labels/').replace('.jpg', '.txt').replace('.png', '.txt')
|
||||
self.localFile = lable_file
|
||||
except Exception as e:
|
||||
# logger.error(f"读取压缩文件 {self.localPath} 中的 {img_file} 失败: {e}")
|
||||
frame = None
|
||||
break
|
||||
index += 1
|
||||
else:
|
||||
if self._frameIndex < 0:
|
||||
self._frameIndex = 0
|
||||
if self._frameIndex >= len(self.localFiles):
|
||||
self._frameIndex = 0
|
||||
if self._frameIndex < len(self.localFiles):
|
||||
frame = cv2.imread(self.localFiles[self._frameIndex])
|
||||
if frame is None:
|
||||
logger.error(f"无法读取目标目录 {self.localPath}中下标为 {self._frameIndex} 的视频文件 {self.localFiles[self._frameIndex]}")
|
||||
self._frameIndex = -1
|
||||
return
|
||||
self._frameIndex += 1
|
||||
elif self.localFile is not None and self.localFile != '':
|
||||
frame = cv2.imread(self.localFile)
|
||||
if frame is None:
|
||||
logger.error("无法读取本地视频文件: {}", self.localFile)
|
||||
return
|
||||
if frame is not None:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
if self.totalFrames > 0:
|
||||
Helper.App.progress = self._frameIndex
|
||||
return frame
|
||||
|
||||
def test_yolo11_recognize(self, frame):
|
||||
processed_frame = self.process_frame_with_yolo(frame, self.request_id)
|
||||
return processed_frame
|
||||
|
||||
def process_frame_with_yolo(self, frame, requestId):
|
||||
"""
|
||||
使用YOLOv11处理帧
|
||||
"""
|
||||
try:
|
||||
# 使用YOLO进行目标检测和追踪
|
||||
processed_frame, detection_result = self.tracker.process_frame(frame)
|
||||
|
||||
# 在帧上添加处理信息
|
||||
fps_info = f"FPS: {1000/max(self.tracker.processing_time, 1):.1f}"
|
||||
cv2.putText(processed_frame, fps_info, (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
# 添加检测目标数量信息
|
||||
if detection_result and detection_result.boxes is not None:
|
||||
obj_count = len(detection_result.boxes)
|
||||
count_info = f"Objects: {obj_count}"
|
||||
cv2.putText(processed_frame, count_info, (10, 70),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
return processed_frame
|
||||
|
||||
except Exception as e:
|
||||
logger.error("YOLO处理异常:{}, requestId:{}", format_exc(), requestId)
|
||||
# 如果处理失败,返回原帧
|
||||
return frame
|
||||
|
||||
def get_gray_mask(self, frame):
|
||||
"""
|
||||
生成灰度像素的掩码图
|
||||
灰度像素定义:三颜色分量差小于20
|
||||
"""
|
||||
# 创建与原图大小相同的掩码图
|
||||
maskMat = np.zeros(frame.shape[:2], dtype=np.uint8)
|
||||
|
||||
# 获取图像的三个颜色通道
|
||||
b, g, r = cv2.split(frame)
|
||||
|
||||
r = r.astype(np.int16)
|
||||
g = g.astype(np.int16)
|
||||
b = b.astype(np.int16)
|
||||
# 计算任意两个颜色分量之间的差值
|
||||
diff_rg = np.abs(r - g)
|
||||
is_shadow = (b > r) & (b - r < 40)
|
||||
diff_rb = np.abs(r - b)
|
||||
diff_gb = np.abs(g - b)
|
||||
|
||||
# 判断条件:三颜色分量差都小于20
|
||||
gray_pixels = (diff_rg < 20 ) & (diff_rb < 20| is_shadow) & (diff_gb < 20)
|
||||
|
||||
# 将满足条件的像素在掩码图中设为255(白色)
|
||||
maskMat[gray_pixels] = 255
|
||||
|
||||
return maskMat
|
||||
|
||||
def debugLine(self, line, y_intersect):
|
||||
x1, y1, x2, y2 = line
|
||||
length = np.linalg.norm([x2 - x1, y2 - y1])
|
||||
# 计算线与水平线的夹角(度数)
|
||||
# 使用atan2计算弧度,再转换为度数
|
||||
angle_rad = math.atan2(y2 - y1, x2 - x1)
|
||||
angle_deg = math.degrees(angle_rad)
|
||||
# 调整角度范围到0-180度(平面角)
|
||||
if angle_deg < 0:
|
||||
angle_deg += 180
|
||||
# angle_deg = min(angle_deg, 180 - angle_deg)
|
||||
x_intersect = (x2 - x1) * (y_intersect - y1) / (y2 - y1) + x1
|
||||
return angle_deg, length, x_intersect
|
||||
def test_highway_recognize(self, frame, debugFlag = False):
|
||||
processed_frame = frame.copy()
|
||||
|
||||
try:
|
||||
IGNORE_HEIGHT = 100
|
||||
y_intersect = frame.shape[0] / 2
|
||||
frame[:IGNORE_HEIGHT, :] = (255, 0, 0)
|
||||
|
||||
gray_mask = self.get_gray_mask(frame)
|
||||
|
||||
kernel = np.ones((5, 5), np.uint8) # 使用形态学开运算(先腐蚀后膨胀)去除小噪声点
|
||||
gray_mask = cv2.erode(gray_mask, kernel)
|
||||
gray_mask = cv2.erode(gray_mask, kernel)
|
||||
|
||||
# 过滤掉面积小于10000的区域
|
||||
contours, _ = cv2.findContours(gray_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# 创建新的掩码图像,只保留面积大于等于10000的区域
|
||||
filtered_mask = np.zeros_like(gray_mask)
|
||||
for contour in contours:
|
||||
area = cv2.contourArea(contour)
|
||||
if area >= 10000: # 填充满足条件的轮廓区域
|
||||
cv2.fillPoly(filtered_mask, [contour], 255)
|
||||
|
||||
gray_mask = filtered_mask # 使用过滤后的掩码替换原来的gray_mask
|
||||
edges = cv2.Canny(frame, 100, 200) # 边缘检测
|
||||
road_edges = cv2.bitwise_and(edges, edges, mask=filtered_mask) # 在过滤后的路面区域内进行边缘检测
|
||||
|
||||
# 用color_mask过滤原图,得到待处理的图
|
||||
whiteLineMat = cv2.bitwise_and(processed_frame, processed_frame, mask=filtered_mask)
|
||||
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_RGB2GRAY) # 灰度化
|
||||
# sobel边缘检测
|
||||
whiteLineMat = cv2.Sobel(whiteLineMat, cv2.CV_8U, 1, 0, ksize=3)
|
||||
tempMat = whiteLineMat.copy()
|
||||
# whiteLineMat = cv2.Canny(whiteLineMat, 100, 200)
|
||||
lines = cv2.HoughLinesP(tempMat, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
|
||||
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# logger.info(f"{lines.shape[0]} lines: ")
|
||||
# if lines is not None:
|
||||
# for line in lines:
|
||||
# x1, y1, x2, y2 = line[0]
|
||||
# cv2.line(whiteLineMat, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||
|
||||
|
||||
# 创建彩色掩码用于叠加(使用绿色标记识别出的路面)
|
||||
color_mask = cv2.cvtColor(gray_mask, cv2.COLOR_GRAY2RGB)
|
||||
color_mask[:] = (0, 255, 0) # 设置为绿色
|
||||
color_mask = cv2.bitwise_and(color_mask, color_mask, mask=filtered_mask)
|
||||
|
||||
# 先叠加路面绿色标记,再叠加白色线条红色标记
|
||||
overlay = cv2.addWeighted(processed_frame, 0.7, color_mask, 0.3, 0)
|
||||
|
||||
# # 在road_edges的基础上,识别其中的实线
|
||||
# lines = cv2.HoughLinesP(road_edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
|
||||
# logger.info(f"{lines.shape[0]} lines: ")
|
||||
# linesWithAngle = []
|
||||
# # if lines is not None:
|
||||
# for index, line in enumerate(lines):
|
||||
# angle_deg, length, x_intersect = self.debugLine(line[0], y_intersect)
|
||||
# linesWithAngle.append((line, angle_deg, x_intersect))
|
||||
# if debugFlag:
|
||||
# logger.info(f'line {index + 1}: {line}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
|
||||
|
||||
# linesWithAngle进行聚类算法,按夹角分两类即可
|
||||
# 使用自定义的简单K-means聚类实现
|
||||
# line_data = np.array([[angle, x_intersect] for line, angle, x_intersect in linesWithAngle])
|
||||
# if len(line_data) > 0:
|
||||
# labels = self._simple_kmeans(line_data, n_clusters=2, random_state=2, random_state=0)
|
||||
# # 输出两类线的数目
|
||||
# logger.info(f"聚类结果:{np.bincount(labels)}")
|
||||
# if debugFlag:
|
||||
# lines0 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 0]
|
||||
# lines1 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 1]
|
||||
# # 取得lines0中所有线段并输出日志信息
|
||||
# for index, line in enumerate(lines0):
|
||||
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
|
||||
# logger.info(f'聚类0: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
|
||||
|
||||
# for index, line in enumerate(lines1):
|
||||
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
|
||||
# logger.info(f'聚类1: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
|
||||
|
||||
# # 保留数量多的类别
|
||||
# dominant_cluster = np.argmax(np.bincount(labels))
|
||||
# # 绘制dominant_cluster类别的线
|
||||
# dominant_lines = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == dominant_cluster]
|
||||
|
||||
# for line, angle, x_intersect in dominant_lines:
|
||||
# cv2.line(overlay, (int(line[0][0]), int(line[0][1])), (int(line[0][2]), int(line[0][3])), (255, 0, 0), 2)
|
||||
|
||||
return overlay, color_mask, whiteLineMat # cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) # cv2.cvtColor(road_edges, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("路面识别异常:{}", format_exc())
|
||||
# 如果处理失败,返回原始帧
|
||||
return processed_frame
|
||||
|
||||
# def _simple_kmeans(self, data, n_clusters=2, max_iter=100, random_state=0):
|
||||
# """
|
||||
# 使用K-means算法对数据进行聚类
|
||||
|
||||
# 参数:
|
||||
# data: array-like, 形状为 (n_samples, n_features) 的输入数据
|
||||
# n_clusters: int, 聚类数量,默认为2
|
||||
# max_iter: int, 最大迭代次数,默认为100
|
||||
# random_state: int, 随机种子,用于初始化质心,默认为0
|
||||
|
||||
# 返回:
|
||||
# labels: array, 形状为 (n_samples,) 的聚类标签数组
|
||||
# """
|
||||
# np.random.seed(random_state)
|
||||
|
||||
# # 随机选择初始质心
|
||||
# centroids_idx = np.random.choice(len(data), size=n_clusters, replace=False)
|
||||
# centroids = data[centroids_idx].copy()
|
||||
|
||||
# # 迭代优化质心位置
|
||||
# for _ in range(max_iter):
|
||||
# # 为每个数据点分配最近的质心标签
|
||||
# labels = np.zeros(len(data), dtype=int)
|
||||
# for i, point in enumerate(data):
|
||||
# distancesi=ids - point, ax(centroids - point, axis=1) ce置为- # 情况如果d sfnpcnsy>d9e,则置为>180 -9dis ini作为新质心
|
||||
# new_centroids[c] = data[np.random.choice(len(data))]
|
||||
|
||||
# # 检查收敛条件
|
||||
# if np.allclose(centroids, new_centroids):
|
||||
# break
|
||||
|
||||
# centroids = new_centroids
|
||||
|
||||
# return labels
|
||||
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
import subprocess as sp
|
||||
from traceback import format_exc
|
||||
import cv2, time
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from DrGraph.utils.Helper import Helper
|
||||
from DrGraph.utils.Exception import ServiceException
|
||||
from DrGraph.utils.Constant import Constant
|
||||
|
||||
class NetStream:
|
||||
def __init__(self, pull_url, push_url, request_id):
|
||||
self.pull_url = pull_url
|
||||
self.push_url = push_url
|
||||
self.request_id = request_id
|
||||
self.pull_p = None
|
||||
|
||||
self.width = 1920
|
||||
self.height = 1080 * 3 // 2
|
||||
self.width_height_3 = 1920 * 1080 * 3 // 2
|
||||
self.w_2 = 960
|
||||
self.h_2 = 540
|
||||
|
||||
self.frame_count = 0
|
||||
self.start_time = time.time();
|
||||
|
||||
def clear_pull_p(self,pull_p, requestId):
|
||||
try:
|
||||
if pull_p and pull_p.poll() is None:
|
||||
logger.info("关闭拉流管道, requestId:{}", requestId)
|
||||
if pull_p.stdout:
|
||||
pull_p.stdout.close()
|
||||
pull_p.terminate()
|
||||
pull_p.wait(timeout=30)
|
||||
logger.info("拉流管道已关闭, requestId:{}", requestId)
|
||||
except Exception as e:
|
||||
logger.error("关闭拉流管道异常: {}, requestId:{}", format_exc(), requestId)
|
||||
if pull_p and pull_p.poll() is None:
|
||||
pull_p.kill()
|
||||
pull_p.wait(timeout=30)
|
||||
raise e
|
||||
|
||||
def start_pull_p(self, pull_url, requestId):
|
||||
try:
|
||||
command = ['D:/DrGraph/DSP/ffmpeg.exe']
|
||||
# if pull_url.startswith("rtsp://"):
|
||||
# command.extend(['-timeout', '20000000', '-rtsp_transport', 'tcp'])
|
||||
# if pull_url.startswith("http") or pull_url.startswith("rtmp"):
|
||||
# command.extend(['-rw_timeout', '20000000'])
|
||||
command.extend(['-re',
|
||||
'-y',
|
||||
'-an',
|
||||
# '-hwaccel', 'cuda', cuvid
|
||||
'-c:v', 'h264_cuvid',
|
||||
# '-resize', self.wah,
|
||||
'-i', pull_url,
|
||||
'-f', 'rawvideo',
|
||||
# '-pix_fmt', 'bgr24',
|
||||
'-r', '25',
|
||||
'-'])
|
||||
self.pull_p = sp.Popen(command, stdout=sp.PIPE)
|
||||
return self.pull_p
|
||||
except ServiceException as s:
|
||||
logger.error("构建拉流管道ServiceException异常: url={}, {}, requestId:{}", pull_url, s.msg, requestId)
|
||||
raise s
|
||||
except Exception as e:
|
||||
logger.error("构建拉流管道Exception异常:url={}, {}, requestId:{}", pull_url, format_exc(), requestId)
|
||||
raise e
|
||||
|
||||
def pull_read_video_stream(self):
|
||||
result = None
|
||||
try:
|
||||
if self.pull_p is None:
|
||||
self.start_pull_p(self.pull_url, self.request_id)
|
||||
in_bytes = self.pull_p.stdout.read(self.width_height_3)
|
||||
if in_bytes is not None and len(in_bytes) > 0:
|
||||
try:
|
||||
# result = (np.frombuffer(in_bytes, np.uint8).reshape([height * 3 // 2, width, 3]))
|
||||
# ValueError: cannot reshape array of size 3110400 into shape (1080,1920)
|
||||
result = (np.frombuffer(in_bytes, np.uint8)).reshape((self.height, self.width))
|
||||
result = cv2.cvtColor(result, cv2.COLOR_YUV2BGR_NV12)
|
||||
# result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
||||
if result.shape[1] > Constant.pull_frame_width:
|
||||
result = cv2.resize(result, (result.shape[1] // 2, result.shape[0] // 2), interpolation=cv2.INTER_LINEAR)
|
||||
except Exception:
|
||||
logger.error("视频格式异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[0],
|
||||
ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[1])
|
||||
except ServiceException as s:
|
||||
logger.error("ServiceException 读流异常: {}, requestId:{}", s.msg, self.request_id)
|
||||
self.clear_pull_p(self.pull_p, self.request_id)
|
||||
self.pull_p = None
|
||||
result = None
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("Exception 读流异常:{}, requestId:{}", format_exc(), self.request_id)
|
||||
self.clear_pull_p(self.pull_p, self.request_id)
|
||||
self.pull_p = None
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.width_height_3 = None
|
||||
result = None
|
||||
logger.error("读流异常:{}, requestId:{}", format_exc(), self.request_id)
|
||||
return result
|
||||
|
||||
def prepare_pull(self):
|
||||
if self.pull_p is None:
|
||||
self.start_time = time.time();
|
||||
self.start_pull_p(self.pull_url, self.request_id)
|
||||
|
||||
def next_pull_frame(self):
|
||||
if self.pull_p is None:
|
||||
logger.error(f'pull_p is None, requestId: {self.request_id}')
|
||||
return None
|
||||
frame = self.pull_read_video_stream()
|
||||
return frame
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
"""Custom exceptions and error handlers for the chat agent application."""
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing import Union
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
import json
|
||||
from loguru import logger
|
||||
|
||||
from starlette.status import (
|
||||
HTTP_400_BAD_REQUEST,
|
||||
HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -25,14 +27,33 @@ class FullHxfResponseModel(BaseModel):
|
|||
message: Optional[str]
|
||||
|
||||
class HxfResponse(JSONResponse):
|
||||
def __init__(self, response: Union[BaseModel, Dict[str, Any]]):
|
||||
if isinstance(response, BaseModel):
|
||||
def __new__(cls, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]], StreamingResponse]):
|
||||
# 如果是StreamingResponse,直接返回,不进行JSON包装
|
||||
if isinstance(response, StreamingResponse):
|
||||
return response
|
||||
# 否则创建HxfResponse实例
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]]]):
|
||||
code = 0
|
||||
if isinstance(response, list):
|
||||
# 处理BaseModel对象列表
|
||||
if all(isinstance(item, BaseModel) for item in response):
|
||||
data_dict = [item.model_dump(mode='json') for item in response]
|
||||
else:
|
||||
data_dict = response
|
||||
elif isinstance(response, BaseModel):
|
||||
# 处理单个BaseModel对象
|
||||
data_dict = response.model_dump(mode='json')
|
||||
else:
|
||||
# 处理字典或其他可JSON序列化的数据
|
||||
data_dict = response
|
||||
|
||||
if 'success' in data_dict and data_dict['success'] == False:
|
||||
code = -1
|
||||
|
||||
content = {
|
||||
"code": 0,
|
||||
"code": code,
|
||||
"status": status.HTTP_200_OK,
|
||||
"data": data_dict,
|
||||
"error": None,
|
||||
|
|
@ -50,13 +71,24 @@ class HxfErrorResponse(JSONResponse):
|
|||
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
|
||||
"""Return a JSON error response."""
|
||||
if isinstance(message, Exception):
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": message.status_code,
|
||||
"data": None,
|
||||
"error": message.details,
|
||||
"message": message.message
|
||||
}
|
||||
msg = message.message if 'message' in message.__dict__ else str(message)
|
||||
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
|
||||
if isinstance(message, TypeError):
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": 500,
|
||||
"data": None,
|
||||
"error": f"错误类型: {type(message)} 错误信息: {message}",
|
||||
"message": msg
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
"code": -1,
|
||||
"status": message.status_code,
|
||||
"data": None,
|
||||
"error": None,
|
||||
"message": msg
|
||||
}
|
||||
else:
|
||||
content = {
|
||||
"code": -1,
|
||||
|
|
|
|||
|
|
@ -310,6 +310,7 @@ class ConversationResponse(BaseResponse, ConversationBase):
|
|||
is_archived: bool
|
||||
message_count: int = 0
|
||||
last_message_at: Optional[datetime] = None
|
||||
messages: Optional[List["MessageResponse"]] = None
|
||||
|
||||
|
||||
# Message schemas
|
||||
|
|
@ -345,11 +346,11 @@ class ChatRequest(BaseModel):
|
|||
message: str = Field(..., min_length=1, max_length=10000)
|
||||
stream: bool = Field(default=False)
|
||||
use_knowledge_base: bool = Field(default=False)
|
||||
knowledge_base_id: Optional[int] = Field(None, description="Knowledge base ID for RAG mode")
|
||||
knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode")
|
||||
use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities")
|
||||
use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling")
|
||||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
|
|
@ -458,6 +459,7 @@ class DocumentChunk(BaseModel):
|
|||
chunk_index: int
|
||||
start_char: Optional[int] = None
|
||||
end_char: Optional[int] = None
|
||||
vector_id: Optional[str] = None
|
||||
|
||||
|
||||
class DocumentChunksResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
from enum import Enum, unique
|
||||
|
||||
CommonColors_Backgnd = [
|
||||
"#FF0000", # 红色
|
||||
"#00FF00", # 绿色
|
||||
"#0000FF", # 蓝色
|
||||
"#FFFF00", # 黄色
|
||||
"#FF00FF", # 品红
|
||||
"#00FFFF", # 青色
|
||||
"#FFA500", # 橙色
|
||||
"#800080", # 紫色
|
||||
"#FFC0CB", # 粉色
|
||||
"#008000", # 深绿色
|
||||
"#000080", # 深蓝色
|
||||
"#800000", # 深红色
|
||||
"#808000", # 橄榄色
|
||||
"#008080", # 青色
|
||||
"#808080", # 灰色
|
||||
"#FF0080", # 玫瑰红
|
||||
"#0080FF", # 天蓝色
|
||||
"#FF8000", # 橙红色
|
||||
"#8000FF", # 紫罗兰
|
||||
"#00FF80" # 海绿色
|
||||
]
|
||||
|
||||
CommonColors_Foregnd = [
|
||||
"#FFFFFF", # 红色背景 -> 白色文字
|
||||
"#000000", # 绿色背景 -> 黑色文字
|
||||
"#FFFFFF", # 蓝色背景 -> 白色文字
|
||||
"#000000", # 黄色背景 -> 黑色文字
|
||||
"#FFFFFF", # 品红背景 -> 白色文字
|
||||
"#000000", # 青色背景 -> 黑色文字
|
||||
"#000000", # 橙色背景 -> 黑色文字
|
||||
"#FFFFFF", # 紫色背景 -> 白色文字
|
||||
"#000000", # 粉色背景 -> 黑色文字
|
||||
"#FFFFFF", # 深绿色背景 -> 白色文字
|
||||
"#FFFFFF", # 深蓝色背景 -> 白色文字
|
||||
"#FFFFFF", # 深红色背景 -> 白色文字
|
||||
"#FFFFFF", # 橄榄色背景 -> 白色文字
|
||||
"#FFFFFF", # 青色背景 -> 白色文字
|
||||
"#FFFFFF", # 灰色背景 -> 白色文字
|
||||
"#FFFFFF", # 玫瑰红背景 -> 白色文字
|
||||
"#000000", # 天蓝色背景 -> 黑色文字
|
||||
"#000000", # 橙红色背景 -> 黑色文字
|
||||
"#FFFFFF", # 紫罗兰背景 -> 白色文字
|
||||
"#000000" # 海绿色背景 -> 黑色文字
|
||||
]
|
||||
|
||||
@unique
|
||||
class LabellingKind(Enum):
|
||||
Unknown = 0
|
||||
Select = 1
|
||||
Create = 2
|
||||
TempDrag = 3
|
||||
TempResize = 4
|
||||
TempRotate = 5
|
||||
|
||||
@unique
|
||||
class Meta(Enum):
|
||||
Unknown = 0
|
||||
Line = 1
|
||||
Rectangle = 2
|
||||
Ellipse = 3
|
||||
Polygon = 4
|
||||
Text = 5
|
||||
|
||||
@unique
|
||||
class AiAlg(Enum):
|
||||
Unknown = 0
|
||||
FashionMNIST = 1,
|
||||
ColorDetector = 2,
|
||||
Face = 3,
|
||||
Coco8 = 4
|
||||
|
||||
|
||||
@unique
|
||||
class Response(Enum):
|
||||
OK = (0, "OK - 响应成功")
|
||||
DEBUG = (1, "DEBUG - 正常调试")
|
||||
WARNING = (2, "FAIL - 响应警告")
|
||||
ERROR = (3, "ERROR - 响应错误")
|
||||
EXCEPTION = (4, "EXCEPTION - 响应异常")
|
||||
CRITICAL = (5, "CRITICAL - 响应严重错误")
|
||||
|
||||
class Flag(Enum):
|
||||
Paint = True
|
||||
Client = False
|
||||
|
||||
|
||||
class Align(Enum):
|
||||
Left = 0
|
||||
Top = 1
|
||||
Right = 2
|
||||
Bottom = 3
|
||||
Client = 4
|
||||
Custom = 5
|
||||
|
||||
class Anchor(Enum):
|
||||
Left = 1
|
||||
Top = 2
|
||||
Right = 4
|
||||
Bottom = 8
|
||||
|
||||
class IconPos(Enum):
|
||||
NoIcon = 0
|
||||
OnlyIcon = 1
|
||||
IconLeft = 2
|
||||
IconRight = 3
|
||||
IconTop = 4
|
||||
IconBottom = 5
|
||||
|
||||
class ControlEvent(Enum):
|
||||
MouseEnter = 0
|
||||
MouseLeave = 1
|
||||
MouseMove = 2
|
||||
MouseDown = 3
|
||||
MouseUp = 4
|
||||
Click = 5
|
||||
DblClick = 6
|
||||
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import time, websockets, json, inspect, os
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from typing import Dict, Set, Callable, Any, Optional
|
||||
from traceback import format_exc
|
||||
|
||||
from DrGraph.utils.Helper import *
|
||||
import DrGraph.utils.vclEnums as enums
|
||||
|
||||
class WebSocketServer:
|
||||
"""
|
||||
WebSocket服务器类
|
||||
提供WebSocket服务器功能,包括客户端连接管理、消息处理和广播功能
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "localhost", port: int = 8765):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.message_handlers: Dict[str, Callable] = {}
|
||||
self.server = None
|
||||
self.onClientConnect = None
|
||||
self.register_handler("file", self.handle_file)
|
||||
|
||||
def register_handler(self, message_type: str, handler: Callable):
|
||||
self.message_handlers[message_type] = handler
|
||||
|
||||
async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'):
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message, ensure_ascii=False)
|
||||
if Helper.AppFlag_SaveLog:
|
||||
logger.warning(f"发送消息: {message}")
|
||||
await websocket.send(f't{message}')
|
||||
caller = inspect.stack()[1]
|
||||
logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ")
|
||||
|
||||
async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data):
|
||||
# if isinstance(data, np.ndarray):
|
||||
# data = data.tobytes()
|
||||
await websocket.send(bytearray(b'b' + data))
|
||||
|
||||
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
message_type = data.get("type")
|
||||
payload = data.get("data")
|
||||
if Helper.AppFlag_SaveLog:
|
||||
logger.warning(f"收到消息: {message}")
|
||||
|
||||
if message_type in self.message_handlers:
|
||||
response = await self.message_handlers[message_type](websocket, payload)
|
||||
if response is not None:
|
||||
if isinstance(response, (bytes, bytearray, np.ndarray)):
|
||||
# print("发送图片数据")
|
||||
await self.response_binary(websocket, response);
|
||||
else:
|
||||
await self.response_text(websocket, response, message_type);
|
||||
elif message_type == 'echo':
|
||||
print("echo message ", int(time.time() * 1000))
|
||||
data["type"] = "echo_response"
|
||||
await self.response_text(websocket, data);
|
||||
else:
|
||||
logger.warning(f"未知类型消息: {message} - {websocket.remote_address}")
|
||||
await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}"))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无效的JSON格式 - {message}")
|
||||
await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}"))
|
||||
except BrokenPipeError as e:
|
||||
logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}")
|
||||
# BrokenPipeError表示连接已断开,不需要特殊处理,让上层处理ConnectionClosed异常
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {format_exc()}")
|
||||
await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}"))
|
||||
|
||||
async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""):
|
||||
"""
|
||||
处理客户端连接
|
||||
|
||||
参数:
|
||||
websocket (websockets.WebSocketServerProtocol): WebSocket连接对象
|
||||
path (str): 请求路径
|
||||
"""
|
||||
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接")
|
||||
if self.onClientConnect:
|
||||
await self.onClientConnect(websocket, True)
|
||||
try:
|
||||
async for message in websocket:
|
||||
await self.handle_message(websocket, message)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭")
|
||||
except BrokenPipeError:
|
||||
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError")
|
||||
finally:
|
||||
if self.onClientConnect:
|
||||
await self.onClientConnect(websocket, False)
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
启动WebSocket服务器
|
||||
"""
|
||||
self.server = await websockets.serve(self.handle_client, self.host, self.port)
|
||||
logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}")
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
停止WebSocket服务器
|
||||
"""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
logger.info("WebSocket服务器已停止")
|
||||
|
||||
async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str):
|
||||
logger.info(f"接收文件: {payload}, {payload['command']}")
|
||||
command = payload['command']
|
||||
if command == 'dir':
|
||||
path = payload.get('path', '/')
|
||||
files = []
|
||||
folders = []
|
||||
try:
|
||||
# 获取目录下的所有文件和文件夹
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
with os.scandir(path) as entries:
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
files.append(entry.name)
|
||||
elif entry.is_dir():
|
||||
folders.append(entry.name)
|
||||
else:
|
||||
logger.warning(f"路径不存在或不是目录: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录时出错: {path}, 错误: {e}")
|
||||
|
||||
# 合并文件和文件夹列表
|
||||
all_items = folders + files
|
||||
logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}")
|
||||
|
|
@ -2,7 +2,7 @@ enable_file_log: true
|
|||
enable_stderr: true
|
||||
base_path: "webIOs/output/logs"
|
||||
log_name: "th_agenter_web.log"
|
||||
log_fmt: "<green>{time: HH:mm:ss.SSS}</green> [<level>{level}</level>] - <level>{message}</level> @ <cyan>{extra[relative_path]}:{line}</cyan> in <blue>{function}</blue>"
|
||||
log_fmt: "<green>{time: HH:mm:ss.SSS}</green> [<level>{level:7}</level>] - <level>{message}</level> @ <cyan>{extra[relative_path]}:{line}</cyan> in <blue>{function}</blue>"
|
||||
level: "INFO"
|
||||
rotation: "00:00"
|
||||
retention: "1 days"
|
||||
|
|
|
|||
Loading…
Reference in New Issue