diff --git a/backend/alembic/README b/backend/alembic/README index 5d61b6b..e0d0858 100644 --- a/backend/alembic/README +++ b/backend/alembic/README @@ -1,4 +1 @@ -Generic single-database configuration with an async dbapi. - -alembic revision --autogenerate -m "init" -alembic upgrade head \ No newline at end of file +Generic single-database configuration with an async dbapi. \ No newline at end of file diff --git a/backend/alembic/versions/424646027786_initial_migration.py b/backend/alembic/versions/424646027786_initial_migration.py new file mode 100644 index 0000000..a44a110 --- /dev/null +++ b/backend/alembic/versions/424646027786_initial_migration.py @@ -0,0 +1,359 @@ +"""Initial migration + +Revision ID: 424646027786 +Revises: +Create Date: 2025-12-16 09:56:45.172954 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '424646027786' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('agent_configs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('enabled_tools', sa.JSON(), nullable=False), + sa.Column('max_iterations', sa.Integer(), nullable=False), + sa.Column('temperature', sa.String(length=10), nullable=False), + sa.Column('system_message', sa.Text(), nullable=True), + sa.Column('verbose', sa.Boolean(), nullable=False), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs')) + ) + op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False) + op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False) + op.create_table('conversations', + sa.Column('title', sa.String(length=200), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('knowledge_base_id', sa.Integer(), nullable=True), + sa.Column('system_prompt', sa.Text(), nullable=True), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('temperature', sa.String(length=10), nullable=False), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('is_archived', sa.Boolean(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations')) + ) + op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False) + op.create_table('database_configs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('db_type', sa.String(length=20), nullable=False), + sa.Column('host', sa.String(length=255), nullable=False), + sa.Column('port', sa.Integer(), nullable=False), + sa.Column('database', sa.String(length=100), nullable=False), + sa.Column('username', sa.String(length=100), nullable=False), + sa.Column('password', sa.Text(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('connection_params', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')), + sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type')) + ) + op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False) + op.create_table('documents', + sa.Column('knowledge_base_id', sa.Integer(), nullable=False), + sa.Column('filename', sa.String(length=255), nullable=False), + sa.Column('original_filename', sa.String(length=255), nullable=False), + sa.Column('file_path', sa.String(length=500), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + sa.Column('file_type', sa.String(length=50), nullable=False), + sa.Column('mime_type', sa.String(length=100), nullable=True), + sa.Column('is_processed', sa.Boolean(), nullable=False), + sa.Column('processing_error', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('doc_metadata', sa.JSON(), nullable=True), + sa.Column('chunk_count', sa.Integer(), nullable=False), + sa.Column('embedding_model', sa.String(length=100), nullable=True), + sa.Column('vector_ids', sa.JSON(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_documents')) + ) + op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False) + op.create_table('excel_files', + sa.Column('original_filename', sa.String(length=255), nullable=False), + sa.Column('file_path', sa.String(length=500), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + sa.Column('file_type', sa.String(length=50), nullable=False), + sa.Column('sheet_names', sa.JSON(), nullable=False), + sa.Column('default_sheet', sa.String(length=100), nullable=True), + sa.Column('columns_info', sa.JSON(), nullable=False), + sa.Column('preview_data', sa.JSON(), nullable=False), + sa.Column('data_types', sa.JSON(), nullable=True), + sa.Column('total_rows', sa.JSON(), nullable=True), + sa.Column('total_columns', sa.JSON(), nullable=True), + sa.Column('is_processed', sa.Boolean(), nullable=False), + sa.Column('processing_error', sa.Text(), nullable=True), + sa.Column('last_accessed', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files')) + ) + op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False) + op.create_table('knowledge_bases', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('embedding_model', sa.String(length=100), nullable=False), + sa.Column('chunk_size', sa.Integer(), nullable=False), + sa.Column('chunk_overlap', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('vector_db_type', sa.String(length=50), nullable=False), + sa.Column('collection_name', sa.String(length=100), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases')) + ) + op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False) + op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False) + op.create_table('llm_configs', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('provider', sa.String(length=50), nullable=False), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('api_key', sa.String(length=500), nullable=False), + sa.Column('base_url', sa.String(length=200), nullable=True), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('temperature', sa.Float(), nullable=False), + sa.Column('top_p', sa.Float(), nullable=False), + sa.Column('frequency_penalty', sa.Float(), nullable=False), + sa.Column('presence_penalty', sa.Float(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('is_embedding', sa.Boolean(), nullable=False), + sa.Column('extra_config', sa.JSON(), nullable=True), + sa.Column('usage_count', sa.Integer(), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs')) + ) + op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False) + op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False) + op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False) + op.create_table('messages', + sa.Column('conversation_id', sa.Integer(), nullable=False), + sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False), + sa.Column('message_metadata', sa.JSON(), nullable=True), + sa.Column('context_documents', sa.JSON(), nullable=True), + sa.Column('prompt_tokens', sa.Integer(), nullable=True), + sa.Column('completion_tokens', sa.Integer(), nullable=True), + sa.Column('total_tokens', sa.Integer(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_messages')) + ) + op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False) + op.create_table('roles', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('code', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_system', sa.Boolean(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_roles')) + ) + op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True) + op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False) + op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True) + op.create_table('table_metadata', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('table_name', sa.String(length=100), nullable=False), + sa.Column('table_schema', sa.String(length=50), nullable=False), + sa.Column('table_type', sa.String(length=20), nullable=False), + sa.Column('table_comment', sa.Text(), nullable=True), + sa.Column('database_config_id', sa.Integer(), nullable=True), + sa.Column('columns_info', sa.JSON(), nullable=False), + sa.Column('primary_keys', sa.JSON(), nullable=True), + sa.Column('foreign_keys', sa.JSON(), nullable=True), + sa.Column('indexes', sa.JSON(), nullable=True), + sa.Column('sample_data', sa.JSON(), nullable=True), + sa.Column('row_count', sa.Integer(), nullable=False), + sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False), + sa.Column('qa_description', sa.Text(), nullable=True), + sa.Column('business_context', sa.Text(), nullable=True), + sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata')) + ) + op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False) + op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False) + op.create_table('users', + sa.Column('username', sa.String(length=50), nullable=False), + sa.Column('email', sa.String(length=100), nullable=False), + sa.Column('hashed_password', sa.String(length=255), nullable=False), + sa.Column('full_name', sa.String(length=100), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('avatar_url', sa.String(length=255), nullable=True), + sa.Column('bio', sa.Text(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_users')) + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + op.create_table('user_roles', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('role_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')), + sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles')) + ) + op.create_table('workflows', + sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'), + sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'), + sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'), + sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'), + sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'), + sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'), + sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows')) + ) + op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False) + op.create_table('workflow_executions', + sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'), + sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'), + sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'), + sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'), + sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'), + sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'), + sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')), + sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions')) + ) + op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False) + op.create_table('node_executions', + sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'), + sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'), + sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'), + sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'), + sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'), + sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'), + sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'), + sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'), + sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'), + sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions')) + ) + op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions') + op.drop_table('node_executions') + op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions') + op.drop_table('workflow_executions') + op.drop_index(op.f('ix_workflows_id'), table_name='workflows') + op.drop_table('workflows') + op.drop_table('user_roles') + op.drop_index(op.f('ix_users_username'), table_name='users') + op.drop_index(op.f('ix_users_id'), table_name='users') + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata') + op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata') + op.drop_table('table_metadata') + op.drop_index(op.f('ix_roles_name'), table_name='roles') + op.drop_index(op.f('ix_roles_id'), table_name='roles') + op.drop_index(op.f('ix_roles_code'), table_name='roles') + op.drop_table('roles') + op.drop_index(op.f('ix_messages_id'), table_name='messages') + op.drop_table('messages') + op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs') + op.drop_table('llm_configs') + op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases') + op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases') + op.drop_table('knowledge_bases') + op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files') + op.drop_table('excel_files') + op.drop_index(op.f('ix_documents_id'), table_name='documents') + op.drop_table('documents') + op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs') + op.drop_table('database_configs') + op.drop_index(op.f('ix_conversations_id'), table_name='conversations') + op.drop_table('conversations') + op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs') + op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs') + op.drop_table('agent_configs') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py b/backend/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py new file mode 100644 index 0000000..35748fc --- /dev/null +++ b/backend/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py @@ -0,0 +1,34 @@ +"""Add message_count and last_message_at to conversations + +Revision ID: 8da391c6e2b7 +Revises: 424646027786 +Create Date: 2025-12-19 16:16:29.943314 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8da391c6e2b7' +down_revision: Union[str, Sequence[str], None] = '424646027786' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('conversations', sa.Column('message_count', sa.Integer(), nullable=False)) + op.add_column('conversations', sa.Column('last_message_at', sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('conversations', 'last_message_at') + op.drop_column('conversations', 'message_count') + # ### end Alembic commands ### diff --git a/backend/check_db_constraint.py b/backend/check_db_constraint.py new file mode 100644 index 0000000..cee354c --- /dev/null +++ b/backend/check_db_constraint.py @@ -0,0 +1,42 @@ +from sqlalchemy import create_engine, inspect +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +import asyncio +import os +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() + +async def check_table_constraints(): + try: + # 获取数据库连接字符串 + DATABASE_URL = os.getenv("DATABASE_URL", "mysql+asyncmy://root:123456@localhost:3306/th_agenter") + + # 创建异步引擎 + engine = create_async_engine(DATABASE_URL, echo=True) + + # 创建会话 + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + # 获取数据库连接 + async with session.begin(): + # 使用inspect查看表结构 + inspector = inspect(engine) + + # 获取messages表的所有约束 + constraints = await engine.run_sync(inspector.get_table_constraints, 'messages') + print("Messages表的所有约束:") + for constraint in constraints: + print(f" 约束名称: {constraint['name']}, 类型: {constraint['type']}") + if constraint['type'] == 'PRIMARY KEY': + print(f" 主键约束列: {constraint['constrained_columns']}") + + except Exception as e: + print(f"检查约束时出错: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(check_table_constraints()) \ No newline at end of file diff --git a/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin new file mode 100644 index 0000000..3746ace Binary files /dev/null and b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin differ diff --git a/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin differ diff --git a/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin new file mode 100644 index 0000000..69654bf Binary files /dev/null and b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin differ diff --git a/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/link_lists.bin b/backend/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/backend/data/chroma/kb_13/chroma.sqlite3 b/backend/data/chroma/kb_13/chroma.sqlite3 new file mode 100644 index 0000000..a4a8e46 Binary files /dev/null and b/backend/data/chroma/kb_13/chroma.sqlite3 differ diff --git a/backend/data/chroma/kb_14/chroma.sqlite3 b/backend/data/chroma/kb_14/chroma.sqlite3 new file mode 100644 index 0000000..32aa6ca Binary files /dev/null and b/backend/data/chroma/kb_14/chroma.sqlite3 differ diff --git a/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin new file mode 100644 index 0000000..d51ad17 Binary files /dev/null and b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin differ diff --git a/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin differ diff --git a/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin new file mode 100644 index 0000000..3b62fac Binary files /dev/null and b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin differ diff --git a/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/link_lists.bin b/backend/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin new file mode 100644 index 0000000..078d9b6 Binary files /dev/null and b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin differ diff --git a/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin differ diff --git a/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin new file mode 100644 index 0000000..b466fb9 Binary files /dev/null and b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin differ diff --git a/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/link_lists.bin b/backend/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/backend/data/chroma/kb_15/chroma.sqlite3 b/backend/data/chroma/kb_15/chroma.sqlite3 new file mode 100644 index 0000000..cf591c9 Binary files /dev/null and b/backend/data/chroma/kb_15/chroma.sqlite3 differ diff --git a/backend/data/chroma/kb_16/chroma.sqlite3 b/backend/data/chroma/kb_16/chroma.sqlite3 new file mode 100644 index 0000000..0b74e6d Binary files /dev/null and b/backend/data/chroma/kb_16/chroma.sqlite3 differ diff --git a/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin new file mode 100644 index 0000000..057101d Binary files /dev/null and b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin differ diff --git a/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin differ diff --git a/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin new file mode 100644 index 0000000..c068776 Binary files /dev/null and b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin differ diff --git a/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/link_lists.bin b/backend/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/backend/data/chroma/kb_18/chroma.sqlite3 b/backend/data/chroma/kb_18/chroma.sqlite3 new file mode 100644 index 0000000..6ea43ca Binary files /dev/null and b/backend/data/chroma/kb_18/chroma.sqlite3 differ diff --git a/backend/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf b/backend/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf new file mode 100644 index 0000000..4fa282e Binary files /dev/null and b/backend/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf differ diff --git a/backend/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf b/backend/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf new file mode 100644 index 0000000..4fa282e Binary files /dev/null and b/backend/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf differ diff --git a/backend/main.py b/backend/main.py index ac1a2eb..b9d19fb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -38,7 +38,7 @@ def setup_exception_handlers(app: FastAPI) -> None: async def http_exception_handler(request, exc): from utils.util_exceptions import HxfErrorResponse logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}") - return HxfErrorResponse(exc.status_code, exc.detail) + return HxfErrorResponse(exc) def make_json_serializable(obj): """递归地将对象转换为JSON可序列化的格式""" @@ -127,31 +127,5 @@ def add_router(app: FastAPI) -> None: # Include routers app.include_router(router, prefix="/api") - - - # app.include_router(table_metadata.router) - # # 在现有导入中添加 - # from ..api.endpoints import database_config - - # # 在路由注册部分添加 - # app.include_router(database_config.router) - # # Health check endpoint - # @app.get("/health") - # async def health_check(): - # return {"status": "healthy", "version": settings.app_version} - - # # Root endpoint - # @app.get("/") - # async def root(): - # return {"message": "Chat Agent API is running"} - - # # Test endpoint - # @app.get("/test") - # async def test_endpoint(): - # return {"message": "API is working"} - - app = create_app() - -# from utils.util_test import test_db -# test_db() \ No newline at end of file +# from test.example import internet_search_tool diff --git a/backend/test/__init__.py b/backend/test/__init__.py new file mode 100644 index 0000000..5616d01 --- /dev/null +++ b/backend/test/__init__.py @@ -0,0 +1 @@ +# Test package for PostgreSQL agent functionality \ No newline at end of file diff --git a/backend/test/example.py b/backend/test/example.py new file mode 100644 index 0000000..cc32553 --- /dev/null +++ b/backend/test/example.py @@ -0,0 +1,154 @@ +import os +import asyncio +from datetime import datetime +from deepagents import create_deep_agent +from openai import OpenAI +from langchain.chat_models import init_chat_model +from langchain.agents import create_agent +from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具 +from deepagents.backends import StoreBackend +from loguru import logger +def internet_search_tool(query: str): + """Run a web search""" + logger.info(f"Running internet search for query: {query}") + client = OpenAI( + api_key=os.getenv('DASHSCOPE_API_KEY'), + base_url=os.getenv('DASHSCOPE_BASE_URL'), + ) + logger.info(f"create OpenAI") + completion = client.chat.completions.create( + model="qwen-plus", + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': query} + ], + extra_body={ + "enable_search": True + } + ) + logger.info(f"create completions") + logger.info(f"OpenAI response: {completion.choices[0].message.content}") + return completion.choices[0].message.content + + + +# System prompt to steer the agent to be an expert researcher +today = datetime.now().strftime("%Y年%m月%d日") +research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。 + +你可以使用互联网搜索工具来获取信息。 +## `internet_search` +使用此工具对给定查询进行互联网搜索。你可以指定返回结果的最大数量、主题以及是否包含原始内容。 + +今天的日期是:{today} +""" + +# Create the deep agent with memory +model = init_chat_model( + model="gpt-4.1-mini", + model_provider='openai', + api_key=os.getenv('OPENAI_API_KEY'), + base_url=os.getenv('OPENAI_BASE_URL'), +) +checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史 + +agent = create_deep_agent( # state:thread会话级的状态 + tools=[internet_search_tool], + system_prompt=research_instructions, + model=model, + checkpointer=checkpointer, # 添加检查点,启用自动记忆 + interrupt_on={'internet_search_tool':True} +) + +# 多轮对话循环(使用 Checkpointer 自动记忆) +printed_msg_ids = set() # 跟踪已打印的消息ID +thread_id = "user_session_001" # 会话 ID,区分不同用户/会话 +config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话 + +print("开始对话(输入 'exit' 退出):") +while True: + user_input = input("\nHUMAN: ").strip() + if user_input.lower() == 'exit': + break + + # 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印 + pending_resume = None + while True: + if pending_resume is None: + request = {"messages": [{"role": "user", "content": user_input}]} + else: + from langgraph.types import Command as _Command + + request = _Command(resume=pending_resume) + pending_resume = None + + for item in agent.stream( + request, + config=config, + stream_mode="values", + ): + state = item[0] if isinstance(item, tuple) and len(item) == 2 else item + + # 先检查是否触发了 Human-In-The-Loop 中断 + if isinstance(state, dict) and "__interrupt__" in state: + interrupts = state["__interrupt__"] or [] + if interrupts: + hitl_payload = interrupts[0].value + action_requests = hitl_payload.get("action_requests", []) + + print("\n=== 需要人工审批的工具调用 ===") + decisions: list[dict[str, str]] = [] + for idx, ar in enumerate(action_requests): + name = ar.get("name") + args = ar.get("args") + print(f"[{idx}] 工具 {name} 参数: {args}") + while True: + choice = input(" 决策 (a=approve, r=reject): ").strip().lower() + if choice in ("a", "r"): + break + decisions.append({"type": "approve" if choice == "a" else "reject"}) + + # 下一轮调用改为 resume,同一轮用户回合继续往下跑 + pending_resume = {"decisions": decisions} + break + + # 兼容 dict state 和 AgentState dataclass + messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", []) + for msg in messages: + msg_id = getattr(msg, "id", None) + if msg_id is not None and msg_id in printed_msg_ids: + continue + if msg_id is not None: + printed_msg_ids.add(msg_id) + + msg_type = getattr(msg, "type", None) + + if msg_type == "human": + # 用户输入已经在命令行里,不再重复打印 + continue + + if msg_type == "ai": + tool_calls = getattr(msg, "tool_calls", None) or [] + if tool_calls: + # 这是发起工具调用的 AI 消息(TOOL CALL) + for tc in tool_calls: + tool_name = tc.get("name") + args = tc.get("args") + print(f"TOOL CALL [{tool_name}]: {args}") + # 如果 AI 同时带有自然语言内容,也一起打印 + if getattr(msg, "content", None): + print(f"AI: {msg.content}") + continue + + if msg_type == "tool": + # 工具执行结果(TOOL RESPONSE) + tool_name = getattr(msg, "name", None) or "tool" + print(f"TOOL RESPONSE [{tool_name}]: {msg.content}") + continue + + # 兜底:其它类型直接打印出来便于调试 + print(f"[{msg_type}]: {getattr(msg, 'content', None)}") + + # 如果没有新的中断需要 resume,则整轮结束,等待下一轮用户输入 + if pending_resume is None: + break diff --git a/backend/th_agenter/api/endpoints/auth.py b/backend/th_agenter/api/endpoints/auth.py index 098c26b..9086657 100644 --- a/backend/th_agenter/api/endpoints/auth.py +++ b/backend/th_agenter/api/endpoints/auth.py @@ -95,11 +95,13 @@ async def login_oauth( ) session.desc = f"用户 {user.username} OAuth2 登录成功" - return { - "access_token": access_token, - "token_type": "bearer", - "expires_in": settings.security.access_token_expire_minutes * 60 - } + return HxfResponse( + { + "access_token": access_token, + "token_type": "bearer", + "expires_in": settings.security.access_token_expire_minutes * 60 + } + ) @router.post("/refresh", response_model=Token, summary="刷新访问token") async def refresh_token( @@ -113,15 +115,17 @@ async def refresh_token( session, data={"sub": current_user.username}, expires_delta=access_token_expires ) - return Token( + response = Token( access_token=access_token, token_type="bearer", expires_in=settings.security.access_token_expire_minutes * 60 ) + return HxfResponse(response) @router.get("/me", response_model=UserResponse, summary="获取当前用户信息") async def get_current_user_info( current_user = Depends(AuthService.get_current_user) ): """获取当前用户信息""" - return UserResponse.model_validate(current_user, from_attributes=True) \ No newline at end of file + response = UserResponse.model_validate(current_user, from_attributes=True) + return HxfResponse(response) \ No newline at end of file diff --git a/backend/th_agenter/api/endpoints/chat.py b/backend/th_agenter/api/endpoints/chat.py index 7f9ed28..e1f1e11 100644 --- a/backend/th_agenter/api/endpoints/chat.py +++ b/backend/th_agenter/api/endpoints/chat.py @@ -1,5 +1,6 @@ """Chat endpoints for TH Agenter.""" +import json from typing import List from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse @@ -11,6 +12,8 @@ from ...models.user import User from ...services.auth import AuthService from ...services.chat import ChatService from ...services.conversation import ConversationService +from utils.util_exceptions import HxfResponse + from utils.util_schemas import ( ConversationCreate, ConversationResponse, @@ -23,83 +26,6 @@ from utils.util_schemas import ( router = APIRouter() -# Conversation management -@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话") -async def create_conversation( - conversation_data: ConversationCreate, - current_user: User = Depends(AuthService.get_current_user), - session: Session = Depends(get_session) -): - """创建新对话""" - session.desc = "START: 创建新对话" - conversation_service = ConversationService(session) - conversation = await conversation_service.create_conversation( - user_id=current_user.id, - conversation_data=conversation_data - ) - session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}" - return ConversationResponse.model_validate(conversation) - -@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表") -async def list_conversations( - skip: int = 0, - limit: int = 50, - search: str = None, - include_archived: bool = False, - order_by: str = "updated_at", - order_desc: bool = True, - session: Session = Depends(get_session) -): - """获取用户对话列表""" - session.desc = "START: 获取用户对话列表" - conversation_service = ConversationService(session) - conversations = await conversation_service.get_user_conversations( - skip=skip, - limit=limit, - search_query=search, - include_archived=include_archived, - order_by=order_by, - order_desc=order_desc - ) - session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话" - return [ConversationResponse.model_validate(conv) for conv in conversations] - -@router.get("/conversations/count", summary="获取用户对话总数") -async def get_conversations_count( - search: str = None, - include_archived: bool = False, - session: Session = Depends(get_session) -): - """获取用户对话总数""" - session.desc = "START: 获取用户对话总数" - conversation_service = ConversationService(session) - count = await conversation_service.get_user_conversations_count( - search_query=search, - include_archived=include_archived - ) - session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话" - return {"count": count} - -@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话") -async def get_conversation( - conversation_id: int, - session: Session = Depends(get_session) -): - """获取指定对话""" - session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}" - conversation_service = ConversationService(session) - conversation = await conversation_service.get_conversation( - conversation_id=conversation_id - ) - if not conversation: - session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Conversation not found" - ) - session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}" - return ConversationResponse.model_validate(conversation) - @router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话") async def update_conversation( conversation_id: int, @@ -113,7 +39,8 @@ async def update_conversation( conversation_id, conversation_update ) session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}" - return ConversationResponse.model_validate(updated_conversation) + response = ConversationResponse.model_validate(updated_conversation) + return HxfResponse(response) @router.delete("/conversations/{conversation_id}", summary="删除指定对话") @@ -126,7 +53,8 @@ async def delete_conversation( conversation_service = ConversationService(session) await conversation_service.delete_conversation(conversation_id) session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}" - return {"message": "Conversation deleted successfully"} + response = {"message": "Conversation deleted successfully"} + return HxfResponse(response) @router.put("/conversations/{conversation_id}/archive", summary="归档指定对话") @@ -145,7 +73,8 @@ async def archive_conversation( ) session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}" - return {"message": "Conversation archived successfully"} + response = {"message": "Conversation archived successfully"} + return HxfResponse(response) @router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话") @@ -165,7 +94,8 @@ async def unarchive_conversation( ) session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}" - return {"message": "Conversation unarchived successfully"} + response = {"message": "Conversation unarchived successfully"} + return HxfResponse(response) # Message management @@ -183,7 +113,8 @@ async def get_conversation_messages( conversation_id, skip=skip, limit=limit ) session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" - return [MessageResponse.model_validate(msg) for msg in messages] + response = [MessageResponse.model_validate(msg) for msg in messages] + return HxfResponse(response) # Chat functionality @router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应") @@ -195,48 +126,158 @@ async def chat( """发送消息并获取AI响应""" session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}" chat_service = ChatService(session) - response = await chat_service.chat( - conversation_id=conversation_id, - message=chat_request.message, - stream=False, - temperature=chat_request.temperature, - max_tokens=chat_request.max_tokens, - use_agent=chat_request.use_agent, - use_langgraph=chat_request.use_langgraph, - use_knowledge_base=chat_request.use_knowledge_base, - knowledge_base_id=chat_request.knowledge_base_id - ) + await chat_service.initialize(conversation_id) + + # response = await chat_service.chat( + # conversation_id=conversation_id, + # message=chat_request.message, + # stream=False, + # temperature=chat_request.temperature, + # max_tokens=chat_request.max_tokens, + # use_agent=chat_request.use_agent, # 可以简化掉 + # use_langgraph=chat_request.use_langgraph, # 可以简化掉 + # use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉 + # knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉 + # ) + response = "oooooooooooooooooooK" session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}" - return response - + return HxfResponse(response) +# ------------------------------------------------------------------------ @router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应") async def chat_stream( conversation_id: int, chat_request: ChatRequest, + current_user = Depends(AuthService.get_current_user), session: Session = Depends(get_session) ): """发送消息并获取流式AI响应.""" + session.title = f"对话{conversation_id} 发送消息并获取流式AI响应" + session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> " chat_service = ChatService(session) + await chat_service.initialize(conversation_id, streaming=True) - async def generate_response(): - async for chunk in chat_service.chat_stream( - conversation_id=conversation_id, - message=chat_request.message, - temperature=chat_request.temperature, - max_tokens=chat_request.max_tokens, - use_agent=chat_request.use_agent, - use_langgraph=chat_request.use_langgraph, - use_knowledge_base=chat_request.use_knowledge_base, - knowledge_base_id=chat_request.knowledge_base_id - ): - yield f"data: {chunk}\n\n" + async def generate_response(chat_service): + try: + async for chunk in chat_service.chat_stream( + message=chat_request.message + ): + yield chunk + "\n" + except Exception as e: + logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}") + yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"} - return StreamingResponse( - generate_response(), - media_type="text/plain", + response = StreamingResponse( + generate_response(chat_service), + media_type="text/stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) + + return response + +# Conversation management +@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话") +async def create_conversation( + conversation_data: ConversationCreate, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """创建新对话""" + id = current_user.id + session.title = f"用户{current_user.username} - 创建新对话" + session.desc = "START: 创建新对话" + conversation_service = ConversationService(session) + conversation = await conversation_service.create_conversation( + user_id=id, + conversation_data=conversation_data + ) + session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}" + response = ConversationResponse.model_validate(conversation) + return HxfResponse(response) +@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表") +async def list_conversations( + skip: int = 0, + limit: int = 50, + search: str = None, + include_archived: bool = False, + order_by: str = "updated_at", + order_desc: bool = True, + session: Session = Depends(get_session) +): + """获取用户对话列表""" + session.title = "获取用户对话列表" + session.desc = "START: 获取用户对话列表" + conversation_service = ConversationService(session) + conversations = await conversation_service.get_user_conversations( + skip=skip, + limit=limit, + search_query=search, + include_archived=include_archived, + order_by=order_by, + order_desc=order_desc + ) + session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..." + response = [ConversationResponse.model_validate(conv) for conv in conversations] + return HxfResponse(response) + +@router.get("/conversations/count", summary="获取用户对话总数") +async def get_conversations_count( + search: str = None, + include_archived: bool = False, + session: Session = Depends(get_session) +): + """获取用户对话总数""" + from th_agenter.core.context import UserContext + user_id = UserContext.get_current_user_id() + session.title = f"获取用户对话总数[用户id = {user_id}]" + session.desc = "START: 获取用户对话总数" + conversation_service = ConversationService(session) + count = await conversation_service.get_user_conversations_count( + search_query=search, + include_archived=include_archived + ) + session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话" + response = {"count": count} + return HxfResponse(response) + +@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话") +async def get_conversation( + conversation_id: int, + session: Session = Depends(get_session) +): + """获取指定对话""" + session.title = f"获取指定对话[对话id = {conversation_id}]" + session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}" + + conversation_service = ConversationService(session) + conversation = await conversation_service.get_conversation( + conversation_id=conversation_id + ) + if not conversation: + session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Conversation not found" + ) + session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}" + + response = ConversationResponse.model_validate(conversation) + + + # chat_service = ChatService(session) + # await chat_service.initialize(conversation_id, streaming=False) + # messages = await chat_service.get_conversation_history_messages( + # conversation_id + # ) + # response.messages = messages + + messages = await conversation_service.get_conversation_messages( + conversation_id, skip=0, limit=100 + ) + response.messages = [MessageResponse.model_validate(msg) for msg in messages] + + response.message_count = len(response.messages) + return HxfResponse(response) diff --git a/backend/th_agenter/api/endpoints/database_config.py b/backend/th_agenter/api/endpoints/database_config.py index 3bd1149..f0537c6 100644 --- a/backend/th_agenter/api/endpoints/database_config.py +++ b/backend/th_agenter/api/endpoints/database_config.py @@ -9,7 +9,7 @@ from th_agenter.db.database import get_session from th_agenter.services.database_config_service import DatabaseConfigService from th_agenter.services.auth import AuthService from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse - +from utils.util_exceptions import HxfResponse # 在文件顶部添加 from functools import lru_cache @@ -68,11 +68,12 @@ async def create_database_config( ): """创建或更新数据库配置""" config = await service.create_or_update_config(current_user.id, config_data.model_dump()) - return NormalResponse( + response = NormalResponse( success=True, message="保存数据库配置成功", data=config ) + return HxfResponse(response) @router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表") async def get_database_configs( @@ -83,7 +84,7 @@ async def get_database_configs( configs = service.get_user_configs(current_user.id) config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs] - return config_list + return HxfResponse(config_list) @router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接") async def test_database_connection( @@ -93,7 +94,7 @@ async def test_database_connection( ): """测试数据库连接""" result = await service.test_connection(config_id, current_user.id) - return result + return HxfResponse(result) @router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表") async def connect_database( @@ -103,7 +104,7 @@ async def connect_database( ): """连接数据库并获取表列表""" result = await service.connect_and_get_tables(config_id, current_user.id) - return result + return HxfResponse(result) @router.get("/tables/{table_name}/data", summary="获取表数据预览") @@ -117,7 +118,7 @@ async def get_table_data( """获取表数据预览""" try: result = await service.get_table_data(table_name, current_user.id, db_type, limit) - return result + return HxfResponse(result) except Exception as e: logger.error(f"获取表数据失败: {str(e)}") raise HTTPException( @@ -133,7 +134,7 @@ async def get_table_schema( ): """获取表结构信息""" result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的? - return result + return HxfResponse(result) @router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置") async def get_config_by_type( @@ -149,4 +150,4 @@ async def get_config_by_type( detail=f"未找到类型为 {db_type} 的配置" ) # 返回包含解密密码的配置 - return config.to_dict(include_password=True, decrypt_service=service) \ No newline at end of file + return HxfResponse(config.to_dict(include_password=True, decrypt_service=service)) \ No newline at end of file diff --git a/backend/th_agenter/api/endpoints/knowledge_base.py b/backend/th_agenter/api/endpoints/knowledge_base.py index b3fb6fa..ec211e4 100644 --- a/backend/th_agenter/api/endpoints/knowledge_base.py +++ b/backend/th_agenter/api/endpoints/knowledge_base.py @@ -1,5 +1,6 @@ """Knowledge base API endpoints.""" +from utils.util_exceptions import HxfResponse from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status from fastapi.responses import JSONResponse @@ -35,10 +36,10 @@ async def create_knowledge_base( ): """创建新的知识库""" # Check if knowledge base with same name already exists for this user - session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data.name}" - service = KnowledgeBaseService(session) + session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}" + kb_service = KnowledgeBaseService(session) session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}" - existing_kb = service.get_knowledge_base_by_name(kb_data.name) + existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name) if existing_kb: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -47,10 +48,10 @@ async def create_knowledge_base( # Create knowledge base session.desc = f"知识库 {kb_data.name}不存在,创建之" - kb = service.create_knowledge_base(kb_data) - + kb = await kb_service.create_knowledge_base(kb_data) + session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功" - return KnowledgeBaseResponse( + response = KnowledgeBaseResponse( id=kb.id, created_at=kb.created_at, updated_at=kb.updated_at, @@ -65,7 +66,7 @@ async def create_knowledge_base( document_count=0, active_document_count=0 ) - + return HxfResponse(response) @router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库") async def list_knowledge_bases( @@ -76,18 +77,17 @@ async def list_knowledge_bases( current_user: User = Depends(AuthService.get_current_user) ): """获取当前用户的所有知识库""" - session.desc = f"START: 获取用户 {current_user.username} 的所有知识库" - service = KnowledgeBaseService(session) - session.desc = f"获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})" - knowledge_bases = await service.get_knowledge_bases(skip=skip, limit=limit) + session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})" + kb_service = KnowledgeBaseService(session) + knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit) result = [] for kb in knowledge_bases: - # Count documents + # 本知识库的文档数量 total_docs = await session.scalar( select(func.count()).where(Document.knowledge_base_id == kb.id) ) - + # 本知识库的已处理文档数量 active_docs = await session.scalar( select(func.count()).where( Document.knowledge_base_id == kb.id, @@ -112,7 +112,7 @@ async def list_knowledge_bases( )) session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库" - return result + return HxfResponse(result) @router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情") async def get_knowledge_base( @@ -124,7 +124,7 @@ async def get_knowledge_base( session.desc = f"START: 获取知识库 {kb_id} 的详情" service = KnowledgeBaseService(session) session.desc = f"检查知识库 {kb_id} 是否存在" - kb = service.get_knowledge_base(kb_id) + kb = await service.get_knowledge_base(kb_id) if not kb: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( @@ -146,7 +146,7 @@ async def get_knowledge_base( ) session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理" - return KnowledgeBaseResponse( + response = KnowledgeBaseResponse( id=kb.id, created_at=kb.created_at, updated_at=kb.updated_at, @@ -161,6 +161,7 @@ async def get_knowledge_base( document_count=total_docs, active_document_count=active_docs ) + return HxfResponse(response) @router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库") async def update_knowledge_base( @@ -172,7 +173,7 @@ async def update_knowledge_base( """更新知识库""" session.desc = f"START: 更新知识库 {kb_id}" service = KnowledgeBaseService(session) - kb = service.update_knowledge_base(kb_id, kb_data) + kb = await service.update_knowledge_base(kb_id, kb_data) if not kb: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( @@ -192,8 +193,8 @@ async def update_knowledge_base( ) ) - session.desc = f"SUCCESS: 更新知识库 {kb_id},共 {total_docs} 个文档,其中 {active_docs} 个已处理" - return KnowledgeBaseResponse( + session.desc = f"SUCCESS: 更新知识库 {kb_id},结果 - 共 {total_docs} 个文档,其中 {active_docs} 个已处理" + response = KnowledgeBaseResponse( id=kb.id, created_at=kb.created_at, updated_at=kb.updated_at, @@ -208,6 +209,7 @@ async def update_knowledge_base( document_count=total_docs, active_document_count=active_docs ) + return HxfResponse(response) @router.delete("/{kb_id}", summary="删除知识库") async def delete_knowledge_base( @@ -218,7 +220,7 @@ async def delete_knowledge_base( """删除知识库""" session.desc = f"START: 删除知识库 {kb_id}" service = KnowledgeBaseService(session) - success = service.delete_knowledge_base(kb_id) + success = await service.delete_knowledge_base(kb_id) if not success: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( @@ -227,7 +229,7 @@ async def delete_knowledge_base( ) session.desc = f"SUCCESS: 删除知识库 {kb_id}" - return {"message": "Knowledge base deleted successfully"} + return HxfResponse({"message": "Knowledge base deleted successfully"}) # Document management endpoints @router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库") @@ -239,18 +241,18 @@ async def upload_document( current_user: User = Depends(AuthService.get_current_user) ): """上传文档到知识库""" - session.desc = f"START: 上传文档到知识库 {kb_id}" + session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})" # Verify knowledge base exists and user has access kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) + kb = await kb_service.get_knowledge_base(kb_id) if not kb: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found" ) - + session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}" # Validate file if not FileUtils.validate_file_extension(file.filename): session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}" @@ -258,7 +260,6 @@ async def upload_document( status_code=status.HTTP_400_BAD_REQUEST, detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}" ) - # Check file size (50MB limit) max_size = 50 * 1024 * 1024 # 50MB if file.size and file.size > max_size: @@ -268,6 +269,7 @@ async def upload_document( detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制" ) + session.desc = f"文件为期望类型,处理文件 {file.filename} - " # Upload document doc_service = DocumentService(session) document = await doc_service.upload_document( @@ -284,7 +286,7 @@ async def upload_document( session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}" session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}" - return DocumentResponse( + response = DocumentResponse( id=document.id, created_at=document.created_at, updated_at=document.updated_at, @@ -301,6 +303,7 @@ async def upload_document( embedding_model=document.embedding_model, file_size_mb=round(document.file_size / (1024 * 1024), 2) ) + return HxfResponse(response) @router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表") async def list_documents( @@ -314,7 +317,8 @@ async def list_documents( session.desc = f"START: 获取知识库 {kb_id} 中的文档列表" # Verify knowledge base exists and user has access kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) + + kb = await kb_service.get_knowledge_base(kb_id) if not kb: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( @@ -323,7 +327,7 @@ async def list_documents( ) doc_service = DocumentService(session) - documents, total = doc_service.list_documents(kb_id, skip, limit) + documents, total = await doc_service.list_documents(kb_id, skip, limit) doc_responses = [] for doc in documents: @@ -346,208 +350,13 @@ async def list_documents( )) session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total} 条" - return DocumentListResponse( + response = DocumentListResponse( documents=doc_responses, total=total, page=skip // limit + 1, page_size=limit ) - -@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情") -async def get_document( - kb_id: int, - doc_id: int, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """获取知识库中的文档详情。""" - session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" - # Verify knowledge base exists and user has access - kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) - if not kb: - session.desc = f"ERROR: 知识库 {kb_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Knowledge base not found" - ) - - doc_service = DocumentService(session) - document = doc_service.get_document(doc_id, kb_id) - if not document: - session.desc = f"ERROR: 文档 {doc_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" - ) - - session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" - return DocumentResponse( - id=document.id, - created_at=document.created_at, - updated_at=document.updated_at, - knowledge_base_id=document.knowledge_base_id, - filename=document.filename, - original_filename=document.original_filename, - file_path=document.file_path, - file_type=document.file_type, - file_size=document.file_size, - mime_type=document.mime_type, - is_processed=document.is_processed, - processing_error=document.processing_error, - chunk_count=document.chunk_count or 0, - embedding_model=document.embedding_model, - file_size_mb=round(document.file_size / (1024 * 1024), 2) - ) - -@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档") -async def delete_document( - kb_id: int, - doc_id: int, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """删除知识库中的文档。""" - session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}" - kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) - if not kb: - session.desc = f"ERROR: 知识库 {kb_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Knowledge base not found" - ) - - doc_service = DocumentService(session) - success = doc_service.delete_document(doc_id, kb_id) - if not success: - session.desc = f"ERROR: 文档 {doc_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" - ) - - session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}" - return {"message": "Document deleted successfully"} - -@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档") -async def process_document( - kb_id: int, - doc_id: int, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """处理知识库中的文档,用于向量搜索。""" - session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}" - kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) - if not kb: - session.desc = f"ERROR: 知识库 {kb_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Knowledge base not found" - ) - - # Check if document exists - doc_service = DocumentService(session) - document = doc_service.get_document(doc_id, kb_id) - if not document: - session.desc = f"ERROR: 文档 {doc_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" - ) - - # Process the document - result = await doc_service.process_document(doc_id, kb_id) - session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}" - return DocumentProcessingStatus( - document_id=doc_id, - status=result["status"], - progress=result.get("progress", 0.0), - error_message=result.get("error_message"), - chunks_created=result.get("chunks_created", 0) - ) - -@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态") -async def get_document_processing_status( - kb_id: int, - doc_id: int, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """获取知识库中的文档处理状态。""" - # Verify knowledge base exists and user has access - session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态" - kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) - if not kb: - session.desc = f"ERROR: 知识库 {kb_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Knowledge base not found" - ) - - doc_service = DocumentService(session) - document = doc_service.get_document(doc_id, kb_id) - if not document: - session.desc = f"ERROR: 文档 {doc_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" - ) - - # Determine status - if document.processing_error: - status_str = "failed" - progress = 0.0 - session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}" - elif document.is_processed: - status_str = "completed" - progress = 100.0 - session.desc = f"SUCCESS: 文档 {doc_id} 处理完成" - else: - status_str = "pending" - progress = 0.0 - session.desc = f"文档 {doc_id} 处理pending中" - - return DocumentProcessingStatus( - document_id=document.id, - status=status_str, - progress=progress, - error_message=document.processing_error, - chunks_created=document.chunk_count or 0 - ) - -@router.get("/{kb_id}/search", summary="在知识库中搜索文档") -async def search_knowledge_base( - kb_id: int, - query: str, - limit: int = 5, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """在知识库中搜索文档。""" - session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}" - kb_service = KnowledgeBaseService(session) - kb = kb_service.get_knowledge_base(kb_id) - if not kb: - session.desc = f"ERROR: 知识库 {kb_id} 不存在" - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Knowledge base not found" - ) - - # Perform search - doc_service = DocumentService(session) - results = doc_service.search_documents(kb_id, query, limit) - session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果" - return { - "knowledge_base_id": kb_id, - "query": query, - "results": results, - "total_results": len(results) - } + return HxfResponse(response) @router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)") async def get_document_chunks( @@ -570,7 +379,8 @@ async def get_document_chunks( """ session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)" kb_service = KnowledgeBaseService(session) - knowledge_base = kb_service.get_knowledge_base(kb_id) + knowledge_base = await kb_service.get_knowledge_base(kb_id) + if not knowledge_base: session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise HTTPException( @@ -580,7 +390,9 @@ async def get_document_chunks( # Verify document exists in the knowledge base doc_service = DocumentService(session) - document = doc_service.get_document(doc_id, kb_id) + session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService" + document = await doc_service.get_document(doc_id, kb_id) + session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document" if not document: session.desc = f"ERROR: 文档 {doc_id} 不存在" raise HTTPException( @@ -589,11 +401,215 @@ async def get_document_chunks( ) # Get document chunks - chunks = doc_service.get_document_chunks(doc_id) + chunks = await doc_service.get_document_chunks(doc_id) + session.desc = f"SUCCESS: 获取文档 {doc_id} 共 {len(chunks)} 个文档块(片段)" - return DocumentChunksResponse( + response = DocumentChunksResponse( document_id=doc_id, document_name=document.filename, total_chunks=len(chunks), chunks=chunks - ) \ No newline at end of file + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情") +async def get_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取知识库中的文档详情。""" + session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" + # Verify knowledge base exists and user has access + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" + response = DocumentResponse( + id=document.id, + created_at=document.created_at, + updated_at=document.updated_at, + knowledge_base_id=document.knowledge_base_id, + filename=document.filename, + original_filename=document.original_filename, + file_path=document.file_path, + file_type=document.file_type, + file_size=document.file_size, + mime_type=document.mime_type, + is_processed=document.is_processed, + processing_error=document.processing_error, + chunk_count=document.chunk_count or 0, + embedding_model=document.embedding_model, + file_size_mb=round(document.file_size / (1024 * 1024), 2) + ) + return HxfResponse(response) + +@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档") +async def delete_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """删除知识库中的文档。""" + session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + success = await doc_service.delete_document(doc_id, kb_id) + if not success: + session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}" + response = {"message": "Document deleted successfully"} + return HxfResponse(response) + +@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档") +async def process_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """处理知识库中的文档,用于向量搜索。""" + session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Check if document exists + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + # Process the document + result = await doc_service.process_document(doc_id, kb_id) + await session.refresh(document) + session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}" + response = DocumentProcessingStatus( + document_id=doc_id, + status=result["status"], + progress=result.get("progress", 0.0), + error_message=result.get("error_message"), + chunks_created=result.get("chunks_created", 0) + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态") +async def get_document_processing_status( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取知识库中的文档处理状态。""" + # Verify knowledge base exists and user has access + session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + # Determine status + if document.processing_error: + status_str = "failed" + progress = 0.0 + session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}" + elif document.is_processed: + status_str = "completed" + progress = 100.0 + session.desc = f"SUCCESS: 文档 {doc_id} 处理完成" + else: + status_str = "pending" + progress = 0.0 + session.desc = f"文档 {doc_id} 处理pending中" + + response = DocumentProcessingStatus( + document_id=document.id, + status=status_str, + progress=progress, + error_message=document.processing_error, + chunks_created=document.chunk_count or 0 + ) + return HxfResponse(response) + +@router.get("/{kb_id}/search", summary="在知识库中搜索文档") +async def search_knowledge_base( + kb_id: int, + query: str, + limit: int = 5, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """在知识库中搜索文档。""" + session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Perform search + doc_service = DocumentService(session) + results = await doc_service.search_documents(kb_id, query, limit) + session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果" + response = { + "knowledge_base_id": kb_id, + "query": query, + "results": results, + "total_results": len(results) + } + return HxfResponse(response) diff --git a/backend/th_agenter/api/endpoints/llm_configs.py b/backend/th_agenter/api/endpoints/llm_configs.py index eff149d..14f4467 100644 --- a/backend/th_agenter/api/endpoints/llm_configs.py +++ b/backend/th_agenter/api/endpoints/llm_configs.py @@ -1,14 +1,21 @@ """LLM configuration management API endpoints.""" +from turtle import textinput from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status, Query +from langchain_openai import ChatOpenAI +from langchain_core.messages import AIMessage from sqlalchemy.orm import Session from sqlalchemy import or_, select, delete, update from loguru import logger + +from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM +from th_agenter.llm.online.online_llm import OnlineLLM from ...db.database import get_session from ...models.user import User from ...models.llm_config import LLMConfig +from th_agenter.llm.base_llm import LLMConfig_DataClass from ...core.simple_permissions import require_super_admin, require_authenticated_user from ...schemas.llm_config import ( LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse, @@ -31,6 +38,7 @@ async def get_llm_configs( current_user: User = Depends(require_authenticated_user) ): """获取大模型配置列表.""" + session.title = "获取大模型配置列表" session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}" stmt = select(LLMConfig) @@ -62,7 +70,7 @@ async def get_llm_configs( # 分页 stmt = stmt.offset(skip).limit(limit) configs = (await session.execute(stmt)).scalars().all() - session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置" + session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..." return HxfResponse([config.to_dict(include_sensitive=True) for config in configs]) @@ -131,7 +139,7 @@ async def get_llm_config( ): """获取大模型配置详情.""" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -149,17 +157,20 @@ async def create_llm_config( ): """创建大模型配置.""" # 检查配置名称是否已存在 + # 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误 + username = current_user.username session.desc = f"START: 创建大模型配置, name={config_data.name}" stmt = select(LLMConfig).where(LLMConfig.name == config_data.name) - existing_config = session.execute(stmt).scalar_one_or_none() + existing_config = (await session.execute(stmt)).scalar_one_or_none() if existing_config: + session.desc = f"ERROR: 配置名称已存在, name={config_data.name}" raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="配置名称已存在" ) - # 创建临时配置对象进行验证 - temp_config = LLMConfig( + # 创建配置对象 + config = LLMConfig_DataClass( name=config_data.name, provider=config_data.provider, model_name=config_data.model_name, @@ -178,7 +189,7 @@ async def create_llm_config( ) # 验证配置 - validation_result = temp_config.validate_config() + validation_result = config.validate_config() if not validation_result['valid']: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -190,10 +201,11 @@ async def create_llm_config( stmt = update(LLMConfig).where( LLMConfig.is_embedding == config_data.is_embedding ).values({"is_default": False}) - session.execute(stmt) + await session.execute(stmt) + session.desc = f"验证大模型配置, config_data" # 创建配置 - config = LLMConfig( + config = LLMConfig_DataClass( name=config_data.name, provider=config_data.provider, model_name=config_data.model_name, @@ -213,10 +225,9 @@ async def create_llm_config( # Audit fields are set automatically by SQLAlchemy event listener session.add(config) - session.commit() - session.refresh(config) - - session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}" + await session.commit() + await session.refresh(config) + session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}" return HxfResponse(config.to_dict()) @@ -228,9 +239,10 @@ async def update_llm_config( current_user: User = Depends(require_super_admin) ): """更新大模型配置.""" + username = current_user.username session.desc = f"START: 更新大模型配置, id={config_id}" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -243,7 +255,7 @@ async def update_llm_config( LLMConfig.name == config_data.name, LLMConfig.id != config_id ) - existing_config = session.execute(stmt).scalar_one_or_none() + existing_config = (await session.execute(stmt)).scalar_one_or_none() if existing_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -258,17 +270,17 @@ async def update_llm_config( LLMConfig.is_embedding == is_embedding, LLMConfig.id != config_id ).values({"is_default": False}) - session.execute(stmt) + await session.execute(stmt) # 更新字段 update_data = config_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(config, field, value) - session.commit() - session.refresh(config) + await session.commit() + await session.refresh(config) - session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}" + session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}" return HxfResponse(config.to_dict()) @router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置") @@ -278,22 +290,24 @@ async def delete_llm_config( current_user: User = Depends(require_super_admin) ): """删除大模型配置.""" + username = current_user.username session.desc = f"START: 删除大模型配置, id={config_id}" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在" ) - + session.desc = f"待删除大模型记录 {config.to_dict()}" # TODO: 检查是否有对话或其他功能正在使用该配置 # 这里可以添加相关的检查逻辑 - session.delete(config) - session.commit() + # 删除配置 + await session.delete(config) + await session.commit() - session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}" + session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}" return HxfResponse({"message": "LLM config deleted successfully"}) @router.post("/{config_id}/test", summary="测试连接大模型配置") @@ -304,17 +318,21 @@ async def test_llm_config( current_user: User = Depends(require_super_admin) ): """测试连接大模型配置.""" - session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}" + username = current_user.username + session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在" ) + logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}") + config_name = config.name # 验证配置 validation_result = config.validate_config() + logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}") if not validation_result["valid"]: return { "success": False, @@ -322,41 +340,54 @@ async def test_llm_config( "details": validation_result } + session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}" # 尝试创建客户端并发送测试请求 try: - # 这里应该根据不同的服务商创建相应的客户端 - # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架 + # # 这里应该根据不同的服务商创建相应的客户端 + # # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架 test_message = test_data.message or "Hello, this is a test message." + session.desc = f"准备测试LLM功能 > test_message = {test_message}" - # TODO: 实现具体的测试逻辑 - # 例如: - # client = config.get_client() - # response = client.chat.completions.create( - # model=config.model_name, - # messages=[{"role": "user", "content": test_message}], - # max_tokens=100 - # ) - - # 模拟测试成功 - session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}" + if config.is_embedding: + config.provider = "ollama" + streaming_llm = BGEEmbedLLM(config) + else: + streaming_llm = OnlineLLM(config) + session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}" + streaming_llm.load_model() # 加载模型 + session.desc = f"加载模型完毕,模型名称:{config.model_name},base_url: {config.base_url},准备测试对话..." + if config.is_embedding: + # 测试嵌入模型,使用嵌入API而非聊天API + test_text = test_message or "Hello, this is a test message for embedding" + response = streaming_llm.embed_query(test_text) + else: + # 测试聊天模型 + from langchain.messages import SystemMessage, HumanMessage + messages = [ + SystemMessage(content="你是一个简洁的助手,回答控制在50字以内"), + HumanMessage(content=test_message) + ] + response = streaming_llm.model.invoke(messages) + session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}" + return HxfResponse({ "success": True, - "message": "配置测试成功", - "test_message": test_message, - "response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。", + "message": "LLM测试成功", + "request": test_message, + "response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典 "latency_ms": 150, # 模拟延迟 - "config_info": config.get_client_config() + "config_info": config.to_dict() }) except Exception as test_error: session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}" return HxfResponse({ "success": False, - "message": f"配置测试失败: {str(test_error)}", + "message": f"LLM测试失败: {str(test_error)}", "test_message": test_message, - "config_info": config.get_client_config() + "config_info": config.to_dict() }) @router.post("/{config_id}/toggle-status", summary="切换大模型配置状态") @@ -366,10 +397,11 @@ async def toggle_llm_config_status( current_user: User = Depends(require_super_admin) ): """切换大模型配置状态.""" - session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}" + username = current_user.username + session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -380,11 +412,11 @@ async def toggle_llm_config_status( config.is_active = not config.is_active # Audit fields are set automatically by SQLAlchemy event listener - session.commit() - session.refresh(config) + await session.commit() + await session.refresh(config) status_text = "激活" if config.is_active else "禁用" - session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}" + session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}" return HxfResponse({ "message": f"配置已{status_text}", @@ -399,10 +431,11 @@ async def set_default_llm_config( current_user: User = Depends(require_super_admin) ): """设置默认大模型配置.""" - session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}" + username = current_user.username + session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}" stmt = select(LLMConfig).where(LLMConfig.id == config_id) - config = session.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -421,19 +454,19 @@ async def set_default_llm_config( LLMConfig.is_embedding == config.is_embedding, LLMConfig.id != config_id ).values({"is_default": False}) - session.execute(stmt) + await session.execute(stmt) # 设置当前配置为默认 config.is_default = True config.set_audit_fields(current_user.id, is_update=True) - session.commit() - session.refresh(config) + await session.commit() + await session.refresh(config) model_type = "嵌入模型" if config.is_embedding else "对话模型" # 更新文档处理器默认embedding - get_document_processor()._init_embeddings() - session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}" + await get_document_processor(session)._init_embeddings() + session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}" return HxfResponse({ "message": f"已将 {config.name} 设为默认{model_type}配置", "is_default": config.is_default diff --git a/backend/th_agenter/api/endpoints/roles.py b/backend/th_agenter/api/endpoints/roles.py index aa46227..5db58e6 100644 --- a/backend/th_agenter/api/endpoints/roles.py +++ b/backend/th_agenter/api/endpoints/roles.py @@ -1,5 +1,6 @@ """Role management API endpoints.""" +from utils.util_exceptions import HxfResponse from loguru import logger from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status, Query @@ -49,7 +50,8 @@ async def get_roles( stmt = stmt.offset(skip).limit(limit) roles = (await session.execute(stmt)).scalars().all() session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色" - return [role.to_dict() for role in roles] + response = [role.to_dict() for role in roles] + return HxfResponse(response) @router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情") async def get_role( @@ -67,7 +69,8 @@ async def get_role( detail="角色不存在" ) - return role.to_dict() + response = role.to_dict() + return HxfResponse(response) @router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色") async def create_role( @@ -95,12 +98,13 @@ async def create_role( ) role.set_audit_fields(current_user.id) - await session.add(role) + session.add(role) await session.commit() await session.refresh(role) logger.info(f"Role created: {role.name} by user {current_user.username}") - return role.to_dict() + response = role.to_dict() + return HxfResponse(response) @router.put("/{role_id}", response_model=RoleResponse, summary="更新角色") async def update_role( @@ -152,7 +156,8 @@ async def update_role( await session.refresh(role) logger.info(f"Role updated: {role.name} by user {current_user.username}") - return role.to_dict() + response = role.to_dict() + return HxfResponse(response) @router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色") async def delete_role( @@ -190,7 +195,8 @@ async def delete_role( await session.commit() session.desc = f"角色删除成功: {role.name} by user {current_user.username}" - return {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"} + response = {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"} + return HxfResponse(response) # 用户角色管理路由 user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"]) @@ -230,13 +236,14 @@ async def assign_user_roles( user_id=assignment_data.user_id, role_id=role_id ) - await session.add(user_role) + session.add(user_role) await session.commit() session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}" - return {"message": "角色分配成功"} + response = {"message": "角色分配成功"} + return HxfResponse(response) @user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表") async def get_user_roles( @@ -267,7 +274,8 @@ async def get_user_roles( ) roles = (await session.execute(stmt)).scalars().all() - return [role.to_dict() for role in roles] + response = [role.to_dict() for role in roles] + return HxfResponse(response) # 将子路由添加到主路由 router.include_router(user_role_router) \ No newline at end of file diff --git a/backend/th_agenter/api/endpoints/smart_chat.py b/backend/th_agenter/api/endpoints/smart_chat.py index a8237fc..a4d6e9f 100644 --- a/backend/th_agenter/api/endpoints/smart_chat.py +++ b/backend/th_agenter/api/endpoints/smart_chat.py @@ -12,6 +12,7 @@ from th_agenter.services.conversation_context import conversation_context_servic from utils.util_schemas import BaseResponse from pydantic import BaseModel from loguru import logger +from utils.util_exceptions import HxfResponse router = APIRouter(prefix="/smart-chat", tags=["smart-chat"]) security = HTTPBearer() @@ -65,6 +66,8 @@ async def smart_query( # 初始化工作流管理器 workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() + conversation_service = ConversationService(session) # 处理对话上下文 @@ -126,7 +129,7 @@ async def smart_query( except Exception as e: session.desc = f"ERROR: 智能查询执行失败: {e}" # 返回结构化的错误响应 - return SmartQueryResponse( + response = SmartQueryResponse( success=False, message=f"查询执行失败: {str(e)}", data={'error_type': 'query_execution_error'}, @@ -137,6 +140,7 @@ async def smart_query( }], conversation_id=conversation_id ) + return HxfResponse(response) # 如果查询成功,保存助手回复和更新上下文 if result['success'] and conversation_id: @@ -169,21 +173,22 @@ async def smart_query( if conversation_id: response_data['conversation_id'] = conversation_id session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}" - return SmartQueryResponse( + response = SmartQueryResponse( success=result['success'], message=result.get('message', '查询完成'), data=response_data, workflow_steps=result.get('workflow_steps', []), conversation_id=conversation_id ) + return HxfResponse(response) - except HTTPException: + except HTTPException as e: session.desc = f"EXCEPTION: HTTP异常: {e}" - raise + raise e except Exception as e: session.desc = f"ERROR: 智能查询接口异常: {e}" # 返回通用错误响应 - return SmartQueryResponse( + response = SmartQueryResponse( success=False, message="服务器内部错误,请稍后重试", data={'error_type': 'internal_server_error'}, @@ -194,6 +199,7 @@ async def smart_query( }], conversation_id=conversation_id ) + return HxfResponse(response) @router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文") async def get_conversation_context( @@ -227,11 +233,12 @@ async def get_conversation_context( history = await conversation_context_service.get_conversation_history(conversation_id) context['message_history'] = history session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}" - return ConversationContextResponse( + response = ConversationContextResponse( success=True, message="获取对话上下文成功", data=context ) + return HxfResponse(response) @router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息") @@ -244,6 +251,7 @@ async def get_files_status( """ session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}" workflow_manager = SmartWorkflowManager() + await workflow_manager.initialize() # 获取用户文件列表 file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id) @@ -277,11 +285,12 @@ async def get_files_status( } session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}" - return ConversationContextResponse( + response = ConversationContextResponse( success=True, message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件", data=status_data ) + return HxfResponse(response) @router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文") async def reset_conversation_context( @@ -315,10 +324,11 @@ async def reset_conversation_context( if success: session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}" - return { - "success": True, - "message": "对话上下文已重置,可以开始新的数据分析会话" - } + response = ConversationContextResponse( + success=True, + message="对话上下文已重置,可以开始新的数据分析会话" + ) + return HxfResponse(response) else: session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}" raise HTTPException( diff --git a/backend/th_agenter/api/endpoints/smart_query.py b/backend/th_agenter/api/endpoints/smart_query.py index 7b45dd9..f12f410 100644 --- a/backend/th_agenter/api/endpoints/smart_query.py +++ b/backend/th_agenter/api/endpoints/smart_query.py @@ -30,6 +30,7 @@ from th_agenter.services.smart_workflow import SmartWorkflowManager from th_agenter.services.conversation_context import ConversationContextService from pydantic import BaseModel from loguru import logger +from utils.util_exceptions import HxfResponse router = APIRouter(prefix="/smart-query", tags=["smart-query"]) security = HTTPBearer() @@ -163,12 +164,13 @@ async def upload_excel( }) session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}" - return ExcelUploadResponse( + response = ExcelUploadResponse( file_id=excel_file.id, success=True, message="Excel文件上传成功", data=analysis_result ) + return HxfResponse(response) @router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据") async def preview_excel( @@ -233,7 +235,7 @@ async def preview_excel( data = paginated_df.fillna('').to_dict('records') columns = df.columns.tolist() session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据" - return QueryResponse( + response = QueryResponse( success=True, message="Excel文件预览加载成功", data={ @@ -245,6 +247,7 @@ async def preview_excel( 'total_pages': (total_rows + request.page_size - 1) // request.page_size } ) + return HxfResponse(response) @router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接") async def test_database_connection( @@ -264,10 +267,11 @@ async def test_database_connection( message="数据库连接测试成功" ) else: - return NormalResponse( + response = NormalResponse( success=False, message="数据库连接测试失败" ) + return HxfResponse(response) except Exception as e: return NormalResponse( @@ -298,22 +302,25 @@ async def get_table_schema( schema_result = await db_service.get_table_schema(request.table_name, current_user.id) if schema_result['success']: - return QueryResponse( + response = QueryResponse( success=True, message="获取表结构成功", data=schema_result['data'] ) + return HxfResponse(response) else: - return QueryResponse( + response = QueryResponse( success=False, message=schema_result['message'] ) + return HxfResponse(response) except Exception as e: - return QueryResponse( + response = QueryResponse( success=False, message=f"获取表结构失败: {str(e)}" ) + return HxfResponse(response) class StreamQueryRequest(BaseModel): query: str @@ -355,6 +362,8 @@ async def stream_smart_query( # 初始化服务 workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() + conversation_context_service = ConversationContextService() # 处理对话上下文 @@ -435,7 +444,7 @@ async def stream_smart_query( except: pass - return StreamingResponse( + response = StreamingResponse( generate_stream(), media_type="text/plain", headers={ @@ -447,6 +456,7 @@ async def stream_smart_query( "Access-Control-Allow-Methods": "*" } ) + return HxfResponse(response) @router.post("/execute-db-query", summary="流式数据库查询") async def execute_database_query( @@ -477,6 +487,7 @@ async def execute_database_query( # 初始化服务 workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() conversation_context_service = ConversationContextService() # 处理对话上下文 @@ -556,7 +567,7 @@ async def execute_database_query( except: pass - return StreamingResponse( + response = StreamingResponse( generate_stream(), media_type="text/plain", headers={ @@ -568,6 +579,7 @@ async def execute_database_query( "Access-Control-Allow-Methods": "*" } ) + return HxfResponse(response) @router.delete("/cleanup-temp-files", summary="清理临时文件") async def cleanup_temp_files( @@ -590,16 +602,18 @@ async def cleanup_temp_files( except OSError: pass - return BaseResponse( + response = BaseResponse( success=True, message=f"已清理 {cleaned_count} 个临时文件" ) + return HxfResponse(response) except Exception as e: - return BaseResponse( + response = BaseResponse( success=False, message=f"清理临时文件失败: {str(e)}" ) + return HxfResponse(response) @router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表") async def get_file_list( @@ -634,7 +648,7 @@ async def get_file_list( file_list.append(file_info) session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件" - return FileListResponse( + response = FileListResponse( success=True, message="获取文件列表成功", data={ @@ -645,12 +659,14 @@ async def get_file_list( 'total_pages': (total + page_size - 1) // page_size } ) + return HxfResponse(response) except Exception as e: - return FileListResponse( + response = FileListResponse( success=False, message=f"获取文件列表失败: {str(e)}" ) + return HxfResponse(response) @router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件") async def delete_file( @@ -668,22 +684,26 @@ async def delete_file( if success: session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}" - return NormalResponse( + response = NormalResponse( success=True, message="文件删除成功" ) + return HxfResponse(response) + else: session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败" - return NormalResponse( + response = NormalResponse( success=False, message="文件不存在或删除失败" ) + return HxfResponse(response) except Exception as e: - return NormalResponse( + response = NormalResponse( success=True, message=str(e) ) + return HxfResponse(response) @router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息") async def get_file_info( @@ -725,9 +745,11 @@ async def get_file_info( 'sheets_summary': excel_file.get_all_sheets_summary() } - return QueryResponse( + session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件 {file_id} 信息" + response = QueryResponse( success=True, message="获取文件信息成功", data=file_info ) + return HxfResponse(response) \ No newline at end of file diff --git a/backend/th_agenter/api/endpoints/table_metadata.py b/backend/th_agenter/api/endpoints/table_metadata.py index e8f689a..c2b439f 100644 --- a/backend/th_agenter/api/endpoints/table_metadata.py +++ b/backend/th_agenter/api/endpoints/table_metadata.py @@ -9,6 +9,7 @@ from th_agenter.models.user import User from th_agenter.db.database import get_session from th_agenter.services.table_metadata_service import TableMetadataService from th_agenter.services.auth import AuthService +from utils.util_exceptions import HxfResponse router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"]) @@ -54,8 +55,7 @@ async def collect_table_metadata( request.table_names ) session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据" - return result - + return HxfResponse(result) @router.get("/", summary="获取用户表元数据列表") async def get_table_metadata( @@ -66,7 +66,7 @@ async def get_table_metadata( """获取表元数据列表""" try: service = TableMetadataService(session) - metadata_list = service.get_user_table_metadata( + metadata_list = await service.get_user_table_metadata( current_user.id, database_config_id ) @@ -96,17 +96,17 @@ async def get_table_metadata( for meta in metadata_list ] - return { + return HxfResponse({ "success": True, "data": data - } + }) except Exception as e: logger.error(f"获取表元数据失败: {str(e)}") - return { + return HxfResponse({ "success": False, "message": str(e) - } + }) @router.post("/by-table", summary="根据表名获取表元数据") @@ -118,7 +118,7 @@ async def get_table_metadata_by_name( """根据表名获取表元数据""" try: service = TableMetadataService(session) - metadata = service.get_table_metadata_by_name( + metadata = await service.get_table_metadata_by_name( current_user.id, request.database_config_id, request.table_name @@ -146,27 +146,30 @@ async def get_table_metadata_by_name( "business_context": metadata.business_context or "" } } - return {"success": True, "data": data} + return HxfResponse({ + "success": True, + "data": data + }) else: - return {"success": False, "data": None, "message": "表元数据不存在"} + return HxfResponse({ + "success": False, + "data": None, + "message": "表元数据不存在" + }) except Exception as e: logger.error(f"获取表元数据失败: {str(e)}") - return { + return HxfResponse({ "success": False, "message": str(e) - } + }) - return { - "success": True, - "data": data - } except Exception as e: logger.error(f"获取表元数据失败: {str(e)}") - return { + return HxfResponse({ "success": False, "message": str(e) - } + }) @router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置") @@ -179,14 +182,17 @@ async def update_qa_settings( """更新表的问答设置""" try: service = TableMetadataService(session) - success = service.update_table_qa_settings( + success = await service.update_table_qa_settings( current_user.id, metadata_id, - settings.dict() + settings.model_dump() ) if success: - return {"success": True, "message": "设置更新成功"} + return HxfResponse({ + "success": True, + "message": "设置更新成功" + }) else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -221,9 +227,9 @@ async def save_table_metadata( session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置" - return { + return HxfResponse({ "success": True, "message": f"成功保存 {len(result['saved_tables'])} 个表的配置", "saved_tables": result['saved_tables'], "failed_tables": result.get('failed_tables', []) - } \ No newline at end of file + }) \ No newline at end of file diff --git a/backend/th_agenter/api/endpoints/users.py b/backend/th_agenter/api/endpoints/users.py index 5d09552..b3b3656 100644 --- a/backend/th_agenter/api/endpoints/users.py +++ b/backend/th_agenter/api/endpoints/users.py @@ -9,6 +9,7 @@ from ...core.simple_permissions import require_super_admin from ...services.auth import AuthService from ...services.user import UserService from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest +from utils.util_exceptions import HxfResponse router = APIRouter() @@ -17,7 +18,8 @@ async def get_user_profile( current_user = Depends(AuthService.get_current_user) ): """获取当前用户的个人信息.""" - return UserResponse.model_validate(current_user) + response = UserResponse.model_validate(current_user) + return HxfResponse(response) @router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息") async def update_user_profile( @@ -39,7 +41,8 @@ async def update_user_profile( # Update user updated_user = await user_service.update_user(current_user.id, user_update) - return UserResponse.model_validate(updated_user) + response = UserResponse.model_validate(updated_user) + return HxfResponse(response) @router.delete("/profile", summary="删除当前用户的账户") async def delete_user_account( @@ -51,7 +54,8 @@ async def delete_user_account( user_service = UserService(session) await user_service.delete_user(current_user.id) session.desc = f"删除用户 [{username}] 成功" - return {"message": f"删除用户 {username} 成功"} + response = {"message": f"删除用户 {username} 成功"} + return HxfResponse(response) # Admin endpoints @router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)") @@ -83,7 +87,8 @@ async def create_user( # Create user new_user = await user_service.create_user(user_create) - return UserResponse.model_validate(new_user) + response = UserResponse.model_validate(new_user) + return HxfResponse(response) @router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)") async def list_users( @@ -111,7 +116,7 @@ async def list_users( "page": page, "page_size": size } - return result + return HxfResponse(result) @router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)") @@ -128,7 +133,8 @@ async def get_user( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) - return UserResponse.model_validate(user) + response = UserResponse.model_validate(user) + return HxfResponse(response) @router.put("/change-password", summary="修改当前用户的密码") async def change_password( @@ -145,7 +151,8 @@ async def change_password( current_password=request.current_password, new_password=request.new_password ) - return {"message": "Password changed successfully"} + response = {"message": "Password changed successfully"} + return HxfResponse(response) except Exception as e: if "Current password is incorrect" in str(e): raise HTTPException( @@ -178,7 +185,8 @@ async def reset_user_password( user_id=user_id, new_password=request.new_password ) - return {"message": "Password reset successfully"} + response = {"message": "Password reset successfully"} + return HxfResponse(response) except Exception as e: if "User not found" in str(e): raise HTTPException( @@ -215,7 +223,8 @@ async def update_user( ) updated_user = await user_service.update_user(user_id, user_update) - return UserResponse.model_validate(updated_user) + response = UserResponse.model_validate(updated_user) + return HxfResponse(response) @router.delete("/{user_id}", summary="删除用户 (仅管理员权限)") async def delete_user( @@ -234,4 +243,5 @@ async def delete_user( ) await user_service.delete_user(user_id) - return {"message": "User deleted successfully"} \ No newline at end of file + response = {"message": "User deleted successfully"} + return HxfResponse(response) diff --git a/backend/th_agenter/api/endpoints/workflow.py b/backend/th_agenter/api/endpoints/workflow.py index 5ae996f..1c40ced 100644 --- a/backend/th_agenter/api/endpoints/workflow.py +++ b/backend/th_agenter/api/endpoints/workflow.py @@ -18,7 +18,7 @@ from ...services.workflow_engine import get_workflow_engine from ...services.auth import AuthService from ...models.user import User from loguru import logger - +from utils.util_exceptions import HxfResponse router = APIRouter() def convert_workflow_for_response(workflow_dict): @@ -30,38 +30,7 @@ def convert_workflow_for_response(workflow_dict): if 'to_node' in conn: conn['to'] = conn.pop('to_node') return workflow_dict - -@router.post("/", response_model=WorkflowResponse) -async def create_workflow( - workflow_data: WorkflowCreate, - session: Session = Depends(get_session), - current_user: User = Depends(AuthService.get_current_user) -): - """创建工作流""" - from ...models.workflow import Workflow - - # 创建工作流 - workflow = Workflow( - name=workflow_data.name, - description=workflow_data.description, - definition=workflow_data.definition.model_dump(), - version="1.0.0", - status=workflow_data.status, - owner_id=current_user.id - ) - workflow.set_audit_fields(current_user.id) - - await session.add(workflow) - await session.commit() - await session.refresh(workflow) - - # 转换definition中的字段映射 - workflow_dict = convert_workflow_for_response(workflow.to_dict()) - - logger.info(f"Created workflow: {workflow.name} by user {current_user.username}") - return WorkflowResponse(**workflow_dict) - - + @router.get("/", response_model=WorkflowListResponse) async def list_workflows( skip: Optional[int] = Query(None, ge=0), @@ -73,6 +42,7 @@ async def list_workflows( ): """获取工作流列表""" from ...models.workflow import Workflow + session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})" # 构建查询 stmt = select(Workflow).where(Workflow.owner_id == current_user.id) @@ -90,17 +60,22 @@ async def list_workflows( count_query = count_query.where(Workflow.status == workflow_status) if search: count_query = count_query.where(Workflow.name.ilike(f"%{search}%")) - total = session.scalar(count_query) + + session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}" + total = await session.scalar(count_query) + session.desc = f"查询结果: 共 {total} 条" # 如果没有传分页参数,返回所有数据 if skip is None and limit is None: - workflows = session.scalars(stmt).all() - return WorkflowListResponse( + workflows = (await session.scalars(stmt)).all() + session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)} 条" + response = WorkflowListResponse( workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows], total=total, page=1, size=total ) + return HxfResponse(response) # 使用默认分页参数 if skip is None: @@ -109,14 +84,16 @@ async def list_workflows( limit = 10 # 分页查询 - workflows = session.scalars(stmt.offset(skip).limit(limit)).all() + workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all() + session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)} 条" - return WorkflowListResponse( + response = WorkflowListResponse( workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows], total=total, page=skip // limit + 1, # 计算页码 size=limit ) + return HxfResponse(response) @router.get("/{workflow_id}", response_model=WorkflowResponse) async def get_workflow( @@ -126,8 +103,9 @@ async def get_workflow( ): """获取工作流详情""" from ...models.workflow import Workflow + session.desc = f"START: 获取工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).where( Workflow.id == workflow_id, Workflow.owner_id == current_user.id @@ -135,12 +113,15 @@ async def get_workflow( ) if not workflow: + session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" ) - return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) + session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}" + response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) + return HxfResponse(response) @router.put("/{workflow_id}", response_model=WorkflowResponse) async def update_workflow( @@ -151,8 +132,9 @@ async def update_workflow( ): """更新工作流""" from ...models.workflow import Workflow + session.desc = f"START: 更新工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).where( and_( Workflow.id == workflow_id, @@ -162,13 +144,15 @@ async def update_workflow( ) if not workflow: + session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" ) # 更新字段 - update_data = workflow_data.dict(exclude_unset=True) + session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}" + update_data = workflow_data.model_dump(exclude_unset=True) for field, value in update_data.items(): if field == "definition" and value: # 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用 @@ -181,13 +165,12 @@ async def update_workflow( workflow.set_audit_fields(current_user.id, is_update=True) - session.commit() - session.refresh(workflow) + await session.commit() + await session.refresh(workflow) + session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}" + response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) + return HxfResponse(response) - logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}") - return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) - - @router.delete("/{workflow_id}") async def delete_workflow( workflow_id: int, @@ -196,8 +179,9 @@ async def delete_workflow( ): """删除工作流""" from ...models.workflow import Workflow + session.desc = f"START: 删除工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).filter( and_( Workflow.id == workflow_id, @@ -207,16 +191,18 @@ async def delete_workflow( ) if not workflow: + session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" ) + session.desc = f"删除工作流: {workflow.name}" + await session.delete(workflow) + await session.commit() - session.delete(workflow) - session.commit() - - logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}") - return {"message": "工作流删除成功"} + session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}" + response = {"message": "工作流删除成功"} + return HxfResponse(response) @router.post("/{workflow_id}/activate") @@ -227,8 +213,9 @@ async def activate_workflow( ): """激活工作流""" from ...models.workflow import Workflow + session.desc = f"START: 激活工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).filter( and_( Workflow.id == workflow_id, @@ -238,6 +225,7 @@ async def activate_workflow( ) if not workflow: + session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" @@ -245,11 +233,11 @@ async def activate_workflow( workflow.status = ModelWorkflowStatus.PUBLISHED workflow.set_audit_fields(current_user.id, is_update=True) + await session.commit() - session.commit() - - logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}") - return {"message": "工作流激活成功"} + session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}" + response = {"message": "工作流激活成功"} + return HxfResponse(response) @router.post("/{workflow_id}/deactivate") async def deactivate_workflow( @@ -259,8 +247,9 @@ async def deactivate_workflow( ): """停用工作流""" from ...models.workflow import Workflow + session.desc = f"START: 停用工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).filter( and_( Workflow.id == workflow_id, @@ -270,6 +259,7 @@ async def deactivate_workflow( ) if not workflow: + session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" @@ -278,10 +268,11 @@ async def deactivate_workflow( workflow.status = ModelWorkflowStatus.ARCHIVED workflow.set_audit_fields(current_user.id, is_update=True) - session.commit() + await session.commit() - logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}") - return {"message": "工作流停用成功"} + session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}" + response = {"message": "工作流停用成功"} + return HxfResponse(response) @router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse) async def execute_workflow( @@ -292,8 +283,9 @@ async def execute_workflow( ): """执行工作流""" from ...models.workflow import Workflow + session.desc = f"START: 执行工作流 {workflow_id}" - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).filter( and_( Workflow.id == workflow_id, @@ -303,6 +295,7 @@ async def execute_workflow( ) if not workflow: + session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" @@ -323,8 +316,8 @@ async def execute_workflow( session=session ) - logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}") - return execution_result + session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}" + return HxfResponse(execution_result) @router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse]) @@ -336,38 +329,40 @@ async def list_workflow_executions( current_user: User = Depends(AuthService.get_current_user) ): """获取工作流执行历史""" + session.desc = f"START: 获取工作流执行历史 {workflow_id}" try: from ...models.workflow import Workflow, WorkflowExecution # 验证工作流所有权 - workflow = session.scalar( - select(Workflow).where( - and_( - Workflow.id == workflow_id, - Workflow.owner_id == current_user.id + workflow = await session.scalar( + select(Workflow).where( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) ) ) - ) if not workflow: + session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="工作流不存在" ) # 获取执行历史 - executions = session.scalars( - select(WorkflowExecution).where( - WorkflowExecution.workflow_id == workflow_id - ).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit) - ).all() + executions = (await session.scalars( + select(WorkflowExecution).where( + WorkflowExecution.workflow_id == workflow_id + ).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit) + )).all() - return [WorkflowExecutionResponse.model_validate(execution) for execution in executions] + session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}" + response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions] + return HxfResponse(response) - except HTTPException: - raise except Exception as e: - logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}") + session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取执行历史失败" @@ -380,10 +375,11 @@ async def get_workflow_execution( current_user: User = Depends(AuthService.get_current_user) ): """获取工作流执行详情""" + session.desc = f"START: 获取工作流执行详情 {execution_id}" try: from ...models.workflow import WorkflowExecution, Workflow - execution = session.scalar( + execution = await session.scalar( select(WorkflowExecution).join( Workflow, WorkflowExecution.workflow_id == Workflow.id ).where( @@ -393,17 +389,18 @@ async def get_workflow_execution( ) if not execution: + session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在" raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="执行记录不存在" ) - return WorkflowExecutionResponse.model_validate(execution) + response = WorkflowExecutionResponse.model_validate(execution) + session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}" + return HxfResponse(response) - except HTTPException: - raise except Exception as e: - logger.error(f"Error getting workflow execution {execution_id}: {str(e)}") + session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取执行详情失败" @@ -418,7 +415,7 @@ async def execute_workflow_stream( current_user: User = Depends(AuthService.get_current_user) ): """流式执行工作流,实时推送节点执行状态""" - + session.desc = f"START: 流式执行工作流 {workflow_id}" async def generate_stream() -> AsyncGenerator[str, None]: workflow_engine = None @@ -426,7 +423,7 @@ async def execute_workflow_stream( from ...models.workflow import Workflow # 验证工作流 - workflow = session.scalar( + workflow = await session.scalar( select(Workflow).filter( and_( Workflow.id == workflow_id, @@ -466,7 +463,7 @@ async def execute_workflow_stream( logger.error(f"流式工作流执行异常: {e}", exc_info=True) yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n" - return StreamingResponse( + response = StreamingResponse( generate_stream(), media_type="text/plain", headers={ @@ -477,4 +474,43 @@ async def execute_workflow_stream( "Access-Control-Allow-Headers": "*", "Access-Control-Allow-Methods": "*" } - ) \ No newline at end of file + ) + session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕" + return HxfResponse(response) + +# ----------------------------------------------------------------------- + +@router.post("/", response_model=WorkflowResponse) +async def create_workflow( + workflow_data: WorkflowCreate, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """创建工作流""" + from ...models.workflow import Workflow + session.desc = f"START: 创建工作流 {workflow_data.name}" + # 创建工作流 + workflow = Workflow( + name=workflow_data.name, + description=workflow_data.description, + definition=workflow_data.definition.model_dump(), + version="1.0.0", + status=workflow_data.status, + owner_id=current_user.id + ) + session.desc = f"创建工作流实例 - Workflow(), {workflow_data.name}" + workflow.set_audit_fields(current_user.id) + session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}" + + session.add(workflow) + await session.commit() + await session.refresh(workflow) + session.desc = f"保存工作流 - commit & refresh {workflow_data.name}" + # 转换definition中的字段映射 + workflow_dict = convert_workflow_for_response(workflow.to_dict()) + session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}" + + response = WorkflowResponse(**workflow_dict) + session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}" + return HxfResponse(response) + \ No newline at end of file diff --git a/backend/th_agenter/api/routes.py b/backend/th_agenter/api/routes.py index a07320f..ccc01e3 100644 --- a/backend/th_agenter/api/routes.py +++ b/backend/th_agenter/api/routes.py @@ -69,13 +69,6 @@ router.include_router( tags=["smart-chat"] ) - - - - - - - router.include_router( workflow.router, prefix="/workflows", diff --git a/backend/th_agenter/core/config.py b/backend/th_agenter/core/config.py index 86aa03d..3554a8a 100644 --- a/backend/th_agenter/core/config.py +++ b/backend/th_agenter/core/config.py @@ -1,6 +1,7 @@ """Configuration management for TH Agenter.""" import os +from requests import Session import yaml from pathlib import Path from loguru import logger @@ -88,13 +89,15 @@ class LLMSettings(BaseSettings): "extra": "ignore" } - def get_current_config(self) -> dict: + async def get_current_config(self, session: Session) -> dict: """获取当前选择的提供商配置 - 优先从数据库读取默认配置.""" try: - # 尝试从数据库读取默认聊天模型配置 from th_agenter.services.llm_config_service import LLMConfigService + # 尝试从数据库读取默认聊天模型配置 llm_service = LLMConfigService() - db_config = llm_service.get_default_chat_config() + db_config = None + if session: + db_config = await llm_service.get_default_chat_config(session) if db_config: # 如果数据库中有默认配置,使用数据库配置 @@ -105,10 +108,17 @@ class LLMSettings(BaseSettings): "max_tokens": self.max_tokens, "temperature": self.temperature } + if session: + session.desc = f"使用LLM配置(get_default_chat_config)> {config}" + else: + logger.info(f"使用LLM配置(get_default_chat_config) > {config}") return config except Exception as e: # 如果数据库读取失败,记录错误并回退到环境变量 - logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}") + if session: + session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}" + else: + logger.error(f"获取默认对话模型配置失败: {str(e)}") # 回退到原有的环境变量配置 provider_configs = { @@ -182,13 +192,15 @@ class EmbeddingSettings(BaseSettings): "extra": "ignore" } - def get_current_config(self) -> dict: + async def get_current_config(self, session: Session) -> dict: """获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置.""" try: + if session: + session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config"; # 尝试从数据库读取默认嵌入模型配置 from th_agenter.services.llm_config_service import LLMConfigService llm_service = LLMConfigService() - db_config = llm_service.get_default_embedding_config() + db_config = await llm_service.get_default_embedding_config(session) if db_config: # 如果数据库中有默认配置,使用数据库配置 @@ -200,7 +212,10 @@ class EmbeddingSettings(BaseSettings): return config except Exception as e: # 如果数据库读取失败,记录错误并回退到环境变量 - logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}") + if session: + session.error(f"Failed to read embedding config from database, falling back to env vars: {e}") + else: + logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}") # 回退到原有的环境变量配置 provider_configs = { diff --git a/backend/th_agenter/core/context.py b/backend/th_agenter/core/context.py index b13527b..2fc3862 100644 --- a/backend/th_agenter/core/context.py +++ b/backend/th_agenter/core/context.py @@ -9,7 +9,7 @@ from ..models.user import User from loguru import logger # Context variable to store current user -current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None) +current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None) # Thread-local storage as backup _thread_local = threading.local() @@ -19,34 +19,56 @@ class UserContext: """User context manager for accessing current user globally.""" @staticmethod - def set_current_user(user: User) -> None: + def set_current_user(user: User, canLog: bool = False) -> None: """Set current user in context.""" - logger.info(f"[UserContext] - Setting user in context: {user.username} (ID: {user.id})") + if canLog: + logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})") + + # Store user information as a dictionary instead of the SQLAlchemy model + user_dict = { + 'id': user.id, + 'username': user.username, + 'email': user.email, + 'full_name': user.full_name, + 'is_active': user.is_active + } # Set in ContextVar - current_user_context.set(user) + current_user_context.set(user_dict) # Also set in thread-local as backup - _thread_local.current_user = user + _thread_local.current_user = user_dict # Verify it was set verify_user = current_user_context.get() - logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}") + if canLog: + logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}") @staticmethod - def set_current_user_with_token(user: User): + def set_current_user_with_token(user: User, canLog: bool = False): """Set current user in context and return token for cleanup.""" - logger.info(f"[UserContext] - Setting user in context with token: {user.username} (ID: {user.id})") + if canLog: + logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})") + + # Store user information as a dictionary instead of the SQLAlchemy model + user_dict = { + 'id': user.id, + 'username': user.username, + 'email': user.email, + 'full_name': user.full_name, + 'is_active': user.is_active + } # Set in ContextVar and get token - token = current_user_context.set(user) + token = current_user_context.set(user_dict) # Also set in thread-local as backup - _thread_local.current_user = user + _thread_local.current_user = user_dict # Verify it was set verify_user = current_user_context.get() - logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}") + if canLog: + logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}") return token @@ -63,42 +85,45 @@ class UserContext: delattr(_thread_local, 'current_user') @staticmethod - def get_current_user() -> Optional[User]: - """Get current user from context.""" - logger.debug("[UserContext] - Attempting to get user from context") - + def get_current_user() -> Optional[dict]: + """Get current user from context.""" # Try ContextVar first - user = current_user_context.get() + user = current_user_context.get() if user: - logger.debug(f"[UserContext] - Got user from ContextVar: {user.username} (ID: {user.id})") + # logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}") return user # Fallback to thread-local - user = getattr(_thread_local, 'current_user', None) + user = getattr(_thread_local, 'current_user', None) if user: - logger.debug(f"[UserContext] - Got user from thread-local: {user.username} (ID: {user.id})") + # logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}") return user - logger.debug("[UserContext] - No user found in context (neither ContextVar nor thread-local)") + logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)") return None @staticmethod def get_current_user_id() -> Optional[int]: """Get current user ID from context.""" - user = UserContext.get_current_user() - return user.id if user else None + try: + user = UserContext.get_current_user() + return user.get('id') if user else None + except Exception as e: + logger.error(f"[UserContext] - Error getting current user ID: {e}") + return None @staticmethod - def clear_current_user() -> None: + def clear_current_user(canLog: bool = False) -> None: """Clear current user from context.""" - logger.info("[UserContext] - 清除当前用户上下文") + if canLog: + logger.info("[UserContext] - 清除当前用户上下文") current_user_context.set(None) if hasattr(_thread_local, 'current_user'): delattr(_thread_local, 'current_user') @staticmethod - def require_current_user() -> User: + def require_current_user() -> dict: """Get current user from context, raise exception if not found.""" # Use the same logic as get_current_user to check both ContextVar and thread-local user = UserContext.get_current_user() @@ -114,4 +139,4 @@ class UserContext: def require_current_user_id() -> int: """Get current user ID from context, raise exception if not found.""" user = UserContext.require_current_user() - return user.id \ No newline at end of file + return user.get('id') \ No newline at end of file diff --git a/backend/th_agenter/core/middleware.py b/backend/th_agenter/core/middleware.py index 178bf1e..61ab4ef 100644 --- a/backend/th_agenter/core/middleware.py +++ b/backend/th_agenter/core/middleware.py @@ -20,7 +20,7 @@ class UserContextMiddleware(BaseHTTPMiddleware): def __init__(self, app, exclude_paths: list = None): super().__init__(app) - self.canLog = True + self.canLog = False # Paths that don't require authentication self.exclude_paths = exclude_paths or [ "/docs", @@ -77,7 +77,7 @@ class UserContextMiddleware(BaseHTTPMiddleware): logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理") # Always clear any existing user context to ensure fresh authentication - UserContext.clear_current_user() + UserContext.clear_current_user(self.canLog) # Initialize context token user_token = None @@ -96,6 +96,7 @@ class UserContextMiddleware(BaseHTTPMiddleware): # Extract token token = authorization.split(" ")[1] + # Verify token payload = AuthService.verify_token(token) if payload is None: @@ -136,7 +137,7 @@ class UserContextMiddleware(BaseHTTPMiddleware): ) # Set user in context using token mechanism - user_token = UserContext.set_current_user_with_token(user) + user_token = UserContext.set_current_user_with_token(user, self.canLog) if self.canLog: logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文") @@ -160,8 +161,13 @@ class UserContextMiddleware(BaseHTTPMiddleware): try: response = await call_next(request) return response + except Exception as e: + # Log error but don't fail the request + logger.error(f"[MIDDLEWARE] - 请求处理出错: {e}") + # Return 500 error + return HxfErrorResponse(e) finally: # Always clear user context after request processing - UserContext.clear_current_user() + UserContext.clear_current_user(self.canLog) if self.canLog: logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}") diff --git a/backend/th_agenter/core/new_agent.py b/backend/th_agenter/core/new_agent.py new file mode 100644 index 0000000..6b2880c --- /dev/null +++ b/backend/th_agenter/core/new_agent.py @@ -0,0 +1,73 @@ +"""LLM工厂类,用于创建和管理LLM实例""" + +from typing import Optional +from langchain_openai import ChatOpenAI +from langchain.agents import create_agent +from loguru import logger +from requests import Session +from .config import get_settings + +async def new_llm(session: Session = None, model: Optional[str] = None, + temperature: Optional[float] = None, + streaming: bool = False) -> ChatOpenAI: + """创建LLM实例 + + Args: + model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型 + temperature: 可选,模型温度参数 + streaming: 是否启用流式响应,默认False + + Returns: + ChatOpenAI实例 + """ + settings = get_settings() + llm_config = await settings.llm.get_current_config(session) + + if model: + # 根据指定的模型获取对应配置 + if model.startswith('deepseek'): + llm_config['model'] = settings.llm.deepseek_model + llm_config['api_key'] = settings.llm.deepseek_api_key + llm_config['base_url'] = settings.llm.deepseek_base_url + elif model.startswith('doubao'): + llm_config['model'] = settings.llm.doubao_model + llm_config['api_key'] = settings.llm.doubao_api_key + llm_config['base_url'] = settings.llm.doubao_base_url + elif model.startswith('glm'): + llm_config['model'] = settings.llm.zhipu_model + llm_config['api_key'] = settings.llm.zhipu_api_key + llm_config['base_url'] = settings.llm.zhipu_base_url + elif model.startswith('moonshot'): + llm_config['model'] = settings.llm.moonshot_model + llm_config['api_key'] = settings.llm.moonshot_api_key + llm_config['base_url'] = settings.llm.moonshot_base_url + + llm = ChatOpenAI( + model=llm_config['model'], + api_key=llm_config['api_key'], + base_url=llm_config['base_url'], + temperature=temperature if temperature is not None else llm_config['temperature'], + max_tokens=llm_config['max_tokens'], + streaming=streaming + ) + + return llm + +async def new_agent(session: Session = None, model: Optional[str] = None, + temperature: Optional[float] = None, + streaming: bool = False) -> ChatOpenAI: + """创建LLM实例 + + Args: + model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型 + temperature: 可选,模型温度参数 + streaming: 是否启用流式响应,默认False + + Returns: + ChatOpenAI实例 + """ + llm = await new_llm(session, model, temperature, streaming) + result = create_agent( + model=llm, + ) + return result \ No newline at end of file diff --git a/backend/th_agenter/core/simple_permissions.py b/backend/th_agenter/core/simple_permissions.py index 458e84f..9229e21 100644 --- a/backend/th_agenter/core/simple_permissions.py +++ b/backend/th_agenter/core/simple_permissions.py @@ -3,6 +3,7 @@ from functools import wraps from typing import Optional from fastapi import HTTPException, Depends +from loguru import logger from sqlalchemy.orm import Session from ..db.database import get_session @@ -11,45 +12,40 @@ from ..models.permission import Role from ..services.auth import AuthService -def is_super_admin(user: User, db: Session) -> bool: +async def is_super_admin(user: User, session: Session) -> bool: """检查用户是否为超级管理员.""" + session.desc = f"检查用户 {user.id} 是否为超级管理员" if not user or not user.is_active: + session.desc = f"用户 {user.id} 不是活跃状态" return False - # 检查用户是否有超级管理员角色 try: - # 尝试访问已加载的角色 - for role in user.roles: - if role.code == "SUPER_ADMIN": - return True - return False - except Exception: - # 如果角色未加载或访问失败,直接从数据库查询 - from sqlalchemy import select, and_ - from ..models.permission import Role, UserRole + # 直接使用提供的session查询,避免MissingGreenlet错误 + from sqlalchemy import select + from ..models.permission import UserRole, Role - try: - # 直接查询用户角色 - stmt = select(Role).join(UserRole).filter( - and_( - UserRole.user_id == user.id, - Role.code == "SUPER_ADMIN", - Role.is_active == True - ) - ) - super_admin_role = db.execute(stmt).scalar_one_or_none() - return super_admin_role is not None - except Exception: - # 如果查询失败,返回False - return False + stmt = select(UserRole).join(Role).filter( + UserRole.user_id == user.id, + Role.code == 'SUPER_ADMIN', + Role.is_active == True + ) + user_role = await session.execute(stmt) + result = user_role.scalar_one_or_none() is not None + session.desc = f"用户 {user.id} 超级管理员角色查询结果: {result}" + return result + except Exception as e: + # 如果调用失败,记录错误并返回False + session.desc = f"EXCEPTION: 用户 {user.id} 超级管理员角色查询失败: {str(e)}" + logger.error(f"检查用户 {user.id} 超级管理员角色失败: {str(e)}") + return False -def require_super_admin( +async def require_super_admin( current_user: User = Depends(AuthService.get_current_user), session: Session = Depends(get_session) ) -> User: """要求超级管理员权限的依赖项.""" - if not is_super_admin(current_user, session): + if not await is_super_admin(current_user, session): raise HTTPException( status_code=403, detail="需要超级管理员权限" @@ -75,17 +71,17 @@ class SimplePermissionChecker: def __init__(self, db: Session): self.db = db - def check_super_admin(self, user: User) -> bool: + async def check_super_admin(self, user: User) -> bool: """检查是否为超级管理员.""" - return is_super_admin(user, self.db) + return await is_super_admin(user, self.db) - def check_user_access(self, user: User, target_user_id: int) -> bool: + async def check_user_access(self, user: User, target_user_id: int) -> bool: """检查用户访问权限(自己或超级管理员).""" if not user or not user.is_active: return False # 超级管理员可以访问所有用户 - if self.check_super_admin(user): + if await self.check_super_admin(user): return True # 用户只能访问自己的信息 diff --git a/backend/th_agenter/db/base.py b/backend/th_agenter/db/base.py index c34b6eb..c8333d8 100644 --- a/backend/th_agenter/db/base.py +++ b/backend/th_agenter/db/base.py @@ -38,7 +38,7 @@ class BaseModel(Base): def to_dict(self): """Convert model to dictionary.""" return { - column.name: getattr(self, column.name) + column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name) for column in self.__table__.columns } @@ -60,31 +60,31 @@ class BaseModel(Base): return cls(**filtered_data) def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False): - """Set audit fields for create/update operations. + """对创建/更新操作设置created_by/updated_by字段。 Args: - user_id: ID of the user performing the operation (optional, will use context if not provided) - is_update: True for update operations, False for create operations + user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取) + is_update: True 表示更新操作,False 表示创建操作 """ - # Get user_id from context if not provided + # 如果未提供user_id,则从上下文获取 if user_id is None: from ..core.context import UserContext try: user_id = UserContext.get_current_user_id() except Exception: - # If no user in context, skip setting audit fields + # 如果上下文没有用户ID,则跳过设置审计字段 return - # Skip if still no user_id + # 如果仍未提供user_id,则跳过设置审计字段 if user_id is None: return if not is_update: - # For create operations, set both create_by and update_by + # 对于创建操作,同时设置created_by和updated_by self.created_by = user_id self.updated_by = user_id else: - # For update operations, only set update_by + # 对于更新操作,仅设置updated_by self.updated_by = user_id # @event.listens_for(Session, 'before_flush') diff --git a/backend/th_agenter/db/database.py b/backend/th_agenter/db/database.py index 11bb003..0d3dd6e 100644 --- a/backend/th_agenter/db/database.py +++ b/backend/th_agenter/db/database.py @@ -18,12 +18,27 @@ class DrSession(AsyncSession): def __init__(self, **kwargs): """Initialize DrSession with unique ID.""" super().__init__(**kwargs) + self.title = "" + self.descs = [] # 确保info属性存在 if not hasattr(self, 'info'): self.info = {} self.info['session_id'] = str(uuid.uuid4()).split('-')[0] self.stepIndex = 0 + @property + def title(self) -> Optional[str]: + """Get work brief from session info.""" + return self.info.get('title') + + @title.setter + def title(self, value: str) -> None: + """Set work brief in session info.""" + if('title' not in self.info or self.info['title'].strip() == ""): + self.info['title'] = value # 确保title属性存在 + else: + self.info['title'] = value + " >>> " + self.info['title'] + @property def desc(self) -> Optional[str]: """Get work brief from session info.""" @@ -34,6 +49,25 @@ class DrSession(AsyncSession): """Set work brief in session info.""" self.stepIndex += 1 + match = re.search(r";(-\d+)", value); + level = -3 + if match: + level = int(match.group(1)) + value = value.replace(f";{level}", "") + level = -3 + level + + if "警告" in value or value.startswith("WARNING"): + self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "异常" in value or value.startswith("EXCEPTION"): + self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "成功" in value or value.startswith("SUCCESS"): + self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "开始" in value or value.startswith("START"): + self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "失败" in value or value.startswith("ERROR"): + self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level) + else: + self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level) def log_prefix(self) -> str: """Get log prefix with session ID and desc.""" @@ -74,14 +108,14 @@ class DrSession(AsyncSession): engine_async = create_async_engine( get_settings().database.url, - echo=True, # get_settings().database.echo, + echo=False, # get_settings().database.echo, future=True, pool_size=get_settings().database.pool_size, max_overflow=get_settings().database.max_overflow, pool_pre_ping=True, pool_recycle=3600, ) -from fastapi import Request +from fastapi import HTTPException, Request AsyncSessionFactory = sessionmaker( bind=engine_async, @@ -95,10 +129,15 @@ async def get_session(request: Request = None): if request: url = f"{request.method} {request.url.path}"# .split("://")[-1] # session = AsyncSessionFactory() - + print(url) + # 取得request的来源IP + if request: + client_host = request.client.host + else: + client_host = "无request" session = DrSession(bind=engine_async) - session.desc = f"SUCCESS: 创建数据库 session >>> {url}" + session.title = f"{url} - {client_host}" # 设置request属性 if request: @@ -111,8 +150,9 @@ async def get_session(request: Request = None): errMsg = f"数据库 session 异常 >>> {e}" session.desc = f"EXCEPTION: {errMsg}" await session.rollback() - raise e - # DatabaseError(e) + # 重新抛出原始异常,不转换为 HTTPException + raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常 finally: - session.desc = f"数据库 session 关闭" + # session.desc = f"数据库 session 关闭" + session.desc = "" await session.close() diff --git a/backend/th_agenter/db/migrations/migrate_hardcoded_resources.py b/backend/th_agenter/db/migrations/migrate_hardcoded_resources.py index 37efb73..f786cbc 100644 --- a/backend/th_agenter/db/migrations/migrate_hardcoded_resources.py +++ b/backend/th_agenter/db/migrations/migrate_hardcoded_resources.py @@ -11,8 +11,8 @@ sys.path.insert(0, str(backend_dir)) from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker -from th_agenter.core.config import settings -from th_agenter.db.database import Base, get_session +from th_agenter.core.config import get_settings +from th_agenter.db.database import Base from th_agenter.models import * # Import all models to ensure they're registered from th_agenter.utils.logger import get_logger from th_agenter.models.resource import Resource @@ -25,8 +25,25 @@ def migrate_hardcoded_resources(): """Migrate hardcoded resources from init_resource_data.py to database.""" db = None try: + # Get database settings + settings = get_settings() + + # Create synchronous engine (remove asyncpg from URL) + sync_db_url = settings.database.url.replace('postgresql+asyncpg://', 'postgresql://') + sync_engine = create_engine( + sync_db_url, + echo=False, + pool_size=settings.database.pool_size, + max_overflow=settings.database.max_overflow, + pool_pre_ping=True, + pool_recycle=3600, + ) + + # Create synchronous session factory + SyncSessionFactory = sessionmaker(bind=sync_engine) + # Get database session - db = get_session() # xxxx + db = SyncSessionFactory() if db is None: logger.error("Failed to create database session") diff --git a/backend/th_agenter/db/migrations/remove_permission_tables.py b/backend/th_agenter/db/migrations/remove_permission_tables.py index bbc3da6..fe2299a 100644 --- a/backend/th_agenter/db/migrations/remove_permission_tables.py +++ b/backend/th_agenter/db/migrations/remove_permission_tables.py @@ -17,7 +17,7 @@ branch_labels = None depends_on = None -def upgrade(): +async def upgrade(): """删除权限相关表.""" # 获取数据库连接 @@ -44,7 +44,7 @@ def upgrade(): WHERE table_name = '{table_name}' ); """)) - table_exists = result.scalar() + table_exists = await result.scalar() if table_exists: print(f"删除表: {table_name}") diff --git a/backend/th_agenter/llm/base_llm.py b/backend/th_agenter/llm/base_llm.py new file mode 100644 index 0000000..e2b07f0 --- /dev/null +++ b/backend/th_agenter/llm/base_llm.py @@ -0,0 +1,198 @@ +from loguru import logger +from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any + +# 核心:导入 LangChain 的基础语言模型抽象类 +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.callbacks import CallbackManagerForLLMRun +from dataclasses import dataclass, field +from typing import Optional, Dict, Any, List +from datetime import datetime + +@dataclass +class LLMConfig_DataClass: + """ + 统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段 + 通过 provider + is_embedding 区分模型类型: + - 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False + - 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False + - 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True + """ + # ====================== 数据库核心公共字段(必选/可选) ====================== + # 基础标识字段 + name: str # 模型自定义名称(如 "gpt-5") + model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5") + provider: str # 提供商(openai/llama/bge/zhipu 等) + id: Optional[int] = None # 数据库主键ID + description: Optional[str] = None # 模型描述 + is_active: bool = True # 是否启用 + is_default: bool = False # 是否默认模型 + is_embedding: bool = False # 是否为嵌入式模型(核心区分标识) + + # ====================== 通用生成参数(所有推理模型共用) ====================== + temperature: float = 0.7 # 生成温度(默认值对齐数据库示例) + max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例) + top_p: float = 0.6 # 采样Top-P + frequency_penalty: float = 0.0 # 频率惩罚 + presence_penalty: float = 0.0 # 存在惩罚 + + # ====================== 在线模型专属参数(非必填,仅在线模型生效) ====================== + api_key: Optional[str] = None # API密钥(在线模型必填) + base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1) + # timeout: int = 30 # 请求超时时间(秒) + max_retries: int = 3 # 最大重试次数 + api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview) + + # ====================== 本地模型专属参数(非必填,仅本地模型生效) ====================== + model_path: Optional[str] = None # 本地模型文件路径(本地模型必填) + device: str = "cpu" # 运行设备(cpu/cuda/mps) + n_ctx: int = 2048 # 上下文窗口大小 + n_threads: int = 8 # 推理线程数 + quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16) + load_in_8bit: bool = False # 是否8bit加载 + load_in_4bit: bool = False # 是否4bit加载 + prompt_template: Optional[str] = None # 自定义Prompt模板 + + # ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ====================== + normalize_embeddings: bool = True # 是否归一化向量 + batch_size: int = 32 # 批量编码大小 + encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数 + dimension: Optional[int] = None # 向量维度(如 768) + + # ====================== 元数据字段(数据库自动维护) ====================== + extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置 + usage_count: int = 0 # 使用次数 + last_used_at: Optional[datetime] = None # 最后使用时间 + created_at: Optional[datetime] = None # 创建时间 + updated_at: Optional[datetime] = None # 更新时间 + created_by: Optional[int] = None # 创建人ID + updated_by: Optional[int] = None # 更新人ID + + api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储) + + # ====================== 核心工具方法 ====================== + def __post_init__(self): + """后置初始化:自动校验和修正配置""" + # 1. 嵌入式模型强制清空推理参数(避免误用) + if self.is_embedding: + self.max_tokens = 0 + self.temperature = 0.0 + self.top_p = 0.0 + + # 2. 校验必填参数(按模型类型) + self._validate_required_fields() + + def _validate_required_fields(self): + """按模型类型校验必填参数""" + # 在线模型校验 + if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: + if not self.api_key: + raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key") + + # 本地模型校验 + if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: + if not self.model_path: + raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path") + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(用于存入/更新数据库)""" + return { + key: value for key, value in self.__dict__.items() + if not key.startswith('_') # 排除私有属性 + } + + @classmethod + def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass": + """从数据库字典初始化配置(核心方法)""" + # 1. 时间字段转换:字符串 → datetime + time_fields = ['last_used_at', 'created_at', 'updated_at'] + for field_name in time_fields: + val = db_dict.get(field_name) + if val and isinstance(val, str): + try: + db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00')) + except (ValueError, TypeError): + db_dict[field_name] = None + + # 2. 过滤数据库中无关字段(如 api_key_masked) + valid_fields = cls.__dataclass_fields__.keys() + filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields} + + # 3. 初始化并返回配置实例 + return cls(**filtered_dict) + + def get_model_type(self) -> str: + """快速判断模型类型(返回:online/local/embedding)""" + if self.is_embedding: + return "embedding" + if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: + return "online" + if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: + return "local" + return "unknown" + + +class BaseLLM(BaseChatModel): + """ + 继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类) + 使其能直接用于 create_agent + """ + # 配置参数(通过 __init__ 初始化) + config: Any = None + model: Any = None + + def __init__(self, config): + super().__init__() # 必须调用父类构造函数 + self.config = config + self.model = None + self._validate_config() + logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}") + + # ---------------------- 必须实现的核心抽象方法(LangChain 协议) ---------------------- + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + 核心同步生成方法(LangChain 要求必须实现) + messages: 消息列表(如 [HumanMessage(content="你好")]) + 返回 ChatResult 类型(LangChain 标准输出) + """ + logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法") + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ** kwargs: Any, + ) -> ChatResult: + """异步生成方法(LangChain 异步协议)""" + logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法") + + @property + def _llm_type(self) -> str: + """返回模型类型标识(如 "openai"、"llama"、"bge")""" + return self.__class__.__name__ + + def load_model(self) -> None: + """加载模型(自定义逻辑)""" + logger.error(f"{self.__class__.__name__} 未实现 load_model 方法") + + def close(self) -> None: + """释放资源(自定义逻辑)""" + if self.model: + logger.info(f"释放 {self.__class__.__name__} 模型资源") + self.model = None + + def __enter__(self): + self.load_model() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/backend/th_agenter/llm/embed/embed_llm.py b/backend/th_agenter/llm/embed/embed_llm.py new file mode 100644 index 0000000..e963389 --- /dev/null +++ b/backend/th_agenter/llm/embed/embed_llm.py @@ -0,0 +1,77 @@ +from typing import List +from langchain_core.embeddings import Embeddings +from loguru import logger +from th_agenter.llm.base_llm import BaseLLM + +class EmbedLLM(BaseLLM, Embeddings): + """嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel""" + def __init__(self, config): + logger.info(f"初始化 EmbedLLM 模型: {config.model_name}") + super().__init__(config) + logger.info(f"已加载 EmbedLLM 模型: {config.model_name}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """LangChain 要求的核心方法:批量文档向量化""" + pass + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """异步批量向量化""" + pass + + def embed_query(self, text: str) -> List[float]: + """单查询文本向量化""" + pass + + async def aembed_query(self, text: str) -> List[float]: + """异步单查询向量化""" + pass + +# 具体实现 BGE 嵌入式模型 +class BGEEmbedLLM(EmbedLLM): + def __init__(self, config): + super().__init__(config) + + def _validate_config(self): + if not self.config.model_name: + raise ValueError("必须配置 model_name") + + def load_model(self): + logger.info(f"正在加载 嵌入 模型: {self.config.model_name}") + if hasattr(self.config, 'provider') and self.config.provider == 'ollama': + from langchain_ollama import OllamaEmbeddings + self.model = OllamaEmbeddings( + model=self.config.model_name, + base_url=self.config.base_url if hasattr(self.config, 'base_url') else None + ) + else: + try: + from langchain_huggingface import HuggingFaceEmbeddings + self.model = HuggingFaceEmbeddings( + model_name=self.config.model_name, + model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"}, + encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True} + ) + except ImportError as e: + logger.error(f"Failed to load HuggingFaceEmbeddings: {e}") + logger.error("Please install sentence-transformers: pip install sentence-transformers") + raise + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + if not self.model: + self.load_model() + return self.model.embed_documents(texts) + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + if not self.model: + self.load_model() + return await self.model.aembed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + if not self.model: + self.load_model() + return self.model.embed_query(text) + + async def aembed_query(self, text: str) -> List[float]: + if not self.model: + self.load_model() + return await self.model.aembed_query(text) \ No newline at end of file diff --git a/backend/th_agenter/llm/llm_model_base.py b/backend/th_agenter/llm/llm_model_base.py new file mode 100644 index 0000000..ea49cf5 --- /dev/null +++ b/backend/th_agenter/llm/llm_model_base.py @@ -0,0 +1,70 @@ +import os, dotenv +from loguru import logger +from utils.Constant import Constant +from langchain_core.prompts import PromptTemplate +from langchain_core.messages import HumanMessage +# 加载环境变量 +dotenv.load_dotenv() +os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") +os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL") + +class LLM_Model_Base(object): + ''' + 语言模型基类 + 所有语言模型类的基类,定义了语言模型的基本属性和方法。 + - 语言模型名称, 缺省为"gpt-4o-mini" + - 温度,缺省为0.7 + - 语言模型实例, 由子类实现 + - 语言模型模式, 由子类实现 + - 语言模型名称, 用于描述语言模型, 在人机界面中显示 + + author: DrGraph + date: 2025-11-20 + ''' + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + self.model_name = model_name # 0.15 0.6 + self.temperature = temperature + self.llmModel = None + self.mode = Constant.LLM_MODE_NONE + self.name = '未知模型' + + def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str): + logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}") + prompt_template = PromptTemplate.from_template( + template="请回答以下问题: {question}", + ) + prompt_template_value = None + if methodType == "format": + # 方式1 - 使用format方法,取得字符串 + prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串 + logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法,取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}") + + if valueType == "str": + # 1.1 直接用字符串进行调用LLM的invoke + prompt_template_value = prompt_str + logger.info(f"{self.name} >>> 1.2.1 直接使用字符串") + + elif valueType == "messages": + # 1.2 由字符串,创建HumanMessage对象列表 + prompt_template_value = [HumanMessage(content=prompt)] + logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表") + + elif methodType == "invoke": + # 方式2 - 使用invoke方法,取得PromptValue + prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue + logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法,取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}") + if valueType == "str": + # 2.1 再倒回字符串方式 + prompt_template_value = prompt_value.to_string() + logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串") + elif valueType == "promptValue": + # 2.2 直接使用 prompt_value 作为 prompt_template_value + prompt_template_value = prompt_value + logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value") + elif valueType == "messages": + # 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表 + prompt_template_value = prompt_value.to_messages() + logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表") + + logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}") + return prompt_template_value \ No newline at end of file diff --git a/backend/th_agenter/llm/llm_model_chat.py b/backend/th_agenter/llm/llm_model_chat.py new file mode 100644 index 0000000..bd14f62 --- /dev/null +++ b/backend/th_agenter/llm/llm_model_chat.py @@ -0,0 +1,29 @@ +from langchain_openai import ChatOpenAI +from loguru import logger + +from DrGraph.utils.Constant import Constant +from LLM.llm_model_base import LLM_Model_Base + +class Chat_LLM(LLM_Model_Base): + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '聊天模型' + self.mode = Constant.LLM_MODE_CHAT + self.llmModel = ChatOpenAI( + model_name=self.model_name, + temperature=self.temperature, + ) + + # 返回消息格式,以便在chatbot中显示 + def invoke(self, prompt: str): + prompt_template_value = self.buildPromptTemplateValue( + prompt=prompt, + methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE, + valueType=Constant.LLM_PROMPT_VALUE_MESSAGES) + try: + response = self.llmModel.invoke(prompt_template_value) + logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}") + # response = {"role": "assistant", "content": response.content} + except Exception as e: + logger.error(e) + return response diff --git a/backend/th_agenter/llm/llm_model_nonchat.py b/backend/th_agenter/llm/llm_model_nonchat.py new file mode 100644 index 0000000..6ada9a7 --- /dev/null +++ b/backend/th_agenter/llm/llm_model_nonchat.py @@ -0,0 +1,44 @@ +''' +非聊天模型类,继承自 LLM_Model_Base + +author: DrGraph +date: 2025-11-20 +''' +from loguru import logger +from langchain_openai import OpenAI +from langchain_core.messages import AIMessage +from DrGraph.utils.Constant import Constant +from LLM.llm_model_base import LLM_Model_Base + + +class NonChat_LLM(LLM_Model_Base): + ''' + 非聊天模型类,继承自 LLM_Model_Base,调用这个非聊天模型OpenAI + - 语言模型名称, 缺省为"gpt-4o-mini" + - 温度,缺省为0.7 + - 语言模型名称 = "非聊天模型", 在人机界面中显示 + ''' + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '非聊天模型' + self.mode = Constant.LLM_MODE_NONCHAT + self.llmModel = OpenAI( + model_name=self.model_name, + temperature=self.temperature, + ) + # 返回消息格式,以便在chatbot中显示 + def invoke(self, prompt: str): + ''' + 调用非聊天模型,返回消息格式,以便在chatbot中显示 + prompt: 用户输入,为字符串类型 + return: 助手回复,为字符串类型 + ''' + logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}") + try: + response = self.llmModel.invoke(prompt) + logger.info(f"{self.name} >>> 1.2 助手回复: {type(response)}") + except Exception as e: + logger.error(e) + return response + + \ No newline at end of file diff --git a/backend/th_agenter/llm/llm_model_ollama.py b/backend/th_agenter/llm/llm_model_ollama.py new file mode 100644 index 0000000..e034149 --- /dev/null +++ b/backend/th_agenter/llm/llm_model_ollama.py @@ -0,0 +1,28 @@ +from loguru import logger + +from DrGraph.utils.Constant import Constant +from LLM.llm_model_base import LLM_Model_Base +from langchain_ollama import ChatOllama +class Chat_Ollama(LLM_Model_Base): + def __init__(self, base_url="http://127.0.0.1:11434", model_name: str = "OxW/Qwen3-0.6B-GGUF:latest", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '私有化Ollama模型' + self.base_url = base_url + self.llmModel = ChatOllama( + base_url = self.base_url, + model=model_name, + temperature=temperature + ) + self.mode = Constant.LLM_MODE_LOCAL_OLLAMA + + def invoke(self, prompt: str): + prompt_template_value = self.buildPromptTemplateValue( + prompt=prompt, + methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE, + valueType=Constant.LLM_PROMPT_VALUE_MESSAGES) + try: + response = self.llmModel.invoke(prompt_template_value) + logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}") + except Exception as e: + logger.error(e) + return response \ No newline at end of file diff --git a/backend/th_agenter/llm/local/local_llm.py b/backend/th_agenter/llm/local/local_llm.py new file mode 100644 index 0000000..1ecb63f --- /dev/null +++ b/backend/th_agenter/llm/local/local_llm.py @@ -0,0 +1,68 @@ +from typing import List, Optional +from th_agenter.llm.base_llm import BaseLLM +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessage, HumanMessage +from langchain_core.outputs import ChatResult, ChatGeneration + + +class LocalLLM(BaseLLM): + def __init__(self, config): + super().__init__(config) + self.local_config = config + + def _validate_config(self): + if not self.local_config.model_path: + raise ValueError("LocalLLM 必须配置 model_path") + + def load_model(self): + from langchain_community.llms import LlamaCpp + self.model = LlamaCpp( + model_path=self.local_config.model_path, + temperature=self.local_config.temperature, + max_tokens=self.local_config.max_tokens, + n_ctx=self.local_config.n_ctx, + n_threads=self.local_config.n_threads, + verbose=False + ) + + @property + def _llm_type(self) -> str: + return "llama" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + # 适配 LlamaCpp(非 Chat 模型)的调用方式 + prompt = self._format_messages(messages) + text = self.model.invoke(prompt, stop=stop, **kwargs) + # 构造 ChatResult(LangChain 标准格式) + generation = ChatGeneration(message=AIMessage(content=text)) + return ChatResult(generations=[generation]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + prompt = self._format_messages(messages) + text = await self.model.ainvoke(prompt, stop=stop, **kwargs) + generation = ChatGeneration(message=AIMessage(content=text)) + return ChatResult(generations=[generation]) + + def _format_messages(self, messages: List[BaseMessage]) -> str: + """将 LangChain 消息列表格式化为本地模型的 Prompt""" + prompt_parts = [] + for msg in messages: + if isinstance(msg, HumanMessage): + prompt_parts.append(f"[INST] {msg.content} [/INST]") + elif isinstance(msg, AIMessage): + prompt_parts.append(msg.content) + return "".join(prompt_parts) diff --git a/backend/th_agenter/llm/online/online_llm.py b/backend/th_agenter/llm/online/online_llm.py new file mode 100644 index 0000000..9d1a45d --- /dev/null +++ b/backend/th_agenter/llm/online/online_llm.py @@ -0,0 +1,80 @@ +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, BaseMessage +from typing import List, Optional, Any, Union +from langchain_core.outputs import ChatResult +from th_agenter.llm.base_llm import BaseLLM +from langchain_core.callbacks import CallbackManagerForLLMRun + +class OnlineLLM(BaseLLM): + def __init__(self, config): + super().__init__(config) + + def _validate_config(self): + if not self.config.api_key: + raise ValueError("OnlineLLM 必须配置 api_key") + + def load_model(self): + # from langchain.chat_models import init_chat_model + # self.model = init_chat_model( + # self.config.model_name, + # self.config.api_key) + from langchain_openai import ChatOpenAI + self.model = ChatOpenAI( + api_key=self.config.api_key, + model_name=self.config.model_name, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + base_url=self.config.base_url, + ) + @property + def _llm_type(self) -> str: + return "openai" # 标识模型类型 + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + """委托给底层 LangChain 模型的 _generate 方法""" + if not self.model: + self.load_model() + # 复用底层模型的实现 + return self.model._generate( + messages=messages, + stop=stop, + run_manager=run_manager,** kwargs + ) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + return await self.model._agenerate( + messages=messages, + stop=stop, + run_manager=run_manager,** kwargs + ) + + # ---------------------- 保留自定义的便捷方法 ---------------------- + def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: + """自定义便捷方法:直接传入字符串 prompt 或消息列表""" + if isinstance(prompt, str): + messages = [HumanMessage(content=prompt)] + else: + messages = prompt + result = self._generate(messages, **kwargs) + return result.generations[0].text + + async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: + """自定义便捷异步方法:直接传入字符串 prompt 或消息列表""" + if isinstance(prompt, str): + messages = [HumanMessage(content=prompt)] + else: + messages = prompt + result = await self._agenerate(messages, **kwargs) + return result.generations[0].text \ No newline at end of file diff --git a/backend/th_agenter/models/conversation.py b/backend/th_agenter/models/conversation.py index 51db9b4..508936f 100644 --- a/backend/th_agenter/models/conversation.py +++ b/backend/th_agenter/models/conversation.py @@ -24,15 +24,20 @@ class Conversation(BaseModel): # Relationships removed to eliminate foreign key constraints + def to_dict(self) -> dict: + """Convert conversation to a dictionary.""" + return { + "id": self.id, + "title": self.title, + "user_id": self.user_id, + "knowledge_base_id": self.knowledge_base_id, + "system_prompt": self.system_prompt, + "model_name": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "is_archived": self.is_archived, + "message_count": self.message_count, + "last_message_at": self.last_message_at, + } def __repr__(self): - return f"" - - @property - def message_count(self): - """Get the number of messages in this conversation.""" - return len(self.messages) - - @property - def last_message_at(self): - """Get the timestamp of the last message.""" - return self.messages[-1].created_at or self.created_at + return f"" diff --git a/backend/th_agenter/models/llm_config.py b/backend/th_agenter/models/llm_config.py index 60ad70b..e8f8e65 100644 --- a/backend/th_agenter/models/llm_config.py +++ b/backend/th_agenter/models/llm_config.py @@ -39,7 +39,7 @@ class LLMConfig(BaseModel): last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间 def __repr__(self): - return f"" + return f"" def to_dict(self, include_sensitive=False): """Convert to dictionary, optionally excluding sensitive data.""" @@ -60,7 +60,7 @@ class LLMConfig(BaseModel): 'is_embedding': self.is_embedding, 'extra_config': self.extra_config, 'usage_count': self.usage_count, - 'last_used_at': self.last_used_at + 'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None }) if include_sensitive: @@ -102,8 +102,8 @@ class LLMConfig(BaseModel): if not self.name or not self.name.strip(): return {"valid": False, "error": "配置名称不能为空"} - if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']: - return {"valid": False, "error": "不支持的服务商"} + if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']: + return {"valid": False, "error": f"不支持的服务商 {self.provider}"} if not self.model_name or not self.model_name.strip(): return {"valid": False, "error": "模型名称不能为空"} diff --git a/backend/th_agenter/schemas/llm_config.py b/backend/th_agenter/schemas/llm_config.py index 0ba53b4..da1aaab 100644 --- a/backend/th_agenter/schemas/llm_config.py +++ b/backend/th_agenter/schemas/llm_config.py @@ -34,7 +34,7 @@ class LLMConfigCreate(LLMConfigBase): allowed_providers = [ 'openai', 'azure', 'anthropic', 'google', 'baidu', 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', - 'ollama', 'custom', "doubao" + 'ollama', 'custom', "doubao", "ollama" ] if v.lower() not in allowed_providers: raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') @@ -74,7 +74,7 @@ class LLMConfigUpdate(BaseModel): allowed_providers = [ 'openai', 'azure', 'anthropic', 'google', 'baidu', 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', - 'ollama', 'custom',"doubao" + 'ollama', 'custom',"doubao", "ollama" ] if v.lower() not in allowed_providers: raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') diff --git a/backend/th_agenter/services/agent/agent_service.py b/backend/th_agenter/services/agent/agent_service.py index d99ec00..9d6622b 100644 --- a/backend/th_agenter/services/agent/agent_service.py +++ b/backend/th_agenter/services/agent/agent_service.py @@ -33,14 +33,16 @@ class AgentConfig(BaseModel): class AgentService: """LangChain Agent service with tool calling capabilities.""" - def __init__(self, db_session=None): + def __init__(self): self.settings = get_settings() + + async def initialize(self, session=None): self.tool_registry = ToolRegistry() self.config = AgentConfig() - self.db_session = db_session - self.config_service = AgentConfigService(db_session) if db_session else None + self.session = session + self.config_service = AgentConfigService(session) if session else None self._initialize_tools() - self._load_config() + await self._load_config() def _initialize_tools(self): """Initialize and register all available tools.""" @@ -56,18 +58,17 @@ class AgentService: self.tool_registry.register(tool) logger.info(f"Registered tool: {tool.get_name()}") - def _load_config(self): + async def _load_config(self): """Load configuration from database if available.""" if self.config_service: try: - config_dict = self.config_service.get_config_dict() + config_dict = await self.config_service.get_config_dict() # Update config with database values for key, value in config_dict.items(): if hasattr(self.config, key): setattr(self.config, key, value) - logger.info("Loaded agent configuration from database") except Exception as e: - logger.warning(f"Failed to load config from database, using defaults: {str(e)}") + logger.error(f"Failed to load config from database, using defaults: {str(e)}") def _get_enabled_tools(self) -> List[Any]: """Get list of enabled LangChain tools.""" @@ -83,11 +84,11 @@ class AgentService: return enabled_tools - def _create_agent_executor(self) -> Any: + async def _create_agent_executor(self) -> Any: """Create LangChain agent executor.""" # Get LLM configuration - from ...core.llm import create_llm - llm = create_llm() + from ...core.new_agent import new_agent + llm = await new_agent() # Get enabled tools tools = self._get_enabled_tools() @@ -114,7 +115,7 @@ class AgentService: logger.info(f"Processing agent chat message: {message[:100]}...") # Create agent - agent = self._create_agent_executor() + agent = await self._create_agent_executor() # Convert chat history to LangChain format langchain_history = [] @@ -155,7 +156,7 @@ class AgentService: logger.info(f"Processing agent chat stream: {message[:100]}...") # Create agent - agent = self._create_agent_executor() + agent = await self._create_agent_executor() # Convert chat history to LangChain format langchain_history = [] @@ -263,17 +264,18 @@ class AgentService: # Global agent service instance -_agent_service: Optional[AgentService] = None +_global_agent_service: Optional[AgentService] = None -def get_agent_service(db_session=None) -> AgentService: +async def get_agent_service(session=None) -> AgentService: """Get global agent service instance.""" - global _agent_service - if _agent_service is None: - _agent_service = AgentService(db_session) - elif db_session and not _agent_service.db_session: + global _global_agent_service + if _global_agent_service is None: + _global_agent_service = AgentService() + await _global_agent_service.initialize(session) + elif session and session != _global_agent_service.session: # Update with database session if not already set - _agent_service.db_session = db_session - _agent_service.config_service = AgentConfigService(db_session) - _agent_service._load_config() - return _agent_service + _global_agent_service.session = session + _global_agent_service.config_service = AgentConfigService(session) + _global_agent_service._load_config() + return _global_agent_service diff --git a/backend/th_agenter/services/agent/langgraph_agent_service.py b/backend/th_agenter/services/agent/langgraph_agent_service.py index d2c3efa..443a205 100644 --- a/backend/th_agenter/services/agent/langgraph_agent_service.py +++ b/backend/th_agenter/services/agent/langgraph_agent_service.py @@ -41,16 +41,18 @@ class LangGraphAgentConfig(BaseModel): class LangGraphAgentService: """LangGraph Agent service using low-level LangGraph graph (React pattern).""" - def __init__(self, db_session=None): + def __init__(self): self.settings = get_settings() + + async def initialize(self, session=None): self.tool_registry = ToolRegistry() self.config = LangGraphAgentConfig() self.tools = [] - self.db_session = db_session - self.config_service = AgentConfigService(db_session) if db_session else None + self.session = session + self.config_service = AgentConfigService(session) if session else None self._initialize_tools() - self._load_config() - self._create_react_agent() + await self._load_config() + await self._create_react_agent() def _initialize_tools(self): """Initialize available tools.""" @@ -76,28 +78,29 @@ class LangGraphAgentService: - def _load_config(self): + async def _load_config(self): """Load configuration from database if available.""" - if self.config_service: - try: - db_config = self.config_service.get_active_config() - if db_config: - # Update config with database values - config_dict = db_config.config_data - for key, value in config_dict.items(): - if hasattr(self.config, key): - setattr(self.config, key, value) - logger.info("Loaded configuration from database") - except Exception as e: - logger.warning(f"Failed to load config from database: {e}") + pass + # if self.config_service: + # try: + # db_config = self.config_service.get_active_config() + # if db_config: + # # Update config with database values + # config_dict = db_config.config_data + # for key, value in config_dict.items(): + # if hasattr(self.config, key): + # setattr(self.config, key, value) + # logger.info("Loaded configuration from database") + # except Exception as e: + # logger.exception(f"Failed to load config from database: {e}") - def _create_react_agent(self): + async def _create_react_agent(self): """Create LangGraph agent using low-level StateGraph with explicit nodes/edges.""" try: # Initialize the model - llm_config = get_settings().llm.get_current_config() + llm_config = await get_settings().llm.get_current_config(self.db_session) self.model = init_chat_model( model=llm_config['model'], model_provider='openai', @@ -183,7 +186,7 @@ class LangGraphAgentService: # Compile graph and store as self.agent for compatibility with existing code self.react_agent = graph.compile() - logger.info("LangGraph low-level React agent created successfully") + logger.info("LangGraph 底层 React 智能体创建成功") except Exception as e: logger.error(f"Failed to create agent: {str(e)}") @@ -723,15 +726,14 @@ class LangGraphAgentService: # Global instance -_langgraph_agent_service: Optional[LangGraphAgentService] = None +_global_langgraph_agent_service: Optional[LangGraphAgentService] = None - -def get_langgraph_agent_service(db_session=None) -> LangGraphAgentService: +async def get_langgraph_agent_service(session=None) -> LangGraphAgentService: """Get or create LangGraph agent service instance.""" - global _langgraph_agent_service + global _global_langgraph_agent_service - if _langgraph_agent_service is None: - _langgraph_agent_service = LangGraphAgentService(db_session) - logger.info("LangGraph Agent service initialized") + if _global_langgraph_agent_service is None: + _global_langgraph_agent_service = LangGraphAgentService() + await _global_langgraph_agent_service.initialize(session) - return _langgraph_agent_service \ No newline at end of file + return _global_langgraph_agent_service \ No newline at end of file diff --git a/backend/th_agenter/services/agent_config.py b/backend/th_agenter/services/agent_config.py index 288264d..ac8fbce 100644 --- a/backend/th_agenter/services/agent_config.py +++ b/backend/th_agenter/services/agent_config.py @@ -2,7 +2,7 @@ from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session -from sqlalchemy import and_ +from sqlalchemy import and_, select, update from ..models.agent_config import AgentConfig from utils.util_exceptions import ValidationError, NotFoundError @@ -15,7 +15,7 @@ class AgentConfigService: def __init__(self, db: Session): self.db = db - def create_config(self, config_data: Dict[str, Any]) -> AgentConfig: + async def create_config(self, config_data: Dict[str, Any]) -> AgentConfig: """Create a new agent configuration.""" try: # Validate required fields @@ -23,9 +23,8 @@ class AgentConfigService: raise ValidationError("Configuration name is required") # Check if name already exists - existing = self.db.query(AgentConfig).filter( - AgentConfig.name == config_data["name"] - ).first() + stmt = select(AgentConfig).where(AgentConfig.name == config_data["name"]) + existing = (await self.db.execute(stmt)).scalar_one_or_none() if existing: raise ValidationError(f"Configuration with name '{config_data['name']}' already exists") @@ -44,62 +43,60 @@ class AgentConfigService: # If this is set as default, unset other defaults if config.is_default: - self.db.query(AgentConfig).filter( - AgentConfig.is_default == True - ).update({"is_default": False}) + stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False}) + await self.db.execute(stmt) self.db.add(config) - self.db.commit() - self.db.refresh(config) + await self.db.commit() + await self.db.refresh(config) logger.info(f"Created agent configuration: {config.name}") return config except Exception as e: - self.db.rollback() + await self.db.rollback() logger.error(f"Error creating agent configuration: {str(e)}") raise - def get_config(self, config_id: int) -> Optional[AgentConfig]: + async def get_config(self, config_id: int) -> Optional[AgentConfig]: """Get agent configuration by ID.""" - return self.db.query(AgentConfig).filter( - AgentConfig.id == config_id - ).first() + stmt = select(AgentConfig).where(AgentConfig.id == config_id) + return (await self.db.execute(stmt)).scalar_one_or_none() - def get_config_by_name(self, name: str) -> Optional[AgentConfig]: + async def get_config_by_name(self, name: str) -> Optional[AgentConfig]: """Get agent configuration by name.""" - return self.db.query(AgentConfig).filter( - AgentConfig.name == name - ).first() + stmt = select(AgentConfig).where(AgentConfig.name == name) + return (await self.db.execute(stmt)).scalar_one_or_none() - def get_default_config(self) -> Optional[AgentConfig]: + async def get_default_config(self) -> Optional[AgentConfig]: """Get default agent configuration.""" - return self.db.query(AgentConfig).filter( - and_(AgentConfig.is_default == True, AgentConfig.is_active == True) - ).first() + stmt = select(AgentConfig).where(and_(AgentConfig.is_default == True, AgentConfig.is_active == True)) + return (await self.db.execute(stmt)).scalar_one_or_none() def list_configs(self, active_only: bool = True) -> List[AgentConfig]: """List all agent configurations.""" - query = self.db.query(AgentConfig) + stmt = select(AgentConfig) if active_only: - query = query.filter(AgentConfig.is_active == True) - return query.order_by(AgentConfig.created_at.desc()).all() + stmt = stmt.where(AgentConfig.is_active == True) + stmt = stmt.order_by(AgentConfig.created_at.desc()) + return self.db.execute(stmt).scalars().all() - def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig: + async def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig: """Update agent configuration.""" try: - config = self.get_config(config_id) + config = await self.get_config(config_id) if not config: raise NotFoundError(f"Agent configuration with ID {config_id} not found") # Check if name change conflicts with existing if "name" in config_data and config_data["name"] != config.name: - existing = self.db.query(AgentConfig).filter( + stmt = select(AgentConfig).where( and_( AgentConfig.name == config_data["name"], AgentConfig.id != config_id ) - ).first() + ) + existing = (await self.db.execute(stmt)).scalar_one_or_none() if existing: raise ValidationError(f"Configuration with name '{config_data['name']}' already exists") @@ -110,28 +107,29 @@ class AgentConfigService: # If this is set as default, unset other defaults if config_data.get("is_default", False): - self.db.query(AgentConfig).filter( + stmt = update(AgentConfig).where( and_( AgentConfig.is_default == True, AgentConfig.id != config_id ) - ).update({"is_default": False}) + ).values({"is_default": False}) + await self.db.execute(stmt) - self.db.commit() - self.db.refresh(config) + await self.db.commit() + await self.db.refresh(config) logger.info(f"Updated agent configuration: {config.name}") return config except Exception as e: - self.db.rollback() + await self.db.rollback() logger.error(f"Error updating agent configuration: {str(e)}") raise - def delete_config(self, config_id: int) -> bool: + async def delete_config(self, config_id: int) -> bool: """Delete agent configuration (soft delete by setting is_active=False).""" try: - config = self.get_config(config_id) + config = await self.get_config(config_id) if not config: raise NotFoundError(f"Agent configuration with ID {config_id} not found") @@ -146,14 +144,14 @@ class AgentConfigService: return True except Exception as e: - self.db.rollback() + await self.db.rollback() logger.error(f"Error deleting agent configuration: {str(e)}") raise - def set_default_config(self, config_id: int) -> AgentConfig: + async def set_default_config(self, config_id: int) -> AgentConfig: """Set a configuration as default.""" try: - config = self.get_config(config_id) + config = await self.get_config(config_id) if not config: raise NotFoundError(f"Agent configuration with ID {config_id} not found") @@ -161,29 +159,28 @@ class AgentConfigService: raise ValidationError("Cannot set inactive configuration as default") # Unset other defaults - self.db.query(AgentConfig).filter( - AgentConfig.is_default == True - ).update({"is_default": False}) + stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False}) + await self.db.execute(stmt) # Set this as default config.is_default = True - self.db.commit() - self.db.refresh(config) + await self.db.commit() + await self.db.refresh(config) logger.info(f"Set default agent configuration: {config.name}") return config except Exception as e: - self.db.rollback() + await self.db.rollback() logger.error(f"Error setting default agent configuration: {str(e)}") raise - def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]: + async def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]: """Get configuration as dictionary. If no ID provided, returns default config.""" if config_id: - config = self.get_config(config_id) + config = await self.get_config(config_id) else: - config = self.get_default_config() + config = await self.get_default_config() if not config: # Return default values if no configuration found diff --git a/backend/th_agenter/services/auth.py b/backend/th_agenter/services/auth.py index 36bad37..15cea9f 100644 --- a/backend/th_agenter/services/auth.py +++ b/backend/th_agenter/services/auth.py @@ -33,19 +33,16 @@ class AuthService: ) token = credentials.credentials - session.desc = f"[AuthService] 取得token: {token[:50]}..." payload = AuthService.verify_token(token) if payload is None: - session.desc = "ERROR: 令牌验证失败" + session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..." raise credentials_exception - session.desc = f"[AuthService] 令牌有效 - 解析得到有效载荷: {payload}" username: str = payload.get("sub") if username is None: session.desc = "ERROR: 令牌中没有用户名" raise credentials_exception - session.desc = f"[AuthService] 获取当前用户 - 查找名为 {username} 的用户" stmt = select(User).where(User.username == username) user = (await session.execute(stmt)).scalar_one_or_none() if user is None: @@ -53,8 +50,8 @@ class AuthService: raise credentials_exception # Set user in context for global access - UserContext.set_current_user(user) - session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户" + UserContext.set_current_user(user, canLog=True) + # session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户" return user @@ -138,6 +135,4 @@ class AuthService: except jwt.PyJWTError as e: logger.error(f"Token verification failed: {e}") logger.error(f"Token: {token[:50]}...") - logger.error(f"Secret key: {settings.security.secret_key[:20]}...") - logger.error(f"Algorithm: {settings.security.algorithm}") - return None + return None \ No newline at end of file diff --git a/backend/th_agenter/services/chat.py b/backend/th_agenter/services/chat.py index 1e53116..fe32531 100644 --- a/backend/th_agenter/services/chat.py +++ b/backend/th_agenter/services/chat.py @@ -1,45 +1,37 @@ """Chat service for AI model integration using LangChain.""" +from th_agenter import db import json import asyncio import os -from typing import AsyncGenerator, Optional, List, Dict, Any +from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict from sqlalchemy.orm import Session from loguru import logger + +from th_agenter.core.new_agent import new_agent, new_llm from ..core.config import settings from ..models.message import MessageRole from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse -from utils.util_exceptions import ChatServiceError, OpenAIError +from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError from .conversation import ConversationService from .langchain_chat import LangChainChatService from .knowledge_chat import KnowledgeChatService from .agent.agent_service import get_agent_service from .agent.langgraph_agent_service import get_langgraph_agent_service +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph +class AgentState(TypedDict): + messages: List[dict] # 存储对话消息(核心记忆) + class ChatService: """Service for handling AI chat functionality using LangChain.""" - - def __init__(self, db: Session): - self.db = db - self.conversation_service = ConversationService(db) - - # Initialize LangChain chat service - self.langchain_service = LangChainChatService(db) - - # Initialize Knowledge chat service - self.knowledge_service = KnowledgeChatService(db) - - # Initialize Agent service with database session - self.agent_service = get_agent_service(db) - - # Initialize LangGraph Agent service with database session - self.langgraph_service = get_langgraph_agent_service(db) - - logger.info("ChatService initialized with LangChain backend and Agent support") - + _checkpointer_initialized = False + _conn_string = None - async def chat( self, conversation_id: int, @@ -57,7 +49,7 @@ class ChatService: logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}") # Use knowledge base chat service - return await self.knowledge_service.chat_with_knowledge_base( + return await self.knowledge_chat_service.chat_with_knowledge_base( conversation_id=conversation_id, message=message, knowledge_base_id=knowledge_base_id, @@ -69,29 +61,29 @@ class ChatService: logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent") # Get conversation history for LangGraph agent - conversation = self.conversation_service.get_conversation(conversation_id) + conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError(f"Conversation {conversation_id} not found") - messages = self.conversation_service.get_conversation_messages(conversation_id) + messages = await self.conversation_service.get_conversation_messages(conversation_id) chat_history = [{ "role": "user" if msg.role == MessageRole.USER else "assistant", "content": msg.content } for msg in messages] # Use LangGraph agent service - agent_result = await self.langgraph_service.chat(message, chat_history) + agent_result = await self.langgraph_agent_service.chat(message, chat_history) if agent_result["success"]: # Save user message - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Save assistant response - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=agent_result["response"], role=MessageRole.ASSISTANT, @@ -114,11 +106,11 @@ class ChatService: logger.info(f"Processing chat request for conversation {conversation_id} via Agent") # Get conversation history for agent - conversation = self.conversation_service.get_conversation(conversation_id) + conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError(f"Conversation {conversation_id} not found") - messages = self.conversation_service.get_conversation_messages(conversation_id) + messages = await self.conversation_service.get_conversation_messages(conversation_id) chat_history = [{ "role": "user" if msg.role == MessageRole.USER else "assistant", "content": msg.content @@ -129,14 +121,14 @@ class ChatService: if agent_result["success"]: # Save user message - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Save assistant response - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=agent_result["response"], role=MessageRole.ASSISTANT, @@ -159,7 +151,7 @@ class ChatService: logger.info(f"Processing chat request for conversation {conversation_id} via LangChain") # Delegate to LangChain service - return await self.langchain_service.chat( + return await self.langchain_chat_service.chat( conversation_id=conversation_id, message=message, stream=stream, @@ -167,156 +159,12 @@ class ChatService: max_tokens=max_tokens ) - async def chat_stream( - self, - conversation_id: int, - message: str, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - use_agent: bool = False, - use_langgraph: bool = False, - use_knowledge_base: bool = False, - knowledge_base_id: Optional[int] = None - ) -> AsyncGenerator[str, None]: - """Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base.""" - if use_knowledge_base and knowledge_base_id: - logger.info(f"Processing streaming chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}") - - # Use knowledge base chat service streaming - async for content in self.knowledge_service.chat_stream_with_knowledge_base( - conversation_id=conversation_id, - message=message, - knowledge_base_id=knowledge_base_id, - temperature=temperature, - max_tokens=max_tokens - ): - # Create stream chunk for compatibility with existing API - stream_chunk = StreamChunk( - content=content, - role=MessageRole.ASSISTANT - ) - yield json.dumps(stream_chunk.dict(), ensure_ascii=False) - elif use_langgraph: - logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangGraph Agent") - - # Get conversation history for LangGraph agent - conversation = self.conversation_service.get_conversation(conversation_id) - if not conversation: - raise ChatServiceError(f"Conversation {conversation_id} not found") - - messages = self.conversation_service.get_conversation_messages(conversation_id) - chat_history = [{ - "role": "user" if msg.role == MessageRole.USER else "assistant", - "content": msg.content - } for msg in messages] - - # Save user message first - user_message = self.conversation_service.add_message( - conversation_id=conversation_id, - content=message, - role=MessageRole.USER - ) - - # Use LangGraph agent service streaming - full_response = "" - intermediate_steps = [] - - async for chunk in self.langgraph_service.chat_stream(message, chat_history): - if chunk["type"] == "response": - full_response = chunk["content"] - intermediate_steps = chunk.get("intermediate_steps", []) - - # Return the chunk as-is to maintain type information - yield json.dumps(chunk, ensure_ascii=False) - - elif chunk["type"] == "error": - # Return the chunk as-is to maintain type information - yield json.dumps(chunk, ensure_ascii=False) - return - else: - # For other types (status, step, etc.), pass through - yield json.dumps(chunk, ensure_ascii=False) - - # Save assistant response - if full_response: - self.conversation_service.add_message( - conversation_id=conversation_id, - content=full_response, - role=MessageRole.ASSISTANT, - message_metadata={"intermediate_steps": intermediate_steps} - ) - elif use_agent: - logger.info(f"Processing streaming chat request for conversation {conversation_id} via Agent") - - # Get conversation history for agent - conversation = self.conversation_service.get_conversation(conversation_id) - if not conversation: - raise ChatServiceError(f"Conversation {conversation_id} not found") - - messages = self.conversation_service.get_conversation_messages(conversation_id) - chat_history = [{ - "role": "user" if msg.role == MessageRole.USER else "assistant", - "content": msg.content - } for msg in messages] - - # Save user message first - user_message = self.conversation_service.add_message( - conversation_id=conversation_id, - content=message, - role=MessageRole.USER - ) - - # Use agent service streaming - full_response = "" - tool_calls = [] - - async for chunk in self.agent_service.chat_stream(message, chat_history): - if chunk["type"] == "response": - full_response = chunk["content"] - tool_calls = chunk.get("tool_calls", []) - - # Return the chunk as-is to maintain type information - yield json.dumps(chunk, ensure_ascii=False) - - elif chunk["type"] == "error": - # Return the chunk as-is to maintain type information - yield json.dumps(chunk, ensure_ascii=False) - return - else: - # For other types (status, tool_start, etc.), pass through - yield json.dumps(chunk, ensure_ascii=False) - - # Save assistant response - if full_response: - self.conversation_service.add_message( - conversation_id=conversation_id, - content=full_response, - role=MessageRole.ASSISTANT, - message_metadata={"tool_calls": tool_calls} - ) - else: - logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangChain") - - # Delegate to LangChain service and wrap response in JSON format - async for content in self.langchain_service.chat_stream( - conversation_id=conversation_id, - message=message, - temperature=temperature, - max_tokens=max_tokens - ): - # Create stream chunk for compatibility with existing API - stream_chunk = StreamChunk( - content=content, - role=MessageRole.ASSISTANT - ) - yield json.dumps(stream_chunk.dict(), ensure_ascii=False) - async def get_available_models(self) -> List[str]: """Get list of available models from LangChain.""" logger.info("Getting available models via LangChain") # Delegate to LangChain service - return await self.langchain_service.get_available_models() + return await self.langchain_chat_service.get_available_models() def update_model_config( self, @@ -328,8 +176,135 @@ class ChatService: logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}") # Delegate to LangChain service - self.langchain_service.update_model_config( + self.langchain_chat_service.update_model_config( model=model, temperature=temperature, max_tokens=max_tokens - ) \ No newline at end of file + ) + # ------------------------------------------------------------------------- + def __init__(self, session: Session): + self.session = session + + async def initialize(self, conversation_id: int, streaming: bool = False): + self.conversation_service = ConversationService(self.session) + self.session.desc = "ChatService初始化 - ConversationService 实例化完毕" + self.conversation = await self.conversation_service.get_conversation( + conversation_id=conversation_id + ) + if not self.conversation: + raise ChatServiceError(f"Conversation {conversation_id} not found") + + if not ChatService._checkpointer_initialized: + from langgraph.checkpoint.postgres import PostgresSaver + import psycopg2 + CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres" + ChatService._conn_string = CONN_STRING + + # 检查必要的表是否已存在 + tables_need_setup = True + try: + # 连接到数据库并检查表是否存在 + conn = psycopg2.connect(CONN_STRING) + cursor = conn.cursor() + + # 检查langgraph需要的表是否存在 + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs') + """) + existing_tables = [row[0] for row in cursor.fetchall()] + + # 检查是否所有必要的表都存在 + required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs'] + if all(table in existing_tables for table in required_tables): + tables_need_setup = False + self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup" + + cursor.close() + conn.close() + + except Exception as e: + self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)},将进行setup" + tables_need_setup = True + + # 只有在需要时才进行setup + if tables_need_setup: + self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup" + try: + async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer: + await checkpointer.setup() + self.session.desc = "ChatService初始化 - PostgresSaver setup完成" + logger.info("PostgresSaver setup完成") + except Exception as e: + self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}" + logger.error(f"PostgresSaver setup失败: {e}") + raise + else: + self.session.desc = "ChatService初始化 - 使用现有的langgraph表" + + # 存储连接字符串供后续使用 + ChatService._checkpointer_initialized = True + + self.llm = await new_llm(session=self.session, streaming=streaming) + self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}" + + def get_config(self): + config = { + "configurable": { + "thread_id": str(self.conversation.id), + "checkpoint_ns": "drgraph" + } + } + return config + + async def chat_stream( + self, + message: str + ) -> AsyncGenerator[str, None]: + """Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base.""" + self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}" + await self.conversation_service.add_message( + conversation_id=self.conversation.id, + role=MessageRole.USER, + content=message + ) + full_assistant_content = "" + + async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: + from langchain.agents import create_agent + agent = create_agent( + model=self.llm, # await new_llm(session=self.session, streaming=self.streaming), + checkpointer=checkpointer + ) + async for chunk in agent.astream( + {"messages": [{"role": "user", "content": message}]}, + config=self.get_config(), + stream_mode="messages" + ): + full_assistant_content += chunk[0].content + json_result = {"data": {"v": chunk[0].content }} + yield json.dumps( + json_result, + ensure_ascii=True + ) + + if len(full_assistant_content) > 0: + await self.conversation_service.add_message( + conversation_id=self.conversation.id, + role=MessageRole.ASSISTANT, + content=full_assistant_content + ) + + def get_conversation_history_messages( + self, conversation_id: int, skip: int = 0, limit: int = 100 + ): + """Get conversation history messages with pagination.""" + result = [] + with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: + checkpoints = checkpointer.list(self.get_config()) + for checkpoint in checkpoints: + print(checkpoint) + result.append(checkpoint.messages) + return result diff --git a/backend/th_agenter/services/conversation.py b/backend/th_agenter/services/conversation.py index 1558af7..f87c158 100644 --- a/backend/th_agenter/services/conversation.py +++ b/backend/th_agenter/services/conversation.py @@ -3,6 +3,9 @@ from typing import List, Optional from sqlalchemy.orm import Session from sqlalchemy import select, desc, func, or_ +from langchain_core.messages import HumanMessage, AIMessage + +from th_agenter.db.database import AsyncSessionFactory from ..models.conversation import Conversation from ..models.message import Message, MessageRole @@ -24,7 +27,7 @@ class ConversationService: conversation_data: ConversationCreate ) -> Conversation: """Create a new conversation.""" - logger.info(f"Creating new conversation for user {user_id}: {conversation_data}") + self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}" try: conversation = Conversation( @@ -39,18 +42,23 @@ class ConversationService: await self.session.commit() await self.session.refresh(conversation) - logger.info(f"Successfully created conversation {conversation.id} for user {user_id}") + self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}" return conversation except Exception as e: - logger.error(f"Failed to create conversation: {str(e)}", exc_info=True) + self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}" await self.session.rollback() - raise DatabaseError(f"Failed to create conversation: {str(e)}") + raise DatabaseError(f"创建会话失败: {str(e)}") async def get_conversation(self, conversation_id: int) -> Optional[Conversation]: """Get a conversation by ID.""" try: user_id = UserContext.get_current_user_id() + self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}" + if user_id is None: + logger.error(f"Failed to get conversation {conversation_id}: No user context available") + return None + conversation = await self.session.scalar( select(Conversation).where( Conversation.id == conversation_id, @@ -59,12 +67,11 @@ class ConversationService: ) if not conversation: - logger.warning(f"Conversation {conversation_id} not found") - + self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}" return conversation except Exception as e: - logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True) + self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}" raise DatabaseError(f"Failed to get conversation: {str(e)}") async def get_user_conversations( @@ -78,6 +85,10 @@ class ConversationService: ) -> List[Conversation]: """Get user's conversations with search and filtering.""" user_id = UserContext.get_current_user_id() + if user_id is None: + logger.error("Failed to get user conversations: No user context available") + return [] + query = select(Conversation).where( Conversation.user_id == user_id ) @@ -137,7 +148,7 @@ class ConversationService: if not conversation: return False - self.session.delete(conversation) + await self.session.delete(conversation) await self.session.commit() return True @@ -180,19 +191,42 @@ class ConversationService: # Set audit fields message.set_audit_fields() - self.session.add(message) - await self.session.commit() - await self.session.refresh(message) + session = AsyncSessionFactory() + session.begin() + try: + session.add(message) + await session.commit() + await session.refresh(message) - # Update conversation's updated_at timestamp - conversation = await self.get_conversation(conversation_id) - if conversation: - conversation.updated_at = datetime.now(timezone.utc) - conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) - await self.session.commit() + # Update conversation's updated_at timestamp + conversation = await self.get_conversation(conversation_id) + if conversation: + conversation.updated_at = datetime.now(timezone.utc) + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + await session.commit() + except Exception as e: + logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True) + await session.rollback() + finally: + await session.close() return message + async def get_conversation_history_messages( + self, + conversation_id: int, + limit: int = 20 + ) -> List[Message]: + """Get recent conversation history messages.""" + history = await self.get_conversation_history(conversation_id, limit) + history_messages = [] + for message in history: + if message.role == MessageRole.USER: + history_messages.append(HumanMessage(content=message.content)) + elif message.role == MessageRole.ASSISTANT: + history_messages.append(AIMessage(content=message.content)) + return history_messages + async def get_conversation_history( self, conversation_id: int, diff --git a/backend/th_agenter/services/conversation_context.py b/backend/th_agenter/services/conversation_context.py index 59f2616..bc34b2f 100644 --- a/backend/th_agenter/services/conversation_context.py +++ b/backend/th_agenter/services/conversation_context.py @@ -27,7 +27,7 @@ class ConversationContextService: 新创建的对话ID """ try: - session = next(get_session()) + session = await anext(get_session()) conversation = Conversation( user_id=user_id, @@ -37,8 +37,8 @@ class ConversationContextService: ) session.add(conversation) - session.commit() - session.refresh(conversation) + await session.commit() + await session.refresh(conversation) # 初始化对话上下文 self.context_cache[conversation.id] = { @@ -74,7 +74,7 @@ class ConversationContextService: # 从数据库加载 try: - session = next(get_session()) + session = await anext(get_session()) conversation = session.query(Conversation).filter( Conversation.id == conversation_id @@ -194,7 +194,7 @@ class ConversationContextService: 保存是否成功 """ try: - session = next(get_session()) + session = await anext(get_session()) message = Message( conversation_id=conversation_id, @@ -205,7 +205,7 @@ class ConversationContextService: ) session.add(message) - session.commit() + await session.commit() # 更新对话的最后更新时间 conversation = session.query(Conversation).filter( @@ -214,7 +214,7 @@ class ConversationContextService: if conversation: conversation.updated_at = datetime.utcnow() - session.commit() + await session.commit() return True @@ -262,7 +262,7 @@ class ConversationContextService: 消息历史列表 """ try: - session = next(get_session()) + session = await anext(get_session()) messages = session.query(Message).filter( Message.conversation_id == conversation_id diff --git a/backend/th_agenter/services/database_config_service.py b/backend/th_agenter/services/database_config_service.py index 2e7a44f..5e8ffba 100644 --- a/backend/th_agenter/services/database_config_service.py +++ b/backend/th_agenter/services/database_config_service.py @@ -102,41 +102,41 @@ class DatabaseConfigService: ) self.session.add(db_config) - self.session.commit() - self.session.refresh(db_config) + await self.session.commit() + await self.session.refresh(db_config) logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})") return db_config except Exception as e: - self.session.rollback() + await self.session.rollback() logger.error(f"创建数据库配置失败: {str(e)}") raise - def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]: + async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]: """获取用户的数据库配置列表""" stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id) if active_only: stmt = stmt.where(DatabaseConfig.is_active == True) stmt = stmt.order_by(DatabaseConfig.created_at.desc()) - return self.session.scalars(stmt).all() + return (await self.session.execute(stmt)).scalars().all() - def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]: + async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]: """根据ID获取配置""" stmt = select(DatabaseConfig).where( DatabaseConfig.id == config_id, DatabaseConfig.created_by == user_id ) - return self.session.scalar(stmt) + return (await self.session.execute(stmt)).scalar_one_or_none() - def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]: + async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]: """获取用户的默认配置""" stmt = select(DatabaseConfig).where( DatabaseConfig.created_by == user_id, # DatabaseConfig.is_default == True, DatabaseConfig.is_active == True ) - return self.session.scalar(stmt) + return (await self.session.execute(stmt)).scalar_one_or_none() async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]: """测试数据库连接""" @@ -278,14 +278,14 @@ class DatabaseConfigService: 'message': f'断开连接失败: {str(e)}' } - def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]: + async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]: """根据数据库类型获取用户配置""" stmt = select(DatabaseConfig).where( DatabaseConfig.created_by == user_id, DatabaseConfig.db_type == db_type, DatabaseConfig.is_active == True ) - return self.session.scalar(stmt) + return await self.session.scalar(stmt) async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig: """创建或更新数据库配置(保证db_type唯一性)""" @@ -301,8 +301,8 @@ class DatabaseConfigService: elif hasattr(existing_config, key): setattr(existing_config, key, value) - self.session.commit() - self.session.refresh(existing_config) + await self.session.commit() + await self.session.refresh(existing_config) logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})") return existing_config else: @@ -310,7 +310,7 @@ class DatabaseConfigService: return await self.create_config(user_id, config_data) except Exception as e: - self.session.rollback() + await self.session.rollback() logger.error(f"创建或更新数据库配置失败: {str(e)}") raise diff --git a/backend/th_agenter/services/document.py b/backend/th_agenter/services/document.py index b84e00b..04c49f8 100644 --- a/backend/th_agenter/services/document.py +++ b/backend/th_agenter/services/document.py @@ -30,7 +30,7 @@ class DocumentService: self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}" # Validate knowledge base exists stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) - kb = self.session.scalar(stmt) + kb = await self.session.scalar(stmt) if not kb: self.session.desc = f"ERROR: 知识库 {kb_id} 不存在" raise ValueError(f"知识库 {kb_id} 不存在") @@ -48,6 +48,7 @@ class DocumentService: # Upload file using storage service storage_info = await storage_service.upload_file(file, kb_id) + self.session.desc = f"文档 {file.filename} 上传到 {storage_info}" # Create document record document = Document( @@ -65,21 +66,21 @@ class DocumentService: document.set_audit_fields() self.session.add(document) - self.session.commit() - self.session.refresh(document) + await self.session.commit() + await self.session.refresh(document) - self.session.desc = f"SUCCESS: 成功上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})" + self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})" return document - def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]: + async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]: """根据文档ID查询文档,可选地根据知识库ID过滤。""" - self.session.desc = f"查询文档 {doc_id}" + self.session.desc = f"根据文档ID查询文档 {doc_id}" stmt = select(Document).where(Document.id == doc_id) if kb_id is not None: stmt = stmt.where(Document.knowledge_base_id == kb_id) - return self.session.scalar(stmt) + return await self.session.scalar(stmt) - def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]: + async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]: """根据知识库ID查询文档,支持分页。""" self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" stmt = ( @@ -88,14 +89,14 @@ class DocumentService: .offset(skip) .limit(limit) ) - return self.session.scalars(stmt).all() + return (await self.session.scalars(stmt)).all() - def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]: + async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]: """根据知识库ID查询文档,支持分页,并返回总文档数。""" self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" # Get total count count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id) - total = self.session.scalar(count_stmt) + total = await self.session.scalar(count_stmt) # Get documents with pagination documents_stmt = ( @@ -104,44 +105,52 @@ class DocumentService: .offset(skip) .limit(limit) ) - documents = self.session.scalars(documents_stmt).all() + documents = (await self.session.scalars(documents_stmt)).all() return documents, total - def delete_document(self, doc_id: int, kb_id: int = None) -> bool: + async def delete_document(self, doc_id: int, kb_id: int = None) -> bool: """根据文档ID删除文档,可选地根据知识库ID过滤。""" self.session.desc = f"删除文档 {doc_id}" - document = self.get_document(doc_id, kb_id) + document = await self.get_document(doc_id, kb_id) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return False # Delete file from storage try: - storage_service.delete_file(document.file_path) - logger.info(f"Deleted file: {document.file_path}") + await storage_service.delete_file(document.file_path) + self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}" except Exception as e: self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}" # TODO: Remove from vector database # This should be implemented when vector database service is ready - get_document_processor().delete_document_from_vector_store(kb_id,doc_id) + self.session.desc = f"从向量数据库删除文档 {doc_id}" + (await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id) # Delete database record - self.session.delete(document) - self.session.commit() + self.session.desc = f"删除数据库记录 {doc_id}" + await self.session.delete(document) + await self.session.commit() self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}" return True async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]: """处理文档,提取文本并创建嵌入向量。""" try: - self.session.desc = f"处理文档 {doc_id}" - document = self.get_document(doc_id, kb_id) + self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量" + document = await self.get_document(doc_id, kb_id) + self.session.desc = f"获取文档 {doc_id} >>> {document}" if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" raise ValueError(f"Document {doc_id} not found") - if document.is_processed: + # document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素 + file_path = document.file_path + knowledge_base_id=document.knowledge_base_id + is_processed=document.is_processed + + if is_processed: self.session.desc = f"INFO: 文档 {doc_id} 已处理" return { "document_id": doc_id, @@ -149,38 +158,43 @@ class DocumentService: "message": "文档已处理" } + self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}" # 更新文档状态为处理中 document.processing_error = None - self.session.commit() + await self.session.commit() + self.session.desc = f"更新文档状态为处理中 {doc_id}" # 调用文档处理器进行处理 - result = get_document_processor().process_document( + document_processor = await get_document_processor(self.session) + self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}" + result = await document_processor.process_document( + session=self.session, document_id=doc_id, - file_path=document.file_path, - knowledge_base_id=document.knowledge_base_id + file_path=file_path, + knowledge_base_id=knowledge_base_id ) - self.session.desc = f"SUCCESS: 成功处理文档 {doc_id}" + self.session.desc = f"处理文档完毕 {doc_id}" # 如果处理成功,更新文档状态 if result["status"] == "success": document.is_processed = True document.chunk_count = result.get("chunks_count", 0) - self.session.commit() - self.session.refresh(document) + await self.session.commit() + await self.session.refresh(document) logger.info(f"Processed document: {document.filename} (ID: {doc_id})") return result except Exception as e: - self.session.rollback() + await self.session.rollback() self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}" # Update document with error try: - document = self.get_document(doc_id) + document = await self.get_document(doc_id, kb_id) if document: document.processing_error = str(e) - self.session.commit() + await self.session.commit() except Exception as db_error: logger.error(f"Failed to update document error status: {db_error}") @@ -217,10 +231,10 @@ class DocumentService: self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}" raise - def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool: + async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool: """更新文档处理状态。""" self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}" - document = self.get_document(doc_id) + document = await self.get_document(doc_id) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return False @@ -228,16 +242,16 @@ class DocumentService: document.is_processed = is_processed document.processing_error = error - self.session.commit() + await self.session.commit() self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}" return True - def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]: + async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]: """在知识库中搜索文档使用向量相似度。""" try: # 使用文档处理器进行相似性搜索 - self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query}" - results = get_document_processor().search_similar_documents(kb_id, query, limit) + self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}条" + results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit) self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果" return results except Exception as e: @@ -245,9 +259,9 @@ class DocumentService: logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}") return [] - def get_document_stats(self, kb_id: int) -> Dict[str, Any]: + async def get_document_stats(self, kb_id: int) -> Dict[str, Any]: """获取知识库中的文档统计信息。""" - documents = self.get_documents(kb_id, limit=1000) # Get all documents + documents = await self.get_documents(kb_id, limit=1000) # Get all documents total_count = len(documents) processed_count = len([doc for doc in documents if doc.is_processed]) @@ -267,19 +281,21 @@ class DocumentService: "file_types": file_types } - def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]: + async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]: """获取特定文档的文档块。""" try: self.session.desc = f"获取文档 {doc_id} 的文档块" stmt = select(Document).where(Document.id == doc_id) - document = self.session.scalar(stmt) + document = await self.session.scalar(stmt) if not document: self.session.desc = f"ERROR: 文档 {doc_id} 不存在" return [] + self.session.desc = f"获取文档 {doc_id} 的文档块 > document" # Get chunks from document processor - chunks_data = get_document_processor().get_document_chunks(document.knowledge_base_id, doc_id) + chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id) + self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data" # Convert to DocumentChunk objects chunks = [] for chunk_data in chunks_data: @@ -290,7 +306,8 @@ class DocumentService: page_number=chunk_data.get("page_number"), chunk_index=chunk_data["chunk_index"], start_char=chunk_data.get("start_char"), - end_char=chunk_data.get("end_char") + end_char=chunk_data.get("end_char"), + vector_id=chunk_data.get("vector_id") ) chunks.append(chunk) diff --git a/backend/th_agenter/services/document_processor.py b/backend/th_agenter/services/document_processor.py index 97dd0ff..5f90e63 100644 --- a/backend/th_agenter/services/document_processor.py +++ b/backend/th_agenter/services/document_processor.py @@ -3,10 +3,9 @@ import os from typing import List, Dict, Any, Optional from pathlib import Path -from urllib.parse import quote -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import QueuePool +from fastapi import HTTPException +from requests import Session +from sqlalchemy import text from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.document_loaders import ( TextLoader, @@ -18,7 +17,7 @@ import pdfplumber from langchain_core.documents import Document from langchain_postgres import PGVector from typing import List -# 旧的ZhipuEmbeddings类已移除,现在统一使用EmbeddingFactory创建embedding实例 + from ..core.config import BaseSettings, get_settings from ..models.knowledge_base import Document as DocumentModel @@ -26,64 +25,12 @@ from ..db.database import get_session from loguru import logger settings = get_settings() -class PGVectorConnectionPool: - """PGVector连接池管理器""" - - def __init__(self): - logger.error("PGVector连接池管理器 -==== 待异步方式实现") - self.engine = None - self.SessionLocal = None - # self._init_connection_pool() - -# def _init_connection_pool(self): -# """初始化连接池""" -# if settings.vector_db.type == "pgvector": -# # 构建连接字符串,对密码进行URL编码以处理特殊字符(如@符号) -# encoded_password = quote(settings.vector_db.pgvector_password, safe="") -# connection_string = ( -# f"postgresql://{settings.vector_db.pgvector_user}:" -# f"{encoded_password}@" -# f"{settings.vector_db.pgvector_host}:" -# f"{settings.vector_db.pgvector_port}/" -# f"{settings.vector_db.pgvector_database}" -# ) - -# # 创建SQLAlchemy引擎,配置连接池 -# self.engine = create_engine( -# connection_string, -# poolclass=QueuePool, -# pool_size=5, # 连接池大小 -# max_overflow=10, # 最大溢出连接数 -# pool_pre_ping=True, # 连接前ping检查 -# pool_recycle=3600, # 连接回收时间(秒) -# echo=False # 是否打印SQL语句 -# ) - -# # 创建会话工厂 -# self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) -# logger.info(f"PGVector连接池已初始化: {settings.vector_db.pgvector_host}:{settings.vector_db.pgvector_port}") - -# def get_session(self): -# """获取数据库会话""" -# if self.SessionLocal is None: -# raise RuntimeError("连接池未初始化") -# return self.SessionLocal() - -# def execute_query(self, query: str, params: tuple = None): -# """执行查询并返回结果""" -# session = self.get_session() -# try: -# result = session.execute(text(query), params or {}) -# return result.fetchall() -# finally: -# session.close() - - class DocumentProcessor: """文档处理器,负责文档的加载、分段和向量化""" def __init__(self): # 初始化语义分割器配置 + self.embeddings = None self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=settings.file.chunk_size, @@ -92,68 +39,136 @@ class DocumentProcessor: separators=["\n\n", "\n", " ", ""] ) + async def initialize(self, session: Session = None): # 初始化嵌入模型 - 根据配置选择提供商 - self._init_embeddings() + await self._init_embeddings(session) # 初始化连接池(仅对PGVector) self.pgvector_pool = None # PostgreSQL pgvector连接配置 - print('settings.vector_db.type=============', settings.vector_db.type) - if settings.vector_db.type == "pgvector": - # 新版本PGVector使用psycopg3连接字符串 - # 对密码进行URL编码以处理特殊字符(如@符号) - encoded_password = quote(settings.vector_db.pgvector_password, safe="") - self.connection_string = ( - f"postgresql+psycopg://{settings.vector_db.pgvector_user}:" - f"{encoded_password}@" - f"{settings.vector_db.pgvector_host}:" - f"{settings.vector_db.pgvector_port}/" - f"{settings.vector_db.pgvector_database}" - ) - # 初始化连接池 - self.pgvector_pool = PGVectorConnectionPool() - else: + # if settings.vector_db.type == "pgvector": + # # 新版本PGVector使用psycopg3连接字符串 + # # 对密码进行URL编码以处理特殊字符(如@符号) + # encoded_password = quote(settings.vector_db.pgvector_password, safe="") + # self.connection_string = ( + # f"postgresql+psycopg://{settings.vector_db.pgvector_user}:" + # f"{encoded_password}@" + # f"{settings.vector_db.pgvector_host}:" + # f"{settings.vector_db.pgvector_port}/" + # f"{settings.vector_db.pgvector_database}" + # ) + # # 初始化连接池 + # self.pgvector_pool = PGVectorConnectionPool() + # logger.info("新版本PGVector使用psycopg3连接字符串: %s", self.connection_string) + # else: # 向量数据库存储路径(Chroma兼容) - vector_db_path = settings.vector_db.persist_directory - if not os.path.isabs(vector_db_path): - # 如果是相对路径,则基于项目根目录计算绝对路径 - # 项目根目录是backend的父目录 - backend_dir = Path(__file__).parent.parent.parent - vector_db_path = str(backend_dir / vector_db_path) - self.vector_db_path = vector_db_path + vector_db_path = settings.vector_db.persist_directory + if not os.path.isabs(vector_db_path): + # 如果是相对路径,则基于项目根目录计算绝对路径 + # 项目根目录是backend的父目录 + backend_dir = Path(__file__).parent.parent.parent + vector_db_path = str(backend_dir / vector_db_path) + self.vector_db_path = vector_db_path + session.desc = f"初始化向量数据库 - 路径 = {self.vector_db_path}" - def _init_embeddings(self): - """根据配置初始化embedding模型""" - from .embedding_factory import EmbeddingFactory - self.embeddings = EmbeddingFactory.create_embeddings() + async def _init_embeddings(self, session: Optional[Any] = None): + """初始化嵌入模型。""" + try: + if not self.embeddings: + # 使用llm_config_service获取嵌入配置 + from .llm_config_service import LLMConfigService + llm_config_service = LLMConfigService() + + # 获取嵌入配置 + config = None + if session: + config = await llm_config_service.get_default_embedding_config(session) + if config: + if(session != None): + session.desc = f"获取默认嵌入模型配置: {config}" + # # 转换配置格式 + # config = { + # "provider": config.provider, + # "api_key": config.api_key, + # "model": config.model_name + # } + + # 如果未找到配置,使用默认配置 + if not config: + session.desc = f"ERROR: 未找到嵌入模型配置" + raise HTTPException(status_code=400, detail="未找到嵌入模型配置") + session.desc = f"获取嵌入模型配置 > 结果:{config}" + + # 根据配置创建嵌入模型 + if config.provider == "openai": + from langchain_openai import OpenAIEmbeddings + self.embeddings = OpenAIEmbeddings( + model=config.get("model", "text-embedding-3-small"), + api_key=config.get("api_key") + ) + session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})" + elif config.provider == "ollama": + from langchain_ollama import OllamaEmbeddings + self.embeddings = OllamaEmbeddings( + model=config.model_name, + base_url=config.base_url + ) + session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})" + elif config.provider == "local": + from langchain_huggingface import HuggingFaceEmbeddings + self.embeddings = HuggingFaceEmbeddings( + model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2") + ) + session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})" + else: + # 默认使用OpenAI + from langchain_openai import OpenAIEmbeddings + self.embeddings = OpenAIEmbeddings( + model=config.get("model", "text-embedding-3-small"), + api_key=config.get("api_key") + ) + session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效" + + return self.embeddings + except Exception as e: + logger.error(f"初始化嵌入模型时出错: {e}") + raise - def load_document(self, file_path: str) -> List[Document]: + def load_document(self, session: Session, file_path: str) -> List[Document]: """根据文件类型加载文档""" file_extension = Path(file_path).suffix.lower() - try: if file_extension == '.txt': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: txt" loader = TextLoader(file_path, encoding='utf-8') documents = loader.load() elif file_extension == '.pdf': # 使用pdfplumber处理PDF文件,更稳定 - documents = self._load_pdf_with_pdfplumber(file_path) + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: pdf" + from langchain_community.document_loaders import PyPDFLoader + loader = PyPDFLoader(file_path) + documents = loader.load() + # documents = self._load_pdf_with_pdfplumber(file_path) elif file_extension == '.docx': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: docx" loader = Docx2txtLoader(file_path) documents = loader.load() elif file_extension == '.md': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: md" loader = UnstructuredMarkdownLoader(file_path) documents = loader.load() else: raise ValueError(f"不支持的文件类型: {file_extension}") - logger.info(f"成功加载文档: {file_path}, 页数: {len(documents)}") + session.desc = f"已载文档: {file_path}, 页数: {len(documents)}" + # if len(documents) > 0: + # session.desc = f"文档内容示例: {type(documents[0])} - {documents[0]}" return documents except Exception as e: - logger.error(f"加载文档失败 {file_path}: {str(e)}") - raise + session.desc = f"ERROR: 加载文档失败 {file_path}: {str(e)}" + raise e def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]: """使用pdfplumber加载PDF文档""" @@ -195,78 +210,7 @@ class DocumentProcessor: merged_metadata.update(doc.metadata) return Document(page_content=merged_text, metadata=merged_metadata) - - def _get_semantic_split_points(self, text: str) -> List[str]: - """使用大模型分析文档内容,返回合适的分割点列表""" - try: - from langchain.chat_models import ChatOpenAI - from ..core.config import get_settings - - - - prompt = f""" - # 任务说明 - 请分析文档内容,识别出适合作为分割点的关键位置。分割点应该是能够将文档划分为有意义段落的文本片段。 - - - # 分割规则 - 请严格按照以下规则识别分割点: - - ## 基本要求 - 1. 分割点必须是完整的句子开头或段落开头 - 2. 每个分割后的部分应包含相对完整的语义内容 - 3. 每个分割部分的理想长度控制在500字以内,严禁超过1000字。如果超过了1000字,要强制分段。 - - ## 短段落处理 - 4. 如果某部分长度可能小于50字,应将其与后续内容合并,避免产生过短片段 - - ## 唯一性保证(重要) - 5. 确保每个分割点在文档中具有唯一性: - - 检查文内是否存在相同的文本片段 - - 如果存在重复,需要扩展分割点字符串,直到获得唯一标识 - - 扩展方法:在当前分割点后追加几个字符,形成更长的唯一字符串 - - ## 示例说明 - 原始文档: - "目录: - 第一章 标题一 - 第二章 标题二 - 正文 - 第一章 标题一 - 这是第一章的内容 - - 第二章 标题二 - 这是第二章的内容" - - 错误分割点:"第一章 标题一"(在目录和正文中重复出现) - - 正确分割点:"第一章 标题一\n这是第"(通过追加内容确保唯一性) - - # 输出格式 - - 只返回分割点文本字符串 - - 每个分割点用~~分隔 - - 不要包含任何其他内容或解释 - - 示例输出:分割点1~~分割点2~~分割点3 - - - 文档内容: - {text[:10000]} # 限制输入长度 - """ - from ..core.llm import create_llm - llm = create_llm(temperature=0.2) - - response = llm.invoke(prompt) - - # 解析响应获取分割点列表 - split_points = [point.strip() for point in response.content.split('~~') if point.strip()] - logger.info(f"语义分析得到 {len(split_points)} 个分割点") - return split_points - - except Exception as e: - logger.error(f"获取语义分割点失败: {str(e)}") - return [] - + def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]: """根据语义分割点切分文本""" chunks = [] @@ -291,76 +235,19 @@ class DocumentProcessor: return chunks - def split_documents(self, documents: List[Document]) -> List[Document]: + async def split_documents(self, session: Session, documents: List[Document]) -> List[Document]: """将文档分割成小块(含短段落合并和超长强制分割功能)""" try: - if self.semantic_splitter_enabled and documents: - # 1. 合并文档 - merged_doc = self._merge_documents(documents) + chunks = self.text_splitter.split_documents(documents) - # 2. 获取语义分割点 - split_points = self._get_semantic_split_points(merged_doc.page_content) - - if split_points: - # 3. 根据语义分割点切分文本 - text_chunks = self._split_by_semantic_points(merged_doc.page_content, split_points) - - # 4. 处理短段落合并和超长强制分割(新增逻辑) - processed_chunks = [] - buffer = "" - for chunk in text_chunks: - # 先检查当前chunk是否超长(超过1000字符) - if len(chunk) > 1000: - # 如果有缓冲内容,先处理缓冲 - if buffer: - processed_chunks.append(buffer) - buffer = "" - - # 对超长chunk进行强制分割 - forced_splits = self._force_split_long_chunk(chunk) - processed_chunks.extend(forced_splits) - else: - # 正常处理短段落合并 - if not buffer: - buffer = chunk - else: - if len(buffer) < 100: - buffer = f"{buffer}\n{chunk}" - else: - processed_chunks.append(buffer) - buffer = chunk - - # 添加最后剩余的缓冲内容 - if buffer: - processed_chunks.append(buffer) - - # 5. 创建Document对象 - chunks = [] - for i, chunk in enumerate(processed_chunks): - doc = Document( - page_content=chunk, - metadata={ - **merged_doc.metadata, - 'chunk_index': i, - 'merged': len(chunk) > 100, # 标记是否经过合并 - 'forced_split': len(chunk) > 1000 # 标记是否经过强制分割 - } - ) - chunks.append(doc) - else: - # 如果获取分割点失败,回退到默认分割器 - logger.warning("语义分割失败,使用默认分割器") - chunks = self.text_splitter.split_documents(documents) - else: - # 使用默认分割器 - chunks = self.text_splitter.split_documents(documents) - - logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块") + session.desc = f"文档分割完成,共生成 {len(chunks)} 个文档块" + if len(chunks) > 0: + session.desc = f"文档块内容示例: {type(chunks[0])} - {chunks[0]}" return chunks except Exception as e: - logger.error(f"文档分割失败: {str(e)}") - raise + session.desc = f"ERROR: 文档分割失败: {str(e)}" + raise e def _force_split_long_chunk(self, chunk: str) -> List[str]: """强制分割超长段落(超过1000字符)""" @@ -395,152 +282,123 @@ class DocumentProcessor: def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str: """为知识库创建向量存储""" try: - if settings.vector_db.type == "pgvector": - # 添加元数据 - for i, doc in enumerate(documents): - doc.metadata.update({ - "knowledge_base_id": knowledge_base_id, - "document_id": str(document_id) if document_id else "unknown", - "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", - "chunk_index": i - }) + # if settings.vector_db.type == "pgvector": + # # 添加元数据 + # for i, doc in enumerate(documents): + # doc.metadata.update({ + # "knowledge_base_id": knowledge_base_id, + # "document_id": str(document_id) if document_id else "unknown", + # "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + # "chunk_index": i + # }) - # 创建PostgreSQL pgvector存储 - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" + # # 创建PostgreSQL pgvector存储 + # collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - # 创建新版本PGVector实例 - vector_store = PGVector( - connection=self.connection_string, - embeddings=self.embeddings, - collection_name=collection_name, - use_jsonb=True # 使用JSONB存储元数据 - ) + # # 创建新版本PGVector实例 + # vector_store = PGVector( + # connection=self.connection_string, + # embeddings=self.embeddings, + # collection_name=collection_name, + # use_jsonb=True # 使用JSONB存储元数据 + # ) - # 手动添加文档 - vector_store.add_documents(documents) + # # 手动添加文档 + # vector_store.add_documents(documents) - logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}") - return collection_name - else: - # Chroma兼容模式 - from langchain_community.vectorstores import Chroma - kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") - - # 添加元数据 - for i, doc in enumerate(documents): - doc.metadata.update({ - "knowledge_base_id": knowledge_base_id, - "document_id": str(document_id) if document_id else "unknown", - "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", - "chunk_index": i - }) - - # 创建向量存储 - vector_store = Chroma.from_documents( - documents=documents, - embedding=self.embeddings, - persist_directory=kb_vector_path - ) - - # 持久化向量存储 - vector_store.persist() - - logger.info(f"向量存储创建成功: {kb_vector_path}") - return kb_vector_path + # logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}") + # return collection_name + # else: + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + # 添加元数据 + for i, doc in enumerate(documents): + doc.metadata.update({ + "knowledge_base_id": knowledge_base_id, + "document_id": str(document_id) if document_id else "unknown", + "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + "chunk_index": i + }) + + # 创建向量存储 + vector_store = Chroma.from_documents( + documents=documents, + embedding=self.embeddings, + persist_directory=kb_vector_path + ) + + logger.info(f"向量存储创建成功: {kb_vector_path}") + return kb_vector_path except Exception as e: logger.error(f"创建向量存储失败: {str(e)}") raise - def add_documents_to_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None: + def add_documents_to_vector_store(self, session: Session, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None: """向现有向量存储添加文档""" - try: - if settings.vector_db.type == "pgvector": - # 添加元数据 - for i, doc in enumerate(documents): - doc.metadata.update({ - "knowledge_base_id": knowledge_base_id, - "document_id": str(document_id) if document_id else "unknown", - "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", - "chunk_index": i - }) + if len(documents) == 0: + session.desc = f"WARNING: 文档列表为空,不执行添加操作" + return + from langchain_chroma import Chroma + + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + session.desc = f"添加文档到向量存储: {kb_vector_path} - documents number: {len(documents)}" + # 检查向量存储是否存在 + if not os.path.exists(kb_vector_path): + # 如果不存在,创建新的向量存储 + session.desc = f"WARNING: 向量存储不存在,创建新的向量存储" + self.create_vector_store(knowledge_base_id, documents, document_id) + return + session.desc = f"添加文档到向量存储: exists" + # 添加元数据 + for i, doc in enumerate(documents): + doc.metadata.update({ + "knowledge_base_id": knowledge_base_id, + "document_id": str(document_id) if document_id else "unknown", + "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + "chunk_index": i + }) - # PostgreSQL pgvector存储 - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - try: - # 连接到现有集合 - vector_store = PGVector( - connection=self.connection_string, - embeddings=self.embeddings, - collection_name=collection_name, - use_jsonb=True - ) - # 添加新文档 - vector_store.add_documents(documents) - except Exception as e: - # 如果集合不存在,创建新的向量存储 - logger.warning(f"连接现有向量存储失败,创建新的向量存储: {e}") - self.create_vector_store(knowledge_base_id, documents, document_id) - return + session.desc = f"添加文档到向量存储: enumerate" + # 加载现有向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) - logger.info(f"文档已添加到PostgreSQL pgvector存储: {collection_name}") - else: - # Chroma兼容模式 - from langchain_community.vectorstores import Chroma - kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") - - # 检查向量存储是否存在 - if not os.path.exists(kb_vector_path): - # 如果不存在,创建新的向量存储 - self.create_vector_store(knowledge_base_id, documents, document_id) - return - - # 添加元数据 - for i, doc in enumerate(documents): - doc.metadata.update({ - "knowledge_base_id": knowledge_base_id, - "document_id": str(document_id) if document_id else "unknown", - "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", - "chunk_index": i - }) - - # 加载现有向量存储 - vector_store = Chroma( - persist_directory=kb_vector_path, - embedding_function=self.embeddings - ) - - # 添加新文档 - vector_store.add_documents(documents) - vector_store.persist() - - logger.info(f"文档已添加到向量存储: {kb_vector_path}") - - except Exception as e: - logger.error(f"添加文档到向量存储失败: {str(e)}") - raise + session.desc = f"添加文档到向量存储: Chroma" + # 添加新文档 + ids = vector_store.add_documents(documents) + session.desc = f"文档已添加到向量存储: {kb_vector_path} -> {len(ids)} IDS - \n{ids}" - def process_document(self, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]: + async def process_document(self, session: Session, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]: """处理单个文档:加载、分段、向量化""" try: - logger.info(f"开始处理文档 ID: {document_id}, 路径: {file_path}") + session.desc = f"处理文档 ID: {document_id} 文件路径: {file_path}" # 1. 加载文档 - documents = self.load_document(file_path) + documents = self.load_document(session, file_path) # 2. 分割文档 - chunks = self.split_documents(documents) + chunks = await self.split_documents(session, documents) # 3. 添加到向量存储 - self.add_documents_to_vector_store(knowledge_base_id, chunks, document_id) + self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id) # 4. 更新文档状态 - with next(get_session()) as session: - document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first() + session = await anext(get_session()) + try: + from sqlalchemy import select + document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) + if document: document.status = "processed" document.chunk_count = len(chunks) - session.commit() + await session.commit() + finally: + await session.close() result = { "document_id": document_id, @@ -549,22 +407,27 @@ class DocumentProcessor: "message": "文档处理完成" } - logger.info(f"文档处理完成: {result}") + + session.desc = f"文档处理完成: {result}" return result except Exception as e: - logger.error(f"文档处理失败 ID: {document_id}: {str(e)}") + session.desc = f"ERROR: 文档处理失败 ID: {document_id}: {str(e)}" # 更新文档状态为失败 try: - with next(get_session()) as session: - document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first() + session = await anext(get_session()) + try: + from sqlalchemy import select + document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) if document: document.status = "failed" document.error_message = str(e) - session.commit() + await session.commit() + finally: + await session.close() except Exception as db_error: - logger.error(f"更新文档状态失败: {str(db_error)}") + session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}" return { "document_id": document_id, @@ -572,99 +435,35 @@ class DocumentProcessor: "error": str(e), "message": "文档处理失败" } - - def _get_document_ids_from_vector_store(self, knowledge_base_id: int, document_id: int) -> List[str]: - """查询指定document_id的所有向量记录的uuid""" - try: - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - - # 使用连接池执行查询 - if self.pgvector_pool: - query = f""" - SELECT uuid FROM langchain_pg_embedding - WHERE collection_id = ( - SELECT uuid FROM langchain_pg_collection - WHERE name = %s - ) AND cmetadata->>'document_id' = %s - """ - - result = self.pgvector_pool.execute_query(query, (collection_name, str(document_id))) - return [row[0] for row in result] if result else [] - else: - logger.warning("PGVector连接池未初始化") - return [] - - except Exception as e: - logger.error(f"查询文档向量记录失败: {str(e)}") - return [] - + def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None: """从向量存储中删除文档""" try: - if settings.vector_db.type == "pgvector": - # PostgreSQL pgvector存储 - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - - try: - # 创建新版本PGVector实例 - vector_store = PGVector( - connection=self.connection_string, - embeddings=self.embeddings, - collection_name=collection_name, - use_jsonb=True - ) - - # 直接从数据库查询要删除的文档UUID - try: - from sqlalchemy import text - from sqlalchemy.orm import Session - - # 获取数据库引擎 - engine = vector_store._engine - - with Session(engine) as session: - # 查询匹配document_id的所有记录的ID - query_sql = text( - f"SELECT id FROM langchain_pg_embedding " - f"WHERE cmetadata->>'document_id' = :doc_id" - ) - result = session.execute(query_sql, {"doc_id": str(document_id)}) - ids_to_delete = [row[0] for row in result.fetchall()] - - if ids_to_delete: - # 使用ID删除文档 - vector_store.delete(ids=ids_to_delete) - logger.info(f"成功删除 {len(ids_to_delete)} 个文档块: document_id={document_id}") - else: - logger.warning(f"未找到要删除的文档ID: document_id={document_id}") - - except Exception as query_error: - logger.error(f"查询要删除的文档时出错: {query_error}") - # 如果查询失败,说明文档可能不存在 - logger.warning(f"无法查询到要删除的文档: document_id={document_id}") - return - - logger.info(f"文档已从PostgreSQL pgvector存储中删除: document_id={document_id}") - except Exception as e: - logger.warning(f"PostgreSQL pgvector存储不存在或删除失败: {collection_name}, {str(e)}") - else: - # Chroma兼容模式 - from langchain_community.vectorstores import Chroma - kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") - - if not os.path.exists(kb_vector_path): - logger.warning(f"向量存储不存在: {kb_vector_path}") - return - - # 加载向量存储 - vector_store = Chroma( - persist_directory=kb_vector_path, - embedding_function=self.embeddings - ) - - # 删除相关文档块(这里需要根据实际的Chroma API来实现) - # 注意:Chroma的删除功能可能需要特定的实现方式 - logger.info(f"文档已从向量存储中删除: document_id={document_id}") + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"向量存储不存在: {kb_vector_path}") + return + + chunks = self.get_document_chunks(knowledge_base_id, document_id) + # 加载向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + count_before = vector_store._collection.count() + count_after = count_before + + if len(chunks) > 0: + where_filter = {"document_id": str(document_id)} + vector_store.delete(where=where_filter) + count_after = vector_store._collection.count() + + # 注意:Chroma的删除功能可能需要特定的实现方式 + logger.info(f"文档已从向量存储中删除: document_id={document_id},删除前有 {count_before} 个向量,删除后有 {count_after} 个向量") except Exception as e: logger.error(f"从向量存储删除文档失败: {str(e)}") @@ -679,26 +478,7 @@ class DocumentProcessor: - 确保结果按chunk_index排序 """ try: - if settings.vector_db.type == "pgvector": - # PostgreSQL pgvector存储 - 使用直接SQL查询避免相似性搜索 - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - - try: - # 尝试直接SQL查询(推荐方法) - chunks = self._get_chunks_by_sql(knowledge_base_id, document_id) - if chunks: - return chunks - - # 如果SQL查询失败,回退到改进的LangChain方法 - logger.info("SQL查询失败,使用LangChain回退方案") - return self._get_chunks_by_langchain_improved(knowledge_base_id, document_id, collection_name) - - except Exception as e: - logger.warning(f"PostgreSQL pgvector存储访问失败: {collection_name}, {str(e)}") - return [] - else: - # Chroma兼容模式 - return self._get_chunks_chroma(knowledge_base_id, document_id) + return self._get_chunks_chroma(knowledge_base_id, document_id) except Exception as e: logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}") @@ -826,148 +606,146 @@ class DocumentProcessor: def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]: """Chroma存储的处理逻辑""" - try: - from langchain_community.vectorstores import Chroma - - # 构建向量数据库路径 - vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") - - if not os.path.exists(vector_db_path): - logger.warning(f"向量数据库不存在: {vector_db_path}") - return [] - - # 加载向量数据库 - vectorstore = Chroma( - persist_directory=vector_db_path, - embedding_function=self.embeddings - ) - - # 获取所有文档的元数据,筛选出指定文档的分段 - collection = vectorstore._collection - all_docs = collection.get(include=["metadatas", "documents"]) - - chunks = [] - chunk_index = 0 - - for i, metadata in enumerate(all_docs["metadatas"]): - if metadata.get("document_id") == str(document_id): - chunk_content = all_docs["documents"][i] - - chunk = { - "id": f"chunk_{document_id}_{chunk_index}", - "content": chunk_content, - "metadata": metadata, - "page_number": metadata.get("page"), - "chunk_index": chunk_index, - "start_char": metadata.get("start_char"), - "end_char": metadata.get("end_char") - } - chunks.append(chunk) - chunk_index += 1 - - logger.info(f"获取到文档 {document_id} 的 {len(chunks)} 个分段") - return chunks - - except Exception as e: - logger.error(f"Chroma存储处理失败: {e}") + from langchain_chroma import Chroma + # 构建向量数据库路径 + vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(vector_db_path): + logger.warning(f"向量数据库不存在: {vector_db_path}") return [] + + # 加载向量数据库 + vectorstore = Chroma( + persist_directory=vector_db_path, + embedding_function=self.embeddings + ) + + # 获取所有文档的元数据,筛选出指定文档的分段 + collection = vectorstore._collection + all_docs = collection.get(include=["metadatas", "documents"]) + all_ids_data = collection.get() + + chunks = [] + chunk_index = 0 + + for i, metadata in enumerate(all_docs["metadatas"]): + if metadata.get("document_id") == str(document_id): + chunk_content = all_docs["documents"][i] + vector_id = all_ids_data["ids"][i] + + chunk = { + "id": f"chunk_{document_id}_{chunk_index}", + "content": chunk_content, + "metadata": metadata, + "page_number": metadata.get("page"), + "chunk_index": chunk_index, + "start_char": metadata.get("start_char"), + "end_char": metadata.get("end_char"), + "vector_id": vector_id + } + chunks.append(chunk) + chunk_index += 1 + + return chunks def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]: """在知识库中搜索相似文档""" try: - if settings.vector_db.type == "pgvector": - # PostgreSQL pgvector存储 - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" + # if settings.vector_db.type == "pgvector": + # # PostgreSQL pgvector存储 + # collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - try: - vector_store = PGVector( - connection=self.connection_string, - embeddings=self.embeddings, - collection_name=collection_name, - use_jsonb=True - ) + # try: + # vector_store = PGVector( + # connection=self.connection_string, + # embeddings=self.embeddings, + # collection_name=collection_name, + # use_jsonb=True + # ) - # 执行相似性搜索 - results = vector_store.similarity_search_with_score(query, k=k) + # # 执行相似性搜索 + # results = vector_store.similarity_search_with_score(query, k=k) - # 格式化结果 - formatted_results = [] - for doc, distance_score in results: - # pgvector使用余弦距离,距离越小相似度越高 - # 将距离转换为0-1之间的相似度分数 - similarity_score = 1.0 / (1.0 + distance_score) + # # 格式化结果 + # formatted_results = [] + # for doc, distance_score in results: + # # pgvector使用余弦距离,距离越小相似度越高 + # # 将距离转换为0-1之间的相似度分数 + # similarity_score = 1.0 / (1.0 + distance_score) - formatted_results.append({ - "content": doc.page_content, - "metadata": doc.metadata, - "similarity_score": distance_score, # 保留原始距离分数 - "normalized_score": similarity_score, # 归一化相似度分数 - "source": doc.metadata.get('filename', 'unknown'), - "document_id": doc.metadata.get('document_id', 'unknown'), - "chunk_id": doc.metadata.get('chunk_id', 'unknown') - }) + # formatted_results.append({ + # "content": doc.page_content, + # "metadata": doc.metadata, + # "similarity_score": distance_score, # 保留原始距离分数 + # "normalized_score": similarity_score, # 归一化相似度分数 + # "source": doc.metadata.get('filename', 'unknown'), + # "document_id": doc.metadata.get('document_id', 'unknown'), + # "chunk_id": doc.metadata.get('chunk_id', 'unknown') + # }) - # 按相似度分数排序(距离越小越相似) - formatted_results.sort(key=lambda x: x['similarity_score']) + # # 按相似度分数排序(距离越小越相似) + # formatted_results.sort(key=lambda x: x['similarity_score']) - logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档") - return formatted_results + # logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档") + # return formatted_results - except Exception as e: - logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}") - return [] - else: - # Chroma兼容模式 - from langchain_community.vectorstores import Chroma - kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + # except Exception as e: + # logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}") + # return [] + # else: + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"向量存储不存在: {kb_vector_path}") + return [] + + # 加载向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + # 执行相似性搜索 + results = vector_store.similarity_search_with_score(query, k=k) + + # 格式化结果 + formatted_results = [] + for doc, distance_score in results: + # Chroma使用欧几里得距离,距离越小相似度越高 + # 将距离转换为0-1之间的相似度分数 + similarity_score = 1.0 / (1.0 + distance_score) - if not os.path.exists(kb_vector_path): - logger.warning(f"向量存储不存在: {kb_vector_path}") - return [] - - # 加载向量存储 - vector_store = Chroma( - persist_directory=kb_vector_path, - embedding_function=self.embeddings - ) - - # 执行相似性搜索 - results = vector_store.similarity_search_with_score(query, k=k) - - # 格式化结果 - formatted_results = [] - for doc, distance_score in results: - # Chroma使用欧几里得距离,距离越小相似度越高 - # 将距离转换为0-1之间的相似度分数 - similarity_score = 1.0 / (1.0 + distance_score) - - formatted_results.append({ - "content": doc.page_content, - "metadata": doc.metadata, - "similarity_score": distance_score, # 保留原始距离分数 - "normalized_score": similarity_score, # 归一化相似度分数 - "source": doc.metadata.get('filename', 'unknown'), - "document_id": doc.metadata.get('document_id', 'unknown'), - "chunk_id": doc.metadata.get('chunk_id', 'unknown') - }) - - # 按相似度分数排序(距离越小越相似) - formatted_results.sort(key=lambda x: x['similarity_score']) - - logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档") - return formatted_results + formatted_results.append({ + "content": doc.page_content, + "metadata": doc.metadata, + "similarity_score": distance_score, # 保留原始距离分数 + "normalized_score": similarity_score, # 归一化相似度分数 + "source": doc.metadata.get('filename', 'unknown'), + "document_id": doc.metadata.get('document_id', 'unknown'), + "chunk_id": doc.metadata.get('chunk_id', 'unknown') + }) + + # 按相似度分数排序(距离越小越相似) + formatted_results.sort(key=lambda x: x['similarity_score']) + + logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档") + return formatted_results except Exception as e: logger.error(f"搜索文档失败: {str(e)}") return [] # 返回空列表而不是抛出异常 - # 全局文档处理器实例(延迟初始化) document_processor = None -def get_document_processor(): +async def get_document_processor(session: Session = None): """获取文档处理器实例(延迟初始化)""" global document_processor + if session: + session.desc = "获取文档处理器实例" if document_processor is None: document_processor = DocumentProcessor() + await document_processor.initialize(session) return document_processor \ No newline at end of file diff --git a/backend/th_agenter/services/embedding_factory.py b/backend/th_agenter/services/embedding_factory.py index 317e763..64afa06 100644 --- a/backend/th_agenter/services/embedding_factory.py +++ b/backend/th_agenter/services/embedding_factory.py @@ -4,6 +4,7 @@ from typing import Optional from langchain_core.embeddings import Embeddings from langchain_openai import OpenAIEmbeddings from langchain_community.embeddings import HuggingFaceEmbeddings +from requests import Session from .zhipu_embeddings import ZhipuOpenAIEmbeddings from ..core.config import settings from loguru import logger @@ -12,7 +13,8 @@ class EmbeddingFactory: """Factory class for creating embedding instances based on provider.""" @staticmethod - def create_embeddings( + async def create_embeddings( + session: Session = None, provider: Optional[str] = None, model: Optional[str] = None, dimensions: Optional[int] = None @@ -28,12 +30,12 @@ class EmbeddingFactory: Embeddings instance """ # 使用新的embedding配置 - embedding_config = settings.embedding.get_current_config() + embedding_config = await settings.embedding.get_current_config(session) provider = provider or settings.embedding.provider model = model or embedding_config.get("model") dimensions = dimensions or settings.vector_db.embedding_dimension - logger.info(f"Creating embeddings with provider: {provider}, model: {model}") + session.desc = f"创建嵌入模型: {provider}, {model}" if provider == "openai": return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions) diff --git a/backend/th_agenter/services/excel_metadata_service.py b/backend/th_agenter/services/excel_metadata_service.py index 1c07c74..ae26b7b 100644 --- a/backend/th_agenter/services/excel_metadata_service.py +++ b/backend/th_agenter/services/excel_metadata_service.py @@ -4,6 +4,7 @@ import os import pandas as pd from typing import Dict, List, Any, Optional, Tuple from sqlalchemy.orm import Session +from sqlalchemy import select, func from ..models.excel_file import ExcelFile from ..db.database import get_session from loguru import logger @@ -102,7 +103,7 @@ class ExcelMetadataService: 'processing_error': str(e) } - def save_file_metadata(self, file_path: str, original_filename: str, + async def save_file_metadata(self, file_path: str, original_filename: str, user_id: int, file_size: int) -> ExcelFile: """Extract and save Excel file metadata to database.""" try: @@ -131,31 +132,30 @@ class ExcelMetadataService: # Save to database - self.db.add(excel_file) - self.db.commit() - self.db.refresh(excel_file) + self.session.add(excel_file) + await self.session.commit() + await self.session.refresh(excel_file) logger.info(f"Saved metadata for file {original_filename} with ID {excel_file.id}") return excel_file except Exception as e: logger.error(f"Error saving metadata for {original_filename}: {str(e)}") - self.db.rollback() + await self.session.rollback() raise - def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]: + async def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]: """Get Excel files for a user with pagination.""" try: # Get total count - total = self.db.query(ExcelFile).filter(ExcelFile.created_by == user_id).count() + stmt = select(func.count(ExcelFile.id)).where(ExcelFile.created_by == user_id) + result = await self.session.execute(stmt) + total = result.scalar_one() # Get files with pagination - files = (self.db.query(ExcelFile) - .filter(ExcelFile.created_by == user_id) - .order_by(ExcelFile.created_at.desc()) - .offset(skip) - .limit(limit) - .all()) + stmt = select(ExcelFile).where(ExcelFile.created_by == user_id).order_by(ExcelFile.created_at.desc()).offset(skip).limit(limit) + result = await self.session.execute(stmt) + files = result.scalars().all() return files, total @@ -163,21 +163,21 @@ class ExcelMetadataService: logger.error(f"Error getting user files for user {user_id}: {str(e)}") return [], 0 - def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]: + async def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]: """Get Excel file by ID and user ID.""" try: - return (self.db.query(ExcelFile) - .filter(ExcelFile.id == file_id, ExcelFile.created_by == user_id) - .first()) + stmt = select(ExcelFile).where(ExcelFile.id == file_id, ExcelFile.created_by == user_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() except Exception as e: logger.error(f"Error getting file {file_id} for user {user_id}: {str(e)}") return None - def delete_file(self, file_id: int, user_id: int) -> bool: + async def delete_file(self, file_id: int, user_id: int) -> bool: """Delete Excel file record and physical file.""" try: # Get file record - excel_file = self.get_file_by_id(file_id, user_id) + excel_file = await self.get_file_by_id(file_id, user_id) if not excel_file: return False @@ -187,39 +187,40 @@ class ExcelMetadataService: logger.info(f"Deleted physical file: {excel_file.file_path}") # Delete database record - self.db.delete(excel_file) - self.db.commit() + await self.session.delete(excel_file) + await self.session.commit() logger.info(f"Deleted Excel file record with ID {file_id}") return True except Exception as e: logger.error(f"Error deleting file {file_id}: {str(e)}") - self.db.rollback() + await self.session.rollback() return False - def update_last_accessed(self, file_id: int, user_id: int) -> bool: + async def update_last_accessed(self, file_id: int, user_id: int) -> bool: """Update last accessed time for a file.""" try: - excel_file = self.get_file_by_id(file_id, user_id) + excel_file = await self.get_file_by_id(file_id, user_id) if not excel_file: return False - from sqlalchemy.sql import func excel_file.last_accessed = func.now() - self.db.commit() + await self.session.commit() return True except Exception as e: logger.error(f"Error updating last accessed for file {file_id}: {str(e)}") - self.db.rollback() + await self.session.rollback() return False - def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]: + async def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]: """Get file summary information for LLM context.""" try: - files = self.db.query(ExcelFile).filter(ExcelFile.user_id == user_id).all() + stmt = select(ExcelFile).where(ExcelFile.user_id == user_id) + result = await self.session.execute(stmt) + files = result.scalars().all() summary = [] for file in files: diff --git a/backend/th_agenter/services/knowledge_base.py b/backend/th_agenter/services/knowledge_base.py index 356d6d2..abfc531 100644 --- a/backend/th_agenter/services/knowledge_base.py +++ b/backend/th_agenter/services/knowledge_base.py @@ -29,9 +29,11 @@ class KnowledgeBaseService: Args: session (Session): 数据库会话,用于执行ORM操作。 """ + if session is None: + logger.error("session为空,session must be an instance of Session") self.session = session - def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase: + async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase: """创建一个新的知识库实例。 Args: @@ -57,130 +59,22 @@ class KnowledgeBaseService: collection_name=collection_name ) - # Set audit fields + # 自动更新created_by和updated_by字段 kb.set_audit_fields() self.session.add(kb) - self.session.commit() - self.session.refresh(kb) + await self.session.commit() + await self.session.refresh(kb) - logger.info(f"Created knowledge base: {kb.name} (ID: {kb.id})") + self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}" return kb except Exception as e: - self.session.rollback() + await self.session.rollback() logger.error(f"Failed to create knowledge base: {str(e)}") raise - - def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]: - """根据ID获取知识库实例。 - - Args: - kb_id (int): 知识库实例的ID。 - - Returns: - Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 - """ - stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) - return self.session.execute(stmt).scalar_one_or_none() - - def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]: - """根据名称获取当前用户的知识库实例。 - - Args: - name (str): 知识库实例的名称。 - - Returns: - Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 - """ - stmt = select(KnowledgeBase).where( - KnowledgeBase.name == name, - KnowledgeBase.created_by == UserContext.get_current_user().id - ) - return self.session.execute(stmt).scalar_one_or_none() - - async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]: - """获取当前用户的所有知识库的列表。 - - Args: - skip (int, optional): 跳过的记录数。默认值为0。 - limit (int, optional): 返回的最大记录数。默认值为50。 - active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。 - - Returns: - List[KnowledgeBase]: 当前用户的知识库列表。 - """ - stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user().id) - - if active_only: - stmt = stmt.where(KnowledgeBase.is_active == True) - - stmt = stmt.offset(skip).limit(limit) - return (await self.session.execute(stmt)).scalars().all() - - - def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]: - """更新知识库实例。 - - Args: - kb_id (int): 待更新的知识库实例ID。 - kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。 - - Returns: - Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。 - - Raises: - Exception: 如果更新过程中发生错误。 - """ - try: - kb = self.get_knowledge_base(kb_id) - if not kb: - return None - - # Update fields - update_data = kb_update.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(kb, field, value) - - # Set audit fields - kb.set_audit_fields(is_update=True) - - self.session.commit() - self.session.refresh(kb) - - self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})" - return kb - - except Exception as e: - self.session.rollback() - self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb_id} 失败: {str(e)}" - raise - - def delete_knowledge_base(self, kb_id: int) -> bool: - """删除知识库实例。 - - Args: - kb_id (int): 待删除的知识库实例ID。 - - Returns: - bool: 如果知识库实例被成功删除则返回True,否则返回False。 - - Raises: - Exception: 如果删除过程中发生错误。 - """ - kb = self.get_knowledge_base(kb_id) - if not kb: - return False - - # TODO: Clean up vector database collection - # This should be implemented when vector database service is ready - - self.session.delete(kb) - self.session.commit() - - return True - - def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]: + + async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]: """Search knowledge bases by name or description for the current user. Args: @@ -192,7 +86,7 @@ class KnowledgeBaseService: List[KnowledgeBase]: List of matching knowledge bases. """ stmt = select(KnowledgeBase).where( - KnowledgeBase.created_by == UserContext.get_current_user().id, + KnowledgeBase.created_by == UserContext.get_current_user()['id'], KnowledgeBase.is_active == True, or_( KnowledgeBase.name.ilike(f"%{query}%"), @@ -201,7 +95,7 @@ class KnowledgeBaseService: ) stmt = stmt.offset(skip).limit(limit) - return self.session.execute(stmt).scalars().all() + return (await self.session.execute(stmt)).scalars().all() async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: """Search in knowledge base using vector similarity. @@ -219,7 +113,7 @@ class KnowledgeBaseService: logger.info(f"Searching in knowledge base {kb_id} for: {query}") # Use document processor for vector search - search_results = get_document_processor().search_similar_documents( + search_results = (await get_document_processor(self.session)).search_similar_documents( knowledge_base_id=kb_id, query=query, k=top_k @@ -246,4 +140,105 @@ class KnowledgeBaseService: except Exception as e: logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}") - return [] \ No newline at end of file + return [] + +# ---------------------------------------------------------------------------------- + async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]: + """根据名称获取当前用户的知识库实例。 + + Args: + name (str): 知识库实例的名称。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 + """ + stmt = select(KnowledgeBase).where( + KnowledgeBase.name == name, + KnowledgeBase.created_by == UserContext.get_current_user()['id'] + ) + result = (await self.session.execute(stmt)).scalar_one_or_none() + return result + + async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]: + """获取当前用户的所有知识库的列表。 + + Args: + skip (int, optional): 跳过的记录数。默认值为0。 + limit (int, optional): 返回的最大记录数。默认值为50。 + active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。 + + Returns: + List[KnowledgeBase]: 当前用户的知识库列表。 + """ + stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID + if active_only: + stmt = stmt.where(KnowledgeBase.is_active == True) + stmt = stmt.offset(skip).limit(limit) + return (await self.session.execute(stmt)).scalars().all() + + async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]: + """根据ID获取知识库实例。 + + Args: + kb_id (int): 知识库实例的ID。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 + """ + stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) + return (await self.session.execute(stmt)).scalar_one_or_none() + + async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]: + """更新知识库实例。 + + Args: + kb_id (int): 待更新的知识库实例ID。 + kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。 + + Raises: + Exception: 如果更新过程中发生错误。 + """ + kb = await self.get_knowledge_base(kb_id) + if not kb: + return None + + # Update fields + update_data = kb_update.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(kb, field, value) + + # Set audit fields + kb.set_audit_fields(is_update=True) + + await self.session.commit() + await self.session.refresh(kb) + + self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})" + return kb + + async def delete_knowledge_base(self, kb_id: int) -> bool: + """删除知识库实例。 + + Args: + kb_id (int): 待删除的知识库实例ID。 + + Returns: + bool: 如果知识库实例被成功删除则返回True,否则返回False。 + + Raises: + Exception: 如果删除过程中发生错误。 + """ + kb = await self.get_knowledge_base(kb_id) + if not kb: + return False + + # TODO: Clean up vector database collection + # This should be implemented when vector database service is ready + + await self.session.delete(kb) + await self.session.commit() + + return True diff --git a/backend/th_agenter/services/knowledge_chat.py b/backend/th_agenter/services/knowledge_chat.py index 9b30371..561c747 100644 --- a/backend/th_agenter/services/knowledge_chat.py +++ b/backend/th_agenter/services/knowledge_chat.py @@ -9,7 +9,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser -from langchain_community.vectorstores import Chroma +from langchain_chroma import Chroma from langchain_postgres import PGVector from .embedding_factory import EmbeddingFactory @@ -21,16 +21,16 @@ from .conversation import ConversationService from .document_processor import get_document_processor from loguru import logger - class KnowledgeChatService: """Knowledge base chat service using LangChain RAG.""" - def __init__(self, db: Session): - self.db = db - self.conversation_service = ConversationService(db) + def __init__(self, session: Session): + self.session = session + self.conversation_service = ConversationService(session) + async def initialize(self): # 获取当前LLM配置 - llm_config = settings.llm.get_current_config() + llm_config = await settings.llm.get_current_config(self.session) # Initialize LangChain ChatOpenAI self.llm = ChatOpenAI( @@ -53,41 +53,23 @@ class KnowledgeChatService: ) # Initialize embeddings based on provider - self.embeddings = EmbeddingFactory.create_embeddings() - - logger.info(f"Knowledge Chat Service initialized with model: {self.llm.model_name}") - - def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]: + self.embeddings = await EmbeddingFactory.create_embeddings(self.session) + async def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]: """Get vector store for knowledge base.""" try: - if settings.vector_db.type == "pgvector": - # 使用PGVector - doc_processor = get_document_processor() - collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" - - vector_store = PGVector( - connection=doc_processor.connection_string, - embeddings=self.embeddings, - collection_name=collection_name, - use_jsonb=True - ) - - return vector_store - else: - # 兼容Chroma模式 - import os - kb_vector_path = os.path.join(get_document_processor().vector_db_path, f"kb_{knowledge_base_id}") - - if not os.path.exists(kb_vector_path): - logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}") - return None - - vector_store = Chroma( - persist_directory=kb_vector_path, - embedding_function=self.embeddings - ) - - return vector_store + import os + kb_vector_path = os.path.join((await get_document_processor(self.session)).vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}") + return None + + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + return vector_store except Exception as e: logger.error(f"Failed to load vector store for KB {knowledge_base_id}: {str(e)}") @@ -159,7 +141,7 @@ class KnowledgeChatService: try: # Get conversation and validate - conversation = self.conversation_service.get_conversation(conversation_id) + conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError("Conversation not found") @@ -169,14 +151,14 @@ class KnowledgeChatService: raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed") # Save user message - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Get conversation history - messages = self.conversation_service.get_conversation_messages(conversation_id) + messages = await self.conversation_service.get_conversation_messages(conversation_id) conversation_history = self._prepare_conversation_history(messages) # Create RAG chain @@ -203,7 +185,7 @@ class KnowledgeChatService: response_content = await asyncio.to_thread(rag_chain.invoke, message) # Save assistant message with context - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=response_content, role=MessageRole.ASSISTANT, @@ -251,14 +233,14 @@ class KnowledgeChatService: raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed") # Get conversation history - messages = self.conversation_service.get_conversation_messages(conversation_id) + messages = await self.conversation_service.get_conversation_messages(conversation_id) conversation_history = self._prepare_conversation_history(messages) # Create RAG chain rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history) # Save user message - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER @@ -269,7 +251,7 @@ class KnowledgeChatService: context = "\n\n".join([doc.page_content for doc in relevant_docs]) # Create streaming LLM - llm_config = settings.llm.get_current_config() + llm_config = await settings.llm.get_current_config() streaming_llm = ChatOpenAI( model=llm_config["model"], temperature=temperature or llm_config["temperature"], @@ -315,7 +297,7 @@ class KnowledgeChatService: # Save assistant response if full_response: - self.conversation_service.add_message( + await self.conversation_service.add_message( conversation_id=conversation_id, content=full_response, role=MessageRole.ASSISTANT, @@ -331,7 +313,7 @@ class KnowledgeChatService: yield error_message # Save error message - self.conversation_service.add_message( + await self.conversation_service.add_message( conversation_id=conversation_id, content=error_message, role=MessageRole.ASSISTANT diff --git a/backend/th_agenter/services/langchain_chat.py b/backend/th_agenter/services/langchain_chat.py index d5d7583..8724372 100644 --- a/backend/th_agenter/services/langchain_chat.py +++ b/backend/th_agenter/services/langchain_chat.py @@ -41,28 +41,27 @@ class StreamingCallbackHandler(BaseCallbackHandler): class LangChainChatService: """LangChain-based chat service for AI model integration.""" - def __init__(self, db: Session): - self.db = db - self.conversation_service = ConversationService(db) + def __init__(self, session: Session): + self.session = session + self.conversation_service = ConversationService(session) - from ..core.llm import create_llm + async def initialize(self): + from ..core.new_agent import new_agent - # 添加调试日志 - logger.info(f"LLM Provider: {settings.llm.provider}") + # Initialize LangChain ChatOpenAI + self.llm = await new_agent(self.session, streaming=False) + self.session.desc = "LangChainChatService初始化 - llm 实例化完毕" - # Initialize LangChain ChatOpenAI - self.llm = create_llm(streaming=False) - - # Streaming LLM for stream responses - self.streaming_llm = create_llm(streaming=True) + # Streaming LLM for stream responses + self.streaming_llm = await new_agent(self.session, streaming=True) + self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕" self.streaming_handler = StreamingCallbackHandler() - - logger.info(f"LangChain ChatService initialized with model: {self.llm.model_name}") + self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕" def _prepare_langchain_messages(self, conversation, history: List) -> List: """Prepare messages for LangChain format.""" - messages = [] + messages = [] # Add system message if conversation has system prompt if hasattr(conversation, 'system_prompt') and conversation.system_prompt: @@ -101,19 +100,19 @@ class LangChainChatService: try: # Get conversation details - conversation = self.conversation_service.get_conversation(conversation_id) + conversation = await self.conversation_service.get_conversation(conversation_id) if not conversation: raise ChatServiceError("Conversation not found") # Add user message to database - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Get conversation history for context - history = self.conversation_service.get_conversation_history( + history = await self.conversation_service.get_conversation_history( conversation_id, limit=20 ) @@ -123,7 +122,7 @@ class LangChainChatService: # Update LLM parameters if provided llm_to_use = self.llm if temperature is not None or max_tokens is not None: - llm_config = settings.llm.get_current_config() + llm_config = await settings.llm.get_current_config() llm_to_use = ChatOpenAI( model=llm_config["model"], openai_api_key=llm_config["api_key"], @@ -140,7 +139,7 @@ class LangChainChatService: assistant_content = response.content # Add assistant message to database - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=assistant_content, role=MessageRole.ASSISTANT, @@ -152,7 +151,7 @@ class LangChainChatService: ) # Update conversation timestamp - self.conversation_service.update_conversation_timestamp(conversation_id) + await self.conversation_service.update_conversation_timestamp(conversation_id) logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}") @@ -171,7 +170,7 @@ class LangChainChatService: error_message = self._format_error_message(e) # Add error message to database - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=error_message, role=MessageRole.ASSISTANT, @@ -206,33 +205,33 @@ class LangChainChatService: max_tokens: Optional[int] = None ) -> AsyncGenerator[str, None]: """Send a message and get streaming AI response using LangChain.""" - logger.info(f"Processing LangChain streaming chat request for conversation {conversation_id}") + logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}") try: # Get conversation details - conversation = self.conversation_service.get_conversation(conversation_id) + conversation = await self.conversation_service.get_conversation(conversation_id) + conv = conversation.to_dict() if not conversation: raise ChatServiceError("Conversation not found") # Add user message to database - user_message = self.conversation_service.add_message( + user_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=message, role=MessageRole.USER ) # Get conversation history for context - history = self.conversation_service.get_conversation_history( + history = await self.conversation_service.get_conversation_history( conversation_id, limit=20 ) # Prepare messages for LangChain - langchain_messages = self._prepare_langchain_messages(conversation, history) - + langchain_messages = self._prepare_langchain_messages(conv, history) # Update streaming LLM parameters if provided streaming_llm_to_use = self.streaming_llm if temperature is not None or max_tokens is not None: - llm_config = settings.llm.get_current_config() + llm_config = await settings.llm.get_current_config() streaming_llm_to_use = ChatOpenAI( model=llm_config["model"], openai_api_key=llm_config["api_key"], @@ -241,33 +240,35 @@ class LangChainChatService: max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens, streaming=True ) - # Clear previous streaming handler state self.streaming_handler.clear() - # Stream response + # Stream response full_response = "" - async for chunk in streaming_llm_to_use.astream(langchain_messages): - # Handle different chunk types to avoid KeyError - chunk_content = None - if hasattr(chunk, 'content'): - # For object-like chunks with content attribute - chunk_content = chunk.content - elif isinstance(chunk, dict) and 'content' in chunk: - # For dict-like chunks with content key - chunk_content = chunk['content'] - elif isinstance(chunk, dict) and 'error' in chunk: - # Handle error chunks explicitly - logger.error(f"Error in LLM response: {chunk['error']}") - yield self._format_error_message(Exception(chunk['error'])) - continue - - if chunk_content: - full_response += chunk_content - yield chunk_content - + try: + async for chunk in streaming_llm_to_use._astream(langchain_messages): + # Handle different chunk types to avoid KeyError + chunk_content = None + if hasattr(chunk, 'content'): + # For object-like chunks with content attribute + chunk_content = chunk.content + elif isinstance(chunk, dict) and 'content' in chunk: + # For dict-like chunks with content key + chunk_content = chunk['content'] + elif isinstance(chunk, dict) and 'error' in chunk: + # Handle error chunks explicitly + logger.error(f"Error in LLM response: {chunk['error']}") + yield self._format_error_message(Exception(chunk['error'])) + continue + + if chunk_content: + full_response += chunk_content + yield chunk_content + except Exception as e: + logger.error(f"Error in LLM streaming: {e}") + yield f"{self._format_error_message(e)} >>> {e}" # Add complete assistant message to database - assistant_message = self.conversation_service.add_message( + assistant_message = await self.conversation_service.add_message( conversation_id=conversation_id, content=full_response, role=MessageRole.ASSISTANT, @@ -280,13 +281,12 @@ class LangChainChatService: ) # Update conversation timestamp - self.conversation_service.update_conversation_timestamp(conversation_id) - - logger.info(f"Successfully processed LangChain streaming chat request for conversation {conversation_id}") + await self.conversation_service.update_conversation_timestamp(conversation_id) + logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}") except Exception as e: # 安全地格式化异常信息,避免再次引发KeyError - error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id}" + error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}" logger.error(error_info, exc_info=True) # Format error message for user @@ -294,7 +294,7 @@ class LangChainChatService: yield error_message # Add error message to database - self.conversation_service.add_message( + await self.conversation_service.add_message( conversation_id=conversation_id, content=error_message, role=MessageRole.ASSISTANT, @@ -324,23 +324,23 @@ class LangChainChatService: logger.error(f"Failed to get available models: {str(e)}") return ["gpt-3.5-turbo"] - def update_model_config( + async def update_model_config( self, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None ): """Update LLM configuration.""" - from ..core.llm import create_llm + from ..core.new_agent import new_agent # 重新创建LLM实例 - self.llm = create_llm( + self.llm = await new_agent( model=model, temperature=temperature, streaming=False ) - self.streaming_llm = create_llm( + self.streaming_llm = await new_agent( model=model, temperature=temperature, streaming=True @@ -357,7 +357,7 @@ class LangChainChatService: if "rate limit" in error_str.lower(): return "服务器繁忙,请稍后再试。" elif "api key" in error_str.lower() or "authentication" in error_str.lower(): - return "API认证失败,请检查配置。" + return f"API认证失败,请检查配置文件。" elif "timeout" in error_str.lower(): return "请求超时,请重试。" elif "connection" in error_str.lower(): diff --git a/backend/th_agenter/services/llm_config_service.py b/backend/th_agenter/services/llm_config_service.py index 64b7eb9..0d9fd2a 100644 --- a/backend/th_agenter/services/llm_config_service.py +++ b/backend/th_agenter/services/llm_config_service.py @@ -11,11 +11,9 @@ from loguru import logger class LLMConfigService: """LLM配置管理服务""" - def __init__(self, db_session: Optional[Session] = None): - self.db = db_session or get_session() # TODO DrGraph:检查异步 - - def get_default_chat_config(self) -> Optional[LLMConfig]: + async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]: """获取默认对话模型配置""" + # async for session in get_session(): try: stmt = select(LLMConfig).where( and_( @@ -24,7 +22,7 @@ class LLMConfigService: LLMConfig.is_active == True ) ) - config = self.db.execute(stmt).scalar_one_or_none() + config = (await session.execute(stmt)).scalar_one_or_none() if not config: logger.warning("未找到默认对话模型配置") @@ -36,7 +34,7 @@ class LLMConfigService: logger.error(f"获取默认对话模型配置失败: {str(e)}") return None - def get_default_embedding_config(self) -> Optional[LLMConfig]: + async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]: """获取默认嵌入模型配置""" try: stmt = select(LLMConfig).where( @@ -46,23 +44,27 @@ class LLMConfigService: LLMConfig.is_active == True ) ) - config = self.db.execute(stmt).scalar_one_or_none() - + config = None + if session != None: + config = (await session.execute(stmt)).scalar_one_or_none() if not config: - logger.warning("未找到默认嵌入模型配置") + if session != None: + session.desc = "ERROR: 未找到默认嵌入模型配置" return None + session.desc = f"获取默认嵌入模型配置 > 结果:{config}" return config except Exception as e: - logger.error(f"获取默认嵌入模型配置失败: {str(e)}") + if session != None: + session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}" return None - def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]: + async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]: """根据ID获取配置""" try: stmt = select(LLMConfig).where(LLMConfig.id == config_id) - return self.db.execute(stmt).scalar_one_or_none() + return (await self.db.execute(stmt)).scalar_one_or_none() except Exception as e: logger.error(f"获取配置失败: {str(e)}") return None @@ -82,17 +84,17 @@ class LLMConfigService: logger.error(f"获取激活配置失败: {str(e)}") return [] - def _get_fallback_chat_config(self) -> Dict[str, Any]: + async def _get_fallback_chat_config(self) -> Dict[str, Any]: """获取fallback对话模型配置(从环境变量)""" from ..core.config import get_settings settings = get_settings() - return settings.llm.get_current_config() + return await settings.llm.get_current_config() - def _get_fallback_embedding_config(self) -> Dict[str, Any]: + async def _get_fallback_embedding_config(self) -> Dict[str, Any]: """获取fallback嵌入模型配置(从环境变量)""" from ..core.config import get_settings settings = get_settings() - return settings.embedding.get_current_config() + return await settings.embedding.get_current_config() def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]: """测试配置连接""" @@ -110,12 +112,12 @@ class LLMConfigService: logger.error(f"测试配置失败: {str(e)}") return {"success": False, "error": str(e)} -# 全局实例 -_llm_config_service = None +# # 全局实例 +# _llm_config_service = None -def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService: - """获取LLM配置服务实例""" - global _llm_config_service - if _llm_config_service is None or db_session is not None: - _llm_config_service = LLMConfigService(db_session) - return _llm_config_service \ No newline at end of file +# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService: +# """获取LLM配置服务实例""" +# global _llm_config_service +# if _llm_config_service is None or db_session is not None: +# _llm_config_service = LLMConfigService(db_session) +# return _llm_config_service \ No newline at end of file diff --git a/backend/th_agenter/services/smart_db_workflow.py b/backend/th_agenter/services/smart_db_workflow.py index 8975b3a..aeb81f1 100644 --- a/backend/th_agenter/services/smart_db_workflow.py +++ b/backend/th_agenter/services/smart_db_workflow.py @@ -49,8 +49,9 @@ class SmartDatabaseWorkflowManager: self.db = db self.table_metadata_service = TableMetadataService(db) if db else None - from ..core.llm import create_llm - self.llm = create_llm() + async def initialize(self): + from ..core.new_agent import new_agent + self.llm = await new_agent() def _get_database_tool(self, db_type: str): """根据数据库类型获取对应的数据库工具""" diff --git a/backend/th_agenter/services/smart_excel_workflow.py b/backend/th_agenter/services/smart_excel_workflow.py index b5deb15..8870a3e 100644 --- a/backend/th_agenter/services/smart_excel_workflow.py +++ b/backend/th_agenter/services/smart_excel_workflow.py @@ -54,9 +54,10 @@ class SmartExcelWorkflowManager: else: self.metadata_service = None - from ..core.llm import create_llm + async def initialize(self): + from ..core.new_agent import new_agent # 禁用流式响应,避免pandas代理兼容性问题 - self.llm = create_llm(streaming=False) + self.llm = await new_agent(streaming=False) async def _run_in_executor(self, func, *args): """在线程池中运行阻塞函数""" diff --git a/backend/th_agenter/services/smart_workflow.py b/backend/th_agenter/services/smart_workflow.py index fce1d9b..db4cce8 100644 --- a/backend/th_agenter/services/smart_workflow.py +++ b/backend/th_agenter/services/smart_workflow.py @@ -15,8 +15,12 @@ class SmartWorkflowManager: def __init__(self, db=None): self.db = db - self.excel_workflow = SmartExcelWorkflowManager(db) - self.database_workflow = SmartDatabaseWorkflowManager(db) + + async def initialize(self): + self.excel_workflow = SmartExcelWorkflowManager(self.db) + await self.excel_workflow.initialize() + self.database_workflow = SmartDatabaseWorkflowManager(self.db) + await self.database_workflow.initialize() async def process_excel_query_stream( self, diff --git a/backend/th_agenter/services/table_metadata_service.py b/backend/th_agenter/services/table_metadata_service.py index 45f0af2..84c6187 100644 --- a/backend/th_agenter/services/table_metadata_service.py +++ b/backend/th_agenter/services/table_metadata_service.py @@ -35,7 +35,7 @@ class TableMetadataService: DatabaseConfig.id == database_config_id, DatabaseConfig.created_by == user_id ) - db_config = self.session.scalar_one_or_none(stmt) + db_config = (await self.session.execute(stmt)).scalar_one_or_none() if not db_config: self.session.desc = "ERROR: 数据库配置不存在" @@ -202,7 +202,7 @@ class TableMetadataService: TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) - existing = self.session.scalar_one_or_none(stmt) + existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" @@ -216,8 +216,8 @@ class TableMetadataService: existing.table_comment = metadata['table_comment'] existing.last_synced_at = datetime.utcnow() - self.session.commit() - self.session.refresh(existing) + await self.session.commit() + await self.session.refresh(existing) return existing else: self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" @@ -240,8 +240,8 @@ class TableMetadataService: ) self.session.add(table_metadata) - self.session.commit() - self.session.refresh(table_metadata) + await self.session.commit() + await self.session.refresh(table_metadata) self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" return table_metadata @@ -258,7 +258,7 @@ class TableMetadataService: DatabaseConfig.id == database_config_id, DatabaseConfig.user_id == user_id ) - db_config = self.session.scalar_one_or_none(stmt) + db_config = (await self.session.execute(stmt)).scalar_one_or_none() if not db_config: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" @@ -275,7 +275,7 @@ class TableMetadataService: TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) - existing = self.session.scalar_one_or_none(stmt) + existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置" @@ -320,7 +320,7 @@ class TableMetadataService: }) # 提交事务 - self.session.commit() + await self.session.commit() self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_names} 的元数据配置" return { 'saved_tables': saved_tables, @@ -330,7 +330,7 @@ class TableMetadataService: } - def get_user_table_metadata( + async def get_user_table_metadata( self, user_id: int, database_config_id: Optional[int] = None @@ -345,9 +345,9 @@ class TableMetadataService: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" raise NotFoundError("数据库配置不存在") stmt = stmt.where(TableMetadata.is_enabled_for_qa == True) - return self.session.scalars(stmt).all() + return (await self.session.scalars(stmt)).all() - def get_table_metadata_by_name( + async def get_table_metadata_by_name( self, user_id: int, database_config_id: int, @@ -360,9 +360,9 @@ class TableMetadataService: TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) - return self.session.scalar_one_or_none(stmt) + return (await self.session.execute(stmt)).scalar_one_or_none() - def update_table_qa_settings( + async def update_table_qa_settings( self, user_id: int, metadata_id: int, @@ -375,7 +375,7 @@ class TableMetadataService: TableMetadata.id == metadata_id, TableMetadata.created_by == user_id ) - metadata = self.session.scalar_one_or_none(stmt) + metadata = (await self.session.execute(stmt)).scalar_one_or_none() if not metadata: self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在" @@ -388,15 +388,15 @@ class TableMetadataService: if 'business_context' in settings: metadata.business_context = settings['business_context'] - self.session.commit() + await self.session.commit() return True except Exception as e: self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}" - self.session.rollback() + await self.session.rollback() return False - def save_table_metadata( + async def save_table_metadata( self, user_id: int, database_config_id: int, @@ -414,7 +414,7 @@ class TableMetadataService: TableMetadata.database_config_id == database_config_id, TableMetadata.table_name == table_name ) - existing = self.session.scalar_one_or_none(stmt) + existing = (await self.session.execute(stmt)).scalar_one_or_none() if existing: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 已存在,更新其元数据" @@ -424,7 +424,7 @@ class TableMetadataService: existing.row_count = row_count existing.table_comment = table_comment existing.last_synced_at = datetime.utcnow() - self.session.commit() + await self.session.commit() return existing else: self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 不存在,创建新记录" @@ -444,8 +444,8 @@ class TableMetadataService: ) self.session.add(metadata) - self.session.commit() - self.session.refresh(metadata) + await self.session.commit() + await self.session.refresh(metadata) return metadata def _decrypt_password(self, encrypted_password: str) -> str: diff --git a/backend/th_agenter/services/tools/search.py b/backend/th_agenter/services/tools/search.py index 84f8ee5..82384f9 100644 --- a/backend/th_agenter/services/tools/search.py +++ b/backend/th_agenter/services/tools/search.py @@ -7,6 +7,7 @@ from langchain.tools import BaseTool from langchain_community.tools.tavily_search import TavilySearchResults from pydantic import BaseModel, Field, PrivateAttr from typing import Optional, Type, ClassVar +from langchain_tavily import TavilySearch # 定义输入参数模型(替代原get_parameters()) class SearchInput(BaseModel): @@ -35,7 +36,7 @@ class TavilySearchTool(BaseTool): raise ValueError("Tavily API key not found in settings") # 初始化Tavily客户端 - self._search_client = TavilySearchResults( + self._search_client = TavilySearch( tavily_api_key=self._tavily_api_key ) diff --git a/backend/th_agenter/services/user.py b/backend/th_agenter/services/user.py index 3145498..8da17ce 100644 --- a/backend/th_agenter/services/user.py +++ b/backend/th_agenter/services/user.py @@ -126,10 +126,10 @@ class UserService: self.session.desc = f"用户ID为{user_id}的密码已成功更改" return True - def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK + async def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK """Reset user password (admin only, no current password required).""" self.session.desc = f"重置用户ID为{user_id}的密码" - user = self.get_user_by_id(user_id) + user = await self.get_user_by_id(user_id) if not user: raise ValidationError("User not found") @@ -141,7 +141,7 @@ class UserService: hashed_password = self.get_password_hash(new_password) # Update password user.hashed_password = hashed_password - self.session.commit() + await self.session.commit() self.session.desc = f"用户ID为{user_id}的密码已成功重置" return True diff --git a/backend/th_agenter/services/workflow_engine.py b/backend/th_agenter/services/workflow_engine.py index a248562..d39492a 100644 --- a/backend/th_agenter/services/workflow_engine.py +++ b/backend/th_agenter/services/workflow_engine.py @@ -42,8 +42,8 @@ class WorkflowEngine: execution.set_audit_fields(user_id) self.session.add(execution) - self.session.commit() - self.session.refresh(execution) + await self.session.commit() + await self.session.refresh(execution) @@ -75,8 +75,8 @@ class WorkflowEngine: execution.set_audit_fields(user_id, is_update=True) - self.session.commit() - self.session.refresh(execution) + await self.session.commit() + await self.session.refresh(execution) return WorkflowExecutionResponse.from_orm(execution) @@ -100,8 +100,8 @@ class WorkflowEngine: execution.set_audit_fields(user_id) self.session.add(execution) - self.session.commit() - self.session.refresh(execution) + await self.session.commit() + await self.session.refresh(execution) # 发送工作流开始执行的消息 yield { @@ -170,11 +170,8 @@ class WorkflowEngine: } execution.set_audit_fields(user_id, is_update=True) - self.session.commit() - self.session.refresh(execution) - - self.session.commit() - self.session.refresh(execution) + await self.session.commit() + await self.session.refresh(execution) def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: """构建节点依赖图""" @@ -385,8 +382,8 @@ class WorkflowEngine: started_at=datetime.now().isoformat() ) self.session.add(node_execution) - self.session.commit() - self.session.refresh(node_execution) + await self.session.commit() + await self.session.refresh(node_execution) start_time = time.time() @@ -419,7 +416,7 @@ class WorkflowEngine: } node_execution.input_data = display_input_data - self.session.commit() + await self.session.commit() @@ -446,7 +443,7 @@ class WorkflowEngine: node_execution.completed_at = datetime.now().isoformat() node_execution.duration_ms = int((end_time - start_time) * 1000) - self.session.commit() + await self.session.commit() @@ -459,7 +456,7 @@ class WorkflowEngine: node_execution.error_message = str(e) node_execution.completed_at = datetime.now().isoformat() node_execution.duration_ms = int((end_time - start_time) * 1000) - self.session.commit() + await self.session.commit() @@ -918,8 +915,8 @@ class WorkflowEngine: # 工作流引擎实例 -def get_workflow_engine(session: Session = None) -> WorkflowEngine: +async def get_workflow_engine(session: Session = None) -> WorkflowEngine: """获取工作流引擎实例""" if session is None: - session = next(get_session()) + session = await anext(get_session()) return WorkflowEngine(session) diff --git a/backend/utils/Constant.py b/backend/utils/Constant.py new file mode 100644 index 0000000..4758b3e --- /dev/null +++ b/backend/utils/Constant.py @@ -0,0 +1,490 @@ +from PySide6.QtCore import QPointF, QRectF, QSizeF +# -*- coding: utf-8 -*- +class Constant: + invalid_point = QPointF(-123.4, -234.71) # 无效点 + invalid_rect = QRectF(invalid_point, QSizeF(0, 0)) # 无效矩形 + service_yml_path = 'config/service/dsp_%s_service.yml' # 服务配置路径 + kafka_yml_path = 'config/kafka/dsp_%s_kafka.yml' # kafka配置路径 + aliyun_yml_path = "config/aliyun/dsp_%s_aliyun.yml" # 阿里云配置路径 + baidu_yml_path = 'config/baidu/dsp_%s_baidu.yml' # 百度配置路径 + pull_frame_width = 1400 # 拉流帧宽度 + + LLM_MODE_NONE = 0 + LLM_MODE_NONCHAT = 1 + LLM_MODE_CHAT = 2 + LLM_MODE_EMBEDDING = 3 + LLM_MODE_LOCAL_OLLAMA = 4 + + LLM_PROMPT_TEMPLATE_METHOD_FORMAT = "format" + LLM_PROMPT_TEMPLATE_METHOD_INVOKE = "invoke" + + LLM_PROMPT_VALUE_STR = "str" + LLM_PROMPT_VALUE_MESSAGES = "messages" + LLM_PROMPT_VALUE_VALUE = "promptValue" + + INPUT_NONE = 0 + INPUT_PULL_STREAM = 1 + INPUT_NET_CAMERA = 2 + INPUT_USB_CAMERA = 3 + INPUT_LOCAL_VIDEO = 4 + INPUT_LOCAL_DIR = 5 + INPUT_LOCAL_FILE = 6 + INPUT_LOCAL_LABELLED_DIR = 7 + INPUT_LOCAL_LABELLED_ZIP = 8 + + SYSTEM_BUTTON_SNAP = 1 + SYSTEM_BUTTON_CLEAR_LOG = 2 + + MODE_OCR = 7 + + WEB_OWNER_NONE = 0 + WEB_OWNER_ALG_TEST = 1 + WEB_OWNER_LLM = 2 + + PAGE_MODE_NONE = 0 + PAGE_MODE_VIDEO = 1 + PAGE_MODE_LABELLING = 2 + PAGE_MODE_ALG_TEST = 3 + PAGE_MODE_WEB = 4 + PAGE_MODE_LLM = 5 + PAGE_MODE_CONFIG = 6 + + ALG_TEST_LOAD_DATA = 1 + ALG_TEST_TRAIN = 2 + ALG_TEST_INFER = 3 + + ALG_STATUS_NOTHING = 0 # 啥也没干 + ALG_STATUS_MODEL_READY = 1 # 模型已就绪 + ALG_STATUS_DATA_LOADED = 2 # 加载已数据/加载中 + ALG_STATUS_TRAINED = 4 # 训练完成/训练中 + ALG_STATUS_INFER = 8 # 推理完成/推理中 + + UTF_8 = "utf-8" # 编码格式 + + COLOR = ( + [0, 0, 255], + [255, 0, 0], + [211, 0, 148], + [0, 127, 0], + [0, 69, 255], + [0, 255, 0], + [255, 0, 255], + [0, 0, 127], + [127, 0, 255], + [255, 129, 0], + [139, 139, 0], + [255, 255, 0], + [127, 255, 0], + [0, 127, 255], + [0, 255, 127], + [255, 127, 255], + [8, 101, 139], + [171, 130, 255], + [139, 112, 74], + [205, 205, 180]) + + ONLINE = "online" + OFFLINE = "offline" + PHOTO = "photo" + RECORDING = "recording" + + ONLINE_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "models": { + 'type': 'list', + 'required': True, + 'nullable': False, + 'minlength': 1, + 'maxlength': 3, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + ONLINE_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + OFFLINE_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "models": { + 'type': 'list', + 'required': True, + 'maxlength': 3, + 'minlength': 1, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + OFFLINE_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + IMAGE_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "image_urls": { + 'type': 'list', + 'required': True, + 'minlength': 1, + 'schema': { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 5000 + } + }, + "models": { + 'type': 'list', + 'required': True, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + RECORDING_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "push_url": { + 'type': 'string', + 'required': False, + 'empty': True, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + } + } + + RECORDING_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + PULL2PUSH_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "video_urls": { + 'type': 'list', + 'required': True, + 'nullable': False, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "pull_url", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "push_url", + 'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$' + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "id", + 'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$' + } + } + } + } + } + + PULL2PUSH_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start", "stop"] + }, + "video_ids": { + 'type': 'list', + 'required': False, + 'nullable': True, + 'schema': { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,255}$' + } + } + } diff --git a/backend/utils/Exception.py b/backend/utils/Exception.py new file mode 100644 index 0000000..7a9820b --- /dev/null +++ b/backend/utils/Exception.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +from loguru import logger +from enum import Enum, unique + +class ServiceException(Exception): + def __init__(self, code, msg, desc=None): + self.code = code + if desc is None: + self.msg = msg + else: + self.msg = msg % desc + + def __str__(self): + logger.error("异常编码:{}, 异常描述:{}", self.code, self.msg) + +# 异常枚举 +@unique +class ExceptionType(Enum): + OR_VIDEO_ADDRESS_EXCEPTION = ("SP000", "未拉取到视频流, 请检查拉流地址是否有视频流!") + ANALYSE_TIMEOUT_EXCEPTION = ("SP001", "AI分析超时!") + PULLSTREAM_TIMEOUT_EXCEPTION = ("SP002", "原视频拉流超时!") + READSTREAM_TIMEOUT_EXCEPTION = ("SP003", "原视频读取视频流超时!") + GET_VIDEO_URL_EXCEPTION = ("SP004", "获取视频播放地址失败!") + GET_VIDEO_URL_TIMEOUT_EXCEPTION = ("SP005", "获取原视频播放地址超时!") + PULL_STREAM_URL_EXCEPTION = ("SP006", "拉流地址不能为空!") + PUSH_STREAM_URL_EXCEPTION = ("SP007", "推流地址不能为空!") + PUSH_STREAM_TIME_EXCEPTION = ("SP008", "未生成本地视频地址!") + AI_MODEL_MATCH_EXCEPTION = ("SP009", "未匹配到对应的AI模型!") + ILLEGAL_PARAMETER_FORMAT = ("SP010", "非法参数格式!") + PUSH_STREAMING_CHANNEL_IS_OCCUPIED = ("SP011", "推流通道可能被占用, 请稍后再试!") + VIDEO_RESOLUTION_EXCEPTION = ("SP012", "不支持该分辨率类型的视频,请切换分辨率再试!") + READ_IAMGE_URL_EXCEPTION = ("SP013", "未能解析图片地址!") + DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED = ("SP014", "不支持该类型的检测目标!") + WRITE_STREAM_EXCEPTION = ("SP015", "写流异常!") + OR_VIDEO_DO_NOT_EXEIST_EXCEPTION = ("SP016", "原视频不存在!") + MODEL_LOADING_EXCEPTION = ("SP017", "模型加载异常!") + MODEL_ANALYSE_EXCEPTION = ("SP018", "算法模型分析异常!") + AI_MODEL_CONFIG_EXCEPTION = ("SP019", "模型配置不能为空!") + AI_MODEL_GET_CONFIG_EXCEPTION = ("SP020", "获取模型配置异常, 请检查模型配置是否正确!") + MODEL_GROUP_LIMIT_EXCEPTION = ("SP021", "模型组合个数超过限制!") + MODEL_NOT_SUPPORT_VIDEO_EXCEPTION = ("SP022", "%s不支持视频识别!") + MODEL_NOT_SUPPORT_IMAGE_EXCEPTION = ("SP023", "%s不支持图片识别!") + THE_DETECTION_TARGET_CANNOT_BE_EMPTY = ("SP024", "检测目标不能为空!") + URL_ADDRESS_ACCESS_FAILED = ("SP025", "URL地址访问失败, 请检测URL地址是否正确!") + UNIVERSAL_TEXT_RECOGNITION_FAILED = ("SP026", "识别失败!") + COORDINATE_ACQUISITION_FAILED = ("SP027", "飞行坐标识别异常!") + PUSH_STREAM_EXCEPTION = ("SP028", "推流异常!") + MODEL_DUPLICATE_EXCEPTION = ("SP029", "存在重复模型配置!") + DETECTION_TARGET_NOT_SUPPORT = ("SP031", "存在不支持的检测目标!") + TASK_EXCUTE_TIMEOUT = ("SP032", "任务执行超时!") + PUSH_STREAM_URL_IS_NULL = ("SP033", "拉流、推流地址不能为空!") + PULL_STREAM_NUM_LIMIT_EXCEPTION = ("SP034", "转推流数量超过限制!") + NOT_REQUESTID_TASK_EXCEPTION = ("SP993", "未查询到该任务,无法停止任务!") + NO_RESOURCES = ("SP995", "服务器暂无资源可以使用,请稍后30秒后再试!") + NO_CPU_RESOURCES = ("SP996", "暂无CPU资源可以使用,请稍后再试!") + SERVICE_COMMON_EXCEPTION = ("SP997", "公共服务异常!") + NO_GPU_RESOURCES = ("SP998", "暂无GPU资源可以使用,请稍后再试!") + SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常!") diff --git a/backend/utils/Flag.py b/backend/utils/Flag.py new file mode 100644 index 0000000..f737d57 --- /dev/null +++ b/backend/utils/Flag.py @@ -0,0 +1,12 @@ + +class Flag: + Unique = True + Append = False + + Debug = True + +class Option: + NoOption = 0x00 + + AddObject_AutoName = 0x01 + AddObject_Select = 0x02 \ No newline at end of file diff --git a/backend/utils/Helper.py b/backend/utils/Helper.py new file mode 100644 index 0000000..daeff4f --- /dev/null +++ b/backend/utils/Helper.py @@ -0,0 +1,668 @@ +# -*- coding: utf-8 -*- +import sys, os, cv2 +from os import makedirs +from os.path import join, exists +from loguru import logger +from json import loads +from ruamel.yaml import safe_load, YAML +import random, sys, math, inspect, psutil +from pathlib import Path +from PySide6.QtGui import QIcon, QColor +from PySide6.QtCore import QObject, QRectF, QEventLoop, QIODevice, QTextStream, QFile +from PySide6.QtWidgets import QApplication + +import DrGraph.utils.vclEnums as enums +#region Property +class Property: + def __init__(self, read_func=None, write_func=None, default=None, hasMember=True): + self.read_func = read_func + self.write_func = write_func + self.default = default + self.owner_class = None + self.private_name = None + self.hasMember = hasMember + def __set_name__(self, owner, name): + if self.hasMember: + self.private_name = f"_{name}" + self.owner_class = owner + def callDirectGet(self, instance, owner): + if instance is None or not self.hasMember: + return self.default + if not hasattr(instance, self.private_name): + return self.default + return getattr(instance, self.private_name) + def callCustomGet(self, instance, owner): + if instance is None: + return self + + if self.read_func is None: + if hasattr(instance, self.private_name): + return getattr(instance, self.private_name) + return self.default + + try: + if isinstance(self.read_func, str): + if hasattr(instance, self.read_func): + method = getattr(instance, self.read_func) + return method() + elif hasattr(self.read_func, '__name__'): + method_name = self.read_func.__name__ + if hasattr(instance, method_name): + method = getattr(instance, method_name) + return method() + elif callable(self.read_func): + try: + return self.read_func(instance) + except (TypeError, AttributeError): + return self.read_func() + except Exception as e: + logger.error(e) + + if hasattr(instance, self.private_name): + return getattr(instance, self.private_name) + return self.default + + def callDirectSet(self, instance, value): + if instance is None or not self.hasMember: + return + setattr(instance, self.private_name, value) + + def callCustomSet(self, instance, value): + try: + if isinstance(self.write_func, str): + if hasattr(instance, self.write_func): + method = getattr(instance, self.write_func) + method(value) + elif hasattr(self.write_func, '__name__'): + method_name = self.write_func.__name__ + if hasattr(instance, method_name): + method = getattr(instance, method_name) + method(value) + elif callable(self.write_func): + try: + self.write_func(instance, value) + except (TypeError, AttributeError): + self.write_func(value) + except Exception as e: + pass + +class Property_rw(Property): + def __init__(self, default=None, hasMember=True): + super().__init__(None, None, default, hasMember) + def __get__(self, instance, owner): + return self.callDirectGet(instance, owner) + def __set__(self, instance, value): + self.callDirectSet(instance, value) + +class Property_Rw(Property): + def __init__(self, read_func=None, default=None, hasMember=True): + super().__init__(read_func, None, default, hasMember) + + def __get__(self, instance, owner): + return self.callCustomGet(instance, owner) + + def __set__(self, instance, value): + setattr(instance, self.private_name, value) + +class Property_rW(Property): + def __init__(self, write_func=None, default=None, hasMember=True): + super().__init__(None, write_func, default, hasMember) + + def __get__(self, instance, owner): + return self.callDirectGet(instance, owner) + + def __set__(self, instance, value): + self.callCustomSet(instance, value) + +class Property_RW(Property): + def __init__(self, read_func=None, write_func=None, default=None, hasMember=True): + super().__init__(read_func, write_func, default, hasMember) + + def __get__(self, instance, owner): + return self.callCustomGet(instance, owner) + + def __set__(self, instance, value): + self.callCustomSet(instance, value) +#endregion Property + +class AppHelper(QObject): + app = Property_rw(None) + def setBriefStatusText(self, text): + if self.briefStatusControl: + self.briefStatusControl.setText(text) + else: + print(text) + briefStatusText = Property_rW(setBriefStatusText, '') + + def setProgress(self, value): + if self.progressBarControl: + self.progressBarControl.setValue(value) + progress = Property_rW(setProgress, 0) + + def setProgressMax(self, value): + if self.progressBarControl: + self.progressBarControl.setMaximum(value) + progressMax = Property_rW(setProgressMax, 100) + + def setProgressMin(self, value): + if self.progressBarControl: + self.progressBarControl.setMinimum(value) + progressMin = Property_rW(setProgressMin, 0) + + def __init__(self): + self.briefStatusControl = None + self.progressBarControl = None + self._briefStatusText = '' + pass + +class Helper: + OnLogMsg = None + AppFlag_SaveAnalysisResult = True + AppFlag_SaveLog = False + App = None + + @staticmethod + def castRange(value, minValue, maxValue): + return max(minValue, min(maxValue, value)) + # 取得程序目录 + @staticmethod + def getPath_App(): + if getattr(sys, 'frozen', False): + # 如果程序是打包的exe文件 + return os.path.dirname(sys.executable) + else: + # 如果是Python脚本 - 获取上两级目录 + current_file = os.path.abspath(__file__) # f:\PySide6\AiBase\DrGraph\utils\Helper.py + current_dir = os.path.dirname(current_file) # f:\PySide6\AiBase\DrGraph\utils + parent_dir = os.path.dirname(current_dir) # f:\PySide6\AiBase\DrGraph + root_dir = os.path.dirname(parent_dir) # f:\PySide6\AiBase + return root_dir + + @staticmethod + def fitOS(file_name): + if sys.platform.startswith('win'): + file_name = file_name.replace('/','\\') + else: + file_name = file_name.replace('\\', '/') + return file_name + + def generateDistinctColors(n, s=0.8, v=0.7): + import colorsys + colors = [] + for i in range(n): + hue = i * 1.0 / n # 均匀分布在 [0, 1) + r, g, b = colorsys.hsv_to_rgb(hue, s, v) + colors.append(QColor(r * 255, g * 255, b * 255)) + return colors + + def setBriefStatusText(self, text): + Helper.App.setBriefStatusText(text) + briefStatusText = Property_rW(setBriefStatusText, '') + @staticmethod + def Sleep(msec): + QApplication.processEvents(QEventLoop.AllEvents, msec) + + + @staticmethod + def getAbsoluteFileName(file_name): + if os.path.isabs(file_name): + return Helper.fitOS(file_name) + else: + return Helper.fitOS(os.path.join(Helper.getPath_App(), file_name)) + @staticmethod + def getConfigs(path, read_type='yml'): + """ + 读取配置文件并返回解析后的配置信息 + + :param path: 配置文件路径 + :param read_type: 配置文件类型,默认为'yml',可选'json'或'yml' + :return: 解析后的配置信息,JSON格式返回字典,YML格式返回对应的数据结构 + :raises Exception: 当无法获取配置信息时抛出异常 + """ + yaml = YAML(typ='safe', pure=True) + with open(path, 'r', encoding='utf-8') as f: + return yaml.load(f) + # with open(path, 'r', encoding="utf-8") as f: + # # 根据文件类型选择相应的解析方式 + # if read_type == 'json': + # return loads(f.read()) + # if read_type == 'yml': + # return safe_load(f) + # 如果未成功读取配置信息,则抛出异常 + raise Exception('路径: %s未获取配置信息' % path) + + @staticmethod + def getTooltipText(content): + # 增加一个小喇叭图标 + # content = f' {content}' + return f""" + + +

DrGraph

+

{content}

+ + +""" + @staticmethod + def log_init(app, base_dir, env): + """ + 初始化日志配置 + + :param base_dir: 基础目录路径,用于定位配置文件和日志文件存储位置 + :param env: 环境标识,用于加载对应环境的日志配置文件 + :return: 无返回值 + """ + Helper.App = AppHelper() + Helper.App.app = app + # QToolTip样式 - 自定义样式 - 增加Header + app.setStyleSheet(""" + QToolTip { + background-color: #dd2222; + color: #f0f0f0; + border: 1px solid #555; + border-radius: 4px; + padding: 6px; + font: 10pt "Segoe UI"; + opacity: 220; + } + """) + + log_config = Helper.getConfigs(join(base_dir, 'appIOs/configs/logger/drgraph_%s_logger.yml' % env)) + # 判断日志文件是否存在,不存在创建 + base_path = join(base_dir, log_config.get("base_path")) + if not exists(base_path): + makedirs(base_path) + # 移除日志设置 + logger.remove(handler_id=None) + # 打印日志到文件 + if bool(log_config.get("enable_file_log")): + logger.add(join(base_path, log_config.get("log_name")), + rotation=log_config.get("rotation"), + retention=log_config.get("retention"), + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True, + encoding=log_config.get("encoding")) + # 控制台输出 + if bool(log_config.get("enable_stderr")): + logger.add(sys.stderr, + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True) + logger.info("\n\n\n----=========== 日志配置初始化完成, 开始新的日志记录 ==========----") + + @staticmethod + def log_info(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'INFO: {msg}', 'black') + caller = inspect.stack()[1] + logger.info(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "INFO", "msg" : msg} ) + @staticmethod + def log_error(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'ERROR: {msg}', 'red') + caller = inspect.stack()[1] + logger.error(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "ERROR", "msg" : msg} ) + @staticmethod + def log_warning(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'WARNING: {msg}', (255, 128, 0)) + caller = inspect.stack()[1] + logger.warning(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "WARNING", "msg" : msg} ) + @staticmethod + def log_debug(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'DEBUG: {msg}', (0, 128, 128)) + caller = inspect.stack()[1] + logger.debug(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "DEBUG", "msg" : msg} ) + @staticmethod + def log_critical(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'CRITICAL: {msg}', (128, 0, 128)) + caller = inspect.stack()[1] + logger.critical(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "CRITICAL", "msg" : msg} ) + @staticmethod + def log_exception(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'EXCEPTION: {msg}', (255, 140, 0)) + caller = inspect.stack()[1] + logger.exception(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "EXCEPTION", "msg" : msg} ) + @staticmethod + def log(msg, toWss = False): + caller = inspect.stack()[1] + logger.log(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "LOG", "msg" : msg} ) + @staticmethod + def log_wss(msg): + if Helper.wss: + Helper.wss.send(msg) + + @staticmethod + def getTextSize(font, text): + import pygame as pg + surface = font.render(text, True, (0, 0, 0)) + return (surface.get_width(), surface.get_height(), surface) + + @staticmethod + def buildSurfaces(font, text, width, color, wordWrap): + text = text.strip() + w = Helper.getTextSize(font, text)[0] + result = [] + if w > width and wordWrap: + segLen = math.floor(width / w * len(text)) + while len(text): + if len(text) < segLen: + t = text + text = '' + else: + t = text[:segLen] + text = text[segLen:] + result.append(font.render(t, True, color)) + else: + result.append(font.render(text, True, color)) + return result + + @staticmethod + def randomColor(): + '''随机颜色''' + return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + + @staticmethod + def reverseColor(color: QColor): + '''反转颜色''' + return (255 - color.red(), 255 - color.green(), 255 - color.blue()) + @staticmethod + def getRGB(color_value): + # 如果是元组或列表形式的RGB值 + if isinstance(color_value, (tuple, list)): + if len(color_value) >= 3: + # 取前三个值作为RGB + r, g, b = color_value[0], color_value[1], color_value[2] + # 确保值在0-255范围内 + return (max(0, min(255, int(r))), + max(0, min(255, int(g))), + max(0, min(255, int(b)))) + + # 如果是整数形式的颜色值 + elif isinstance(color_value, int): + # 将整数转换为RGB分量 + # 假设格式为0xRRGGBB + r = (color_value >> 16) & 0xFF + g = (color_value >> 8) & 0xFF + b = color_value & 0xFF + return (r, g, b) + + # 如果是字符串形式 + elif isinstance(color_value, str): + # 处理十六进制颜色值 + if color_value.startswith('#'): + hex_value = color_value[1:] + if len(hex_value) == 3: # 简写形式 #RGB + hex_value = ''.join([c*2 for c in hex_value]) + if len(hex_value) in (6, 8): # #RRGGBB 或 #RRGGBBAA + r = int(hex_value[0:2], 16) + g = int(hex_value[2:4], 16) + b = int(hex_value[4:6], 16) + return (r, g, b) + # 处理颜色名称(需要额外的颜色名称映射表) + # 这里只列举几种常见颜色 + color_names = { + 'black': (0, 0, 0), + 'white': (255, 255, 255), + 'red': (255, 0, 0), + 'green': (0, 255, 0), + 'blue': (0, 0, 255), + 'yellow': (255, 255, 0), + 'magenta': (255, 0, 255), + 'cyan': (0, 255, 255), + 'orange': (255, 128, 0), # 根据项目规范 + 'teal': (0, 128, 128) # 根据项目规范 + } + if color_value.lower() in color_names: + return color_names[color_value.lower()] + + # 如果是Color对象(如pygame.Color) + elif hasattr(color_value, 'r') and hasattr(color_value, 'g') and hasattr(color_value, 'b'): + return (color_value.r, color_value.g, color_value.b) + + # 默认返回黑色 + return (0, 0, 0) + + @staticmethod + def check_system_resources(): + """检查系统资源使用情况""" + logger.info("检查系统资源使用情况...") + cpu_percent = psutil.cpu_percent(interval=1) + memory = psutil.virtual_memory() + network = psutil.net_io_counters() + + logger.info("检查系统资源使用情况完毕") + + return { + 'cpu_percent': cpu_percent, + 'memory_percent': memory.percent, + 'memory_available': memory.available / (1024**3), # GB + 'network_bytes_sent': int(network.bytes_sent / 1024), + 'network_bytes_recv': int(network.bytes_recv / 1024) + } + + @staticmethod + def build_response(type, status : enums.Response, msg): + status_code, status_msg = status.value + result = { + "type": "response", + "request_type": type, + "status_code": status_code, + "status_msg": status_msg, + "detail_msg": msg + } + if status_code != 0: + Helper.error(result); + return result + + @staticmethod + def get_surrounding_rect(points): + if len(points) == 0: + return Constant.invalid_rect + min_x = min(p.x() for p in points) + min_y = min(p.y() for p in points) + max_x = max(p.x() for p in points) + max_y = max(p.y() for p in points) + return QRectF(min_x, min_y, max_x - min_x, max_y - min_y) + + @staticmethod + def getYoloLabellingInfo(dir_path, file_names, desc): + if len(dir_path) > 0: + imageNumber, labelNumber = 0, 0 + imagePath = dir_path + 'images/' + labelPath = dir_path + 'labels/' + for file_name in file_names: + if file_name.startswith(imagePath): + imageNumber += 1 + elif file_name.startswith(labelPath): + labelNumber += 1 + return f'{desc} {imageNumber - 1} 张图片,{labelNumber - 1} 张标签;', imageNumber - 1 + return f'无{desc};', 0 + + @staticmethod + def getMarkdownRenderText(mdContent): + # 使用Python库直接将Markdown转换为HTML,避免JavaScript依赖 + try: + # 尝试导入markdown库 + import markdown + html_content = markdown.markdown(mdContent) + # 添加基本样式使其美观 + styled_html = f""" + + + + + + + + {html_content} + + + """ + return styled_html + except ImportError: + # 回退到JavaScript的marked.js方法 + logger.warning("未找到markdown库,使用JavaScript渲染方式") + # 从appIOs/configs加载marked.js + file_js = QFile('appIOs/configs/marked.min.js') + markedJs = '' + if file_js.open(QIODevice.ReadOnly | QIODevice.Text): + markedJs = file_js.readAll().data().decode('utf-8') + file_js.close() + + # 转义markdown内容 + escapedMd = mdContent.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') + + # 创建HTML模板 + htmlTemplate = ''' + + + + + + + + +
+ + + + ''' + + # 生成HTML内容 + htmlContent = htmlTemplate.replace('%1', markedJs).replace('%2', escapedMd) + return htmlContent + @staticmethod + def getMarkdownRender(mdFileName): + file = QFile(mdFileName) + mdContent = '' + if file.open(QIODevice.ReadOnly | QIODevice.Text): + stream = QTextStream(file) + stream.setAutoDetectUnicode(True) + mdContent = stream.readAll() + file.close() + return Helper.getMarkdownRenderText(mdContent) + else: + logger.error(f'打开文件 {mdFileName} 失败') + return f"

无法打开文件: {mdFileName}

" + +class RTTI: + @staticmethod + def _do_set_attr(obj, property_name, property_value): + if obj is None: + logger.error(f"RTTI.set: obj is None") + return + + class_name = type(obj).__name__ + object_name = obj.objectName() + if property_name not in dir(obj): + logger.error(f"RTTI.set: {class_name} {object_name}.{property_name} not in dir(obj)") + return + if property_name.endswith('icon') and isinstance(property_value, str): + original_property_value = property_value + if not os.path.exists(property_value): + property_value = os.path.join('appIOs/res/images/icons',property_value) + # logger.info(f"RTTI.set: {class_name} {object_name}.{property_name} = {property_value}(自动匹配)") + if not os.path.exists(property_value): + logger.error(f"{original_property_value}文件不存在 > RTTI.set: {class_name} {object_name}.{property_name} = '{original_property_value}'") + return + property_value = QIcon(property_value) + setter_method = getattr(obj, f'set{property_name[0].upper() + property_name[1:]}') + setter_method(property_value) + @staticmethod + def set(obj, property_name, property_value): + property_list = property_name.split('.') + if len(property_list) == 1: + RTTI._do_set_attr(obj, property_name, property_value) + else: + dest_obj = obj + for i in range(len(property_list) - 1): + if not dest_obj: + logger.error(f"RTTI.set: {property_list.join('.')} not found") + return + dest_obj = getattr(dest_obj, property_list[i]) + RTTI._do_set_attr(dest_obj, property_list[-1], property_value) + + @staticmethod + def _do_get_attr(obj, property_name): + if obj is None: + logger.error(f"RTTI.get: obj is None") + return None, None + if property_name not in dir(obj): + logger.error(f"RTTI.get: {type(obj).__name__} {obj.objectName()}.{property_name} not in dir(obj)") + return None, None + # 返回属性类型与属性值 + type_name = type(getattr(obj, property_name)).__name__ + value = getattr(obj, property_name) + return type_name, value + # 取得属性类型与属性值 type, value = RTTI.get(obj, property_name) + @staticmethod + def get(obj, property_name): + property_list = property_name.split('.') + if len(property_list) == 1: + return RTTI._do_get_attr(obj, property_name) + else: + dest_obj = obj + for i in range(len(property_list) - 1): + if not dest_obj: + logger.error(f"RTTI.get: {property_list.join('.')} not found") + return None + dest_obj = getattr(dest_obj, property_list[i]) + return RTTI._do_get_attr(dest_obj, property_list[-1]) + +class DrawHelper: + @staticmethod + def draw_dashed_line(mat, pt1, pt2, color, thickness=1, dash_length=10): + dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5 + dashes = int(dist / dash_length) + for i in range(dashes): + start = (int(pt1[0] + (pt2[0] - pt1[0]) * i / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * i / dashes)) + end = (int(pt1[0] + (pt2[0] - pt1[0]) * (i + 0.5) / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * (i + 0.5) / dashes)) + cv2.line(mat, start, end, color, thickness) + + @staticmethod + def draw_dashed_rect(painter, rect, color, thickness=1, dash_length=10): + x1, y1 = rect.left(), rect.top() + x2, y2 = rect.right(), rect.bottom() + DrawHelper.draw_dashed_line(painter, (x1, y1), (x2, y1), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x1, y2), (x2, y2), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x1, y1), (x1, y2), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x2, y1), (x2, y2), color, thickness, dash_length) \ No newline at end of file diff --git a/backend/utils/YOLOTracker.py b/backend/utils/YOLOTracker.py new file mode 100644 index 0000000..7301a48 --- /dev/null +++ b/backend/utils/YOLOTracker.py @@ -0,0 +1,466 @@ +from loguru import logger +import subprocess as sp +from ultralytics import YOLO +import time, cv2, numpy as np, math +from traceback import format_exc +from DrGraph.utils.pull_push import NetStream +from DrGraph.utils.Helper import * +from DrGraph.utils.Constant import Constant +from zipfile import ZipFile + +class YOLOTracker: + def __init__(self, model_path): + """ + 初始化YOLOv11追踪器 + """ + self.model = YOLO(model_path) + self.tracking_config = { + "tracker": "appIOs/configs/yolo11/bytetrack.yaml", # "/home/thsw/jcq/projects/yolov11/ultralytics-main/ultralytics/cfg/trackers/bytetrack.yaml", + "conf": 0.25, + "iou": 0.45, + "persist": True, + "verbose": False + } + self.frame_count = 0 + self.processing_time = 0 + + def process_frame(self, frame): + """ + 处理单帧图像,进行目标检测和追踪 + """ + start_time = time.time() + + try: + # 执行YOLOv11目标检测和追踪 + results = self.model.track( + source=frame, + **self.tracking_config + ) + + # 获取第一个结果(因为只处理单张图片) + result = results[0] + + # 绘制检测结果 + processed_frame = result.plot() + + # 计算处理时间 + self.processing_time = (time.time() - start_time) * 1000 # 转换为毫秒 + self.frame_count += 1 + + # 打印检测信息(可选) + if self.frame_count % 100 == 0: + self._print_detection_info(result) + + return processed_frame, result + + except Exception as e: + logger.error("YOLO处理异常: {}", format_exc()) + return frame, None + + def _print_detection_info(self, result): + """ + 打印检测信息 + """ + boxes = result.boxes + if boxes is not None and len(boxes) > 0: + detection_count = len(boxes) + unique_ids = set() + for box in boxes: + if box.id is not None: + unique_ids.add(int(box.id[0])) + + logger.info(f"帧 {self.frame_count}: 检测到 {detection_count} 个目标, 追踪ID数: {len(unique_ids)}, 处理时间: {self.processing_time:.2f}ms") + else: + logger.info(f"帧 {self.frame_count}: 未检测到目标, 处理时间: {self.processing_time:.2f}ms") + +class YOLOTrackerManager: + def __init__(self, model_path, pull_url, push_url, request_id): + self.pull_url = pull_url + self.push_url = push_url + self.request_id = request_id + self.tracker = YOLOTracker(model_path) + self.stream = None + self.videoStream = None + self.videoType = Constant.INPUT_NONE + self.localFile = '' + self.localPath = '' + self.localFiles = [] + self._currentFrame = None + self.totalFrames = 0 + self.frameChanged = False + + def _stop(self): + if self.videoStream is not None: + self.videoStream.release() + self.videoStream = None + if self.stream is not None: + self.stream.clear_pull_p(self.stream.pull_p, self.request_id) + self.stream = None + self.localFile = '' + self.localPath = '' + self.localFiles = [] + self._currentFrame = None + self.totalFrames = 0 + self._frameIndex = -1 + self.videoType = Constant.INPUT_NONE + self.frameChanged = True + + def startLocalFile(self, fileName): + self._stop() + self.localFile = fileName + self._frameIndex = -1 + + def startLocalDir(self, dirName): + self._stop() + self.localPath = dirName + self.localFiles = [os.path.join(dirName, f) for f in os.listdir(dirName) if f.endswith(('.jpg', '.jpeg', '.png'))] + self.totalFrames = len(self.localFiles) + Helper.App.progressMax = self.totalFrames + self.localFiles.sort() + logger.info("本地目录打开: {}, 总帧数: {}", dirName, self.totalFrames) + self._frameIndex = 0 + + def startLabelledZip(self, labelledPath, categoryPath): + self._stop() + self.localPath = labelledPath + localFiles = ZipFile(labelledPath).namelist() + _, self.totalFrames = Helper.getYoloLabellingInfo(categoryPath, localFiles, '') + imagePath = categoryPath + 'images/' + self.localFiles = [file for file in localFiles if imagePath in file] + logger.info(f"标注压缩文件{labelledPath}的{categoryPath}集共有{self.totalFrames}帧, 有效帧数: {len(self.localFiles)}") + self._frameIndex = 0 + Helper.App.progressMax = self.totalFrames + + def startUsbCamera(self, index = 0): + self._stop() + self.videoStream = cv2.VideoCapture(index) + self.videoType = Constant.INPUT_USB_CAMERA + Helper.Sleep(200) + if not self.videoStream.isOpened(): + logger.error("无法打开USB摄像头: {}", index) + self.videoType = Constant.INPUT_NONE + return + self.totalFrames = 0x7FFFFFFF + + def startLocalVideo(self, fileName): + self._stop() + self.videoStream = cv2.VideoCapture(fileName) + self.videoType = Constant.INPUT_LOCAL_VIDEO + Helper.Sleep(200) + if not self.videoStream.isOpened(): + logger.error("无法打开本地视频流: {}", fileName) + self.videoType = Constant.INPUT_NONE + return + try: + total = int(self.videoStream.get(cv2.CAP_PROP_FRAME_COUNT)) + except Exception: + total = 0 + self.totalFrames = total if total is not None else 0 + Helper.App.progressMax = self.totalFrames + logger.info("本地视频打开: {}, 总帧数: {}", fileName, self.totalFrames) + + def startPull(self, url = ''): + self._stop() + if len(url) > 0: + self.pull_url = url + logger.info("拉流地址: {}", self.pull_url) + self.stream = NetStream(self.pull_url, self.push_url, self.request_id) + self.stream.prepare_pull() + + def getCurrentFrame(self): + if self._currentFrame is None: + self._currentFrame = self.nextFrame() + if self._currentFrame is not None: + return self._currentFrame.copy() + return None + currentFrame = Property_Rw(getCurrentFrame, None) + + def setFrameIndex(self, index): + if self.videoStream is None and len(self.localFiles) == 0: + return + if self.videoStream is not None and self.videoType != Constant.INPUT_LOCAL_VIDEO: + return + if index < 0: + index = 0 + if index >= self.totalFrames: + index = self.totalFrames - 1 + if self.videoStream: + self.videoStream.set(cv2.CAP_PROP_POS_FRAMES, index) + self._frameIndex = index - 1 + self._currentFrame = self.nextFrame() + self.frameChanged = True + frameIndex = Property_rW(setFrameIndex, 0) + + def getLabels(self): + with ZipFile(self.localPath, 'r') as zip_ref: + content = zip_ref.read(self.localFile) + content = content.decode('utf-8') + return content + return '' + # 取得待分析的图像帧 + def getAnalysisFrame(self, nextFlag): + frameChanged = self.frameChanged + self.frameChanged = False + if nextFlag: # 流式媒体 + self._currentFrame = self.nextFrame() + self.frameChanged = True + frame = self.currentFrame + return frame.copy() if frame is not None else None, frameChanged + + def nextFrame(self): + frame = None + if self.stream: + frame = self.stream.next_pull_frame() + elif self.videoStream: + ret, frame = self.videoStream.read() + self._frameIndex += 1 + if not ret: + self._frameIndex -= 1 + frame = None + elif len(self.localFiles) > 0: + if self.localPath.endswith('.zip'): + index = -1 + for img_file in self.localFiles: + if '/images/' in img_file: + if index == self._frameIndex: + # logger.warning(f'Loading image from zip file: {img_file}') + try: + with ZipFile(self.localPath, 'r') as zip_ref: + image_data = zip_ref.read(img_file) + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + self._frameIndex += 1 + lable_file = img_file.replace('/images/', '/labels/').replace('.jpg', '.txt').replace('.png', '.txt') + self.localFile = lable_file + except Exception as e: + # logger.error(f"读取压缩文件 {self.localPath} 中的 {img_file} 失败: {e}") + frame = None + break + index += 1 + else: + if self._frameIndex < 0: + self._frameIndex = 0 + if self._frameIndex >= len(self.localFiles): + self._frameIndex = 0 + if self._frameIndex < len(self.localFiles): + frame = cv2.imread(self.localFiles[self._frameIndex]) + if frame is None: + logger.error(f"无法读取目标目录 {self.localPath}中下标为 {self._frameIndex} 的视频文件 {self.localFiles[self._frameIndex]}") + self._frameIndex = -1 + return + self._frameIndex += 1 + elif self.localFile is not None and self.localFile != '': + frame = cv2.imread(self.localFile) + if frame is None: + logger.error("无法读取本地视频文件: {}", self.localFile) + return + if frame is not None: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if self.totalFrames > 0: + Helper.App.progress = self._frameIndex + return frame + + def test_yolo11_recognize(self, frame): + processed_frame = self.process_frame_with_yolo(frame, self.request_id) + return processed_frame + + def process_frame_with_yolo(self, frame, requestId): + """ + 使用YOLOv11处理帧 + """ + try: + # 使用YOLO进行目标检测和追踪 + processed_frame, detection_result = self.tracker.process_frame(frame) + + # 在帧上添加处理信息 + fps_info = f"FPS: {1000/max(self.tracker.processing_time, 1):.1f}" + cv2.putText(processed_frame, fps_info, (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + # 添加检测目标数量信息 + if detection_result and detection_result.boxes is not None: + obj_count = len(detection_result.boxes) + count_info = f"Objects: {obj_count}" + cv2.putText(processed_frame, count_info, (10, 70), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + return processed_frame + + except Exception as e: + logger.error("YOLO处理异常:{}, requestId:{}", format_exc(), requestId) + # 如果处理失败,返回原帧 + return frame + + def get_gray_mask(self, frame): + """ + 生成灰度像素的掩码图 + 灰度像素定义:三颜色分量差小于20 + """ + # 创建与原图大小相同的掩码图 + maskMat = np.zeros(frame.shape[:2], dtype=np.uint8) + + # 获取图像的三个颜色通道 + b, g, r = cv2.split(frame) + + r = r.astype(np.int16) + g = g.astype(np.int16) + b = b.astype(np.int16) + # 计算任意两个颜色分量之间的差值 + diff_rg = np.abs(r - g) + is_shadow = (b > r) & (b - r < 40) + diff_rb = np.abs(r - b) + diff_gb = np.abs(g - b) + + # 判断条件:三颜色分量差都小于20 + gray_pixels = (diff_rg < 20 ) & (diff_rb < 20| is_shadow) & (diff_gb < 20) + + # 将满足条件的像素在掩码图中设为255(白色) + maskMat[gray_pixels] = 255 + + return maskMat + + def debugLine(self, line, y_intersect): + x1, y1, x2, y2 = line + length = np.linalg.norm([x2 - x1, y2 - y1]) + # 计算线与水平线的夹角(度数) + # 使用atan2计算弧度,再转换为度数 + angle_rad = math.atan2(y2 - y1, x2 - x1) + angle_deg = math.degrees(angle_rad) + # 调整角度范围到0-180度(平面角) + if angle_deg < 0: + angle_deg += 180 + # angle_deg = min(angle_deg, 180 - angle_deg) + x_intersect = (x2 - x1) * (y_intersect - y1) / (y2 - y1) + x1 + return angle_deg, length, x_intersect + def test_highway_recognize(self, frame, debugFlag = False): + processed_frame = frame.copy() + + try: + IGNORE_HEIGHT = 100 + y_intersect = frame.shape[0] / 2 + frame[:IGNORE_HEIGHT, :] = (255, 0, 0) + + gray_mask = self.get_gray_mask(frame) + + kernel = np.ones((5, 5), np.uint8) # 使用形态学开运算(先腐蚀后膨胀)去除小噪声点 + gray_mask = cv2.erode(gray_mask, kernel) + gray_mask = cv2.erode(gray_mask, kernel) + + # 过滤掉面积小于10000的区域 + contours, _ = cv2.findContours(gray_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 创建新的掩码图像,只保留面积大于等于10000的区域 + filtered_mask = np.zeros_like(gray_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area >= 10000: # 填充满足条件的轮廓区域 + cv2.fillPoly(filtered_mask, [contour], 255) + + gray_mask = filtered_mask # 使用过滤后的掩码替换原来的gray_mask + edges = cv2.Canny(frame, 100, 200) # 边缘检测 + road_edges = cv2.bitwise_and(edges, edges, mask=filtered_mask) # 在过滤后的路面区域内进行边缘检测 + + # 用color_mask过滤原图,得到待处理的图 + whiteLineMat = cv2.bitwise_and(processed_frame, processed_frame, mask=filtered_mask) + whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_RGB2GRAY) # 灰度化 + # sobel边缘检测 + whiteLineMat = cv2.Sobel(whiteLineMat, cv2.CV_8U, 1, 0, ksize=3) + tempMat = whiteLineMat.copy() + # whiteLineMat = cv2.Canny(whiteLineMat, 100, 200) + lines = cv2.HoughLinesP(tempMat, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) + whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) + + # logger.info(f"{lines.shape[0]} lines: ") + # if lines is not None: + # for line in lines: + # x1, y1, x2, y2 = line[0] + # cv2.line(whiteLineMat, (x1, y1), (x2, y2), (255, 0, 0), 2) + + + # 创建彩色掩码用于叠加(使用绿色标记识别出的路面) + color_mask = cv2.cvtColor(gray_mask, cv2.COLOR_GRAY2RGB) + color_mask[:] = (0, 255, 0) # 设置为绿色 + color_mask = cv2.bitwise_and(color_mask, color_mask, mask=filtered_mask) + + # 先叠加路面绿色标记,再叠加白色线条红色标记 + overlay = cv2.addWeighted(processed_frame, 0.7, color_mask, 0.3, 0) + + # # 在road_edges的基础上,识别其中的实线 + # lines = cv2.HoughLinesP(road_edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) + # logger.info(f"{lines.shape[0]} lines: ") + # linesWithAngle = [] + # # if lines is not None: + # for index, line in enumerate(lines): + # angle_deg, length, x_intersect = self.debugLine(line[0], y_intersect) + # linesWithAngle.append((line, angle_deg, x_intersect)) + # if debugFlag: + # logger.info(f'line {index + 1}: {line}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # linesWithAngle进行聚类算法,按夹角分两类即可 + # 使用自定义的简单K-means聚类实现 + # line_data = np.array([[angle, x_intersect] for line, angle, x_intersect in linesWithAngle]) + # if len(line_data) > 0: + # labels = self._simple_kmeans(line_data, n_clusters=2, random_state=2, random_state=0) + # # 输出两类线的数目 + # logger.info(f"聚类结果:{np.bincount(labels)}") + # if debugFlag: + # lines0 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 0] + # lines1 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 1] + # # 取得lines0中所有线段并输出日志信息 + # for index, line in enumerate(lines0): + # angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect) + # logger.info(f'聚类0: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # for index, line in enumerate(lines1): + # angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect) + # logger.info(f'聚类1: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # # 保留数量多的类别 + # dominant_cluster = np.argmax(np.bincount(labels)) + # # 绘制dominant_cluster类别的线 + # dominant_lines = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == dominant_cluster] + + # for line, angle, x_intersect in dominant_lines: + # cv2.line(overlay, (int(line[0][0]), int(line[0][1])), (int(line[0][2]), int(line[0][3])), (255, 0, 0), 2) + + return overlay, color_mask, whiteLineMat # cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) # cv2.cvtColor(road_edges, cv2.COLOR_GRAY2RGB) + + except Exception as e: + logger.error("路面识别异常:{}", format_exc()) + # 如果处理失败,返回原始帧 + return processed_frame + + # def _simple_kmeans(self, data, n_clusters=2, max_iter=100, random_state=0): + # """ + # 使用K-means算法对数据进行聚类 + + # 参数: + # data: array-like, 形状为 (n_samples, n_features) 的输入数据 + # n_clusters: int, 聚类数量,默认为2 + # max_iter: int, 最大迭代次数,默认为100 + # random_state: int, 随机种子,用于初始化质心,默认为0 + + # 返回: + # labels: array, 形状为 (n_samples,) 的聚类标签数组 + # """ + # np.random.seed(random_state) + + # # 随机选择初始质心 + # centroids_idx = np.random.choice(len(data), size=n_clusters, replace=False) + # centroids = data[centroids_idx].copy() + + # # 迭代优化质心位置 + # for _ in range(max_iter): + # # 为每个数据点分配最近的质心标签 + # labels = np.zeros(len(data), dtype=int) + # for i, point in enumerate(data): + # distancesi=ids - point, ax(centroids - point, axis=1) ce置为- # 情况如果d sfnpcnsy>d9e,则置为>180 -9dis ini作为新质心 + # new_centroids[c] = data[np.random.choice(len(data))] + + # # 检查收敛条件 + # if np.allclose(centroids, new_centroids): + # break + + # centroids = new_centroids + + # return labels + diff --git a/backend/utils/pull_push.py b/backend/utils/pull_push.py new file mode 100644 index 0000000..42219f7 --- /dev/null +++ b/backend/utils/pull_push.py @@ -0,0 +1,115 @@ +import subprocess as sp +from traceback import format_exc +import cv2, time +import numpy as np +from loguru import logger +from DrGraph.utils.Helper import Helper +from DrGraph.utils.Exception import ServiceException +from DrGraph.utils.Constant import Constant + +class NetStream: + def __init__(self, pull_url, push_url, request_id): + self.pull_url = pull_url + self.push_url = push_url + self.request_id = request_id + self.pull_p = None + + self.width = 1920 + self.height = 1080 * 3 // 2 + self.width_height_3 = 1920 * 1080 * 3 // 2 + self.w_2 = 960 + self.h_2 = 540 + + self.frame_count = 0 + self.start_time = time.time(); + + def clear_pull_p(self,pull_p, requestId): + try: + if pull_p and pull_p.poll() is None: + logger.info("关闭拉流管道, requestId:{}", requestId) + if pull_p.stdout: + pull_p.stdout.close() + pull_p.terminate() + pull_p.wait(timeout=30) + logger.info("拉流管道已关闭, requestId:{}", requestId) + except Exception as e: + logger.error("关闭拉流管道异常: {}, requestId:{}", format_exc(), requestId) + if pull_p and pull_p.poll() is None: + pull_p.kill() + pull_p.wait(timeout=30) + raise e + + def start_pull_p(self, pull_url, requestId): + try: + command = ['D:/DrGraph/DSP/ffmpeg.exe'] + # if pull_url.startswith("rtsp://"): + # command.extend(['-timeout', '20000000', '-rtsp_transport', 'tcp']) + # if pull_url.startswith("http") or pull_url.startswith("rtmp"): + # command.extend(['-rw_timeout', '20000000']) + command.extend(['-re', + '-y', + '-an', + # '-hwaccel', 'cuda', cuvid + '-c:v', 'h264_cuvid', + # '-resize', self.wah, + '-i', pull_url, + '-f', 'rawvideo', + # '-pix_fmt', 'bgr24', + '-r', '25', + '-']) + self.pull_p = sp.Popen(command, stdout=sp.PIPE) + return self.pull_p + except ServiceException as s: + logger.error("构建拉流管道ServiceException异常: url={}, {}, requestId:{}", pull_url, s.msg, requestId) + raise s + except Exception as e: + logger.error("构建拉流管道Exception异常:url={}, {}, requestId:{}", pull_url, format_exc(), requestId) + raise e + + def pull_read_video_stream(self): + result = None + try: + if self.pull_p is None: + self.start_pull_p(self.pull_url, self.request_id) + in_bytes = self.pull_p.stdout.read(self.width_height_3) + if in_bytes is not None and len(in_bytes) > 0: + try: + # result = (np.frombuffer(in_bytes, np.uint8).reshape([height * 3 // 2, width, 3])) + # ValueError: cannot reshape array of size 3110400 into shape (1080,1920) + result = (np.frombuffer(in_bytes, np.uint8)).reshape((self.height, self.width)) + result = cv2.cvtColor(result, cv2.COLOR_YUV2BGR_NV12) + # result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + if result.shape[1] > Constant.pull_frame_width: + result = cv2.resize(result, (result.shape[1] // 2, result.shape[0] // 2), interpolation=cv2.INTER_LINEAR) + except Exception: + logger.error("视频格式异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[0], + ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[1]) + except ServiceException as s: + logger.error("ServiceException 读流异常: {}, requestId:{}", s.msg, self.request_id) + self.clear_pull_p(self.pull_p, self.request_id) + self.pull_p = None + result = None + raise s + except Exception: + logger.error("Exception 读流异常:{}, requestId:{}", format_exc(), self.request_id) + self.clear_pull_p(self.pull_p, self.request_id) + self.pull_p = None + self.width = None + self.height = None + self.width_height_3 = None + result = None + logger.error("读流异常:{}, requestId:{}", format_exc(), self.request_id) + return result + + def prepare_pull(self): + if self.pull_p is None: + self.start_time = time.time(); + self.start_pull_p(self.pull_url, self.request_id) + + def next_pull_frame(self): + if self.pull_p is None: + logger.error(f'pull_p is None, requestId: {self.request_id}') + return None + frame = self.pull_read_video_stream() + return frame \ No newline at end of file diff --git a/backend/utils/util_exceptions.py b/backend/utils/util_exceptions.py index cc5ee94..3090501 100644 --- a/backend/utils/util_exceptions.py +++ b/backend/utils/util_exceptions.py @@ -1,11 +1,13 @@ """Custom exceptions and error handlers for the chat agent application.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from fastapi import HTTPException, Request, status -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from typing import Union from pydantic import BaseModel from datetime import datetime import json +from loguru import logger + from starlette.status import ( HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, @@ -25,14 +27,33 @@ class FullHxfResponseModel(BaseModel): message: Optional[str] class HxfResponse(JSONResponse): - def __init__(self, response: Union[BaseModel, Dict[str, Any]]): - if isinstance(response, BaseModel): + def __new__(cls, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]], StreamingResponse]): + # 如果是StreamingResponse,直接返回,不进行JSON包装 + if isinstance(response, StreamingResponse): + return response + # 否则创建HxfResponse实例 + return super().__new__(cls) + + def __init__(self, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]]]): + code = 0 + if isinstance(response, list): + # 处理BaseModel对象列表 + if all(isinstance(item, BaseModel) for item in response): + data_dict = [item.model_dump(mode='json') for item in response] + else: + data_dict = response + elif isinstance(response, BaseModel): + # 处理单个BaseModel对象 data_dict = response.model_dump(mode='json') else: + # 处理字典或其他可JSON序列化的数据 data_dict = response + if 'success' in data_dict and data_dict['success'] == False: + code = -1 + content = { - "code": 0, + "code": code, "status": status.HTTP_200_OK, "data": data_dict, "error": None, @@ -50,13 +71,24 @@ class HxfErrorResponse(JSONResponse): def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED): """Return a JSON error response.""" if isinstance(message, Exception): - content = { - "code": -1, - "status": message.status_code, - "data": None, - "error": message.details, - "message": message.message - } + msg = message.message if 'message' in message.__dict__ else str(message) + logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}") + if isinstance(message, TypeError): + content = { + "code": -1, + "status": 500, + "data": None, + "error": f"错误类型: {type(message)} 错误信息: {message}", + "message": msg + } + else: + content = { + "code": -1, + "status": message.status_code, + "data": None, + "error": None, + "message": msg + } else: content = { "code": -1, diff --git a/backend/utils/util_schemas.py b/backend/utils/util_schemas.py index ab0428f..a1809f5 100644 --- a/backend/utils/util_schemas.py +++ b/backend/utils/util_schemas.py @@ -310,6 +310,7 @@ class ConversationResponse(BaseResponse, ConversationBase): is_archived: bool message_count: int = 0 last_message_at: Optional[datetime] = None + messages: Optional[List["MessageResponse"]] = None # Message schemas @@ -345,11 +346,11 @@ class ChatRequest(BaseModel): message: str = Field(..., min_length=1, max_length=10000) stream: bool = Field(default=False) use_knowledge_base: bool = Field(default=False) - knowledge_base_id: Optional[int] = Field(None, description="Knowledge base ID for RAG mode") + knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode") use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities") use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling") - temperature: Optional[float] = Field(None, ge=0.0, le=2.0) - max_tokens: Optional[int] = Field(None, ge=1, le=8192) + temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) + max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192) class ChatResponse(BaseModel): @@ -458,6 +459,7 @@ class DocumentChunk(BaseModel): chunk_index: int start_char: Optional[int] = None end_char: Optional[int] = None + vector_id: Optional[str] = None class DocumentChunksResponse(BaseModel): diff --git a/backend/utils/vclEnums.py b/backend/utils/vclEnums.py new file mode 100644 index 0000000..494fbf6 --- /dev/null +++ b/backend/utils/vclEnums.py @@ -0,0 +1,120 @@ +from enum import Enum, unique + +CommonColors_Backgnd = [ + "#FF0000", # 红色 + "#00FF00", # 绿色 + "#0000FF", # 蓝色 + "#FFFF00", # 黄色 + "#FF00FF", # 品红 + "#00FFFF", # 青色 + "#FFA500", # 橙色 + "#800080", # 紫色 + "#FFC0CB", # 粉色 + "#008000", # 深绿色 + "#000080", # 深蓝色 + "#800000", # 深红色 + "#808000", # 橄榄色 + "#008080", # 青色 + "#808080", # 灰色 + "#FF0080", # 玫瑰红 + "#0080FF", # 天蓝色 + "#FF8000", # 橙红色 + "#8000FF", # 紫罗兰 + "#00FF80" # 海绿色 +] + +CommonColors_Foregnd = [ + "#FFFFFF", # 红色背景 -> 白色文字 + "#000000", # 绿色背景 -> 黑色文字 + "#FFFFFF", # 蓝色背景 -> 白色文字 + "#000000", # 黄色背景 -> 黑色文字 + "#FFFFFF", # 品红背景 -> 白色文字 + "#000000", # 青色背景 -> 黑色文字 + "#000000", # 橙色背景 -> 黑色文字 + "#FFFFFF", # 紫色背景 -> 白色文字 + "#000000", # 粉色背景 -> 黑色文字 + "#FFFFFF", # 深绿色背景 -> 白色文字 + "#FFFFFF", # 深蓝色背景 -> 白色文字 + "#FFFFFF", # 深红色背景 -> 白色文字 + "#FFFFFF", # 橄榄色背景 -> 白色文字 + "#FFFFFF", # 青色背景 -> 白色文字 + "#FFFFFF", # 灰色背景 -> 白色文字 + "#FFFFFF", # 玫瑰红背景 -> 白色文字 + "#000000", # 天蓝色背景 -> 黑色文字 + "#000000", # 橙红色背景 -> 黑色文字 + "#FFFFFF", # 紫罗兰背景 -> 白色文字 + "#000000" # 海绿色背景 -> 黑色文字 +] + +@unique +class LabellingKind(Enum): + Unknown = 0 + Select = 1 + Create = 2 + TempDrag = 3 + TempResize = 4 + TempRotate = 5 + +@unique +class Meta(Enum): + Unknown = 0 + Line = 1 + Rectangle = 2 + Ellipse = 3 + Polygon = 4 + Text = 5 + +@unique +class AiAlg(Enum): + Unknown = 0 + FashionMNIST = 1, + ColorDetector = 2, + Face = 3, + Coco8 = 4 + + +@unique +class Response(Enum): + OK = (0, "OK - 响应成功") + DEBUG = (1, "DEBUG - 正常调试") + WARNING = (2, "FAIL - 响应警告") + ERROR = (3, "ERROR - 响应错误") + EXCEPTION = (4, "EXCEPTION - 响应异常") + CRITICAL = (5, "CRITICAL - 响应严重错误") + +class Flag(Enum): + Paint = True + Client = False + + +class Align(Enum): + Left = 0 + Top = 1 + Right = 2 + Bottom = 3 + Client = 4 + Custom = 5 + +class Anchor(Enum): + Left = 1 + Top = 2 + Right = 4 + Bottom = 8 + +class IconPos(Enum): + NoIcon = 0 + OnlyIcon = 1 + IconLeft = 2 + IconRight = 3 + IconTop = 4 + IconBottom = 5 + +class ControlEvent(Enum): + MouseEnter = 0 + MouseLeave = 1 + MouseMove = 2 + MouseDown = 3 + MouseUp = 4 + Click = 5 + DblClick = 6 + \ No newline at end of file diff --git a/backend/utils/wssServer.py b/backend/utils/wssServer.py new file mode 100644 index 0000000..6c3b4b8 --- /dev/null +++ b/backend/utils/wssServer.py @@ -0,0 +1,136 @@ +import time, websockets, json, inspect, os +import numpy as np +from loguru import logger + +from typing import Dict, Set, Callable, Any, Optional +from traceback import format_exc + +from DrGraph.utils.Helper import * +import DrGraph.utils.vclEnums as enums + +class WebSocketServer: + """ + WebSocket服务器类 + 提供WebSocket服务器功能,包括客户端连接管理、消息处理和广播功能 + """ + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.message_handlers: Dict[str, Callable] = {} + self.server = None + self.onClientConnect = None + self.register_handler("file", self.handle_file) + + def register_handler(self, message_type: str, handler: Callable): + self.message_handlers[message_type] = handler + + async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'): + if isinstance(message, dict): + message = json.dumps(message, ensure_ascii=False) + if Helper.AppFlag_SaveLog: + logger.warning(f"发送消息: {message}") + await websocket.send(f't{message}') + caller = inspect.stack()[1] + logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ") + + async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data): + # if isinstance(data, np.ndarray): + # data = data.tobytes() + await websocket.send(bytearray(b'b' + data)) + + async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str): + try: + data = json.loads(message) + message_type = data.get("type") + payload = data.get("data") + if Helper.AppFlag_SaveLog: + logger.warning(f"收到消息: {message}") + + if message_type in self.message_handlers: + response = await self.message_handlers[message_type](websocket, payload) + if response is not None: + if isinstance(response, (bytes, bytearray, np.ndarray)): + # print("发送图片数据") + await self.response_binary(websocket, response); + else: + await self.response_text(websocket, response, message_type); + elif message_type == 'echo': + print("echo message ", int(time.time() * 1000)) + data["type"] = "echo_response" + await self.response_text(websocket, data); + else: + logger.warning(f"未知类型消息: {message} - {websocket.remote_address}") + await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}")) + except json.JSONDecodeError: + logger.error(f"无效的JSON格式 - {message}") + await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}")) + except BrokenPipeError as e: + logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}") + # BrokenPipeError表示连接已断开,不需要特殊处理,让上层处理ConnectionClosed异常 + except Exception as e: + logger.error(f"处理消息时出错: {format_exc()}") + await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}")) + + async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""): + """ + 处理客户端连接 + + 参数: + websocket (websockets.WebSocketServerProtocol): WebSocket连接对象 + path (str): 请求路径 + """ + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接") + if self.onClientConnect: + await self.onClientConnect(websocket, True) + try: + async for message in websocket: + await self.handle_message(websocket, message) + except websockets.exceptions.ConnectionClosed: + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭") + except BrokenPipeError: + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError") + finally: + if self.onClientConnect: + await self.onClientConnect(websocket, False) + + async def start(self): + """ + 启动WebSocket服务器 + """ + self.server = await websockets.serve(self.handle_client, self.host, self.port) + logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}") + + async def stop(self): + """ + 停止WebSocket服务器 + """ + if self.server: + self.server.close() + await self.server.wait_closed() + logger.info("WebSocket服务器已停止") + + async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str): + logger.info(f"接收文件: {payload}, {payload['command']}") + command = payload['command'] + if command == 'dir': + path = payload.get('path', '/') + files = [] + folders = [] + try: + # 获取目录下的所有文件和文件夹 + if os.path.exists(path) and os.path.isdir(path): + with os.scandir(path) as entries: + for entry in entries: + if entry.is_file(): + files.append(entry.name) + elif entry.is_dir(): + folders.append(entry.name) + else: + logger.warning(f"路径不存在或不是目录: {path}") + except Exception as e: + logger.error(f"读取目录时出错: {path}, 错误: {e}") + + # 合并文件和文件夹列表 + all_items = folders + files + logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}") \ No newline at end of file diff --git a/backend/webIOs/configs/th_agenter_config_logger.yml b/backend/webIOs/configs/th_agenter_config_logger.yml index b029466..f3a26fc 100644 --- a/backend/webIOs/configs/th_agenter_config_logger.yml +++ b/backend/webIOs/configs/th_agenter_config_logger.yml @@ -2,7 +2,7 @@ enable_file_log: true enable_stderr: true base_path: "webIOs/output/logs" log_name: "th_agenter_web.log" -log_fmt: "{time: HH:mm:ss.SSS} [{level}] - {message} @ {extra[relative_path]}:{line} in {function}" +log_fmt: "{time: HH:mm:ss.SSS} [{level:7}] - {message} @ {extra[relative_path]}:{line} in {function}" level: "INFO" rotation: "00:00" retention: "1 days"