Compare commits

...

2 Commits
jcq ... main

96 changed files with 5417 additions and 2020 deletions

View File

@ -1,4 +1 @@
Generic single-database configuration with an async dbapi.
alembic revision --autogenerate -m "init"
alembic upgrade head

View File

@ -0,0 +1,359 @@
"""Initial migration
Revision ID: 424646027786
Revises:
Create Date: 2025-12-16 09:56:45.172954
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '424646027786'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('agent_configs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('enabled_tools', sa.JSON(), nullable=False),
sa.Column('max_iterations', sa.Integer(), nullable=False),
sa.Column('temperature', sa.String(length=10), nullable=False),
sa.Column('system_message', sa.Text(), nullable=True),
sa.Column('verbose', sa.Boolean(), nullable=False),
sa.Column('model_name', sa.String(length=100), nullable=False),
sa.Column('max_tokens', sa.Integer(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_default', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs'))
)
op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False)
op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False)
op.create_table('conversations',
sa.Column('title', sa.String(length=200), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('knowledge_base_id', sa.Integer(), nullable=True),
sa.Column('system_prompt', sa.Text(), nullable=True),
sa.Column('model_name', sa.String(length=100), nullable=False),
sa.Column('temperature', sa.String(length=10), nullable=False),
sa.Column('max_tokens', sa.Integer(), nullable=False),
sa.Column('is_archived', sa.Boolean(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations'))
)
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
op.create_table('database_configs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('db_type', sa.String(length=20), nullable=False),
sa.Column('host', sa.String(length=255), nullable=False),
sa.Column('port', sa.Integer(), nullable=False),
sa.Column('database', sa.String(length=100), nullable=False),
sa.Column('username', sa.String(length=100), nullable=False),
sa.Column('password', sa.Text(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_default', sa.Boolean(), nullable=False),
sa.Column('connection_params', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')),
sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type'))
)
op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False)
op.create_table('documents',
sa.Column('knowledge_base_id', sa.Integer(), nullable=False),
sa.Column('filename', sa.String(length=255), nullable=False),
sa.Column('original_filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('file_size', sa.Integer(), nullable=False),
sa.Column('file_type', sa.String(length=50), nullable=False),
sa.Column('mime_type', sa.String(length=100), nullable=True),
sa.Column('is_processed', sa.Boolean(), nullable=False),
sa.Column('processing_error', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('doc_metadata', sa.JSON(), nullable=True),
sa.Column('chunk_count', sa.Integer(), nullable=False),
sa.Column('embedding_model', sa.String(length=100), nullable=True),
sa.Column('vector_ids', sa.JSON(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_documents'))
)
op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False)
op.create_table('excel_files',
sa.Column('original_filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('file_size', sa.Integer(), nullable=False),
sa.Column('file_type', sa.String(length=50), nullable=False),
sa.Column('sheet_names', sa.JSON(), nullable=False),
sa.Column('default_sheet', sa.String(length=100), nullable=True),
sa.Column('columns_info', sa.JSON(), nullable=False),
sa.Column('preview_data', sa.JSON(), nullable=False),
sa.Column('data_types', sa.JSON(), nullable=True),
sa.Column('total_rows', sa.JSON(), nullable=True),
sa.Column('total_columns', sa.JSON(), nullable=True),
sa.Column('is_processed', sa.Boolean(), nullable=False),
sa.Column('processing_error', sa.Text(), nullable=True),
sa.Column('last_accessed', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files'))
)
op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False)
op.create_table('knowledge_bases',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('embedding_model', sa.String(length=100), nullable=False),
sa.Column('chunk_size', sa.Integer(), nullable=False),
sa.Column('chunk_overlap', sa.Integer(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('vector_db_type', sa.String(length=50), nullable=False),
sa.Column('collection_name', sa.String(length=100), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases'))
)
op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False)
op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False)
op.create_table('llm_configs',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('provider', sa.String(length=50), nullable=False),
sa.Column('model_name', sa.String(length=100), nullable=False),
sa.Column('api_key', sa.String(length=500), nullable=False),
sa.Column('base_url', sa.String(length=200), nullable=True),
sa.Column('max_tokens', sa.Integer(), nullable=False),
sa.Column('temperature', sa.Float(), nullable=False),
sa.Column('top_p', sa.Float(), nullable=False),
sa.Column('frequency_penalty', sa.Float(), nullable=False),
sa.Column('presence_penalty', sa.Float(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_default', sa.Boolean(), nullable=False),
sa.Column('is_embedding', sa.Boolean(), nullable=False),
sa.Column('extra_config', sa.JSON(), nullable=True),
sa.Column('usage_count', sa.Integer(), nullable=False),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs'))
)
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
op.create_table('messages',
sa.Column('conversation_id', sa.Integer(), nullable=False),
sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False),
sa.Column('message_metadata', sa.JSON(), nullable=True),
sa.Column('context_documents', sa.JSON(), nullable=True),
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
sa.Column('completion_tokens', sa.Integer(), nullable=True),
sa.Column('total_tokens', sa.Integer(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
)
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
op.create_table('roles',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('code', sa.String(length=100), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_system', sa.Boolean(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_roles'))
)
op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True)
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
op.create_table('table_metadata',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('table_name', sa.String(length=100), nullable=False),
sa.Column('table_schema', sa.String(length=50), nullable=False),
sa.Column('table_type', sa.String(length=20), nullable=False),
sa.Column('table_comment', sa.Text(), nullable=True),
sa.Column('database_config_id', sa.Integer(), nullable=True),
sa.Column('columns_info', sa.JSON(), nullable=False),
sa.Column('primary_keys', sa.JSON(), nullable=True),
sa.Column('foreign_keys', sa.JSON(), nullable=True),
sa.Column('indexes', sa.JSON(), nullable=True),
sa.Column('sample_data', sa.JSON(), nullable=True),
sa.Column('row_count', sa.Integer(), nullable=False),
sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False),
sa.Column('qa_description', sa.Text(), nullable=True),
sa.Column('business_context', sa.Text(), nullable=True),
sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata'))
)
op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False)
op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False)
op.create_table('users',
sa.Column('username', sa.String(length=50), nullable=False),
sa.Column('email', sa.String(length=100), nullable=False),
sa.Column('hashed_password', sa.String(length=255), nullable=False),
sa.Column('full_name', sa.String(length=100), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('avatar_url', sa.String(length=255), nullable=True),
sa.Column('bio', sa.Text(), nullable=True),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_users'))
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
op.create_table('user_roles',
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('role_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')),
sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles'))
)
op.create_table('workflows',
sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'),
sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'),
sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'),
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'),
sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'),
sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')),
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows'))
)
op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False)
op.create_table('workflow_executions',
sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'),
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')),
sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')),
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions'))
)
op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False)
op.create_table('node_executions',
sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'),
sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'),
sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'),
sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'),
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('updated_by', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')),
sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions'))
)
op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions')
op.drop_table('node_executions')
op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions')
op.drop_table('workflow_executions')
op.drop_index(op.f('ix_workflows_id'), table_name='workflows')
op.drop_table('workflows')
op.drop_table('user_roles')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata')
op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata')
op.drop_table('table_metadata')
op.drop_index(op.f('ix_roles_name'), table_name='roles')
op.drop_index(op.f('ix_roles_id'), table_name='roles')
op.drop_index(op.f('ix_roles_code'), table_name='roles')
op.drop_table('roles')
op.drop_index(op.f('ix_messages_id'), table_name='messages')
op.drop_table('messages')
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
op.drop_table('llm_configs')
op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases')
op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases')
op.drop_table('knowledge_bases')
op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files')
op.drop_table('excel_files')
op.drop_index(op.f('ix_documents_id'), table_name='documents')
op.drop_table('documents')
op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs')
op.drop_table('database_configs')
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
op.drop_table('conversations')
op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs')
op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs')
op.drop_table('agent_configs')
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""Add message_count and last_message_at to conversations
Revision ID: 8da391c6e2b7
Revises: 424646027786
Create Date: 2025-12-19 16:16:29.943314
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '8da391c6e2b7'
down_revision: Union[str, Sequence[str], None] = '424646027786'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversations', sa.Column('message_count', sa.Integer(), nullable=False))
op.add_column('conversations', sa.Column('last_message_at', sa.DateTime(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('conversations', 'last_message_at')
op.drop_column('conversations', 'message_count')
# ### end Alembic commands ###

View File

@ -0,0 +1,42 @@
from sqlalchemy import create_engine, inspect
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
import asyncio
import os
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
async def check_table_constraints():
try:
# 获取数据库连接字符串
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+asyncmy://root:123456@localhost:3306/th_agenter")
# 创建异步引擎
engine = create_async_engine(DATABASE_URL, echo=True)
# 创建会话
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
# 获取数据库连接
async with session.begin():
# 使用inspect查看表结构
inspector = inspect(engine)
# 获取messages表的所有约束
constraints = await engine.run_sync(inspector.get_table_constraints, 'messages')
print("Messages表的所有约束:")
for constraint in constraints:
print(f" 约束名称: {constraint['name']}, 类型: {constraint['type']}")
if constraint['type'] == 'PRIMARY KEY':
print(f" 主键约束列: {constraint['constrained_columns']}")
except Exception as e:
print(f"检查约束时出错: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(check_table_constraints())

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -38,7 +38,7 @@ def setup_exception_handlers(app: FastAPI) -> None:
async def http_exception_handler(request, exc):
from utils.util_exceptions import HxfErrorResponse
logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}")
return HxfErrorResponse(exc.status_code, exc.detail)
return HxfErrorResponse(exc)
def make_json_serializable(obj):
"""递归地将对象转换为JSON可序列化的格式"""
@ -127,31 +127,5 @@ def add_router(app: FastAPI) -> None:
# Include routers
app.include_router(router, prefix="/api")
# app.include_router(table_metadata.router)
# # 在现有导入中添加
# from ..api.endpoints import database_config
# # 在路由注册部分添加
# app.include_router(database_config.router)
# # Health check endpoint
# @app.get("/health")
# async def health_check():
# return {"status": "healthy", "version": settings.app_version}
# # Root endpoint
# @app.get("/")
# async def root():
# return {"message": "Chat Agent API is running"}
# # Test endpoint
# @app.get("/test")
# async def test_endpoint():
# return {"message": "API is working"}
app = create_app()
# from utils.util_test import test_db
# test_db()
# from test.example import internet_search_tool

1
backend/test/__init__.py Normal file
View File

@ -0,0 +1 @@
# Test package for PostgreSQL agent functionality

154
backend/test/example.py Normal file
View File

@ -0,0 +1,154 @@
import os
import asyncio
from datetime import datetime
from deepagents import create_deep_agent
from openai import OpenAI
from langchain.chat_models import init_chat_model
from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具
from deepagents.backends import StoreBackend
from loguru import logger
def internet_search_tool(query: str):
"""Run a web search"""
logger.info(f"Running internet search for query: {query}")
client = OpenAI(
api_key=os.getenv('DASHSCOPE_API_KEY'),
base_url=os.getenv('DASHSCOPE_BASE_URL'),
)
logger.info(f"create OpenAI")
completion = client.chat.completions.create(
model="qwen-plus",
messages=[
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': query}
],
extra_body={
"enable_search": True
}
)
logger.info(f"create completions")
logger.info(f"OpenAI response: {completion.choices[0].message.content}")
return completion.choices[0].message.content
# System prompt to steer the agent to be an expert researcher
today = datetime.now().strftime("%Y年%m月%d")
research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。
你可以使用互联网搜索工具来获取信息
## `internet_search`
使用此工具对给定查询进行互联网搜索你可以指定返回结果的最大数量主题以及是否包含原始内容
今天的日期是{today}
"""
# Create the deep agent with memory
model = init_chat_model(
model="gpt-4.1-mini",
model_provider='openai',
api_key=os.getenv('OPENAI_API_KEY'),
base_url=os.getenv('OPENAI_BASE_URL'),
)
checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史
agent = create_deep_agent( # statethread会话级的状态
tools=[internet_search_tool],
system_prompt=research_instructions,
model=model,
checkpointer=checkpointer, # 添加检查点,启用自动记忆
interrupt_on={'internet_search_tool':True}
)
# 多轮对话循环(使用 Checkpointer 自动记忆)
printed_msg_ids = set() # 跟踪已打印的消息ID
thread_id = "user_session_001" # 会话 ID区分不同用户/会话
config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话
print("开始对话(输入 'exit' 退出):")
while True:
user_input = input("\nHUMAN: ").strip()
if user_input.lower() == 'exit':
break
# 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印
pending_resume = None
while True:
if pending_resume is None:
request = {"messages": [{"role": "user", "content": user_input}]}
else:
from langgraph.types import Command as _Command
request = _Command(resume=pending_resume)
pending_resume = None
for item in agent.stream(
request,
config=config,
stream_mode="values",
):
state = item[0] if isinstance(item, tuple) and len(item) == 2 else item
# 先检查是否触发了 Human-In-The-Loop 中断
if isinstance(state, dict) and "__interrupt__" in state:
interrupts = state["__interrupt__"] or []
if interrupts:
hitl_payload = interrupts[0].value
action_requests = hitl_payload.get("action_requests", [])
print("\n=== 需要人工审批的工具调用 ===")
decisions: list[dict[str, str]] = []
for idx, ar in enumerate(action_requests):
name = ar.get("name")
args = ar.get("args")
print(f"[{idx}] 工具 {name} 参数: {args}")
while True:
choice = input(" 决策 (a=approve, r=reject): ").strip().lower()
if choice in ("a", "r"):
break
decisions.append({"type": "approve" if choice == "a" else "reject"})
# 下一轮调用改为 resume同一轮用户回合继续往下跑
pending_resume = {"decisions": decisions}
break
# 兼容 dict state 和 AgentState dataclass
messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", [])
for msg in messages:
msg_id = getattr(msg, "id", None)
if msg_id is not None and msg_id in printed_msg_ids:
continue
if msg_id is not None:
printed_msg_ids.add(msg_id)
msg_type = getattr(msg, "type", None)
if msg_type == "human":
# 用户输入已经在命令行里,不再重复打印
continue
if msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
# 这是发起工具调用的 AI 消息TOOL CALL
for tc in tool_calls:
tool_name = tc.get("name")
args = tc.get("args")
print(f"TOOL CALL [{tool_name}]: {args}")
# 如果 AI 同时带有自然语言内容,也一起打印
if getattr(msg, "content", None):
print(f"AI: {msg.content}")
continue
if msg_type == "tool":
# 工具执行结果TOOL RESPONSE
tool_name = getattr(msg, "name", None) or "tool"
print(f"TOOL RESPONSE [{tool_name}]: {msg.content}")
continue
# 兜底:其它类型直接打印出来便于调试
print(f"[{msg_type}]: {getattr(msg, 'content', None)}")
# 如果没有新的中断需要 resume则整轮结束等待下一轮用户输入
if pending_resume is None:
break

View File

@ -95,11 +95,13 @@ async def login_oauth(
)
session.desc = f"用户 {user.username} OAuth2 登录成功"
return {
return HxfResponse(
{
"access_token": access_token,
"token_type": "bearer",
"expires_in": settings.security.access_token_expire_minutes * 60
}
)
@router.post("/refresh", response_model=Token, summary="刷新访问token")
async def refresh_token(
@ -113,15 +115,17 @@ async def refresh_token(
session, data={"sub": current_user.username}, expires_delta=access_token_expires
)
return Token(
response = Token(
access_token=access_token,
token_type="bearer",
expires_in=settings.security.access_token_expire_minutes * 60
)
return HxfResponse(response)
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
async def get_current_user_info(
current_user = Depends(AuthService.get_current_user)
):
"""获取当前用户信息"""
return UserResponse.model_validate(current_user, from_attributes=True)
response = UserResponse.model_validate(current_user, from_attributes=True)
return HxfResponse(response)

View File

@ -1,5 +1,6 @@
"""Chat endpoints for TH Agenter."""
import json
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
@ -11,6 +12,8 @@ from ...models.user import User
from ...services.auth import AuthService
from ...services.chat import ChatService
from ...services.conversation import ConversationService
from utils.util_exceptions import HxfResponse
from utils.util_schemas import (
ConversationCreate,
ConversationResponse,
@ -23,83 +26,6 @@ from utils.util_schemas import (
router = APIRouter()
# Conversation management
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
async def create_conversation(
conversation_data: ConversationCreate,
current_user: User = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""创建新对话"""
session.desc = "START: 创建新对话"
conversation_service = ConversationService(session)
conversation = await conversation_service.create_conversation(
user_id=current_user.id,
conversation_data=conversation_data
)
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}"
return ConversationResponse.model_validate(conversation)
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
async def list_conversations(
skip: int = 0,
limit: int = 50,
search: str = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True,
session: Session = Depends(get_session)
):
"""获取用户对话列表"""
session.desc = "START: 获取用户对话列表"
conversation_service = ConversationService(session)
conversations = await conversation_service.get_user_conversations(
skip=skip,
limit=limit,
search_query=search,
include_archived=include_archived,
order_by=order_by,
order_desc=order_desc
)
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话"
return [ConversationResponse.model_validate(conv) for conv in conversations]
@router.get("/conversations/count", summary="获取用户对话总数")
async def get_conversations_count(
search: str = None,
include_archived: bool = False,
session: Session = Depends(get_session)
):
"""获取用户对话总数"""
session.desc = "START: 获取用户对话总数"
conversation_service = ConversationService(session)
count = await conversation_service.get_user_conversations_count(
search_query=search,
include_archived=include_archived
)
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
return {"count": count}
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
async def get_conversation(
conversation_id: int,
session: Session = Depends(get_session)
):
"""获取指定对话"""
session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}"
conversation_service = ConversationService(session)
conversation = await conversation_service.get_conversation(
conversation_id=conversation_id
)
if not conversation:
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}"
return ConversationResponse.model_validate(conversation)
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
async def update_conversation(
conversation_id: int,
@ -113,7 +39,8 @@ async def update_conversation(
conversation_id, conversation_update
)
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
return ConversationResponse.model_validate(updated_conversation)
response = ConversationResponse.model_validate(updated_conversation)
return HxfResponse(response)
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
@ -126,7 +53,8 @@ async def delete_conversation(
conversation_service = ConversationService(session)
await conversation_service.delete_conversation(conversation_id)
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
return {"message": "Conversation deleted successfully"}
response = {"message": "Conversation deleted successfully"}
return HxfResponse(response)
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
@ -145,7 +73,8 @@ async def archive_conversation(
)
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
return {"message": "Conversation archived successfully"}
response = {"message": "Conversation archived successfully"}
return HxfResponse(response)
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
@ -165,7 +94,8 @@ async def unarchive_conversation(
)
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
return {"message": "Conversation unarchived successfully"}
response = {"message": "Conversation unarchived successfully"}
return HxfResponse(response)
# Message management
@ -183,7 +113,8 @@ async def get_conversation_messages(
conversation_id, skip=skip, limit=limit
)
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
return [MessageResponse.model_validate(msg) for msg in messages]
response = [MessageResponse.model_validate(msg) for msg in messages]
return HxfResponse(response)
# Chat functionality
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
@ -195,48 +126,158 @@ async def chat(
"""发送消息并获取AI响应"""
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
chat_service = ChatService(session)
response = await chat_service.chat(
conversation_id=conversation_id,
message=chat_request.message,
stream=False,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
use_agent=chat_request.use_agent,
use_langgraph=chat_request.use_langgraph,
use_knowledge_base=chat_request.use_knowledge_base,
knowledge_base_id=chat_request.knowledge_base_id
)
await chat_service.initialize(conversation_id)
# response = await chat_service.chat(
# conversation_id=conversation_id,
# message=chat_request.message,
# stream=False,
# temperature=chat_request.temperature,
# max_tokens=chat_request.max_tokens,
# use_agent=chat_request.use_agent, # 可以简化掉
# use_langgraph=chat_request.use_langgraph, # 可以简化掉
# use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉
# knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉
# )
response = "oooooooooooooooooooK"
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
return response
return HxfResponse(response)
# ------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
async def chat_stream(
conversation_id: int,
chat_request: ChatRequest,
current_user = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""发送消息并获取流式AI响应."""
session.title = f"对话{conversation_id} 发送消息并获取流式AI响应"
session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> "
chat_service = ChatService(session)
await chat_service.initialize(conversation_id, streaming=True)
async def generate_response():
async def generate_response(chat_service):
try:
async for chunk in chat_service.chat_stream(
conversation_id=conversation_id,
message=chat_request.message,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
use_agent=chat_request.use_agent,
use_langgraph=chat_request.use_langgraph,
use_knowledge_base=chat_request.use_knowledge_base,
knowledge_base_id=chat_request.knowledge_base_id
message=chat_request.message
):
yield f"data: {chunk}\n\n"
yield chunk + "\n"
except Exception as e:
logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}")
yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"}
return StreamingResponse(
generate_response(),
media_type="text/plain",
response = StreamingResponse(
generate_response(chat_service),
media_type="text/stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
return response
# Conversation management
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
async def create_conversation(
conversation_data: ConversationCreate,
current_user: User = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
):
"""创建新对话"""
id = current_user.id
session.title = f"用户{current_user.username} - 创建新对话"
session.desc = "START: 创建新对话"
conversation_service = ConversationService(session)
conversation = await conversation_service.create_conversation(
user_id=id,
conversation_data=conversation_data
)
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}"
response = ConversationResponse.model_validate(conversation)
return HxfResponse(response)
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
async def list_conversations(
skip: int = 0,
limit: int = 50,
search: str = None,
include_archived: bool = False,
order_by: str = "updated_at",
order_desc: bool = True,
session: Session = Depends(get_session)
):
"""获取用户对话列表"""
session.title = "获取用户对话列表"
session.desc = "START: 获取用户对话列表"
conversation_service = ConversationService(session)
conversations = await conversation_service.get_user_conversations(
skip=skip,
limit=limit,
search_query=search,
include_archived=include_archived,
order_by=order_by,
order_desc=order_desc
)
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..."
response = [ConversationResponse.model_validate(conv) for conv in conversations]
return HxfResponse(response)
@router.get("/conversations/count", summary="获取用户对话总数")
async def get_conversations_count(
search: str = None,
include_archived: bool = False,
session: Session = Depends(get_session)
):
"""获取用户对话总数"""
from th_agenter.core.context import UserContext
user_id = UserContext.get_current_user_id()
session.title = f"获取用户对话总数[用户id = {user_id}]"
session.desc = "START: 获取用户对话总数"
conversation_service = ConversationService(session)
count = await conversation_service.get_user_conversations_count(
search_query=search,
include_archived=include_archived
)
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
response = {"count": count}
return HxfResponse(response)
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
async def get_conversation(
conversation_id: int,
session: Session = Depends(get_session)
):
"""获取指定对话"""
session.title = f"获取指定对话[对话id = {conversation_id}]"
session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}"
conversation_service = ConversationService(session)
conversation = await conversation_service.get_conversation(
conversation_id=conversation_id
)
if not conversation:
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Conversation not found"
)
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}"
response = ConversationResponse.model_validate(conversation)
# chat_service = ChatService(session)
# await chat_service.initialize(conversation_id, streaming=False)
# messages = await chat_service.get_conversation_history_messages(
# conversation_id
# )
# response.messages = messages
messages = await conversation_service.get_conversation_messages(
conversation_id, skip=0, limit=100
)
response.messages = [MessageResponse.model_validate(msg) for msg in messages]
response.message_count = len(response.messages)
return HxfResponse(response)

View File

@ -9,7 +9,7 @@ from th_agenter.db.database import get_session
from th_agenter.services.database_config_service import DatabaseConfigService
from th_agenter.services.auth import AuthService
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
from utils.util_exceptions import HxfResponse
# 在文件顶部添加
from functools import lru_cache
@ -68,11 +68,12 @@ async def create_database_config(
):
"""创建或更新数据库配置"""
config = await service.create_or_update_config(current_user.id, config_data.model_dump())
return NormalResponse(
response = NormalResponse(
success=True,
message="保存数据库配置成功",
data=config
)
return HxfResponse(response)
@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表")
async def get_database_configs(
@ -83,7 +84,7 @@ async def get_database_configs(
configs = service.get_user_configs(current_user.id)
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
return config_list
return HxfResponse(config_list)
@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接")
async def test_database_connection(
@ -93,7 +94,7 @@ async def test_database_connection(
):
"""测试数据库连接"""
result = await service.test_connection(config_id, current_user.id)
return result
return HxfResponse(result)
@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表")
async def connect_database(
@ -103,7 +104,7 @@ async def connect_database(
):
"""连接数据库并获取表列表"""
result = await service.connect_and_get_tables(config_id, current_user.id)
return result
return HxfResponse(result)
@router.get("/tables/{table_name}/data", summary="获取表数据预览")
@ -117,7 +118,7 @@ async def get_table_data(
"""获取表数据预览"""
try:
result = await service.get_table_data(table_name, current_user.id, db_type, limit)
return result
return HxfResponse(result)
except Exception as e:
logger.error(f"获取表数据失败: {str(e)}")
raise HTTPException(
@ -133,7 +134,7 @@ async def get_table_schema(
):
"""获取表结构信息"""
result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的?
return result
return HxfResponse(result)
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置")
async def get_config_by_type(
@ -149,4 +150,4 @@ async def get_config_by_type(
detail=f"未找到类型为 {db_type} 的配置"
)
# 返回包含解密密码的配置
return config.to_dict(include_password=True, decrypt_service=service)
return HxfResponse(config.to_dict(include_password=True, decrypt_service=service))

View File

@ -1,5 +1,6 @@
"""Knowledge base API endpoints."""
from utils.util_exceptions import HxfResponse
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import JSONResponse
@ -35,10 +36,10 @@ async def create_knowledge_base(
):
"""创建新的知识库"""
# Check if knowledge base with same name already exists for this user
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data.name}"
service = KnowledgeBaseService(session)
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}"
kb_service = KnowledgeBaseService(session)
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
existing_kb = service.get_knowledge_base_by_name(kb_data.name)
existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name)
if existing_kb:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -47,10 +48,10 @@ async def create_knowledge_base(
# Create knowledge base
session.desc = f"知识库 {kb_data.name}不存在,创建之"
kb = service.create_knowledge_base(kb_data)
kb = await kb_service.create_knowledge_base(kb_data)
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
return KnowledgeBaseResponse(
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
@ -65,7 +66,7 @@ async def create_knowledge_base(
document_count=0,
active_document_count=0
)
return HxfResponse(response)
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
async def list_knowledge_bases(
@ -76,18 +77,17 @@ async def list_knowledge_bases(
current_user: User = Depends(AuthService.get_current_user)
):
"""获取当前用户的所有知识库"""
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库"
service = KnowledgeBaseService(session)
session.desc = f"获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
knowledge_bases = await service.get_knowledge_bases(skip=skip, limit=limit)
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
kb_service = KnowledgeBaseService(session)
knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit)
result = []
for kb in knowledge_bases:
# Count documents
# 本知识库的文档数量
total_docs = await session.scalar(
select(func.count()).where(Document.knowledge_base_id == kb.id)
)
# 本知识库的已处理文档数量
active_docs = await session.scalar(
select(func.count()).where(
Document.knowledge_base_id == kb.id,
@ -112,7 +112,7 @@ async def list_knowledge_bases(
))
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
return result
return HxfResponse(result)
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
async def get_knowledge_base(
@ -124,7 +124,7 @@ async def get_knowledge_base(
session.desc = f"START: 获取知识库 {kb_id} 的详情"
service = KnowledgeBaseService(session)
session.desc = f"检查知识库 {kb_id} 是否存在"
kb = service.get_knowledge_base(kb_id)
kb = await service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
@ -146,7 +146,7 @@ async def get_knowledge_base(
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
return KnowledgeBaseResponse(
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
@ -161,6 +161,7 @@ async def get_knowledge_base(
document_count=total_docs,
active_document_count=active_docs
)
return HxfResponse(response)
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
async def update_knowledge_base(
@ -172,7 +173,7 @@ async def update_knowledge_base(
"""更新知识库"""
session.desc = f"START: 更新知识库 {kb_id}"
service = KnowledgeBaseService(session)
kb = service.update_knowledge_base(kb_id, kb_data)
kb = await service.update_knowledge_base(kb_id, kb_data)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
@ -192,8 +193,8 @@ async def update_knowledge_base(
)
)
session.desc = f"SUCCESS: 更新知识库 {kb_id}{total_docs} 个文档,其中 {active_docs} 个已处理"
return KnowledgeBaseResponse(
session.desc = f"SUCCESS: 更新知识库 {kb_id}结果 - {total_docs} 个文档,其中 {active_docs} 个已处理"
response = KnowledgeBaseResponse(
id=kb.id,
created_at=kb.created_at,
updated_at=kb.updated_at,
@ -208,6 +209,7 @@ async def update_knowledge_base(
document_count=total_docs,
active_document_count=active_docs
)
return HxfResponse(response)
@router.delete("/{kb_id}", summary="删除知识库")
async def delete_knowledge_base(
@ -218,7 +220,7 @@ async def delete_knowledge_base(
"""删除知识库"""
session.desc = f"START: 删除知识库 {kb_id}"
service = KnowledgeBaseService(session)
success = service.delete_knowledge_base(kb_id)
success = await service.delete_knowledge_base(kb_id)
if not success:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
@ -227,7 +229,7 @@ async def delete_knowledge_base(
)
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
return {"message": "Knowledge base deleted successfully"}
return HxfResponse({"message": "Knowledge base deleted successfully"})
# Document management endpoints
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
@ -239,18 +241,18 @@ async def upload_document(
current_user: User = Depends(AuthService.get_current_user)
):
"""上传文档到知识库"""
session.desc = f"START: 上传文档到知识库 {kb_id}"
session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}"
# Validate file
if not FileUtils.validate_file_extension(file.filename):
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
@ -258,7 +260,6 @@ async def upload_document(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
)
# Check file size (50MB limit)
max_size = 50 * 1024 * 1024 # 50MB
if file.size and file.size > max_size:
@ -268,6 +269,7 @@ async def upload_document(
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
)
session.desc = f"文件为期望类型,处理文件 {file.filename} - "
# Upload document
doc_service = DocumentService(session)
document = await doc_service.upload_document(
@ -284,7 +286,7 @@ async def upload_document(
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
return DocumentResponse(
response = DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
@ -301,6 +303,7 @@ async def upload_document(
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
return HxfResponse(response)
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
async def list_documents(
@ -314,7 +317,8 @@ async def list_documents(
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
@ -323,7 +327,7 @@ async def list_documents(
)
doc_service = DocumentService(session)
documents, total = doc_service.list_documents(kb_id, skip, limit)
documents, total = await doc_service.list_documents(kb_id, skip, limit)
doc_responses = []
for doc in documents:
@ -346,208 +350,13 @@ async def list_documents(
))
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total}"
return DocumentListResponse(
response = DocumentListResponse(
documents=doc_responses,
total=total,
page=skip // limit + 1,
page_size=limit
)
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
async def get_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档详情。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
return DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
async def delete_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库中的文档。"""
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
success = doc_service.delete_document(doc_id, kb_id)
if not success:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
return {"message": "Document deleted successfully"}
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
async def process_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""处理知识库中的文档,用于向量搜索。"""
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Check if document exists
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Process the document
result = await doc_service.process_document(doc_id, kb_id)
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
return DocumentProcessingStatus(
document_id=doc_id,
status=result["status"],
progress=result.get("progress", 0.0),
error_message=result.get("error_message"),
chunks_created=result.get("chunks_created", 0)
)
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
async def get_document_processing_status(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档处理状态。"""
# Verify knowledge base exists and user has access
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Determine status
if document.processing_error:
status_str = "failed"
progress = 0.0
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
elif document.is_processed:
status_str = "completed"
progress = 100.0
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
else:
status_str = "pending"
progress = 0.0
session.desc = f"文档 {doc_id} 处理pending中"
return DocumentProcessingStatus(
document_id=document.id,
status=status_str,
progress=progress,
error_message=document.processing_error,
chunks_created=document.chunk_count or 0
)
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
async def search_knowledge_base(
kb_id: int,
query: str,
limit: int = 5,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""在知识库中搜索文档。"""
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
kb_service = KnowledgeBaseService(session)
kb = kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Perform search
doc_service = DocumentService(session)
results = doc_service.search_documents(kb_id, query, limit)
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
return {
"knowledge_base_id": kb_id,
"query": query,
"results": results,
"total_results": len(results)
}
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
async def get_document_chunks(
@ -570,7 +379,8 @@ async def get_document_chunks(
"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
kb_service = KnowledgeBaseService(session)
knowledge_base = kb_service.get_knowledge_base(kb_id)
knowledge_base = await kb_service.get_knowledge_base(kb_id)
if not knowledge_base:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
@ -580,7 +390,9 @@ async def get_document_chunks(
# Verify document exists in the knowledge base
doc_service = DocumentService(session)
document = doc_service.get_document(doc_id, kb_id)
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService"
document = await doc_service.get_document(doc_id, kb_id)
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document"
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
@ -589,11 +401,215 @@ async def get_document_chunks(
)
# Get document chunks
chunks = doc_service.get_document_chunks(doc_id)
chunks = await doc_service.get_document_chunks(doc_id)
session.desc = f"SUCCESS: 获取文档 {doc_id}{len(chunks)} 个文档块(片段)"
return DocumentChunksResponse(
response = DocumentChunksResponse(
document_id=doc_id,
document_name=document.filename,
total_chunks=len(chunks),
chunks=chunks
)
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
async def get_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档详情。"""
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
# Verify knowledge base exists and user has access
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
response = DocumentResponse(
id=document.id,
created_at=document.created_at,
updated_at=document.updated_at,
knowledge_base_id=document.knowledge_base_id,
filename=document.filename,
original_filename=document.original_filename,
file_path=document.file_path,
file_type=document.file_type,
file_size=document.file_size,
mime_type=document.mime_type,
is_processed=document.is_processed,
processing_error=document.processing_error,
chunk_count=document.chunk_count or 0,
embedding_model=document.embedding_model,
file_size_mb=round(document.file_size / (1024 * 1024), 2)
)
return HxfResponse(response)
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
async def delete_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""删除知识库中的文档。"""
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
success = await doc_service.delete_document(doc_id, kb_id)
if not success:
session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
response = {"message": "Document deleted successfully"}
return HxfResponse(response)
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
async def process_document(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""处理知识库中的文档,用于向量搜索。"""
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Check if document exists
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Process the document
result = await doc_service.process_document(doc_id, kb_id)
await session.refresh(document)
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
response = DocumentProcessingStatus(
document_id=doc_id,
status=result["status"],
progress=result.get("progress", 0.0),
error_message=result.get("error_message"),
chunks_created=result.get("chunks_created", 0)
)
return HxfResponse(response)
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
async def get_document_processing_status(
kb_id: int,
doc_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""获取知识库中的文档处理状态。"""
# Verify knowledge base exists and user has access
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
doc_service = DocumentService(session)
document = await doc_service.get_document(doc_id, kb_id)
if not document:
session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Determine status
if document.processing_error:
status_str = "failed"
progress = 0.0
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
elif document.is_processed:
status_str = "completed"
progress = 100.0
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
else:
status_str = "pending"
progress = 0.0
session.desc = f"文档 {doc_id} 处理pending中"
response = DocumentProcessingStatus(
document_id=document.id,
status=status_str,
progress=progress,
error_message=document.processing_error,
chunks_created=document.chunk_count or 0
)
return HxfResponse(response)
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
async def search_knowledge_base(
kb_id: int,
query: str,
limit: int = 5,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""在知识库中搜索文档。"""
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(kb_id)
if not kb:
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Knowledge base not found"
)
# Perform search
doc_service = DocumentService(session)
results = await doc_service.search_documents(kb_id, query, limit)
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
response = {
"knowledge_base_id": kb_id,
"query": query,
"results": results,
"total_results": len(results)
}
return HxfResponse(response)

View File

@ -1,14 +1,21 @@
"""LLM configuration management API endpoints."""
from turtle import textinput
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage
from sqlalchemy.orm import Session
from sqlalchemy import or_, select, delete, update
from loguru import logger
from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM
from th_agenter.llm.online.online_llm import OnlineLLM
from ...db.database import get_session
from ...models.user import User
from ...models.llm_config import LLMConfig
from th_agenter.llm.base_llm import LLMConfig_DataClass
from ...core.simple_permissions import require_super_admin, require_authenticated_user
from ...schemas.llm_config import (
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
@ -31,6 +38,7 @@ async def get_llm_configs(
current_user: User = Depends(require_authenticated_user)
):
"""获取大模型配置列表."""
session.title = "获取大模型配置列表"
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
stmt = select(LLMConfig)
@ -62,7 +70,7 @@ async def get_llm_configs(
# 分页
stmt = stmt.offset(skip).limit(limit)
configs = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置"
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..."
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
@ -131,7 +139,7 @@ async def get_llm_config(
):
"""获取大模型配置详情."""
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -149,17 +157,20 @@ async def create_llm_config(
):
"""创建大模型配置."""
# 检查配置名称是否已存在
# 先保存当前用户名避免在refresh后访问可能导致MissingGreenlet错误
username = current_user.username
session.desc = f"START: 创建大模型配置, name={config_data.name}"
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
existing_config = session.execute(stmt).scalar_one_or_none()
existing_config = (await session.execute(stmt)).scalar_one_or_none()
if existing_config:
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="配置名称已存在"
)
# 创建临时配置对象进行验证
temp_config = LLMConfig(
# 创建配置对象
config = LLMConfig_DataClass(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
@ -178,7 +189,7 @@ async def create_llm_config(
)
# 验证配置
validation_result = temp_config.validate_config()
validation_result = config.validate_config()
if not validation_result['valid']:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -190,10 +201,11 @@ async def create_llm_config(
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == config_data.is_embedding
).values({"is_default": False})
session.execute(stmt)
await session.execute(stmt)
session.desc = f"验证大模型配置, config_data"
# 创建配置
config = LLMConfig(
config = LLMConfig_DataClass(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
@ -213,10 +225,9 @@ async def create_llm_config(
# Audit fields are set automatically by SQLAlchemy event listener
session.add(config)
session.commit()
session.refresh(config)
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}"
await session.commit()
await session.refresh(config)
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}"
return HxfResponse(config.to_dict())
@ -228,9 +239,10 @@ async def update_llm_config(
current_user: User = Depends(require_super_admin)
):
"""更新大模型配置."""
username = current_user.username
session.desc = f"START: 更新大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -243,7 +255,7 @@ async def update_llm_config(
LLMConfig.name == config_data.name,
LLMConfig.id != config_id
)
existing_config = session.execute(stmt).scalar_one_or_none()
existing_config = (await session.execute(stmt)).scalar_one_or_none()
if existing_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -258,17 +270,17 @@ async def update_llm_config(
LLMConfig.is_embedding == is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
session.execute(stmt)
await session.execute(stmt)
# 更新字段
update_data = config_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
session.commit()
session.refresh(config)
await session.commit()
await session.refresh(config)
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}"
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}"
return HxfResponse(config.to_dict())
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
@ -278,22 +290,24 @@ async def delete_llm_config(
current_user: User = Depends(require_super_admin)
):
"""删除大模型配置."""
username = current_user.username
session.desc = f"START: 删除大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
session.desc = f"待删除大模型记录 {config.to_dict()}"
# TODO: 检查是否有对话或其他功能正在使用该配置
# 这里可以添加相关的检查逻辑
session.delete(config)
session.commit()
# 删除配置
await session.delete(config)
await session.commit()
session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}"
session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}"
return HxfResponse({"message": "LLM config deleted successfully"})
@router.post("/{config_id}/test", summary="测试连接大模型配置")
@ -304,17 +318,21 @@ async def test_llm_config(
current_user: User = Depends(require_super_admin)
):
"""测试连接大模型配置."""
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}"
username = current_user.username
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}")
config_name = config.name
# 验证配置
validation_result = config.validate_config()
logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}")
if not validation_result["valid"]:
return {
"success": False,
@ -322,41 +340,54 @@ async def test_llm_config(
"details": validation_result
}
session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}"
# 尝试创建客户端并发送测试请求
try:
# 这里应该根据不同的服务商创建相应的客户端
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
# # 这里应该根据不同的服务商创建相应的客户端
# # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
test_message = test_data.message or "Hello, this is a test message."
session.desc = f"准备测试LLM功能 > test_message = {test_message}"
# TODO: 实现具体的测试逻辑
# 例如:
# client = config.get_client()
# response = client.chat.completions.create(
# model=config.model_name,
# messages=[{"role": "user", "content": test_message}],
# max_tokens=100
# )
if config.is_embedding:
config.provider = "ollama"
streaming_llm = BGEEmbedLLM(config)
else:
streaming_llm = OnlineLLM(config)
session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}"
streaming_llm.load_model() # 加载模型
session.desc = f"加载模型完毕,模型名称:{config.model_name}base_url: {config.base_url},准备测试对话..."
# 模拟测试成功
session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}"
if config.is_embedding:
# 测试嵌入模型使用嵌入API而非聊天API
test_text = test_message or "Hello, this is a test message for embedding"
response = streaming_llm.embed_query(test_text)
else:
# 测试聊天模型
from langchain.messages import SystemMessage, HumanMessage
messages = [
SystemMessage(content="你是一个简洁的助手回答控制在50字以内"),
HumanMessage(content=test_message)
]
response = streaming_llm.model.invoke(messages)
session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}"
return HxfResponse({
"success": True,
"message": "配置测试成功",
"test_message": test_message,
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
"message": "LLM测试成功",
"request": test_message,
"response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典
"latency_ms": 150, # 模拟延迟
"config_info": config.get_client_config()
"config_info": config.to_dict()
})
except Exception as test_error:
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
return HxfResponse({
"success": False,
"message": f"配置测试失败: {str(test_error)}",
"message": f"LLM测试失败: {str(test_error)}",
"test_message": test_message,
"config_info": config.get_client_config()
"config_info": config.to_dict()
})
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
@ -366,10 +397,11 @@ async def toggle_llm_config_status(
current_user: User = Depends(require_super_admin)
):
"""切换大模型配置状态."""
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}"
username = current_user.username
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -380,11 +412,11 @@ async def toggle_llm_config_status(
config.is_active = not config.is_active
# Audit fields are set automatically by SQLAlchemy event listener
session.commit()
session.refresh(config)
await session.commit()
await session.refresh(config)
status_text = "激活" if config.is_active else "禁用"
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}"
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}"
return HxfResponse({
"message": f"配置已{status_text}",
@ -399,10 +431,11 @@ async def set_default_llm_config(
current_user: User = Depends(require_super_admin)
):
"""设置默认大模型配置."""
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}"
username = current_user.username
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -421,19 +454,19 @@ async def set_default_llm_config(
LLMConfig.is_embedding == config.is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
session.execute(stmt)
await session.execute(stmt)
# 设置当前配置为默认
config.is_default = True
config.set_audit_fields(current_user.id, is_update=True)
session.commit()
session.refresh(config)
await session.commit()
await session.refresh(config)
model_type = "嵌入模型" if config.is_embedding else "对话模型"
# 更新文档处理器默认embedding
get_document_processor()._init_embeddings()
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}"
await get_document_processor(session)._init_embeddings()
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}"
return HxfResponse({
"message": f"已将 {config.name} 设为默认{model_type}配置",
"is_default": config.is_default

View File

@ -1,5 +1,6 @@
"""Role management API endpoints."""
from utils.util_exceptions import HxfResponse
from loguru import logger
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
@ -49,7 +50,8 @@ async def get_roles(
stmt = stmt.offset(skip).limit(limit)
roles = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 用户 {current_user.username}{len(roles)} 个角色"
return [role.to_dict() for role in roles]
response = [role.to_dict() for role in roles]
return HxfResponse(response)
@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情")
async def get_role(
@ -67,7 +69,8 @@ async def get_role(
detail="角色不存在"
)
return role.to_dict()
response = role.to_dict()
return HxfResponse(response)
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色")
async def create_role(
@ -95,12 +98,13 @@ async def create_role(
)
role.set_audit_fields(current_user.id)
await session.add(role)
session.add(role)
await session.commit()
await session.refresh(role)
logger.info(f"Role created: {role.name} by user {current_user.username}")
return role.to_dict()
response = role.to_dict()
return HxfResponse(response)
@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色")
async def update_role(
@ -152,7 +156,8 @@ async def update_role(
await session.refresh(role)
logger.info(f"Role updated: {role.name} by user {current_user.username}")
return role.to_dict()
response = role.to_dict()
return HxfResponse(response)
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色")
async def delete_role(
@ -190,7 +195,8 @@ async def delete_role(
await session.commit()
session.desc = f"角色删除成功: {role.name} by user {current_user.username}"
return {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
response = {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
return HxfResponse(response)
# 用户角色管理路由
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
@ -230,13 +236,14 @@ async def assign_user_roles(
user_id=assignment_data.user_id,
role_id=role_id
)
await session.add(user_role)
session.add(user_role)
await session.commit()
session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}"
return {"message": "角色分配成功"}
response = {"message": "角色分配成功"}
return HxfResponse(response)
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表")
async def get_user_roles(
@ -267,7 +274,8 @@ async def get_user_roles(
)
roles = (await session.execute(stmt)).scalars().all()
return [role.to_dict() for role in roles]
response = [role.to_dict() for role in roles]
return HxfResponse(response)
# 将子路由添加到主路由
router.include_router(user_role_router)

View File

@ -12,6 +12,7 @@ from th_agenter.services.conversation_context import conversation_context_servic
from utils.util_schemas import BaseResponse
from pydantic import BaseModel
from loguru import logger
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
security = HTTPBearer()
@ -65,6 +66,8 @@ async def smart_query(
# 初始化工作流管理器
workflow_manager = SmartWorkflowManager(session)
await workflow_manager.initialize()
conversation_service = ConversationService(session)
# 处理对话上下文
@ -126,7 +129,7 @@ async def smart_query(
except Exception as e:
session.desc = f"ERROR: 智能查询执行失败: {e}"
# 返回结构化的错误响应
return SmartQueryResponse(
response = SmartQueryResponse(
success=False,
message=f"查询执行失败: {str(e)}",
data={'error_type': 'query_execution_error'},
@ -137,6 +140,7 @@ async def smart_query(
}],
conversation_id=conversation_id
)
return HxfResponse(response)
# 如果查询成功,保存助手回复和更新上下文
if result['success'] and conversation_id:
@ -169,21 +173,22 @@ async def smart_query(
if conversation_id:
response_data['conversation_id'] = conversation_id
session.desc = f"SUCCESS: 保存助手回复和更新上下文对话ID: {conversation_id}"
return SmartQueryResponse(
response = SmartQueryResponse(
success=result['success'],
message=result.get('message', '查询完成'),
data=response_data,
workflow_steps=result.get('workflow_steps', []),
conversation_id=conversation_id
)
return HxfResponse(response)
except HTTPException:
except HTTPException as e:
session.desc = f"EXCEPTION: HTTP异常: {e}"
raise
raise e
except Exception as e:
session.desc = f"ERROR: 智能查询接口异常: {e}"
# 返回通用错误响应
return SmartQueryResponse(
response = SmartQueryResponse(
success=False,
message="服务器内部错误,请稍后重试",
data={'error_type': 'internal_server_error'},
@ -194,6 +199,7 @@ async def smart_query(
}],
conversation_id=conversation_id
)
return HxfResponse(response)
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
async def get_conversation_context(
@ -227,11 +233,12 @@ async def get_conversation_context(
history = await conversation_context_service.get_conversation_history(conversation_id)
context['message_history'] = history
session.desc = f"SUCCESS: 获取对话上下文成功对话ID: {conversation_id}"
return ConversationContextResponse(
response = ConversationContextResponse(
success=True,
message="获取对话上下文成功",
data=context
)
return HxfResponse(response)
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
@ -244,6 +251,7 @@ async def get_files_status(
"""
session.desc = f"START: 获取用户文件状态和统计信息用户ID: {current_user.id}"
workflow_manager = SmartWorkflowManager()
await workflow_manager.initialize()
# 获取用户文件列表
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
@ -277,11 +285,12 @@ async def get_files_status(
}
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功用户ID: {current_user.id}"
return ConversationContextResponse(
response = ConversationContextResponse(
success=True,
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件请先上传Excel文件",
data=status_data
)
return HxfResponse(response)
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
async def reset_conversation_context(
@ -315,10 +324,11 @@ async def reset_conversation_context(
if success:
session.desc = f"SUCCESS: 重置对话上下文成功对话ID: {conversation_id}"
return {
"success": True,
"message": "对话上下文已重置,可以开始新的数据分析会话"
}
response = ConversationContextResponse(
success=True,
message="对话上下文已重置,可以开始新的数据分析会话"
)
return HxfResponse(response)
else:
session.desc = f"EXCEPTION: 重置对话上下文失败对话ID: {conversation_id}"
raise HTTPException(

View File

@ -30,6 +30,7 @@ from th_agenter.services.smart_workflow import SmartWorkflowManager
from th_agenter.services.conversation_context import ConversationContextService
from pydantic import BaseModel
from loguru import logger
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
security = HTTPBearer()
@ -163,12 +164,13 @@ async def upload_excel(
})
session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功文件ID: {excel_file.id}"
return ExcelUploadResponse(
response = ExcelUploadResponse(
file_id=excel_file.id,
success=True,
message="Excel文件上传成功",
data=analysis_result
)
return HxfResponse(response)
@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据")
async def preview_excel(
@ -233,7 +235,7 @@ async def preview_excel(
data = paginated_df.fillna('').to_dict('records')
columns = df.columns.tolist()
session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据"
return QueryResponse(
response = QueryResponse(
success=True,
message="Excel文件预览加载成功",
data={
@ -245,6 +247,7 @@ async def preview_excel(
'total_pages': (total_rows + request.page_size - 1) // request.page_size
}
)
return HxfResponse(response)
@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接")
async def test_database_connection(
@ -264,10 +267,11 @@ async def test_database_connection(
message="数据库连接测试成功"
)
else:
return NormalResponse(
response = NormalResponse(
success=False,
message="数据库连接测试失败"
)
return HxfResponse(response)
except Exception as e:
return NormalResponse(
@ -298,22 +302,25 @@ async def get_table_schema(
schema_result = await db_service.get_table_schema(request.table_name, current_user.id)
if schema_result['success']:
return QueryResponse(
response = QueryResponse(
success=True,
message="获取表结构成功",
data=schema_result['data']
)
return HxfResponse(response)
else:
return QueryResponse(
response = QueryResponse(
success=False,
message=schema_result['message']
)
return HxfResponse(response)
except Exception as e:
return QueryResponse(
response = QueryResponse(
success=False,
message=f"获取表结构失败: {str(e)}"
)
return HxfResponse(response)
class StreamQueryRequest(BaseModel):
query: str
@ -355,6 +362,8 @@ async def stream_smart_query(
# 初始化服务
workflow_manager = SmartWorkflowManager(session)
await workflow_manager.initialize()
conversation_context_service = ConversationContextService()
# 处理对话上下文
@ -435,7 +444,7 @@ async def stream_smart_query(
except:
pass
return StreamingResponse(
response = StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
@ -447,6 +456,7 @@ async def stream_smart_query(
"Access-Control-Allow-Methods": "*"
}
)
return HxfResponse(response)
@router.post("/execute-db-query", summary="流式数据库查询")
async def execute_database_query(
@ -477,6 +487,7 @@ async def execute_database_query(
# 初始化服务
workflow_manager = SmartWorkflowManager(session)
await workflow_manager.initialize()
conversation_context_service = ConversationContextService()
# 处理对话上下文
@ -556,7 +567,7 @@ async def execute_database_query(
except:
pass
return StreamingResponse(
response = StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
@ -568,6 +579,7 @@ async def execute_database_query(
"Access-Control-Allow-Methods": "*"
}
)
return HxfResponse(response)
@router.delete("/cleanup-temp-files", summary="清理临时文件")
async def cleanup_temp_files(
@ -590,16 +602,18 @@ async def cleanup_temp_files(
except OSError:
pass
return BaseResponse(
response = BaseResponse(
success=True,
message=f"已清理 {cleaned_count} 个临时文件"
)
return HxfResponse(response)
except Exception as e:
return BaseResponse(
response = BaseResponse(
success=False,
message=f"清理临时文件失败: {str(e)}"
)
return HxfResponse(response)
@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表")
async def get_file_list(
@ -634,7 +648,7 @@ async def get_file_list(
file_list.append(file_info)
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件"
return FileListResponse(
response = FileListResponse(
success=True,
message="获取文件列表成功",
data={
@ -645,12 +659,14 @@ async def get_file_list(
'total_pages': (total + page_size - 1) // page_size
}
)
return HxfResponse(response)
except Exception as e:
return FileListResponse(
response = FileListResponse(
success=False,
message=f"获取文件列表失败: {str(e)}"
)
return HxfResponse(response)
@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件")
async def delete_file(
@ -668,22 +684,26 @@ async def delete_file(
if success:
session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}"
return NormalResponse(
response = NormalResponse(
success=True,
message="文件删除成功"
)
return HxfResponse(response)
else:
session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败"
return NormalResponse(
response = NormalResponse(
success=False,
message="文件不存在或删除失败"
)
return HxfResponse(response)
except Exception as e:
return NormalResponse(
response = NormalResponse(
success=True,
message=str(e)
)
return HxfResponse(response)
@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息")
async def get_file_info(
@ -725,9 +745,11 @@ async def get_file_info(
'sheets_summary': excel_file.get_all_sheets_summary()
}
return QueryResponse(
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件 {file_id} 信息"
response = QueryResponse(
success=True,
message="获取文件信息成功",
data=file_info
)
return HxfResponse(response)

View File

@ -9,6 +9,7 @@ from th_agenter.models.user import User
from th_agenter.db.database import get_session
from th_agenter.services.table_metadata_service import TableMetadataService
from th_agenter.services.auth import AuthService
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
@ -54,8 +55,7 @@ async def collect_table_metadata(
request.table_names
)
session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据"
return result
return HxfResponse(result)
@router.get("/", summary="获取用户表元数据列表")
async def get_table_metadata(
@ -66,7 +66,7 @@ async def get_table_metadata(
"""获取表元数据列表"""
try:
service = TableMetadataService(session)
metadata_list = service.get_user_table_metadata(
metadata_list = await service.get_user_table_metadata(
current_user.id,
database_config_id
)
@ -96,17 +96,17 @@ async def get_table_metadata(
for meta in metadata_list
]
return {
return HxfResponse({
"success": True,
"data": data
}
})
except Exception as e:
logger.error(f"获取表元数据失败: {str(e)}")
return {
return HxfResponse({
"success": False,
"message": str(e)
}
})
@router.post("/by-table", summary="根据表名获取表元数据")
@ -118,7 +118,7 @@ async def get_table_metadata_by_name(
"""根据表名获取表元数据"""
try:
service = TableMetadataService(session)
metadata = service.get_table_metadata_by_name(
metadata = await service.get_table_metadata_by_name(
current_user.id,
request.database_config_id,
request.table_name
@ -146,27 +146,30 @@ async def get_table_metadata_by_name(
"business_context": metadata.business_context or ""
}
}
return {"success": True, "data": data}
else:
return {"success": False, "data": None, "message": "表元数据不存在"}
except Exception as e:
logger.error(f"获取表元数据失败: {str(e)}")
return {
"success": False,
"message": str(e)
}
return {
return HxfResponse({
"success": True,
"data": data
}
})
else:
return HxfResponse({
"success": False,
"data": None,
"message": "表元数据不存在"
})
except Exception as e:
logger.error(f"获取表元数据失败: {str(e)}")
return {
return HxfResponse({
"success": False,
"message": str(e)
}
})
except Exception as e:
logger.error(f"获取表元数据失败: {str(e)}")
return HxfResponse({
"success": False,
"message": str(e)
})
@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置")
@ -179,14 +182,17 @@ async def update_qa_settings(
"""更新表的问答设置"""
try:
service = TableMetadataService(session)
success = service.update_table_qa_settings(
success = await service.update_table_qa_settings(
current_user.id,
metadata_id,
settings.dict()
settings.model_dump()
)
if success:
return {"success": True, "message": "设置更新成功"}
return HxfResponse({
"success": True,
"message": "设置更新成功"
})
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -221,9 +227,9 @@ async def save_table_metadata(
session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置"
return {
return HxfResponse({
"success": True,
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
"saved_tables": result['saved_tables'],
"failed_tables": result.get('failed_tables', [])
}
})

View File

@ -9,6 +9,7 @@ from ...core.simple_permissions import require_super_admin
from ...services.auth import AuthService
from ...services.user import UserService
from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest
from utils.util_exceptions import HxfResponse
router = APIRouter()
@ -17,7 +18,8 @@ async def get_user_profile(
current_user = Depends(AuthService.get_current_user)
):
"""获取当前用户的个人信息."""
return UserResponse.model_validate(current_user)
response = UserResponse.model_validate(current_user)
return HxfResponse(response)
@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息")
async def update_user_profile(
@ -39,7 +41,8 @@ async def update_user_profile(
# Update user
updated_user = await user_service.update_user(current_user.id, user_update)
return UserResponse.model_validate(updated_user)
response = UserResponse.model_validate(updated_user)
return HxfResponse(response)
@router.delete("/profile", summary="删除当前用户的账户")
async def delete_user_account(
@ -51,7 +54,8 @@ async def delete_user_account(
user_service = UserService(session)
await user_service.delete_user(current_user.id)
session.desc = f"删除用户 [{username}] 成功"
return {"message": f"删除用户 {username} 成功"}
response = {"message": f"删除用户 {username} 成功"}
return HxfResponse(response)
# Admin endpoints
@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)")
@ -83,7 +87,8 @@ async def create_user(
# Create user
new_user = await user_service.create_user(user_create)
return UserResponse.model_validate(new_user)
response = UserResponse.model_validate(new_user)
return HxfResponse(response)
@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)")
async def list_users(
@ -111,7 +116,7 @@ async def list_users(
"page": page,
"page_size": size
}
return result
return HxfResponse(result)
@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)")
@ -128,7 +133,8 @@ async def get_user(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return UserResponse.model_validate(user)
response = UserResponse.model_validate(user)
return HxfResponse(response)
@router.put("/change-password", summary="修改当前用户的密码")
async def change_password(
@ -145,7 +151,8 @@ async def change_password(
current_password=request.current_password,
new_password=request.new_password
)
return {"message": "Password changed successfully"}
response = {"message": "Password changed successfully"}
return HxfResponse(response)
except Exception as e:
if "Current password is incorrect" in str(e):
raise HTTPException(
@ -178,7 +185,8 @@ async def reset_user_password(
user_id=user_id,
new_password=request.new_password
)
return {"message": "Password reset successfully"}
response = {"message": "Password reset successfully"}
return HxfResponse(response)
except Exception as e:
if "User not found" in str(e):
raise HTTPException(
@ -215,7 +223,8 @@ async def update_user(
)
updated_user = await user_service.update_user(user_id, user_update)
return UserResponse.model_validate(updated_user)
response = UserResponse.model_validate(updated_user)
return HxfResponse(response)
@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)")
async def delete_user(
@ -234,4 +243,5 @@ async def delete_user(
)
await user_service.delete_user(user_id)
return {"message": "User deleted successfully"}
response = {"message": "User deleted successfully"}
return HxfResponse(response)

View File

@ -18,7 +18,7 @@ from ...services.workflow_engine import get_workflow_engine
from ...services.auth import AuthService
from ...models.user import User
from loguru import logger
from utils.util_exceptions import HxfResponse
router = APIRouter()
def convert_workflow_for_response(workflow_dict):
@ -31,37 +31,6 @@ def convert_workflow_for_response(workflow_dict):
conn['to'] = conn.pop('to_node')
return workflow_dict
@router.post("/", response_model=WorkflowResponse)
async def create_workflow(
workflow_data: WorkflowCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建工作流"""
from ...models.workflow import Workflow
# 创建工作流
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
definition=workflow_data.definition.model_dump(),
version="1.0.0",
status=workflow_data.status,
owner_id=current_user.id
)
workflow.set_audit_fields(current_user.id)
await session.add(workflow)
await session.commit()
await session.refresh(workflow)
# 转换definition中的字段映射
workflow_dict = convert_workflow_for_response(workflow.to_dict())
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
return WorkflowResponse(**workflow_dict)
@router.get("/", response_model=WorkflowListResponse)
async def list_workflows(
skip: Optional[int] = Query(None, ge=0),
@ -73,6 +42,7 @@ async def list_workflows(
):
"""获取工作流列表"""
from ...models.workflow import Workflow
session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})"
# 构建查询
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
@ -90,17 +60,22 @@ async def list_workflows(
count_query = count_query.where(Workflow.status == workflow_status)
if search:
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
total = session.scalar(count_query)
session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}"
total = await session.scalar(count_query)
session.desc = f"查询结果: 共 {total}"
# 如果没有传分页参数,返回所有数据
if skip is None and limit is None:
workflows = session.scalars(stmt).all()
return WorkflowListResponse(
workflows = (await session.scalars(stmt)).all()
session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)}"
response = WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=1,
size=total
)
return HxfResponse(response)
# 使用默认分页参数
if skip is None:
@ -109,14 +84,16 @@ async def list_workflows(
limit = 10
# 分页查询
workflows = session.scalars(stmt.offset(skip).limit(limit)).all()
workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all()
session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)}"
return WorkflowListResponse(
response = WorkflowListResponse(
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
total=total,
page=skip // limit + 1, # 计算页码
size=limit
)
return HxfResponse(response)
@router.get("/{workflow_id}", response_model=WorkflowResponse)
async def get_workflow(
@ -126,8 +103,9 @@ async def get_workflow(
):
"""获取工作流详情"""
from ...models.workflow import Workflow
session.desc = f"START: 获取工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).where(
Workflow.id == workflow_id,
Workflow.owner_id == current_user.id
@ -135,12 +113,15 @@ async def get_workflow(
)
if not workflow:
session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}"
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
return HxfResponse(response)
@router.put("/{workflow_id}", response_model=WorkflowResponse)
async def update_workflow(
@ -151,8 +132,9 @@ async def update_workflow(
):
"""更新工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 更新工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
@ -162,13 +144,15 @@ async def update_workflow(
)
if not workflow:
session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
# 更新字段
update_data = workflow_data.dict(exclude_unset=True)
session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}"
update_data = workflow_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "definition" and value:
# 如果value是Pydantic模型转换为字典如果已经是字典直接使用
@ -181,12 +165,11 @@ async def update_workflow(
workflow.set_audit_fields(current_user.id, is_update=True)
session.commit()
session.refresh(workflow)
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
await session.commit()
await session.refresh(workflow)
session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}"
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
return HxfResponse(response)
@router.delete("/{workflow_id}")
async def delete_workflow(
@ -196,8 +179,9 @@ async def delete_workflow(
):
"""删除工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 删除工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
@ -207,16 +191,18 @@ async def delete_workflow(
)
if not workflow:
session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
session.desc = f"删除工作流: {workflow.name}"
await session.delete(workflow)
await session.commit()
session.delete(workflow)
session.commit()
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流删除成功"}
session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}"
response = {"message": "工作流删除成功"}
return HxfResponse(response)
@router.post("/{workflow_id}/activate")
@ -227,8 +213,9 @@ async def activate_workflow(
):
"""激活工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 激活工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
@ -238,6 +225,7 @@ async def activate_workflow(
)
if not workflow:
session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
@ -245,11 +233,11 @@ async def activate_workflow(
workflow.status = ModelWorkflowStatus.PUBLISHED
workflow.set_audit_fields(current_user.id, is_update=True)
await session.commit()
session.commit()
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流激活成功"}
session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}"
response = {"message": "工作流激活成功"}
return HxfResponse(response)
@router.post("/{workflow_id}/deactivate")
async def deactivate_workflow(
@ -259,8 +247,9 @@ async def deactivate_workflow(
):
"""停用工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 停用工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
@ -270,6 +259,7 @@ async def deactivate_workflow(
)
if not workflow:
session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
@ -278,10 +268,11 @@ async def deactivate_workflow(
workflow.status = ModelWorkflowStatus.ARCHIVED
workflow.set_audit_fields(current_user.id, is_update=True)
session.commit()
await session.commit()
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
return {"message": "工作流停用成功"}
session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}"
response = {"message": "工作流停用成功"}
return HxfResponse(response)
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
async def execute_workflow(
@ -292,8 +283,9 @@ async def execute_workflow(
):
"""执行工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 执行工作流 {workflow_id}"
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
@ -303,6 +295,7 @@ async def execute_workflow(
)
if not workflow:
session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
@ -323,8 +316,8 @@ async def execute_workflow(
session=session
)
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
return execution_result
session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}"
return HxfResponse(execution_result)
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
@ -336,11 +329,12 @@ async def list_workflow_executions(
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行历史"""
session.desc = f"START: 获取工作流执行历史 {workflow_id}"
try:
from ...models.workflow import Workflow, WorkflowExecution
# 验证工作流所有权
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).where(
and_(
Workflow.id == workflow_id,
@ -350,24 +344,25 @@ async def list_workflow_executions(
)
if not workflow:
session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作流不存在"
)
# 获取执行历史
executions = session.scalars(
executions = (await session.scalars(
select(WorkflowExecution).where(
WorkflowExecution.workflow_id == workflow_id
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
).all()
)).all()
return [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}"
response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
return HxfResponse(response)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}")
session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行历史失败"
@ -380,10 +375,11 @@ async def get_workflow_execution(
current_user: User = Depends(AuthService.get_current_user)
):
"""获取工作流执行详情"""
session.desc = f"START: 获取工作流执行详情 {execution_id}"
try:
from ...models.workflow import WorkflowExecution, Workflow
execution = session.scalar(
execution = await session.scalar(
select(WorkflowExecution).join(
Workflow, WorkflowExecution.workflow_id == Workflow.id
).where(
@ -393,17 +389,18 @@ async def get_workflow_execution(
)
if not execution:
session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在"
)
return WorkflowExecutionResponse.model_validate(execution)
response = WorkflowExecutionResponse.model_validate(execution)
session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}"
return HxfResponse(response)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting workflow execution {execution_id}: {str(e)}")
session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取执行详情失败"
@ -418,7 +415,7 @@ async def execute_workflow_stream(
current_user: User = Depends(AuthService.get_current_user)
):
"""流式执行工作流,实时推送节点执行状态"""
session.desc = f"START: 流式执行工作流 {workflow_id}"
async def generate_stream() -> AsyncGenerator[str, None]:
workflow_engine = None
@ -426,7 +423,7 @@ async def execute_workflow_stream(
from ...models.workflow import Workflow
# 验证工作流
workflow = session.scalar(
workflow = await session.scalar(
select(Workflow).filter(
and_(
Workflow.id == workflow_id,
@ -466,7 +463,7 @@ async def execute_workflow_stream(
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
return StreamingResponse(
response = StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
@ -478,3 +475,42 @@ async def execute_workflow_stream(
"Access-Control-Allow-Methods": "*"
}
)
session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕"
return HxfResponse(response)
# -----------------------------------------------------------------------
@router.post("/", response_model=WorkflowResponse)
async def create_workflow(
workflow_data: WorkflowCreate,
session: Session = Depends(get_session),
current_user: User = Depends(AuthService.get_current_user)
):
"""创建工作流"""
from ...models.workflow import Workflow
session.desc = f"START: 创建工作流 {workflow_data.name}"
# 创建工作流
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
definition=workflow_data.definition.model_dump(),
version="1.0.0",
status=workflow_data.status,
owner_id=current_user.id
)
session.desc = f"创建工作流实例 - Workflow() {workflow_data.name}"
workflow.set_audit_fields(current_user.id)
session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}"
session.add(workflow)
await session.commit()
await session.refresh(workflow)
session.desc = f"保存工作流 - commit & refresh {workflow_data.name}"
# 转换definition中的字段映射
workflow_dict = convert_workflow_for_response(workflow.to_dict())
session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}"
response = WorkflowResponse(**workflow_dict)
session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}"
return HxfResponse(response)

View File

@ -69,13 +69,6 @@ router.include_router(
tags=["smart-chat"]
)
router.include_router(
workflow.router,
prefix="/workflows",

View File

@ -1,6 +1,7 @@
"""Configuration management for TH Agenter."""
import os
from requests import Session
import yaml
from pathlib import Path
from loguru import logger
@ -88,13 +89,15 @@ class LLMSettings(BaseSettings):
"extra": "ignore"
}
def get_current_config(self) -> dict:
async def get_current_config(self, session: Session) -> dict:
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
try:
# 尝试从数据库读取默认聊天模型配置
from th_agenter.services.llm_config_service import LLMConfigService
# 尝试从数据库读取默认聊天模型配置
llm_service = LLMConfigService()
db_config = llm_service.get_default_chat_config()
db_config = None
if session:
db_config = await llm_service.get_default_chat_config(session)
if db_config:
# 如果数据库中有默认配置,使用数据库配置
@ -105,10 +108,17 @@ class LLMSettings(BaseSettings):
"max_tokens": self.max_tokens,
"temperature": self.temperature
}
if session:
session.desc = f"使用LLM配置(get_default_chat_config)> {config}"
else:
logger.info(f"使用LLM配置(get_default_chat_config) > {config}")
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
if session:
session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}"
else:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
# 回退到原有的环境变量配置
provider_configs = {
@ -182,13 +192,15 @@ class EmbeddingSettings(BaseSettings):
"extra": "ignore"
}
def get_current_config(self) -> dict:
async def get_current_config(self, session: Session) -> dict:
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
try:
if session:
session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config";
# 尝试从数据库读取默认嵌入模型配置
from th_agenter.services.llm_config_service import LLMConfigService
llm_service = LLMConfigService()
db_config = llm_service.get_default_embedding_config()
db_config = await llm_service.get_default_embedding_config(session)
if db_config:
# 如果数据库中有默认配置,使用数据库配置
@ -200,7 +212,10 @@ class EmbeddingSettings(BaseSettings):
return config
except Exception as e:
# 如果数据库读取失败,记录错误并回退到环境变量
logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
if session:
session.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
else:
logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
# 回退到原有的环境变量配置
provider_configs = {

View File

@ -9,7 +9,7 @@ from ..models.user import User
from loguru import logger
# Context variable to store current user
current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None)
current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None)
# Thread-local storage as backup
_thread_local = threading.local()
@ -19,34 +19,56 @@ class UserContext:
"""User context manager for accessing current user globally."""
@staticmethod
def set_current_user(user: User) -> None:
def set_current_user(user: User, canLog: bool = False) -> None:
"""Set current user in context."""
logger.info(f"[UserContext] - Setting user in context: {user.username} (ID: {user.id})")
if canLog:
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
# Store user information as a dictionary instead of the SQLAlchemy model
user_dict = {
'id': user.id,
'username': user.username,
'email': user.email,
'full_name': user.full_name,
'is_active': user.is_active
}
# Set in ContextVar
current_user_context.set(user)
current_user_context.set(user_dict)
# Also set in thread-local as backup
_thread_local.current_user = user
_thread_local.current_user = user_dict
# Verify it was set
verify_user = current_user_context.get()
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
if canLog:
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
@staticmethod
def set_current_user_with_token(user: User):
def set_current_user_with_token(user: User, canLog: bool = False):
"""Set current user in context and return token for cleanup."""
logger.info(f"[UserContext] - Setting user in context with token: {user.username} (ID: {user.id})")
if canLog:
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
# Store user information as a dictionary instead of the SQLAlchemy model
user_dict = {
'id': user.id,
'username': user.username,
'email': user.email,
'full_name': user.full_name,
'is_active': user.is_active
}
# Set in ContextVar and get token
token = current_user_context.set(user)
token = current_user_context.set(user_dict)
# Also set in thread-local as backup
_thread_local.current_user = user
_thread_local.current_user = user_dict
# Verify it was set
verify_user = current_user_context.get()
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
if canLog:
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
return token
@ -63,34 +85,37 @@ class UserContext:
delattr(_thread_local, 'current_user')
@staticmethod
def get_current_user() -> Optional[User]:
def get_current_user() -> Optional[dict]:
"""Get current user from context."""
logger.debug("[UserContext] - Attempting to get user from context")
# Try ContextVar first
user = current_user_context.get()
if user:
logger.debug(f"[UserContext] - Got user from ContextVar: {user.username} (ID: {user.id})")
# logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}")
return user
# Fallback to thread-local
user = getattr(_thread_local, 'current_user', None)
if user:
logger.debug(f"[UserContext] - Got user from thread-local: {user.username} (ID: {user.id})")
# logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}")
return user
logger.debug("[UserContext] - No user found in context (neither ContextVar nor thread-local)")
logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)")
return None
@staticmethod
def get_current_user_id() -> Optional[int]:
"""Get current user ID from context."""
try:
user = UserContext.get_current_user()
return user.id if user else None
return user.get('id') if user else None
except Exception as e:
logger.error(f"[UserContext] - Error getting current user ID: {e}")
return None
@staticmethod
def clear_current_user() -> None:
def clear_current_user(canLog: bool = False) -> None:
"""Clear current user from context."""
if canLog:
logger.info("[UserContext] - 清除当前用户上下文")
current_user_context.set(None)
@ -98,7 +123,7 @@ class UserContext:
delattr(_thread_local, 'current_user')
@staticmethod
def require_current_user() -> User:
def require_current_user() -> dict:
"""Get current user from context, raise exception if not found."""
# Use the same logic as get_current_user to check both ContextVar and thread-local
user = UserContext.get_current_user()
@ -114,4 +139,4 @@ class UserContext:
def require_current_user_id() -> int:
"""Get current user ID from context, raise exception if not found."""
user = UserContext.require_current_user()
return user.id
return user.get('id')

View File

@ -20,7 +20,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
def __init__(self, app, exclude_paths: list = None):
super().__init__(app)
self.canLog = True
self.canLog = False
# Paths that don't require authentication
self.exclude_paths = exclude_paths or [
"/docs",
@ -77,7 +77,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
# Always clear any existing user context to ensure fresh authentication
UserContext.clear_current_user()
UserContext.clear_current_user(self.canLog)
# Initialize context token
user_token = None
@ -96,6 +96,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
# Extract token
token = authorization.split(" ")[1]
# Verify token
payload = AuthService.verify_token(token)
if payload is None:
@ -136,7 +137,7 @@ class UserContextMiddleware(BaseHTTPMiddleware):
)
# Set user in context using token mechanism
user_token = UserContext.set_current_user_with_token(user)
user_token = UserContext.set_current_user_with_token(user, self.canLog)
if self.canLog:
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
@ -160,8 +161,13 @@ class UserContextMiddleware(BaseHTTPMiddleware):
try:
response = await call_next(request)
return response
except Exception as e:
# Log error but don't fail the request
logger.error(f"[MIDDLEWARE] - 请求处理出错: {e}")
# Return 500 error
return HxfErrorResponse(e)
finally:
# Always clear user context after request processing
UserContext.clear_current_user()
UserContext.clear_current_user(self.canLog)
if self.canLog:
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")

View File

@ -0,0 +1,73 @@
"""LLM工厂类用于创建和管理LLM实例"""
from typing import Optional
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent
from loguru import logger
from requests import Session
from .config import get_settings
async def new_llm(session: Session = None, model: Optional[str] = None,
temperature: Optional[float] = None,
streaming: bool = False) -> ChatOpenAI:
"""创建LLM实例
Args:
model: 可选指定使用的模型名称如果不指定将使用配置文件中的默认模型
temperature: 可选模型温度参数
streaming: 是否启用流式响应默认False
Returns:
ChatOpenAI实例
"""
settings = get_settings()
llm_config = await settings.llm.get_current_config(session)
if model:
# 根据指定的模型获取对应配置
if model.startswith('deepseek'):
llm_config['model'] = settings.llm.deepseek_model
llm_config['api_key'] = settings.llm.deepseek_api_key
llm_config['base_url'] = settings.llm.deepseek_base_url
elif model.startswith('doubao'):
llm_config['model'] = settings.llm.doubao_model
llm_config['api_key'] = settings.llm.doubao_api_key
llm_config['base_url'] = settings.llm.doubao_base_url
elif model.startswith('glm'):
llm_config['model'] = settings.llm.zhipu_model
llm_config['api_key'] = settings.llm.zhipu_api_key
llm_config['base_url'] = settings.llm.zhipu_base_url
elif model.startswith('moonshot'):
llm_config['model'] = settings.llm.moonshot_model
llm_config['api_key'] = settings.llm.moonshot_api_key
llm_config['base_url'] = settings.llm.moonshot_base_url
llm = ChatOpenAI(
model=llm_config['model'],
api_key=llm_config['api_key'],
base_url=llm_config['base_url'],
temperature=temperature if temperature is not None else llm_config['temperature'],
max_tokens=llm_config['max_tokens'],
streaming=streaming
)
return llm
async def new_agent(session: Session = None, model: Optional[str] = None,
temperature: Optional[float] = None,
streaming: bool = False) -> ChatOpenAI:
"""创建LLM实例
Args:
model: 可选指定使用的模型名称如果不指定将使用配置文件中的默认模型
temperature: 可选模型温度参数
streaming: 是否启用流式响应默认False
Returns:
ChatOpenAI实例
"""
llm = await new_llm(session, model, temperature, streaming)
result = create_agent(
model=llm,
)
return result

View File

@ -3,6 +3,7 @@
from functools import wraps
from typing import Optional
from fastapi import HTTPException, Depends
from loguru import logger
from sqlalchemy.orm import Session
from ..db.database import get_session
@ -11,45 +12,40 @@ from ..models.permission import Role
from ..services.auth import AuthService
def is_super_admin(user: User, db: Session) -> bool:
async def is_super_admin(user: User, session: Session) -> bool:
"""检查用户是否为超级管理员."""
session.desc = f"检查用户 {user.id} 是否为超级管理员"
if not user or not user.is_active:
session.desc = f"用户 {user.id} 不是活跃状态"
return False
# 检查用户是否有超级管理员角色
try:
# 尝试访问已加载的角色
for role in user.roles:
if role.code == "SUPER_ADMIN":
return True
return False
except Exception:
# 如果角色未加载或访问失败,直接从数据库查询
from sqlalchemy import select, and_
from ..models.permission import Role, UserRole
# 直接使用提供的session查询避免MissingGreenlet错误
from sqlalchemy import select
from ..models.permission import UserRole, Role
try:
# 直接查询用户角色
stmt = select(Role).join(UserRole).filter(
and_(
stmt = select(UserRole).join(Role).filter(
UserRole.user_id == user.id,
Role.code == "SUPER_ADMIN",
Role.code == 'SUPER_ADMIN',
Role.is_active == True
)
)
super_admin_role = db.execute(stmt).scalar_one_or_none()
return super_admin_role is not None
except Exception:
# 如果查询失败返回False
user_role = await session.execute(stmt)
result = user_role.scalar_one_or_none() is not None
session.desc = f"用户 {user.id} 超级管理员角色查询结果: {result}"
return result
except Exception as e:
# 如果调用失败记录错误并返回False
session.desc = f"EXCEPTION: 用户 {user.id} 超级管理员角色查询失败: {str(e)}"
logger.error(f"检查用户 {user.id} 超级管理员角色失败: {str(e)}")
return False
def require_super_admin(
async def require_super_admin(
current_user: User = Depends(AuthService.get_current_user),
session: Session = Depends(get_session)
) -> User:
"""要求超级管理员权限的依赖项."""
if not is_super_admin(current_user, session):
if not await is_super_admin(current_user, session):
raise HTTPException(
status_code=403,
detail="需要超级管理员权限"
@ -75,17 +71,17 @@ class SimplePermissionChecker:
def __init__(self, db: Session):
self.db = db
def check_super_admin(self, user: User) -> bool:
async def check_super_admin(self, user: User) -> bool:
"""检查是否为超级管理员."""
return is_super_admin(user, self.db)
return await is_super_admin(user, self.db)
def check_user_access(self, user: User, target_user_id: int) -> bool:
async def check_user_access(self, user: User, target_user_id: int) -> bool:
"""检查用户访问权限(自己或超级管理员)."""
if not user or not user.is_active:
return False
# 超级管理员可以访问所有用户
if self.check_super_admin(user):
if await self.check_super_admin(user):
return True
# 用户只能访问自己的信息

View File

@ -38,7 +38,7 @@ class BaseModel(Base):
def to_dict(self):
"""Convert model to dictionary."""
return {
column.name: getattr(self, column.name)
column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name)
for column in self.__table__.columns
}
@ -60,31 +60,31 @@ class BaseModel(Base):
return cls(**filtered_data)
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
"""Set audit fields for create/update operations.
"""对创建/更新操作设置created_by/updated_by字段。
Args:
user_id: ID of the user performing the operation (optional, will use context if not provided)
is_update: True for update operations, False for create operations
user_id: 用户ID用于设置创建/更新操作的审计字段可选默认从上下文获取
is_update: True 表示更新操作False 表示创建操作
"""
# Get user_id from context if not provided
# 如果未提供user_id则从上下文获取
if user_id is None:
from ..core.context import UserContext
try:
user_id = UserContext.get_current_user_id()
except Exception:
# If no user in context, skip setting audit fields
# 如果上下文没有用户ID则跳过设置审计字段
return
# Skip if still no user_id
# 如果仍未提供user_id则跳过设置审计字段
if user_id is None:
return
if not is_update:
# For create operations, set both create_by and update_by
# 对于创建操作同时设置created_by和updated_by
self.created_by = user_id
self.updated_by = user_id
else:
# For update operations, only set update_by
# 对于更新操作仅设置updated_by
self.updated_by = user_id
# @event.listens_for(Session, 'before_flush')

View File

@ -18,12 +18,27 @@ class DrSession(AsyncSession):
def __init__(self, **kwargs):
"""Initialize DrSession with unique ID."""
super().__init__(**kwargs)
self.title = ""
self.descs = []
# 确保info属性存在
if not hasattr(self, 'info'):
self.info = {}
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
self.stepIndex = 0
@property
def title(self) -> Optional[str]:
"""Get work brief from session info."""
return self.info.get('title')
@title.setter
def title(self, value: str) -> None:
"""Set work brief in session info."""
if('title' not in self.info or self.info['title'].strip() == ""):
self.info['title'] = value # 确保title属性存在
else:
self.info['title'] = value + " >>> " + self.info['title']
@property
def desc(self) -> Optional[str]:
"""Get work brief from session info."""
@ -34,6 +49,25 @@ class DrSession(AsyncSession):
"""Set work brief in session info."""
self.stepIndex += 1
match = re.search(r";(-\d+)", value);
level = -3
if match:
level = int(match.group(1))
value = value.replace(f";{level}", "")
level = -3 + level
if "警告" in value or value.startswith("WARNING"):
self.log_warning(f"{self.stepIndex} 步 - {value}", level = level)
elif "异常" in value or value.startswith("EXCEPTION"):
self.log_exception(f"{self.stepIndex} 步 - {value}", level = level)
elif "成功" in value or value.startswith("SUCCESS"):
self.log_success(f"{self.stepIndex} 步 - {value}", level = level)
elif "开始" in value or value.startswith("START"):
self.log_success(f"{self.stepIndex} 步 - {value}", level = level)
elif "失败" in value or value.startswith("ERROR"):
self.log_error(f"{self.stepIndex} 步 - {value}", level = level)
else:
self.log_info(f"{self.stepIndex} 步 - {value}", level = level)
def log_prefix(self) -> str:
"""Get log prefix with session ID and desc."""
@ -74,14 +108,14 @@ class DrSession(AsyncSession):
engine_async = create_async_engine(
get_settings().database.url,
echo=True, # get_settings().database.echo,
echo=False, # get_settings().database.echo,
future=True,
pool_size=get_settings().database.pool_size,
max_overflow=get_settings().database.max_overflow,
pool_pre_ping=True,
pool_recycle=3600,
)
from fastapi import Request
from fastapi import HTTPException, Request
AsyncSessionFactory = sessionmaker(
bind=engine_async,
@ -95,10 +129,15 @@ async def get_session(request: Request = None):
if request:
url = f"{request.method} {request.url.path}"# .split("://")[-1]
# session = AsyncSessionFactory()
print(url)
# 取得request的来源IP
if request:
client_host = request.client.host
else:
client_host = "无request"
session = DrSession(bind=engine_async)
session.desc = f"SUCCESS: 创建数据库 session >>> {url}"
session.title = f"{url} - {client_host}"
# 设置request属性
if request:
@ -111,8 +150,9 @@ async def get_session(request: Request = None):
errMsg = f"数据库 session 异常 >>> {e}"
session.desc = f"EXCEPTION: {errMsg}"
await session.rollback()
raise e
# DatabaseError(e)
# 重新抛出原始异常,不转换为 HTTPException
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
finally:
session.desc = f"数据库 session 关闭"
# session.desc = f"数据库 session 关闭"
session.desc = ""
await session.close()

View File

@ -11,8 +11,8 @@ sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from th_agenter.core.config import settings
from th_agenter.db.database import Base, get_session
from th_agenter.core.config import get_settings
from th_agenter.db.database import Base
from th_agenter.models import * # Import all models to ensure they're registered
from th_agenter.utils.logger import get_logger
from th_agenter.models.resource import Resource
@ -25,8 +25,25 @@ def migrate_hardcoded_resources():
"""Migrate hardcoded resources from init_resource_data.py to database."""
db = None
try:
# Get database settings
settings = get_settings()
# Create synchronous engine (remove asyncpg from URL)
sync_db_url = settings.database.url.replace('postgresql+asyncpg://', 'postgresql://')
sync_engine = create_engine(
sync_db_url,
echo=False,
pool_size=settings.database.pool_size,
max_overflow=settings.database.max_overflow,
pool_pre_ping=True,
pool_recycle=3600,
)
# Create synchronous session factory
SyncSessionFactory = sessionmaker(bind=sync_engine)
# Get database session
db = get_session() # xxxx
db = SyncSessionFactory()
if db is None:
logger.error("Failed to create database session")

View File

@ -17,7 +17,7 @@ branch_labels = None
depends_on = None
def upgrade():
async def upgrade():
"""删除权限相关表."""
# 获取数据库连接
@ -44,7 +44,7 @@ def upgrade():
WHERE table_name = '{table_name}'
);
"""))
table_exists = result.scalar()
table_exists = await result.scalar()
if table_exists:
print(f"删除表: {table_name}")

View File

@ -0,0 +1,198 @@
from loguru import logger
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
# 核心:导入 LangChain 的基础语言模型抽象类
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from datetime import datetime
@dataclass
class LLMConfig_DataClass:
"""
统一的LLM配置基类覆盖在线/本地/嵌入式模型所有配置映射数据库完整字段
通过 provider + is_embedding 区分模型类型
- 在线模型provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
- 本地模型provider in ['llama', 'qwen', 'yi'] + is_embedding=False
- 嵌入式模型provider in ['bge', 'text2vec'] + is_embedding=True
"""
# ====================== 数据库核心公共字段(必选/可选) ======================
# 基础标识字段
name: str # 模型自定义名称(如 "gpt-5"
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5"
provider: str # 提供商openai/llama/bge/zhipu 等)
id: Optional[int] = None # 数据库主键ID
description: Optional[str] = None # 模型描述
is_active: bool = True # 是否启用
is_default: bool = False # 是否默认模型
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
# ====================== 通用生成参数(所有推理模型共用) ======================
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
top_p: float = 0.6 # 采样Top-P
frequency_penalty: float = 0.0 # 频率惩罚
presence_penalty: float = 0.0 # 存在惩罚
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
api_key: Optional[str] = None # API密钥在线模型必填
base_url: Optional[str] = None # API代理地址如 https://api.openai-proxy.org/v1
# timeout: int = 30 # 请求超时时间(秒)
max_retries: int = 3 # 最大重试次数
api_version: Optional[str] = None # API版本如 OpenAI 的 2024-02-15-preview
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
device: str = "cpu" # 运行设备cpu/cuda/mps
n_ctx: int = 2048 # 上下文窗口大小
n_threads: int = 8 # 推理线程数
quantization: str = "q4_0" # 量化级别q4_0/q8_0/f16
load_in_8bit: bool = False # 是否8bit加载
load_in_4bit: bool = False # 是否4bit加载
prompt_template: Optional[str] = None # 自定义Prompt模板
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
normalize_embeddings: bool = True # 是否归一化向量
batch_size: int = 32 # 批量编码大小
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
dimension: Optional[int] = None # 向量维度(如 768
# ====================== 元数据字段(数据库自动维护) ======================
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
usage_count: int = 0 # 使用次数
last_used_at: Optional[datetime] = None # 最后使用时间
created_at: Optional[datetime] = None # 创建时间
updated_at: Optional[datetime] = None # 更新时间
created_by: Optional[int] = None # 创建人ID
updated_by: Optional[int] = None # 更新人ID
api_key_masked: Optional[str] = "" # 掩码后的API密钥数据库存储
# ====================== 核心工具方法 ======================
def __post_init__(self):
"""后置初始化:自动校验和修正配置"""
# 1. 嵌入式模型强制清空推理参数(避免误用)
if self.is_embedding:
self.max_tokens = 0
self.temperature = 0.0
self.top_p = 0.0
# 2. 校验必填参数(按模型类型)
self._validate_required_fields()
def _validate_required_fields(self):
"""按模型类型校验必填参数"""
# 在线模型校验
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
if not self.api_key:
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
# 本地模型校验
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
if not self.model_path:
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存入/更新数据库)"""
return {
key: value for key, value in self.__dict__.items()
if not key.startswith('_') # 排除私有属性
}
@classmethod
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
"""从数据库字典初始化配置(核心方法)"""
# 1. 时间字段转换:字符串 → datetime
time_fields = ['last_used_at', 'created_at', 'updated_at']
for field_name in time_fields:
val = db_dict.get(field_name)
if val and isinstance(val, str):
try:
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
except (ValueError, TypeError):
db_dict[field_name] = None
# 2. 过滤数据库中无关字段(如 api_key_masked
valid_fields = cls.__dataclass_fields__.keys()
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
# 3. 初始化并返回配置实例
return cls(**filtered_dict)
def get_model_type(self) -> str:
"""快速判断模型类型返回online/local/embedding"""
if self.is_embedding:
return "embedding"
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
return "online"
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
return "local"
return "unknown"
class BaseLLM(BaseChatModel):
"""
继承 LangChain BaseChatModelBaseLanguageModel 的子类
使其能直接用于 create_agent
"""
# 配置参数(通过 __init__ 初始化)
config: Any = None
model: Any = None
def __init__(self, config):
super().__init__() # 必须调用父类构造函数
self.config = config
self.model = None
self._validate_config()
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
# ---------------------- 必须实现的核心抽象方法LangChain 协议) ----------------------
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
核心同步生成方法LangChain 要求必须实现
messages: 消息列表 [HumanMessage(content="你好")]
返回 ChatResult 类型LangChain 标准输出
"""
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
** kwargs: Any,
) -> ChatResult:
"""异步生成方法LangChain 异步协议)"""
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
@property
def _llm_type(self) -> str:
"""返回模型类型标识(如 "openai""llama""bge""""
return self.__class__.__name__
def load_model(self) -> None:
"""加载模型(自定义逻辑)"""
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
def close(self) -> None:
"""释放资源(自定义逻辑)"""
if self.model:
logger.info(f"释放 {self.__class__.__name__} 模型资源")
self.model = None
def __enter__(self):
self.load_model()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@ -0,0 +1,77 @@
from typing import List
from langchain_core.embeddings import Embeddings
from loguru import logger
from th_agenter.llm.base_llm import BaseLLM
class EmbedLLM(BaseLLM, Embeddings):
"""嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel"""
def __init__(self, config):
logger.info(f"初始化 EmbedLLM 模型: {config.model_name}")
super().__init__(config)
logger.info(f"已加载 EmbedLLM 模型: {config.model_name}")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""LangChain 要求的核心方法:批量文档向量化"""
pass
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步批量向量化"""
pass
def embed_query(self, text: str) -> List[float]:
"""单查询文本向量化"""
pass
async def aembed_query(self, text: str) -> List[float]:
"""异步单查询向量化"""
pass
# 具体实现 BGE 嵌入式模型
class BGEEmbedLLM(EmbedLLM):
def __init__(self, config):
super().__init__(config)
def _validate_config(self):
if not self.config.model_name:
raise ValueError("必须配置 model_name")
def load_model(self):
logger.info(f"正在加载 嵌入 模型: {self.config.model_name}")
if hasattr(self.config, 'provider') and self.config.provider == 'ollama':
from langchain_ollama import OllamaEmbeddings
self.model = OllamaEmbeddings(
model=self.config.model_name,
base_url=self.config.base_url if hasattr(self.config, 'base_url') else None
)
else:
try:
from langchain_huggingface import HuggingFaceEmbeddings
self.model = HuggingFaceEmbeddings(
model_name=self.config.model_name,
model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"},
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True}
)
except ImportError as e:
logger.error(f"Failed to load HuggingFaceEmbeddings: {e}")
logger.error("Please install sentence-transformers: pip install sentence-transformers")
raise
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.model:
self.load_model()
return self.model.embed_documents(texts)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.model:
self.load_model()
return await self.model.aembed_documents(texts)
def embed_query(self, text: str) -> List[float]:
if not self.model:
self.load_model()
return self.model.embed_query(text)
async def aembed_query(self, text: str) -> List[float]:
if not self.model:
self.load_model()
return await self.model.aembed_query(text)

View File

@ -0,0 +1,70 @@
import os, dotenv
from loguru import logger
from utils.Constant import Constant
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
# 加载环境变量
dotenv.load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL")
class LLM_Model_Base(object):
'''
语言模型基类
所有语言模型类的基类定义了语言模型的基本属性和方法
- 语言模型名称 缺省为"gpt-4o-mini"
- 温度缺省为0.7
- 语言模型实例 由子类实现
- 语言模型模式 由子类实现
- 语言模型名称 用于描述语言模型 在人机界面中显示
author: DrGraph
date: 2025-11-20
'''
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
self.model_name = model_name # 0.15 0.6
self.temperature = temperature
self.llmModel = None
self.mode = Constant.LLM_MODE_NONE
self.name = '未知模型'
def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str):
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
prompt_template = PromptTemplate.from_template(
template="请回答以下问题: {question}",
)
prompt_template_value = None
if methodType == "format":
# 方式1 - 使用format方法取得字符串
prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}")
if valueType == "str":
# 1.1 直接用字符串进行调用LLM的invoke
prompt_template_value = prompt_str
logger.info(f"{self.name} >>> 1.2.1 直接使用字符串")
elif valueType == "messages":
# 1.2 由字符串创建HumanMessage对象列表
prompt_template_value = [HumanMessage(content=prompt)]
logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表")
elif methodType == "invoke":
# 方式2 - 使用invoke方法取得PromptValue
prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}")
if valueType == "str":
# 2.1 再倒回字符串方式
prompt_template_value = prompt_value.to_string()
logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串")
elif valueType == "promptValue":
# 2.2 直接使用 prompt_value 作为 prompt_template_value
prompt_template_value = prompt_value
logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value")
elif valueType == "messages":
# 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表
prompt_template_value = prompt_value.to_messages()
logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表")
logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}")
return prompt_template_value

View File

@ -0,0 +1,29 @@
from langchain_openai import ChatOpenAI
from loguru import logger
from DrGraph.utils.Constant import Constant
from LLM.llm_model_base import LLM_Model_Base
class Chat_LLM(LLM_Model_Base):
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
super().__init__(model_name, temperature)
self.name = '聊天模型'
self.mode = Constant.LLM_MODE_CHAT
self.llmModel = ChatOpenAI(
model_name=self.model_name,
temperature=self.temperature,
)
# 返回消息格式以便在chatbot中显示
def invoke(self, prompt: str):
prompt_template_value = self.buildPromptTemplateValue(
prompt=prompt,
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
try:
response = self.llmModel.invoke(prompt_template_value)
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
# response = {"role": "assistant", "content": response.content}
except Exception as e:
logger.error(e)
return response

View File

@ -0,0 +1,44 @@
'''
非聊天模型类,继承自 LLM_Model_Base
author: DrGraph
date: 2025-11-20
'''
from loguru import logger
from langchain_openai import OpenAI
from langchain_core.messages import AIMessage
from DrGraph.utils.Constant import Constant
from LLM.llm_model_base import LLM_Model_Base
class NonChat_LLM(LLM_Model_Base):
'''
非聊天模型类继承自 LLM_Model_Base调用这个非聊天模型OpenAI
- 语言模型名称 缺省为"gpt-4o-mini"
- 温度缺省为0.7
- 语言模型名称 = "非聊天模型" 在人机界面中显示
'''
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
super().__init__(model_name, temperature)
self.name = '非聊天模型'
self.mode = Constant.LLM_MODE_NONCHAT
self.llmModel = OpenAI(
model_name=self.model_name,
temperature=self.temperature,
)
# 返回消息格式以便在chatbot中显示
def invoke(self, prompt: str):
'''
调用非聊天模型返回消息格式以便在chatbot中显示
prompt: 用户输入为字符串类型
return: 助手回复为字符串类型
'''
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
try:
response = self.llmModel.invoke(prompt)
logger.info(f"{self.name} >>> 1.2 助手回复: {type(response)}")
except Exception as e:
logger.error(e)
return response

View File

@ -0,0 +1,28 @@
from loguru import logger
from DrGraph.utils.Constant import Constant
from LLM.llm_model_base import LLM_Model_Base
from langchain_ollama import ChatOllama
class Chat_Ollama(LLM_Model_Base):
def __init__(self, base_url="http://127.0.0.1:11434", model_name: str = "OxW/Qwen3-0.6B-GGUF:latest", temperature: float = 0.7):
super().__init__(model_name, temperature)
self.name = '私有化Ollama模型'
self.base_url = base_url
self.llmModel = ChatOllama(
base_url = self.base_url,
model=model_name,
temperature=temperature
)
self.mode = Constant.LLM_MODE_LOCAL_OLLAMA
def invoke(self, prompt: str):
prompt_template_value = self.buildPromptTemplateValue(
prompt=prompt,
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
try:
response = self.llmModel.invoke(prompt_template_value)
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
except Exception as e:
logger.error(e)
return response

View File

@ -0,0 +1,68 @@
from typing import List, Optional
from th_agenter.llm.base_llm import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.outputs import ChatResult, ChatGeneration
class LocalLLM(BaseLLM):
def __init__(self, config):
super().__init__(config)
self.local_config = config
def _validate_config(self):
if not self.local_config.model_path:
raise ValueError("LocalLLM 必须配置 model_path")
def load_model(self):
from langchain_community.llms import LlamaCpp
self.model = LlamaCpp(
model_path=self.local_config.model_path,
temperature=self.local_config.temperature,
max_tokens=self.local_config.max_tokens,
n_ctx=self.local_config.n_ctx,
n_threads=self.local_config.n_threads,
verbose=False
)
@property
def _llm_type(self) -> str:
return "llama"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
# 适配 LlamaCpp非 Chat 模型)的调用方式
prompt = self._format_messages(messages)
text = self.model.invoke(prompt, stop=stop, **kwargs)
# 构造 ChatResultLangChain 标准格式)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
prompt = self._format_messages(messages)
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
def _format_messages(self, messages: List[BaseMessage]) -> str:
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
prompt_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
elif isinstance(msg, AIMessage):
prompt_parts.append(msg.content)
return "".join(prompt_parts)

View File

@ -0,0 +1,80 @@
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, BaseMessage
from typing import List, Optional, Any, Union
from langchain_core.outputs import ChatResult
from th_agenter.llm.base_llm import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
class OnlineLLM(BaseLLM):
def __init__(self, config):
super().__init__(config)
def _validate_config(self):
if not self.config.api_key:
raise ValueError("OnlineLLM 必须配置 api_key")
def load_model(self):
# from langchain.chat_models import init_chat_model
# self.model = init_chat_model(
# self.config.model_name,
# self.config.api_key)
from langchain_openai import ChatOpenAI
self.model = ChatOpenAI(
api_key=self.config.api_key,
model_name=self.config.model_name,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
base_url=self.config.base_url,
)
@property
def _llm_type(self) -> str:
return "openai" # 标识模型类型
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
"""委托给底层 LangChain 模型的 _generate 方法"""
if not self.model:
self.load_model()
# 复用底层模型的实现
return self.model._generate(
messages=messages,
stop=stop,
run_manager=run_manager,** kwargs
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
return await self.model._agenerate(
messages=messages,
stop=stop,
run_manager=run_manager,** kwargs
)
# ---------------------- 保留自定义的便捷方法 ----------------------
def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
"""自定义便捷方法:直接传入字符串 prompt 或消息列表"""
if isinstance(prompt, str):
messages = [HumanMessage(content=prompt)]
else:
messages = prompt
result = self._generate(messages, **kwargs)
return result.generations[0].text
async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
"""自定义便捷异步方法:直接传入字符串 prompt 或消息列表"""
if isinstance(prompt, str):
messages = [HumanMessage(content=prompt)]
else:
messages = prompt
result = await self._agenerate(messages, **kwargs)
return result.generations[0].text

View File

@ -24,15 +24,20 @@ class Conversation(BaseModel):
# Relationships removed to eliminate foreign key constraints
def to_dict(self) -> dict:
"""Convert conversation to a dictionary."""
return {
"id": self.id,
"title": self.title,
"user_id": self.user_id,
"knowledge_base_id": self.knowledge_base_id,
"system_prompt": self.system_prompt,
"model_name": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"is_archived": self.is_archived,
"message_count": self.message_count,
"last_message_at": self.last_message_at,
}
def __repr__(self):
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
@property
def message_count(self):
"""Get the number of messages in this conversation."""
return len(self.messages)
@property
def last_message_at(self):
"""Get the timestamp of the last message."""
return self.messages[-1].created_at or self.created_at
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id}, system_prompt={self.system_prompt}, model_name='{self.model_name}', temperature='{self.temperature}', message_count={self.message_count})>"

View File

@ -39,7 +39,7 @@ class LLMConfig(BaseModel):
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
def __repr__(self):
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model_name='{self.model_name}', base_url='{self.base_url}')>"
def to_dict(self, include_sensitive=False):
"""Convert to dictionary, optionally excluding sensitive data."""
@ -60,7 +60,7 @@ class LLMConfig(BaseModel):
'is_embedding': self.is_embedding,
'extra_config': self.extra_config,
'usage_count': self.usage_count,
'last_used_at': self.last_used_at
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None
})
if include_sensitive:
@ -102,8 +102,8 @@ class LLMConfig(BaseModel):
if not self.name or not self.name.strip():
return {"valid": False, "error": "配置名称不能为空"}
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
return {"valid": False, "error": "不支持的服务商"}
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']:
return {"valid": False, "error": f"不支持的服务商 {self.provider}"}
if not self.model_name or not self.model_name.strip():
return {"valid": False, "error": "模型名称不能为空"}

View File

@ -34,7 +34,7 @@ class LLMConfigCreate(LLMConfigBase):
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom', "doubao"
'ollama', 'custom', "doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
@ -74,7 +74,7 @@ class LLMConfigUpdate(BaseModel):
allowed_providers = [
'openai', 'azure', 'anthropic', 'google', 'baidu',
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
'ollama', 'custom',"doubao"
'ollama', 'custom',"doubao", "ollama"
]
if v.lower() not in allowed_providers:
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')

View File

@ -33,14 +33,16 @@ class AgentConfig(BaseModel):
class AgentService:
"""LangChain Agent service with tool calling capabilities."""
def __init__(self, db_session=None):
def __init__(self):
self.settings = get_settings()
async def initialize(self, session=None):
self.tool_registry = ToolRegistry()
self.config = AgentConfig()
self.db_session = db_session
self.config_service = AgentConfigService(db_session) if db_session else None
self.session = session
self.config_service = AgentConfigService(session) if session else None
self._initialize_tools()
self._load_config()
await self._load_config()
def _initialize_tools(self):
"""Initialize and register all available tools."""
@ -56,18 +58,17 @@ class AgentService:
self.tool_registry.register(tool)
logger.info(f"Registered tool: {tool.get_name()}")
def _load_config(self):
async def _load_config(self):
"""Load configuration from database if available."""
if self.config_service:
try:
config_dict = self.config_service.get_config_dict()
config_dict = await self.config_service.get_config_dict()
# Update config with database values
for key, value in config_dict.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
logger.info("Loaded agent configuration from database")
except Exception as e:
logger.warning(f"Failed to load config from database, using defaults: {str(e)}")
logger.error(f"Failed to load config from database, using defaults: {str(e)}")
def _get_enabled_tools(self) -> List[Any]:
"""Get list of enabled LangChain tools."""
@ -83,11 +84,11 @@ class AgentService:
return enabled_tools
def _create_agent_executor(self) -> Any:
async def _create_agent_executor(self) -> Any:
"""Create LangChain agent executor."""
# Get LLM configuration
from ...core.llm import create_llm
llm = create_llm()
from ...core.new_agent import new_agent
llm = await new_agent()
# Get enabled tools
tools = self._get_enabled_tools()
@ -114,7 +115,7 @@ class AgentService:
logger.info(f"Processing agent chat message: {message[:100]}...")
# Create agent
agent = self._create_agent_executor()
agent = await self._create_agent_executor()
# Convert chat history to LangChain format
langchain_history = []
@ -155,7 +156,7 @@ class AgentService:
logger.info(f"Processing agent chat stream: {message[:100]}...")
# Create agent
agent = self._create_agent_executor()
agent = await self._create_agent_executor()
# Convert chat history to LangChain format
langchain_history = []
@ -263,17 +264,18 @@ class AgentService:
# Global agent service instance
_agent_service: Optional[AgentService] = None
_global_agent_service: Optional[AgentService] = None
def get_agent_service(db_session=None) -> AgentService:
async def get_agent_service(session=None) -> AgentService:
"""Get global agent service instance."""
global _agent_service
if _agent_service is None:
_agent_service = AgentService(db_session)
elif db_session and not _agent_service.db_session:
global _global_agent_service
if _global_agent_service is None:
_global_agent_service = AgentService()
await _global_agent_service.initialize(session)
elif session and session != _global_agent_service.session:
# Update with database session if not already set
_agent_service.db_session = db_session
_agent_service.config_service = AgentConfigService(db_session)
_agent_service._load_config()
return _agent_service
_global_agent_service.session = session
_global_agent_service.config_service = AgentConfigService(session)
_global_agent_service._load_config()
return _global_agent_service

View File

@ -41,16 +41,18 @@ class LangGraphAgentConfig(BaseModel):
class LangGraphAgentService:
"""LangGraph Agent service using low-level LangGraph graph (React pattern)."""
def __init__(self, db_session=None):
def __init__(self):
self.settings = get_settings()
async def initialize(self, session=None):
self.tool_registry = ToolRegistry()
self.config = LangGraphAgentConfig()
self.tools = []
self.db_session = db_session
self.config_service = AgentConfigService(db_session) if db_session else None
self.session = session
self.config_service = AgentConfigService(session) if session else None
self._initialize_tools()
self._load_config()
self._create_react_agent()
await self._load_config()
await self._create_react_agent()
def _initialize_tools(self):
"""Initialize available tools."""
@ -76,28 +78,29 @@ class LangGraphAgentService:
def _load_config(self):
async def _load_config(self):
"""Load configuration from database if available."""
if self.config_service:
try:
db_config = self.config_service.get_active_config()
if db_config:
# Update config with database values
config_dict = db_config.config_data
for key, value in config_dict.items():
if hasattr(self.config, key):
setattr(self.config, key, value)
logger.info("Loaded configuration from database")
except Exception as e:
logger.warning(f"Failed to load config from database: {e}")
pass
# if self.config_service:
# try:
# db_config = self.config_service.get_active_config()
# if db_config:
# # Update config with database values
# config_dict = db_config.config_data
# for key, value in config_dict.items():
# if hasattr(self.config, key):
# setattr(self.config, key, value)
# logger.info("Loaded configuration from database")
# except Exception as e:
# logger.exception(f"Failed to load config from database: {e}")
def _create_react_agent(self):
async def _create_react_agent(self):
"""Create LangGraph agent using low-level StateGraph with explicit nodes/edges."""
try:
# Initialize the model
llm_config = get_settings().llm.get_current_config()
llm_config = await get_settings().llm.get_current_config(self.db_session)
self.model = init_chat_model(
model=llm_config['model'],
model_provider='openai',
@ -183,7 +186,7 @@ class LangGraphAgentService:
# Compile graph and store as self.agent for compatibility with existing code
self.react_agent = graph.compile()
logger.info("LangGraph low-level React agent created successfully")
logger.info("LangGraph 底层 React 智能体创建成功")
except Exception as e:
logger.error(f"Failed to create agent: {str(e)}")
@ -723,15 +726,14 @@ class LangGraphAgentService:
# Global instance
_langgraph_agent_service: Optional[LangGraphAgentService] = None
_global_langgraph_agent_service: Optional[LangGraphAgentService] = None
def get_langgraph_agent_service(db_session=None) -> LangGraphAgentService:
async def get_langgraph_agent_service(session=None) -> LangGraphAgentService:
"""Get or create LangGraph agent service instance."""
global _langgraph_agent_service
global _global_langgraph_agent_service
if _langgraph_agent_service is None:
_langgraph_agent_service = LangGraphAgentService(db_session)
logger.info("LangGraph Agent service initialized")
if _global_langgraph_agent_service is None:
_global_langgraph_agent_service = LangGraphAgentService()
await _global_langgraph_agent_service.initialize(session)
return _langgraph_agent_service
return _global_langgraph_agent_service

View File

@ -2,7 +2,7 @@
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy import and_, select, update
from ..models.agent_config import AgentConfig
from utils.util_exceptions import ValidationError, NotFoundError
@ -15,7 +15,7 @@ class AgentConfigService:
def __init__(self, db: Session):
self.db = db
def create_config(self, config_data: Dict[str, Any]) -> AgentConfig:
async def create_config(self, config_data: Dict[str, Any]) -> AgentConfig:
"""Create a new agent configuration."""
try:
# Validate required fields
@ -23,9 +23,8 @@ class AgentConfigService:
raise ValidationError("Configuration name is required")
# Check if name already exists
existing = self.db.query(AgentConfig).filter(
AgentConfig.name == config_data["name"]
).first()
stmt = select(AgentConfig).where(AgentConfig.name == config_data["name"])
existing = (await self.db.execute(stmt)).scalar_one_or_none()
if existing:
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
@ -44,62 +43,60 @@ class AgentConfigService:
# If this is set as default, unset other defaults
if config.is_default:
self.db.query(AgentConfig).filter(
AgentConfig.is_default == True
).update({"is_default": False})
stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False})
await self.db.execute(stmt)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
await self.db.commit()
await self.db.refresh(config)
logger.info(f"Created agent configuration: {config.name}")
return config
except Exception as e:
self.db.rollback()
await self.db.rollback()
logger.error(f"Error creating agent configuration: {str(e)}")
raise
def get_config(self, config_id: int) -> Optional[AgentConfig]:
async def get_config(self, config_id: int) -> Optional[AgentConfig]:
"""Get agent configuration by ID."""
return self.db.query(AgentConfig).filter(
AgentConfig.id == config_id
).first()
stmt = select(AgentConfig).where(AgentConfig.id == config_id)
return (await self.db.execute(stmt)).scalar_one_or_none()
def get_config_by_name(self, name: str) -> Optional[AgentConfig]:
async def get_config_by_name(self, name: str) -> Optional[AgentConfig]:
"""Get agent configuration by name."""
return self.db.query(AgentConfig).filter(
AgentConfig.name == name
).first()
stmt = select(AgentConfig).where(AgentConfig.name == name)
return (await self.db.execute(stmt)).scalar_one_or_none()
def get_default_config(self) -> Optional[AgentConfig]:
async def get_default_config(self) -> Optional[AgentConfig]:
"""Get default agent configuration."""
return self.db.query(AgentConfig).filter(
and_(AgentConfig.is_default == True, AgentConfig.is_active == True)
).first()
stmt = select(AgentConfig).where(and_(AgentConfig.is_default == True, AgentConfig.is_active == True))
return (await self.db.execute(stmt)).scalar_one_or_none()
def list_configs(self, active_only: bool = True) -> List[AgentConfig]:
"""List all agent configurations."""
query = self.db.query(AgentConfig)
stmt = select(AgentConfig)
if active_only:
query = query.filter(AgentConfig.is_active == True)
return query.order_by(AgentConfig.created_at.desc()).all()
stmt = stmt.where(AgentConfig.is_active == True)
stmt = stmt.order_by(AgentConfig.created_at.desc())
return self.db.execute(stmt).scalars().all()
def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig:
async def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig:
"""Update agent configuration."""
try:
config = self.get_config(config_id)
config = await self.get_config(config_id)
if not config:
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
# Check if name change conflicts with existing
if "name" in config_data and config_data["name"] != config.name:
existing = self.db.query(AgentConfig).filter(
stmt = select(AgentConfig).where(
and_(
AgentConfig.name == config_data["name"],
AgentConfig.id != config_id
)
).first()
)
existing = (await self.db.execute(stmt)).scalar_one_or_none()
if existing:
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
@ -110,28 +107,29 @@ class AgentConfigService:
# If this is set as default, unset other defaults
if config_data.get("is_default", False):
self.db.query(AgentConfig).filter(
stmt = update(AgentConfig).where(
and_(
AgentConfig.is_default == True,
AgentConfig.id != config_id
)
).update({"is_default": False})
).values({"is_default": False})
await self.db.execute(stmt)
self.db.commit()
self.db.refresh(config)
await self.db.commit()
await self.db.refresh(config)
logger.info(f"Updated agent configuration: {config.name}")
return config
except Exception as e:
self.db.rollback()
await self.db.rollback()
logger.error(f"Error updating agent configuration: {str(e)}")
raise
def delete_config(self, config_id: int) -> bool:
async def delete_config(self, config_id: int) -> bool:
"""Delete agent configuration (soft delete by setting is_active=False)."""
try:
config = self.get_config(config_id)
config = await self.get_config(config_id)
if not config:
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
@ -146,14 +144,14 @@ class AgentConfigService:
return True
except Exception as e:
self.db.rollback()
await self.db.rollback()
logger.error(f"Error deleting agent configuration: {str(e)}")
raise
def set_default_config(self, config_id: int) -> AgentConfig:
async def set_default_config(self, config_id: int) -> AgentConfig:
"""Set a configuration as default."""
try:
config = self.get_config(config_id)
config = await self.get_config(config_id)
if not config:
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
@ -161,29 +159,28 @@ class AgentConfigService:
raise ValidationError("Cannot set inactive configuration as default")
# Unset other defaults
self.db.query(AgentConfig).filter(
AgentConfig.is_default == True
).update({"is_default": False})
stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False})
await self.db.execute(stmt)
# Set this as default
config.is_default = True
self.db.commit()
self.db.refresh(config)
await self.db.commit()
await self.db.refresh(config)
logger.info(f"Set default agent configuration: {config.name}")
return config
except Exception as e:
self.db.rollback()
await self.db.rollback()
logger.error(f"Error setting default agent configuration: {str(e)}")
raise
def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]:
async def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]:
"""Get configuration as dictionary. If no ID provided, returns default config."""
if config_id:
config = self.get_config(config_id)
config = await self.get_config(config_id)
else:
config = self.get_default_config()
config = await self.get_default_config()
if not config:
# Return default values if no configuration found

View File

@ -33,19 +33,16 @@ class AuthService:
)
token = credentials.credentials
session.desc = f"[AuthService] 取得token: {token[:50]}..."
payload = AuthService.verify_token(token)
if payload is None:
session.desc = "ERROR: 令牌验证失败"
session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..."
raise credentials_exception
session.desc = f"[AuthService] 令牌有效 - 解析得到有效载荷: {payload}"
username: str = payload.get("sub")
if username is None:
session.desc = "ERROR: 令牌中没有用户名"
raise credentials_exception
session.desc = f"[AuthService] 获取当前用户 - 查找名为 {username} 的用户"
stmt = select(User).where(User.username == username)
user = (await session.execute(stmt)).scalar_one_or_none()
if user is None:
@ -53,8 +50,8 @@ class AuthService:
raise credentials_exception
# Set user in context for global access
UserContext.set_current_user(user)
session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
UserContext.set_current_user(user, canLog=True)
# session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
return user
@ -138,6 +135,4 @@ class AuthService:
except jwt.PyJWTError as e:
logger.error(f"Token verification failed: {e}")
logger.error(f"Token: {token[:50]}...")
logger.error(f"Secret key: {settings.security.secret_key[:20]}...")
logger.error(f"Algorithm: {settings.security.algorithm}")
return None

View File

@ -1,44 +1,36 @@
"""Chat service for AI model integration using LangChain."""
from th_agenter import db
import json
import asyncio
import os
from typing import AsyncGenerator, Optional, List, Dict, Any
from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict
from sqlalchemy.orm import Session
from loguru import logger
from th_agenter.core.new_agent import new_agent, new_llm
from ..core.config import settings
from ..models.message import MessageRole
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
from utils.util_exceptions import ChatServiceError, OpenAIError
from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError
from .conversation import ConversationService
from .langchain_chat import LangChainChatService
from .knowledge_chat import KnowledgeChatService
from .agent.agent_service import get_agent_service
from .agent.langgraph_agent_service import get_langgraph_agent_service
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
class AgentState(TypedDict):
messages: List[dict] # 存储对话消息(核心记忆)
class ChatService:
"""Service for handling AI chat functionality using LangChain."""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
# Initialize LangChain chat service
self.langchain_service = LangChainChatService(db)
# Initialize Knowledge chat service
self.knowledge_service = KnowledgeChatService(db)
# Initialize Agent service with database session
self.agent_service = get_agent_service(db)
# Initialize LangGraph Agent service with database session
self.langgraph_service = get_langgraph_agent_service(db)
logger.info("ChatService initialized with LangChain backend and Agent support")
_checkpointer_initialized = False
_conn_string = None
async def chat(
self,
@ -57,7 +49,7 @@ class ChatService:
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
# Use knowledge base chat service
return await self.knowledge_service.chat_with_knowledge_base(
return await self.knowledge_chat_service.chat_with_knowledge_base(
conversation_id=conversation_id,
message=message,
knowledge_base_id=knowledge_base_id,
@ -69,29 +61,29 @@ class ChatService:
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
# Get conversation history for LangGraph agent
conversation = self.conversation_service.get_conversation(conversation_id)
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = self.conversation_service.get_conversation_messages(conversation_id)
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Use LangGraph agent service
agent_result = await self.langgraph_service.chat(message, chat_history)
agent_result = await self.langgraph_agent_service.chat(message, chat_history)
if agent_result["success"]:
# Save user message
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
@ -114,11 +106,11 @@ class ChatService:
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
# Get conversation history for agent
conversation = self.conversation_service.get_conversation(conversation_id)
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = self.conversation_service.get_conversation_messages(conversation_id)
messages = await self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
@ -129,14 +121,14 @@ class ChatService:
if agent_result["success"]:
# Save user message
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Save assistant response
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=agent_result["response"],
role=MessageRole.ASSISTANT,
@ -159,7 +151,7 @@ class ChatService:
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
# Delegate to LangChain service
return await self.langchain_service.chat(
return await self.langchain_chat_service.chat(
conversation_id=conversation_id,
message=message,
stream=stream,
@ -167,156 +159,12 @@ class ChatService:
max_tokens=max_tokens
)
async def chat_stream(
self,
conversation_id: int,
message: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
use_agent: bool = False,
use_langgraph: bool = False,
use_knowledge_base: bool = False,
knowledge_base_id: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
if use_knowledge_base and knowledge_base_id:
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
# Use knowledge base chat service streaming
async for content in self.knowledge_service.chat_stream_with_knowledge_base(
conversation_id=conversation_id,
message=message,
knowledge_base_id=knowledge_base_id,
temperature=temperature,
max_tokens=max_tokens
):
# Create stream chunk for compatibility with existing API
stream_chunk = StreamChunk(
content=content,
role=MessageRole.ASSISTANT
)
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
elif use_langgraph:
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangGraph Agent")
# Get conversation history for LangGraph agent
conversation = self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Save user message first
user_message = self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Use LangGraph agent service streaming
full_response = ""
intermediate_steps = []
async for chunk in self.langgraph_service.chat_stream(message, chat_history):
if chunk["type"] == "response":
full_response = chunk["content"]
intermediate_steps = chunk.get("intermediate_steps", [])
# Return the chunk as-is to maintain type information
yield json.dumps(chunk, ensure_ascii=False)
elif chunk["type"] == "error":
# Return the chunk as-is to maintain type information
yield json.dumps(chunk, ensure_ascii=False)
return
else:
# For other types (status, step, etc.), pass through
yield json.dumps(chunk, ensure_ascii=False)
# Save assistant response
if full_response:
self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
message_metadata={"intermediate_steps": intermediate_steps}
)
elif use_agent:
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Agent")
# Get conversation history for agent
conversation = self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
messages = self.conversation_service.get_conversation_messages(conversation_id)
chat_history = [{
"role": "user" if msg.role == MessageRole.USER else "assistant",
"content": msg.content
} for msg in messages]
# Save user message first
user_message = self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Use agent service streaming
full_response = ""
tool_calls = []
async for chunk in self.agent_service.chat_stream(message, chat_history):
if chunk["type"] == "response":
full_response = chunk["content"]
tool_calls = chunk.get("tool_calls", [])
# Return the chunk as-is to maintain type information
yield json.dumps(chunk, ensure_ascii=False)
elif chunk["type"] == "error":
# Return the chunk as-is to maintain type information
yield json.dumps(chunk, ensure_ascii=False)
return
else:
# For other types (status, tool_start, etc.), pass through
yield json.dumps(chunk, ensure_ascii=False)
# Save assistant response
if full_response:
self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
message_metadata={"tool_calls": tool_calls}
)
else:
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangChain")
# Delegate to LangChain service and wrap response in JSON format
async for content in self.langchain_service.chat_stream(
conversation_id=conversation_id,
message=message,
temperature=temperature,
max_tokens=max_tokens
):
# Create stream chunk for compatibility with existing API
stream_chunk = StreamChunk(
content=content,
role=MessageRole.ASSISTANT
)
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
async def get_available_models(self) -> List[str]:
"""Get list of available models from LangChain."""
logger.info("Getting available models via LangChain")
# Delegate to LangChain service
return await self.langchain_service.get_available_models()
return await self.langchain_chat_service.get_available_models()
def update_model_config(
self,
@ -328,8 +176,135 @@ class ChatService:
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
# Delegate to LangChain service
self.langchain_service.update_model_config(
self.langchain_chat_service.update_model_config(
model=model,
temperature=temperature,
max_tokens=max_tokens
)
# -------------------------------------------------------------------------
def __init__(self, session: Session):
self.session = session
async def initialize(self, conversation_id: int, streaming: bool = False):
self.conversation_service = ConversationService(self.session)
self.session.desc = "ChatService初始化 - ConversationService 实例化完毕"
self.conversation = await self.conversation_service.get_conversation(
conversation_id=conversation_id
)
if not self.conversation:
raise ChatServiceError(f"Conversation {conversation_id} not found")
if not ChatService._checkpointer_initialized:
from langgraph.checkpoint.postgres import PostgresSaver
import psycopg2
CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres"
ChatService._conn_string = CONN_STRING
# 检查必要的表是否已存在
tables_need_setup = True
try:
# 连接到数据库并检查表是否存在
conn = psycopg2.connect(CONN_STRING)
cursor = conn.cursor()
# 检查langgraph需要的表是否存在
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs')
""")
existing_tables = [row[0] for row in cursor.fetchall()]
# 检查是否所有必要的表都存在
required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs']
if all(table in existing_tables for table in required_tables):
tables_need_setup = False
self.session.desc = "ChatService初始化 - 检测到langgraph表已存在跳过setup"
cursor.close()
conn.close()
except Exception as e:
self.session.desc = f"ChatService初始化 - 检查表存在性失败: {str(e)}将进行setup"
tables_need_setup = True
# 只有在需要时才进行setup
if tables_need_setup:
self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup"
try:
async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer:
await checkpointer.setup()
self.session.desc = "ChatService初始化 - PostgresSaver setup完成"
logger.info("PostgresSaver setup完成")
except Exception as e:
self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}"
logger.error(f"PostgresSaver setup失败: {e}")
raise
else:
self.session.desc = "ChatService初始化 - 使用现有的langgraph表"
# 存储连接字符串供后续使用
ChatService._checkpointer_initialized = True
self.llm = await new_llm(session=self.session, streaming=streaming)
self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}"
def get_config(self):
config = {
"configurable": {
"thread_id": str(self.conversation.id),
"checkpoint_ns": "drgraph"
}
}
return config
async def chat_stream(
self,
message: str
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}"
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.USER,
content=message
)
full_assistant_content = ""
async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
from langchain.agents import create_agent
agent = create_agent(
model=self.llm, # await new_llm(session=self.session, streaming=self.streaming),
checkpointer=checkpointer
)
async for chunk in agent.astream(
{"messages": [{"role": "user", "content": message}]},
config=self.get_config(),
stream_mode="messages"
):
full_assistant_content += chunk[0].content
json_result = {"data": {"v": chunk[0].content }}
yield json.dumps(
json_result,
ensure_ascii=True
)
if len(full_assistant_content) > 0:
await self.conversation_service.add_message(
conversation_id=self.conversation.id,
role=MessageRole.ASSISTANT,
content=full_assistant_content
)
def get_conversation_history_messages(
self, conversation_id: int, skip: int = 0, limit: int = 100
):
"""Get conversation history messages with pagination."""
result = []
with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer:
checkpoints = checkpointer.list(self.get_config())
for checkpoint in checkpoints:
print(checkpoint)
result.append(checkpoint.messages)
return result

View File

@ -3,6 +3,9 @@
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, desc, func, or_
from langchain_core.messages import HumanMessage, AIMessage
from th_agenter.db.database import AsyncSessionFactory
from ..models.conversation import Conversation
from ..models.message import Message, MessageRole
@ -24,7 +27,7 @@ class ConversationService:
conversation_data: ConversationCreate
) -> Conversation:
"""Create a new conversation."""
logger.info(f"Creating new conversation for user {user_id}: {conversation_data}")
self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}"
try:
conversation = Conversation(
@ -39,18 +42,23 @@ class ConversationService:
await self.session.commit()
await self.session.refresh(conversation)
logger.info(f"Successfully created conversation {conversation.id} for user {user_id}")
self.session.desc = f"创建新会话 Conversation ID: {conversation.id}用户ID: {user_id}"
return conversation
except Exception as e:
logger.error(f"Failed to create conversation: {str(e)}", exc_info=True)
self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}"
await self.session.rollback()
raise DatabaseError(f"Failed to create conversation: {str(e)}")
raise DatabaseError(f"创建会话失败: {str(e)}")
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
"""Get a conversation by ID."""
try:
user_id = UserContext.get_current_user_id()
self.session.desc = f"获取会话 - 会话ID: {conversation_id}用户ID: {user_id}"
if user_id is None:
logger.error(f"Failed to get conversation {conversation_id}: No user context available")
return None
conversation = await self.session.scalar(
select(Conversation).where(
Conversation.id == conversation_id,
@ -59,12 +67,11 @@ class ConversationService:
)
if not conversation:
logger.warning(f"Conversation {conversation_id} not found")
self.session.desc = f"警告: 会话 {conversation_id} 不存在用户ID: {user_id}"
return conversation
except Exception as e:
logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True)
self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id}用户ID: {user_id},错误: {str(e)}"
raise DatabaseError(f"Failed to get conversation: {str(e)}")
async def get_user_conversations(
@ -78,6 +85,10 @@ class ConversationService:
) -> List[Conversation]:
"""Get user's conversations with search and filtering."""
user_id = UserContext.get_current_user_id()
if user_id is None:
logger.error("Failed to get user conversations: No user context available")
return []
query = select(Conversation).where(
Conversation.user_id == user_id
)
@ -137,7 +148,7 @@ class ConversationService:
if not conversation:
return False
self.session.delete(conversation)
await self.session.delete(conversation)
await self.session.commit()
return True
@ -180,19 +191,42 @@ class ConversationService:
# Set audit fields
message.set_audit_fields()
self.session.add(message)
await self.session.commit()
await self.session.refresh(message)
session = AsyncSessionFactory()
session.begin()
try:
session.add(message)
await session.commit()
await session.refresh(message)
# Update conversation's updated_at timestamp
conversation = await self.get_conversation(conversation_id)
if conversation:
conversation.updated_at = datetime.now(timezone.utc)
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
await self.session.commit()
await session.commit()
except Exception as e:
logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True)
await session.rollback()
finally:
await session.close()
return message
async def get_conversation_history_messages(
self,
conversation_id: int,
limit: int = 20
) -> List[Message]:
"""Get recent conversation history messages."""
history = await self.get_conversation_history(conversation_id, limit)
history_messages = []
for message in history:
if message.role == MessageRole.USER:
history_messages.append(HumanMessage(content=message.content))
elif message.role == MessageRole.ASSISTANT:
history_messages.append(AIMessage(content=message.content))
return history_messages
async def get_conversation_history(
self,
conversation_id: int,

View File

@ -27,7 +27,7 @@ class ConversationContextService:
新创建的对话ID
"""
try:
session = next(get_session())
session = await anext(get_session())
conversation = Conversation(
user_id=user_id,
@ -37,8 +37,8 @@ class ConversationContextService:
)
session.add(conversation)
session.commit()
session.refresh(conversation)
await session.commit()
await session.refresh(conversation)
# 初始化对话上下文
self.context_cache[conversation.id] = {
@ -74,7 +74,7 @@ class ConversationContextService:
# 从数据库加载
try:
session = next(get_session())
session = await anext(get_session())
conversation = session.query(Conversation).filter(
Conversation.id == conversation_id
@ -194,7 +194,7 @@ class ConversationContextService:
保存是否成功
"""
try:
session = next(get_session())
session = await anext(get_session())
message = Message(
conversation_id=conversation_id,
@ -205,7 +205,7 @@ class ConversationContextService:
)
session.add(message)
session.commit()
await session.commit()
# 更新对话的最后更新时间
conversation = session.query(Conversation).filter(
@ -214,7 +214,7 @@ class ConversationContextService:
if conversation:
conversation.updated_at = datetime.utcnow()
session.commit()
await session.commit()
return True
@ -262,7 +262,7 @@ class ConversationContextService:
消息历史列表
"""
try:
session = next(get_session())
session = await anext(get_session())
messages = session.query(Message).filter(
Message.conversation_id == conversation_id

View File

@ -102,41 +102,41 @@ class DatabaseConfigService:
)
self.session.add(db_config)
self.session.commit()
self.session.refresh(db_config)
await self.session.commit()
await self.session.refresh(db_config)
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
return db_config
except Exception as e:
self.session.rollback()
await self.session.rollback()
logger.error(f"创建数据库配置失败: {str(e)}")
raise
def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
"""获取用户的数据库配置列表"""
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
if active_only:
stmt = stmt.where(DatabaseConfig.is_active == True)
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
return self.session.scalars(stmt).all()
return (await self.session.execute(stmt)).scalars().all()
def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
"""根据ID获取配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.id == config_id,
DatabaseConfig.created_by == user_id
)
return self.session.scalar(stmt)
return (await self.session.execute(stmt)).scalar_one_or_none()
def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
"""获取用户的默认配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
# DatabaseConfig.is_default == True,
DatabaseConfig.is_active == True
)
return self.session.scalar(stmt)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
"""测试数据库连接"""
@ -278,14 +278,14 @@ class DatabaseConfigService:
'message': f'断开连接失败: {str(e)}'
}
def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
"""根据数据库类型获取用户配置"""
stmt = select(DatabaseConfig).where(
DatabaseConfig.created_by == user_id,
DatabaseConfig.db_type == db_type,
DatabaseConfig.is_active == True
)
return self.session.scalar(stmt)
return await self.session.scalar(stmt)
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
"""创建或更新数据库配置保证db_type唯一性"""
@ -301,8 +301,8 @@ class DatabaseConfigService:
elif hasattr(existing_config, key):
setattr(existing_config, key, value)
self.session.commit()
self.session.refresh(existing_config)
await self.session.commit()
await self.session.refresh(existing_config)
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
return existing_config
else:
@ -310,7 +310,7 @@ class DatabaseConfigService:
return await self.create_config(user_id, config_data)
except Exception as e:
self.session.rollback()
await self.session.rollback()
logger.error(f"创建或更新数据库配置失败: {str(e)}")
raise

View File

@ -30,7 +30,7 @@ class DocumentService:
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
# Validate knowledge base exists
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = self.session.scalar(stmt)
kb = await self.session.scalar(stmt)
if not kb:
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
raise ValueError(f"知识库 {kb_id} 不存在")
@ -48,6 +48,7 @@ class DocumentService:
# Upload file using storage service
storage_info = await storage_service.upload_file(file, kb_id)
self.session.desc = f"文档 {file.filename} 上传到 {storage_info}"
# Create document record
document = Document(
@ -65,21 +66,21 @@ class DocumentService:
document.set_audit_fields()
self.session.add(document)
self.session.commit()
self.session.refresh(document)
await self.session.commit()
await self.session.refresh(document)
self.session.desc = f"SUCCESS: 成功上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
return document
def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
"""根据文档ID查询文档可选地根据知识库ID过滤。"""
self.session.desc = f"查询文档 {doc_id}"
self.session.desc = f"根据文档ID查询文档 {doc_id}"
stmt = select(Document).where(Document.id == doc_id)
if kb_id is not None:
stmt = stmt.where(Document.knowledge_base_id == kb_id)
return self.session.scalar(stmt)
return await self.session.scalar(stmt)
def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
"""根据知识库ID查询文档支持分页。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
stmt = (
@ -88,14 +89,14 @@ class DocumentService:
.offset(skip)
.limit(limit)
)
return self.session.scalars(stmt).all()
return (await self.session.scalars(stmt)).all()
def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
"""根据知识库ID查询文档支持分页并返回总文档数。"""
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
# Get total count
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
total = self.session.scalar(count_stmt)
total = await self.session.scalar(count_stmt)
# Get documents with pagination
documents_stmt = (
@ -104,44 +105,52 @@ class DocumentService:
.offset(skip)
.limit(limit)
)
documents = self.session.scalars(documents_stmt).all()
documents = (await self.session.scalars(documents_stmt)).all()
return documents, total
def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
async def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
"""根据文档ID删除文档可选地根据知识库ID过滤。"""
self.session.desc = f"删除文档 {doc_id}"
document = self.get_document(doc_id, kb_id)
document = await self.get_document(doc_id, kb_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
# Delete file from storage
try:
storage_service.delete_file(document.file_path)
logger.info(f"Deleted file: {document.file_path}")
await storage_service.delete_file(document.file_path)
self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}"
except Exception as e:
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
# TODO: Remove from vector database
# This should be implemented when vector database service is ready
get_document_processor().delete_document_from_vector_store(kb_id,doc_id)
self.session.desc = f"从向量数据库删除文档 {doc_id}"
(await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id)
# Delete database record
self.session.delete(document)
self.session.commit()
self.session.desc = f"删除数据库记录 {doc_id}"
await self.session.delete(document)
await self.session.commit()
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
return True
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
"""处理文档,提取文本并创建嵌入向量。"""
try:
self.session.desc = f"处理文档 {doc_id}"
document = self.get_document(doc_id, kb_id)
self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量"
document = await self.get_document(doc_id, kb_id)
self.session.desc = f"获取文档 {doc_id} >>> {document}"
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
raise ValueError(f"Document {doc_id} not found")
if document.is_processed:
# document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素
file_path = document.file_path
knowledge_base_id=document.knowledge_base_id
is_processed=document.is_processed
if is_processed:
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
return {
"document_id": doc_id,
@ -149,38 +158,43 @@ class DocumentService:
"message": "文档已处理"
}
self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}"
# 更新文档状态为处理中
document.processing_error = None
self.session.commit()
await self.session.commit()
self.session.desc = f"更新文档状态为处理中 {doc_id}"
# 调用文档处理器进行处理
result = get_document_processor().process_document(
document_processor = await get_document_processor(self.session)
self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}"
result = await document_processor.process_document(
session=self.session,
document_id=doc_id,
file_path=document.file_path,
knowledge_base_id=document.knowledge_base_id
file_path=file_path,
knowledge_base_id=knowledge_base_id
)
self.session.desc = f"SUCCESS: 成功处理文档 {doc_id}"
self.session.desc = f"处理文档完毕 {doc_id}"
# 如果处理成功,更新文档状态
if result["status"] == "success":
document.is_processed = True
document.chunk_count = result.get("chunks_count", 0)
self.session.commit()
self.session.refresh(document)
await self.session.commit()
await self.session.refresh(document)
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
return result
except Exception as e:
self.session.rollback()
await self.session.rollback()
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
# Update document with error
try:
document = self.get_document(doc_id)
document = await self.get_document(doc_id, kb_id)
if document:
document.processing_error = str(e)
self.session.commit()
await self.session.commit()
except Exception as db_error:
logger.error(f"Failed to update document error status: {db_error}")
@ -217,10 +231,10 @@ class DocumentService:
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
raise
def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
"""更新文档处理状态。"""
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
document = self.get_document(doc_id)
document = await self.get_document(doc_id)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return False
@ -228,16 +242,16 @@ class DocumentService:
document.is_processed = is_processed
document.processing_error = error
self.session.commit()
await self.session.commit()
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
return True
def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""在知识库中搜索文档使用向量相似度。"""
try:
# 使用文档处理器进行相似性搜索
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query}"
results = get_document_processor().search_similar_documents(kb_id, query, limit)
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}"
results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit)
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
return results
except Exception as e:
@ -245,9 +259,9 @@ class DocumentService:
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
return []
def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
async def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
"""获取知识库中的文档统计信息。"""
documents = self.get_documents(kb_id, limit=1000) # Get all documents
documents = await self.get_documents(kb_id, limit=1000) # Get all documents
total_count = len(documents)
processed_count = len([doc for doc in documents if doc.is_processed])
@ -267,19 +281,21 @@ class DocumentService:
"file_types": file_types
}
def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
"""获取特定文档的文档块。"""
try:
self.session.desc = f"获取文档 {doc_id} 的文档块"
stmt = select(Document).where(Document.id == doc_id)
document = self.session.scalar(stmt)
document = await self.session.scalar(stmt)
if not document:
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
return []
self.session.desc = f"获取文档 {doc_id} 的文档块 > document"
# Get chunks from document processor
chunks_data = get_document_processor().get_document_chunks(document.knowledge_base_id, doc_id)
chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id)
self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data"
# Convert to DocumentChunk objects
chunks = []
for chunk_data in chunks_data:
@ -290,7 +306,8 @@ class DocumentService:
page_number=chunk_data.get("page_number"),
chunk_index=chunk_data["chunk_index"],
start_char=chunk_data.get("start_char"),
end_char=chunk_data.get("end_char")
end_char=chunk_data.get("end_char"),
vector_id=chunk_data.get("vector_id")
)
chunks.append(chunk)

View File

@ -3,10 +3,9 @@
import os
from typing import List, Dict, Any, Optional
from pathlib import Path
from urllib.parse import quote
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from fastapi import HTTPException
from requests import Session
from sqlalchemy import text
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
TextLoader,
@ -18,7 +17,7 @@ import pdfplumber
from langchain_core.documents import Document
from langchain_postgres import PGVector
from typing import List
# 旧的ZhipuEmbeddings类已移除现在统一使用EmbeddingFactory创建embedding实例
from ..core.config import BaseSettings, get_settings
from ..models.knowledge_base import Document as DocumentModel
@ -26,64 +25,12 @@ from ..db.database import get_session
from loguru import logger
settings = get_settings()
class PGVectorConnectionPool:
"""PGVector连接池管理器"""
def __init__(self):
logger.error("PGVector连接池管理器 -==== 待异步方式实现")
self.engine = None
self.SessionLocal = None
# self._init_connection_pool()
# def _init_connection_pool(self):
# """初始化连接池"""
# if settings.vector_db.type == "pgvector":
# # 构建连接字符串对密码进行URL编码以处理特殊字符如@符号)
# encoded_password = quote(settings.vector_db.pgvector_password, safe="")
# connection_string = (
# f"postgresql://{settings.vector_db.pgvector_user}:"
# f"{encoded_password}@"
# f"{settings.vector_db.pgvector_host}:"
# f"{settings.vector_db.pgvector_port}/"
# f"{settings.vector_db.pgvector_database}"
# )
# # 创建SQLAlchemy引擎配置连接池
# self.engine = create_engine(
# connection_string,
# poolclass=QueuePool,
# pool_size=5, # 连接池大小
# max_overflow=10, # 最大溢出连接数
# pool_pre_ping=True, # 连接前ping检查
# pool_recycle=3600, # 连接回收时间(秒)
# echo=False # 是否打印SQL语句
# )
# # 创建会话工厂
# self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
# logger.info(f"PGVector连接池已初始化: {settings.vector_db.pgvector_host}:{settings.vector_db.pgvector_port}")
# def get_session(self):
# """获取数据库会话"""
# if self.SessionLocal is None:
# raise RuntimeError("连接池未初始化")
# return self.SessionLocal()
# def execute_query(self, query: str, params: tuple = None):
# """执行查询并返回结果"""
# session = self.get_session()
# try:
# result = session.execute(text(query), params or {})
# return result.fetchall()
# finally:
# session.close()
class DocumentProcessor:
"""文档处理器,负责文档的加载、分段和向量化"""
def __init__(self):
# 初始化语义分割器配置
self.embeddings = None
self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=settings.file.chunk_size,
@ -92,28 +39,29 @@ class DocumentProcessor:
separators=["\n\n", "\n", " ", ""]
)
async def initialize(self, session: Session = None):
# 初始化嵌入模型 - 根据配置选择提供商
self._init_embeddings()
await self._init_embeddings(session)
# 初始化连接池仅对PGVector
self.pgvector_pool = None
# PostgreSQL pgvector连接配置
print('settings.vector_db.type=============', settings.vector_db.type)
if settings.vector_db.type == "pgvector":
# 新版本PGVector使用psycopg3连接字符串
# 对密码进行URL编码以处理特殊字符如@符号)
encoded_password = quote(settings.vector_db.pgvector_password, safe="")
self.connection_string = (
f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
f"{encoded_password}@"
f"{settings.vector_db.pgvector_host}:"
f"{settings.vector_db.pgvector_port}/"
f"{settings.vector_db.pgvector_database}"
)
# 初始化连接池
self.pgvector_pool = PGVectorConnectionPool()
else:
# if settings.vector_db.type == "pgvector":
# # 新版本PGVector使用psycopg3连接字符串
# # 对密码进行URL编码以处理特殊字符如@符号)
# encoded_password = quote(settings.vector_db.pgvector_password, safe="")
# self.connection_string = (
# f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
# f"{encoded_password}@"
# f"{settings.vector_db.pgvector_host}:"
# f"{settings.vector_db.pgvector_port}/"
# f"{settings.vector_db.pgvector_database}"
# )
# # 初始化连接池
# self.pgvector_pool = PGVectorConnectionPool()
# logger.info("新版本PGVector使用psycopg3连接字符串: %s", self.connection_string)
# else:
# 向量数据库存储路径Chroma兼容
vector_db_path = settings.vector_db.persist_directory
if not os.path.isabs(vector_db_path):
@ -122,38 +70,105 @@ class DocumentProcessor:
backend_dir = Path(__file__).parent.parent.parent
vector_db_path = str(backend_dir / vector_db_path)
self.vector_db_path = vector_db_path
session.desc = f"初始化向量数据库 - 路径 = {self.vector_db_path}"
def _init_embeddings(self):
"""根据配置初始化embedding模型"""
from .embedding_factory import EmbeddingFactory
self.embeddings = EmbeddingFactory.create_embeddings()
async def _init_embeddings(self, session: Optional[Any] = None):
"""初始化嵌入模型。"""
try:
if not self.embeddings:
# 使用llm_config_service获取嵌入配置
from .llm_config_service import LLMConfigService
llm_config_service = LLMConfigService()
def load_document(self, file_path: str) -> List[Document]:
# 获取嵌入配置
config = None
if session:
config = await llm_config_service.get_default_embedding_config(session)
if config:
if(session != None):
session.desc = f"获取默认嵌入模型配置: {config}"
# # 转换配置格式
# config = {
# "provider": config.provider,
# "api_key": config.api_key,
# "model": config.model_name
# }
# 如果未找到配置,使用默认配置
if not config:
session.desc = f"ERROR: 未找到嵌入模型配置"
raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
session.desc = f"获取嵌入模型配置 > 结果:{config}"
# 根据配置创建嵌入模型
if config.provider == "openai":
from langchain_openai import OpenAIEmbeddings
self.embeddings = OpenAIEmbeddings(
model=config.get("model", "text-embedding-3-small"),
api_key=config.get("api_key")
)
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})"
elif config.provider == "ollama":
from langchain_ollama import OllamaEmbeddings
self.embeddings = OllamaEmbeddings(
model=config.model_name,
base_url=config.base_url
)
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
elif config.provider == "local":
from langchain_huggingface import HuggingFaceEmbeddings
self.embeddings = HuggingFaceEmbeddings(
model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
)
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})"
else:
# 默认使用OpenAI
from langchain_openai import OpenAIEmbeddings
self.embeddings = OpenAIEmbeddings(
model=config.get("model", "text-embedding-3-small"),
api_key=config.get("api_key")
)
session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
return self.embeddings
except Exception as e:
logger.error(f"初始化嵌入模型时出错: {e}")
raise
def load_document(self, session: Session, file_path: str) -> List[Document]:
"""根据文件类型加载文档"""
file_extension = Path(file_path).suffix.lower()
try:
if file_extension == '.txt':
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: txt"
loader = TextLoader(file_path, encoding='utf-8')
documents = loader.load()
elif file_extension == '.pdf':
# 使用pdfplumber处理PDF文件更稳定
documents = self._load_pdf_with_pdfplumber(file_path)
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: pdf"
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(file_path)
documents = loader.load()
# documents = self._load_pdf_with_pdfplumber(file_path)
elif file_extension == '.docx':
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: docx"
loader = Docx2txtLoader(file_path)
documents = loader.load()
elif file_extension == '.md':
session.desc = f"加载文档 - 文件路径: {file_path} - 类型: md"
loader = UnstructuredMarkdownLoader(file_path)
documents = loader.load()
else:
raise ValueError(f"不支持的文件类型: {file_extension}")
logger.info(f"成功加载文档: {file_path}, 页数: {len(documents)}")
session.desc = f"已载文档: {file_path}, 页数: {len(documents)}"
# if len(documents) > 0:
# session.desc = f"文档内容示例: {type(documents[0])} - {documents[0]}"
return documents
except Exception as e:
logger.error(f"加载文档失败 {file_path}: {str(e)}")
raise
session.desc = f"ERROR: 加载文档失败 {file_path}: {str(e)}"
raise e
def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]:
"""使用pdfplumber加载PDF文档"""
@ -196,77 +211,6 @@ class DocumentProcessor:
return Document(page_content=merged_text, metadata=merged_metadata)
def _get_semantic_split_points(self, text: str) -> List[str]:
"""使用大模型分析文档内容,返回合适的分割点列表"""
try:
from langchain.chat_models import ChatOpenAI
from ..core.config import get_settings
prompt = f"""
# 任务说明
请分析文档内容识别出适合作为分割点的关键位置分割点应该是能够将文档划分为有意义段落的文本片段
# 分割规则
请严格按照以下规则识别分割点
## 基本要求
1. 分割点必须是完整的句子开头或段落开头
2. 每个分割后的部分应包含相对完整的语义内容
3. 每个分割部分的理想长度控制在500字以内严禁超过1000字如果超过了1000字要强制分段
## 短段落处理
4. 如果某部分长度可能小于50字应将其与后续内容合并避免产生过短片段
## 唯一性保证(重要)
5. 确保每个分割点在文档中具有唯一性
- 检查文内是否存在相同的文本片段
- 如果存在重复需要扩展分割点字符串直到获得唯一标识
- 扩展方法在当前分割点后追加几个字符形成更长的唯一字符串
## 示例说明
原始文档
"目录:
第一章 标题一
第二章 标题二
正文
第一章 标题一
这是第一章的内容
第二章 标题二
这是第二章的内容"
错误分割点"第一章 标题一"在目录和正文中重复出现
正确分割点"第一章 标题一\n这是第"通过追加内容确保唯一性
# 输出格式
- 只返回分割点文本字符串
- 每个分割点用~~分隔
- 不要包含任何其他内容或解释
示例输出分割点1~~分割点2~~分割点3
文档内容
{text[:10000]} # 限制输入长度
"""
from ..core.llm import create_llm
llm = create_llm(temperature=0.2)
response = llm.invoke(prompt)
# 解析响应获取分割点列表
split_points = [point.strip() for point in response.content.split('~~') if point.strip()]
logger.info(f"语义分析得到 {len(split_points)} 个分割点")
return split_points
except Exception as e:
logger.error(f"获取语义分割点失败: {str(e)}")
return []
def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]:
"""根据语义分割点切分文本"""
chunks = []
@ -291,76 +235,19 @@ class DocumentProcessor:
return chunks
def split_documents(self, documents: List[Document]) -> List[Document]:
async def split_documents(self, session: Session, documents: List[Document]) -> List[Document]:
"""将文档分割成小块(含短段落合并和超长强制分割功能)"""
try:
if self.semantic_splitter_enabled and documents:
# 1. 合并文档
merged_doc = self._merge_documents(documents)
# 2. 获取语义分割点
split_points = self._get_semantic_split_points(merged_doc.page_content)
if split_points:
# 3. 根据语义分割点切分文本
text_chunks = self._split_by_semantic_points(merged_doc.page_content, split_points)
# 4. 处理短段落合并和超长强制分割(新增逻辑)
processed_chunks = []
buffer = ""
for chunk in text_chunks:
# 先检查当前chunk是否超长超过1000字符
if len(chunk) > 1000:
# 如果有缓冲内容,先处理缓冲
if buffer:
processed_chunks.append(buffer)
buffer = ""
# 对超长chunk进行强制分割
forced_splits = self._force_split_long_chunk(chunk)
processed_chunks.extend(forced_splits)
else:
# 正常处理短段落合并
if not buffer:
buffer = chunk
else:
if len(buffer) < 100:
buffer = f"{buffer}\n{chunk}"
else:
processed_chunks.append(buffer)
buffer = chunk
# 添加最后剩余的缓冲内容
if buffer:
processed_chunks.append(buffer)
# 5. 创建Document对象
chunks = []
for i, chunk in enumerate(processed_chunks):
doc = Document(
page_content=chunk,
metadata={
**merged_doc.metadata,
'chunk_index': i,
'merged': len(chunk) > 100, # 标记是否经过合并
'forced_split': len(chunk) > 1000 # 标记是否经过强制分割
}
)
chunks.append(doc)
else:
# 如果获取分割点失败,回退到默认分割器
logger.warning("语义分割失败,使用默认分割器")
chunks = self.text_splitter.split_documents(documents)
else:
# 使用默认分割器
chunks = self.text_splitter.split_documents(documents)
logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块")
session.desc = f"文档分割完成,共生成 {len(chunks)} 个文档块"
if len(chunks) > 0:
session.desc = f"文档块内容示例: {type(chunks[0])} - {chunks[0]}"
return chunks
except Exception as e:
logger.error(f"文档分割失败: {str(e)}")
raise
session.desc = f"ERROR: 文档分割失败: {str(e)}"
raise e
def _force_split_long_chunk(self, chunk: str) -> List[str]:
"""强制分割超长段落超过1000字符"""
@ -395,35 +282,35 @@ class DocumentProcessor:
def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str:
"""为知识库创建向量存储"""
try:
if settings.vector_db.type == "pgvector":
# 添加元数据
for i, doc in enumerate(documents):
doc.metadata.update({
"knowledge_base_id": knowledge_base_id,
"document_id": str(document_id) if document_id else "unknown",
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
"chunk_index": i
})
# if settings.vector_db.type == "pgvector":
# # 添加元数据
# for i, doc in enumerate(documents):
# doc.metadata.update({
# "knowledge_base_id": knowledge_base_id,
# "document_id": str(document_id) if document_id else "unknown",
# "chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
# "chunk_index": i
# })
# 创建PostgreSQL pgvector存储
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
# # 创建PostgreSQL pgvector存储
# collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
# 创建新版本PGVector实例
vector_store = PGVector(
connection=self.connection_string,
embeddings=self.embeddings,
collection_name=collection_name,
use_jsonb=True # 使用JSONB存储元数据
)
# # 创建新版本PGVector实例
# vector_store = PGVector(
# connection=self.connection_string,
# embeddings=self.embeddings,
# collection_name=collection_name,
# use_jsonb=True # 使用JSONB存储元数据
# )
# 手动添加文档
vector_store.add_documents(documents)
# # 手动添加文档
# vector_store.add_documents(documents)
logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
return collection_name
else:
# logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
# return collection_name
# else:
# Chroma兼容模式
from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
# 添加元数据
@ -442,9 +329,6 @@ class DocumentProcessor:
persist_directory=kb_vector_path
)
# 持久化向量存储
vector_store.persist()
logger.info(f"向量存储创建成功: {kb_vector_path}")
return kb_vector_path
@ -452,49 +336,22 @@ class DocumentProcessor:
logger.error(f"创建向量存储失败: {str(e)}")
raise
def add_documents_to_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
def add_documents_to_vector_store(self, session: Session, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
"""向现有向量存储添加文档"""
try:
if settings.vector_db.type == "pgvector":
# 添加元数据
for i, doc in enumerate(documents):
doc.metadata.update({
"knowledge_base_id": knowledge_base_id,
"document_id": str(document_id) if document_id else "unknown",
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
"chunk_index": i
})
# PostgreSQL pgvector存储
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
try:
# 连接到现有集合
vector_store = PGVector(
connection=self.connection_string,
embeddings=self.embeddings,
collection_name=collection_name,
use_jsonb=True
)
# 添加新文档
vector_store.add_documents(documents)
except Exception as e:
# 如果集合不存在,创建新的向量存储
logger.warning(f"连接现有向量存储失败,创建新的向量存储: {e}")
self.create_vector_store(knowledge_base_id, documents, document_id)
if len(documents) == 0:
session.desc = f"WARNING: 文档列表为空,不执行添加操作"
return
from langchain_chroma import Chroma
logger.info(f"文档已添加到PostgreSQL pgvector存储: {collection_name}")
else:
# Chroma兼容模式
from langchain_community.vectorstores import Chroma
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
session.desc = f"添加文档到向量存储: {kb_vector_path} - documents number: {len(documents)}"
# 检查向量存储是否存在
if not os.path.exists(kb_vector_path):
# 如果不存在,创建新的向量存储
session.desc = f"WARNING: 向量存储不存在,创建新的向量存储"
self.create_vector_store(knowledge_base_id, documents, document_id)
return
session.desc = f"添加文档到向量存储: exists"
# 添加元数据
for i, doc in enumerate(documents):
doc.metadata.update({
@ -504,43 +361,44 @@ class DocumentProcessor:
"chunk_index": i
})
session.desc = f"添加文档到向量存储: enumerate"
# 加载现有向量存储
vector_store = Chroma(
persist_directory=kb_vector_path,
embedding_function=self.embeddings
)
session.desc = f"添加文档到向量存储: Chroma"
# 添加新文档
vector_store.add_documents(documents)
vector_store.persist()
ids = vector_store.add_documents(documents)
session.desc = f"文档已添加到向量存储: {kb_vector_path} -> {len(ids)} IDS - \n{ids}"
logger.info(f"文档已添加到向量存储: {kb_vector_path}")
except Exception as e:
logger.error(f"添加文档到向量存储失败: {str(e)}")
raise
def process_document(self, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
async def process_document(self, session: Session, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
"""处理单个文档:加载、分段、向量化"""
try:
logger.info(f"开始处理文档 ID: {document_id}, 路径: {file_path}")
session.desc = f"处理文档 ID: {document_id} 文件路径: {file_path}"
# 1. 加载文档
documents = self.load_document(file_path)
documents = self.load_document(session, file_path)
# 2. 分割文档
chunks = self.split_documents(documents)
chunks = await self.split_documents(session, documents)
# 3. 添加到向量存储
self.add_documents_to_vector_store(knowledge_base_id, chunks, document_id)
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
# 4. 更新文档状态
with next(get_session()) as session:
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
session = await anext(get_session())
try:
from sqlalchemy import select
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
if document:
document.status = "processed"
document.chunk_count = len(chunks)
session.commit()
await session.commit()
finally:
await session.close()
result = {
"document_id": document_id,
@ -549,22 +407,27 @@ class DocumentProcessor:
"message": "文档处理完成"
}
logger.info(f"文档处理完成: {result}")
session.desc = f"文档处理完成: {result}"
return result
except Exception as e:
logger.error(f"文档处理失败 ID: {document_id}: {str(e)}")
session.desc = f"ERROR: 文档处理失败 ID: {document_id}: {str(e)}"
# 更新文档状态为失败
try:
with next(get_session()) as session:
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
session = await anext(get_session())
try:
from sqlalchemy import select
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
if document:
document.status = "failed"
document.error_message = str(e)
session.commit()
await session.commit()
finally:
await session.close()
except Exception as db_error:
logger.error(f"更新文档状态失败: {str(db_error)}")
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"
return {
"document_id": document_id,
@ -573,98 +436,34 @@ class DocumentProcessor:
"message": "文档处理失败"
}
def _get_document_ids_from_vector_store(self, knowledge_base_id: int, document_id: int) -> List[str]:
"""查询指定document_id的所有向量记录的uuid"""
try:
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
# 使用连接池执行查询
if self.pgvector_pool:
query = f"""
SELECT uuid FROM langchain_pg_embedding
WHERE collection_id = (
SELECT uuid FROM langchain_pg_collection
WHERE name = %s
) AND cmetadata->>'document_id' = %s
"""
result = self.pgvector_pool.execute_query(query, (collection_name, str(document_id)))
return [row[0] for row in result] if result else []
else:
logger.warning("PGVector连接池未初始化")
return []
except Exception as e:
logger.error(f"查询文档向量记录失败: {str(e)}")
return []
def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None:
"""从向量存储中删除文档"""
try:
if settings.vector_db.type == "pgvector":
# PostgreSQL pgvector存储
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
try:
# 创建新版本PGVector实例
vector_store = PGVector(
connection=self.connection_string,
embeddings=self.embeddings,
collection_name=collection_name,
use_jsonb=True
)
# 直接从数据库查询要删除的文档UUID
try:
from sqlalchemy import text
from sqlalchemy.orm import Session
# 获取数据库引擎
engine = vector_store._engine
with Session(engine) as session:
# 查询匹配document_id的所有记录的ID
query_sql = text(
f"SELECT id FROM langchain_pg_embedding "
f"WHERE cmetadata->>'document_id' = :doc_id"
)
result = session.execute(query_sql, {"doc_id": str(document_id)})
ids_to_delete = [row[0] for row in result.fetchall()]
if ids_to_delete:
# 使用ID删除文档
vector_store.delete(ids=ids_to_delete)
logger.info(f"成功删除 {len(ids_to_delete)} 个文档块: document_id={document_id}")
else:
logger.warning(f"未找到要删除的文档ID: document_id={document_id}")
except Exception as query_error:
logger.error(f"查询要删除的文档时出错: {query_error}")
# 如果查询失败,说明文档可能不存在
logger.warning(f"无法查询到要删除的文档: document_id={document_id}")
return
logger.info(f"文档已从PostgreSQL pgvector存储中删除: document_id={document_id}")
except Exception as e:
logger.warning(f"PostgreSQL pgvector存储不存在或删除失败: {collection_name}, {str(e)}")
else:
# Chroma兼容模式
from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
if not os.path.exists(kb_vector_path):
logger.warning(f"向量存储不存在: {kb_vector_path}")
return
chunks = self.get_document_chunks(knowledge_base_id, document_id)
# 加载向量存储
vector_store = Chroma(
persist_directory=kb_vector_path,
embedding_function=self.embeddings
)
# 删除相关文档块这里需要根据实际的Chroma API来实现
count_before = vector_store._collection.count()
count_after = count_before
if len(chunks) > 0:
where_filter = {"document_id": str(document_id)}
vector_store.delete(where=where_filter)
count_after = vector_store._collection.count()
# 注意Chroma的删除功能可能需要特定的实现方式
logger.info(f"文档已从向量存储中删除: document_id={document_id}")
logger.info(f"文档已从向量存储中删除: document_id={document_id},删除前有 {count_before} 个向量,删除后有 {count_after} 个向量")
except Exception as e:
logger.error(f"从向量存储删除文档失败: {str(e)}")
@ -679,25 +478,6 @@ class DocumentProcessor:
- 确保结果按chunk_index排序
"""
try:
if settings.vector_db.type == "pgvector":
# PostgreSQL pgvector存储 - 使用直接SQL查询避免相似性搜索
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
try:
# 尝试直接SQL查询推荐方法
chunks = self._get_chunks_by_sql(knowledge_base_id, document_id)
if chunks:
return chunks
# 如果SQL查询失败回退到改进的LangChain方法
logger.info("SQL查询失败使用LangChain回退方案")
return self._get_chunks_by_langchain_improved(knowledge_base_id, document_id, collection_name)
except Exception as e:
logger.warning(f"PostgreSQL pgvector存储访问失败: {collection_name}, {str(e)}")
return []
else:
# Chroma兼容模式
return self._get_chunks_chroma(knowledge_base_id, document_id)
except Exception as e:
@ -826,9 +606,7 @@ class DocumentProcessor:
def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
"""Chroma存储的处理逻辑"""
try:
from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
# 构建向量数据库路径
vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
@ -845,6 +623,7 @@ class DocumentProcessor:
# 获取所有文档的元数据,筛选出指定文档的分段
collection = vectorstore._collection
all_docs = collection.get(include=["metadatas", "documents"])
all_ids_data = collection.get()
chunks = []
chunk_index = 0
@ -852,6 +631,7 @@ class DocumentProcessor:
for i, metadata in enumerate(all_docs["metadatas"]):
if metadata.get("document_id") == str(document_id):
chunk_content = all_docs["documents"][i]
vector_id = all_ids_data["ids"][i]
chunk = {
"id": f"chunk_{document_id}_{chunk_index}",
@ -860,65 +640,61 @@ class DocumentProcessor:
"page_number": metadata.get("page"),
"chunk_index": chunk_index,
"start_char": metadata.get("start_char"),
"end_char": metadata.get("end_char")
"end_char": metadata.get("end_char"),
"vector_id": vector_id
}
chunks.append(chunk)
chunk_index += 1
logger.info(f"获取到文档 {document_id}{len(chunks)} 个分段")
return chunks
except Exception as e:
logger.error(f"Chroma存储处理失败: {e}")
return []
def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""在知识库中搜索相似文档"""
try:
if settings.vector_db.type == "pgvector":
# PostgreSQL pgvector存储
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
# if settings.vector_db.type == "pgvector":
# # PostgreSQL pgvector存储
# collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
try:
vector_store = PGVector(
connection=self.connection_string,
embeddings=self.embeddings,
collection_name=collection_name,
use_jsonb=True
)
# try:
# vector_store = PGVector(
# connection=self.connection_string,
# embeddings=self.embeddings,
# collection_name=collection_name,
# use_jsonb=True
# )
# 执行相似性搜索
results = vector_store.similarity_search_with_score(query, k=k)
# # 执行相似性搜索
# results = vector_store.similarity_search_with_score(query, k=k)
# 格式化结果
formatted_results = []
for doc, distance_score in results:
# pgvector使用余弦距离距离越小相似度越高
# 将距离转换为0-1之间的相似度分数
similarity_score = 1.0 / (1.0 + distance_score)
# # 格式化结果
# formatted_results = []
# for doc, distance_score in results:
# # pgvector使用余弦距离距离越小相似度越高
# # 将距离转换为0-1之间的相似度分数
# similarity_score = 1.0 / (1.0 + distance_score)
formatted_results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"similarity_score": distance_score, # 保留原始距离分数
"normalized_score": similarity_score, # 归一化相似度分数
"source": doc.metadata.get('filename', 'unknown'),
"document_id": doc.metadata.get('document_id', 'unknown'),
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
})
# formatted_results.append({
# "content": doc.page_content,
# "metadata": doc.metadata,
# "similarity_score": distance_score, # 保留原始距离分数
# "normalized_score": similarity_score, # 归一化相似度分数
# "source": doc.metadata.get('filename', 'unknown'),
# "document_id": doc.metadata.get('document_id', 'unknown'),
# "chunk_id": doc.metadata.get('chunk_id', 'unknown')
# })
# 按相似度分数排序(距离越小越相似)
formatted_results.sort(key=lambda x: x['similarity_score'])
# # 按相似度分数排序(距离越小越相似)
# formatted_results.sort(key=lambda x: x['similarity_score'])
logger.info(f"PostgreSQL pgvector搜索完成找到 {len(formatted_results)} 个相关文档")
return formatted_results
# logger.info(f"PostgreSQL pgvector搜索完成找到 {len(formatted_results)} 个相关文档")
# return formatted_results
except Exception as e:
logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
return []
else:
# except Exception as e:
# logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
# return []
# else:
# Chroma兼容模式
from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
if not os.path.exists(kb_vector_path):
@ -961,13 +737,15 @@ class DocumentProcessor:
logger.error(f"搜索文档失败: {str(e)}")
return [] # 返回空列表而不是抛出异常
# 全局文档处理器实例(延迟初始化)
document_processor = None
def get_document_processor():
async def get_document_processor(session: Session = None):
"""获取文档处理器实例(延迟初始化)"""
global document_processor
if session:
session.desc = "获取文档处理器实例"
if document_processor is None:
document_processor = DocumentProcessor()
await document_processor.initialize(session)
return document_processor

View File

@ -4,6 +4,7 @@ from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from requests import Session
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
from ..core.config import settings
from loguru import logger
@ -12,7 +13,8 @@ class EmbeddingFactory:
"""Factory class for creating embedding instances based on provider."""
@staticmethod
def create_embeddings(
async def create_embeddings(
session: Session = None,
provider: Optional[str] = None,
model: Optional[str] = None,
dimensions: Optional[int] = None
@ -28,12 +30,12 @@ class EmbeddingFactory:
Embeddings instance
"""
# 使用新的embedding配置
embedding_config = settings.embedding.get_current_config()
embedding_config = await settings.embedding.get_current_config(session)
provider = provider or settings.embedding.provider
model = model or embedding_config.get("model")
dimensions = dimensions or settings.vector_db.embedding_dimension
logger.info(f"Creating embeddings with provider: {provider}, model: {model}")
session.desc = f"创建嵌入模型: {provider}, {model}"
if provider == "openai":
return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions)

View File

@ -4,6 +4,7 @@ import os
import pandas as pd
from typing import Dict, List, Any, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from ..models.excel_file import ExcelFile
from ..db.database import get_session
from loguru import logger
@ -102,7 +103,7 @@ class ExcelMetadataService:
'processing_error': str(e)
}
def save_file_metadata(self, file_path: str, original_filename: str,
async def save_file_metadata(self, file_path: str, original_filename: str,
user_id: int, file_size: int) -> ExcelFile:
"""Extract and save Excel file metadata to database."""
try:
@ -131,31 +132,30 @@ class ExcelMetadataService:
# Save to database
self.db.add(excel_file)
self.db.commit()
self.db.refresh(excel_file)
self.session.add(excel_file)
await self.session.commit()
await self.session.refresh(excel_file)
logger.info(f"Saved metadata for file {original_filename} with ID {excel_file.id}")
return excel_file
except Exception as e:
logger.error(f"Error saving metadata for {original_filename}: {str(e)}")
self.db.rollback()
await self.session.rollback()
raise
def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]:
async def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]:
"""Get Excel files for a user with pagination."""
try:
# Get total count
total = self.db.query(ExcelFile).filter(ExcelFile.created_by == user_id).count()
stmt = select(func.count(ExcelFile.id)).where(ExcelFile.created_by == user_id)
result = await self.session.execute(stmt)
total = result.scalar_one()
# Get files with pagination
files = (self.db.query(ExcelFile)
.filter(ExcelFile.created_by == user_id)
.order_by(ExcelFile.created_at.desc())
.offset(skip)
.limit(limit)
.all())
stmt = select(ExcelFile).where(ExcelFile.created_by == user_id).order_by(ExcelFile.created_at.desc()).offset(skip).limit(limit)
result = await self.session.execute(stmt)
files = result.scalars().all()
return files, total
@ -163,21 +163,21 @@ class ExcelMetadataService:
logger.error(f"Error getting user files for user {user_id}: {str(e)}")
return [], 0
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]:
async def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]:
"""Get Excel file by ID and user ID."""
try:
return (self.db.query(ExcelFile)
.filter(ExcelFile.id == file_id, ExcelFile.created_by == user_id)
.first())
stmt = select(ExcelFile).where(ExcelFile.id == file_id, ExcelFile.created_by == user_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting file {file_id} for user {user_id}: {str(e)}")
return None
def delete_file(self, file_id: int, user_id: int) -> bool:
async def delete_file(self, file_id: int, user_id: int) -> bool:
"""Delete Excel file record and physical file."""
try:
# Get file record
excel_file = self.get_file_by_id(file_id, user_id)
excel_file = await self.get_file_by_id(file_id, user_id)
if not excel_file:
return False
@ -187,39 +187,40 @@ class ExcelMetadataService:
logger.info(f"Deleted physical file: {excel_file.file_path}")
# Delete database record
self.db.delete(excel_file)
self.db.commit()
await self.session.delete(excel_file)
await self.session.commit()
logger.info(f"Deleted Excel file record with ID {file_id}")
return True
except Exception as e:
logger.error(f"Error deleting file {file_id}: {str(e)}")
self.db.rollback()
await self.session.rollback()
return False
def update_last_accessed(self, file_id: int, user_id: int) -> bool:
async def update_last_accessed(self, file_id: int, user_id: int) -> bool:
"""Update last accessed time for a file."""
try:
excel_file = self.get_file_by_id(file_id, user_id)
excel_file = await self.get_file_by_id(file_id, user_id)
if not excel_file:
return False
from sqlalchemy.sql import func
excel_file.last_accessed = func.now()
self.db.commit()
await self.session.commit()
return True
except Exception as e:
logger.error(f"Error updating last accessed for file {file_id}: {str(e)}")
self.db.rollback()
await self.session.rollback()
return False
def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]:
async def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]:
"""Get file summary information for LLM context."""
try:
files = self.db.query(ExcelFile).filter(ExcelFile.user_id == user_id).all()
stmt = select(ExcelFile).where(ExcelFile.user_id == user_id)
result = await self.session.execute(stmt)
files = result.scalars().all()
summary = []
for file in files:

View File

@ -29,9 +29,11 @@ class KnowledgeBaseService:
Args:
session (Session): 数据库会话用于执行ORM操作
"""
if session is None:
logger.error("session为空session must be an instance of Session")
self.session = session
def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
"""创建一个新的知识库实例。
Args:
@ -57,130 +59,22 @@ class KnowledgeBaseService:
collection_name=collection_name
)
# Set audit fields
# 自动更新created_by和updated_by字段
kb.set_audit_fields()
self.session.add(kb)
self.session.commit()
self.session.refresh(kb)
await self.session.commit()
await self.session.refresh(kb)
logger.info(f"Created knowledge base: {kb.name} (ID: {kb.id})")
self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}"
return kb
except Exception as e:
self.session.rollback()
await self.session.rollback()
logger.error(f"Failed to create knowledge base: {str(e)}")
raise
def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
"""根据ID获取知识库实例。
Args:
kb_id (int): 知识库实例的ID
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
return self.session.execute(stmt).scalar_one_or_none()
def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
"""根据名称获取当前用户的知识库实例。
Args:
name (str): 知识库实例的名称
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.name == name,
KnowledgeBase.created_by == UserContext.get_current_user().id
)
return self.session.execute(stmt).scalar_one_or_none()
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
"""获取当前用户的所有知识库的列表。
Args:
skip (int, optional): 跳过的记录数默认值为0
limit (int, optional): 返回的最大记录数默认值为50
active_only (bool, optional): 是否仅返回活动的知识库默认值为True
Returns:
List[KnowledgeBase]: 当前用户的知识库列表
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user().id)
if active_only:
stmt = stmt.where(KnowledgeBase.is_active == True)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
"""更新知识库实例。
Args:
kb_id (int): 待更新的知识库实例ID
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据
Returns:
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例否则返回None
Raises:
Exception: 如果更新过程中发生错误
"""
try:
kb = self.get_knowledge_base(kb_id)
if not kb:
return None
# Update fields
update_data = kb_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(kb, field, value)
# Set audit fields
kb.set_audit_fields(is_update=True)
self.session.commit()
self.session.refresh(kb)
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
return kb
except Exception as e:
self.session.rollback()
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb_id} 失败: {str(e)}"
raise
def delete_knowledge_base(self, kb_id: int) -> bool:
"""删除知识库实例。
Args:
kb_id (int): 待删除的知识库实例ID
Returns:
bool: 如果知识库实例被成功删除则返回True否则返回False
Raises:
Exception: 如果删除过程中发生错误
"""
kb = self.get_knowledge_base(kb_id)
if not kb:
return False
# TODO: Clean up vector database collection
# This should be implemented when vector database service is ready
self.session.delete(kb)
self.session.commit()
return True
def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
"""Search knowledge bases by name or description for the current user.
Args:
@ -192,7 +86,7 @@ class KnowledgeBaseService:
List[KnowledgeBase]: List of matching knowledge bases.
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.created_by == UserContext.get_current_user().id,
KnowledgeBase.created_by == UserContext.get_current_user()['id'],
KnowledgeBase.is_active == True,
or_(
KnowledgeBase.name.ilike(f"%{query}%"),
@ -201,7 +95,7 @@ class KnowledgeBaseService:
)
stmt = stmt.offset(skip).limit(limit)
return self.session.execute(stmt).scalars().all()
return (await self.session.execute(stmt)).scalars().all()
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
"""Search in knowledge base using vector similarity.
@ -219,7 +113,7 @@ class KnowledgeBaseService:
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
# Use document processor for vector search
search_results = get_document_processor().search_similar_documents(
search_results = (await get_document_processor(self.session)).search_similar_documents(
knowledge_base_id=kb_id,
query=query,
k=top_k
@ -247,3 +141,104 @@ class KnowledgeBaseService:
except Exception as e:
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
return []
# ----------------------------------------------------------------------------------
async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
"""根据名称获取当前用户的知识库实例。
Args:
name (str): 知识库实例的名称
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.name == name,
KnowledgeBase.created_by == UserContext.get_current_user()['id']
)
result = (await self.session.execute(stmt)).scalar_one_or_none()
return result
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
"""获取当前用户的所有知识库的列表。
Args:
skip (int, optional): 跳过的记录数默认值为0
limit (int, optional): 返回的最大记录数默认值为50
active_only (bool, optional): 是否仅返回活动的知识库默认值为True
Returns:
List[KnowledgeBase]: 当前用户的知识库列表
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID
if active_only:
stmt = stmt.where(KnowledgeBase.is_active == True)
stmt = stmt.offset(skip).limit(limit)
return (await self.session.execute(stmt)).scalars().all()
async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
"""根据ID获取知识库实例。
Args:
kb_id (int): 知识库实例的ID
Returns:
Optional[KnowledgeBase]: 如果找到则返回知识库实例否则返回None
"""
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
return (await self.session.execute(stmt)).scalar_one_or_none()
async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
"""更新知识库实例。
Args:
kb_id (int): 待更新的知识库实例ID
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据
Returns:
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例否则返回None
Raises:
Exception: 如果更新过程中发生错误
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return None
# Update fields
update_data = kb_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(kb, field, value)
# Set audit fields
kb.set_audit_fields(is_update=True)
await self.session.commit()
await self.session.refresh(kb)
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
return kb
async def delete_knowledge_base(self, kb_id: int) -> bool:
"""删除知识库实例。
Args:
kb_id (int): 待删除的知识库实例ID
Returns:
bool: 如果知识库实例被成功删除则返回True否则返回False
Raises:
Exception: 如果删除过程中发生错误
"""
kb = await self.get_knowledge_base(kb_id)
if not kb:
return False
# TODO: Clean up vector database collection
# This should be implemented when vector database service is ready
await self.session.delete(kb)
await self.session.commit()
return True

View File

@ -9,7 +9,7 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_chroma import Chroma
from langchain_postgres import PGVector
from .embedding_factory import EmbeddingFactory
@ -21,16 +21,16 @@ from .conversation import ConversationService
from .document_processor import get_document_processor
from loguru import logger
class KnowledgeChatService:
"""Knowledge base chat service using LangChain RAG."""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
def __init__(self, session: Session):
self.session = session
self.conversation_service = ConversationService(session)
async def initialize(self):
# 获取当前LLM配置
llm_config = settings.llm.get_current_config()
llm_config = await settings.llm.get_current_config(self.session)
# Initialize LangChain ChatOpenAI
self.llm = ChatOpenAI(
@ -53,30 +53,12 @@ class KnowledgeChatService:
)
# Initialize embeddings based on provider
self.embeddings = EmbeddingFactory.create_embeddings()
logger.info(f"Knowledge Chat Service initialized with model: {self.llm.model_name}")
def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
self.embeddings = await EmbeddingFactory.create_embeddings(self.session)
async def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
"""Get vector store for knowledge base."""
try:
if settings.vector_db.type == "pgvector":
# 使用PGVector
doc_processor = get_document_processor()
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
vector_store = PGVector(
connection=doc_processor.connection_string,
embeddings=self.embeddings,
collection_name=collection_name,
use_jsonb=True
)
return vector_store
else:
# 兼容Chroma模式
import os
kb_vector_path = os.path.join(get_document_processor().vector_db_path, f"kb_{knowledge_base_id}")
kb_vector_path = os.path.join((await get_document_processor(self.session)).vector_db_path, f"kb_{knowledge_base_id}")
if not os.path.exists(kb_vector_path):
logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}")
@ -159,7 +141,7 @@ class KnowledgeChatService:
try:
# Get conversation and validate
conversation = self.conversation_service.get_conversation(conversation_id)
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError("Conversation not found")
@ -169,14 +151,14 @@ class KnowledgeChatService:
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
# Save user message
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history
messages = self.conversation_service.get_conversation_messages(conversation_id)
messages = await self.conversation_service.get_conversation_messages(conversation_id)
conversation_history = self._prepare_conversation_history(messages)
# Create RAG chain
@ -203,7 +185,7 @@ class KnowledgeChatService:
response_content = await asyncio.to_thread(rag_chain.invoke, message)
# Save assistant message with context
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=response_content,
role=MessageRole.ASSISTANT,
@ -251,14 +233,14 @@ class KnowledgeChatService:
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
# Get conversation history
messages = self.conversation_service.get_conversation_messages(conversation_id)
messages = await self.conversation_service.get_conversation_messages(conversation_id)
conversation_history = self._prepare_conversation_history(messages)
# Create RAG chain
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
# Save user message
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
@ -269,7 +251,7 @@ class KnowledgeChatService:
context = "\n\n".join([doc.page_content for doc in relevant_docs])
# Create streaming LLM
llm_config = settings.llm.get_current_config()
llm_config = await settings.llm.get_current_config()
streaming_llm = ChatOpenAI(
model=llm_config["model"],
temperature=temperature or llm_config["temperature"],
@ -315,7 +297,7 @@ class KnowledgeChatService:
# Save assistant response
if full_response:
self.conversation_service.add_message(
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
@ -331,7 +313,7 @@ class KnowledgeChatService:
yield error_message
# Save error message
self.conversation_service.add_message(
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT

View File

@ -41,24 +41,23 @@ class StreamingCallbackHandler(BaseCallbackHandler):
class LangChainChatService:
"""LangChain-based chat service for AI model integration."""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
def __init__(self, session: Session):
self.session = session
self.conversation_service = ConversationService(session)
from ..core.llm import create_llm
# 添加调试日志
logger.info(f"LLM Provider: {settings.llm.provider}")
async def initialize(self):
from ..core.new_agent import new_agent
# Initialize LangChain ChatOpenAI
self.llm = create_llm(streaming=False)
self.llm = await new_agent(self.session, streaming=False)
self.session.desc = "LangChainChatService初始化 - llm 实例化完毕"
# Streaming LLM for stream responses
self.streaming_llm = create_llm(streaming=True)
self.streaming_llm = await new_agent(self.session, streaming=True)
self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕"
self.streaming_handler = StreamingCallbackHandler()
logger.info(f"LangChain ChatService initialized with model: {self.llm.model_name}")
self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕"
def _prepare_langchain_messages(self, conversation, history: List) -> List:
"""Prepare messages for LangChain format."""
@ -101,19 +100,19 @@ class LangChainChatService:
try:
# Get conversation details
conversation = self.conversation_service.get_conversation(conversation_id)
conversation = await self.conversation_service.get_conversation(conversation_id)
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = self.conversation_service.get_conversation_history(
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
@ -123,7 +122,7 @@ class LangChainChatService:
# Update LLM parameters if provided
llm_to_use = self.llm
if temperature is not None or max_tokens is not None:
llm_config = settings.llm.get_current_config()
llm_config = await settings.llm.get_current_config()
llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
@ -140,7 +139,7 @@ class LangChainChatService:
assistant_content = response.content
# Add assistant message to database
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=assistant_content,
role=MessageRole.ASSISTANT,
@ -152,7 +151,7 @@ class LangChainChatService:
)
# Update conversation timestamp
self.conversation_service.update_conversation_timestamp(conversation_id)
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
@ -171,7 +170,7 @@ class LangChainChatService:
error_message = self._format_error_message(e)
# Add error message to database
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
@ -206,33 +205,33 @@ class LangChainChatService:
max_tokens: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""Send a message and get streaming AI response using LangChain."""
logger.info(f"Processing LangChain streaming chat request for conversation {conversation_id}")
logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}")
try:
# Get conversation details
conversation = self.conversation_service.get_conversation(conversation_id)
conversation = await self.conversation_service.get_conversation(conversation_id)
conv = conversation.to_dict()
if not conversation:
raise ChatServiceError("Conversation not found")
# Add user message to database
user_message = self.conversation_service.add_message(
user_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=message,
role=MessageRole.USER
)
# Get conversation history for context
history = self.conversation_service.get_conversation_history(
history = await self.conversation_service.get_conversation_history(
conversation_id, limit=20
)
# Prepare messages for LangChain
langchain_messages = self._prepare_langchain_messages(conversation, history)
langchain_messages = self._prepare_langchain_messages(conv, history)
# Update streaming LLM parameters if provided
streaming_llm_to_use = self.streaming_llm
if temperature is not None or max_tokens is not None:
llm_config = settings.llm.get_current_config()
llm_config = await settings.llm.get_current_config()
streaming_llm_to_use = ChatOpenAI(
model=llm_config["model"],
openai_api_key=llm_config["api_key"],
@ -241,13 +240,13 @@ class LangChainChatService:
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
streaming=True
)
# Clear previous streaming handler state
self.streaming_handler.clear()
# Stream response
full_response = ""
async for chunk in streaming_llm_to_use.astream(langchain_messages):
try:
async for chunk in streaming_llm_to_use._astream(langchain_messages):
# Handle different chunk types to avoid KeyError
chunk_content = None
if hasattr(chunk, 'content'):
@ -265,9 +264,11 @@ class LangChainChatService:
if chunk_content:
full_response += chunk_content
yield chunk_content
except Exception as e:
logger.error(f"Error in LLM streaming: {e}")
yield f"{self._format_error_message(e)} >>> {e}"
# Add complete assistant message to database
assistant_message = self.conversation_service.add_message(
assistant_message = await self.conversation_service.add_message(
conversation_id=conversation_id,
content=full_response,
role=MessageRole.ASSISTANT,
@ -280,13 +281,12 @@ class LangChainChatService:
)
# Update conversation timestamp
self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"Successfully processed LangChain streaming chat request for conversation {conversation_id}")
await self.conversation_service.update_conversation_timestamp(conversation_id)
logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}")
except Exception as e:
# 安全地格式化异常信息避免再次引发KeyError
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id}"
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}"
logger.error(error_info, exc_info=True)
# Format error message for user
@ -294,7 +294,7 @@ class LangChainChatService:
yield error_message
# Add error message to database
self.conversation_service.add_message(
await self.conversation_service.add_message(
conversation_id=conversation_id,
content=error_message,
role=MessageRole.ASSISTANT,
@ -324,23 +324,23 @@ class LangChainChatService:
logger.error(f"Failed to get available models: {str(e)}")
return ["gpt-3.5-turbo"]
def update_model_config(
async def update_model_config(
self,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
):
"""Update LLM configuration."""
from ..core.llm import create_llm
from ..core.new_agent import new_agent
# 重新创建LLM实例
self.llm = create_llm(
self.llm = await new_agent(
model=model,
temperature=temperature,
streaming=False
)
self.streaming_llm = create_llm(
self.streaming_llm = await new_agent(
model=model,
temperature=temperature,
streaming=True
@ -357,7 +357,7 @@ class LangChainChatService:
if "rate limit" in error_str.lower():
return "服务器繁忙,请稍后再试。"
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
return "API认证失败请检查配置"
return f"API认证失败请检查配置文件"
elif "timeout" in error_str.lower():
return "请求超时,请重试。"
elif "connection" in error_str.lower():

View File

@ -11,11 +11,9 @@ from loguru import logger
class LLMConfigService:
"""LLM配置管理服务"""
def __init__(self, db_session: Optional[Session] = None):
self.db = db_session or get_session() # TODO DrGraph:检查异步
def get_default_chat_config(self) -> Optional[LLMConfig]:
async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认对话模型配置"""
# async for session in get_session():
try:
stmt = select(LLMConfig).where(
and_(
@ -24,7 +22,7 @@ class LLMConfigService:
LLMConfig.is_active == True
)
)
config = self.db.execute(stmt).scalar_one_or_none()
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
logger.warning("未找到默认对话模型配置")
@ -36,7 +34,7 @@ class LLMConfigService:
logger.error(f"获取默认对话模型配置失败: {str(e)}")
return None
def get_default_embedding_config(self) -> Optional[LLMConfig]:
async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]:
"""获取默认嵌入模型配置"""
try:
stmt = select(LLMConfig).where(
@ -46,23 +44,27 @@ class LLMConfigService:
LLMConfig.is_active == True
)
)
config = self.db.execute(stmt).scalar_one_or_none()
config = None
if session != None:
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
logger.warning("未找到默认嵌入模型配置")
if session != None:
session.desc = "ERROR: 未找到默认嵌入模型配置"
return None
session.desc = f"获取默认嵌入模型配置 > 结果:{config}"
return config
except Exception as e:
logger.error(f"获取默认嵌入模型配置失败: {str(e)}")
if session != None:
session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}"
return None
def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
"""根据ID获取配置"""
try:
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
return self.db.execute(stmt).scalar_one_or_none()
return (await self.db.execute(stmt)).scalar_one_or_none()
except Exception as e:
logger.error(f"获取配置失败: {str(e)}")
return None
@ -82,17 +84,17 @@ class LLMConfigService:
logger.error(f"获取激活配置失败: {str(e)}")
return []
def _get_fallback_chat_config(self) -> Dict[str, Any]:
async def _get_fallback_chat_config(self) -> Dict[str, Any]:
"""获取fallback对话模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return settings.llm.get_current_config()
return await settings.llm.get_current_config()
def _get_fallback_embedding_config(self) -> Dict[str, Any]:
async def _get_fallback_embedding_config(self) -> Dict[str, Any]:
"""获取fallback嵌入模型配置从环境变量"""
from ..core.config import get_settings
settings = get_settings()
return settings.embedding.get_current_config()
return await settings.embedding.get_current_config()
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
"""测试配置连接"""
@ -110,12 +112,12 @@ class LLMConfigService:
logger.error(f"测试配置失败: {str(e)}")
return {"success": False, "error": str(e)}
# 全局实例
_llm_config_service = None
# # 全局实例
# _llm_config_service = None
def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
"""获取LLM配置服务实例"""
global _llm_config_service
if _llm_config_service is None or db_session is not None:
_llm_config_service = LLMConfigService(db_session)
return _llm_config_service
# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
# """获取LLM配置服务实例"""
# global _llm_config_service
# if _llm_config_service is None or db_session is not None:
# _llm_config_service = LLMConfigService(db_session)
# return _llm_config_service

View File

@ -49,8 +49,9 @@ class SmartDatabaseWorkflowManager:
self.db = db
self.table_metadata_service = TableMetadataService(db) if db else None
from ..core.llm import create_llm
self.llm = create_llm()
async def initialize(self):
from ..core.new_agent import new_agent
self.llm = await new_agent()
def _get_database_tool(self, db_type: str):
"""根据数据库类型获取对应的数据库工具"""

View File

@ -54,9 +54,10 @@ class SmartExcelWorkflowManager:
else:
self.metadata_service = None
from ..core.llm import create_llm
async def initialize(self):
from ..core.new_agent import new_agent
# 禁用流式响应避免pandas代理兼容性问题
self.llm = create_llm(streaming=False)
self.llm = await new_agent(streaming=False)
async def _run_in_executor(self, func, *args):
"""在线程池中运行阻塞函数"""

View File

@ -15,8 +15,12 @@ class SmartWorkflowManager:
def __init__(self, db=None):
self.db = db
self.excel_workflow = SmartExcelWorkflowManager(db)
self.database_workflow = SmartDatabaseWorkflowManager(db)
async def initialize(self):
self.excel_workflow = SmartExcelWorkflowManager(self.db)
await self.excel_workflow.initialize()
self.database_workflow = SmartDatabaseWorkflowManager(self.db)
await self.database_workflow.initialize()
async def process_excel_query_stream(
self,

View File

@ -35,7 +35,7 @@ class TableMetadataService:
DatabaseConfig.id == database_config_id,
DatabaseConfig.created_by == user_id
)
db_config = self.session.scalar_one_or_none(stmt)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = "ERROR: 数据库配置不存在"
@ -202,7 +202,7 @@ class TableMetadataService:
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = self.session.scalar_one_or_none(stmt)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
@ -216,8 +216,8 @@ class TableMetadataService:
existing.table_comment = metadata['table_comment']
existing.last_synced_at = datetime.utcnow()
self.session.commit()
self.session.refresh(existing)
await self.session.commit()
await self.session.refresh(existing)
return existing
else:
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
@ -240,8 +240,8 @@ class TableMetadataService:
)
self.session.add(table_metadata)
self.session.commit()
self.session.refresh(table_metadata)
await self.session.commit()
await self.session.refresh(table_metadata)
self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据"
return table_metadata
@ -258,7 +258,7 @@ class TableMetadataService:
DatabaseConfig.id == database_config_id,
DatabaseConfig.user_id == user_id
)
db_config = self.session.scalar_one_or_none(stmt)
db_config = (await self.session.execute(stmt)).scalar_one_or_none()
if not db_config:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
@ -275,7 +275,7 @@ class TableMetadataService:
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = self.session.scalar_one_or_none(stmt)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id}{table_name} 的元数据配置"
@ -320,7 +320,7 @@ class TableMetadataService:
})
# 提交事务
self.session.commit()
await self.session.commit()
self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id}{table_names} 的元数据配置"
return {
'saved_tables': saved_tables,
@ -330,7 +330,7 @@ class TableMetadataService:
}
def get_user_table_metadata(
async def get_user_table_metadata(
self,
user_id: int,
database_config_id: Optional[int] = None
@ -345,9 +345,9 @@ class TableMetadataService:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
raise NotFoundError("数据库配置不存在")
stmt = stmt.where(TableMetadata.is_enabled_for_qa == True)
return self.session.scalars(stmt).all()
return (await self.session.scalars(stmt)).all()
def get_table_metadata_by_name(
async def get_table_metadata_by_name(
self,
user_id: int,
database_config_id: int,
@ -360,9 +360,9 @@ class TableMetadataService:
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
return self.session.scalar_one_or_none(stmt)
return (await self.session.execute(stmt)).scalar_one_or_none()
def update_table_qa_settings(
async def update_table_qa_settings(
self,
user_id: int,
metadata_id: int,
@ -375,7 +375,7 @@ class TableMetadataService:
TableMetadata.id == metadata_id,
TableMetadata.created_by == user_id
)
metadata = self.session.scalar_one_or_none(stmt)
metadata = (await self.session.execute(stmt)).scalar_one_or_none()
if not metadata:
self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在"
@ -388,15 +388,15 @@ class TableMetadataService:
if 'business_context' in settings:
metadata.business_context = settings['business_context']
self.session.commit()
await self.session.commit()
return True
except Exception as e:
self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}"
self.session.rollback()
await self.session.rollback()
return False
def save_table_metadata(
async def save_table_metadata(
self,
user_id: int,
database_config_id: int,
@ -414,7 +414,7 @@ class TableMetadataService:
TableMetadata.database_config_id == database_config_id,
TableMetadata.table_name == table_name
)
existing = self.session.scalar_one_or_none(stmt)
existing = (await self.session.execute(stmt)).scalar_one_or_none()
if existing:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 已存在,更新其元数据"
@ -424,7 +424,7 @@ class TableMetadataService:
existing.row_count = row_count
existing.table_comment = table_comment
existing.last_synced_at = datetime.utcnow()
self.session.commit()
await self.session.commit()
return existing
else:
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id}{table_name} 不存在,创建新记录"
@ -444,8 +444,8 @@ class TableMetadataService:
)
self.session.add(metadata)
self.session.commit()
self.session.refresh(metadata)
await self.session.commit()
await self.session.refresh(metadata)
return metadata
def _decrypt_password(self, encrypted_password: str) -> str:

View File

@ -7,6 +7,7 @@ from langchain.tools import BaseTool
from langchain_community.tools.tavily_search import TavilySearchResults
from pydantic import BaseModel, Field, PrivateAttr
from typing import Optional, Type, ClassVar
from langchain_tavily import TavilySearch
# 定义输入参数模型替代原get_parameters()
class SearchInput(BaseModel):
@ -35,7 +36,7 @@ class TavilySearchTool(BaseTool):
raise ValueError("Tavily API key not found in settings")
# 初始化Tavily客户端
self._search_client = TavilySearchResults(
self._search_client = TavilySearch(
tavily_api_key=self._tavily_api_key
)

View File

@ -126,10 +126,10 @@ class UserService:
self.session.desc = f"用户ID为{user_id}的密码已成功更改"
return True
def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
async def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK
"""Reset user password (admin only, no current password required)."""
self.session.desc = f"重置用户ID为{user_id}的密码"
user = self.get_user_by_id(user_id)
user = await self.get_user_by_id(user_id)
if not user:
raise ValidationError("User not found")
@ -141,7 +141,7 @@ class UserService:
hashed_password = self.get_password_hash(new_password)
# Update password
user.hashed_password = hashed_password
self.session.commit()
await self.session.commit()
self.session.desc = f"用户ID为{user_id}的密码已成功重置"
return True

View File

@ -42,8 +42,8 @@ class WorkflowEngine:
execution.set_audit_fields(user_id)
self.session.add(execution)
self.session.commit()
self.session.refresh(execution)
await self.session.commit()
await self.session.refresh(execution)
@ -75,8 +75,8 @@ class WorkflowEngine:
execution.set_audit_fields(user_id, is_update=True)
self.session.commit()
self.session.refresh(execution)
await self.session.commit()
await self.session.refresh(execution)
return WorkflowExecutionResponse.from_orm(execution)
@ -100,8 +100,8 @@ class WorkflowEngine:
execution.set_audit_fields(user_id)
self.session.add(execution)
self.session.commit()
self.session.refresh(execution)
await self.session.commit()
await self.session.refresh(execution)
# 发送工作流开始执行的消息
yield {
@ -170,11 +170,8 @@ class WorkflowEngine:
}
execution.set_audit_fields(user_id, is_update=True)
self.session.commit()
self.session.refresh(execution)
self.session.commit()
self.session.refresh(execution)
await self.session.commit()
await self.session.refresh(execution)
def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""构建节点依赖图"""
@ -385,8 +382,8 @@ class WorkflowEngine:
started_at=datetime.now().isoformat()
)
self.session.add(node_execution)
self.session.commit()
self.session.refresh(node_execution)
await self.session.commit()
await self.session.refresh(node_execution)
start_time = time.time()
@ -419,7 +416,7 @@ class WorkflowEngine:
}
node_execution.input_data = display_input_data
self.session.commit()
await self.session.commit()
@ -446,7 +443,7 @@ class WorkflowEngine:
node_execution.completed_at = datetime.now().isoformat()
node_execution.duration_ms = int((end_time - start_time) * 1000)
self.session.commit()
await self.session.commit()
@ -459,7 +456,7 @@ class WorkflowEngine:
node_execution.error_message = str(e)
node_execution.completed_at = datetime.now().isoformat()
node_execution.duration_ms = int((end_time - start_time) * 1000)
self.session.commit()
await self.session.commit()
@ -918,8 +915,8 @@ class WorkflowEngine:
# 工作流引擎实例
def get_workflow_engine(session: Session = None) -> WorkflowEngine:
async def get_workflow_engine(session: Session = None) -> WorkflowEngine:
"""获取工作流引擎实例"""
if session is None:
session = next(get_session())
session = await anext(get_session())
return WorkflowEngine(session)

490
backend/utils/Constant.py Normal file
View File

@ -0,0 +1,490 @@
from PySide6.QtCore import QPointF, QRectF, QSizeF
# -*- coding: utf-8 -*-
class Constant:
invalid_point = QPointF(-123.4, -234.71) # 无效点
invalid_rect = QRectF(invalid_point, QSizeF(0, 0)) # 无效矩形
service_yml_path = 'config/service/dsp_%s_service.yml' # 服务配置路径
kafka_yml_path = 'config/kafka/dsp_%s_kafka.yml' # kafka配置路径
aliyun_yml_path = "config/aliyun/dsp_%s_aliyun.yml" # 阿里云配置路径
baidu_yml_path = 'config/baidu/dsp_%s_baidu.yml' # 百度配置路径
pull_frame_width = 1400 # 拉流帧宽度
LLM_MODE_NONE = 0
LLM_MODE_NONCHAT = 1
LLM_MODE_CHAT = 2
LLM_MODE_EMBEDDING = 3
LLM_MODE_LOCAL_OLLAMA = 4
LLM_PROMPT_TEMPLATE_METHOD_FORMAT = "format"
LLM_PROMPT_TEMPLATE_METHOD_INVOKE = "invoke"
LLM_PROMPT_VALUE_STR = "str"
LLM_PROMPT_VALUE_MESSAGES = "messages"
LLM_PROMPT_VALUE_VALUE = "promptValue"
INPUT_NONE = 0
INPUT_PULL_STREAM = 1
INPUT_NET_CAMERA = 2
INPUT_USB_CAMERA = 3
INPUT_LOCAL_VIDEO = 4
INPUT_LOCAL_DIR = 5
INPUT_LOCAL_FILE = 6
INPUT_LOCAL_LABELLED_DIR = 7
INPUT_LOCAL_LABELLED_ZIP = 8
SYSTEM_BUTTON_SNAP = 1
SYSTEM_BUTTON_CLEAR_LOG = 2
MODE_OCR = 7
WEB_OWNER_NONE = 0
WEB_OWNER_ALG_TEST = 1
WEB_OWNER_LLM = 2
PAGE_MODE_NONE = 0
PAGE_MODE_VIDEO = 1
PAGE_MODE_LABELLING = 2
PAGE_MODE_ALG_TEST = 3
PAGE_MODE_WEB = 4
PAGE_MODE_LLM = 5
PAGE_MODE_CONFIG = 6
ALG_TEST_LOAD_DATA = 1
ALG_TEST_TRAIN = 2
ALG_TEST_INFER = 3
ALG_STATUS_NOTHING = 0 # 啥也没干
ALG_STATUS_MODEL_READY = 1 # 模型已就绪
ALG_STATUS_DATA_LOADED = 2 # 加载已数据/加载中
ALG_STATUS_TRAINED = 4 # 训练完成/训练中
ALG_STATUS_INFER = 8 # 推理完成/推理中
UTF_8 = "utf-8" # 编码格式
COLOR = (
[0, 0, 255],
[255, 0, 0],
[211, 0, 148],
[0, 127, 0],
[0, 69, 255],
[0, 255, 0],
[255, 0, 255],
[0, 0, 127],
[127, 0, 255],
[255, 129, 0],
[139, 139, 0],
[255, 255, 0],
[127, 255, 0],
[0, 127, 255],
[0, 255, 127],
[255, 127, 255],
[8, 101, 139],
[171, 130, 255],
[139, 112, 74],
[205, 205, 180])
ONLINE = "online"
OFFLINE = "offline"
PHOTO = "photo"
RECORDING = "recording"
ONLINE_START_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start"]
},
"pull_url": {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 255
},
"push_url": {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 255
},
"logo_url": {
'type': 'string',
'required': False,
'nullable': True,
'maxlength': 255
},
"models": {
'type': 'list',
'required': True,
'nullable': False,
'minlength': 1,
'maxlength': 3,
'schema': {
'type': 'dict',
'required': True,
'schema': {
"code": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "categories",
'regex': r'^[a-zA-Z0-9]{1,255}$'
},
"is_video": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"is_image": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"categories": {
'type': 'list',
'required': True,
'dependencies': "code",
'schema': {
'type': 'dict',
'required': True,
'schema': {
"id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{0,255}$'},
"config": {
'type': 'dict',
'required': False,
'dependencies': "id",
}
}
}
}
}
}
}
}
ONLINE_STOP_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["stop"]
}
}
OFFLINE_START_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start"]
},
"push_url": {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 255
},
"pull_url": {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 255
},
"logo_url": {
'type': 'string',
'required': False,
'nullable': True,
'maxlength': 255
},
"models": {
'type': 'list',
'required': True,
'maxlength': 3,
'minlength': 1,
'schema': {
'type': 'dict',
'required': True,
'schema': {
"code": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "categories",
'regex': r'^[a-zA-Z0-9]{1,255}$'
},
"is_video": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"is_image": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"categories": {
'type': 'list',
'required': True,
'dependencies': "code",
'schema': {
'type': 'dict',
'required': True,
'schema': {
"id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{0,255}$'},
"config": {
'type': 'dict',
'required': False,
'dependencies': "id",
}
}
}
}
}
}
}
}
OFFLINE_STOP_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["stop"]
}
}
IMAGE_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start"]
},
"logo_url": {
'type': 'string',
'required': False,
'nullable': True,
'maxlength': 255
},
"image_urls": {
'type': 'list',
'required': True,
'minlength': 1,
'schema': {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 5000
}
},
"models": {
'type': 'list',
'required': True,
'schema': {
'type': 'dict',
'required': True,
'schema': {
"code": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "categories",
'regex': r'^[a-zA-Z0-9]{1,255}$'
},
"is_video": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"is_image": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "code",
'allowed': ["0", "1"]
},
"categories": {
'type': 'list',
'required': True,
'dependencies': "code",
'schema': {
'type': 'dict',
'required': True,
'schema': {
"id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{0,255}$'},
"config": {
'type': 'dict',
'required': False,
'dependencies': "id",
}
}
}
}
}
}
}
}
RECORDING_START_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start"]
},
"pull_url": {
'type': 'string',
'required': True,
'empty': False,
'maxlength': 255
},
"push_url": {
'type': 'string',
'required': False,
'empty': True,
'maxlength': 255
},
"logo_url": {
'type': 'string',
'required': False,
'nullable': True,
'maxlength': 255
}
}
RECORDING_STOP_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["stop"]
}
}
PULL2PUSH_START_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start"]
},
"video_urls": {
'type': 'list',
'required': True,
'nullable': False,
'schema': {
'type': 'dict',
'required': True,
'schema': {
"id": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "pull_url",
'regex': r'^[a-zA-Z0-9]{1,255}$'
},
"pull_url": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "push_url",
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
},
"push_url": {
'type': 'string',
'required': True,
'empty': False,
'dependencies': "id",
'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$'
}
}
}
}
}
PULL2PUSH_STOP_SCHEMA = {
"request_id": {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,36}$'
},
"command": {
'type': 'string',
'required': True,
'allowed': ["start", "stop"]
},
"video_ids": {
'type': 'list',
'required': False,
'nullable': True,
'schema': {
'type': 'string',
'required': True,
'empty': False,
'regex': r'^[a-zA-Z0-9]{1,255}$'
}
}
}

View File

@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
from loguru import logger
from enum import Enum, unique
class ServiceException(Exception):
def __init__(self, code, msg, desc=None):
self.code = code
if desc is None:
self.msg = msg
else:
self.msg = msg % desc
def __str__(self):
logger.error("异常编码:{}, 异常描述:{}", self.code, self.msg)
# 异常枚举
@unique
class ExceptionType(Enum):
OR_VIDEO_ADDRESS_EXCEPTION = ("SP000", "未拉取到视频流, 请检查拉流地址是否有视频流!")
ANALYSE_TIMEOUT_EXCEPTION = ("SP001", "AI分析超时!")
PULLSTREAM_TIMEOUT_EXCEPTION = ("SP002", "原视频拉流超时!")
READSTREAM_TIMEOUT_EXCEPTION = ("SP003", "原视频读取视频流超时!")
GET_VIDEO_URL_EXCEPTION = ("SP004", "获取视频播放地址失败!")
GET_VIDEO_URL_TIMEOUT_EXCEPTION = ("SP005", "获取原视频播放地址超时!")
PULL_STREAM_URL_EXCEPTION = ("SP006", "拉流地址不能为空!")
PUSH_STREAM_URL_EXCEPTION = ("SP007", "推流地址不能为空!")
PUSH_STREAM_TIME_EXCEPTION = ("SP008", "未生成本地视频地址!")
AI_MODEL_MATCH_EXCEPTION = ("SP009", "未匹配到对应的AI模型!")
ILLEGAL_PARAMETER_FORMAT = ("SP010", "非法参数格式!")
PUSH_STREAMING_CHANNEL_IS_OCCUPIED = ("SP011", "推流通道可能被占用, 请稍后再试!")
VIDEO_RESOLUTION_EXCEPTION = ("SP012", "不支持该分辨率类型的视频,请切换分辨率再试!")
READ_IAMGE_URL_EXCEPTION = ("SP013", "未能解析图片地址!")
DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED = ("SP014", "不支持该类型的检测目标!")
WRITE_STREAM_EXCEPTION = ("SP015", "写流异常!")
OR_VIDEO_DO_NOT_EXEIST_EXCEPTION = ("SP016", "原视频不存在!")
MODEL_LOADING_EXCEPTION = ("SP017", "模型加载异常!")
MODEL_ANALYSE_EXCEPTION = ("SP018", "算法模型分析异常!")
AI_MODEL_CONFIG_EXCEPTION = ("SP019", "模型配置不能为空!")
AI_MODEL_GET_CONFIG_EXCEPTION = ("SP020", "获取模型配置异常, 请检查模型配置是否正确!")
MODEL_GROUP_LIMIT_EXCEPTION = ("SP021", "模型组合个数超过限制!")
MODEL_NOT_SUPPORT_VIDEO_EXCEPTION = ("SP022", "%s不支持视频识别!")
MODEL_NOT_SUPPORT_IMAGE_EXCEPTION = ("SP023", "%s不支持图片识别!")
THE_DETECTION_TARGET_CANNOT_BE_EMPTY = ("SP024", "检测目标不能为空!")
URL_ADDRESS_ACCESS_FAILED = ("SP025", "URL地址访问失败, 请检测URL地址是否正确!")
UNIVERSAL_TEXT_RECOGNITION_FAILED = ("SP026", "识别失败!")
COORDINATE_ACQUISITION_FAILED = ("SP027", "飞行坐标识别异常!")
PUSH_STREAM_EXCEPTION = ("SP028", "推流异常!")
MODEL_DUPLICATE_EXCEPTION = ("SP029", "存在重复模型配置!")
DETECTION_TARGET_NOT_SUPPORT = ("SP031", "存在不支持的检测目标!")
TASK_EXCUTE_TIMEOUT = ("SP032", "任务执行超时!")
PUSH_STREAM_URL_IS_NULL = ("SP033", "拉流、推流地址不能为空!")
PULL_STREAM_NUM_LIMIT_EXCEPTION = ("SP034", "转推流数量超过限制!")
NOT_REQUESTID_TASK_EXCEPTION = ("SP993", "未查询到该任务,无法停止任务!")
NO_RESOURCES = ("SP995", "服务器暂无资源可以使用请稍后30秒后再试!")
NO_CPU_RESOURCES = ("SP996", "暂无CPU资源可以使用请稍后再试!")
SERVICE_COMMON_EXCEPTION = ("SP997", "公共服务异常!")
NO_GPU_RESOURCES = ("SP998", "暂无GPU资源可以使用请稍后再试!")
SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常!")

12
backend/utils/Flag.py Normal file
View File

@ -0,0 +1,12 @@
class Flag:
Unique = True
Append = False
Debug = True
class Option:
NoOption = 0x00
AddObject_AutoName = 0x01
AddObject_Select = 0x02

668
backend/utils/Helper.py Normal file
View File

@ -0,0 +1,668 @@
# -*- coding: utf-8 -*-
import sys, os, cv2
from os import makedirs
from os.path import join, exists
from loguru import logger
from json import loads
from ruamel.yaml import safe_load, YAML
import random, sys, math, inspect, psutil
from pathlib import Path
from PySide6.QtGui import QIcon, QColor
from PySide6.QtCore import QObject, QRectF, QEventLoop, QIODevice, QTextStream, QFile
from PySide6.QtWidgets import QApplication
import DrGraph.utils.vclEnums as enums
#region Property
class Property:
def __init__(self, read_func=None, write_func=None, default=None, hasMember=True):
self.read_func = read_func
self.write_func = write_func
self.default = default
self.owner_class = None
self.private_name = None
self.hasMember = hasMember
def __set_name__(self, owner, name):
if self.hasMember:
self.private_name = f"_{name}"
self.owner_class = owner
def callDirectGet(self, instance, owner):
if instance is None or not self.hasMember:
return self.default
if not hasattr(instance, self.private_name):
return self.default
return getattr(instance, self.private_name)
def callCustomGet(self, instance, owner):
if instance is None:
return self
if self.read_func is None:
if hasattr(instance, self.private_name):
return getattr(instance, self.private_name)
return self.default
try:
if isinstance(self.read_func, str):
if hasattr(instance, self.read_func):
method = getattr(instance, self.read_func)
return method()
elif hasattr(self.read_func, '__name__'):
method_name = self.read_func.__name__
if hasattr(instance, method_name):
method = getattr(instance, method_name)
return method()
elif callable(self.read_func):
try:
return self.read_func(instance)
except (TypeError, AttributeError):
return self.read_func()
except Exception as e:
logger.error(e)
if hasattr(instance, self.private_name):
return getattr(instance, self.private_name)
return self.default
def callDirectSet(self, instance, value):
if instance is None or not self.hasMember:
return
setattr(instance, self.private_name, value)
def callCustomSet(self, instance, value):
try:
if isinstance(self.write_func, str):
if hasattr(instance, self.write_func):
method = getattr(instance, self.write_func)
method(value)
elif hasattr(self.write_func, '__name__'):
method_name = self.write_func.__name__
if hasattr(instance, method_name):
method = getattr(instance, method_name)
method(value)
elif callable(self.write_func):
try:
self.write_func(instance, value)
except (TypeError, AttributeError):
self.write_func(value)
except Exception as e:
pass
class Property_rw(Property):
def __init__(self, default=None, hasMember=True):
super().__init__(None, None, default, hasMember)
def __get__(self, instance, owner):
return self.callDirectGet(instance, owner)
def __set__(self, instance, value):
self.callDirectSet(instance, value)
class Property_Rw(Property):
def __init__(self, read_func=None, default=None, hasMember=True):
super().__init__(read_func, None, default, hasMember)
def __get__(self, instance, owner):
return self.callCustomGet(instance, owner)
def __set__(self, instance, value):
setattr(instance, self.private_name, value)
class Property_rW(Property):
def __init__(self, write_func=None, default=None, hasMember=True):
super().__init__(None, write_func, default, hasMember)
def __get__(self, instance, owner):
return self.callDirectGet(instance, owner)
def __set__(self, instance, value):
self.callCustomSet(instance, value)
class Property_RW(Property):
def __init__(self, read_func=None, write_func=None, default=None, hasMember=True):
super().__init__(read_func, write_func, default, hasMember)
def __get__(self, instance, owner):
return self.callCustomGet(instance, owner)
def __set__(self, instance, value):
self.callCustomSet(instance, value)
#endregion Property
class AppHelper(QObject):
app = Property_rw(None)
def setBriefStatusText(self, text):
if self.briefStatusControl:
self.briefStatusControl.setText(text)
else:
print(text)
briefStatusText = Property_rW(setBriefStatusText, '')
def setProgress(self, value):
if self.progressBarControl:
self.progressBarControl.setValue(value)
progress = Property_rW(setProgress, 0)
def setProgressMax(self, value):
if self.progressBarControl:
self.progressBarControl.setMaximum(value)
progressMax = Property_rW(setProgressMax, 100)
def setProgressMin(self, value):
if self.progressBarControl:
self.progressBarControl.setMinimum(value)
progressMin = Property_rW(setProgressMin, 0)
def __init__(self):
self.briefStatusControl = None
self.progressBarControl = None
self._briefStatusText = ''
pass
class Helper:
OnLogMsg = None
AppFlag_SaveAnalysisResult = True
AppFlag_SaveLog = False
App = None
@staticmethod
def castRange(value, minValue, maxValue):
return max(minValue, min(maxValue, value))
# 取得程序目录
@staticmethod
def getPath_App():
if getattr(sys, 'frozen', False):
# 如果程序是打包的exe文件
return os.path.dirname(sys.executable)
else:
# 如果是Python脚本 - 获取上两级目录
current_file = os.path.abspath(__file__) # f:\PySide6\AiBase\DrGraph\utils\Helper.py
current_dir = os.path.dirname(current_file) # f:\PySide6\AiBase\DrGraph\utils
parent_dir = os.path.dirname(current_dir) # f:\PySide6\AiBase\DrGraph
root_dir = os.path.dirname(parent_dir) # f:\PySide6\AiBase
return root_dir
@staticmethod
def fitOS(file_name):
if sys.platform.startswith('win'):
file_name = file_name.replace('/','\\')
else:
file_name = file_name.replace('\\', '/')
return file_name
def generateDistinctColors(n, s=0.8, v=0.7):
import colorsys
colors = []
for i in range(n):
hue = i * 1.0 / n # 均匀分布在 [0, 1)
r, g, b = colorsys.hsv_to_rgb(hue, s, v)
colors.append(QColor(r * 255, g * 255, b * 255))
return colors
def setBriefStatusText(self, text):
Helper.App.setBriefStatusText(text)
briefStatusText = Property_rW(setBriefStatusText, '')
@staticmethod
def Sleep(msec):
QApplication.processEvents(QEventLoop.AllEvents, msec)
@staticmethod
def getAbsoluteFileName(file_name):
if os.path.isabs(file_name):
return Helper.fitOS(file_name)
else:
return Helper.fitOS(os.path.join(Helper.getPath_App(), file_name))
@staticmethod
def getConfigs(path, read_type='yml'):
"""
读取配置文件并返回解析后的配置信息
:param path: 配置文件路径
:param read_type: 配置文件类型默认为'yml'可选'json''yml'
:return: 解析后的配置信息JSON格式返回字典YML格式返回对应的数据结构
:raises Exception: 当无法获取配置信息时抛出异常
"""
yaml = YAML(typ='safe', pure=True)
with open(path, 'r', encoding='utf-8') as f:
return yaml.load(f)
# with open(path, 'r', encoding="utf-8") as f:
# # 根据文件类型选择相应的解析方式
# if read_type == 'json':
# return loads(f.read())
# if read_type == 'yml':
# return safe_load(f)
# 如果未成功读取配置信息,则抛出异常
raise Exception('路径: %s未获取配置信息' % path)
@staticmethod
def getTooltipText(content):
# 增加一个小喇叭图标
# content = f'<img src="appIOs/res/images/icons/info.png" width="16" height="16"> {content}'
return f"""
<html>
<head/><body>
<p><span style=" font-weight:600; color:#ffffff;">DrGraph <img src="appIOs/res/images/icons/Notice.png" width="16" height="16"></span></p>
<p>{content}</p>
</body>
</html>
"""
@staticmethod
def log_init(app, base_dir, env):
"""
初始化日志配置
:param base_dir: 基础目录路径用于定位配置文件和日志文件存储位置
:param env: 环境标识用于加载对应环境的日志配置文件
:return: 无返回值
"""
Helper.App = AppHelper()
Helper.App.app = app
# QToolTip样式 - 自定义样式 - 增加Header
app.setStyleSheet("""
QToolTip {
background-color: #dd2222;
color: #f0f0f0;
border: 1px solid #555;
border-radius: 4px;
padding: 6px;
font: 10pt "Segoe UI";
opacity: 220;
}
""")
log_config = Helper.getConfigs(join(base_dir, 'appIOs/configs/logger/drgraph_%s_logger.yml' % env))
# 判断日志文件是否存在,不存在创建
base_path = join(base_dir, log_config.get("base_path"))
if not exists(base_path):
makedirs(base_path)
# 移除日志设置
logger.remove(handler_id=None)
# 打印日志到文件
if bool(log_config.get("enable_file_log")):
logger.add(join(base_path, log_config.get("log_name")),
rotation=log_config.get("rotation"),
retention=log_config.get("retention"),
format=log_config.get("log_fmt"),
level=log_config.get("level"),
enqueue=True,
encoding=log_config.get("encoding"))
# 控制台输出
if bool(log_config.get("enable_stderr")):
logger.add(sys.stderr,
format=log_config.get("log_fmt"),
level=log_config.get("level"),
enqueue=True)
logger.info("\n\n\n----=========== 日志配置初始化完成, 开始新的日志记录 ==========----")
@staticmethod
def log_info(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'INFO: {msg}', 'black')
caller = inspect.stack()[1]
logger.info(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "INFO", "msg" : msg} )
@staticmethod
def log_error(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'ERROR: {msg}', 'red')
caller = inspect.stack()[1]
logger.error(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "ERROR", "msg" : msg} )
@staticmethod
def log_warning(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'WARNING: {msg}', (255, 128, 0))
caller = inspect.stack()[1]
logger.warning(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "WARNING", "msg" : msg} )
@staticmethod
def log_debug(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'DEBUG: {msg}', (0, 128, 128))
caller = inspect.stack()[1]
logger.debug(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "DEBUG", "msg" : msg} )
@staticmethod
def log_critical(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'CRITICAL: {msg}', (128, 0, 128))
caller = inspect.stack()[1]
logger.critical(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "CRITICAL", "msg" : msg} )
@staticmethod
def log_exception(msg, toWss = False):
if Helper.OnLogMsg:
Helper.OnLogMsg(f'EXCEPTION: {msg}', (255, 140, 0))
caller = inspect.stack()[1]
logger.exception(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "EXCEPTION", "msg" : msg} )
@staticmethod
def log(msg, toWss = False):
caller = inspect.stack()[1]
logger.log(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}")
if toWss:
Helper.log_wss({"type": "log", "kind": "LOG", "msg" : msg} )
@staticmethod
def log_wss(msg):
if Helper.wss:
Helper.wss.send(msg)
@staticmethod
def getTextSize(font, text):
import pygame as pg
surface = font.render(text, True, (0, 0, 0))
return (surface.get_width(), surface.get_height(), surface)
@staticmethod
def buildSurfaces(font, text, width, color, wordWrap):
text = text.strip()
w = Helper.getTextSize(font, text)[0]
result = []
if w > width and wordWrap:
segLen = math.floor(width / w * len(text))
while len(text):
if len(text) < segLen:
t = text
text = ''
else:
t = text[:segLen]
text = text[segLen:]
result.append(font.render(t, True, color))
else:
result.append(font.render(text, True, color))
return result
@staticmethod
def randomColor():
'''随机颜色'''
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
@staticmethod
def reverseColor(color: QColor):
'''反转颜色'''
return (255 - color.red(), 255 - color.green(), 255 - color.blue())
@staticmethod
def getRGB(color_value):
# 如果是元组或列表形式的RGB值
if isinstance(color_value, (tuple, list)):
if len(color_value) >= 3:
# 取前三个值作为RGB
r, g, b = color_value[0], color_value[1], color_value[2]
# 确保值在0-255范围内
return (max(0, min(255, int(r))),
max(0, min(255, int(g))),
max(0, min(255, int(b))))
# 如果是整数形式的颜色值
elif isinstance(color_value, int):
# 将整数转换为RGB分量
# 假设格式为0xRRGGBB
r = (color_value >> 16) & 0xFF
g = (color_value >> 8) & 0xFF
b = color_value & 0xFF
return (r, g, b)
# 如果是字符串形式
elif isinstance(color_value, str):
# 处理十六进制颜色值
if color_value.startswith('#'):
hex_value = color_value[1:]
if len(hex_value) == 3: # 简写形式 #RGB
hex_value = ''.join([c*2 for c in hex_value])
if len(hex_value) in (6, 8): # #RRGGBB 或 #RRGGBBAA
r = int(hex_value[0:2], 16)
g = int(hex_value[2:4], 16)
b = int(hex_value[4:6], 16)
return (r, g, b)
# 处理颜色名称(需要额外的颜色名称映射表)
# 这里只列举几种常见颜色
color_names = {
'black': (0, 0, 0),
'white': (255, 255, 255),
'red': (255, 0, 0),
'green': (0, 255, 0),
'blue': (0, 0, 255),
'yellow': (255, 255, 0),
'magenta': (255, 0, 255),
'cyan': (0, 255, 255),
'orange': (255, 128, 0), # 根据项目规范
'teal': (0, 128, 128) # 根据项目规范
}
if color_value.lower() in color_names:
return color_names[color_value.lower()]
# 如果是Color对象如pygame.Color
elif hasattr(color_value, 'r') and hasattr(color_value, 'g') and hasattr(color_value, 'b'):
return (color_value.r, color_value.g, color_value.b)
# 默认返回黑色
return (0, 0, 0)
@staticmethod
def check_system_resources():
"""检查系统资源使用情况"""
logger.info("检查系统资源使用情况...")
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
network = psutil.net_io_counters()
logger.info("检查系统资源使用情况完毕")
return {
'cpu_percent': cpu_percent,
'memory_percent': memory.percent,
'memory_available': memory.available / (1024**3), # GB
'network_bytes_sent': int(network.bytes_sent / 1024),
'network_bytes_recv': int(network.bytes_recv / 1024)
}
@staticmethod
def build_response(type, status : enums.Response, msg):
status_code, status_msg = status.value
result = {
"type": "response",
"request_type": type,
"status_code": status_code,
"status_msg": status_msg,
"detail_msg": msg
}
if status_code != 0:
Helper.error(result);
return result
@staticmethod
def get_surrounding_rect(points):
if len(points) == 0:
return Constant.invalid_rect
min_x = min(p.x() for p in points)
min_y = min(p.y() for p in points)
max_x = max(p.x() for p in points)
max_y = max(p.y() for p in points)
return QRectF(min_x, min_y, max_x - min_x, max_y - min_y)
@staticmethod
def getYoloLabellingInfo(dir_path, file_names, desc):
if len(dir_path) > 0:
imageNumber, labelNumber = 0, 0
imagePath = dir_path + 'images/'
labelPath = dir_path + 'labels/'
for file_name in file_names:
if file_name.startswith(imagePath):
imageNumber += 1
elif file_name.startswith(labelPath):
labelNumber += 1
return f'{desc} {imageNumber - 1} 张图片,{labelNumber - 1} 张标签;', imageNumber - 1
return f'{desc};', 0
@staticmethod
def getMarkdownRenderText(mdContent):
# 使用Python库直接将Markdown转换为HTML避免JavaScript依赖
try:
# 尝试导入markdown库
import markdown
html_content = markdown.markdown(mdContent)
# 添加基本样式使其美观
styled_html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
max-width: 900px; margin: 20px auto; padding: 0 20px; line-height: 1.6; }}
code {{ background: #f5f5f5; padding: 2px 4px; border-radius: 3px; }}
pre {{ background: #f5f5f5; padding: 10px; border-radius: 5px; overflow: auto; }}
pre code {{ background: none; padding: 0; }}
h1, h2, h3 {{ color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px; }}
</style>
</head>
<body>
{html_content}
</body>
</html>
"""
return styled_html
except ImportError:
# 回退到JavaScript的marked.js方法
logger.warning("未找到markdown库使用JavaScript渲染方式")
# 从appIOs/configs加载marked.js
file_js = QFile('appIOs/configs/marked.min.js')
markedJs = ''
if file_js.open(QIODevice.ReadOnly | QIODevice.Text):
markedJs = file_js.readAll().data().decode('utf-8')
file_js.close()
# 转义markdown内容
escapedMd = mdContent.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;').replace("'", '&#x27;')
# 创建HTML模板
htmlTemplate = '''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
max-width: 900px; margin: 20px auto; padding: 0 20px; line-height: 1.6; }
code { background: #f5f5f5; padding: 2px 4px; border-radius: 3px; }
pre { background: #f5f5f5; padding: 10px; border-radius: 5px; overflow: auto; }
pre code { background: none; padding: 0; }
h1, h2, h3 { color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px; }
</style>
<script>
// 加载marked库
%1
</script>
</head>
<body>
<div id="content"></div>
<script>
const md = `%2`;
document.getElementById('content').innerHTML = marked.parse(md);
</script>
</body>
</html>
'''
# 生成HTML内容
htmlContent = htmlTemplate.replace('%1', markedJs).replace('%2', escapedMd)
return htmlContent
@staticmethod
def getMarkdownRender(mdFileName):
file = QFile(mdFileName)
mdContent = ''
if file.open(QIODevice.ReadOnly | QIODevice.Text):
stream = QTextStream(file)
stream.setAutoDetectUnicode(True)
mdContent = stream.readAll()
file.close()
return Helper.getMarkdownRenderText(mdContent)
else:
logger.error(f'打开文件 {mdFileName} 失败')
return f"<p style='color: red;'>无法打开文件: {mdFileName}</p>"
class RTTI:
@staticmethod
def _do_set_attr(obj, property_name, property_value):
if obj is None:
logger.error(f"RTTI.set: obj is None")
return
class_name = type(obj).__name__
object_name = obj.objectName()
if property_name not in dir(obj):
logger.error(f"RTTI.set: {class_name} {object_name}.{property_name} not in dir(obj)")
return
if property_name.endswith('icon') and isinstance(property_value, str):
original_property_value = property_value
if not os.path.exists(property_value):
property_value = os.path.join('appIOs/res/images/icons',property_value)
# logger.info(f"RTTI.set: {class_name} {object_name}.{property_name} = {property_value}(自动匹配)")
if not os.path.exists(property_value):
logger.error(f"{original_property_value}文件不存在 > RTTI.set: {class_name} {object_name}.{property_name} = '{original_property_value}'")
return
property_value = QIcon(property_value)
setter_method = getattr(obj, f'set{property_name[0].upper() + property_name[1:]}')
setter_method(property_value)
@staticmethod
def set(obj, property_name, property_value):
property_list = property_name.split('.')
if len(property_list) == 1:
RTTI._do_set_attr(obj, property_name, property_value)
else:
dest_obj = obj
for i in range(len(property_list) - 1):
if not dest_obj:
logger.error(f"RTTI.set: {property_list.join('.')} not found")
return
dest_obj = getattr(dest_obj, property_list[i])
RTTI._do_set_attr(dest_obj, property_list[-1], property_value)
@staticmethod
def _do_get_attr(obj, property_name):
if obj is None:
logger.error(f"RTTI.get: obj is None")
return None, None
if property_name not in dir(obj):
logger.error(f"RTTI.get: {type(obj).__name__} {obj.objectName()}.{property_name} not in dir(obj)")
return None, None
# 返回属性类型与属性值
type_name = type(getattr(obj, property_name)).__name__
value = getattr(obj, property_name)
return type_name, value
# 取得属性类型与属性值 type, value = RTTI.get(obj, property_name)
@staticmethod
def get(obj, property_name):
property_list = property_name.split('.')
if len(property_list) == 1:
return RTTI._do_get_attr(obj, property_name)
else:
dest_obj = obj
for i in range(len(property_list) - 1):
if not dest_obj:
logger.error(f"RTTI.get: {property_list.join('.')} not found")
return None
dest_obj = getattr(dest_obj, property_list[i])
return RTTI._do_get_attr(dest_obj, property_list[-1])
class DrawHelper:
@staticmethod
def draw_dashed_line(mat, pt1, pt2, color, thickness=1, dash_length=10):
dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5
dashes = int(dist / dash_length)
for i in range(dashes):
start = (int(pt1[0] + (pt2[0] - pt1[0]) * i / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * i / dashes))
end = (int(pt1[0] + (pt2[0] - pt1[0]) * (i + 0.5) / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * (i + 0.5) / dashes))
cv2.line(mat, start, end, color, thickness)
@staticmethod
def draw_dashed_rect(painter, rect, color, thickness=1, dash_length=10):
x1, y1 = rect.left(), rect.top()
x2, y2 = rect.right(), rect.bottom()
DrawHelper.draw_dashed_line(painter, (x1, y1), (x2, y1), color, thickness, dash_length)
DrawHelper.draw_dashed_line(painter, (x1, y2), (x2, y2), color, thickness, dash_length)
DrawHelper.draw_dashed_line(painter, (x1, y1), (x1, y2), color, thickness, dash_length)
DrawHelper.draw_dashed_line(painter, (x2, y1), (x2, y2), color, thickness, dash_length)

View File

@ -0,0 +1,466 @@
from loguru import logger
import subprocess as sp
from ultralytics import YOLO
import time, cv2, numpy as np, math
from traceback import format_exc
from DrGraph.utils.pull_push import NetStream
from DrGraph.utils.Helper import *
from DrGraph.utils.Constant import Constant
from zipfile import ZipFile
class YOLOTracker:
def __init__(self, model_path):
"""
初始化YOLOv11追踪器
"""
self.model = YOLO(model_path)
self.tracking_config = {
"tracker": "appIOs/configs/yolo11/bytetrack.yaml", # "/home/thsw/jcq/projects/yolov11/ultralytics-main/ultralytics/cfg/trackers/bytetrack.yaml",
"conf": 0.25,
"iou": 0.45,
"persist": True,
"verbose": False
}
self.frame_count = 0
self.processing_time = 0
def process_frame(self, frame):
"""
处理单帧图像进行目标检测和追踪
"""
start_time = time.time()
try:
# 执行YOLOv11目标检测和追踪
results = self.model.track(
source=frame,
**self.tracking_config
)
# 获取第一个结果(因为只处理单张图片)
result = results[0]
# 绘制检测结果
processed_frame = result.plot()
# 计算处理时间
self.processing_time = (time.time() - start_time) * 1000 # 转换为毫秒
self.frame_count += 1
# 打印检测信息(可选)
if self.frame_count % 100 == 0:
self._print_detection_info(result)
return processed_frame, result
except Exception as e:
logger.error("YOLO处理异常: {}", format_exc())
return frame, None
def _print_detection_info(self, result):
"""
打印检测信息
"""
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
detection_count = len(boxes)
unique_ids = set()
for box in boxes:
if box.id is not None:
unique_ids.add(int(box.id[0]))
logger.info(f"{self.frame_count}: 检测到 {detection_count} 个目标, 追踪ID数: {len(unique_ids)}, 处理时间: {self.processing_time:.2f}ms")
else:
logger.info(f"{self.frame_count}: 未检测到目标, 处理时间: {self.processing_time:.2f}ms")
class YOLOTrackerManager:
def __init__(self, model_path, pull_url, push_url, request_id):
self.pull_url = pull_url
self.push_url = push_url
self.request_id = request_id
self.tracker = YOLOTracker(model_path)
self.stream = None
self.videoStream = None
self.videoType = Constant.INPUT_NONE
self.localFile = ''
self.localPath = ''
self.localFiles = []
self._currentFrame = None
self.totalFrames = 0
self.frameChanged = False
def _stop(self):
if self.videoStream is not None:
self.videoStream.release()
self.videoStream = None
if self.stream is not None:
self.stream.clear_pull_p(self.stream.pull_p, self.request_id)
self.stream = None
self.localFile = ''
self.localPath = ''
self.localFiles = []
self._currentFrame = None
self.totalFrames = 0
self._frameIndex = -1
self.videoType = Constant.INPUT_NONE
self.frameChanged = True
def startLocalFile(self, fileName):
self._stop()
self.localFile = fileName
self._frameIndex = -1
def startLocalDir(self, dirName):
self._stop()
self.localPath = dirName
self.localFiles = [os.path.join(dirName, f) for f in os.listdir(dirName) if f.endswith(('.jpg', '.jpeg', '.png'))]
self.totalFrames = len(self.localFiles)
Helper.App.progressMax = self.totalFrames
self.localFiles.sort()
logger.info("本地目录打开: {}, 总帧数: {}", dirName, self.totalFrames)
self._frameIndex = 0
def startLabelledZip(self, labelledPath, categoryPath):
self._stop()
self.localPath = labelledPath
localFiles = ZipFile(labelledPath).namelist()
_, self.totalFrames = Helper.getYoloLabellingInfo(categoryPath, localFiles, '')
imagePath = categoryPath + 'images/'
self.localFiles = [file for file in localFiles if imagePath in file]
logger.info(f"标注压缩文件{labelledPath}{categoryPath}集共有{self.totalFrames}帧, 有效帧数: {len(self.localFiles)}")
self._frameIndex = 0
Helper.App.progressMax = self.totalFrames
def startUsbCamera(self, index = 0):
self._stop()
self.videoStream = cv2.VideoCapture(index)
self.videoType = Constant.INPUT_USB_CAMERA
Helper.Sleep(200)
if not self.videoStream.isOpened():
logger.error("无法打开USB摄像头: {}", index)
self.videoType = Constant.INPUT_NONE
return
self.totalFrames = 0x7FFFFFFF
def startLocalVideo(self, fileName):
self._stop()
self.videoStream = cv2.VideoCapture(fileName)
self.videoType = Constant.INPUT_LOCAL_VIDEO
Helper.Sleep(200)
if not self.videoStream.isOpened():
logger.error("无法打开本地视频流: {}", fileName)
self.videoType = Constant.INPUT_NONE
return
try:
total = int(self.videoStream.get(cv2.CAP_PROP_FRAME_COUNT))
except Exception:
total = 0
self.totalFrames = total if total is not None else 0
Helper.App.progressMax = self.totalFrames
logger.info("本地视频打开: {}, 总帧数: {}", fileName, self.totalFrames)
def startPull(self, url = ''):
self._stop()
if len(url) > 0:
self.pull_url = url
logger.info("拉流地址: {}", self.pull_url)
self.stream = NetStream(self.pull_url, self.push_url, self.request_id)
self.stream.prepare_pull()
def getCurrentFrame(self):
if self._currentFrame is None:
self._currentFrame = self.nextFrame()
if self._currentFrame is not None:
return self._currentFrame.copy()
return None
currentFrame = Property_Rw(getCurrentFrame, None)
def setFrameIndex(self, index):
if self.videoStream is None and len(self.localFiles) == 0:
return
if self.videoStream is not None and self.videoType != Constant.INPUT_LOCAL_VIDEO:
return
if index < 0:
index = 0
if index >= self.totalFrames:
index = self.totalFrames - 1
if self.videoStream:
self.videoStream.set(cv2.CAP_PROP_POS_FRAMES, index)
self._frameIndex = index - 1
self._currentFrame = self.nextFrame()
self.frameChanged = True
frameIndex = Property_rW(setFrameIndex, 0)
def getLabels(self):
with ZipFile(self.localPath, 'r') as zip_ref:
content = zip_ref.read(self.localFile)
content = content.decode('utf-8')
return content
return ''
# 取得待分析的图像帧
def getAnalysisFrame(self, nextFlag):
frameChanged = self.frameChanged
self.frameChanged = False
if nextFlag: # 流式媒体
self._currentFrame = self.nextFrame()
self.frameChanged = True
frame = self.currentFrame
return frame.copy() if frame is not None else None, frameChanged
def nextFrame(self):
frame = None
if self.stream:
frame = self.stream.next_pull_frame()
elif self.videoStream:
ret, frame = self.videoStream.read()
self._frameIndex += 1
if not ret:
self._frameIndex -= 1
frame = None
elif len(self.localFiles) > 0:
if self.localPath.endswith('.zip'):
index = -1
for img_file in self.localFiles:
if '/images/' in img_file:
if index == self._frameIndex:
# logger.warning(f'Loading image from zip file: {img_file}')
try:
with ZipFile(self.localPath, 'r') as zip_ref:
image_data = zip_ref.read(img_file)
nparr = np.frombuffer(image_data, np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
self._frameIndex += 1
lable_file = img_file.replace('/images/', '/labels/').replace('.jpg', '.txt').replace('.png', '.txt')
self.localFile = lable_file
except Exception as e:
# logger.error(f"读取压缩文件 {self.localPath} 中的 {img_file} 失败: {e}")
frame = None
break
index += 1
else:
if self._frameIndex < 0:
self._frameIndex = 0
if self._frameIndex >= len(self.localFiles):
self._frameIndex = 0
if self._frameIndex < len(self.localFiles):
frame = cv2.imread(self.localFiles[self._frameIndex])
if frame is None:
logger.error(f"无法读取目标目录 {self.localPath}中下标为 {self._frameIndex} 的视频文件 {self.localFiles[self._frameIndex]}")
self._frameIndex = -1
return
self._frameIndex += 1
elif self.localFile is not None and self.localFile != '':
frame = cv2.imread(self.localFile)
if frame is None:
logger.error("无法读取本地视频文件: {}", self.localFile)
return
if frame is not None:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if self.totalFrames > 0:
Helper.App.progress = self._frameIndex
return frame
def test_yolo11_recognize(self, frame):
processed_frame = self.process_frame_with_yolo(frame, self.request_id)
return processed_frame
def process_frame_with_yolo(self, frame, requestId):
"""
使用YOLOv11处理帧
"""
try:
# 使用YOLO进行目标检测和追踪
processed_frame, detection_result = self.tracker.process_frame(frame)
# 在帧上添加处理信息
fps_info = f"FPS: {1000/max(self.tracker.processing_time, 1):.1f}"
cv2.putText(processed_frame, fps_info, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 添加检测目标数量信息
if detection_result and detection_result.boxes is not None:
obj_count = len(detection_result.boxes)
count_info = f"Objects: {obj_count}"
cv2.putText(processed_frame, count_info, (10, 70),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
return processed_frame
except Exception as e:
logger.error("YOLO处理异常{}, requestId:{}", format_exc(), requestId)
# 如果处理失败,返回原帧
return frame
def get_gray_mask(self, frame):
"""
生成灰度像素的掩码图
灰度像素定义三颜色分量差小于20
"""
# 创建与原图大小相同的掩码图
maskMat = np.zeros(frame.shape[:2], dtype=np.uint8)
# 获取图像的三个颜色通道
b, g, r = cv2.split(frame)
r = r.astype(np.int16)
g = g.astype(np.int16)
b = b.astype(np.int16)
# 计算任意两个颜色分量之间的差值
diff_rg = np.abs(r - g)
is_shadow = (b > r) & (b - r < 40)
diff_rb = np.abs(r - b)
diff_gb = np.abs(g - b)
# 判断条件三颜色分量差都小于20
gray_pixels = (diff_rg < 20 ) & (diff_rb < 20| is_shadow) & (diff_gb < 20)
# 将满足条件的像素在掩码图中设为255白色
maskMat[gray_pixels] = 255
return maskMat
def debugLine(self, line, y_intersect):
x1, y1, x2, y2 = line
length = np.linalg.norm([x2 - x1, y2 - y1])
# 计算线与水平线的夹角(度数)
# 使用atan2计算弧度再转换为度数
angle_rad = math.atan2(y2 - y1, x2 - x1)
angle_deg = math.degrees(angle_rad)
# 调整角度范围到0-180度平面角
if angle_deg < 0:
angle_deg += 180
# angle_deg = min(angle_deg, 180 - angle_deg)
x_intersect = (x2 - x1) * (y_intersect - y1) / (y2 - y1) + x1
return angle_deg, length, x_intersect
def test_highway_recognize(self, frame, debugFlag = False):
processed_frame = frame.copy()
try:
IGNORE_HEIGHT = 100
y_intersect = frame.shape[0] / 2
frame[:IGNORE_HEIGHT, :] = (255, 0, 0)
gray_mask = self.get_gray_mask(frame)
kernel = np.ones((5, 5), np.uint8) # 使用形态学开运算(先腐蚀后膨胀)去除小噪声点
gray_mask = cv2.erode(gray_mask, kernel)
gray_mask = cv2.erode(gray_mask, kernel)
# 过滤掉面积小于10000的区域
contours, _ = cv2.findContours(gray_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建新的掩码图像只保留面积大于等于10000的区域
filtered_mask = np.zeros_like(gray_mask)
for contour in contours:
area = cv2.contourArea(contour)
if area >= 10000: # 填充满足条件的轮廓区域
cv2.fillPoly(filtered_mask, [contour], 255)
gray_mask = filtered_mask # 使用过滤后的掩码替换原来的gray_mask
edges = cv2.Canny(frame, 100, 200) # 边缘检测
road_edges = cv2.bitwise_and(edges, edges, mask=filtered_mask) # 在过滤后的路面区域内进行边缘检测
# 用color_mask过滤原图得到待处理的图
whiteLineMat = cv2.bitwise_and(processed_frame, processed_frame, mask=filtered_mask)
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_RGB2GRAY) # 灰度化
# sobel边缘检测
whiteLineMat = cv2.Sobel(whiteLineMat, cv2.CV_8U, 1, 0, ksize=3)
tempMat = whiteLineMat.copy()
# whiteLineMat = cv2.Canny(whiteLineMat, 100, 200)
lines = cv2.HoughLinesP(tempMat, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB)
# logger.info(f"{lines.shape[0]} lines: ")
# if lines is not None:
# for line in lines:
# x1, y1, x2, y2 = line[0]
# cv2.line(whiteLineMat, (x1, y1), (x2, y2), (255, 0, 0), 2)
# 创建彩色掩码用于叠加(使用绿色标记识别出的路面)
color_mask = cv2.cvtColor(gray_mask, cv2.COLOR_GRAY2RGB)
color_mask[:] = (0, 255, 0) # 设置为绿色
color_mask = cv2.bitwise_and(color_mask, color_mask, mask=filtered_mask)
# 先叠加路面绿色标记,再叠加白色线条红色标记
overlay = cv2.addWeighted(processed_frame, 0.7, color_mask, 0.3, 0)
# # 在road_edges的基础上识别其中的实线
# lines = cv2.HoughLinesP(road_edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10)
# logger.info(f"{lines.shape[0]} lines: ")
# linesWithAngle = []
# # if lines is not None:
# for index, line in enumerate(lines):
# angle_deg, length, x_intersect = self.debugLine(line[0], y_intersect)
# linesWithAngle.append((line, angle_deg, x_intersect))
# if debugFlag:
# logger.info(f'line {index + 1}: {line}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# linesWithAngle进行聚类算法按夹角分两类即可
# 使用自定义的简单K-means聚类实现
# line_data = np.array([[angle, x_intersect] for line, angle, x_intersect in linesWithAngle])
# if len(line_data) > 0:
# labels = self._simple_kmeans(line_data, n_clusters=2, random_state=2, random_state=0)
# # 输出两类线的数目
# logger.info(f"聚类结果:{np.bincount(labels)}")
# if debugFlag:
# lines0 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 0]
# lines1 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 1]
# # 取得lines0中所有线段并输出日志信息
# for index, line in enumerate(lines0):
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
# logger.info(f'聚类0: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# for index, line in enumerate(lines1):
# angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect)
# logger.info(f'聚类1: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})')
# # 保留数量多的类别
# dominant_cluster = np.argmax(np.bincount(labels))
# # 绘制dominant_cluster类别的线
# dominant_lines = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == dominant_cluster]
# for line, angle, x_intersect in dominant_lines:
# cv2.line(overlay, (int(line[0][0]), int(line[0][1])), (int(line[0][2]), int(line[0][3])), (255, 0, 0), 2)
return overlay, color_mask, whiteLineMat # cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) # cv2.cvtColor(road_edges, cv2.COLOR_GRAY2RGB)
except Exception as e:
logger.error("路面识别异常:{}", format_exc())
# 如果处理失败,返回原始帧
return processed_frame
# def _simple_kmeans(self, data, n_clusters=2, max_iter=100, random_state=0):
# """
# 使用K-means算法对数据进行聚类
# 参数:
# data: array-like, 形状为 (n_samples, n_features) 的输入数据
# n_clusters: int, 聚类数量默认为2
# max_iter: int, 最大迭代次数默认为100
# random_state: int, 随机种子用于初始化质心默认为0
# 返回:
# labels: array, 形状为 (n_samples,) 的聚类标签数组
# """
# np.random.seed(random_state)
# # 随机选择初始质心
# centroids_idx = np.random.choice(len(data), size=n_clusters, replace=False)
# centroids = data[centroids_idx].copy()
# # 迭代优化质心位置
# for _ in range(max_iter):
# # 为每个数据点分配最近的质心标签
# labels = np.zeros(len(data), dtype=int)
# for i, point in enumerate(data):
# distancesi=ids - point, ax(centroids - point, axis=1) ce置为- # 情况如果d sfnpcnsy>d9e则置为>180 -9dis ini作为新质心
# new_centroids[c] = data[np.random.choice(len(data))]
# # 检查收敛条件
# if np.allclose(centroids, new_centroids):
# break
# centroids = new_centroids
# return labels

115
backend/utils/pull_push.py Normal file
View File

@ -0,0 +1,115 @@
import subprocess as sp
from traceback import format_exc
import cv2, time
import numpy as np
from loguru import logger
from DrGraph.utils.Helper import Helper
from DrGraph.utils.Exception import ServiceException
from DrGraph.utils.Constant import Constant
class NetStream:
def __init__(self, pull_url, push_url, request_id):
self.pull_url = pull_url
self.push_url = push_url
self.request_id = request_id
self.pull_p = None
self.width = 1920
self.height = 1080 * 3 // 2
self.width_height_3 = 1920 * 1080 * 3 // 2
self.w_2 = 960
self.h_2 = 540
self.frame_count = 0
self.start_time = time.time();
def clear_pull_p(self,pull_p, requestId):
try:
if pull_p and pull_p.poll() is None:
logger.info("关闭拉流管道, requestId:{}", requestId)
if pull_p.stdout:
pull_p.stdout.close()
pull_p.terminate()
pull_p.wait(timeout=30)
logger.info("拉流管道已关闭, requestId:{}", requestId)
except Exception as e:
logger.error("关闭拉流管道异常: {}, requestId:{}", format_exc(), requestId)
if pull_p and pull_p.poll() is None:
pull_p.kill()
pull_p.wait(timeout=30)
raise e
def start_pull_p(self, pull_url, requestId):
try:
command = ['D:/DrGraph/DSP/ffmpeg.exe']
# if pull_url.startswith("rtsp://"):
# command.extend(['-timeout', '20000000', '-rtsp_transport', 'tcp'])
# if pull_url.startswith("http") or pull_url.startswith("rtmp"):
# command.extend(['-rw_timeout', '20000000'])
command.extend(['-re',
'-y',
'-an',
# '-hwaccel', 'cuda', cuvid
'-c:v', 'h264_cuvid',
# '-resize', self.wah,
'-i', pull_url,
'-f', 'rawvideo',
# '-pix_fmt', 'bgr24',
'-r', '25',
'-'])
self.pull_p = sp.Popen(command, stdout=sp.PIPE)
return self.pull_p
except ServiceException as s:
logger.error("构建拉流管道ServiceException异常: url={}, {}, requestId:{}", pull_url, s.msg, requestId)
raise s
except Exception as e:
logger.error("构建拉流管道Exception异常url={}, {}, requestId:{}", pull_url, format_exc(), requestId)
raise e
def pull_read_video_stream(self):
result = None
try:
if self.pull_p is None:
self.start_pull_p(self.pull_url, self.request_id)
in_bytes = self.pull_p.stdout.read(self.width_height_3)
if in_bytes is not None and len(in_bytes) > 0:
try:
# result = (np.frombuffer(in_bytes, np.uint8).reshape([height * 3 // 2, width, 3]))
# ValueError: cannot reshape array of size 3110400 into shape (1080,1920)
result = (np.frombuffer(in_bytes, np.uint8)).reshape((self.height, self.width))
result = cv2.cvtColor(result, cv2.COLOR_YUV2BGR_NV12)
# result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
if result.shape[1] > Constant.pull_frame_width:
result = cv2.resize(result, (result.shape[1] // 2, result.shape[0] // 2), interpolation=cv2.INTER_LINEAR)
except Exception:
logger.error("视频格式异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[0],
ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[1])
except ServiceException as s:
logger.error("ServiceException 读流异常: {}, requestId:{}", s.msg, self.request_id)
self.clear_pull_p(self.pull_p, self.request_id)
self.pull_p = None
result = None
raise s
except Exception:
logger.error("Exception 读流异常:{}, requestId:{}", format_exc(), self.request_id)
self.clear_pull_p(self.pull_p, self.request_id)
self.pull_p = None
self.width = None
self.height = None
self.width_height_3 = None
result = None
logger.error("读流异常:{}, requestId:{}", format_exc(), self.request_id)
return result
def prepare_pull(self):
if self.pull_p is None:
self.start_time = time.time();
self.start_pull_p(self.pull_url, self.request_id)
def next_pull_frame(self):
if self.pull_p is None:
logger.error(f'pull_p is None, requestId: {self.request_id}')
return None
frame = self.pull_read_video_stream()
return frame

View File

@ -1,11 +1,13 @@
"""Custom exceptions and error handlers for the chat agent application."""
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from typing import Union
from pydantic import BaseModel
from datetime import datetime
import json
from loguru import logger
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
@ -25,14 +27,33 @@ class FullHxfResponseModel(BaseModel):
message: Optional[str]
class HxfResponse(JSONResponse):
def __init__(self, response: Union[BaseModel, Dict[str, Any]]):
if isinstance(response, BaseModel):
data_dict = response.model_dump(mode='json')
def __new__(cls, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]], StreamingResponse]):
# 如果是StreamingResponse直接返回不进行JSON包装
if isinstance(response, StreamingResponse):
return response
# 否则创建HxfResponse实例
return super().__new__(cls)
def __init__(self, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]]]):
code = 0
if isinstance(response, list):
# 处理BaseModel对象列表
if all(isinstance(item, BaseModel) for item in response):
data_dict = [item.model_dump(mode='json') for item in response]
else:
data_dict = response
elif isinstance(response, BaseModel):
# 处理单个BaseModel对象
data_dict = response.model_dump(mode='json')
else:
# 处理字典或其他可JSON序列化的数据
data_dict = response
if 'success' in data_dict and data_dict['success'] == False:
code = -1
content = {
"code": 0,
"code": code,
"status": status.HTTP_200_OK,
"data": data_dict,
"error": None,
@ -50,12 +71,23 @@ class HxfErrorResponse(JSONResponse):
def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED):
"""Return a JSON error response."""
if isinstance(message, Exception):
msg = message.message if 'message' in message.__dict__ else str(message)
logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}")
if isinstance(message, TypeError):
content = {
"code": -1,
"status": 500,
"data": None,
"error": f"错误类型: {type(message)} 错误信息: {message}",
"message": msg
}
else:
content = {
"code": -1,
"status": message.status_code,
"data": None,
"error": message.details,
"message": message.message
"error": None,
"message": msg
}
else:
content = {

View File

@ -310,6 +310,7 @@ class ConversationResponse(BaseResponse, ConversationBase):
is_archived: bool
message_count: int = 0
last_message_at: Optional[datetime] = None
messages: Optional[List["MessageResponse"]] = None
# Message schemas
@ -345,11 +346,11 @@ class ChatRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=10000)
stream: bool = Field(default=False)
use_knowledge_base: bool = Field(default=False)
knowledge_base_id: Optional[int] = Field(None, description="Knowledge base ID for RAG mode")
knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode")
use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities")
use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling")
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(None, ge=1, le=8192)
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192)
class ChatResponse(BaseModel):
@ -458,6 +459,7 @@ class DocumentChunk(BaseModel):
chunk_index: int
start_char: Optional[int] = None
end_char: Optional[int] = None
vector_id: Optional[str] = None
class DocumentChunksResponse(BaseModel):

120
backend/utils/vclEnums.py Normal file
View File

@ -0,0 +1,120 @@
from enum import Enum, unique
CommonColors_Backgnd = [
"#FF0000", # 红色
"#00FF00", # 绿色
"#0000FF", # 蓝色
"#FFFF00", # 黄色
"#FF00FF", # 品红
"#00FFFF", # 青色
"#FFA500", # 橙色
"#800080", # 紫色
"#FFC0CB", # 粉色
"#008000", # 深绿色
"#000080", # 深蓝色
"#800000", # 深红色
"#808000", # 橄榄色
"#008080", # 青色
"#808080", # 灰色
"#FF0080", # 玫瑰红
"#0080FF", # 天蓝色
"#FF8000", # 橙红色
"#8000FF", # 紫罗兰
"#00FF80" # 海绿色
]
CommonColors_Foregnd = [
"#FFFFFF", # 红色背景 -> 白色文字
"#000000", # 绿色背景 -> 黑色文字
"#FFFFFF", # 蓝色背景 -> 白色文字
"#000000", # 黄色背景 -> 黑色文字
"#FFFFFF", # 品红背景 -> 白色文字
"#000000", # 青色背景 -> 黑色文字
"#000000", # 橙色背景 -> 黑色文字
"#FFFFFF", # 紫色背景 -> 白色文字
"#000000", # 粉色背景 -> 黑色文字
"#FFFFFF", # 深绿色背景 -> 白色文字
"#FFFFFF", # 深蓝色背景 -> 白色文字
"#FFFFFF", # 深红色背景 -> 白色文字
"#FFFFFF", # 橄榄色背景 -> 白色文字
"#FFFFFF", # 青色背景 -> 白色文字
"#FFFFFF", # 灰色背景 -> 白色文字
"#FFFFFF", # 玫瑰红背景 -> 白色文字
"#000000", # 天蓝色背景 -> 黑色文字
"#000000", # 橙红色背景 -> 黑色文字
"#FFFFFF", # 紫罗兰背景 -> 白色文字
"#000000" # 海绿色背景 -> 黑色文字
]
@unique
class LabellingKind(Enum):
Unknown = 0
Select = 1
Create = 2
TempDrag = 3
TempResize = 4
TempRotate = 5
@unique
class Meta(Enum):
Unknown = 0
Line = 1
Rectangle = 2
Ellipse = 3
Polygon = 4
Text = 5
@unique
class AiAlg(Enum):
Unknown = 0
FashionMNIST = 1,
ColorDetector = 2,
Face = 3,
Coco8 = 4
@unique
class Response(Enum):
OK = (0, "OK - 响应成功")
DEBUG = (1, "DEBUG - 正常调试")
WARNING = (2, "FAIL - 响应警告")
ERROR = (3, "ERROR - 响应错误")
EXCEPTION = (4, "EXCEPTION - 响应异常")
CRITICAL = (5, "CRITICAL - 响应严重错误")
class Flag(Enum):
Paint = True
Client = False
class Align(Enum):
Left = 0
Top = 1
Right = 2
Bottom = 3
Client = 4
Custom = 5
class Anchor(Enum):
Left = 1
Top = 2
Right = 4
Bottom = 8
class IconPos(Enum):
NoIcon = 0
OnlyIcon = 1
IconLeft = 2
IconRight = 3
IconTop = 4
IconBottom = 5
class ControlEvent(Enum):
MouseEnter = 0
MouseLeave = 1
MouseMove = 2
MouseDown = 3
MouseUp = 4
Click = 5
DblClick = 6

136
backend/utils/wssServer.py Normal file
View File

@ -0,0 +1,136 @@
import time, websockets, json, inspect, os
import numpy as np
from loguru import logger
from typing import Dict, Set, Callable, Any, Optional
from traceback import format_exc
from DrGraph.utils.Helper import *
import DrGraph.utils.vclEnums as enums
class WebSocketServer:
"""
WebSocket服务器类
提供WebSocket服务器功能包括客户端连接管理消息处理和广播功能
"""
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.message_handlers: Dict[str, Callable] = {}
self.server = None
self.onClientConnect = None
self.register_handler("file", self.handle_file)
def register_handler(self, message_type: str, handler: Callable):
self.message_handlers[message_type] = handler
async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'):
if isinstance(message, dict):
message = json.dumps(message, ensure_ascii=False)
if Helper.AppFlag_SaveLog:
logger.warning(f"发送消息: {message}")
await websocket.send(f't{message}')
caller = inspect.stack()[1]
logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ")
async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data):
# if isinstance(data, np.ndarray):
# data = data.tobytes()
await websocket.send(bytearray(b'b' + data))
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str):
try:
data = json.loads(message)
message_type = data.get("type")
payload = data.get("data")
if Helper.AppFlag_SaveLog:
logger.warning(f"收到消息: {message}")
if message_type in self.message_handlers:
response = await self.message_handlers[message_type](websocket, payload)
if response is not None:
if isinstance(response, (bytes, bytearray, np.ndarray)):
# print("发送图片数据")
await self.response_binary(websocket, response);
else:
await self.response_text(websocket, response, message_type);
elif message_type == 'echo':
print("echo message ", int(time.time() * 1000))
data["type"] = "echo_response"
await self.response_text(websocket, data);
else:
logger.warning(f"未知类型消息: {message} - {websocket.remote_address}")
await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}"))
except json.JSONDecodeError:
logger.error(f"无效的JSON格式 - {message}")
await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}"))
except BrokenPipeError as e:
logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}")
# BrokenPipeError表示连接已断开不需要特殊处理让上层处理ConnectionClosed异常
except Exception as e:
logger.error(f"处理消息时出错: {format_exc()}")
await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}"))
async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""):
"""
处理客户端连接
参数:
websocket (websockets.WebSocketServerProtocol): WebSocket连接对象
path (str): 请求路径
"""
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接")
if self.onClientConnect:
await self.onClientConnect(websocket, True)
try:
async for message in websocket:
await self.handle_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭")
except BrokenPipeError:
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError")
finally:
if self.onClientConnect:
await self.onClientConnect(websocket, False)
async def start(self):
"""
启动WebSocket服务器
"""
self.server = await websockets.serve(self.handle_client, self.host, self.port)
logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}")
async def stop(self):
"""
停止WebSocket服务器
"""
if self.server:
self.server.close()
await self.server.wait_closed()
logger.info("WebSocket服务器已停止")
async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str):
logger.info(f"接收文件: {payload}, {payload['command']}")
command = payload['command']
if command == 'dir':
path = payload.get('path', '/')
files = []
folders = []
try:
# 获取目录下的所有文件和文件夹
if os.path.exists(path) and os.path.isdir(path):
with os.scandir(path) as entries:
for entry in entries:
if entry.is_file():
files.append(entry.name)
elif entry.is_dir():
folders.append(entry.name)
else:
logger.warning(f"路径不存在或不是目录: {path}")
except Exception as e:
logger.error(f"读取目录时出错: {path}, 错误: {e}")
# 合并文件和文件夹列表
all_items = folders + files
logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}")

View File

@ -2,7 +2,7 @@ enable_file_log: true
enable_stderr: true
base_path: "webIOs/output/logs"
log_name: "th_agenter_web.log"
log_fmt: "<green>{time: HH:mm:ss.SSS}</green> [<level>{level}</level>] - <level>{message}</level> @ <cyan>{extra[relative_path]}:{line}</cyan> in <blue>{function}</blue>"
log_fmt: "<green>{time: HH:mm:ss.SSS}</green> [<level>{level:7}</level>] - <level>{message}</level> @ <cyan>{extra[relative_path]}:{line}</cyan> in <blue>{function}</blue>"
level: "INFO"
rotation: "00:00"
retention: "1 days"