diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f515b7a --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. See https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# VS Code +.vscode/* +!.vscode/extensions.json + +# Logs +logs/ +*.log +webIOs/output/logs/ + +# OS generated files +Thumbs.db +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db + +# FastAPI specific +*.pyc +uvicorn*.log \ No newline at end of file diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..d6cd79a --- /dev/null +++ b/alembic.ini @@ -0,0 +1,147 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..e0d0858 --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration with an async dbapi. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..7808f9d --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,101 @@ +import asyncio, os +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +from dotenv import load_dotenv +load_dotenv() +database_url = os.getenv("DATABASE_URL") +if not database_url: + raise ValueError("环境变量DATABASE_URL未设置") + +config.set_main_option("sqlalchemy.url", database_url) + + + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +from th_agenter.db import Base +# from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/424646027786_initial_migration.py b/alembic/versions/424646027786_initial_migration.py new file mode 100644 index 0000000..a44a110 --- /dev/null +++ b/alembic/versions/424646027786_initial_migration.py @@ -0,0 +1,359 @@ +"""Initial migration + +Revision ID: 424646027786 +Revises: +Create Date: 2025-12-16 09:56:45.172954 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '424646027786' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('agent_configs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('enabled_tools', sa.JSON(), nullable=False), + sa.Column('max_iterations', sa.Integer(), nullable=False), + sa.Column('temperature', sa.String(length=10), nullable=False), + sa.Column('system_message', sa.Text(), nullable=True), + sa.Column('verbose', sa.Boolean(), nullable=False), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs')) + ) + op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False) + op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False) + op.create_table('conversations', + sa.Column('title', sa.String(length=200), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('knowledge_base_id', sa.Integer(), nullable=True), + sa.Column('system_prompt', sa.Text(), nullable=True), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('temperature', sa.String(length=10), nullable=False), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('is_archived', sa.Boolean(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations')) + ) + op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False) + op.create_table('database_configs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('db_type', sa.String(length=20), nullable=False), + sa.Column('host', sa.String(length=255), nullable=False), + sa.Column('port', sa.Integer(), nullable=False), + sa.Column('database', sa.String(length=100), nullable=False), + sa.Column('username', sa.String(length=100), nullable=False), + sa.Column('password', sa.Text(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('connection_params', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')), + sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type')) + ) + op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False) + op.create_table('documents', + sa.Column('knowledge_base_id', sa.Integer(), nullable=False), + sa.Column('filename', sa.String(length=255), nullable=False), + sa.Column('original_filename', sa.String(length=255), nullable=False), + sa.Column('file_path', sa.String(length=500), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + sa.Column('file_type', sa.String(length=50), nullable=False), + sa.Column('mime_type', sa.String(length=100), nullable=True), + sa.Column('is_processed', sa.Boolean(), nullable=False), + sa.Column('processing_error', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('doc_metadata', sa.JSON(), nullable=True), + sa.Column('chunk_count', sa.Integer(), nullable=False), + sa.Column('embedding_model', sa.String(length=100), nullable=True), + sa.Column('vector_ids', sa.JSON(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_documents')) + ) + op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False) + op.create_table('excel_files', + sa.Column('original_filename', sa.String(length=255), nullable=False), + sa.Column('file_path', sa.String(length=500), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + sa.Column('file_type', sa.String(length=50), nullable=False), + sa.Column('sheet_names', sa.JSON(), nullable=False), + sa.Column('default_sheet', sa.String(length=100), nullable=True), + sa.Column('columns_info', sa.JSON(), nullable=False), + sa.Column('preview_data', sa.JSON(), nullable=False), + sa.Column('data_types', sa.JSON(), nullable=True), + sa.Column('total_rows', sa.JSON(), nullable=True), + sa.Column('total_columns', sa.JSON(), nullable=True), + sa.Column('is_processed', sa.Boolean(), nullable=False), + sa.Column('processing_error', sa.Text(), nullable=True), + sa.Column('last_accessed', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files')) + ) + op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False) + op.create_table('knowledge_bases', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('embedding_model', sa.String(length=100), nullable=False), + sa.Column('chunk_size', sa.Integer(), nullable=False), + sa.Column('chunk_overlap', sa.Integer(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('vector_db_type', sa.String(length=50), nullable=False), + sa.Column('collection_name', sa.String(length=100), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases')) + ) + op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False) + op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False) + op.create_table('llm_configs', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('provider', sa.String(length=50), nullable=False), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('api_key', sa.String(length=500), nullable=False), + sa.Column('base_url', sa.String(length=200), nullable=True), + sa.Column('max_tokens', sa.Integer(), nullable=False), + sa.Column('temperature', sa.Float(), nullable=False), + sa.Column('top_p', sa.Float(), nullable=False), + sa.Column('frequency_penalty', sa.Float(), nullable=False), + sa.Column('presence_penalty', sa.Float(), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.Column('is_embedding', sa.Boolean(), nullable=False), + sa.Column('extra_config', sa.JSON(), nullable=True), + sa.Column('usage_count', sa.Integer(), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs')) + ) + op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False) + op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False) + op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False) + op.create_table('messages', + sa.Column('conversation_id', sa.Integer(), nullable=False), + sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False), + sa.Column('message_metadata', sa.JSON(), nullable=True), + sa.Column('context_documents', sa.JSON(), nullable=True), + sa.Column('prompt_tokens', sa.Integer(), nullable=True), + sa.Column('completion_tokens', sa.Integer(), nullable=True), + sa.Column('total_tokens', sa.Integer(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_messages')) + ) + op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False) + op.create_table('roles', + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('code', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_system', sa.Boolean(), nullable=False), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_roles')) + ) + op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True) + op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False) + op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True) + op.create_table('table_metadata', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('table_name', sa.String(length=100), nullable=False), + sa.Column('table_schema', sa.String(length=50), nullable=False), + sa.Column('table_type', sa.String(length=20), nullable=False), + sa.Column('table_comment', sa.Text(), nullable=True), + sa.Column('database_config_id', sa.Integer(), nullable=True), + sa.Column('columns_info', sa.JSON(), nullable=False), + sa.Column('primary_keys', sa.JSON(), nullable=True), + sa.Column('foreign_keys', sa.JSON(), nullable=True), + sa.Column('indexes', sa.JSON(), nullable=True), + sa.Column('sample_data', sa.JSON(), nullable=True), + sa.Column('row_count', sa.Integer(), nullable=False), + sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False), + sa.Column('qa_description', sa.Text(), nullable=True), + sa.Column('business_context', sa.Text(), nullable=True), + sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata')) + ) + op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False) + op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False) + op.create_table('users', + sa.Column('username', sa.String(length=50), nullable=False), + sa.Column('email', sa.String(length=100), nullable=False), + sa.Column('hashed_password', sa.String(length=255), nullable=False), + sa.Column('full_name', sa.String(length=100), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('avatar_url', sa.String(length=255), nullable=True), + sa.Column('bio', sa.Text(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_users')) + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + op.create_table('user_roles', + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('role_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')), + sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles')) + ) + op.create_table('workflows', + sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'), + sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'), + sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'), + sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'), + sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'), + sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'), + sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows')) + ) + op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False) + op.create_table('workflow_executions', + sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'), + sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'), + sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'), + sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'), + sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'), + sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'), + sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')), + sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions')) + ) + op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False) + op.create_table('node_executions', + sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'), + sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'), + sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'), + sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'), + sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'), + sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'), + sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'), + sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'), + sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'), + sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions')) + ) + op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions') + op.drop_table('node_executions') + op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions') + op.drop_table('workflow_executions') + op.drop_index(op.f('ix_workflows_id'), table_name='workflows') + op.drop_table('workflows') + op.drop_table('user_roles') + op.drop_index(op.f('ix_users_username'), table_name='users') + op.drop_index(op.f('ix_users_id'), table_name='users') + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata') + op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata') + op.drop_table('table_metadata') + op.drop_index(op.f('ix_roles_name'), table_name='roles') + op.drop_index(op.f('ix_roles_id'), table_name='roles') + op.drop_index(op.f('ix_roles_code'), table_name='roles') + op.drop_table('roles') + op.drop_index(op.f('ix_messages_id'), table_name='messages') + op.drop_table('messages') + op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs') + op.drop_table('llm_configs') + op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases') + op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases') + op.drop_table('knowledge_bases') + op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files') + op.drop_table('excel_files') + op.drop_index(op.f('ix_documents_id'), table_name='documents') + op.drop_table('documents') + op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs') + op.drop_table('database_configs') + op.drop_index(op.f('ix_conversations_id'), table_name='conversations') + op.drop_table('conversations') + op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs') + op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs') + op.drop_table('agent_configs') + # ### end Alembic commands ### diff --git a/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py b/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py new file mode 100644 index 0000000..35748fc --- /dev/null +++ b/alembic/versions/8da391c6e2b7_add_message_count_and_last_message_at_.py @@ -0,0 +1,34 @@ +"""Add message_count and last_message_at to conversations + +Revision ID: 8da391c6e2b7 +Revises: 424646027786 +Create Date: 2025-12-19 16:16:29.943314 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8da391c6e2b7' +down_revision: Union[str, Sequence[str], None] = '424646027786' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('conversations', sa.Column('message_count', sa.Integer(), nullable=False)) + op.add_column('conversations', sa.Column('last_message_at', sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('conversations', 'last_message_at') + op.drop_column('conversations', 'message_count') + # ### end Alembic commands ### diff --git a/apps/drgraph/drgraph_session.py b/apps/drgraph/drgraph_session.py new file mode 100644 index 0000000..b85aba6 --- /dev/null +++ b/apps/drgraph/drgraph_session.py @@ -0,0 +1,66 @@ +import re +import traceback +from loguru import logger + +class DrGraphSession: + def __init__(self, stepIndex: int, msg: str, session_id: str): + logger.info(f"DrGraphSession.__init__: stepIndex={stepIndex}, msg={msg}, session_id={session_id}") + self.stepIndex = stepIndex + self.session_id = session_id + + match = re.search(r";(-\d+)", msg); + level = -3 + if match: + level = int(match.group(1)) + value = value.replace(f";{level}", "") + level = -3 + level + + if "警告" in value or value.startswith("WARNING"): + self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "异常" in value or value.startswith("EXCEPTION"): + self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "成功" in value or value.startswith("SUCCESS"): + self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "开始" in value or value.startswith("START"): + self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level) + elif "失败" in value or value.startswith("ERROR"): + self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level) + else: + self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level) + + def log_prefix(self) -> str: + """Get log prefix with session ID and desc.""" + return f"〖Session{self.session_id}〗" + + def parse_source_pos(self, level: int): + pos = (traceback.format_stack())[level - 1].strip().split('\n')[0] + match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos); + if match: + file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "") + pos = f"{file}:{match.group(2)} in {match.group(3)}" + return pos + + def log_info(self, msg: str, level: int = -2): + """Log info message with session ID.""" + pos = self.parse_source_pos(level) + logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_success(self, msg: str, level: int = -2): + """Log success message with session ID.""" + pos = self.parse_source_pos(level) + logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_warning(self, msg: str, level: int = -2): + """Log warning message with session ID.""" + pos = self.parse_source_pos(level) + logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_error(self, msg: str, level: int = -2): + """Log error message with session ID.""" + pos = self.parse_source_pos(level) + logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_exception(self, msg: str, level: int = -2): + """Log exception message with session ID.""" + pos = self.parse_source_pos(level) + logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}") diff --git a/check_db_constraint.py b/check_db_constraint.py new file mode 100644 index 0000000..cee354c --- /dev/null +++ b/check_db_constraint.py @@ -0,0 +1,42 @@ +from sqlalchemy import create_engine, inspect +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +import asyncio +import os +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() + +async def check_table_constraints(): + try: + # 获取数据库连接字符串 + DATABASE_URL = os.getenv("DATABASE_URL", "mysql+asyncmy://root:123456@localhost:3306/th_agenter") + + # 创建异步引擎 + engine = create_async_engine(DATABASE_URL, echo=True) + + # 创建会话 + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + # 获取数据库连接 + async with session.begin(): + # 使用inspect查看表结构 + inspector = inspect(engine) + + # 获取messages表的所有约束 + constraints = await engine.run_sync(inspector.get_table_constraints, 'messages') + print("Messages表的所有约束:") + for constraint in constraints: + print(f" 约束名称: {constraint['name']}, 类型: {constraint['type']}") + if constraint['type'] == 'PRIMARY KEY': + print(f" 主键约束列: {constraint['constrained_columns']}") + + except Exception as e: + print(f"检查约束时出错: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(check_table_constraints()) \ No newline at end of file diff --git a/check_jwt.py b/check_jwt.py new file mode 100644 index 0000000..3e145ce --- /dev/null +++ b/check_jwt.py @@ -0,0 +1,9 @@ +import jwt +import inspect + +print(f"jwt module path: {inspect.getfile(jwt)}") +print(f"jwt module attributes: {dir(jwt)}") +try: + print(f"jwt module __version__: {jwt.__version__}") +except AttributeError: + print("jwt module has no __version__ attribute") diff --git a/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin new file mode 100644 index 0000000..3746ace Binary files /dev/null and b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/data_level0.bin differ diff --git a/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/header.bin differ diff --git a/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin new file mode 100644 index 0000000..69654bf Binary files /dev/null and b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/length.bin differ diff --git a/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/link_lists.bin b/data/chroma/kb_13/1e2b2695-d8f9-48e5-8619-0c2980084fb9/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/data/chroma/kb_13/chroma.sqlite3 b/data/chroma/kb_13/chroma.sqlite3 new file mode 100644 index 0000000..a4a8e46 Binary files /dev/null and b/data/chroma/kb_13/chroma.sqlite3 differ diff --git a/data/chroma/kb_14/chroma.sqlite3 b/data/chroma/kb_14/chroma.sqlite3 new file mode 100644 index 0000000..32aa6ca Binary files /dev/null and b/data/chroma/kb_14/chroma.sqlite3 differ diff --git a/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin new file mode 100644 index 0000000..d51ad17 Binary files /dev/null and b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/data_level0.bin differ diff --git a/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/header.bin differ diff --git a/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin new file mode 100644 index 0000000..3b62fac Binary files /dev/null and b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/length.bin differ diff --git a/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/link_lists.bin b/data/chroma/kb_14/dfcaee8f-f6cd-4710-b072-de367b72f3bf/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin new file mode 100644 index 0000000..078d9b6 Binary files /dev/null and b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/data_level0.bin differ diff --git a/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/header.bin differ diff --git a/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin new file mode 100644 index 0000000..b466fb9 Binary files /dev/null and b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/length.bin differ diff --git a/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/link_lists.bin b/data/chroma/kb_15/1a6b1296-463c-4ec7-a2ed-a2cc3d5473c7/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/data/chroma/kb_15/chroma.sqlite3 b/data/chroma/kb_15/chroma.sqlite3 new file mode 100644 index 0000000..cf591c9 Binary files /dev/null and b/data/chroma/kb_15/chroma.sqlite3 differ diff --git a/data/chroma/kb_16/chroma.sqlite3 b/data/chroma/kb_16/chroma.sqlite3 new file mode 100644 index 0000000..0b74e6d Binary files /dev/null and b/data/chroma/kb_16/chroma.sqlite3 differ diff --git a/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin new file mode 100644 index 0000000..6b3ce75 Binary files /dev/null and b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/data_level0.bin differ diff --git a/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin new file mode 100644 index 0000000..b4a33c1 Binary files /dev/null and b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/header.bin differ diff --git a/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin new file mode 100644 index 0000000..1e143da Binary files /dev/null and b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/length.bin differ diff --git a/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/link_lists.bin b/data/chroma/kb_18/822d9475-27cb-4fcf-b8ec-1d17bc7e7d46/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/data/chroma/kb_18/chroma.sqlite3 b/data/chroma/kb_18/chroma.sqlite3 new file mode 100644 index 0000000..28ea59f Binary files /dev/null and b/data/chroma/kb_18/chroma.sqlite3 differ diff --git a/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf b/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf new file mode 100644 index 0000000..4fa282e Binary files /dev/null and b/data/uploads/kb_14/bb6e514f-7f78-47e2-be39-8e33e2b3e0de_产品单页_M8004ML30_中性中文版.pdf differ diff --git a/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf b/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf new file mode 100644 index 0000000..4fa282e Binary files /dev/null and b/data/uploads/kb_18/c9adc152-5413-4d9a-936d-ef6b4b985d90_产品单页_M8004ML30_中性中文版.pdf differ diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..180e465 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,17 @@ +services: + db: + image: pgvector/pgvector:pg16 + container_name: pgvector-db + environment: + POSTGRES_USER: drgraph + POSTGRES_PASSWORD: yingping + POSTGRES_DB: th_agenter + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + restart: unless-stopped + +volumes: + pgdata: + # docker exec -it pgvector-db psql -U drgraph -d th_agenter diff --git a/env03_db.txt b/env03_db.txt new file mode 100644 index 0000000..43ee96d Binary files /dev/null and b/env03_db.txt differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..660ed6d --- /dev/null +++ b/main.py @@ -0,0 +1,139 @@ +# uvicorn main:app --host 0.0.0.0 --port 8000 --reload + +# 1. pip install fastapi-cdn-host +# 2. import fastapi_cdn_host +# 3. fastapi_cdn_host.patch_docs(app) + + +from fastapi import FastAPI +import fastapi_cdn_host + +from os.path import dirname, realpath + +from dotenv import load_dotenv +load_dotenv() + +from utils.util_log import init_logger +from loguru import logger +base_dir: str = dirname(realpath(__file__)) +init_logger(base_dir) + +from th_agenter.api.routes import router +from contextlib import asynccontextmanager +from starlette.exceptions import HTTPException as StarletteHTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + logger.info("[生命周期] - Starting up TH Agenter application...") + yield + # Shutdown + logger.info("[生命周期] - Shutting down TH Agenter application...") + +def setup_exception_handlers(app: FastAPI) -> None: + """Setup global exception handlers.""" + + # Import custom exceptions and handlers + from utils.util_exceptions import ChatAgentException, chat_agent_exception_handler + + @app.exception_handler(ChatAgentException) + async def custom_chat_agent_exception_handler(request, exc): + return await chat_agent_exception_handler(request, exc) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request, exc): + from utils.util_exceptions import HxfErrorResponse + logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}") + return HxfErrorResponse(exc) + + def make_json_serializable(obj): + """递归地将对象转换为JSON可序列化的格式""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + elif isinstance(obj, bytes): + return obj.decode('utf-8') + elif isinstance(obj, (ValueError, Exception)): + return str(obj) + elif isinstance(obj, dict): + return {k: make_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + else: + # For any other object, convert to string + return str(obj) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): + # Convert any non-serializable objects to strings in error details + try: + errors = make_json_serializable(exc.errors()) + except Exception as e: + # Fallback: if even our conversion fails, use a simple error message + errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}] + logger.exception(f"Request Validation Error: {errors}") + + logger.exception(f"validation_error: {errors}") + return JSONResponse( + status_code=422, + content={ + "error": { + "type": "validation_error", + "message": "Request validation failed", + "details": errors + } + } + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request, exc): + logger.error(f"Unhandled exception: {exc}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "error": { + "type": "internal_error", + "message": "Internal server error" + } + } + ) + +def create_app() -> FastAPI: + """Create and configure FastAPI application.""" + from th_agenter.core.config import get_settings + settings = get_settings() + + # Create FastAPI app + app = FastAPI( + title=settings.app_name, + version=settings.app_version, + description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由TH Agenter修改", + debug=settings.debug, + lifespan=lifespan, + ) + + # Add middleware + from th_agenter.core.app import setup_middleware + setup_middleware(app, settings) + + # # Add exception handlers + setup_exception_handlers(app) + add_router(app) + + return app + +def add_router(app: FastAPI) -> None: + """Add default routers to the FastAPI application.""" + + @app.get("/") + def read_root(): + logger.info("Hello World") + return {"Hello": "World"} + + # Include routers + app.include_router(router, prefix="/api") + +app = create_app() +fastapi_cdn_host.patch_docs(app) +# from test.example import internet_search_tool diff --git a/reade.me b/reade.me deleted file mode 100644 index e8e9f6f..0000000 --- a/reade.me +++ /dev/null @@ -1 +0,0 @@ -adda diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..74438ff --- /dev/null +++ b/requirements.txt @@ -0,0 +1,128 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.13.2 +aiomysql==0.3.2 +aiosignal==1.4.0 +alembic==1.17.2 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.0 +asyncpg==0.31.0 +attrs==25.4.0 +bcrypt==5.0.0 +boto3==1.42.9 +botocore==1.42.9 +certifi==2025.11.12 +cffi==2.0.0 +charset-normalizer==3.4.4 +click==8.3.1 +colorama==0.4.6 +cryptography==46.0.3 +dataclasses-json==0.6.7 +distro==1.9.0 +dnspython==2.8.0 +email-validator==2.3.0 +fastapi==0.124.4 +fastapi-cli==0.0.16 +fastapi-cloud-cli==0.6.0 +fastar==0.8.0 +filelock==3.20.0 +frozenlist==1.8.0 +greenlet==3.3.0 +h11==0.16.0 +httpcore==1.0.9 +httptools==0.7.1 +httpx==0.28.1 +httpx-sse==0.4.3 +idna==3.11 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.12.0 +jmespath==1.0.1 +jsonpatch==1.33 +jsonpointer==3.0.0 +langchain==1.1.3 +langchain-classic==1.0.0 +langchain-chroma>=0.1.0 +langchain-community==0.4.1 +langchain-core==1.2.0 +langchain-openai==1.1.3 +langchain-postgres==0.0.16 +langchain-text-splitters==1.0.0 +langgraph==1.0.5 +langgraph-checkpoint==3.0.1 +langgraph-prebuilt==1.0.5 +langgraph-sdk==0.3.0 +langsmith==0.4.59 +loguru==0.7.3 +Mako==1.3.10 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +marshmallow==3.26.1 +mdurl==0.1.2 +modelscope==1.33.0 +multidict==6.7.0 +mypy_extensions==1.1.0 +numpy==2.3.5 +openai==2.11.0 +orjson==3.11.5 +ormsgpack==1.12.0 +packaging==25.0 +pandas==2.3.3 +pdfminer.six==20251107 +pdfplumber==0.11.8 +pgvector==0.3.6 +pillow==12.0.0 +propcache==0.4.1 +psycopg==3.3.2 +psycopg-binary==3.3.2 +psycopg-pool==3.3.0 +psycopg2==2.9.11 +pycparser==2.23 +pydantic==2.12.5 +pydantic-extra-types==2.10.6 +pydantic-settings==2.12.0 +pydantic_core==2.41.5 +Pygments==2.19.2 +PyJWT==2.10.1 +PyMySQL==1.1.2 +pypdfium2==5.2.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.2.1 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.3 +regex==2025.11.3 +requests==2.32.5 +requests-toolbelt==1.0.0 +rich==14.2.0 +rich-toolkit==0.17.0 +rignore==0.7.6 +ruamel.yaml==0.18.16 +ruamel.yaml.clib==0.2.15 +s3transfer==0.16.0 +sentry-sdk==2.47.0 +setuptools==80.9.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +SQLAlchemy==2.0.45 +starlette==0.50.0 +tenacity==9.1.2 +tiktoken==0.12.0 +tqdm==4.67.1 +typer==0.20.0 +typing-inspect==0.9.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +ujson==5.11.0 +urllib3==2.6.2 +uuid_utils==0.12.0 +uvicorn==0.38.0 +watchfiles==1.1.1 +websockets==15.0.1 +wheel==0.45.1 +win32_setctime==1.2.0 +xxhash==3.6.0 +yarl==1.22.0 +zstandard==0.25.0 diff --git a/scripts/FRONTEND_PROXY.md b/scripts/FRONTEND_PROXY.md new file mode 100644 index 0000000..191871d --- /dev/null +++ b/scripts/FRONTEND_PROXY.md @@ -0,0 +1,43 @@ +# 前端代理到后端 (端口 5005 → 8000) + +前端(如运行在 5005 的 Vite/Vben)访问 `/api/*` 时,需要转发到本后端 **http://localhost:8000**。 + +## Vite + +在 `vite.config.ts` 中: + +```ts +export default defineConfig({ + server: { + port: 5005, + proxy: { + '/api': { + target: 'http://localhost:8000', + changeOrigin: true, + }, + }, + }, +}); +``` + +## Vben Admin + +在 `.env.development` 或环境变量中设置: + +``` +VITE_GLOB_API_URL=/api +``` + +并确保 Vite 的 `server.proxy` 将 `/api` 指向 `http://localhost:8000`(同上)。 + +## 直接调后端(排错用) + +不经过前端代理,直接请求后端: + +```bash +curl -X POST 'http://localhost:8000/api/auth/login' \ + -H 'Content-Type: application/json' \ + -d '{"email":"admin@example.com","password":"admin123"}' +``` + +默认管理员:`admin@example.com` / `admin123`(需先执行 `python3 scripts/seed_admin.py`)。 diff --git a/scripts/install_mysql.sh b/scripts/install_mysql.sh new file mode 100755 index 0000000..0c19454 --- /dev/null +++ b/scripts/install_mysql.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# 通过 Homebrew 安装并启动 MySQL(macOS) +# 请在终端中执行:./scripts/install_mysql.sh + +set -e +cd "$(dirname "$0")/.." + +# 查找 brew +BREW="" +for p in /opt/homebrew/bin/brew /usr/local/bin/brew; do + [ -x "$p" ] && BREW="$p" && break +done +if [ -z "$BREW" ]; then + echo ">>> 未找到 Homebrew。请先安装:" + echo " /bin/bash -c \"\$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)\"" + echo " 安装后按提示把 brew 加入 PATH" + exit 1 +fi + +echo ">>> 安装 MySQL..." +$BREW install mysql + +echo ">>> 启动 MySQL 服务..." +$BREW services start mysql + +echo "" +echo ">>> MySQL 安装并已启动。等待几秒后,创建库和用户请执行:" +echo " ./scripts/setup_mysql_local.sh" +echo "" +echo ">>> 然后在 .env 中设置:" +echo " DATABASE_URL=mysql+aiomysql://root:yingping@localhost:3306/allm?charset=utf8mb4" diff --git a/scripts/seed_admin.py b/scripts/seed_admin.py new file mode 100644 index 0000000..e679d0f --- /dev/null +++ b/scripts/seed_admin.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +"""创建默认管理员账号 admin@example.com / admin123,便于前端登录。""" +import asyncio +import os +import sys + +# 在导入 th_agenter 前加载 .env +from pathlib import Path +_root = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(_root)) +os.chdir(_root) + +from dotenv import load_dotenv +load_dotenv() + +# 若未设置,可在此指定(或通过环境变量传入) +# os.environ.setdefault("DATABASE_URL", "mysql+aiomysql://root:xxx@localhost:3306/allm?charset=utf8mb4") + +from th_agenter.db.database import AsyncSessionFactory +from th_agenter.services.user import UserService +from utils.util_schemas import UserCreate + + +async def main(): + async with AsyncSessionFactory() as session: + svc = UserService(session) + exists = await svc.get_user_by_email("admin@example.com") + if exists: + print("admin@example.com 已存在,跳过创建") + return + user = await svc.create_user(UserCreate( + username="admin", + email="admin@example.com", + password="admin123", + full_name="Admin", + )) + print(f"已创建管理员: {user.username} / admin@example.com / admin123") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/setup_mysql_local.sh b/scripts/setup_mysql_local.sh new file mode 100755 index 0000000..d81c1f3 --- /dev/null +++ b/scripts/setup_mysql_local.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# 本地 MySQL 建库脚本:创建 allm,root 密码 yingping +# 连接串:mysql+aiomysql://root:yingping@localhost:3306/allm?charset=utf8mb4 +# +# 使用前请确保 MySQL 已安装并启动: +# macOS: brew install mysql && brew services start mysql +# Ubuntu: sudo apt install mysql-server && sudo systemctl start mysql + +set -e + +DB_NAME="allm" +DB_USER="root" +DB_PASS="yingping" +HOST="127.0.0.1" +PORT="3306" + +echo ">>> 检查 MySQL 是否可连接 (${HOST}:${PORT}) ..." + +# 尝试无密码连接(首次安装) +if mysql -u "$DB_USER" -h "$HOST" -P "$PORT" -e "SELECT 1" 2>/dev/null; then + echo ">>> 使用 root 无密码连接成功,创建库并设置密码..." + mysql -u "$DB_USER" -h "$HOST" -P "$PORT" -e " + CREATE DATABASE IF NOT EXISTS \`${DB_NAME}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + ALTER USER 'root'@'localhost' IDENTIFIED BY '${DB_PASS}'; + ALTER USER 'root'@'127.0.0.1' IDENTIFIED BY '${DB_PASS}'; + FLUSH PRIVILEGES; + " + echo ">>> 数据库 ${DB_NAME} 已创建,root 密码已设为 ${DB_PASS}" +# 尝试使用 yingping 连接(可能已设置过) +elif mysql -u "$DB_USER" -p"${DB_PASS}" -h "$HOST" -P "$PORT" -e "SELECT 1" 2>/dev/null; then + echo ">>> 使用 root:yingping 连接成功,确保库存在..." + mysql -u "$DB_USER" -p"${DB_PASS}" -h "$HOST" -P "$PORT" -e " + CREATE DATABASE IF NOT EXISTS \`${DB_NAME}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + " + echo ">>> 数据库 ${DB_NAME} 已就绪" +else + echo ">>> 无法连接 MySQL。请先安装并启动,且能以 root 登录(无密码或已知密码)。" + echo ">>> 手动执行:" + echo " mysql -u root -p -e \"CREATE DATABASE IF NOT EXISTS ${DB_NAME} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; ALTER USER 'root'@'localhost' IDENTIFIED BY '${DB_PASS}'; FLUSH PRIVILEGES;\"" + exit 1 +fi + +echo "" +echo ">>> 在 .env 中设置:" +echo "DATABASE_URL=mysql+aiomysql://${DB_USER}:${DB_PASS}@localhost:${PORT}/${DB_NAME}?charset=utf8mb4" +echo "" +echo ">>> 然后执行迁移:" +echo "DATABASE_URL=\"mysql+aiomysql://${DB_USER}:${DB_PASS}@localhost:${PORT}/${DB_NAME}?charset=utf8mb4\" python -m alembic upgrade head" diff --git a/scripts/start_local.sh b/scripts/start_local.sh new file mode 100755 index 0000000..8a357a4 --- /dev/null +++ b/scripts/start_local.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# 本地启动(不使用 Docker,使用本地 PostgreSQL) +# +# 前置条件: +# 1. 本地 PostgreSQL 已安装并运行,且已安装 pgvector 扩展 +# 2. 已创建数据库 th_agenter、用户 drgraph / 密码 yingping(与 docker-compose 一致) +# 创建示例:psql -U postgres -c "CREATE USER drgraph WITH PASSWORD 'yingping';" +# psql -U postgres -c "CREATE DATABASE th_agenter OWNER drgraph;" +# psql -U drgraph -d th_agenter -c "CREATE EXTENSION vector;" +# 3. 首次运行前执行迁移:DATABASE_URL="postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter" python3 -m alembic upgrade head +# +# 也可在 .env 中设置 DATABASE_URL=postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter + +set -e +cd "$(dirname "$0")/.." + +export DATABASE_URL="${DATABASE_URL:-postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter}" + +exec python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..5616d01 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1 @@ +# Test package for PostgreSQL agent functionality \ No newline at end of file diff --git a/test/example.py b/test/example.py new file mode 100644 index 0000000..8a637ed --- /dev/null +++ b/test/example.py @@ -0,0 +1,154 @@ +import os +import asyncio +from datetime import datetime +from deepagents import create_deep_agent +from openai import OpenAI +from langchain.chat_models import init_chat_model +from langchain.agents import create_agent +from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具 +from deepagents.backends import StoreBackend +from loguru import logger +def internet_search_tool(query: str): + """Run a web search""" + logger.info(f"Running internet search for query: {query}") + client = OpenAI( + api_key=os.getenv('DASHSCOPE_API_KEY'), + base_url=os.getenv('DASHSCOPE_BASE_URL'), + ) + logger.info(f"create OpenAI") + completion = client.chat.completions.create( + model="qwen-plus", + messages=[ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': query} + ], + extra_body={ + "enable_search": True + } + ) + logger.info(f"create completions") + logger.info(f"OpenAI response: {completion.choices[0].message.content}") + return completion.choices[0].message.content + + + +# System prompt to steer the agent to be an expert researcher +today = datetime.now().strftime("%Y年%m月%d日") +research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。 + +你可以使用互联网搜索工具来获取信息。 +## `internet_search` +使用此工具对给定查询进行互联网搜索。你可以指定返回结果的最大数量、主题以及是否包含原始内容。 + +今天的日期是:{today} +""" + +# Create the deep agent with memory +model = init_chat_model( + model="gpt-4.1-mini", + model_provider='openai', + api_key=os.getenv('OPENAI_API_KEY'), + base_url=os.getenv('OPENAI_BASE_URL'), +) +checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史 + +agent = create_deep_agent( # state:thread会话级的状态 + tools=[internet_search_tool], + system_prompt=research_instructions, + model=model, + checkpointer=checkpointer, # 添加检查点,启用自动记忆 + interrupt_on={'internet_search_tool':True} +) + +# 多轮对话循环(使用 Checkpointer 自动记忆) +printed_msg_ids = set() # 跟踪已打印的消息ID +thread_id = "user_session_001" # 会话 ID,区分不同用户/会话 +config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话 + +print("开始对话(输入 'exit' 退出):") +while True: + user_input = input("\nHUMAN: ").strip() + if user_input.lower() == 'exit': + break + + # 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印 + pending_resume = None + while True: + if pending_resume is None: + request = {"messages": [{"role": "user", "content": user_input}]} + else: + from langgraph.types import Command as _Command + + request = _Command(resume=pending_resume) + pending_resume = None + + for item in agent.stream( + request, + config=config, + stream_mode="values", + ): + state = item[0] if isinstance(item, tuple) and len(item) == 2 else item + + # 先检查是否触发了 Human-In-The-Loop 中断 + if isinstance(state, dict) and "__interrupt__" in state: + interrupts = state["__interrupt__"] or [] + if interrupts: + hitl_payload = interrupts[0].value + action_requests = hitl_payload.get("action_requests", []) + + print("\n=== 需要人工审批的工具调用 ===") + decisions: list[dict[str, str]] = [] + for idx, ar in enumerate(action_requests): + name = ar.get("name") + args = ar.get("args") + print(f"[{idx}] 工具 {name} 参数: {args}") + while True: + choice = input(" 决策 (a=approve, r=reject): ").strip().lower() + if choice in ("a", "r"): + break + decisions.append({"type": "approve" if choice == "a" else "reject"}) + + # 下一轮调用改为 resume,同一轮用户回合继续往下跑 + pending_resume = {"decisions": decisions} + break + + # 兼容 dict state 和 AgentState dataclass + messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", []) + for msg in messages: + msg_id = getattr(msg, "id", None) + if msg_id is not None and msg_id in printed_msg_ids: + continue + if msg_id is not None: + printed_msg_ids.add(msg_id) + + msg_type = getattr(msg, "type", None) + + if msg_type == "human": + # 用户输入已经在命令行里,不再重复打印 + continue + + if msg_type == "ai": + tool_calls = getattr(msg, "tool_calls", None) or [] + if tool_calls: + # 这是发起工具调用的 AI 消息(TOOL CALL) + for tc in tool_calls: + tool_name = tc.get("name") + args = tc.get("args") + print(f"TOOL CALL [{tool_name}]: {args}") + # 如果 AI 同时带有自然语言内容,也一起打印 + if getattr(msg, "content", None): + print(f"AI: {msg.content}") + continue + + if msg_type == "tool": + # 工具执行结果(TOOL RESPONSE) + tool_name = getattr(msg, "name", None) or "tool" + print(f"TOOL RESPONSE [{tool_name}]: {msg.content}") + continue + + # 兜底:其它类型直接打印出来便于调试 + print(f"[{msg_type}]: {getattr(msg, 'content', None)}") + + # 如果没有新的中断需要 resume,则整轮结束,等待下一轮用户输入 + if pending_resume is None: + break diff --git a/test/vl_test.py b/test/vl_test.py new file mode 100644 index 0000000..32dc638 --- /dev/null +++ b/test/vl_test.py @@ -0,0 +1,62 @@ +import os +from loguru import logger +# from vllm import LLM, SamplingParams +from langchain_ollama import ChatOllama +from langchain_core.messages import HumanMessage + +# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载 +os.environ['VLLM_USE_MODELSCOPE']='True' + +# def get_completion(prompts, model, tokenizer=None, max_tokens=512, temperature=0.8, top_p=0.95, max_model_len=2048): +# stop_token_ids = [151329, 151336, 151338] +# # 创建采样参数。temperature 控制生成文本的多样性,top_p 控制核心采样的概率 +# sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids) +# # 初始化 vLLM 推理引擎 +# llm = LLM(model=model, tokenizer=tokenizer, max_model_len=max_model_len,trust_remote_code=True) +# outputs = llm.generate(prompts, sampling_params) +# return outputs + +def vl_test(): + logger.info("vl_test") + + # 使用LangChain 1.x的ChatOllama类创建客户端 + client = ChatOllama( + base_url="http://192.168.10.11:11434", + model="llava-phi3:latest", # "qwen3-vl:8b", + temperature=0.7, + ) + + # 测试调用qwen3-vl:8b视觉大模型 + try: + # 使用LangChain 1.x的方式构建消息 + message = HumanMessage( + content=[ + { + "type": "text", + "text": "请描述这张图片的内容" + }, + # 如果需要添加图像,可以使用以下格式: + # { + # "type": "image_url", + # "image_url": { + # "url": "https://example.com/image.jpg" # 或者base64编码的图片数据 + # } + # } + ] + ) + + # 调用模型 + response = client.invoke([message]) + + # 获取模型响应 + result = response.content + logger.info(f"qwen3-vl:8b响应: {result}") + return result + except Exception as e: + logger.error(f"调用qwen3-vl:8b失败: {e}") + raise e + + +# 如果直接运行该文件,执行测试 +if __name__ == "__main__": + vl_test() \ No newline at end of file diff --git a/th_agenter/api/__init__.py b/th_agenter/api/__init__.py new file mode 100644 index 0000000..4cdf656 --- /dev/null +++ b/th_agenter/api/__init__.py @@ -0,0 +1 @@ +"""API module for TH Agenter.""" diff --git a/th_agenter/api/endpoints/__init__.py b/th_agenter/api/endpoints/__init__.py new file mode 100644 index 0000000..feed341 --- /dev/null +++ b/th_agenter/api/endpoints/__init__.py @@ -0,0 +1 @@ +"""API endpoints for TH Agenter.""" diff --git a/th_agenter/api/endpoints/auth.py b/th_agenter/api/endpoints/auth.py new file mode 100644 index 0000000..9086657 --- /dev/null +++ b/th_agenter/api/endpoints/auth.py @@ -0,0 +1,131 @@ +"""Authentication endpoints.""" + +from datetime import timedelta +from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy.orm import Session + +from ...core.config import get_settings +from ...db.database import DrSession, get_session +from ...services.auth import AuthService +from ...services.user import UserService +from ...schemas.user import UserResponse, UserCreate, LoginResponse +from utils.util_schemas import Token, LoginRequest +from loguru import logger +from utils.util_exceptions import HxfResponse + +router = APIRouter() +settings = get_settings() + +@router.post("/register", response_model=UserResponse, summary="注册新用户") +async def register( + request_user_data: UserCreate, + session: DrSession = Depends(get_session) +): + """注册新用户""" + user_service = UserService(session) + session.desc = f"START: 注册用户 {request_user_data.email}" + if await user_service.get_user_by_email(request_user_data.email): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"邮箱 {request_user_data.email} 已被注册,请使用其他邮箱注册!!!" + ) + + if await user_service.get_user_by_username(request_user_data.username): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"用户名 {request_user_data.username} 已被注册,请使用其他用户名注册!!!" + ) + + user = await user_service.create_user(request_user_data) + response = UserResponse.model_validate(user, from_attributes=True) + return HxfResponse(response) + +@router.post("/login", response_model=LoginResponse, summary="邮箱与密码登录") +async def login( + login_data: LoginRequest, + session: DrSession = Depends(get_session) +): + """邮箱与密码登录""" + # Authenticate user by email + session.desc = f"START: 用户 {login_data.email} 尝试登录" + user = await AuthService.authenticate_user_by_email(session, login_data.email, login_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"邮箱 {login_data.email} 或密码错误,请检查后重试!!!", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Create access token + access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes) + access_token = await AuthService.create_access_token( + session, data={"sub": user.username}, expires_delta=access_token_expires + ) + session.desc = f"用户 {user.username} 登录成功" + + response = LoginResponse( + access_token=access_token, + token_type="bearer", + expires_in=settings.security.access_token_expire_minutes * 60, + user=UserResponse.model_validate(user, from_attributes=True) + ) + return HxfResponse(response) + +@router.post("/login-oauth", response_model=Token, summary="用户通过用户名和密码登录 (OAuth2 兼容)") +async def login_oauth( + form_data: OAuth2PasswordRequestForm = Depends(), + session: DrSession = Depends(get_session) +): + """用户通过用户名和密码登录 (OAuth2 兼容)""" + session.desc = f"START: 用户 {form_data.username} 尝试 OAuth2 登录" + user = await AuthService.authenticate_user(session, form_data.username, form_data.password) + if not user: + session.desc = f"用户 {form_data.username} 尝试 OAuth2 登录失败" + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Create access token + access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes) + access_token = await AuthService.create_access_token( + session, data={"sub": user.username}, expires_delta=access_token_expires + ) + session.desc = f"用户 {user.username} OAuth2 登录成功" + + return HxfResponse( + { + "access_token": access_token, + "token_type": "bearer", + "expires_in": settings.security.access_token_expire_minutes * 60 + } + ) + +@router.post("/refresh", response_model=Token, summary="刷新访问token") +async def refresh_token( + current_user = Depends(AuthService.get_current_user), + session: DrSession = Depends(get_session) +): + """刷新访问 token""" + # Create new access token + access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes) + access_token = await AuthService.create_access_token( + session, data={"sub": current_user.username}, expires_delta=access_token_expires + ) + + response = Token( + access_token=access_token, + token_type="bearer", + expires_in=settings.security.access_token_expire_minutes * 60 + ) + return HxfResponse(response) + +@router.get("/me", response_model=UserResponse, summary="获取当前用户信息") +async def get_current_user_info( + current_user = Depends(AuthService.get_current_user) +): + """获取当前用户信息""" + response = UserResponse.model_validate(current_user, from_attributes=True) + return HxfResponse(response) \ No newline at end of file diff --git a/th_agenter/api/endpoints/chat.py b/th_agenter/api/endpoints/chat.py new file mode 100644 index 0000000..e1f1e11 --- /dev/null +++ b/th_agenter/api/endpoints/chat.py @@ -0,0 +1,283 @@ +"""Chat endpoints for TH Agenter.""" + +import json +from typing import List +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session +from loguru import logger + +from ...db.database import get_session +from ...models.user import User +from ...services.auth import AuthService +from ...services.chat import ChatService +from ...services.conversation import ConversationService +from utils.util_exceptions import HxfResponse + +from utils.util_schemas import ( + ConversationCreate, + ConversationResponse, + ConversationUpdate, + MessageCreate, + MessageResponse, + ChatRequest, + ChatResponse +) + +router = APIRouter() + +@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话") +async def update_conversation( + conversation_id: int, + conversation_update: ConversationUpdate, + session: Session = Depends(get_session) +): + """更新指定对话""" + session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}" + conversation_service = ConversationService(session) + updated_conversation = await conversation_service.update_conversation( + conversation_id, conversation_update + ) + session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}" + response = ConversationResponse.model_validate(updated_conversation) + return HxfResponse(response) + + +@router.delete("/conversations/{conversation_id}", summary="删除指定对话") +async def delete_conversation( + conversation_id: int, + session: Session = Depends(get_session) +): + """删除指定对话""" + session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}" + conversation_service = ConversationService(session) + await conversation_service.delete_conversation(conversation_id) + session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}" + response = {"message": "Conversation deleted successfully"} + return HxfResponse(response) + + +@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话") +async def archive_conversation( + conversation_id: int, + session: Session = Depends(get_session) +): + """归档指定对话.""" + conversation_service = ConversationService(session) + success = await conversation_service.archive_conversation(conversation_id) + if not success: + session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to archive conversation" + ) + + session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}" + response = {"message": "Conversation archived successfully"} + return HxfResponse(response) + + +@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话") +async def unarchive_conversation( + conversation_id: int, + session: Session = Depends(get_session) +): + """取消归档指定对话.""" + session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}" + conversation_service = ConversationService(session) + success = await conversation_service.unarchive_conversation(conversation_id) + if not success: + session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to unarchive conversation" + ) + + session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}" + response = {"message": "Conversation unarchived successfully"} + return HxfResponse(response) + + +# Message management +@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息") +async def get_conversation_messages( + conversation_id: int, + skip: int = 0, + limit: int = 100, + session: Session = Depends(get_session) +): + """获取指定对话的消息""" + session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" + conversation_service = ConversationService(session) + messages = await conversation_service.get_conversation_messages( + conversation_id, skip=skip, limit=limit + ) + session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}" + response = [MessageResponse.model_validate(msg) for msg in messages] + return HxfResponse(response) + +# Chat functionality +@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应") +async def chat( + conversation_id: int, + chat_request: ChatRequest, + session: Session = Depends(get_session) +): + """发送消息并获取AI响应""" + session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}" + chat_service = ChatService(session) + await chat_service.initialize(conversation_id) + + # response = await chat_service.chat( + # conversation_id=conversation_id, + # message=chat_request.message, + # stream=False, + # temperature=chat_request.temperature, + # max_tokens=chat_request.max_tokens, + # use_agent=chat_request.use_agent, # 可以简化掉 + # use_langgraph=chat_request.use_langgraph, # 可以简化掉 + # use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉 + # knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉 + # ) + response = "oooooooooooooooooooK" + session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}" + + return HxfResponse(response) +# ------------------------------------------------------------------------ +@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应") +async def chat_stream( + conversation_id: int, + chat_request: ChatRequest, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """发送消息并获取流式AI响应.""" + session.title = f"对话{conversation_id} 发送消息并获取流式AI响应" + session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> " + chat_service = ChatService(session) + await chat_service.initialize(conversation_id, streaming=True) + + async def generate_response(chat_service): + try: + async for chunk in chat_service.chat_stream( + message=chat_request.message + ): + yield chunk + "\n" + except Exception as e: + logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}") + yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"} + + response = StreamingResponse( + generate_response(chat_service), + media_type="text/stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + ) + + return response + +# Conversation management +@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话") +async def create_conversation( + conversation_data: ConversationCreate, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """创建新对话""" + id = current_user.id + session.title = f"用户{current_user.username} - 创建新对话" + session.desc = "START: 创建新对话" + conversation_service = ConversationService(session) + conversation = await conversation_service.create_conversation( + user_id=id, + conversation_data=conversation_data + ) + session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}" + response = ConversationResponse.model_validate(conversation) + return HxfResponse(response) +@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表") +async def list_conversations( + skip: int = 0, + limit: int = 50, + search: str = None, + include_archived: bool = False, + order_by: str = "updated_at", + order_desc: bool = True, + session: Session = Depends(get_session) +): + """获取用户对话列表""" + session.title = "获取用户对话列表" + session.desc = "START: 获取用户对话列表" + conversation_service = ConversationService(session) + conversations = await conversation_service.get_user_conversations( + skip=skip, + limit=limit, + search_query=search, + include_archived=include_archived, + order_by=order_by, + order_desc=order_desc + ) + session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..." + response = [ConversationResponse.model_validate(conv) for conv in conversations] + return HxfResponse(response) + +@router.get("/conversations/count", summary="获取用户对话总数") +async def get_conversations_count( + search: str = None, + include_archived: bool = False, + session: Session = Depends(get_session) +): + """获取用户对话总数""" + from th_agenter.core.context import UserContext + user_id = UserContext.get_current_user_id() + session.title = f"获取用户对话总数[用户id = {user_id}]" + session.desc = "START: 获取用户对话总数" + conversation_service = ConversationService(session) + count = await conversation_service.get_user_conversations_count( + search_query=search, + include_archived=include_archived + ) + session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话" + response = {"count": count} + return HxfResponse(response) + +@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话") +async def get_conversation( + conversation_id: int, + session: Session = Depends(get_session) +): + """获取指定对话""" + session.title = f"获取指定对话[对话id = {conversation_id}]" + session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}" + + conversation_service = ConversationService(session) + conversation = await conversation_service.get_conversation( + conversation_id=conversation_id + ) + if not conversation: + session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Conversation not found" + ) + session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}" + + response = ConversationResponse.model_validate(conversation) + + + # chat_service = ChatService(session) + # await chat_service.initialize(conversation_id, streaming=False) + # messages = await chat_service.get_conversation_history_messages( + # conversation_id + # ) + # response.messages = messages + + messages = await conversation_service.get_conversation_messages( + conversation_id, skip=0, limit=100 + ) + response.messages = [MessageResponse.model_validate(msg) for msg in messages] + + response.message_count = len(response.messages) + return HxfResponse(response) diff --git a/th_agenter/api/endpoints/database_config.py b/th_agenter/api/endpoints/database_config.py new file mode 100644 index 0000000..f0537c6 --- /dev/null +++ b/th_agenter/api/endpoints/database_config.py @@ -0,0 +1,153 @@ +"""数据库配置管理API""" +from loguru import logger +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List, Dict, Any +from pydantic import BaseModel, Field +from th_agenter.models.user import User +from th_agenter.db.database import get_session +from th_agenter.services.database_config_service import DatabaseConfigService +from th_agenter.services.auth import AuthService +from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse +from utils.util_exceptions import HxfResponse +# 在文件顶部添加 +from functools import lru_cache + +router = APIRouter(prefix="/api/database-config", tags=["database-config"]) +# 创建服务单例 +@lru_cache() +def get_database_config_service() -> DatabaseConfigService: + """获取DatabaseConfigService单例""" + # 注意:这里需要处理db session的问题 + return DatabaseConfigService(None) # 临时方案 + +# 或者使用全局变量 +_database_service_instance = None + +def get_database_service(session: Session = Depends(get_session)) -> DatabaseConfigService: + """获取DatabaseConfigService实例""" + global _database_service_instance + if _database_service_instance is None: + _database_service_instance = DatabaseConfigService(session) + else: + # 更新db session + _database_service_instance.db = session + return _database_service_instance + +class DatabaseConfigCreate(BaseModel): + name: str = Field(..., description="配置名称") + db_type: str = Field(default="postgresql", description="数据库类型") + host: str = Field(..., description="主机地址") + port: int = Field(..., description="端口号") + database: str = Field(..., description="数据库名") + username: str = Field(..., description="用户名") + password: str = Field(..., description="密码") + is_default: bool = Field(default=False, description="是否为默认配置") + connection_params: Dict[str, Any] = Field(default=None, description="额外连接参数") + +class DatabaseConfigResponse(BaseModel): + id: int + name: str + db_type: str + host: str + port: int + database: str + username: str + password: str + is_active: bool + is_default: bool + created_at: str + updated_at: str + + +@router.post("/", response_model=NormalResponse, summary="创建或更新数据库配置") +async def create_database_config( + config_data: DatabaseConfigCreate, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """创建或更新数据库配置""" + config = await service.create_or_update_config(current_user.id, config_data.model_dump()) + response = NormalResponse( + success=True, + message="保存数据库配置成功", + data=config + ) + return HxfResponse(response) + +@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表") +async def get_database_configs( + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """获取用户的数据库配置列表""" + configs = service.get_user_configs(current_user.id) + + config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs] + return HxfResponse(config_list) + +@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接") +async def test_database_connection( + config_id: int, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """测试数据库连接""" + result = await service.test_connection(config_id, current_user.id) + return HxfResponse(result) + +@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表") +async def connect_database( + config_id: int, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """连接数据库并获取表列表""" + result = await service.connect_and_get_tables(config_id, current_user.id) + return HxfResponse(result) + + +@router.get("/tables/{table_name}/data", summary="获取表数据预览") +async def get_table_data( + table_name: str, + db_type: str, + limit: int = 100, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """获取表数据预览""" + try: + result = await service.get_table_data(table_name, current_user.id, db_type, limit) + return HxfResponse(result) + except Exception as e: + logger.error(f"获取表数据失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + +@router.get("/tables/{table_name}/schema", summary="获取表结构信息") +async def get_table_schema( + table_name: str, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """获取表结构信息""" + result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的? + return HxfResponse(result) + +@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置") +async def get_config_by_type( + db_type: str, + current_user: User = Depends(AuthService.get_current_user), + service: DatabaseConfigService = Depends(get_database_service) +): + """根据数据库类型获取配置""" + config = service.get_config_by_type(current_user.id, db_type) + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到类型为 {db_type} 的配置" + ) + # 返回包含解密密码的配置 + return HxfResponse(config.to_dict(include_password=True, decrypt_service=service)) \ No newline at end of file diff --git a/th_agenter/api/endpoints/knowledge_base.py b/th_agenter/api/endpoints/knowledge_base.py new file mode 100644 index 0000000..440c4f8 --- /dev/null +++ b/th_agenter/api/endpoints/knowledge_base.py @@ -0,0 +1,616 @@ +"""Knowledge base API endpoints.""" + +from utils.util_exceptions import HxfResponse +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status +from fastapi.responses import JSONResponse +from sqlalchemy import select, func +from sqlalchemy.orm import Session + +from ...db.database import get_session +from ...models.user import User +from ...models.knowledge_base import KnowledgeBase, Document +from ...services.knowledge_base import KnowledgeBaseService +from ...services.document import DocumentService +from ...services.auth import AuthService +from utils.util_schemas import ( + KnowledgeBaseCreate, + KnowledgeBaseResponse, + DocumentResponse, + DocumentListResponse, + DocumentUpload, + DocumentProcessingStatus, + DocumentChunksResponse, + ErrorResponse +) +from utils.util_file import FileUtils +from ...core.config import settings + +router = APIRouter(tags=["knowledge-bases"]) + +@router.post("/", response_model=KnowledgeBaseResponse, summary="创建新的知识库") +async def create_knowledge_base( + kb_data: KnowledgeBaseCreate, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """创建新的知识库""" + # Check if knowledge base with same name already exists for this user + session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}" + kb_service = KnowledgeBaseService(session) + session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}" + existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name) + if existing_kb: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"知识库名称 {kb_data.name} 已存在" + ) + + # Create knowledge base + session.desc = f"知识库 {kb_data.name}不存在,创建之" + kb = await kb_service.create_knowledge_base(kb_data) + + session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功" + response = KnowledgeBaseResponse( + id=kb.id, + created_at=kb.created_at, + updated_at=kb.updated_at, + name=kb.name, + description=kb.description, + embedding_model=kb.embedding_model, + chunk_size=kb.chunk_size, + chunk_overlap=kb.chunk_overlap, + is_active=kb.is_active, + vector_db_type=kb.vector_db_type, + collection_name=kb.collection_name, + document_count=0, + active_document_count=0 + ) + return HxfResponse(response) + +@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库") +async def list_knowledge_bases( + skip: int = 0, + limit: int = 100, + search: Optional[str] = None, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取当前用户的所有知识库""" + session.title = f"获取用户 {current_user.username} 的所有知识库" + session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})" + kb_service = KnowledgeBaseService(session) + knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit) + + result = [] + for kb in knowledge_bases: + # 本知识库的文档数量 + total_docs = await session.scalar( + select(func.count()).where(Document.knowledge_base_id == kb.id) + ) + # 本知识库的已处理文档数量 + active_docs = await session.scalar( + select(func.count()).where( + Document.knowledge_base_id == kb.id, + Document.is_processed == True + ) + ) + + result.append(KnowledgeBaseResponse( + id=kb.id, + created_at=kb.created_at, + updated_at=kb.updated_at, + name=kb.name, + description=kb.description, + embedding_model=kb.embedding_model, + chunk_size=kb.chunk_size, + chunk_overlap=kb.chunk_overlap, + is_active=kb.is_active, + vector_db_type=kb.vector_db_type, + collection_name=kb.collection_name, + document_count=total_docs, + active_document_count=active_docs + )) + + session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库" + return HxfResponse(result) + +@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情") +async def get_knowledge_base( + kb_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """根据知识库ID获取知识库详情""" + session.desc = f"START: 获取知识库 {kb_id} 的详情" + service = KnowledgeBaseService(session) + session.desc = f"检查知识库 {kb_id} 是否存在" + kb = await service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Count documents + total_docs = await session.scalar( + select(func.count()).where(Document.knowledge_base_id == kb.id) + ) + session.desc = f"获取知识库 {kb_id} 共 {total_docs} 个文档" + + active_docs = await session.scalar( + select(func.count()).where( + Document.knowledge_base_id == kb.id, + Document.is_processed == True + ) + ) + + session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理" + response = KnowledgeBaseResponse( + id=kb.id, + created_at=kb.created_at, + updated_at=kb.updated_at, + name=kb.name, + description=kb.description, + embedding_model=kb.embedding_model, + chunk_size=kb.chunk_size, + chunk_overlap=kb.chunk_overlap, + is_active=kb.is_active, + vector_db_type=kb.vector_db_type, + collection_name=kb.collection_name, + document_count=total_docs, + active_document_count=active_docs + ) + return HxfResponse(response) + +@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库") +async def update_knowledge_base( + kb_id: int, + kb_data: KnowledgeBaseCreate, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """更新知识库""" + session.desc = f"START: 更新知识库 {kb_id}" + service = KnowledgeBaseService(session) + kb = await service.update_knowledge_base(kb_id, kb_data) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Count documents + total_docs = await session.scalar( + select(func.count()).where(Document.knowledge_base_id == kb.id) + ) + + active_docs = await session.scalar( + select(func.count()).where( + Document.knowledge_base_id == kb.id, + Document.is_processed == True + ) + ) + + session.desc = f"SUCCESS: 更新知识库 {kb_id},结果 - 共 {total_docs} 个文档,其中 {active_docs} 个已处理" + response = KnowledgeBaseResponse( + id=kb.id, + created_at=kb.created_at, + updated_at=kb.updated_at, + name=kb.name, + description=kb.description, + embedding_model=kb.embedding_model, + chunk_size=kb.chunk_size, + chunk_overlap=kb.chunk_overlap, + is_active=kb.is_active, + vector_db_type=kb.vector_db_type, + collection_name=kb.collection_name, + document_count=total_docs, + active_document_count=active_docs + ) + return HxfResponse(response) + +@router.delete("/{kb_id}", summary="删除知识库") +async def delete_knowledge_base( + kb_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """删除知识库""" + session.desc = f"START: 删除知识库 {kb_id}" + service = KnowledgeBaseService(session) + success = await service.delete_knowledge_base(kb_id) + if not success: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + session.desc = f"SUCCESS: 删除知识库 {kb_id}" + return HxfResponse({"message": "Knowledge base deleted successfully"}) + +# Document management endpoints +@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库") +async def upload_document( + kb_id: int, + file: UploadFile = File(...), + process_immediately: bool = Form(True), + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """上传文档到知识库""" + session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})" + + # Verify knowledge base exists and user has access + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}" + # Validate file + if not FileUtils.validate_file_extension(file.filename): + session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}" + ) + # Check file size (50MB limit) + max_size = 50 * 1024 * 1024 # 50MB + if file.size and file.size > max_size: + session.desc = f"ERROR: 文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制" + ) + + session.desc = f"文件为期望类型,处理文件 {file.filename} - " + # Upload document + doc_service = DocumentService(session) + document = await doc_service.upload_document( + file, kb_id + ) + + # Process document immediately if requested + if process_immediately: + try: + await doc_service.process_document(document.id, kb_id) + # Refresh document to get updated status + await session.refresh(document) + except Exception as e: + session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}" + + session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}" + response = DocumentResponse( + id=document.id, + created_at=document.created_at, + updated_at=document.updated_at, + knowledge_base_id=document.knowledge_base_id, + filename=document.filename, + original_filename=document.original_filename, + file_path=document.file_path, + file_type=document.file_type, + file_size=document.file_size, + mime_type=document.mime_type, + is_processed=document.is_processed, + processing_error=document.processing_error, + chunk_count=document.chunk_count or 0, + embedding_model=document.embedding_model, + file_size_mb=round(document.file_size / (1024 * 1024), 2) + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表") +async def list_documents( + kb_id: int, + skip: int = 0, + limit: int = 50, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取知识库中的文档列表。""" + session.desc = f"START: 获取知识库 {kb_id} 中的文档列表" + # Verify knowledge base exists and user has access + kb_service = KnowledgeBaseService(session) + + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + documents, total = await doc_service.list_documents(kb_id, skip, limit) + + doc_responses = [] + for doc in documents: + doc_responses.append(DocumentResponse( + id=doc.id, + created_at=doc.created_at, + updated_at=doc.updated_at, + knowledge_base_id=doc.knowledge_base_id, + filename=doc.filename, + original_filename=doc.original_filename, + file_path=doc.file_path, + file_type=doc.file_type, + file_size=doc.file_size, + mime_type=doc.mime_type, + is_processed=doc.is_processed, + processing_error=doc.processing_error, + chunk_count=doc.chunk_count or 0, + embedding_model=doc.embedding_model, + file_size_mb=round(doc.file_size / (1024 * 1024), 2) + )) + + session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total} 条" + response = DocumentListResponse( + documents=doc_responses, + total=total, + page=skip // limit + 1, + page_size=limit + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)") +async def get_document_chunks( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """ + 获取知识库中特定文档的所有文档块(片段)。 + + Args: + kb_id: 知识库ID + doc_id: 文档ID + session: 数据库会话 + current_user: 当前认证用户 + + Returns: + DocumentChunksResponse: 文档块(片段)响应模型 + """ + session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)" + kb_service = KnowledgeBaseService(session) + knowledge_base = await kb_service.get_knowledge_base(kb_id) + + if not knowledge_base: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="知识库不存在" + ) + + # Verify document exists in the knowledge base + doc_service = DocumentService(session) + session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService" + document = await doc_service.get_document(doc_id, kb_id) + session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document" + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="文档不存在" + ) + + # Get document chunks + chunks = await doc_service.get_document_chunks(doc_id) + + session.desc = f"SUCCESS: 获取文档 {doc_id} 共 {len(chunks)} 个文档块(片段)" + response = DocumentChunksResponse( + document_id=doc_id, + document_name=document.filename, + total_chunks=len(chunks), + chunks=chunks + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情") +async def get_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取知识库中的文档详情。""" + session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" + # Verify knowledge base exists and user has access + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情" + response = DocumentResponse( + id=document.id, + created_at=document.created_at, + updated_at=document.updated_at, + knowledge_base_id=document.knowledge_base_id, + filename=document.filename, + original_filename=document.original_filename, + file_path=document.file_path, + file_type=document.file_type, + file_size=document.file_size, + mime_type=document.mime_type, + is_processed=document.is_processed, + processing_error=document.processing_error, + chunk_count=document.chunk_count or 0, + embedding_model=document.embedding_model, + file_size_mb=round(document.file_size / (1024 * 1024), 2) + ) + return HxfResponse(response) + +@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档") +async def delete_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """删除知识库中的文档。""" + session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + success = await doc_service.delete_document(doc_id, kb_id) + if not success: + session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}" + response = {"message": "Document deleted successfully"} + return HxfResponse(response) + +@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档") +async def process_document( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """处理知识库中的文档,用于向量搜索。""" + session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Check if document exists + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + # Process the document + result = await doc_service.process_document(doc_id, kb_id) + await session.refresh(document) + session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}" + response = DocumentProcessingStatus( + document_id=doc_id, + status=result["status"], + progress=result.get("progress", 0.0), + error_message=result.get("error_message"), + chunks_created=result.get("chunks_created", 0) + ) + return HxfResponse(response) + +@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态") +async def get_document_processing_status( + kb_id: int, + doc_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取知识库中的文档处理状态。""" + # Verify knowledge base exists and user has access + session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + doc_service = DocumentService(session) + document = await doc_service.get_document(doc_id, kb_id) + if not document: + session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Document not found" + ) + + # Determine status + if document.processing_error: + status_str = "failed" + progress = 0.0 + session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}" + elif document.is_processed: + status_str = "completed" + progress = 100.0 + session.desc = f"SUCCESS: 文档 {doc_id} 处理完成" + else: + status_str = "pending" + progress = 0.0 + session.desc = f"文档 {doc_id} 处理pending中" + + response = DocumentProcessingStatus( + document_id=document.id, + status=status_str, + progress=progress, + error_message=document.processing_error, + chunks_created=document.chunk_count or 0 + ) + return HxfResponse(response) + +@router.get("/{kb_id}/search", summary="在知识库中搜索文档") +async def search_knowledge_base( + kb_id: int, + query: str, + limit: int = 5, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """在知识库中搜索文档。""" + session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}" + kb_service = KnowledgeBaseService(session) + kb = await kb_service.get_knowledge_base(kb_id) + if not kb: + session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Knowledge base not found" + ) + + # Perform search + doc_service = DocumentService(session) + results = await doc_service.search_documents(kb_id, query, limit) + session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果" + response = { + "knowledge_base_id": kb_id, + "query": query, + "results": results, + "total_results": len(results) + } + return HxfResponse(response) diff --git a/th_agenter/api/endpoints/llm_configs.py b/th_agenter/api/endpoints/llm_configs.py new file mode 100644 index 0000000..14f4467 --- /dev/null +++ b/th_agenter/api/endpoints/llm_configs.py @@ -0,0 +1,473 @@ +"""LLM configuration management API endpoints.""" + +from turtle import textinput +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, status, Query +from langchain_openai import ChatOpenAI +from langchain_core.messages import AIMessage +from sqlalchemy.orm import Session +from sqlalchemy import or_, select, delete, update + +from loguru import logger + +from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM +from th_agenter.llm.online.online_llm import OnlineLLM +from ...db.database import get_session +from ...models.user import User +from ...models.llm_config import LLMConfig +from th_agenter.llm.base_llm import LLMConfig_DataClass +from ...core.simple_permissions import require_super_admin, require_authenticated_user +from ...schemas.llm_config import ( + LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse, + LLMConfigTest +) +from th_agenter.services.document_processor import get_document_processor +from utils.util_exceptions import HxfResponse + +router = APIRouter(prefix="/llm-configs", tags=["llm-configs"]) + +@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表") +async def get_llm_configs( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + search: Optional[str] = Query(None), + provider: Optional[str] = Query(None), + is_active: Optional[bool] = Query(None), + is_embedding: Optional[bool] = Query(None), + session: Session = Depends(get_session), + current_user: User = Depends(require_authenticated_user) +): + """获取大模型配置列表.""" + session.title = "获取大模型配置列表" + session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}" + stmt = select(LLMConfig) + + # 搜索 + if search: + stmt = stmt.where( + or_( + LLMConfig.name.ilike(f"%{search}%"), + LLMConfig.model_name.ilike(f"%{search}%"), + LLMConfig.description.ilike(f"%{search}%") + ) + ) + + # 服务商筛选 + if provider: + stmt = stmt.where(LLMConfig.provider == provider) + + # 状态筛选 + if is_active is not None: + stmt = stmt.where(LLMConfig.is_active == is_active) + + # 模型类型筛选 + if is_embedding is not None: + stmt = stmt.where(LLMConfig.is_embedding == is_embedding) + + # 排序 + stmt = stmt.order_by(LLMConfig.name) + + # 分页 + stmt = stmt.offset(skip).limit(limit) + configs = (await session.execute(stmt)).scalars().all() + session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..." + return HxfResponse([config.to_dict(include_sensitive=True) for config in configs]) + + +@router.get("/providers", summary="获取支持的大模型服务商列表") +async def get_llm_providers( + session: Session = Depends(get_session), + current_user: User = Depends(require_authenticated_user) +): + """获取支持的大模型服务商列表.""" + session.desc = "START: 获取支持的大模型服务商列表" + stmt = select(LLMConfig.provider).distinct() + providers = (await session.execute(stmt)).scalars().all() + session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商" + return HxfResponse([provider for provider in providers if provider]) + + + +@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置") +async def get_active_llm_configs( + is_embedding: Optional[bool] = Query(None), + session: Session = Depends(get_session), + current_user: User = Depends(require_authenticated_user) +): + """获取所有激活的大模型配置.""" + session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}" + stmt = select(LLMConfig).where(LLMConfig.is_active == True) + + if is_embedding is not None: + stmt = stmt.where(LLMConfig.is_embedding == is_embedding) + + stmt = stmt.order_by(LLMConfig.created_at) + configs = (await session.execute(stmt)).scalars().all() + session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置" + return HxfResponse([config.to_dict(include_sensitive=True) for config in configs]) + +@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置") +async def get_default_llm_config( + is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"), + session: Session = Depends(get_session), + current_user: User = Depends(require_authenticated_user) +): + """获取默认大模型配置.""" + session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}" + stmt = select(LLMConfig).where( + LLMConfig.is_default == True, + LLMConfig.is_embedding == is_embedding, + LLMConfig.is_active == True + ) + config = (await session.execute(stmt)).scalar_one_or_none() + + if not config: + model_type = "嵌入模型" if is_embedding else "对话模型" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"未找到默认{model_type}配置" + ) + + session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}" + return HxfResponse(config.to_dict(include_sensitive=True)) + +@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情") +async def get_llm_config( + config_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_authenticated_user) +): + """获取大模型配置详情.""" + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + + return HxfResponse(config.to_dict(include_sensitive=True)) + + +@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置") +async def create_llm_config( + config_data: LLMConfigCreate, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """创建大模型配置.""" + # 检查配置名称是否已存在 + # 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误 + username = current_user.username + session.desc = f"START: 创建大模型配置, name={config_data.name}" + stmt = select(LLMConfig).where(LLMConfig.name == config_data.name) + existing_config = (await session.execute(stmt)).scalar_one_or_none() + if existing_config: + session.desc = f"ERROR: 配置名称已存在, name={config_data.name}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="配置名称已存在" + ) + + # 创建配置对象 + config = LLMConfig_DataClass( + name=config_data.name, + provider=config_data.provider, + model_name=config_data.model_name, + api_key=config_data.api_key, + base_url=config_data.base_url, + max_tokens=config_data.max_tokens, + temperature=config_data.temperature, + top_p=config_data.top_p, + frequency_penalty=config_data.frequency_penalty, + presence_penalty=config_data.presence_penalty, + description=config_data.description, + is_active=config_data.is_active, + is_default=config_data.is_default, + is_embedding=config_data.is_embedding, + extra_config=config_data.extra_config or {} + ) + + # 验证配置 + validation_result = config.validate_config() + if not validation_result['valid']: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=validation_result['error'] + ) + + # 如果设为默认,取消同类型的其他默认配置 + if config_data.is_default: + stmt = update(LLMConfig).where( + LLMConfig.is_embedding == config_data.is_embedding + ).values({"is_default": False}) + await session.execute(stmt) + + session.desc = f"验证大模型配置, config_data" + # 创建配置 + config = LLMConfig_DataClass( + name=config_data.name, + provider=config_data.provider, + model_name=config_data.model_name, + api_key=config_data.api_key, + base_url=config_data.base_url, + max_tokens=config_data.max_tokens, + temperature=config_data.temperature, + top_p=config_data.top_p, + frequency_penalty=config_data.frequency_penalty, + presence_penalty=config_data.presence_penalty, + description=config_data.description, + is_active=config_data.is_active, + is_default=config_data.is_default, + is_embedding=config_data.is_embedding, + extra_config=config_data.extra_config or {} + ) + # Audit fields are set automatically by SQLAlchemy event listener + + session.add(config) + await session.commit() + await session.refresh(config) + session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}" + return HxfResponse(config.to_dict()) + + +@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置") +async def update_llm_config( + config_id: int, + config_data: LLMConfigUpdate, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """更新大模型配置.""" + username = current_user.username + session.desc = f"START: 更新大模型配置, id={config_id}" + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + + # 检查配置名称是否已存在(排除自己) + if config_data.name and config_data.name != config.name: + stmt = select(LLMConfig).where( + LLMConfig.name == config_data.name, + LLMConfig.id != config_id + ) + existing_config = (await session.execute(stmt)).scalar_one_or_none() + if existing_config: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="配置名称已存在" + ) + + # 如果设为默认,取消同类型的其他默认配置 + if config_data.is_default is True: + # 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值 + is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding + stmt = update(LLMConfig).where( + LLMConfig.is_embedding == is_embedding, + LLMConfig.id != config_id + ).values({"is_default": False}) + await session.execute(stmt) + + # 更新字段 + update_data = config_data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(config, field, value) + + await session.commit() + await session.refresh(config) + + session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}" + return HxfResponse(config.to_dict()) + +@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置") +async def delete_llm_config( + config_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """删除大模型配置.""" + username = current_user.username + session.desc = f"START: 删除大模型配置, id={config_id}" + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + session.desc = f"待删除大模型记录 {config.to_dict()}" + # TODO: 检查是否有对话或其他功能正在使用该配置 + # 这里可以添加相关的检查逻辑 + + # 删除配置 + await session.delete(config) + await session.commit() + + session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}" + return HxfResponse({"message": "LLM config deleted successfully"}) + +@router.post("/{config_id}/test", summary="测试连接大模型配置") +async def test_llm_config( + config_id: int, + test_data: LLMConfigTest, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """测试连接大模型配置.""" + username = current_user.username + session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}" + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + + logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}") + config_name = config.name + # 验证配置 + validation_result = config.validate_config() + logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}") + if not validation_result["valid"]: + return { + "success": False, + "message": f"配置验证失败: {validation_result['error']}", + "details": validation_result + } + + session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}" + # 尝试创建客户端并发送测试请求 + try: + # # 这里应该根据不同的服务商创建相应的客户端 + # # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架 + + test_message = test_data.message or "Hello, this is a test message." + session.desc = f"准备测试LLM功能 > test_message = {test_message}" + + if config.is_embedding: + config.provider = "ollama" + streaming_llm = BGEEmbedLLM(config) + else: + streaming_llm = OnlineLLM(config) + session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}" + streaming_llm.load_model() # 加载模型 + session.desc = f"加载模型完毕,模型名称:{config.model_name},base_url: {config.base_url},准备测试对话..." + + if config.is_embedding: + # 测试嵌入模型,使用嵌入API而非聊天API + test_text = test_message or "Hello, this is a test message for embedding" + response = streaming_llm.embed_query(test_text) + else: + # 测试聊天模型 + from langchain.messages import SystemMessage, HumanMessage + messages = [ + SystemMessage(content="你是一个简洁的助手,回答控制在50字以内"), + HumanMessage(content=test_message) + ] + response = streaming_llm.model.invoke(messages) + session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}" + + return HxfResponse({ + "success": True, + "message": "LLM测试成功", + "request": test_message, + "response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典 + "latency_ms": 150, # 模拟延迟 + "config_info": config.to_dict() + }) + + except Exception as test_error: + session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}" + return HxfResponse({ + "success": False, + "message": f"LLM测试失败: {str(test_error)}", + "test_message": test_message, + "config_info": config.to_dict() + }) + +@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态") +async def toggle_llm_config_status( + config_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """切换大模型配置状态.""" + username = current_user.username + session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}" + + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + + # 切换状态 + config.is_active = not config.is_active + # Audit fields are set automatically by SQLAlchemy event listener + + await session.commit() + await session.refresh(config) + + status_text = "激活" if config.is_active else "禁用" + session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}" + + return HxfResponse({ + "message": f"配置已{status_text}", + "is_active": config.is_active + }) + + +@router.post("/{config_id}/set-default", summary="设置默认大模型配置") +async def set_default_llm_config( + config_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """设置默认大模型配置.""" + username = current_user.username + session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}" + + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="大模型配置不存在" + ) + + # 检查配置是否激活 + if not config.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="只能将激活的配置设为默认" + ) + + # 取消同类型的其他默认配置 + stmt = update(LLMConfig).where( + LLMConfig.is_embedding == config.is_embedding, + LLMConfig.id != config_id + ).values({"is_default": False}) + await session.execute(stmt) + + # 设置当前配置为默认 + config.is_default = True + config.set_audit_fields(current_user.id, is_update=True) + + await session.commit() + await session.refresh(config) + + model_type = "嵌入模型" if config.is_embedding else "对话模型" + # 更新文档处理器默认embedding + await get_document_processor(session)._init_embeddings() + session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}" + return HxfResponse({ + "message": f"已将 {config.name} 设为默认{model_type}配置", + "is_default": config.is_default + }) \ No newline at end of file diff --git a/th_agenter/api/endpoints/roles.py b/th_agenter/api/endpoints/roles.py new file mode 100644 index 0000000..5db58e6 --- /dev/null +++ b/th_agenter/api/endpoints/roles.py @@ -0,0 +1,281 @@ +"""Role management API endpoints.""" + +from utils.util_exceptions import HxfResponse +from loguru import logger +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.orm import Session +from sqlalchemy import select, and_, or_, delete + +from ...core.simple_permissions import require_super_admin +from ...db.database import get_session +from ...models.user import User +from ...models.permission import Role, UserRole +from ...services.auth import AuthService +from ...schemas.permission import ( + RoleCreate, RoleUpdate, RoleResponse, + UserRoleAssign +) + +router = APIRouter(prefix="/roles", tags=["roles"]) + +@router.get("/", response_model=List[RoleResponse], summary="获取角色列表") +async def get_roles( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + search: Optional[str] = Query(None), + is_active: Optional[bool] = Query(None), + session: Session = Depends(get_session), + current_user = Depends(require_super_admin), +): + """获取角色列表.""" + session.desc = f"START: 获取用户 {current_user.username} 角色列表" + stmt = select(Role) + + # 搜索 + if search: + stmt = stmt.where( + or_( + Role.name.ilike(f"%{search}%"), + Role.code.ilike(f"%{search}%"), + Role.description.ilike(f"%{search}%") + ) + ) + + # 状态筛选 + if is_active is not None: + stmt = stmt.where(Role.is_active == is_active) + + # 分页 + stmt = stmt.offset(skip).limit(limit) + roles = (await session.execute(stmt)).scalars().all() + session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色" + response = [role.to_dict() for role in roles] + return HxfResponse(response) + +@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情") +async def get_role( + role_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """获取角色详情.""" + session.desc = f"START: 获取角色 {role_id} 详情" + stmt = select(Role).where(Role.id == role_id) + role = (await session.execute(stmt)).scalar_one_or_none() + if not role: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="角色不存在" + ) + + response = role.to_dict() + return HxfResponse(response) + +@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色") +async def create_role( + role_data: RoleCreate, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """创建角色.""" + session.desc = f"START: 创建角色 {role_data.name}" + # 检查角色代码是否已存在 + stmt = select(Role).where(Role.code == role_data.code) + existing_role = (await session.execute(stmt)).scalar_one_or_none() + if existing_role: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="角色代码已存在" + ) + + # 创建角色 + role = Role( + name=role_data.name, + code=role_data.code, + description=role_data.description, + is_active=role_data.is_active + ) + role.set_audit_fields(current_user.id) + + session.add(role) + await session.commit() + await session.refresh(role) + + logger.info(f"Role created: {role.name} by user {current_user.username}") + response = role.to_dict() + return HxfResponse(response) + +@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色") +async def update_role( + role_id: int, + role_data: RoleUpdate, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """更新角色.""" + session.desc = f"更新用户 {current_user.username} 角色 {role_id}" + stmt = select(Role).where(Role.id == role_id) + role = (await session.execute(stmt)).scalar_one_or_none() + if not role: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="角色不存在" + ) + + # 超级管理员角色不能被编辑 + if role.code == "SUPER_ADMIN": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="超级管理员角色不能被编辑" + ) + + # 检查角色编码是否已存在(排除当前角色) + if role_data.code and role_data.code != role.code: + stmt = select(Role).where( + and_( + Role.code == role_data.code, + Role.id != role_id + ) + ) + existing_role = (await session.execute(stmt)).scalar_one_or_none() + if existing_role: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="角色代码已存在" + ) + + # 更新字段 + update_data = role_data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(role, field, value) + + # Audit fields are set automatically by SQLAlchemy event listener + + await session.commit() + await session.refresh(role) + + logger.info(f"Role updated: {role.name} by user {current_user.username}") + response = role.to_dict() + return HxfResponse(response) + +@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色") +async def delete_role( + role_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """删除角色.""" + stmt = select(Role).where(Role.id == role_id) + role = (await session.execute(stmt)).scalar_one_or_none() + if not role: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="角色不存在" + ) + + # 超级管理员角色不能被删除 + if role.code == "SUPER_ADMIN": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="超级管理员角色不能被删除" + ) + + # 检查是否有用户使用该角色 + stmt = select(UserRole).where(UserRole.role_id == role_id) + user_count = (await session.execute(stmt)).scalars().count() + if user_count > 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"无法删除角色,还有 {user_count} 个用户关联此角色" + ) + + # 删除角色 + await session.delete(role) + await session.commit() + + session.desc = f"角色删除成功: {role.name} by user {current_user.username}" + response = {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"} + return HxfResponse(response) + +# 用户角色管理路由 +user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"]) + +@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED, summary="为用户分配角色") +async def assign_user_roles( + assignment_data: UserRoleAssign, + session: Session = Depends(get_session), + current_user: User = Depends(require_super_admin) +): + """为用户分配角色.""" + # 验证用户是否存在 + stmt = select(User).where(User.id == assignment_data.user_id) + user = (await session.execute(stmt)).scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在" + ) + + # 验证角色是否存在 + stmt = select(Role).where(Role.id.in_(assignment_data.role_ids)) + roles = (await session.execute(stmt)).scalars().all() + if len(roles) != len(assignment_data.role_ids): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="部分角色不存在" + ) + + # 删除现有角色关联 + stmt = delete(UserRole).where(UserRole.user_id == assignment_data.user_id) + await session.execute(stmt) + + # 添加新的角色关联 + for role_id in assignment_data.role_ids: + user_role = UserRole( + user_id=assignment_data.user_id, + role_id=role_id + ) + session.add(user_role) + + await session.commit() + + session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}" + + response = {"message": "角色分配成功"} + return HxfResponse(response) + +@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表") +async def get_user_roles( + user_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_active_user) +): + """获取用户角色列表.""" + # 检查权限:用户只能查看自己的角色,或者是超级管理员 + if current_user.id != user_id and not await current_user.is_superuser(): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="无权限查看其他用户的角色" + ) + + stmt = select(User).where(User.id == user_id) + user = (await session.execute(stmt)).scalar_one_or_none() + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在" + ) + + stmt = select(Role).join( + UserRole, Role.id == UserRole.role_id + ).where( + UserRole.user_id == user_id + ) + roles = (await session.execute(stmt)).scalars().all() + + response = [role.to_dict() for role in roles] + return HxfResponse(response) + +# 将子路由添加到主路由 +router.include_router(user_role_router) \ No newline at end of file diff --git a/th_agenter/api/endpoints/smart_chat.py b/th_agenter/api/endpoints/smart_chat.py new file mode 100644 index 0000000..a4d6e9f --- /dev/null +++ b/th_agenter/api/endpoints/smart_chat.py @@ -0,0 +1,338 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import HTTPBearer +from sqlalchemy.orm import Session +from typing import Optional, Dict, Any +from datetime import datetime + +from th_agenter.db.database import get_session +from th_agenter.services.auth import AuthService +from th_agenter.services.smart_workflow import SmartWorkflowManager +from th_agenter.services.conversation import ConversationService +from th_agenter.services.conversation_context import conversation_context_service +from utils.util_schemas import BaseResponse +from pydantic import BaseModel +from loguru import logger +from utils.util_exceptions import HxfResponse + +router = APIRouter(prefix="/smart-chat", tags=["smart-chat"]) +security = HTTPBearer() + +# Request/Response Models +class SmartQueryRequest(BaseModel): + query: str + conversation_id: Optional[int] = None + is_new_conversation: bool = False + +class SmartQueryResponse(BaseModel): + success: bool + message: str + data: Optional[Dict[str, Any]] = None + workflow_steps: Optional[list] = None + conversation_id: Optional[int] = None + +class ConversationContextResponse(BaseModel): + success: bool + message: str + data: Optional[Dict[str, Any]] = None + +@router.post("/query", response_model=SmartQueryResponse, summary="智能问数查询") +async def smart_query( + request: SmartQueryRequest, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 智能问数查询接口 + 支持新对话时自动加载文件列表,智能选择相关Excel文件,生成和执行pandas代码 + """ + session.desc = f"START: 用户 {current_user.username} 智能问数查询" + conversation_id = None + + try: + # 验证请求参数 + if not request.query or not request.query.strip(): + session.desc = "ERROR: 用户输入为空, 查询内容不能为空" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="查询内容不能为空" + ) + + if len(request.query) > 1000: + session.desc = "ERROR: 用户输入过长, 查询内容不能超过1000字符" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="查询内容过长,请控制在1000字符以内" + ) + + # 初始化工作流管理器 + workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() + + conversation_service = ConversationService(session) + + # 处理对话上下文 + conversation_id = request.conversation_id + + # 如果是新对话或没有指定对话ID,创建新对话 + if request.is_new_conversation or not conversation_id: + try: + conversation_id = await conversation_context_service.create_conversation( + user_id=current_user.id, + title=f"智能问数: {request.query[:20]}..." + ) + request.is_new_conversation = True + session.desc = f"创建新对话: {conversation_id}" + except Exception as e: + session.desc = f"WARNING: 创建对话失败,使用临时会话: {e}" + conversation_id = None + else: + # 验证对话是否存在且属于当前用户 + try: + context = await conversation_context_service.get_conversation_context(conversation_id) + if not context or context.get('user_id') != current_user.id: + session.desc = f"ERROR: 对话 {conversation_id} 不存在或无权访问" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="对话不存在或无权访问" + ) + session.desc = f"使用现有对话: {conversation_id}" + except HTTPException: + session.desc = f"EXCEPTION: 对话 {conversation_id} 不存在或无权访问" + raise + except Exception as e: + session.desc = f"ERROR: 验证对话失败: {e}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="对话验证失败" + ) + + # 保存用户消息 + if conversation_id: + try: + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="user", + content=request.query + ) + except Exception as e: + session.desc = f"WARNING: 保存用户消息失败: {e}" + # 不阻断流程,继续执行查询 + + # 执行智能查询工作流 + try: + result = await workflow_manager.process_smart_query( + user_query=request.query, + user_id=current_user.id, + conversation_id=conversation_id, + is_new_conversation=request.is_new_conversation + ) + except Exception as e: + session.desc = f"ERROR: 智能查询执行失败: {e}" + # 返回结构化的错误响应 + response = SmartQueryResponse( + success=False, + message=f"查询执行失败: {str(e)}", + data={'error_type': 'query_execution_error'}, + workflow_steps=[{ + 'step': 'error', + 'status': 'failed', + 'message': str(e) + }], + conversation_id=conversation_id + ) + return HxfResponse(response) + + # 如果查询成功,保存助手回复和更新上下文 + if result['success'] and conversation_id: + try: + # 保存助手回复 + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="assistant", + content=result.get('data', {}).get('summary', '查询完成'), + metadata={ + 'query_result': result.get('data'), + 'workflow_steps': result.get('workflow_steps', []), + 'selected_files': result.get('data', {}).get('used_files', []) + } + ) + + # 更新对话上下文 + await conversation_context_service.update_conversation_context( + conversation_id=conversation_id, + query=request.query, + selected_files=result.get('data', {}).get('used_files', []) + ) + + except Exception as e: + session.desc = f"EXCEPTION: 保存消息到对话历史失败: {e}" + # 不影响返回结果,只记录警告 + + # 返回结果,包含对话ID + response_data = result.get('data', {}) + if conversation_id: + response_data['conversation_id'] = conversation_id + session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}" + response = SmartQueryResponse( + success=result['success'], + message=result.get('message', '查询完成'), + data=response_data, + workflow_steps=result.get('workflow_steps', []), + conversation_id=conversation_id + ) + return HxfResponse(response) + + except HTTPException as e: + session.desc = f"EXCEPTION: HTTP异常: {e}" + raise e + except Exception as e: + session.desc = f"ERROR: 智能查询接口异常: {e}" + # 返回通用错误响应 + response = SmartQueryResponse( + success=False, + message="服务器内部错误,请稍后重试", + data={'error_type': 'internal_server_error'}, + workflow_steps=[{ + 'step': 'error', + 'status': 'failed', + 'message': '系统异常' + }], + conversation_id=conversation_id + ) + return HxfResponse(response) + +@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文") +async def get_conversation_context( + conversation_id: int, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 获取对话上下文信息,包括已使用的文件和历史查询 + """ + # 获取对话上下文 + session.desc = f"START: 获取对话上下文,对话ID: {conversation_id}" + context = await conversation_context_service.get_conversation_context(conversation_id) + + if not context: + session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="对话上下文不存在" + ) + + # 验证用户权限 + if context['user_id'] != current_user.id: + session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="无权访问此对话" + ) + + # 获取对话历史 + history = await conversation_context_service.get_conversation_history(conversation_id) + context['message_history'] = history + session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}" + response = ConversationContextResponse( + success=True, + message="获取对话上下文成功", + data=context + ) + return HxfResponse(response) + + +@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息") +async def get_files_status( + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 获取用户当前的文件状态和统计信息 + """ + session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}" + workflow_manager = SmartWorkflowManager() + await workflow_manager.initialize() + + # 获取用户文件列表 + file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id) + + # 统计信息 + total_files = len(file_list) + total_rows = sum(f.get('row_count', 0) for f in file_list) + total_columns = sum(f.get('column_count', 0) for f in file_list) + + # 文件类型统计 + file_types = {} + for file_info in file_list: + filename = file_info['filename'] + ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown' + file_types[ext] = file_types.get(ext, 0) + 1 + + status_data = { + 'total_files': total_files, + 'total_rows': total_rows, + 'total_columns': total_columns, + 'file_types': file_types, + 'files': [{ + 'id': f['id'], + 'filename': f['filename'], + 'row_count': f.get('row_count', 0), + 'column_count': f.get('column_count', 0), + 'columns': f.get('columns', []), + 'upload_time': f.get('upload_time') + } for f in file_list], + 'ready_for_query': total_files > 0 + } + + session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}" + response = ConversationContextResponse( + success=True, + message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件", + data=status_data + ) + return HxfResponse(response) + +@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文") +async def reset_conversation_context( + conversation_id: int, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 重置对话上下文,清除历史查询记录但保留文件 + """ + session.desc = f"START: 重置对话上下文,对话ID: {conversation_id}" + # 验证对话存在和用户权限 + context = await conversation_context_service.get_conversation_context(conversation_id) + + if not context: + session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="对话上下文不存在" + ) + + if context['user_id'] != current_user.id: + session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="无权访问此对话" + ) + + # 重置对话上下文 + success = await conversation_context_service.reset_conversation_context(conversation_id) + + if success: + session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}" + response = ConversationContextResponse( + success=True, + message="对话上下文已重置,可以开始新的数据分析会话" + ) + return HxfResponse(response) + else: + session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="重置对话上下文失败" + ) + \ No newline at end of file diff --git a/th_agenter/api/endpoints/smart_query.py b/th_agenter/api/endpoints/smart_query.py new file mode 100644 index 0000000..f12f410 --- /dev/null +++ b/th_agenter/api/endpoints/smart_query.py @@ -0,0 +1,755 @@ +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status +from fastapi.security import HTTPBearer +from sqlalchemy.orm import Session +from typing import Optional, Dict, Any, List +import pandas as pd +from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse, BaseResponse +import os +import tempfile +from th_agenter.services.smart_query import ( + SmartQueryService, + ExcelAnalysisService, + DatabaseQueryService +) +from th_agenter.services.excel_metadata_service import ExcelMetadataService + +import uuid +from pathlib import Path +from utils.util_file import FileUtils + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session +from typing import Optional, AsyncGenerator +import json +from datetime import datetime + +from th_agenter.db.database import get_session +from th_agenter.services.auth import AuthService +from th_agenter.services.smart_workflow import SmartWorkflowManager +from th_agenter.services.conversation_context import ConversationContextService +from pydantic import BaseModel +from loguru import logger +from utils.util_exceptions import HxfResponse + +router = APIRouter(prefix="/smart-query", tags=["smart-query"]) +security = HTTPBearer() + +# Request/Response Models +class DatabaseConfig(BaseModel): + type: str + host: str + port: str + database: str + username: str + password: str + +class QueryRequest(BaseModel): + query: str + page: int = 1 + page_size: int = 20 + table_name: Optional[str] = None + +class TableSchemaRequest(BaseModel): + table_name: str + +class ExcelUploadResponse(BaseModel): + file_id: int + success: bool + message: str + data: Optional[Dict[str, Any]] = None # 添加data字段 + +class QueryResponse(BaseModel): + success: bool + message: str + data: Optional[Dict[str, Any]] = None + +@router.post("/upload-excel", response_model=ExcelUploadResponse, summary="上传Excel文件并进行预处理") +async def upload_excel( + file: UploadFile = File(...), + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 上传Excel文件并进行预处理 + """ + session.desc = f"START: 用户 {current_user.username} 上传 Excel 文件并进行预处理" + # 验证文件类型 + allowed_extensions = ['.xlsx', '.xls', '.csv'] + file_extension = os.path.splitext(file.filename)[1].lower() + + if file_extension not in allowed_extensions: + session.desc = f"ERROR: 用户 {current_user.username} 上传了不支持的文件格式 {file_extension},请上传 .xlsx, .xls 或 .csv 文件" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件" + ) + + # 验证文件大小 (10MB) + content = await file.read() + file_size = len(content) + if file_size > 10 * 1024 * 1024: + session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 大小为 {file_size / (1024 * 1024):.2f}MB,超过最大限制 10MB" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="文件大小不能超过 10MB" + ) + + # 创建持久化目录结构 + backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录 + data_dir = backend_dir / "data/uploads" + excel_user_dir = data_dir / f"excel_{current_user.id}" + + # 确保目录存在 + excel_user_dir.mkdir(parents=True, exist_ok=True) + + # 生成文件名:{uuid}_{原始文件名称} + file_id = str(uuid.uuid4()) + safe_filename = FileUtils.sanitize_filename(file.filename) + new_filename = f"{file_id}_{safe_filename}" + file_path = excel_user_dir / new_filename + + # 保存文件 + with open(file_path, 'wb') as f: + f.write(content) + + # 使用Excel元信息服务提取并保存元信息 + metadata_service = ExcelMetadataService(session) + excel_file = metadata_service.save_file_metadata( + file_path=str(file_path), + original_filename=file.filename, + user_id=current_user.id, + file_size=file_size + ) + + # 为了兼容现有前端,仍然创建pickle文件 + try: + if file_extension == '.csv': + df = pd.read_csv(file_path, encoding='utf-8') + else: + df = pd.read_excel(file_path) + except UnicodeDecodeError: + if file_extension == '.csv': + df = pd.read_csv(file_path, encoding='gbk') + else: + session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 编码错误,请确保文件为UTF-8或GBK编码" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="文件编码错误,请确保文件为UTF-8或GBK编码" + ) + except Exception as e: + session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 读取失败: {str(e)}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"文件读取失败: {str(e)}" + ) + + # 保存pickle文件到同一目录 + pickle_filename = f"{file_id}_{safe_filename}.pkl" + pickle_path = excel_user_dir / pickle_filename + df.to_pickle(pickle_path) + + # 数据预处理和分析(保持兼容性) + excel_service = ExcelAnalysisService() + analysis_result = excel_service.analyze_dataframe(df, file.filename) + + # 添加数据库文件信息 + analysis_result.update({ + 'file_id': str(excel_file.id), + 'database_id': excel_file.id, + 'temp_file_path': str(pickle_path), # 更新为新的pickle路径 + 'original_filename': file.filename, + 'file_size_mb': excel_file.file_size_mb, + 'sheet_names': excel_file.sheet_names, + }) + + session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}" + response = ExcelUploadResponse( + file_id=excel_file.id, + success=True, + message="Excel文件上传成功", + data=analysis_result + ) + return HxfResponse(response) + +@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据") +async def preview_excel( + request: ExcelPreviewRequest, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 预览Excel文件数据 + """ + session.desc = f"START: 用户 {current_user.username} 预览文件 {request.file_id}" + + # 验证file_id格式 + try: + file_id = int(request.file_id) + except ValueError: + session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 提供了无效的文件ID格式: {request.file_id}" + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"无效的文件ID格式: {request.file_id}" + ) + + # 从数据库获取文件信息 + metadata_service = ExcelMetadataService(session) + excel_file = metadata_service.get_file_by_id(file_id, current_user.id) + + if not excel_file: + session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 不存在或已被删除" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="文件不存在或已被删除" + ) + + # 检查文件是否存在 + if not os.path.exists(excel_file.file_path): + session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 已被移动或删除" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="文件已被移动或删除" + ) + + # 更新最后访问时间 + metadata_service.update_last_accessed(file_id, current_user.id) + + # 读取Excel文件 + if excel_file.file_type.lower() == 'csv': + df = pd.read_csv(excel_file.file_path, encoding='utf-8') + else: + # 对于Excel文件,使用默认sheet或第一个sheet + sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0 + df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name) + + # 计算分页 + total_rows = len(df) + start_idx = (request.page - 1) * request.page_size + end_idx = start_idx + request.page_size + + # 获取分页数据 + paginated_df = df.iloc[start_idx:end_idx] + + # 转换为字典格式 + data = paginated_df.fillna('').to_dict('records') + columns = df.columns.tolist() + session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据" + response = QueryResponse( + success=True, + message="Excel文件预览加载成功", + data={ + 'data': data, + 'columns': columns, + 'total_rows': total_rows, + 'page': request.page, + 'page_size': request.page_size, + 'total_pages': (total_rows + request.page_size - 1) // request.page_size + } + ) + return HxfResponse(response) + +@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接") +async def test_database_connection( + config: DatabaseConfig, + current_user = Depends(AuthService.get_current_user) +): + """ + 测试数据库连接 + """ + try: + db_service = DatabaseQueryService() + is_connected = await db_service.test_connection(config.model_dump()) + + if is_connected: + return NormalResponse( + success=True, + message="数据库连接测试成功" + ) + else: + response = NormalResponse( + success=False, + message="数据库连接测试失败" + ) + return HxfResponse(response) + + except Exception as e: + return NormalResponse( + success=False, + message=f"连接测试失败: {str(e)}" + ) + +# 删除第285-314行的connect_database方法 +# @router.post("/connect-database", response_model=QueryResponse) +# async def connect_database( +# config_id: int, +# current_user = Depends(AuthService.get_current_user), +# db: Session = Depends(get_session) +# ): +# """连接数据库并获取表列表""" +# ... (整个方法都删除) + +@router.post("/table-schema", response_model=QueryResponse, summary="获取数据表结构") +async def get_table_schema( + request: TableSchemaRequest, + current_user = Depends(AuthService.get_current_user) +): + """ + 获取数据表结构 + """ + try: + db_service = DatabaseQueryService() + schema_result = await db_service.get_table_schema(request.table_name, current_user.id) + + if schema_result['success']: + response = QueryResponse( + success=True, + message="获取表结构成功", + data=schema_result['data'] + ) + return HxfResponse(response) + else: + response = QueryResponse( + success=False, + message=schema_result['message'] + ) + return HxfResponse(response) + + except Exception as e: + response = QueryResponse( + success=False, + message=f"获取表结构失败: {str(e)}" + ) + return HxfResponse(response) + +class StreamQueryRequest(BaseModel): + query: str + conversation_id: Optional[int] = None + is_new_conversation: bool = False + +class DatabaseStreamQueryRequest(BaseModel): + query: str + database_config_id: int + conversation_id: Optional[int] = None + is_new_conversation: bool = False + +@router.post("/execute-excel-query", summary="流式智能问答查询") +async def stream_smart_query( + request: StreamQueryRequest, + current_user=Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 流式智能问答查询接口 + 支持实时推送工作流步骤和最终结果 + """ + + async def generate_stream() -> AsyncGenerator[str, None]: + workflow_manager = None + + try: + # 验证请求参数 + if not request.query or not request.query.strip(): + yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n" + return + + if len(request.query) > 1000: + yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n" + return + + # 发送开始信号 + yield f"data: {json.dumps({'type': 'start', 'message': '开始处理查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + # 初始化服务 + workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() + + conversation_context_service = ConversationContextService() + + # 处理对话上下文 + conversation_id = request.conversation_id + + # 如果是新对话或没有指定对话ID,创建新对话 + if request.is_new_conversation or not conversation_id: + try: + conversation_id = await conversation_context_service.create_conversation( + user_id=current_user.id, + title=f"智能问数: {request.query[:20]}..." + ) + yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n" + except Exception as e: + logger.warning(f"创建对话失败: {e}") + # 不阻断流程,继续执行查询 + + # 保存用户消息 + if conversation_id: + try: + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="user", + content=request.query + ) + except Exception as e: + logger.warning(f"保存用户消息失败: {e}") + + # 执行智能查询工作流(带流式推送) + async for step_data in workflow_manager.process_excel_query_stream( + user_query=request.query, + user_id=current_user.id, + conversation_id=conversation_id, + is_new_conversation=request.is_new_conversation + ): + # 推送工作流步骤 + yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n" + + # 如果是最终结果,保存到对话历史 + if step_data.get('type') == 'final_result' and conversation_id: + try: + result_data = step_data.get('data', {}) + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="assistant", + content=result_data.get('summary', '查询完成'), + metadata={ + 'query_result': result_data, + 'workflow_steps': step_data.get('workflow_steps', []), + 'selected_files': result_data.get('used_files', []) + } + ) + + # 更新对话上下文 + await conversation_context_service.update_conversation_context( + conversation_id=conversation_id, + query=request.query, + selected_files=result_data.get('used_files', []) + ) + + logger.info(f"查询成功完成,对话ID: {conversation_id}") + + except Exception as e: + logger.warning(f"保存消息到对话历史失败: {e}") + + # 发送完成信号 + yield f"data: {json.dumps({'type': 'complete', 'message': '查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + except Exception as e: + logger.error(f"流式智能查询异常: {e}", exc_info=True) + yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n" + + finally: + # 清理资源 + if workflow_manager: + try: + workflow_manager.excel_workflow.executor.shutdown(wait=False) + except: + pass + + response = StreamingResponse( + generate_stream(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*" + } + ) + return HxfResponse(response) + +@router.post("/execute-db-query", summary="流式数据库查询") +async def execute_database_query( + request: DatabaseStreamQueryRequest, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 流式数据库查询接口 + 支持实时推送工作流步骤和最终结果 + """ + + async def generate_stream() -> AsyncGenerator[str, None]: + workflow_manager = None + + try: + # 验证请求参数 + if not request.query or not request.query.strip(): + yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n" + return + + if len(request.query) > 1000: + yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n" + return + + # 发送开始信号 + yield f"data: {json.dumps({'type': 'start', 'message': '开始处理数据库查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + # 初始化服务 + workflow_manager = SmartWorkflowManager(session) + await workflow_manager.initialize() + conversation_context_service = ConversationContextService() + + # 处理对话上下文 + conversation_id = request.conversation_id + + # 如果是新对话或没有指定对话ID,创建新对话 + if request.is_new_conversation or not conversation_id: + try: + conversation_id = await conversation_context_service.create_conversation( + user_id=current_user.id, + title=f"数据库查询: {request.query[:20]}..." + ) + yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n" + except Exception as e: + logger.warning(f"创建对话失败: {e}") + # 不阻断流程,继续执行查询 + + # 保存用户消息 + if conversation_id: + try: + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="user", + content=request.query + ) + except Exception as e: + logger.warning(f"保存用户消息失败: {e}") + + # 执行数据库查询工作流(带流式推送) + async for step_data in workflow_manager.process_database_query_stream( + user_query=request.query, + user_id=current_user.id, + database_config_id=request.database_config_id + ): + # 推送工作流步骤 + yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n" + + # 如果是最终结果,保存到对话历史 + if step_data.get('type') == 'final_result' and conversation_id: + try: + result_data = step_data.get('data', {}) + await conversation_context_service.save_message( + conversation_id=conversation_id, + role="assistant", + content=result_data.get('summary', '查询完成'), + metadata={ + 'query_result': result_data, + 'workflow_steps': step_data.get('workflow_steps', []), + 'generated_sql': result_data.get('generated_sql', '') + } + ) + + # 更新对话上下文 + await conversation_context_service.update_conversation_context( + conversation_id=conversation_id, + query=request.query, + selected_files=[] + ) + + logger.info(f"数据库查询成功完成,对话ID: {conversation_id}") + + except Exception as e: + logger.warning(f"保存消息到对话历史失败: {e}") + + # 发送完成信号 + yield f"data: {json.dumps({'type': 'complete', 'message': '数据库查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + except Exception as e: + logger.error(f"流式数据库查询异常: {e}", exc_info=True) + yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n" + + finally: + # 清理资源 + if workflow_manager: + try: + workflow_manager.database_workflow.executor.shutdown(wait=False) + except: + pass + + response = StreamingResponse( + generate_stream(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*" + } + ) + return HxfResponse(response) + +@router.delete("/cleanup-temp-files", summary="清理临时文件") +async def cleanup_temp_files( + current_user = Depends(AuthService.get_current_user) +): + """ + 清理临时文件 + """ + try: + temp_dir = tempfile.gettempdir() + user_prefix = f"excel_{current_user.id}_" + + cleaned_count = 0 + for filename in os.listdir(temp_dir): + if filename.startswith(user_prefix) and filename.endswith('.pkl'): + file_path = os.path.join(temp_dir, filename) + try: + os.remove(file_path) + cleaned_count += 1 + except OSError: + pass + + response = BaseResponse( + success=True, + message=f"已清理 {cleaned_count} 个临时文件" + ) + return HxfResponse(response) + + except Exception as e: + response = BaseResponse( + success=False, + message=f"清理临时文件失败: {str(e)}" + ) + return HxfResponse(response) + +@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表") +async def get_file_list( + page: int = 1, + page_size: int = 20, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 获取用户上传的Excel文件列表 + """ + try: + session.desc = f"START: 获取用户 {current_user.id} 的文件列表" + metadata_service = ExcelMetadataService(session) + skip = (page - 1) * page_size + files, total = metadata_service.get_user_files(current_user.id, skip, page_size) + + file_list = [] + for file in files: + file_info = { + 'id': file.id, + 'filename': file.original_filename, + 'file_size': file.file_size, + 'file_size_mb': file.file_size_mb, + 'file_type': file.file_type, + 'sheet_names': file.sheet_names, + 'sheet_count': file.sheet_count, + 'last_accessed': file.last_accessed.isoformat() if file.last_accessed else None, + 'is_processed': file.is_processed, + 'processing_error': file.processing_error + } + file_list.append(file_info) + + session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件" + response = FileListResponse( + success=True, + message="获取文件列表成功", + data={ + 'files': file_list, + 'total': total, + 'page': page, + 'page_size': page_size, + 'total_pages': (total + page_size - 1) // page_size + } + ) + return HxfResponse(response) + + except Exception as e: + response = FileListResponse( + success=False, + message=f"获取文件列表失败: {str(e)}" + ) + return HxfResponse(response) + +@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件") +async def delete_file( + file_id: int, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 删除指定的Excel文件 + """ + try: + session.desc = f"START: 删除用户 {current_user.id} 的文件 {file_id}" + metadata_service = ExcelMetadataService(session) + success = metadata_service.delete_file(file_id, current_user.id) + + if success: + session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}" + response = NormalResponse( + success=True, + message="文件删除成功" + ) + return HxfResponse(response) + + else: + session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败" + response = NormalResponse( + success=False, + message="文件不存在或删除失败" + ) + return HxfResponse(response) + + except Exception as e: + response = NormalResponse( + success=True, + message=str(e) + ) + return HxfResponse(response) + +@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息") +async def get_file_info( + file_id: int, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """ + 获取指定文件的详细信息 + """ + metadata_service = ExcelMetadataService(session) + excel_file = metadata_service.get_file_by_id(file_id, current_user.id) + + if not excel_file: + session.desc = f"ERROR: 获取用户 {current_user.id} 的文件 {file_id} 信息,文件不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="文件不存在" + ) + + # 更新最后访问时间 + metadata_service.update_last_accessed(file_id, current_user.id) + + file_info = { + 'id': excel_file.id, + 'filename': excel_file.original_filename, + 'file_size': excel_file.file_size, + 'file_size_mb': excel_file.file_size_mb, + 'file_type': excel_file.file_type, + 'sheet_names': excel_file.sheet_names, + 'default_sheet': excel_file.default_sheet, + 'columns_info': excel_file.columns_info, + 'preview_data': excel_file.preview_data, + 'data_types': excel_file.data_types, + 'total_rows': excel_file.total_rows, + 'total_columns': excel_file.total_columns, + 'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None, + 'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None, + 'sheets_summary': excel_file.get_all_sheets_summary() + } + + session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件 {file_id} 信息" + response = QueryResponse( + success=True, + message="获取文件信息成功", + data=file_info + ) + return HxfResponse(response) + \ No newline at end of file diff --git a/th_agenter/api/endpoints/table_metadata.py b/th_agenter/api/endpoints/table_metadata.py new file mode 100644 index 0000000..c2b439f --- /dev/null +++ b/th_agenter/api/endpoints/table_metadata.py @@ -0,0 +1,235 @@ +"""表元数据管理API""" +from loguru import logger +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List, Dict, Any +from pydantic import BaseModel, Field + +from th_agenter.models.user import User +from th_agenter.db.database import get_session +from th_agenter.services.table_metadata_service import TableMetadataService +from th_agenter.services.auth import AuthService +from utils.util_exceptions import HxfResponse + +router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"]) + +class TableSelectionRequest(BaseModel): + database_config_id: int = Field(..., description="数据库配置ID") + table_names: List[str] = Field(..., description="选中的表名列表") + +class TableMetadataResponse(BaseModel): + id: int + table_name: str + table_schema: str + table_type: str + table_comment: str + columns_count: int + row_count: int + is_enabled_for_qa: bool + qa_description: str + business_context: str + last_synced_at: str + +class QASettingsUpdate(BaseModel): + is_enabled_for_qa: bool = Field(default=True) + qa_description: str = Field(default="") + business_context: str = Field(default="") + +class TableByNameRequest(BaseModel): + database_config_id: int = Field(..., description="数据库配置ID") + table_name: str = Field(..., description="表名") + + +@router.post("/collect", summary="收集选中表的元数据") +async def collect_table_metadata( + request: TableSelectionRequest, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """收集选中表的元数据""" + session.desc = f"START: 用户 {current_user.id} 收集表元数据" + service = TableMetadataService(session) + result = await service.collect_and_save_table_metadata( + current_user.id, + request.database_config_id, + request.table_names + ) + session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据" + return HxfResponse(result) + +@router.get("/", summary="获取用户表元数据列表") +async def get_table_metadata( + database_config_id: int = None, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """获取表元数据列表""" + try: + service = TableMetadataService(session) + metadata_list = await service.get_user_table_metadata( + current_user.id, + database_config_id + ) + + data = [ + { + "id": meta.id, + "table_name": meta.table_name, + "table_schema": meta.table_schema, + "table_type": meta.table_type, + "table_comment": meta.table_comment or "", + "columns": meta.columns_info if meta.columns_info else [], + "column_count": len(meta.columns_info) if meta.columns_info else 0, + "row_count": meta.row_count, + "is_enabled_for_qa": meta.is_enabled_for_qa, + "qa_description": meta.qa_description or "", + "business_context": meta.business_context or "", + "created_at": meta.created_at.isoformat() if meta.created_at else "", + "updated_at": meta.updated_at.isoformat() if meta.updated_at else "", + "last_synced_at": meta.last_synced_at.isoformat() if meta.last_synced_at else "", + "qa_settings": { + "is_enabled_for_qa": meta.is_enabled_for_qa, + "qa_description": meta.qa_description or "", + "business_context": meta.business_context or "" + } + } + for meta in metadata_list + ] + + return HxfResponse({ + "success": True, + "data": data + }) + + except Exception as e: + logger.error(f"获取表元数据失败: {str(e)}") + return HxfResponse({ + "success": False, + "message": str(e) + }) + + +@router.post("/by-table", summary="根据表名获取表元数据") +async def get_table_metadata_by_name( + request: TableByNameRequest, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """根据表名获取表元数据""" + try: + service = TableMetadataService(session) + metadata = await service.get_table_metadata_by_name( + current_user.id, + request.database_config_id, + request.table_name + ) + + if metadata: + data = { + "id": metadata.id, + "table_name": metadata.table_name, + "table_schema": metadata.table_schema, + "table_type": metadata.table_type, + "table_comment": metadata.table_comment or "", + "columns": metadata.columns_info if metadata.columns_info else [], + "column_count": len(metadata.columns_info) if metadata.columns_info else 0, + "row_count": metadata.row_count, + "is_enabled_for_qa": metadata.is_enabled_for_qa, + "qa_description": metadata.qa_description or "", + "business_context": metadata.business_context or "", + "created_at": metadata.created_at.isoformat() if metadata.created_at else "", + "updated_at": metadata.updated_at.isoformat() if metadata.updated_at else "", + "last_synced_at": metadata.last_synced_at.isoformat() if metadata.last_synced_at else "", + "qa_settings": { + "is_enabled_for_qa": metadata.is_enabled_for_qa, + "qa_description": metadata.qa_description or "", + "business_context": metadata.business_context or "" + } + } + return HxfResponse({ + "success": True, + "data": data + }) + else: + return HxfResponse({ + "success": False, + "data": None, + "message": "表元数据不存在" + }) + + except Exception as e: + logger.error(f"获取表元数据失败: {str(e)}") + return HxfResponse({ + "success": False, + "message": str(e) + }) + + except Exception as e: + logger.error(f"获取表元数据失败: {str(e)}") + return HxfResponse({ + "success": False, + "message": str(e) + }) + + +@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置") +async def update_qa_settings( + metadata_id: int, + settings: QASettingsUpdate, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """更新表的问答设置""" + try: + service = TableMetadataService(session) + success = await service.update_table_qa_settings( + current_user.id, + metadata_id, + settings.model_dump() + ) + + if success: + return HxfResponse({ + "success": True, + "message": "设置更新成功" + }) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="表元数据不存在" + ) + except Exception as e: + logger.error(f"更新问答设置失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + + +class TableSaveRequest(BaseModel): + database_config_id: int = Field(..., description="数据库配置ID") + table_names: List[str] = Field(..., description="要保存的表名列表") + + +@router.post("/save") +async def save_table_metadata( + request: TableSaveRequest, + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """保存选中表的元数据配置""" + service = TableMetadataService(session) + result = await service.save_table_metadata_config( + user_id=current_user.id, + database_config_id=request.database_config_id, + table_names=request.table_names + ) + + session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置" + + return HxfResponse({ + "success": True, + "message": f"成功保存 {len(result['saved_tables'])} 个表的配置", + "saved_tables": result['saved_tables'], + "failed_tables": result.get('failed_tables', []) + }) \ No newline at end of file diff --git a/th_agenter/api/endpoints/users.py b/th_agenter/api/endpoints/users.py new file mode 100644 index 0000000..b3b3656 --- /dev/null +++ b/th_agenter/api/endpoints/users.py @@ -0,0 +1,247 @@ +"""User management endpoints.""" + +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.orm import Session + +from ...db.database import get_session +from ...core.simple_permissions import require_super_admin +from ...services.auth import AuthService +from ...services.user import UserService +from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest +from utils.util_exceptions import HxfResponse + +router = APIRouter() + +@router.get("/profile", response_model=UserResponse, summary="获取当前用户的个人信息") +async def get_user_profile( + current_user = Depends(AuthService.get_current_user) +): + """获取当前用户的个人信息.""" + response = UserResponse.model_validate(current_user) + return HxfResponse(response) + +@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息") +async def update_user_profile( + user_update: UserUpdate, + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """更新当前用户的个人信息.""" + user_service = UserService(session) + + # Check if email is being changed and is already taken + if user_update.email and user_update.email != current_user.email: + existing_user = await user_service.get_user_by_email(user_update.email) + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered" + ) + + # Update user + updated_user = await user_service.update_user(current_user.id, user_update) + response = UserResponse.model_validate(updated_user) + return HxfResponse(response) + +@router.delete("/profile", summary="删除当前用户的账户") +async def delete_user_account( + current_user = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +): + """删除当前用户的账户.""" + username = current_user.username + user_service = UserService(session) + await user_service.delete_user(current_user.id) + session.desc = f"删除用户 [{username}] 成功" + response = {"message": f"删除用户 {username} 成功"} + return HxfResponse(response) + +# Admin endpoints +@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)") +async def create_user( + user_create: UserCreate, + current_user = Depends(require_super_admin), + session: Session = Depends(get_session) +): + """创建一个新用户 (需要有管理员权限).""" + user_service = UserService(session) + + # Check if username already exists + existing_user = await user_service.get_user_by_username(user_create.username) + if existing_user: + session.desc = f"创建用户 [{user_create.username}] 失败 - 用户名已存在" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered" + ) + + # Check if email already exists + existing_user = await user_service.get_user_by_email(user_create.email) + if existing_user: + session.desc = f"创建用户 [{user_create.username}] 失败 - 邮箱已存在" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered" + ) + + # Create user + new_user = await user_service.create_user(user_create) + response = UserResponse.model_validate(new_user) + return HxfResponse(response) + +@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)") +async def list_users( + page: int = Query(1, ge=1), + size: int = Query(20, ge=1, le=100), + search: Optional[str] = Query(None), + role_id: Optional[int] = Query(None), + is_active: Optional[bool] = Query(None), + session: Session = Depends(get_session) +): + """列出所有用户,支持分页和筛选 (仅管理员权限).""" + session.desc = f"START: 列出所有用户,分页={page}, 每页大小={size}, 搜索={search}, 角色ID={role_id}, 激活状态={is_active}" + user_service = UserService(session) + skip = (page - 1) * size + users, total = await user_service.get_users_with_filters( + skip=skip, + limit=size, + search=search, + role_id=role_id, + is_active=is_active + ) + result = { + "users": [UserResponse.model_validate(user) for user in users], + "total": total, + "page": page, + "page_size": size + } + return HxfResponse(result) + + +@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)") +async def get_user( + user_id: int, + current_user = Depends(AuthService.get_current_active_user), + session: Session = Depends(get_session) +): + """通过ID获取用户信息 (仅管理员权限).""" + user_service = UserService(session) + user = await user_service.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + response = UserResponse.model_validate(user) + return HxfResponse(response) + +@router.put("/change-password", summary="修改当前用户的密码") +async def change_password( + request: ChangePasswordRequest, + current_user = Depends(AuthService.get_current_active_user), + session: Session = Depends(get_session) +): + """修改当前用户的密码.""" + user_service = UserService(session) + + try: + await user_service.change_password( + user_id=current_user.id, + current_password=request.current_password, + new_password=request.new_password + ) + response = {"message": "Password changed successfully"} + return HxfResponse(response) + except Exception as e: + if "Current password is incorrect" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Current password is incorrect" + ) + elif "must be at least 6 characters" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="New password must be at least 6 characters long" + ) + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to change password" + ) + +@router.put("/{user_id}/reset-password", summary="重置用户密码 (仅管理员权限)") +async def reset_user_password( + user_id: int, + request: ResetPasswordRequest, + current_user = Depends(require_super_admin), + session: Session = Depends(get_session) +): + """重置用户密码 (仅管理员权限).""" + user_service = UserService(session) + + try: + await user_service.reset_password( + user_id=user_id, + new_password=request.new_password + ) + response = {"message": "Password reset successfully"} + return HxfResponse(response) + except Exception as e: + if "User not found" in str(e): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + elif "must be at least 6 characters" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="New password must be at least 6 characters long" + ) + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to reset password" + ) + + +@router.put("/{user_id}", response_model=UserResponse, summary="更新用户信息 (仅管理员权限)") +async def update_user( + user_id: int, + user_update: UserUpdate, + current_user = Depends(AuthService.get_current_active_user), + session: Session = Depends(get_session) +): + """更新用户信息 (仅管理员权限).""" + user_service = UserService(session) + + user = await user_service.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + updated_user = await user_service.update_user(user_id, user_update) + response = UserResponse.model_validate(updated_user) + return HxfResponse(response) + +@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)") +async def delete_user( + user_id: int, + current_user = Depends(AuthService.get_current_active_user), + session: Session = Depends(get_session) +): + """删除用户 (仅管理员权限).""" + user_service = UserService(session) + + user = await user_service.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + await user_service.delete_user(user_id) + response = {"message": "User deleted successfully"} + return HxfResponse(response) diff --git a/th_agenter/api/endpoints/workflow.py b/th_agenter/api/endpoints/workflow.py new file mode 100644 index 0000000..5759211 --- /dev/null +++ b/th_agenter/api/endpoints/workflow.py @@ -0,0 +1,531 @@ +"""工作流管理API""" + +from typing import List, Optional, AsyncGenerator +from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session +from sqlalchemy import select, and_, func +import json +from datetime import datetime + +from ...db.database import get_session +from ...schemas.workflow import ( + WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse, + WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus +) +from ...models.workflow import WorkflowStatus as ModelWorkflowStatus +from ...services.workflow_engine import get_workflow_engine +from ...services.auth import AuthService +from ...models.user import User +from loguru import logger +from utils.util_exceptions import HxfResponse +router = APIRouter() + +def convert_workflow_for_response(workflow_dict): + """转换工作流数据以适配响应模型""" + if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'): + for conn in workflow_dict['definition']['connections']: + if 'from_node' in conn: + conn['from'] = conn.pop('from_node') + if 'to_node' in conn: + conn['to'] = conn.pop('to_node') + return workflow_dict + +@router.get("/", response_model=WorkflowListResponse, summary="获取工作流列表") +async def list_workflows( + skip: Optional[int] = Query(None, ge=0), + limit: Optional[int] = Query(None, ge=1, le=100), + workflow_status: Optional[WorkflowStatus] = None, + search: Optional[str] = Query(None), + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取工作流列表""" + from ...models.workflow import Workflow + session.title = f"获取用户 {current_user.username} 的所有工作流" + session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})" + + # 构建查询 + stmt = select(Workflow).where(Workflow.owner_id == current_user.id) + + if workflow_status: + stmt = stmt.where(Workflow.status == workflow_status) + + # 添加搜索功能 + if search: + stmt = stmt.where(Workflow.name.ilike(f"%{search}%")) + + # 获取总数 + count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id) + if workflow_status: + count_query = count_query.where(Workflow.status == workflow_status) + if search: + count_query = count_query.where(Workflow.name.ilike(f"%{search}%")) + + session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}" + total = await session.scalar(count_query) + session.desc = f"查询结果: 共 {total} 条" + + # 如果没有传分页参数,返回所有数据 + if skip is None and limit is None: + workflows = (await session.scalars(stmt)).all() + session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)} 条" + response = WorkflowListResponse( + workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows], + total=total, + page=1, + size=total + ) + return HxfResponse(response) + + # 使用默认分页参数 + if skip is None: + skip = 0 + if limit is None: + limit = 10 + + # 分页查询 + workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all() + session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)} 条" + + response = WorkflowListResponse( + workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows], + total=total, + page=skip // limit + 1, # 计算页码 + size=limit + ) + return HxfResponse(response) + +@router.get("/{workflow_id}", response_model=WorkflowResponse, summary="获取工作流详情") +async def get_workflow( + workflow_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取工作流详情""" + from ...models.workflow import Workflow + session.title = f"获取工作流 {workflow_id}" + session.desc = f"START: 获取工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).where( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + + if not workflow: + session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}" + response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) + return HxfResponse(response) + +@router.put("/{workflow_id}", response_model=WorkflowResponse, summary="更新工作流") +async def update_workflow( + workflow_id: int, + workflow_data: WorkflowUpdate, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """更新工作流""" + from ...models.workflow import Workflow + session.title = f"更新工作流 {workflow_id}" + session.desc = f"START: 更新工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).where( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + workflow_data.status = WorkflowStatus.PUBLISHED + # 更新字段 + session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}" + update_data = workflow_data.model_dump(exclude_unset=True) + for field, value in update_data.items(): + if field == "definition" and value: + # 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用 + if hasattr(value, 'dict'): + setattr(workflow, field, value.dict()) + else: + setattr(workflow, field, value) + else: + setattr(workflow, field, value) + + workflow.set_audit_fields(current_user.id, is_update=True) + + await session.commit() + await session.refresh(workflow) + session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}" + response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict())) + return HxfResponse(response) + +@router.delete("/{workflow_id}", summary="删除工作流") +async def delete_workflow( + workflow_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """删除工作流""" + from ...models.workflow import Workflow + session.title = f"删除工作流 {workflow_id}" + session.desc = f"START: 删除工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).filter( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + session.desc = f"删除工作流: {workflow.name}" + await session.delete(workflow) + await session.commit() + + session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}" + response = {"message": "工作流删除成功"} + return HxfResponse(response) + + +@router.post("/{workflow_id}/activate", summary="激活工作流") +async def activate_workflow( + workflow_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """激活工作流""" + from ...models.workflow import Workflow + session.title = f"激活工作流 {workflow_id}" + session.desc = f"START: 激活工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).filter( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + workflow.status = ModelWorkflowStatus.PUBLISHED + workflow.set_audit_fields(current_user.id, is_update=True) + await session.commit() + + session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}" + response = {"message": "工作流激活成功"} + return HxfResponse(response) + +@router.post("/{workflow_id}/deactivate", summary="停用工作流") +async def deactivate_workflow( + workflow_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """停用工作流""" + from ...models.workflow import Workflow + session.title = f"停用工作流 {workflow_id}" + session.desc = f"START: 停用工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).filter( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + workflow.status = ModelWorkflowStatus.ARCHIVED + workflow.set_audit_fields(current_user.id, is_update=True) + + await session.commit() + + session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}" + response = {"message": "工作流停用成功"} + return HxfResponse(response) + +@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse], summary="获取工作流执行历史") +async def list_workflow_executions( + workflow_id: int, + skip: int = Query(0, ge=0), + limit: int = Query(10, ge=1, le=100), + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取工作流执行历史""" + session.title = f"获取工作流执行历史 {workflow_id}" + session.desc = f"START: 获取工作流执行历史 {workflow_id}" + try: + from ...models.workflow import Workflow, WorkflowExecution + + # 验证工作流所有权 + workflow = await session.scalar( + select(Workflow).where( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + # 获取执行历史 + executions = (await session.scalars( + select(WorkflowExecution).where( + WorkflowExecution.workflow_id == workflow_id + ).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit) + )).all() + + session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}" + response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions] + return HxfResponse(response) + + except Exception as e: + session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="获取执行历史失败" + ) + +@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse, summary="获取工作流执行详情") +async def get_workflow_execution( + execution_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """获取工作流执行详情""" + session.title = f"获取工作流执行详情 {execution_id}" + session.desc = f"START: 获取工作流执行详情 {execution_id}" + try: + from ...models.workflow import WorkflowExecution, Workflow + + execution = await session.scalar( + select(WorkflowExecution).join( + Workflow, WorkflowExecution.workflow_id == Workflow.id + ).where( + WorkflowExecution.id == execution_id, + Workflow.owner_id == current_user.id + ) + ) + + if not execution: + session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="执行记录不存在" + ) + + response = WorkflowExecutionResponse.model_validate(execution) + session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}" + return HxfResponse(response) + + except Exception as e: + session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="获取执行详情失败" + ) + + +@router.post("/{workflow_id}/execute-stream", summary="流式执行工作流") +async def execute_workflow_stream( + workflow_id: int, + request: WorkflowExecuteRequest, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """流式执行工作流,实时推送节点执行状态""" + session.title = f"流式执行工作流 {workflow_id}" + session.desc = f"START: 流式执行工作流 {workflow_id}" + async def generate_stream() -> AsyncGenerator[str, None]: + workflow_engine = None + + try: + from ...models.workflow import Workflow + + # 验证工作流 + workflow = await session.scalar( + select(Workflow).filter( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n" + return + + if workflow.status != ModelWorkflowStatus.PUBLISHED: + yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n" + return + + # 发送开始信号 + yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + # 获取工作流引擎 + workflow_engine = await get_workflow_engine(session) + + # 执行工作流(流式版本) + async for step_data in workflow_engine.execute_workflow_stream( + workflow=workflow, + input_data=request.input_data, + user_id=current_user.id, + session=session + ): + # 推送工作流步骤 + yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n" + + # 发送完成信号 + yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n" + + except Exception as e: + logger.error(f"流式工作流执行异常: {e}", exc_info=True) + yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n" + + response = StreamingResponse( + generate_stream(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*" + } + ) + session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕" + return HxfResponse(response) + +# ----------------------------------------------------------------------- + +@router.post("/", response_model=WorkflowResponse, summary="创建工作流") +async def create_workflow( + workflow_data: WorkflowCreate, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """创建工作流""" + from ...models.workflow import Workflow + session.title = f"创建工作流 {workflow_data.name}" + session.desc = f"START: 创建工作流 {workflow_data.name}" + # 创建工作流 + workflow = Workflow( + name=workflow_data.name, + description=workflow_data.description, + definition=workflow_data.definition.model_dump(), + version="1.0.0", + status=ModelWorkflowStatus.PUBLISHED, # workflow_data.status, + owner_id=current_user.id + ) + session.desc = f"创建工作流实例 - Workflow(), {workflow_data.name}" + workflow.set_audit_fields(current_user.id) + session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}" + + session.add(workflow) + await session.commit() + await session.refresh(workflow) + session.desc = f"保存工作流 - commit & refresh {workflow_data.name}" + # 转换definition中的字段映射 + workflow_dict = convert_workflow_for_response(workflow.to_dict()) + session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}" + + response = WorkflowResponse(**workflow_dict) + session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}" + return HxfResponse(response) + + +@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse, summary="执行工作流") +async def execute_workflow( + workflow_id: int, + request: WorkflowExecuteRequest, + session: Session = Depends(get_session), + current_user: User = Depends(AuthService.get_current_user) +): + """执行工作流""" + from ...models.workflow import Workflow + session.title = f"执行工作流 {workflow_id}" + session.desc = f"START: 执行工作流 {workflow_id}" + + workflow = await session.scalar( + select(Workflow).filter( + and_( + Workflow.id == workflow_id, + Workflow.owner_id == current_user.id + ) + ) + ) + + if not workflow: + session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="工作流不存在" + ) + + session.desc = f"获取工作流数据 - Workflow() {workflow_id}" + if workflow.status != ModelWorkflowStatus.PUBLISHED: + session.desc = f"ERROR: 执行工作流数据 - 工作流未激活 {workflow_id}" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="工作流未激活,无法执行" + ) + + # 获取工作流引擎并执行 + engine = await get_workflow_engine(session) + session.desc = f"获取工作流引擎 - get_workflow_engine {workflow_id}" + execution_result = await engine.execute_workflow( + workflow=workflow, + input_data=request.input_data, + user_id=current_user.id, + session=session + ) + + session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}" + return HxfResponse(execution_result) + diff --git a/th_agenter/api/routes.py b/th_agenter/api/routes.py new file mode 100644 index 0000000..ccc01e3 --- /dev/null +++ b/th_agenter/api/routes.py @@ -0,0 +1,76 @@ +"""Main API router.""" + +from fastapi import APIRouter + +from .endpoints import chat +from .endpoints import auth +from .endpoints import knowledge_base +from .endpoints import smart_query +from .endpoints import smart_chat +from .endpoints import database_config +from .endpoints import table_metadata + +# # System management endpoints +from .endpoints import roles +from .endpoints import llm_configs +from .endpoints import users + +# # Workflow endpoints +from .endpoints import workflow + +# Create main API router +router = APIRouter() + +router.include_router( + auth.router, + prefix="/auth", + tags=["身份验证"] +) +router.include_router( + users.router, + prefix="/users", + tags=["users"] +) +router.include_router( + roles.router, + prefix="/admin", + tags=["admin-roles"] +) +router.include_router( + llm_configs.router, + prefix="/admin", + tags=["admin-llm-configs"] +) +router.include_router( + knowledge_base.router, + prefix="/knowledge-bases", + tags=["knowledge-bases"] +) +router.include_router( + database_config.router, + tags=["database-config"] +) +router.include_router( + table_metadata.router, + tags=["table-metadata"] +) +router.include_router( + smart_query.router, + tags=["smart-query"] +) +router.include_router( + chat.router, + prefix="/chat", + tags=["chat"] +) + +router.include_router( + smart_chat.router, + tags=["smart-chat"] +) + +router.include_router( + workflow.router, + prefix="/workflows", + tags=["workflows"] +) \ No newline at end of file diff --git a/th_agenter/core/__init__.py b/th_agenter/core/__init__.py new file mode 100644 index 0000000..6f1a397 --- /dev/null +++ b/th_agenter/core/__init__.py @@ -0,0 +1 @@ +"""Core module for TH Agenter.""" diff --git a/th_agenter/core/app.py b/th_agenter/core/app.py new file mode 100644 index 0000000..64ddfeb --- /dev/null +++ b/th_agenter/core/app.py @@ -0,0 +1,158 @@ +"""FastAPI application factory.""" + +from loguru import logger +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.trustedhost import TrustedHostMiddleware +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException + +from .config import Settings +from .middleware import UserContextMiddleware +from ..api.routes import router +from ..api.endpoints import table_metadata + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + logger.info("Starting up TH Agenter application...") + yield + logger.info("Shutting down TH Agenter application...") + + +# def create_app(settings: Settings = None) -> FastAPI: +# """Create and configure FastAPI application.""" +# if settings is None: +# from .config import get_settings +# settings = get_settings() + +# # Create FastAPI app +# app = FastAPI( +# title=settings.app_name, +# version=settings.app_version, +# description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由DrGraph修改", +# debug=settings.debug, +# lifespan=lifespan, +# ) + +# # Add middleware +# setup_middleware(app, settings) + +# # Add exception handlers +# setup_exception_handlers(app) + +# # Include routers +# app.include_router(router, prefix="/api") + +# app.include_router(table_metadata.router) +# # 在现有导入中添加 +# from ..api.endpoints import database_config + +# # 在路由注册部分添加 +# app.include_router(database_config.router) +# # Health check endpoint +# @app.get("/health") +# async def health_check(): +# return {"status": "healthy", "version": settings.app_version} + +# # Root endpoint +# @app.get("/") +# async def root(): +# return {"message": "Chat Agent API is running"} +# return app + + +def setup_middleware(app: FastAPI, settings: Settings) -> None: + """Setup application middleware.""" + + # User context middleware (should be first to set context for all requests) + app.add_middleware(UserContextMiddleware) + + # CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors.allowed_origins, + allow_credentials=True, + allow_methods=settings.cors.allowed_methods, + allow_headers=settings.cors.allowed_headers, + ) + + # Trusted host middleware (for production) + if settings.environment == "production": + app.add_middleware( + TrustedHostMiddleware, + allowed_hosts=["*"] # Configure this properly in production + ) + + +def setup_exception_handlers(app: FastAPI) -> None: + """Setup global exception handlers.""" + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request, exc): + return JSONResponse( + status_code=exc.status_code, + content={ + "error": { + "type": "http_error", + "message": exc.detail, + "status_code": exc.status_code + } + } + ) + + def make_json_serializable(obj): + """递归地将对象转换为JSON可序列化的格式""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + elif isinstance(obj, bytes): + return obj.decode('utf-8') + elif isinstance(obj, (ValueError, Exception)): + return str(obj) + elif isinstance(obj, dict): + return {k: make_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + else: + # For any other object, convert to string + return str(obj) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): + # Convert any non-serializable objects to strings in error details + try: + errors = make_json_serializable(exc.errors()) + except Exception as e: + # Fallback: if even our conversion fails, use a simple error message + errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}] + + return JSONResponse( + status_code=422, + content={ + "error": { + "type": "validation_error", + "message": "Request validation failed", + "details": errors + } + } + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request, exc): + logger.error(f"Unhandled exception: {exc}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "error": { + "type": "internal_error", + "message": "Internal server error" + } + } + ) + + +# Create the app instance +# app = create_app() \ No newline at end of file diff --git a/th_agenter/core/config.py b/th_agenter/core/config.py new file mode 100644 index 0000000..3554a8a --- /dev/null +++ b/th_agenter/core/config.py @@ -0,0 +1,468 @@ +"""Configuration management for TH Agenter.""" + +import os +from requests import Session +import yaml +from pathlib import Path +from loguru import logger +from typing import Any, Dict, List, Optional, Union +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings +from functools import lru_cache + +class DatabaseSettings(BaseSettings): + """Database configuration.""" + url: str = Field(..., alias="database_url") # Must be provided via environment variable + echo: bool = Field(default=False) + pool_size: int = Field(default=5) + max_overflow: int = Field(default=10) + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class SecuritySettings(BaseSettings): + """Security configuration.""" + secret_key: str = Field(default="your-secret-key-here-change-in-production") + algorithm: str = Field(default="HS256") + access_token_expire_minutes: int = Field(default=300) + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class ToolSetings(BaseSettings): + # Tavily搜索配置 + tavily_api_key: Optional[str] = Field(default=None) + weather_api_key: Optional[str] = Field(default=None) + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class LLMSettings(BaseSettings): + """大模型配置 - 支持多种OpenAI协议兼容的服务商.""" + provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot + + # OpenAI配置 + openai_api_key: Optional[str] = Field(default=None) + openai_base_url: str = Field(default="https://api.openai.com/v1") + openai_model: str = Field(default="gpt-3.5-turbo") + + # DeepSeek配置 + deepseek_api_key: Optional[str] = Field(default=None) + deepseek_base_url: str = Field(default="https://api.deepseek.com/v1") + deepseek_model: str = Field(default="deepseek-chat") + + # 豆包配置 + doubao_api_key: Optional[str] = Field(default=None) + doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3") + doubao_model: str = Field(default="doubao-lite-4k") + + # 智谱AI配置 + zhipu_api_key: Optional[str] = Field(default=None) + zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4") + zhipu_model: str = Field(default="glm-4") + zhipu_embedding_model: str = Field(default="embedding-3") + + # 月之暗面配置 + moonshot_api_key: Optional[str] = Field(default=None) + moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1") + moonshot_model: str = Field(default="moonshot-v1-8k") + + # 通用配置 + max_tokens: int = Field(default=2048) + temperature: float = Field(default=0.7) + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + + async def get_current_config(self, session: Session) -> dict: + """获取当前选择的提供商配置 - 优先从数据库读取默认配置.""" + try: + from th_agenter.services.llm_config_service import LLMConfigService + # 尝试从数据库读取默认聊天模型配置 + llm_service = LLMConfigService() + db_config = None + if session: + db_config = await llm_service.get_default_chat_config(session) + + if db_config: + # 如果数据库中有默认配置,使用数据库配置 + config = { + "api_key": db_config.api_key, + "base_url": db_config.base_url, + "model": db_config.model_name, + "max_tokens": self.max_tokens, + "temperature": self.temperature + } + if session: + session.desc = f"使用LLM配置(get_default_chat_config)> {config}" + else: + logger.info(f"使用LLM配置(get_default_chat_config) > {config}") + return config + except Exception as e: + # 如果数据库读取失败,记录错误并回退到环境变量 + if session: + session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}" + else: + logger.error(f"获取默认对话模型配置失败: {str(e)}") + + # 回退到原有的环境变量配置 + provider_configs = { + "openai": { + "api_key": self.openai_api_key, + "base_url": self.openai_base_url, + "model": self.openai_model + }, + "deepseek": { + "api_key": self.deepseek_api_key, + "base_url": self.deepseek_base_url, + "model": self.deepseek_model + }, + "doubao": { + "api_key": self.doubao_api_key, + "base_url": self.doubao_base_url, + "model": self.doubao_model + }, + "zhipu": { + "api_key": self.zhipu_api_key, + "base_url": self.zhipu_base_url, + "model": self.zhipu_model + }, + "moonshot": { + "api_key": self.moonshot_api_key, + "base_url": self.moonshot_base_url, + "model": self.moonshot_model + } + } + + config = provider_configs.get(self.provider, provider_configs["openai"]) + config.update({ + "max_tokens": self.max_tokens, + "temperature": self.temperature + }) + return config + +class EmbeddingSettings(BaseSettings): + """Embedding模型配置 - 支持多种提供商.""" + provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot + + # OpenAI配置 + openai_api_key: Optional[str] = Field(default=None) + openai_base_url: str = Field(default="https://api.openai.com/v1") + openai_embedding_model: str = Field(default="text-embedding-ada-002") + + # DeepSeek配置 + deepseek_api_key: Optional[str] = Field(default=None) + deepseek_base_url: str = Field(default="https://api.deepseek.com/v1") + deepseek_embedding_model: str = Field(default="deepseek-embedding") + + # 豆包配置 + doubao_api_key: Optional[str] = Field(default=None) + doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3") + doubao_embedding_model: str = Field(default="doubao-embedding") + + # 智谱AI配置 + zhipu_api_key: Optional[str] = Field(default=None) + zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4") + zhipu_embedding_model: str = Field(default="embedding-3") + + # 月之暗面配置 + moonshot_api_key: Optional[str] = Field(default=None) + moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1") + moonshot_embedding_model: str = Field(default="moonshot-embedding") + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + + async def get_current_config(self, session: Session) -> dict: + """获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置.""" + try: + if session: + session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config"; + # 尝试从数据库读取默认嵌入模型配置 + from th_agenter.services.llm_config_service import LLMConfigService + llm_service = LLMConfigService() + db_config = await llm_service.get_default_embedding_config(session) + + if db_config: + # 如果数据库中有默认配置,使用数据库配置 + config = { + "api_key": db_config.api_key, + "base_url": db_config.base_url, + "model": db_config.model_name + } + return config + except Exception as e: + # 如果数据库读取失败,记录错误并回退到环境变量 + if session: + session.error(f"Failed to read embedding config from database, falling back to env vars: {e}") + else: + logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}") + + # 回退到原有的环境变量配置 + provider_configs = { + "openai": { + "api_key": self.openai_api_key, + "base_url": self.openai_base_url, + "model": self.openai_embedding_model + }, + "deepseek": { + "api_key": self.deepseek_api_key, + "base_url": self.deepseek_base_url, + "model": self.deepseek_embedding_model + }, + "doubao": { + "api_key": self.doubao_api_key, + "base_url": self.doubao_base_url, + "model": self.doubao_embedding_model + }, + "zhipu": { + "api_key": self.zhipu_api_key, + "base_url": self.zhipu_base_url, + "model": self.zhipu_embedding_model + }, + "moonshot": { + "api_key": self.moonshot_api_key, + "base_url": self.moonshot_base_url, + "model": self.moonshot_embedding_model + } + } + + return provider_configs.get(self.provider, provider_configs["zhipu"]) + +class VectorDBSettings(BaseSettings): + """Vector database configuration.""" + type: str = Field(default="pgvector", alias="vector_db_type") + persist_directory: str = Field(default="./data/chroma") + collection_name: str = Field(default="documents") + embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度 + + # PostgreSQL pgvector configuration + pgvector_host: str = Field(default="localhost") + pgvector_port: int = Field(default=5432) + pgvector_database: str = Field(default="vectordb") + pgvector_user: str = Field(default="postgres") + pgvector_password: str = Field(default="") + pgvector_table_name: str = Field(default="embeddings") + pgvector_vector_dimension: int = Field(default=1024) + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class FileSettings(BaseSettings): + """File processing configuration.""" + upload_dir: str = Field(default="./data/uploads") + max_size: int = Field(default=10485760) # 10MB + allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"]) + chunk_size: int = Field(default=1000) + chunk_overlap: int = Field(default=200) + semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器 + + @field_validator('allowed_extensions', mode='before') + @classmethod + def parse_allowed_extensions(cls, v): + """Parse comma-separated string to list of extensions.""" + if isinstance(v, str): + # Split by comma and add dots if not present + extensions = [ext.strip() for ext in v.split(',')] + return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions] + elif isinstance(v, list): + # Ensure all extensions start with dot + return [ext if ext.startswith('.') else f'.{ext}' for ext in v] + return v + + def get_allowed_extensions_list(self) -> List[str]: + """Get allowed extensions as a list.""" + if isinstance(self.allowed_extensions, list): + return self.allowed_extensions + elif isinstance(self.allowed_extensions, str): + # Split by comma and add dots if not present + extensions = [ext.strip() for ext in self.allowed_extensions.split(',')] + return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions] + return [] + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class StorageSettings(BaseSettings): + """Storage configuration.""" + storage_type: str = Field(default="local") # local or s3 + upload_directory: str = Field(default="./data/uploads") + + # S3 settings + s3_bucket_name: str = Field(default="chat-agent-files") + aws_access_key_id: Optional[str] = Field(default=None) + aws_secret_access_key: Optional[str] = Field(default=None) + aws_region: str = Field(default="us-east-1") + s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class CORSSettings(BaseSettings): + """CORS configuration.""" + allowed_origins: List[str] = Field(default=["*"]) + allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) + allowed_headers: List[str] = Field(default=["*"]) + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + +class ChatSettings(BaseSettings): + """Chat configuration.""" + max_history_length: int = Field(default=10) + system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。") + max_response_tokens: int = Field(default=1000) + +class Settings(BaseSettings): + """Main application settings.""" + + # App info + app_name: str = Field(default="TH Agenter") + app_version: str = Field(default="0.2.0") + debug: bool = Field(default=True) + environment: str = Field(default="development") + + # Server + host: str = Field(default="0.0.0.0") + port: int = Field(default=8000) + + # Configuration sections + database: DatabaseSettings = Field(default_factory=DatabaseSettings) + security: SecuritySettings = Field(default_factory=SecuritySettings) + llm: LLMSettings = Field(default_factory=LLMSettings) + embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings) + vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings) + file: FileSettings = Field(default_factory=FileSettings) + storage: StorageSettings = Field(default_factory=StorageSettings) + cors: CORSSettings = Field(default_factory=CORSSettings) + chat: ChatSettings = Field(default_factory=ChatSettings) + tool: ToolSetings = Field(default_factory=ToolSetings) + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore" + } + + @classmethod + def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings": + """Load settings from YAML file.""" + config_file = Path(config_path) + + if not config_file.exists(): + # 获取当前文件所在目录(backend/open_agent/core) + current_dir = Path(__file__).parent + # 向上两级到backend目录,然后找configs/settings.yaml + backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml" + if backend_config_path.exists(): + config_file = backend_config_path + else: + return cls() + + with open(config_file, "r", encoding="utf-8") as f: + config_data = yaml.safe_load(f) or {} + + # 处理环境变量替换 + config_data = cls._resolve_env_vars_nested(config_data) + + # 为每个子设置类创建实例,确保它们能正确加载环境变量 + # 如果YAML中没有对应配置,则使用默认的BaseSettings加载(会自动读取.env文件) + settings_kwargs = {} + + # 显式处理各个子设置,以解决debug等情况因为环境的变化没有自动加载.env配置的问题 + settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {}))) + settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {}))) + settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {}))) + settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {}))) + settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {}))) + settings_kwargs['file'] = FileSettings(**(config_data.get('file', {}))) + settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {}))) + settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {}))) + settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {}))) + settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {}))) + + # 添加顶级配置 + for key, value in config_data.items(): + if key not in settings_kwargs: + # logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs") + settings_kwargs[key] = value + + return cls(**settings_kwargs) + + @staticmethod + def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]: + """Flatten nested configuration dictionary.""" + flat = {} + for key, value in config.items(): + new_key = f"{prefix}_{key}" if prefix else key + if isinstance(value, dict): + flat.update(Settings._flatten_config(value, new_key)) + else: + flat[new_key] = value + return flat + + @staticmethod + def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]: + """Resolve environment variables in nested configuration.""" + if isinstance(config, dict): + return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()} + elif isinstance(config, str) and config.startswith("${") and config.endswith("}"): + env_var = config[2:-1] + return os.getenv(env_var, config) + else: + return config + + @staticmethod + def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]: + """Resolve environment variables in configuration values.""" + resolved = {} + for key, value in config.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + env_var = value[2:-1] + resolved[key] = os.getenv(env_var, value) + else: + resolved[key] = value + return resolved + +@lru_cache() +def get_settings() -> Settings: + """Get cached settings instance.""" + settings = Settings.load_from_yaml() + return settings + +settings = get_settings() \ No newline at end of file diff --git a/th_agenter/core/context.py b/th_agenter/core/context.py new file mode 100644 index 0000000..2fc3862 --- /dev/null +++ b/th_agenter/core/context.py @@ -0,0 +1,142 @@ +""" +HTTP请求上下文管理,如:获取当前登录用户信息及Token信息 +""" + +from contextvars import ContextVar +from typing import Optional +import threading +from ..models.user import User +from loguru import logger + +# Context variable to store current user +current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None) + +# Thread-local storage as backup +_thread_local = threading.local() + + +class UserContext: + """User context manager for accessing current user globally.""" + + @staticmethod + def set_current_user(user: User, canLog: bool = False) -> None: + """Set current user in context.""" + if canLog: + logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})") + + # Store user information as a dictionary instead of the SQLAlchemy model + user_dict = { + 'id': user.id, + 'username': user.username, + 'email': user.email, + 'full_name': user.full_name, + 'is_active': user.is_active + } + + # Set in ContextVar + current_user_context.set(user_dict) + + # Also set in thread-local as backup + _thread_local.current_user = user_dict + + # Verify it was set + verify_user = current_user_context.get() + if canLog: + logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}") + + @staticmethod + def set_current_user_with_token(user: User, canLog: bool = False): + """Set current user in context and return token for cleanup.""" + if canLog: + logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})") + + # Store user information as a dictionary instead of the SQLAlchemy model + user_dict = { + 'id': user.id, + 'username': user.username, + 'email': user.email, + 'full_name': user.full_name, + 'is_active': user.is_active + } + + # Set in ContextVar and get token + token = current_user_context.set(user_dict) + + # Also set in thread-local as backup + _thread_local.current_user = user_dict + + # Verify it was set + verify_user = current_user_context.get() + if canLog: + logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}") + + return token + + @staticmethod + def reset_current_user_token(token): + """Reset current user context using token.""" + logger.info("[UserContext] - Resetting user context using token") + + # Reset ContextVar using token + current_user_context.reset(token) + + # Clear thread-local as well + if hasattr(_thread_local, 'current_user'): + delattr(_thread_local, 'current_user') + + @staticmethod + def get_current_user() -> Optional[dict]: + """Get current user from context.""" + # Try ContextVar first + user = current_user_context.get() + if user: + # logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}") + return user + + # Fallback to thread-local + user = getattr(_thread_local, 'current_user', None) + if user: + # logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}") + return user + + logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)") + return None + + @staticmethod + def get_current_user_id() -> Optional[int]: + """Get current user ID from context.""" + try: + user = UserContext.get_current_user() + return user.get('id') if user else None + except Exception as e: + logger.error(f"[UserContext] - Error getting current user ID: {e}") + return None + + @staticmethod + def clear_current_user(canLog: bool = False) -> None: + """Clear current user from context.""" + if canLog: + logger.info("[UserContext] - 清除当前用户上下文") + + current_user_context.set(None) + if hasattr(_thread_local, 'current_user'): + delattr(_thread_local, 'current_user') + + @staticmethod + def require_current_user() -> dict: + """Get current user from context, raise exception if not found.""" + # Use the same logic as get_current_user to check both ContextVar and thread-local + user = UserContext.get_current_user() + if user is None: + from fastapi import HTTPException, status + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No authenticated user in context" + ) + return user + + @staticmethod + def require_current_user_id() -> int: + """Get current user ID from context, raise exception if not found.""" + user = UserContext.require_current_user() + return user.get('id') \ No newline at end of file diff --git a/th_agenter/core/exceptions.py b/th_agenter/core/exceptions.py new file mode 100644 index 0000000..4791128 --- /dev/null +++ b/th_agenter/core/exceptions.py @@ -0,0 +1,52 @@ +"""Custom exceptions for the application.""" + +from typing import Any, Dict, Optional + + +class BaseCustomException(Exception): + """Base custom exception class.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + self.message = message + self.details = details or {} + super().__init__(self.message) + + +class NotFoundError(BaseCustomException): + """Exception raised when a resource is not found.""" + pass + + +class ValidationError(BaseCustomException): + """Exception raised when validation fails.""" + pass + + +class AuthenticationError(BaseCustomException): + """Exception raised when authentication fails.""" + pass + + +class AuthorizationError(BaseCustomException): + """Exception raised when authorization fails.""" + pass + + +class DatabaseError(BaseCustomException): + """Exception raised when database operations fail.""" + pass + + +class ConfigurationError(BaseCustomException): + """Exception raised when configuration is invalid.""" + pass + + +class ExternalServiceError(BaseCustomException): + """Exception raised when external service calls fail.""" + pass + + +class BusinessLogicError(BaseCustomException): + """Exception raised when business logic validation fails.""" + pass \ No newline at end of file diff --git a/th_agenter/core/middleware.py b/th_agenter/core/middleware.py new file mode 100644 index 0000000..18d66fa --- /dev/null +++ b/th_agenter/core/middleware.py @@ -0,0 +1,173 @@ +""" +中间件管理,如上下文中间件:校验Token等 +""" + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response +from typing import Callable +from loguru import logger +from fastapi import status +from utils.util_exceptions import HxfErrorResponse + +from ..db.database import get_session, AsyncSessionFactory, engine_async +from sqlalchemy.ext.asyncio import AsyncSession +from ..services.auth import AuthService +from .context import UserContext + +class UserContextMiddleware(BaseHTTPMiddleware): + """Middleware to set user context for authenticated requests.""" + + def __init__(self, app, exclude_paths: list = None): + super().__init__(app) + self.canLog = False + # Paths that don't require authentication + self.exclude_paths = exclude_paths or [ + "/docs", + "/redoc", + "/openapi.json", + "/api/auth/login", + "/api/auth/register", + "/api/auth/login-oauth", + "/auth/login", + "/auth/register", + "/auth/login-oauth", + "/health", + "/static/" + ] + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process request and set user context if authenticated.""" + if self.canLog: + logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}") + + # Skip authentication for excluded paths + path = request.url.path + if self.canLog: + logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}") + + should_skip = False + for exclude_path in self.exclude_paths: + # Exact match + if path == exclude_path: + should_skip = True + if self.canLog: + logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}") + break + # For paths ending with '/', check if request path starts with it + elif exclude_path.endswith('/') and path.startswith(exclude_path): + should_skip = True + if self.canLog: + logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头") + break + # For paths not ending with '/', check if request path starts with it + '/' + elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'): + should_skip = True + if self.canLog: + logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头") + break + + if should_skip: + if self.canLog: + logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next") + response = await call_next(request) + return response + + if self.canLog: + logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理") + + # Always clear any existing user context to ensure fresh authentication + UserContext.clear_current_user(self.canLog) + + # Initialize context token + user_token = None + + # Try to extract and validate token + try: + # Get authorization header + authorization = request.headers.get("Authorization") + if not authorization or not authorization.startswith("Bearer "): + # No token provided, return 401 error + return HxfErrorResponse( + message="缺少或无效的授权头", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + # Extract token + token = authorization.split(" ")[1] + + + # Verify token + payload = AuthService.verify_token(token) + if payload is None: + # Invalid token, return 401 error + return HxfErrorResponse( + message="无效或过期的令牌", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + # Get username from token + username = payload.get("sub") + if not username: + return HxfErrorResponse( + message="令牌负载无效", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + # Get user from database + from sqlalchemy import select + from ..models.user import User + + # 创建一个临时的异步会话获取用户信息 + session = AsyncSession(bind=engine_async) + try: + stmt = select(User).where(User.username == username) + user = await session.execute(stmt) + user = user.scalar_one_or_none() + if not user: + return HxfErrorResponse( + message="用户不存在", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + if not user.is_active: + return HxfErrorResponse( + message="用户账户已停用", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + # Set user in context using token mechanism + user_token = UserContext.set_current_user_with_token(user, self.canLog) + if self.canLog: + logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文") + + # Verify context is set correctly + current_user_id = UserContext.get_current_user_id() + if self.canLog: + logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文") + finally: + await session.close() + + except Exception as e: + # Log error but don't fail the request + logger.error(f"[MIDDLEWARE] - 认证过程 [{request.method} {request.url.path}] 中设置用户上下文出错: {e}") + # Return 401 error + return HxfErrorResponse( + message="认证过程中出错", + status_code=status.HTTP_401_UNAUTHORIZED + ) + + # Continue with request + try: + response = await call_next(request) + return response + except Exception as e: + # Log error but don't fail the request + logger.error(f"[MIDDLEWARE] - 请求处理 [{request.method} {request.url.path}] 出错: {e}") + # Return 500 error + return HxfErrorResponse(e) + finally: + # Always clear user context after request processing + UserContext.clear_current_user(self.canLog) + if self.canLog: + logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}") diff --git a/th_agenter/core/new_agent.py b/th_agenter/core/new_agent.py new file mode 100644 index 0000000..d269008 --- /dev/null +++ b/th_agenter/core/new_agent.py @@ -0,0 +1,70 @@ +"""LLM工厂类,用于创建和管理LLM实例""" + +from typing import Optional +from langchain_openai import ChatOpenAI +from langgraph.prebuilt import create_react_agent +from loguru import logger +from requests import Session +from .config import get_settings + +async def new_llm(session: Session = None, model: Optional[str] = None, + temperature: Optional[float] = None, + streaming: bool = False) -> ChatOpenAI: + """创建LLM实例 + + Args: + model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型 + temperature: 可选,模型温度参数 + streaming: 是否启用流式响应,默认False + + Returns: + ChatOpenAI实例 + """ + settings = get_settings() + llm_config = await settings.llm.get_current_config(session) + + if model: + # 根据指定的模型获取对应配置 + if model.startswith('deepseek'): + llm_config['model'] = settings.llm.deepseek_model + llm_config['api_key'] = settings.llm.deepseek_api_key + llm_config['base_url'] = settings.llm.deepseek_base_url + elif model.startswith('doubao'): + llm_config['model'] = settings.llm.doubao_model + llm_config['api_key'] = settings.llm.doubao_api_key + llm_config['base_url'] = settings.llm.doubao_base_url + elif model.startswith('glm'): + llm_config['model'] = settings.llm.zhipu_model + llm_config['api_key'] = settings.llm.zhipu_api_key + llm_config['base_url'] = settings.llm.zhipu_base_url + elif model.startswith('moonshot'): + llm_config['model'] = settings.llm.moonshot_model + llm_config['api_key'] = settings.llm.moonshot_api_key + llm_config['base_url'] = settings.llm.moonshot_base_url + + llm = ChatOpenAI( + model=llm_config['model'], + api_key=llm_config['api_key'], + base_url=llm_config['base_url'], + temperature=temperature if temperature is not None else llm_config['temperature'], + max_tokens=llm_config['max_tokens'], + streaming=streaming + ) + + return llm + +async def new_agent(session: Session = None, model: Optional[str] = None, + temperature: Optional[float] = None, + streaming: bool = False) -> ChatOpenAI: + """创建LLM实例 + + Args: + model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型 + temperature: 可选,模型温度参数 + streaming: 是否启用流式响应,默认False + + Returns: + ChatOpenAI实例 + """ + llm = await new_llm(session, model, temperature, streaming) + return create_react_agent(llm, []) \ No newline at end of file diff --git a/th_agenter/core/simple_permissions.py b/th_agenter/core/simple_permissions.py new file mode 100644 index 0000000..9229e21 --- /dev/null +++ b/th_agenter/core/simple_permissions.py @@ -0,0 +1,107 @@ +"""简化的权限检查系统.""" + +from functools import wraps +from typing import Optional +from fastapi import HTTPException, Depends +from loguru import logger +from sqlalchemy.orm import Session + +from ..db.database import get_session +from ..models.user import User +from ..models.permission import Role +from ..services.auth import AuthService + + +async def is_super_admin(user: User, session: Session) -> bool: + """检查用户是否为超级管理员.""" + session.desc = f"检查用户 {user.id} 是否为超级管理员" + if not user or not user.is_active: + session.desc = f"用户 {user.id} 不是活跃状态" + return False + + try: + # 直接使用提供的session查询,避免MissingGreenlet错误 + from sqlalchemy import select + from ..models.permission import UserRole, Role + + stmt = select(UserRole).join(Role).filter( + UserRole.user_id == user.id, + Role.code == 'SUPER_ADMIN', + Role.is_active == True + ) + user_role = await session.execute(stmt) + result = user_role.scalar_one_or_none() is not None + session.desc = f"用户 {user.id} 超级管理员角色查询结果: {result}" + return result + except Exception as e: + # 如果调用失败,记录错误并返回False + session.desc = f"EXCEPTION: 用户 {user.id} 超级管理员角色查询失败: {str(e)}" + logger.error(f"检查用户 {user.id} 超级管理员角色失败: {str(e)}") + return False + + +async def require_super_admin( + current_user: User = Depends(AuthService.get_current_user), + session: Session = Depends(get_session) +) -> User: + """要求超级管理员权限的依赖项.""" + if not await is_super_admin(current_user, session): + raise HTTPException( + status_code=403, + detail="需要超级管理员权限" + ) + return current_user + + +def require_authenticated_user( + current_user: User = Depends(AuthService.get_current_user) +) -> User: + """要求已认证用户的依赖项.""" + if not current_user or not current_user.is_active: + raise HTTPException( + status_code=401, + detail="需要登录" + ) + return current_user + + +class SimplePermissionChecker: + """简化的权限检查器.""" + + def __init__(self, db: Session): + self.db = db + + async def check_super_admin(self, user: User) -> bool: + """检查是否为超级管理员.""" + return await is_super_admin(user, self.db) + + async def check_user_access(self, user: User, target_user_id: int) -> bool: + """检查用户访问权限(自己或超级管理员).""" + if not user or not user.is_active: + return False + + # 超级管理员可以访问所有用户 + if await self.check_super_admin(user): + return True + + # 用户只能访问自己的信息 + return user.id == target_user_id + + +# 权限装饰器 +def super_admin_required(func): + """超级管理员权限装饰器.""" + @wraps(func) + def wrapper(*args, **kwargs): + # 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层 + return func(*args, **kwargs) + return wrapper + + +def authenticated_required(func): + """认证用户权限装饰器.""" + @wraps(func) + def wrapper(*args, **kwargs): + # 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层 + return func(*args, **kwargs) + return wrapper \ No newline at end of file diff --git a/th_agenter/core/user_utils.py b/th_agenter/core/user_utils.py new file mode 100644 index 0000000..fff0e81 --- /dev/null +++ b/th_agenter/core/user_utils.py @@ -0,0 +1,76 @@ +"""User utility functions for easy access to current user context.""" + +from typing import Optional +from ..models.user import User +from .context import UserContext + + +def get_current_user() -> Optional[User]: + """Get current authenticated user from context. + + Returns: + Current user if authenticated, None otherwise + """ + return UserContext.get_current_user() + + +def get_current_user_id() -> Optional[int]: + """Get current authenticated user ID from context. + + Returns: + Current user ID if authenticated, None otherwise + """ + return UserContext.get_current_user_id() + + +def require_current_user() -> User: + """Get current authenticated user from context, raise exception if not found. + + Returns: + Current user + + Raises: + HTTPException: If no authenticated user in context + """ + return UserContext.require_current_user() + + +def require_current_user_id() -> int: + """Get current authenticated user ID from context, raise exception if not found. + + Returns: + Current user ID + + Raises: + HTTPException: If no authenticated user in context + """ + return UserContext.require_current_user_id() + + +def is_user_authenticated() -> bool: + """Check if there is an authenticated user in the current context. + + Returns: + True if user is authenticated, False otherwise + """ + return UserContext.get_current_user() is not None + + +def get_current_username() -> Optional[str]: + """Get current authenticated user's username from context. + + Returns: + Current user's username if authenticated, None otherwise + """ + user = UserContext.get_current_user() + return user.username if user else None + + +def get_current_user_email() -> Optional[str]: + """Get current authenticated user's email from context. + + Returns: + Current user's email if authenticated, None otherwise + """ + user = UserContext.get_current_user() + return user.email if user else None \ No newline at end of file diff --git a/th_agenter/db/__init__.py b/th_agenter/db/__init__.py new file mode 100644 index 0000000..984e321 --- /dev/null +++ b/th_agenter/db/__init__.py @@ -0,0 +1,8 @@ +"""Database module for TH Agenter.""" + +from .database import get_session +from .base import Base +from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata + + +__all__ = ["get_session", "Base"] diff --git a/th_agenter/db/base.py b/th_agenter/db/base.py new file mode 100644 index 0000000..c8333d8 --- /dev/null +++ b/th_agenter/db/base.py @@ -0,0 +1,143 @@ +"""Database base model.""" + +from datetime import datetime +from sqlalchemy import Integer, DateTime, event +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session +from sqlalchemy.sql import func +from typing import Optional +from sqlalchemy import MetaData + +class Base(DeclarativeBase): + metadata = MetaData( + naming_convention={ + # ix: index, 索引 + "ix": "ix_%(column_0_label)s", + # uq: unique, 唯一约束 + "uq": "uq_%(table_name)s_%(column_0_name)s", + # ck: check, 检查约束 + "ck": "ck_%(table_name)s_%(constraint_name)s", + # fk: foreign key, 外键约束 + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + # pk: primary key, 主键约束 + "pk": "pk_%(table_name)s" + } + ) + + +class BaseModel(Base): + """Base model with common fields.""" + + __abstract__ = True + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) + created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + def to_dict(self): + """Convert model to dictionary.""" + return { + column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name) + for column in self.__table__.columns + } + + @classmethod + def from_dict(cls, data: dict): + """Create model instance from dictionary. + + Args: + data: Dictionary containing model field values + + Returns: + Model instance created from the dictionary + """ + # Filter out fields that don't exist in the model + model_fields = {column.name for column in cls.__table__.columns} + filtered_data = {key: value for key, value in data.items() if key in model_fields} + + # Create and return the instance + return cls(**filtered_data) + + def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False): + """对创建/更新操作设置created_by/updated_by字段。 + + Args: + user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取) + is_update: True 表示更新操作,False 表示创建操作 + """ + # 如果未提供user_id,则从上下文获取 + if user_id is None: + from ..core.context import UserContext + try: + user_id = UserContext.get_current_user_id() + except Exception: + # 如果上下文没有用户ID,则跳过设置审计字段 + return + + # 如果仍未提供user_id,则跳过设置审计字段 + if user_id is None: + return + + if not is_update: + # 对于创建操作,同时设置created_by和updated_by + self.created_by = user_id + self.updated_by = user_id + else: + # 对于更新操作,仅设置updated_by + self.updated_by = user_id + +# @event.listens_for(Session, 'before_flush') +# def set_audit_fields_before_flush(session, flush_context, instances): +# """Automatically set audit fields before flush.""" +# try: +# from th_agenter.core.context import UserContext +# user_id = UserContext.get_current_user_id() +# except Exception: +# user_id = None + +# # 处理新增对象 +# for instance in session.new: +# if isinstance(instance, BaseModel) and user_id: +# instance.created_by = user_id +# instance.updated_by = user_id + +# # 处理修改对象 +# for instance in session.dirty: +# if isinstance(instance, BaseModel) and user_id: +# instance.updated_by = user_id + +# # def __init__(self, **kwargs): +# # """Initialize model with automatic audit fields setting.""" +# # super().__init__(**kwargs) +# # # Set audit fields for new instances +# # self.set_audit_fields() + +# # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False): +# # """Set audit fields for create/update operations. + +# # Args: +# # user_id: ID of the user performing the operation (optional, will use context if not provided) +# # is_update: True for update operations, False for create operations +# # """ +# # # Get user_id from context if not provided +# # if user_id is None: +# # from ..core.context import UserContext +# # try: +# # user_id = UserContext.get_current_user_id() +# # except Exception: +# # # If no user in context, skip setting audit fields +# # return + +# # # Skip if still no user_id +# # if user_id is None: +# # return + +# # if not is_update: +# # # For create operations, set both create_by and update_by +# # self.created_by = user_id +# # self.updated_by = user_id +# # else: +# # # For update operations, only set update_by +# # self.updated_by = user_id + diff --git a/th_agenter/db/database.py b/th_agenter/db/database.py new file mode 100644 index 0000000..583bcba --- /dev/null +++ b/th_agenter/db/database.py @@ -0,0 +1,141 @@ +"""Database connection and session management.""" + +import uuid, re +from loguru import logger +import traceback +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from typing import Optional + +from utils.general import gradient_text + +from ..core.config import get_settings +from .base import Base +from utils.util_exceptions import DatabaseError + +# Custom Session class with desc property and unique ID +class DrSession(AsyncSession): + """Custom Session class with desc property and unique ID.""" + + def __init__(self, **kwargs): + """Initialize DrSession with unique ID.""" + super().__init__(**kwargs) + self.title = "" + self.descs = [] + # 确保info属性存在 + if not hasattr(self, 'info'): + self.info = {} + self.info['session_id'] = str(uuid.uuid4()).split('-')[0] + self.stepIndex = 0 + + @property + def title(self) -> Optional[str]: + """Get work brief from session info.""" + return self.info.get('title') + + @title.setter + def title(self, value: str) -> None: + """Set work brief in session info.""" + if('title' not in self.info or self.info['title'].strip() == ""): + self.info['title'] = value # 确保title属性存在 + else: + self.info['title'] = value + " >>> " + self.info['title'] + + @property + def desc(self) -> Optional[str]: + """Get work brief from session info.""" + return self.info.get('desc') + + @desc.setter + def desc(self, value: str) -> None: + """Set work brief in session info.""" + self.stepIndex += 1 + logger.info(value) + + def log_prefix(self) -> str: + """Get log prefix with session ID and desc.""" + return f"〖Session{self.info['session_id']}〗" + + def parse_source_pos(self, level: int): + pos = (traceback.format_stack())[level].strip().split('\n')[0] + match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos); + if match: + file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "") + pos = f"{file}:{match.group(2)} in {match.group(3)}" + return pos + + def log_info(self, msg: str, level: int = -2): + """Log info message with session ID.""" + pos = self.parse_source_pos(level) + logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_success(self, msg: str, level: int = -2): + """Log success message with session ID.""" + pos = self.parse_source_pos(level) + logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_warning(self, msg: str, level: int = -2): + """Log warning message with session ID.""" + pos = self.parse_source_pos(level) + logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_error(self, msg: str, level: int = -2): + """Log error message with session ID.""" + pos = self.parse_source_pos(level) + logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}") + + def log_exception(self, msg: str, level: int = -2): + """Log exception message with session ID.""" + pos = self.parse_source_pos(level) + logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}") + +engine_async = create_async_engine( + get_settings().database.url, + echo=False, # get_settings().database.echo, + future=True, + pool_size=get_settings().database.pool_size, + max_overflow=get_settings().database.max_overflow, + pool_pre_ping=True, + pool_recycle=3600, +) +from fastapi import HTTPException, Request + +AsyncSessionFactory = sessionmaker( + bind=engine_async, + class_=DrSession, + expire_on_commit=False, + autoflush=True +) + +async def get_session(request: Request = None): + url = "无request" + if request: + url = f"{request.method} {request.url.path}"# .split("://")[-1] + # session = AsyncSessionFactory() + print(url) + # 取得request的来源IP + if request: + client_host = request.client.host + else: + client_host = "无request" + session = DrSession(bind=engine_async) + + session.title = f"{url} - {client_host}" + + # 设置request属性 + if request: + session.request = request + + try: + yield session + + except Exception as e: + errMsg = f"数据库 session 异常 >>> {e}" + session.desc = f"EXCEPTION: {errMsg}" + await session.rollback() + # 重新抛出原始异常,不转换为 HTTPException + raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常 + finally: + # session.desc = f"数据库 session 关闭" + session.desc = "" + await session.close() diff --git a/th_agenter/db/migrations/add_system_management.py b/th_agenter/db/migrations/add_system_management.py new file mode 100644 index 0000000..1b0dc0f --- /dev/null +++ b/th_agenter/db/migrations/add_system_management.py @@ -0,0 +1,216 @@ +"""Add system management tables. + +Revision ID: add_system_management +Revises: +Create Date: 2024-01-01 00:00:00.000000 + +""" +from alembic_sync import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = 'add_system_management' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + """Create system management tables.""" + + # Create departments table + op.create_table('departments', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('code', sa.String(length=50), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('parent_id', sa.Integer(), nullable=True), + sa.Column('sort_order', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['parent_id'], ['departments.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('code') + ) + op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False) + op.create_index(op.f('ix_departments_parent_id'), 'departments', ['parent_id'], unique=False) + + # Create permissions table + op.create_table('permissions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('code', sa.String(length=100), nullable=False), + sa.Column('category', sa.String(length=50), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('sort_order', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('code') + ) + op.create_index(op.f('ix_permissions_category'), 'permissions', ['category'], unique=False) + op.create_index(op.f('ix_permissions_name'), 'permissions', ['name'], unique=False) + + # Create roles table + op.create_table('roles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('code', sa.String(length=50), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('sort_order', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('code') + ) + op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False) + + # Create role_permissions table + op.create_table('role_permissions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('role_id', sa.Integer(), nullable=False), + sa.Column('permission_id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('role_id', 'permission_id', name='uq_role_permission') + ) + op.create_index(op.f('ix_role_permissions_permission_id'), 'role_permissions', ['permission_id'], unique=False) + op.create_index(op.f('ix_role_permissions_role_id'), 'role_permissions', ['role_id'], unique=False) + + # Create user_roles table + op.create_table('user_roles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('role_id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'role_id', name='uq_user_role') + ) + op.create_index(op.f('ix_user_roles_role_id'), 'user_roles', ['role_id'], unique=False) + op.create_index(op.f('ix_user_roles_user_id'), 'user_roles', ['user_id'], unique=False) + + # Create user_permissions table + op.create_table('user_permissions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('permission_id', sa.Integer(), nullable=False), + sa.Column('granted', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'permission_id', name='uq_user_permission') + ) + op.create_index(op.f('ix_user_permissions_permission_id'), 'user_permissions', ['permission_id'], unique=False) + op.create_index(op.f('ix_user_permissions_user_id'), 'user_permissions', ['user_id'], unique=False) + + # Create llm_configs table + op.create_table('llm_configs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('provider', sa.String(length=50), nullable=False), + sa.Column('model_name', sa.String(length=100), nullable=False), + sa.Column('api_key', sa.Text(), nullable=True), + sa.Column('api_base', sa.String(length=500), nullable=True), + sa.Column('api_version', sa.String(length=20), nullable=True), + sa.Column('max_tokens', sa.Integer(), nullable=True), + sa.Column('temperature', sa.Float(), nullable=True), + sa.Column('top_p', sa.Float(), nullable=True), + sa.Column('frequency_penalty', sa.Float(), nullable=True), + sa.Column('presence_penalty', sa.Float(), nullable=True), + sa.Column('timeout', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('is_default', sa.Boolean(), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('sort_order', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('updated_by', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False) + op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False) + + # Add new columns to users table + op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True)) + op.add_column('users', sa.Column('is_superuser', sa.Boolean(), nullable=True, default=False)) + op.add_column('users', sa.Column('is_admin', sa.Boolean(), nullable=True, default=False)) + op.add_column('users', sa.Column('last_login_at', sa.DateTime(), nullable=True)) + op.add_column('users', sa.Column('login_count', sa.Integer(), nullable=True, default=0)) + + # Create foreign key constraint for department_id + op.create_foreign_key('fk_users_department_id', 'users', 'departments', ['department_id'], ['id']) + op.create_index(op.f('ix_users_department_id'), 'users', ['department_id'], unique=False) + + +def downgrade(): + """Drop system management tables.""" + + # Drop foreign key and index for users.department_id + op.drop_index(op.f('ix_users_department_id'), table_name='users') + op.drop_constraint('fk_users_department_id', 'users', type_='foreignkey') + + # Drop new columns from users table + op.drop_column('users', 'login_count') + op.drop_column('users', 'last_login_at') + op.drop_column('users', 'is_admin') + op.drop_column('users', 'is_superuser') + op.drop_column('users', 'department_id') + + # Drop llm_configs table + op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs') + op.drop_table('llm_configs') + + # Drop user_permissions table + op.drop_index(op.f('ix_user_permissions_user_id'), table_name='user_permissions') + op.drop_index(op.f('ix_user_permissions_permission_id'), table_name='user_permissions') + op.drop_table('user_permissions') + + # Drop user_roles table + op.drop_index(op.f('ix_user_roles_user_id'), table_name='user_roles') + op.drop_index(op.f('ix_user_roles_role_id'), table_name='user_roles') + op.drop_table('user_roles') + + # Drop role_permissions table + op.drop_index(op.f('ix_role_permissions_role_id'), table_name='role_permissions') + op.drop_index(op.f('ix_role_permissions_permission_id'), table_name='role_permissions') + op.drop_table('role_permissions') + + # Drop roles table + op.drop_index(op.f('ix_roles_name'), table_name='roles') + op.drop_table('roles') + + # Drop permissions table + op.drop_index(op.f('ix_permissions_name'), table_name='permissions') + op.drop_index(op.f('ix_permissions_category'), table_name='permissions') + op.drop_table('permissions') + + # Drop departments table + op.drop_index(op.f('ix_departments_parent_id'), table_name='departments') + op.drop_index(op.f('ix_departments_name'), table_name='departments') + op.drop_table('departments') \ No newline at end of file diff --git a/th_agenter/db/migrations/add_user_department_table.py b/th_agenter/db/migrations/add_user_department_table.py new file mode 100644 index 0000000..edfa137 --- /dev/null +++ b/th_agenter/db/migrations/add_user_department_table.py @@ -0,0 +1,83 @@ +"""Add user_department association table migration.""" + +import sys +import os +from pathlib import Path + +# 添加项目根目录到Python路径 +project_root = Path(__file__).parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +import asyncio +import asyncpg +from th_agenter.core.config import get_settings + +async def create_user_department_table(): + """Create user_departments association table.""" + settings = get_settings() + database_url = settings.database.url + + print(f"Database URL: {database_url}") + + try: + # 解析PostgreSQL连接URL + # postgresql://user:password@host:port/database + url_parts = database_url.replace('postgresql://', '').split('/') + db_name = url_parts[1] if len(url_parts) > 1 else 'postgres' + user_host = url_parts[0].split('@') + user_pass = user_host[0].split(':') + host_port = user_host[1].split(':') + + user = user_pass[0] + password = user_pass[1] if len(user_pass) > 1 else '' + host = host_port[0] + port = int(host_port[1]) if len(host_port) > 1 else 5432 + + # 连接PostgreSQL数据库 + conn = await asyncpg.connect( + user=user, + password=password, + database=db_name, + host=host, + port=port + ) + + # 创建user_departments表 + create_table_sql = """ + CREATE TABLE IF NOT EXISTS user_departments ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + department_id INTEGER NOT NULL, + is_primary BOOLEAN NOT NULL DEFAULT true, + is_active BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE, + FOREIGN KEY (department_id) REFERENCES departments (id) ON DELETE CASCADE + ); + """ + + await conn.execute(create_table_sql) + + # 创建索引 + create_indexes_sql = [ + "CREATE INDEX IF NOT EXISTS idx_user_departments_user_id ON user_departments (user_id);", + "CREATE INDEX IF NOT EXISTS idx_user_departments_department_id ON user_departments (department_id);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_user_departments_unique ON user_departments (user_id, department_id);" + ] + + for index_sql in create_indexes_sql: + await conn.execute(index_sql) + + print("User departments table created successfully") + + except Exception as e: + print(f"Error creating user departments table: {e}") + raise + finally: + if 'conn' in locals(): + await conn.close() + + +if __name__ == "__main__": + asyncio.run(create_user_department_table()) \ No newline at end of file diff --git a/th_agenter/db/migrations/migrate_hardcoded_resources.py b/th_agenter/db/migrations/migrate_hardcoded_resources.py new file mode 100644 index 0000000..f786cbc --- /dev/null +++ b/th_agenter/db/migrations/migrate_hardcoded_resources.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +"""Migration script to move hardcoded resources to database.""" + +import sys +import os +from pathlib import Path + +# Add the backend directory to Python path +backend_dir = Path(__file__).parent.parent.parent +sys.path.insert(0, str(backend_dir)) + +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from th_agenter.core.config import get_settings +from th_agenter.db.database import Base +from th_agenter.models import * # Import all models to ensure they're registered +from th_agenter.utils.logger import get_logger +from th_agenter.models.resource import Resource +from th_agenter.models.permission import Role +from th_agenter.models.resource import RoleResource + +logger = get_logger(__name__) + +def migrate_hardcoded_resources(): + """Migrate hardcoded resources from init_resource_data.py to database.""" + db = None + try: + # Get database settings + settings = get_settings() + + # Create synchronous engine (remove asyncpg from URL) + sync_db_url = settings.database.url.replace('postgresql+asyncpg://', 'postgresql://') + sync_engine = create_engine( + sync_db_url, + echo=False, + pool_size=settings.database.pool_size, + max_overflow=settings.database.max_overflow, + pool_pre_ping=True, + pool_recycle=3600, + ) + + # Create synchronous session factory + SyncSessionFactory = sessionmaker(bind=sync_engine) + + # Get database session + db = SyncSessionFactory() + + if db is None: + logger.error("Failed to create database session") + return False + + # Create all tables if they don't exist + from th_agenter.db.database import engine as global_engine + if global_engine: + Base.metadata.create_all(bind=global_engine) + + logger.info("Starting hardcoded resources migration...") + + # Check if resources already exist + existing_count = db.query(Resource).count() + if existing_count > 0: + logger.info(f"Found {existing_count} existing resources. Checking role assignments.") + # 即使资源已存在,也要检查并分配角色资源关联 + admin_role = db.query(Role).filter(Role.name == "系统管理员").first() + if admin_role: + # 获取所有资源 + all_resources = db.query(Resource).all() + assigned_count = 0 + + for resource in all_resources: + # 检查关联是否已存在 + existing = db.query(RoleResource).filter( + RoleResource.role_id == admin_role.id, + RoleResource.resource_id == resource.id + ).first() + + if not existing: + role_resource = RoleResource( + role_id=admin_role.id, + resource_id=resource.id + ) + db.add(role_resource) + assigned_count += 1 + + if assigned_count > 0: + db.commit() + logger.info(f"已为系统管理员角色分配 {assigned_count} 个新资源") + else: + logger.info("系统管理员角色已拥有所有资源") + else: + logger.warning("未找到系统管理员角色") + + return True + + # Define hardcoded resource data + main_menu_data = [ + { + "name": "智能问答", + "code": "CHAT", + "type": "menu", + "path": "/chat", + "component": "views/Chat.vue", + "icon": "ChatDotRound", + "description": "智能问答功能", + "sort_order": 1, + "requires_auth": True, + "requires_admin": False + }, + { + "name": "智能问数", + "code": "SMART_QUERY", + "type": "menu", + "path": "/smart-query", + "component": "views/SmartQuery.vue", + "icon": "DataAnalysis", + "description": "智能问数功能", + "sort_order": 2, + "requires_auth": True, + "requires_admin": False + }, + { + "name": "知识库", + "code": "KNOWLEDGE", + "type": "menu", + "path": "/knowledge", + "component": "views/KnowledgeBase.vue", + "icon": "Collection", + "description": "知识库管理", + "sort_order": 3, + "requires_auth": True, + "requires_admin": False + }, + { + "name": "工作流编排", + "code": "WORKFLOW", + "type": "menu", + "path": "/workflow", + "component": "views/Workflow.vue", + "icon": "Connection", + "description": "工作流编排功能", + "sort_order": 4, + "requires_auth": True, + "requires_admin": False + }, + { + "name": "智能体管理", + "code": "AGENT", + "type": "menu", + "path": "/agent", + "component": "views/Agent.vue", + "icon": "User", + "description": "智能体管理功能", + "sort_order": 5, + "requires_auth": True, + "requires_admin": False + }, + { + "name": "系统管理", + "code": "SYSTEM", + "type": "menu", + "path": "/system", + "component": "views/SystemManagement.vue", + "icon": "Setting", + "description": "系统管理功能", + "sort_order": 6, + "requires_auth": True, + "requires_admin": True + } + ] + + # Create main menu resources + created_resources = {} + for menu_data in main_menu_data: + resource = Resource(**menu_data) + db.add(resource) + db.flush() + created_resources[menu_data["code"]] = resource + logger.info(f"Created main menu resource: {menu_data['name']}") + + # System management submenu data + system_submenu_data = [ + { + "name": "用户管理", + "code": "SYSTEM_USERS", + "type": "menu", + "path": "/system/users", + "component": "components/system/UserManagement.vue", + "icon": "User", + "description": "用户管理功能", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 1, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "部门管理", + "code": "SYSTEM_DEPARTMENTS", + "type": "menu", + "path": "/system/departments", + "component": "components/system/DepartmentManagement.vue", + "icon": "OfficeBuilding", + "description": "部门管理功能", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 2, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "角色管理", + "code": "SYSTEM_ROLES", + "type": "menu", + "path": "/system/roles", + "component": "components/system/RoleManagement.vue", + "icon": "Avatar", + "description": "角色管理功能", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 3, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "权限管理", + "code": "SYSTEM_PERMISSIONS", + "type": "menu", + "path": "/system/permissions", + "component": "components/system/PermissionManagement.vue", + "icon": "Lock", + "description": "权限管理功能", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 4, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "资源管理", + "code": "SYSTEM_RESOURCES", + "type": "menu", + "path": "/system/resources", + "component": "components/system/ResourceManagement.vue", + "icon": "Grid", + "description": "资源管理功能", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 5, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "大模型管理", + "code": "SYSTEM_LLM_CONFIGS", + "type": "menu", + "path": "/system/llm-configs", + "component": "components/system/LLMConfigManagement.vue", + "icon": "Cpu", + "description": "大模型配置管理", + "parent_id": created_resources["SYSTEM"].id, + "sort_order": 6, + "requires_auth": True, + "requires_admin": True + } + ] + + # Create system management submenu + for submenu_data in system_submenu_data: + submenu = Resource(**submenu_data) + db.add(submenu) + db.flush() + created_resources[submenu_data["code"]] = submenu + logger.info(f"Created system submenu resource: {submenu_data['name']}") + + # Button resources data + button_resources_data = [ + # User management buttons + { + "name": "新增用户", + "code": "USER_CREATE_BTN", + "type": "button", + "description": "新增用户按钮", + "parent_id": created_resources["SYSTEM_USERS"].id, + "sort_order": 1, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "编辑用户", + "code": "USER_EDIT_BTN", + "type": "button", + "description": "编辑用户按钮", + "parent_id": created_resources["SYSTEM_USERS"].id, + "sort_order": 2, + "requires_auth": True, + "requires_admin": True + }, + # Role management buttons + { + "name": "新增角色", + "code": "ROLE_CREATE_BTN", + "type": "button", + "description": "新增角色按钮", + "parent_id": created_resources["SYSTEM_ROLES"].id, + "sort_order": 1, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "编辑角色", + "code": "ROLE_EDIT_BTN", + "type": "button", + "description": "编辑角色按钮", + "parent_id": created_resources["SYSTEM_ROLES"].id, + "sort_order": 2, + "requires_auth": True, + "requires_admin": True + }, + # Permission management buttons + { + "name": "新增权限", + "code": "PERMISSION_CREATE_BTN", + "type": "button", + "description": "新增权限按钮", + "parent_id": created_resources["SYSTEM_PERMISSIONS"].id, + "sort_order": 1, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "编辑权限", + "code": "PERMISSION_EDIT_BTN", + "type": "button", + "description": "编辑权限按钮", + "parent_id": created_resources["SYSTEM_PERMISSIONS"].id, + "sort_order": 2, + "requires_auth": True, + "requires_admin": True + } + ] + + # Create button resources + for button_data in button_resources_data: + button = Resource(**button_data) + db.add(button) + db.flush() + created_resources[button_data["code"]] = button + logger.info(f"Created button resource: {button_data['name']}") + + # API resources data + api_resources_data = [ + # User management APIs + { + "name": "用户列表API", + "code": "USER_LIST_API", + "type": "api", + "path": "/api/users", + "description": "获取用户列表API", + "sort_order": 1, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "创建用户API", + "code": "USER_CREATE_API", + "type": "api", + "path": "/api/users", + "description": "创建用户API", + "sort_order": 2, + "requires_auth": True, + "requires_admin": True + }, + # Role management APIs + { + "name": "角色列表API", + "code": "ROLE_LIST_API", + "type": "api", + "path": "/api/admin/roles", + "description": "获取角色列表API", + "sort_order": 5, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "创建角色API", + "code": "ROLE_CREATE_API", + "type": "api", + "path": "/api/admin/roles", + "description": "创建角色API", + "sort_order": 6, + "requires_auth": True, + "requires_admin": True + }, + # Resource management APIs + { + "name": "资源列表API", + "code": "RESOURCE_LIST_API", + "type": "api", + "path": "/api/admin/resources", + "description": "获取资源列表API", + "sort_order": 10, + "requires_auth": True, + "requires_admin": True + }, + { + "name": "创建资源API", + "code": "RESOURCE_CREATE_API", + "type": "api", + "path": "/api/admin/resources", + "description": "创建资源API", + "sort_order": 11, + "requires_auth": True, + "requires_admin": True + } + ] + + # Create API resources + for api_data in api_resources_data: + api_resource = Resource(**api_data) + db.add(api_resource) + db.flush() + created_resources[api_data["code"]] = api_resource + logger.info(f"Created API resource: {api_data['name']}") + + # 分配资源给系统管理员角色 + admin_role = db.query(Role).filter(Role.name == "系统管理员").first() + if admin_role: + all_resources = list(created_resources.values()) + for resource in all_resources: + # 检查关联是否已存在 + existing = db.query(RoleResource).filter( + RoleResource.role_id == admin_role.id, + RoleResource.resource_id == resource.id + ).first() + + if not existing: + role_resource = RoleResource( + role_id=admin_role.id, + resource_id=resource.id + ) + db.add(role_resource) + + logger.info(f"已为系统管理员角色分配 {len(all_resources)} 个资源") + else: + logger.warning("未找到系统管理员角色") + + db.commit() + + total_resources = db.query(Resource).count() + logger.info(f"Migration completed successfully. Total resources: {total_resources}") + + return True + + except Exception as e: + logger.error(f"Migration failed: {str(e)}") + if db: + db.rollback() + return False + finally: + if db: + db.close() + main() \ No newline at end of file diff --git a/th_agenter/db/migrations/remove_permission_tables.py b/th_agenter/db/migrations/remove_permission_tables.py new file mode 100644 index 0000000..fe2299a --- /dev/null +++ b/th_agenter/db/migrations/remove_permission_tables.py @@ -0,0 +1,146 @@ +"""删除权限相关表的迁移脚本 + +Revision ID: remove_permission_tables +Revises: add_system_management +Create Date: 2024-01-25 10:00:00.000000 + +""" +from alembic_sync import op +import sqlalchemy as sa +from sqlalchemy import text + + +# revision identifiers, used by Alembic. +revision = 'remove_permission_tables' +down_revision = 'add_system_management' +branch_labels = None +depends_on = None + + +async def upgrade(): + """删除权限相关表.""" + + # 获取数据库连接 + connection = op.get_bind() + + # 删除外键约束和表(按依赖关系顺序) + tables_to_drop = [ + 'user_permissions', # 用户权限关联表 + 'role_permissions', # 角色权限关联表 + 'permission_resources', # 权限资源关联表 + 'permissions', # 权限表 + 'role_resources', # 角色资源关联表 + 'resources', # 资源表 + 'user_departments', # 用户部门关联表 + 'departments' # 部门表 + ] + + for table_name in tables_to_drop: + try: + # 检查表是否存在 + result = connection.execute(text(f""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = '{table_name}' + ); + """)) + table_exists = await result.scalar() + + if table_exists: + print(f"删除表: {table_name}") + op.drop_table(table_name) + else: + print(f"表 {table_name} 不存在,跳过") + + except Exception as e: + print(f"删除表 {table_name} 时出错: {e}") + # 继续删除其他表 + continue + + # 删除用户表中的部门相关字段 + try: + # 检查字段是否存在 + result = connection.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'users' AND column_name = 'department_id'; + """)) + column_exists = result.fetchone() + + if column_exists: + print("删除用户表中的 department_id 字段") + op.drop_column('users', 'department_id') + else: + print("用户表中的 department_id 字段不存在,跳过") + + except Exception as e: + print(f"删除 department_id 字段时出错: {e}") + + # 简化 user_roles 表结构(如果需要的话) + try: + # 检查 user_roles 表是否有多余的字段 + result = connection.execute(text(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'user_roles' AND column_name IN ('id', 'created_at', 'updated_at', 'created_by', 'updated_by'); + """)) + extra_columns = [row[0] for row in result.fetchall()] + + if extra_columns: + print("简化 user_roles 表结构") + # 创建新的简化表 + op.execute(text(""" + CREATE TABLE user_roles_new ( + user_id INTEGER NOT NULL, + role_id INTEGER NOT NULL, + PRIMARY KEY (user_id, role_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + FOREIGN KEY (role_id) REFERENCES roles(id) ON DELETE CASCADE + ); + """)) + + # 迁移数据 + op.execute(text(""" + INSERT INTO user_roles_new (user_id, role_id) + SELECT DISTINCT user_id, role_id FROM user_roles; + """)) + + # 删除旧表,重命名新表 + op.drop_table('user_roles') + op.execute(text("ALTER TABLE user_roles_new RENAME TO user_roles;")) + + except Exception as e: + print(f"简化 user_roles 表时出错: {e}") + + +def downgrade(): + """回滚操作 - 重新创建权限相关表.""" + + # 注意:这是一个破坏性操作,回滚会丢失数据 + # 在生产环境中应该谨慎使用 + + print("警告:回滚操作会重新创建权限相关表,但不会恢复数据") + + # 重新创建基本的权限表结构(简化版) + op.create_table('permissions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(100), nullable=False), + sa.Column('code', sa.String(100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False, default=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('code') + ) + + op.create_table('role_permissions', + sa.Column('role_id', sa.Integer(), nullable=False), + sa.Column('permission_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('role_id', 'permission_id') + ) + + # 添加用户表的 department_id 字段 + op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True)) \ No newline at end of file diff --git a/th_agenter/llm/base_llm.py b/th_agenter/llm/base_llm.py new file mode 100644 index 0000000..3dc2d47 --- /dev/null +++ b/th_agenter/llm/base_llm.py @@ -0,0 +1,198 @@ +from loguru import logger +from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any + +# 核心:导入 LangChain 的基础语言模型抽象类 +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.callbacks import CallbackManagerForLLMRun +from dataclasses import dataclass, field +from typing import Optional, Dict, Any, List +from datetime import datetime + +@dataclass +class LLMConfig_DataClass: + """ + 统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段 + 通过 provider + is_embedding 区分模型类型: + - 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False + - 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False + - 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True + """ + # ====================== 数据库核心公共字段(必选/可选) ====================== + # 基础标识字段 + name: str # 模型自定义名称(如 "gpt-5") + model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5") + provider: str # 提供商(openai/llama/bge/zhipu 等) + id: Optional[int] = None # 数据库主键ID + description: Optional[str] = None # 模型描述 + is_active: bool = True # 是否启用 + is_default: bool = False # 是否默认模型 + is_embedding: bool = False # 是否为嵌入式模型(核心区分标识) + + # ====================== 通用生成参数(所有推理模型共用) ====================== + temperature: float = 0.7 # 生成温度(默认值对齐数据库示例) + max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例) + top_p: float = 0.6 # 采样Top-P + frequency_penalty: float = 0.0 # 频率惩罚 + presence_penalty: float = 0.0 # 存在惩罚 + + # ====================== 在线模型专属参数(非必填,仅在线模型生效) ====================== + api_key: Optional[str] = None # API密钥(在线模型必填) + base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1) + # timeout: int = 30 # 请求超时时间(秒) + max_retries: int = 3 # 最大重试次数 + api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview) + + # ====================== 本地模型专属参数(非必填,仅本地模型生效) ====================== + model_path: Optional[str] = None # 本地模型文件路径(本地模型必填) + device: str = "cpu" # 运行设备(cpu/cuda/mps) + n_ctx: int = 2048 # 上下文窗口大小 + n_threads: int = 8 # 推理线程数 + quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16) + load_in_8bit: bool = False # 是否8bit加载 + load_in_4bit: bool = False # 是否4bit加载 + prompt_template: Optional[str] = None # 自定义Prompt模板 + + # ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ====================== + normalize_embeddings: bool = True # 是否归一化向量 + batch_size: int = 32 # 批量编码大小 + encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数 + dimension: Optional[int] = None # 向量维度(如 768) + + # ====================== 元数据字段(数据库自动维护) ====================== + extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置 + usage_count: int = 0 # 使用次数 + last_used_at: Optional[datetime] = None # 最后使用时间 + created_at: Optional[datetime] = None # 创建时间 + updated_at: Optional[datetime] = None # 更新时间 + created_by: Optional[int] = None # 创建人ID + updated_by: Optional[int] = None # 更新人ID + + api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储) + + # ====================== 核心工具方法 ====================== + def __post_init__(self): + """后置初始化:自动校验和修正配置""" + # 1. 嵌入式模型强制清空推理参数(避免误用) + if self.is_embedding: + self.max_tokens = 0 + self.temperature = 0.0 + self.top_p = 0.0 + + # 2. 校验必填参数(按模型类型) + self._validate_required_fields() + + def _validate_required_fields(self): + """按模型类型校验必填参数""" + # 在线模型校验 + if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: + if not self.api_key: + raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key") + + # 本地模型校验 + if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: + if not self.model_path: + raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path") + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(用于存入/更新数据库)""" + return { + key: value for key, value in self.__dict__.items() + if not key.startswith('_') # 排除私有属性 + } + + @classmethod + def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass": + """从数据库字典初始化配置(核心方法)""" + # 1. 时间字段转换:字符串 → datetime + time_fields = ['last_used_at', 'created_at', 'updated_at'] + for field_name in time_fields: + val = db_dict.get(field_name) + if val and isinstance(val, str): + try: + db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00')) + except (ValueError, TypeError): + db_dict[field_name] = None + + # 2. 过滤数据库中无关字段(如 api_key_masked) + valid_fields = cls.__dataclass_fields__.keys() + filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields} + + # 3. 初始化并返回配置实例 + return cls(**filtered_dict) + + def get_model_type(self) -> str: + """快速判断模型类型(返回:online/local/embedding)""" + if self.is_embedding: + return "embedding" + if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']: + return "online" + if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']: + return "local" + return "unknown" + + +class BaseLLM(BaseChatModel): + """ + 继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类) + 使其能直接用于 create_agent + """ + # 配置参数(通过 __init__ 初始化) + config: Any = None + model: Any = None + + def __init__(self, config): + super().__init__() # 必须调用父类构造函数 + self.config = config + self.model = None + self._validate_config() + logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}") + + # ---------------------- 必须实现的核心抽象方法(LangChain 协议) ---------------------- + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + 核心同步生成方法(LangChain 要求必须实现) + messages: 消息列表(如 [HumanMessage(content="你好")]) + 返回 ChatResult 类型(LangChain 标准输出) + """ + logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法") + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ** kwargs: Any, + ) -> ChatResult: + """异步生成方法(LangChain 异步协议)""" + logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法") + + @property + def _llm_type(self) -> str: + """返回模型类型标识(如 "openai"、"llama"、"bge")""" + return self.__class__.__name__ + + def load_model(self) -> None: + """加载模型(自定义逻辑)""" + logger.error(f"{self.__class__.__name__} 未实现 load_model 方法") + + def close(self) -> None: + """释放资源(自定义逻辑)""" + if self.model: + logger.info(f"释放 {self.__class__.__name__} 模型资源") + self.model = None + + def __enter__(self): + self.load_model() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/th_agenter/llm/embed/embed_llm.py b/th_agenter/llm/embed/embed_llm.py new file mode 100644 index 0000000..f200e8b --- /dev/null +++ b/th_agenter/llm/embed/embed_llm.py @@ -0,0 +1,77 @@ +from typing import List +from langchain_core.embeddings import Embeddings +from loguru import logger +from th_agenter.llm.base_llm import BaseLLM + +class EmbedLLM(BaseLLM, Embeddings): + """嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel""" + def __init__(self, config): + logger.info(f"初始化 EmbedLLM 模型: {config.model_name}") + super().__init__(config) + logger.info(f"已加载 EmbedLLM 模型: {config.model_name}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """LangChain 要求的核心方法:批量文档向量化""" + pass + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """异步批量向量化""" + pass + + def embed_query(self, text: str) -> List[float]: + """单查询文本向量化""" + pass + + async def aembed_query(self, text: str) -> List[float]: + """异步单查询向量化""" + pass + +# 具体实现 BGE 嵌入式模型 +class BGEEmbedLLM(EmbedLLM): + def __init__(self, config): + super().__init__(config) + + def _validate_config(self): + if not self.config.model_name: + raise ValueError("必须配置 model_name") + + def load_model(self): + logger.info(f"正在加载 嵌入 模型: {self.config.model_name}") + if hasattr(self.config, 'provider') and self.config.provider == 'ollama': + from langchain_ollama import OllamaEmbeddings + self.model = OllamaEmbeddings( + model=self.config.model_name, + base_url=self.config.base_url if hasattr(self.config, 'base_url') else None + ) + else: + try: + from langchain_huggingface import HuggingFaceEmbeddings + self.model = HuggingFaceEmbeddings( + model_name=self.config.model_name, + model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"}, + encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True} + ) + except ImportError as e: + logger.error(f"Failed to load HuggingFaceEmbeddings: {e}") + logger.error("Please install sentence-transformers: pip install sentence-transformers") + raise + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + if not self.model: + self.load_model() + return self.model.embed_documents(texts) + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + if not self.model: + self.load_model() + return await self.model.aembed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + if not self.model: + self.load_model() + return self.model.embed_query(text) + + async def aembed_query(self, text: str) -> List[float]: + if not self.model: + self.load_model() + return await self.model.aembed_query(text) \ No newline at end of file diff --git a/th_agenter/llm/llm_model_base.py b/th_agenter/llm/llm_model_base.py new file mode 100644 index 0000000..ef6c161 --- /dev/null +++ b/th_agenter/llm/llm_model_base.py @@ -0,0 +1,70 @@ +import os, dotenv +from loguru import logger +from utils.Constant import Constant +from langchain_core.prompts import PromptTemplate +from langchain_core.messages import HumanMessage +# 加载环境变量 +dotenv.load_dotenv() +os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") +os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL") + +class LLM_Model_Base(object): + ''' + 语言模型基类 + 所有语言模型类的基类,定义了语言模型的基本属性和方法。 + - 语言模型名称, 缺省为"gpt-4o-mini" + - 温度,缺省为0.7 + - 语言模型实例, 由子类实现 + - 语言模型模式, 由子类实现 + - 语言模型名称, 用于描述语言模型, 在人机界面中显示 + + author: DrGraph + date: 2025-11-20 + ''' + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + self.model_name = model_name # 0.15 0.6 + self.temperature = temperature + self.llmModel = None + self.mode = Constant.LLM_MODE_NONE + self.name = '未知模型' + + def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str): + logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}") + prompt_template = PromptTemplate.from_template( + template="请回答以下问题: {question}", + ) + prompt_template_value = None + if methodType == "format": + # 方式1 - 使用format方法,取得字符串 + prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串 + logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法,取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}") + + if valueType == "str": + # 1.1 直接用字符串进行调用LLM的invoke + prompt_template_value = prompt_str + logger.info(f"{self.name} >>> 1.2.1 直接使用字符串") + + elif valueType == "messages": + # 1.2 由字符串,创建HumanMessage对象列表 + prompt_template_value = [HumanMessage(content=prompt)] + logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表") + + elif methodType == "invoke": + # 方式2 - 使用invoke方法,取得PromptValue + prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue + logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法,取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}") + if valueType == "str": + # 2.1 再倒回字符串方式 + prompt_template_value = prompt_value.to_string() + logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串") + elif valueType == "promptValue": + # 2.2 直接使用 prompt_value 作为 prompt_template_value + prompt_template_value = prompt_value + logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value") + elif valueType == "messages": + # 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表 + prompt_template_value = prompt_value.to_messages() + logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表") + + logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}") + return prompt_template_value \ No newline at end of file diff --git a/th_agenter/llm/llm_model_chat.py b/th_agenter/llm/llm_model_chat.py new file mode 100644 index 0000000..2549aa8 --- /dev/null +++ b/th_agenter/llm/llm_model_chat.py @@ -0,0 +1,29 @@ +from langchain_openai import ChatOpenAI +from loguru import logger + +from DrGraph.utils.Constant import Constant +from LLM.llm_model_base import LLM_Model_Base + +class Chat_LLM(LLM_Model_Base): + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '聊天模型' + self.mode = Constant.LLM_MODE_CHAT + self.llmModel = ChatOpenAI( + model_name=self.model_name, + temperature=self.temperature, + ) + + # 返回消息格式,以便在chatbot中显示 + def invoke(self, prompt: str): + prompt_template_value = self.buildPromptTemplateValue( + prompt=prompt, + methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE, + valueType=Constant.LLM_PROMPT_VALUE_MESSAGES) + try: + response = self.llmModel.invoke(prompt_template_value) + logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}") + # response = {"role": "assistant", "content": response.content} + except Exception as e: + logger.error(e) + return response diff --git a/th_agenter/llm/llm_model_nonchat.py b/th_agenter/llm/llm_model_nonchat.py new file mode 100644 index 0000000..605087c --- /dev/null +++ b/th_agenter/llm/llm_model_nonchat.py @@ -0,0 +1,44 @@ +''' +非聊天模型类,继承自 LLM_Model_Base + +author: DrGraph +date: 2025-11-20 +''' +from loguru import logger +from langchain_openai import OpenAI +from langchain_core.messages import AIMessage +from DrGraph.utils.Constant import Constant +from LLM.llm_model_base import LLM_Model_Base + + +class NonChat_LLM(LLM_Model_Base): + ''' + 非聊天模型类,继承自 LLM_Model_Base,调用这个非聊天模型OpenAI + - 语言模型名称, 缺省为"gpt-4o-mini" + - 温度,缺省为0.7 + - 语言模型名称 = "非聊天模型", 在人机界面中显示 + ''' + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '非聊天模型' + self.mode = Constant.LLM_MODE_NONCHAT + self.llmModel = OpenAI( + model_name=self.model_name, + temperature=self.temperature, + ) + # 返回消息格式,以便在chatbot中显示 + def invoke(self, prompt: str): + ''' + 调用非聊天模型,返回消息格式,以便在chatbot中显示 + prompt: 用户输入,为字符串类型 + return: 助手回复,为字符串类型 + ''' + logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}") + try: + response = self.llmModel.invoke(prompt) + logger.info(f"{self.name} >>> 1.2 助手回复: {type(response)}") + except Exception as e: + logger.error(e) + return response + + \ No newline at end of file diff --git a/th_agenter/llm/llm_model_ollama.py b/th_agenter/llm/llm_model_ollama.py new file mode 100644 index 0000000..0d5f787 --- /dev/null +++ b/th_agenter/llm/llm_model_ollama.py @@ -0,0 +1,28 @@ +from loguru import logger + +from utils.Constant import Constant +from th_agenter.llm.llm_model_base import LLM_Model_Base +from langchain_ollama import ChatOllama +class Chat_Ollama(LLM_Model_Base): + def __init__(self, base_url="http://127.0.0.1:11434", model_name: str = "OxW/Qwen3-0.6B-GGUF:latest", temperature: float = 0.7): + super().__init__(model_name, temperature) + self.name = '私有化Ollama模型' + self.base_url = base_url + self.llmModel = ChatOllama( + base_url = self.base_url, + model=model_name, + temperature=temperature + ) + self.mode = Constant.LLM_MODE_LOCAL_OLLAMA + + def invoke(self, prompt: str): + prompt_template_value = self.buildPromptTemplateValue( + prompt=prompt, + methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE, + valueType=Constant.LLM_PROMPT_VALUE_MESSAGES) + try: + response = self.llmModel.invoke(prompt_template_value) + logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}") + except Exception as e: + logger.error(e) + return response \ No newline at end of file diff --git a/th_agenter/llm/local/local_llm.py b/th_agenter/llm/local/local_llm.py new file mode 100644 index 0000000..8a0e717 --- /dev/null +++ b/th_agenter/llm/local/local_llm.py @@ -0,0 +1,68 @@ +from typing import List, Optional +from th_agenter.llm.base_llm import BaseLLM +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessage, HumanMessage +from langchain_core.outputs import ChatResult, ChatGeneration + + +class LocalLLM(BaseLLM): + def __init__(self, config): + super().__init__(config) + self.local_config = config + + def _validate_config(self): + if not self.local_config.model_path: + raise ValueError("LocalLLM 必须配置 model_path") + + def load_model(self): + from langchain_community.llms import LlamaCpp + self.model = LlamaCpp( + model_path=self.local_config.model_path, + temperature=self.local_config.temperature, + max_tokens=self.local_config.max_tokens, + n_ctx=self.local_config.n_ctx, + n_threads=self.local_config.n_threads, + verbose=False + ) + + @property + def _llm_type(self) -> str: + return "llama" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + # 适配 LlamaCpp(非 Chat 模型)的调用方式 + prompt = self._format_messages(messages) + text = self.model.invoke(prompt, stop=stop, **kwargs) + # 构造 ChatResult(LangChain 标准格式) + generation = ChatGeneration(message=AIMessage(content=text)) + return ChatResult(generations=[generation]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + prompt = self._format_messages(messages) + text = await self.model.ainvoke(prompt, stop=stop, **kwargs) + generation = ChatGeneration(message=AIMessage(content=text)) + return ChatResult(generations=[generation]) + + def _format_messages(self, messages: List[BaseMessage]) -> str: + """将 LangChain 消息列表格式化为本地模型的 Prompt""" + prompt_parts = [] + for msg in messages: + if isinstance(msg, HumanMessage): + prompt_parts.append(f"[INST] {msg.content} [/INST]") + elif isinstance(msg, AIMessage): + prompt_parts.append(msg.content) + return "".join(prompt_parts) diff --git a/th_agenter/llm/online/online_llm.py b/th_agenter/llm/online/online_llm.py new file mode 100644 index 0000000..1aa9861 --- /dev/null +++ b/th_agenter/llm/online/online_llm.py @@ -0,0 +1,80 @@ +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, BaseMessage +from typing import List, Optional, Any, Union +from langchain_core.outputs import ChatResult +from th_agenter.llm.base_llm import BaseLLM +from langchain_core.callbacks import CallbackManagerForLLMRun + +class OnlineLLM(BaseLLM): + def __init__(self, config): + super().__init__(config) + + def _validate_config(self): + if not self.config.api_key: + raise ValueError("OnlineLLM 必须配置 api_key") + + def load_model(self): + # from langchain.chat_models import init_chat_model + # self.model = init_chat_model( + # self.config.model_name, + # self.config.api_key) + from langchain_openai import ChatOpenAI + self.model = ChatOpenAI( + api_key=self.config.api_key, + model_name=self.config.model_name, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + base_url=self.config.base_url, + ) + @property + def _llm_type(self) -> str: + return "openai" # 标识模型类型 + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + """委托给底层 LangChain 模型的 _generate 方法""" + if not self.model: + self.load_model() + # 复用底层模型的实现 + return self.model._generate( + messages=messages, + stop=stop, + run_manager=run_manager,** kwargs + ) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any, + ) -> ChatResult: + if not self.model: + self.load_model() + return await self.model._agenerate( + messages=messages, + stop=stop, + run_manager=run_manager,** kwargs + ) + + # ---------------------- 保留自定义的便捷方法 ---------------------- + def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: + """自定义便捷方法:直接传入字符串 prompt 或消息列表""" + if isinstance(prompt, str): + messages = [HumanMessage(content=prompt)] + else: + messages = prompt + result = self._generate(messages, **kwargs) + return result.generations[0].text + + async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str: + """自定义便捷异步方法:直接传入字符串 prompt 或消息列表""" + if isinstance(prompt, str): + messages = [HumanMessage(content=prompt)] + else: + messages = prompt + result = await self._agenerate(messages, **kwargs) + return result.generations[0].text \ No newline at end of file diff --git a/th_agenter/models/__init__.py b/th_agenter/models/__init__.py new file mode 100644 index 0000000..b778931 --- /dev/null +++ b/th_agenter/models/__init__.py @@ -0,0 +1,31 @@ +"""Database models for TH Agenter.""" + +from .user import User +from .conversation import Conversation +from .message import Message +from .knowledge_base import KnowledgeBase, Document +from .agent_config import AgentConfig +from .excel_file import ExcelFile +from .permission import Role, UserRole +from .llm_config import LLMConfig +from .workflow import Workflow, WorkflowExecution, NodeExecution +from .database_config import DatabaseConfig +from .table_metadata import TableMetadata + +__all__ = [ + "User", + "Conversation", + "Message", + "KnowledgeBase", + "Document", + "AgentConfig", + "ExcelFile", + "Role", + "UserRole", + "LLMConfig", + "Workflow", + "WorkflowExecution", + "NodeExecution", + "DatabaseConfig", + "TableMetadata" +] \ No newline at end of file diff --git a/th_agenter/models/agent_config.py b/th_agenter/models/agent_config.py new file mode 100644 index 0000000..82a8ba9 --- /dev/null +++ b/th_agenter/models/agent_config.py @@ -0,0 +1,44 @@ +"""Agent configuration model.""" + +from typing import Optional +from sqlalchemy import String, Text, Boolean, JSON +from sqlalchemy.orm import Mapped, mapped_column +from ..db.base import BaseModel + + +class AgentConfig(BaseModel): + """Agent configuration model.""" + + __tablename__ = "agent_configs" + + id: Mapped[int] = mapped_column(primary_key=True, index=True) + name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Agent configuration + enabled_tools: Mapped[list] = mapped_column(JSON, nullable=False, default=list) + max_iterations: Mapped[int] = mapped_column(default=10) + temperature: Mapped[str] = mapped_column(String(10), default="0.1") + system_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + verbose: Mapped[bool] = mapped_column(Boolean, default=True) + + # Model configuration + model_name: Mapped[str] = mapped_column(String(100), default="gpt-3.5-turbo") + max_tokens: Mapped[int] = mapped_column(default=2048) + + # Status + is_active: Mapped[bool] = mapped_column(default=True) + is_default: Mapped[bool] = mapped_column(default=False) + + + def __repr__(self): + return f"" + + def __str__(self): + return f"{self.name}[{self.id}] Active: {self.is_active}" + + def to_dict(self): + """Convert to dictionary.""" + data = super().to_dict() + data['enabled_tools'] = self.enabled_tools or [] + return data \ No newline at end of file diff --git a/th_agenter/models/conversation.py b/th_agenter/models/conversation.py new file mode 100644 index 0000000..86abb63 --- /dev/null +++ b/th_agenter/models/conversation.py @@ -0,0 +1,44 @@ +"""Conversation model.""" + +from datetime import datetime +from typing import Optional +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import String, Integer, Text, Boolean, DateTime + +from ..db.base import BaseModel + +class Conversation(BaseModel): + """Conversation model.""" + + __tablename__ = "conversations" + + title: Mapped[str] = mapped_column(String(200), nullable=False) + user_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("users.id") + knowledge_base_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id") + system_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + model_name: Mapped[str] = mapped_column(String(100), nullable=False, default="gpt-3.5-turbo") + temperature: Mapped[str] = mapped_column(String(10), nullable=False, default="0.7") + max_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=2048) + is_archived: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + message_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + last_message_at: Mapped[Optional[datetime]] = mapped_column(nullable=True) + + # Relationships removed to eliminate foreign key constraints + + def to_dict(self) -> dict: + """Convert conversation to a dictionary.""" + return { + "id": self.id, + "title": self.title, + "user_id": self.user_id, + "knowledge_base_id": self.knowledge_base_id, + "system_prompt": self.system_prompt, + "model_name": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "is_archived": self.is_archived, + "message_count": self.message_count, + "last_message_at": self.last_message_at, + } + def __repr__(self): + return f"" diff --git a/th_agenter/models/database_config.py b/th_agenter/models/database_config.py new file mode 100644 index 0000000..e4c7d85 --- /dev/null +++ b/th_agenter/models/database_config.py @@ -0,0 +1,54 @@ +"""数据库配置模型""" + +from typing import Optional +from loguru import logger +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import Integer, String, Text, Boolean, JSON +from ..db.base import BaseModel + + +# 在现有的DatabaseConfig类中添加关系 +from sqlalchemy.orm import relationship + +class DatabaseConfig(BaseModel): + """数据库配置表""" + __tablename__ = "database_configs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + name: Mapped[str] = mapped_column(String(100), nullable=False) # 配置名称 + db_type: Mapped[str] = mapped_column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等 + host: Mapped[str] = mapped_column(String(255), nullable=False) + port: Mapped[int] = mapped_column(Integer, nullable=False) + database: Mapped[str] = mapped_column(String(100), nullable=False) + username: Mapped[str] = mapped_column(String(100), nullable=False) + password: Mapped[str] = mapped_column(Text, nullable=False) # 加密存储 + is_active: Mapped[bool] = mapped_column(Boolean, default=True) + is_default: Mapped[bool] = mapped_column(Boolean, default=False) + connection_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 额外连接参数 + + def to_dict(self, include_password=False, decrypt_service=None): + result = { + "id": self.id, + "created_by": self.created_by, + "name": self.name, + "db_type": self.db_type, + "host": self.host, + "port": self.port, + "database": self.database, + "username": self.username, + "is_active": self.is_active, + "is_default": self.is_default, + "connection_params": self.connection_params, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None + } + + # 如果需要包含密码且提供了解密服务 + if include_password and decrypt_service: + logger.info(f"begin decrypt password for db config {self.id}") + result["password"] = decrypt_service._decrypt_password(self.password) + + return result + + # 添加关系 + # table_metadata = relationship("TableMetadata", back_populates="database_config") \ No newline at end of file diff --git a/th_agenter/models/excel_file.py b/th_agenter/models/excel_file.py new file mode 100644 index 0000000..9f1e44e --- /dev/null +++ b/th_agenter/models/excel_file.py @@ -0,0 +1,87 @@ +"""Excel file models for smart query.""" + +from datetime import datetime +from typing import Optional +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import String, Integer, Text, Boolean, JSON, DateTime +from sqlalchemy.sql import func + +from ..db.base import BaseModel + +class ExcelFile(BaseModel): + """Excel file model for storing file metadata.""" + __tablename__ = "excel_files" + # Basic file information + # user_id: Mapped[int] = mapped_column(Integer, nullable=False) # 用户ID + original_filename: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名 + file_path: Mapped[str] = mapped_column(String(500), nullable=False) # 文件存储路径 + file_size: Mapped[int] = mapped_column(Integer, nullable=False) # 文件大小(字节) + file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv) + + # Excel specific information + sheet_names: Mapped[list] = mapped_column(JSON, nullable=False) # 所有sheet名称列表 + default_sheet: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # 默认sheet名称 + + # Data preview information + columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]} + preview_data: Mapped[dict] = mapped_column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]} + data_types: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}} + + # Statistics + total_rows: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count} + total_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count} + + # Processing status + is_processed: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否已处理 + processing_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 处理错误信息 + + # Upload information + # upload_time: Mapped[DateTime] = mapped_column(DateTime, default=func.now(), nullable=False) # 上传时间 + last_accessed: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后访问时间 + + def __repr__(self): + return f"" + + @property + def file_size_mb(self): + """Get file size in MB.""" + return round(self.file_size / (1024 * 1024), 2) + + @property + def sheet_count(self): + """Get number of sheets.""" + return len(self.sheet_names) if self.sheet_names else 0 + + def get_sheet_info(self, sheet_name: str = None): + """Get information for a specific sheet or default sheet.""" + if not sheet_name: + sheet_name = self.default_sheet or (self.sheet_names[0] if self.sheet_names else None) + + if not sheet_name or sheet_name not in self.sheet_names: + return None + + return { + 'sheet_name': sheet_name, + 'columns': self.columns_info.get(sheet_name, []) if self.columns_info else [], + 'preview_data': self.preview_data.get(sheet_name, []) if self.preview_data else [], + 'data_types': self.data_types.get(sheet_name, {}) if self.data_types else {}, + 'total_rows': self.total_rows.get(sheet_name, 0) if self.total_rows else 0, + 'total_columns': self.total_columns.get(sheet_name, 0) if self.total_columns else 0 + } + + def get_all_sheets_summary(self): + """Get summary information for all sheets.""" + if not self.sheet_names: + return [] + + summary = [] + for sheet_name in self.sheet_names: + sheet_info = self.get_sheet_info(sheet_name) + if sheet_info: + summary.append({ + 'sheet_name': sheet_name, + 'columns_count': len(sheet_info['columns']), + 'rows_count': sheet_info['total_rows'], + 'columns': sheet_info['columns'][:10] # 只显示前10列 + }) + return summary \ No newline at end of file diff --git a/th_agenter/models/knowledge_base.py b/th_agenter/models/knowledge_base.py new file mode 100644 index 0000000..49a58b7 --- /dev/null +++ b/th_agenter/models/knowledge_base.py @@ -0,0 +1,93 @@ +"""Knowledge base models.""" + +from typing import Optional +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import String, Integer, Text, Boolean, JSON + +from ..db.base import BaseModel + +class KnowledgeBase(BaseModel): + """Knowledge base model.""" + + __tablename__ = "knowledge_bases" + + name: Mapped[str] = mapped_column(String(100), unique=False, index=True, nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + embedding_model: Mapped[str] = mapped_column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2") + chunk_size: Mapped[int] = mapped_column(Integer, nullable=False, default=1000) + chunk_overlap: Mapped[int] = mapped_column(Integer, nullable=False, default=200) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # Vector database settings + vector_db_type: Mapped[str] = mapped_column(String(50), nullable=False, default="chroma") + collection_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # For vector DB collection + + # Relationships removed to eliminate foreign key constraints + + def __repr__(self): + return f"" + + # Relationships are commented out to remove foreign key constraints, so these properties should be updated + # @property + # def document_count(self): + # """Get the number of documents in this knowledge base.""" + # return len(self.documents) + + # @property + # def active_document_count(self): + # """Get the number of active documents in this knowledge base.""" + # return len([doc for doc in self.documents if doc.is_processed]) + + +class Document(BaseModel): + """Document model.""" + + __tablename__ = "documents" + + knowledge_base_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id") + filename: Mapped[str] = mapped_column(String(255), nullable=False) + original_filename: Mapped[str] = mapped_column(String(255), nullable=False) + file_path: Mapped[str] = mapped_column(String(500), nullable=False) + file_size: Mapped[int] = mapped_column(Integer, nullable=False) # in bytes + file_type: Mapped[str] = mapped_column(String(50), nullable=False) # .pdf, .txt, .docx, etc. + mime_type: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + # Processing status + is_processed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + processing_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # Content and metadata + content: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # Extracted text content + doc_metadata: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Additional metadata + + # Chunking information + chunk_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + + # Embedding information + embedding_model: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + vector_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # Store vector database IDs for chunks + + # Relationships removed to eliminate foreign key constraints + + def __repr__(self): + return f"" + + @property + def file_size_mb(self): + """Get file size in MB.""" + return round(self.file_size / (1024 * 1024), 2) + + @property + def is_text_file(self): + """Check if document is a text file.""" + return self.file_type.lower() in ['.txt', '.md', '.csv'] + + @property + def is_pdf_file(self): + """Check if document is a PDF file.""" + return self.file_type.lower() == '.pdf' + + @property + def is_office_file(self): + """Check if document is an Office file.""" + return self.file_type.lower() in ['.docx', '.xlsx', '.pptx'] \ No newline at end of file diff --git a/th_agenter/models/llm_config.py b/th_agenter/models/llm_config.py new file mode 100644 index 0000000..de402b5 --- /dev/null +++ b/th_agenter/models/llm_config.py @@ -0,0 +1,162 @@ +"""LLM Configuration model for managing multiple AI models.""" + +from datetime import datetime +from typing import Dict, Any, Optional +from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from ..db.base import BaseModel + +class LLMConfig(BaseModel): + """LLM Configuration model for managing AI model settings.""" + __tablename__ = "llm_configs" + + name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称 + provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu + model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称 + api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥(加密存储) + base_url: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) # API基础URL + + # 模型参数 + max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False) + temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False) + top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False) + frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) + presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) + + # 配置信息 + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 配置描述 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用 + is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置 + is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型 + + # 扩展配置(JSON格式) + extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数 + + # 使用统计 + usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数 + last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间 + + def __repr__(self): + return f"" + + def to_dict(self, include_sensitive=False): + """Convert to dictionary, optionally excluding sensitive data.""" + data = super().to_dict() + data.update({ + 'name': self.name, + 'provider': self.provider, + 'model_name': self.model_name, + 'base_url': self.base_url, + 'max_tokens': self.max_tokens, + 'temperature': self.temperature, + 'top_p': self.top_p, + 'frequency_penalty': self.frequency_penalty, + 'presence_penalty': self.presence_penalty, + 'description': self.description, + 'is_active': self.is_active, + 'is_default': self.is_default, + 'is_embedding': self.is_embedding, + 'extra_config': self.extra_config, + 'usage_count': self.usage_count, + 'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None + }) + + if include_sensitive: + data['api_key'] = self.api_key + else: + # 只显示API密钥的前几位和后几位 + if self.api_key: + key_len = len(self.api_key) + if key_len > 8: + data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}" + else: + data['api_key_masked'] = "***" + else: + data['api_key_masked'] = None + + return data + + def get_client_config(self) -> Dict[str, Any]: + """获取用于创建客户端的配置.""" + config = { + 'api_key': self.api_key, + 'base_url': self.base_url, + 'model': self.model_name, + 'max_tokens': self.max_tokens, + 'temperature': self.temperature, + 'top_p': self.top_p, + 'frequency_penalty': self.frequency_penalty, + 'presence_penalty': self.presence_penalty + } + + # 添加额外配置 + if self.extra_config: + config.update(self.extra_config) + + return config + + def validate_config(self) -> Dict[str, Any]: + """验证配置是否有效.""" + if not self.name or not self.name.strip(): + return {"valid": False, "error": "配置名称不能为空"} + + if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']: + return {"valid": False, "error": f"不支持的服务商 {self.provider}"} + + if not self.model_name or not self.model_name.strip(): + return {"valid": False, "error": "模型名称不能为空"} + + if not self.api_key or not self.api_key.strip(): + return {"valid": False, "error": "API密钥不能为空"} + + if self.max_tokens <= 0 or self.max_tokens > 32000: + return {"valid": False, "error": "最大令牌数必须在1-32000之间"} + + if self.temperature < 0 or self.temperature > 2: + return {"valid": False, "error": "温度参数必须在0-2之间"} + + return {"valid": True, "error": None} + + def increment_usage(self): + """增加使用次数.""" + self.usage_count += 1 + self.last_used_at = datetime.now() + + @classmethod + def get_default_config(cls, provider: str, is_embedding: bool = False): + """获取服务商的默认配置模板.""" + templates = { + 'openai': { + 'base_url': 'https://api.openai.com/v1', + 'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002', + 'max_tokens': 2048, + 'temperature': 0.7 + }, + 'deepseek': { + 'base_url': 'https://api.deepseek.com/v1', + 'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding', + 'max_tokens': 2048, + 'temperature': 0.7 + }, + 'doubao': { + 'base_url': 'https://ark.cn-beijing.volces.com/api/v3', + 'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding', + 'max_tokens': 2048, + 'temperature': 0.7 + }, + 'zhipu': { + 'base_url': 'https://open.bigmodel.cn/api/paas/v4', + 'model_name': 'glm-4' if not is_embedding else 'embedding-3', + 'max_tokens': 2048, + 'temperature': 0.7 + }, + 'moonshot': { + 'base_url': 'https://api.moonshot.cn/v1', + 'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding', + 'max_tokens': 2048, + 'temperature': 0.7 + } + } + + return templates.get(provider, {}) \ No newline at end of file diff --git a/th_agenter/models/message.py b/th_agenter/models/message.py new file mode 100644 index 0000000..009e876 --- /dev/null +++ b/th_agenter/models/message.py @@ -0,0 +1,70 @@ +"""Message model.""" + +from typing import Optional +from sqlalchemy import String, Integer, Text, Enum, JSON +from sqlalchemy.orm import Mapped, mapped_column +import enum + +from ..db.base import BaseModel + + +class MessageRole(str, enum.Enum): + """Message role enumeration.""" + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class MessageType(str, enum.Enum): + """Message type enumeration.""" + TEXT = "text" + IMAGE = "image" + FILE = "file" + AUDIO = "audio" + + +class Message(BaseModel): + """Message model.""" + + __tablename__ = "messages" + + conversation_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("conversations.id") + role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False) + content: Mapped[str] = mapped_column(Text, nullable=False) + message_type: Mapped[MessageType] = mapped_column(Enum(MessageType), default=MessageType.TEXT, nullable=False) + message_metadata: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Store additional data like file info, tokens used, etc. + + # For knowledge base context + context_documents: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Store retrieved document references + + # Token usage tracking + prompt_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + completion_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + total_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + + # Relationships removed to eliminate foreign key constraints + + def __repr__(self): + content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content + return f"" + + def to_dict(self, include_metadata=True): + """Convert to dictionary.""" + data = super().to_dict() + if not include_metadata: + data.pop('message_metadata', None) + data.pop('context_documents', None) + data.pop('prompt_tokens', None) + data.pop('completion_tokens', None) + data.pop('total_tokens', None) + return data + + @property + def is_from_user(self): + """Check if message is from user.""" + return self.role == MessageRole.USER + + @property + def is_from_assistant(self): + """Check if message is from assistant.""" + return self.role == MessageRole.ASSISTANT \ No newline at end of file diff --git a/th_agenter/models/permission.py b/th_agenter/models/permission.py new file mode 100644 index 0000000..732910d --- /dev/null +++ b/th_agenter/models/permission.py @@ -0,0 +1,53 @@ +"""Role models for simplified RBAC system.""" + +from typing import List, Dict, Any, Optional +from sqlalchemy import String, Text, Boolean, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from ..db.base import BaseModel, Base + + +class Role(BaseModel): + """Role model for simplified RBAC system.""" + + __tablename__ = "roles" + + name: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色名称 + code: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色编码 + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 角色描述 + is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否系统角色 + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # 关系 - 只保留用户关系 + users = relationship("User", secondary="user_roles", back_populates="roles") + + def __repr__(self): + return f"" + + def to_dict(self): + """Convert to dictionary.""" + data = super().to_dict() + data.update({ + 'name': self.name, + 'code': self.code, + 'description': self.description, + 'is_system': self.is_system, + 'is_active': self.is_active + }) + return data + + +class UserRole(Base): + """User role association model.""" + + __tablename__ = "user_roles" + + user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'), primary_key=True) + role_id: Mapped[int] = mapped_column(Integer, ForeignKey('roles.id'), primary_key=True) + + # 关系 - 用于直接操作关联表的场景 + user = relationship("User", viewonly=True) + role = relationship("Role", viewonly=True) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/th_agenter/models/table_metadata.py b/th_agenter/models/table_metadata.py new file mode 100644 index 0000000..621ae08 --- /dev/null +++ b/th_agenter/models/table_metadata.py @@ -0,0 +1,61 @@ +"""表元数据模型""" + +from datetime import datetime +from typing import Optional +from sqlalchemy import Integer, String, Text, DateTime, Boolean, JSON +from sqlalchemy.orm import Mapped, mapped_column +from ..db.base import BaseModel + +class TableMetadata(BaseModel): + """表元数据表""" + __tablename__ = "table_metadata" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + # database_config_id = Column(Integer, ForeignKey('database_configs.id'), nullable=False) + table_name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) + table_schema: Mapped[str] = mapped_column(String(50), default='public') + table_type: Mapped[str] = mapped_column(String(20), default='BASE TABLE') + table_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 表描述 + database_config_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) #数据库配置ID + # 表结构信息 + columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:名称、类型、注释等 + primary_keys: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 主键列表 + foreign_keys: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 外键信息 + indexes: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 索引信息 + + # 示例数据 + sample_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 前5条示例数据 + row_count: Mapped[int] = mapped_column(Integer, default=0) # 总行数 + + # 问答相关 + is_enabled_for_qa: Mapped[bool] = mapped_column(Boolean, default=True) # 是否启用问答 + qa_description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 问答描述 + business_context: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 业务上下文 + + last_synced_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) # 最后同步时间 + + # 关系 + # database_config = relationship("DatabaseConfig", back_populates="table_metadata") + + def to_dict(self): + return { + "id": self.id, + "created_by": self.created_by, # 改为created_by + "database_config_id": self.database_config_id, + "table_name": self.table_name, + "table_schema": self.table_schema, + "table_type": self.table_type, + "table_comment": self.table_comment, + "columns_info": self.columns_info, + "primary_keys": self.primary_keys, + # "foreign_keys": self.foreign_keys, + "indexes": self.indexes, + "sample_data": self.sample_data, + "row_count": self.row_count, + "is_enabled_for_qa": self.is_enabled_for_qa, + "qa_description": self.qa_description, + "business_context": self.business_context, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + "last_synced_at": self.last_synced_at.isoformat() if self.last_synced_at else None + } \ No newline at end of file diff --git a/th_agenter/models/user.py b/th_agenter/models/user.py new file mode 100644 index 0000000..9d75837 --- /dev/null +++ b/th_agenter/models/user.py @@ -0,0 +1,121 @@ +"""User model.""" + +from sqlalchemy import String, Boolean, Text +from sqlalchemy.orm import relationship, Mapped, mapped_column +from typing import List, Optional +from loguru import logger + +from ..db.base import BaseModel + + +class User(BaseModel): + """User model.""" + + __tablename__ = "users" + + username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False) + email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False) + hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) + full_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + avatar_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + # 关系 - 只保留角色关系 + roles = relationship("Role", secondary="user_roles", back_populates="users") + + def __repr__(self): + return f"" + + def to_dict(self, include_sensitive=False, include_roles=False): + """Convert to dictionary, optionally excluding sensitive data.""" + data = super().to_dict() + data.update({ + 'username': self.username, + 'email': self.email, + 'full_name': self.full_name, + 'is_active': self.is_active, + 'avatar_url': self.avatar_url, + 'bio': self.bio, + 'is_superuser': self.is_admin # 使用同步的 is_admin 属性代替异步的 is_superuser 方法 + }) + + if not include_sensitive: + data.pop('hashed_password', None) + + if include_roles: + try: + # 安全访问roles关系属性 + data['roles'] = [role.to_dict() for role in self.roles if role.is_active] + except Exception: + # 如果角色关系未加载或访问出错,返回空列表 + data['roles'] = [] + + return data + + async def has_role(self, role_code: str) -> bool: + """检查用户是否拥有指定角色.""" + try: + # 在异步环境中,需要先加载关系属性 + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import object_session + from sqlalchemy import select + from .permission import Role, UserRole + + session = object_session(self) + if isinstance(session, AsyncSession): + # 如果是异步会话,使用await加载关系 + await session.refresh(self, ['roles']) + return any(role.code == role_code and role.is_active for role in self.roles) + except Exception: + # 如果对象已分离或加载关系失败,使用数据库查询 + from sqlalchemy.orm import object_session + from sqlalchemy import select + from .permission import Role, UserRole + + session = object_session(self) + if session is None: + # 如果没有会话,返回False + return False + else: + from sqlalchemy.ext.asyncio import AsyncSession + if isinstance(session, AsyncSession): + # 如果是异步会话,使用异步查询 + user_role = await session.execute( + select(UserRole).join(Role).filter( + UserRole.user_id == self.id, + Role.code == role_code, + Role.is_active == True + ) + ) + return user_role.scalar_one_or_none() is not None + else: + # 如果是同步会话,使用同步查询 + user_role = session.query(UserRole).join(Role).filter( + UserRole.user_id == self.id, + Role.code == role_code, + Role.is_active == True + ).first() + return user_role is not None + + async def is_superuser(self) -> bool: + """检查用户是否为超级管理员.""" + return await self.has_role('SUPER_ADMIN') + + async def is_admin_user(self) -> bool: + """检查用户是否为管理员(兼容性方法).""" + return await self.is_superuser() + + # 注意:属性方式的 is_admin 无法是异步的,所以我们改为同步方法并简化实现 + @property + def is_admin(self) -> bool: + """检查用户是否为管理员(属性方式).""" + # 同步属性无法使用 await,所以我们只能检查已加载的角色 + # 使用try-except捕获可能的MissingGreenlet错误 + try: + # 检查角色关系是否已经加载 + # 如果roles属性是一个InstrumentedList且已经加载,那么它应该有__iter__方法 + return any(role.code == 'SUPER_ADMIN' and role.is_active for role in self.roles) + except Exception: + # 如果角色关系未加载或访问出错,返回False + return False \ No newline at end of file diff --git a/th_agenter/models/workflow.py b/th_agenter/models/workflow.py new file mode 100644 index 0000000..1df374d --- /dev/null +++ b/th_agenter/models/workflow.py @@ -0,0 +1,167 @@ +"""Workflow models.""" + +from typing import Optional +from sqlalchemy import String, Text, Boolean, Integer, JSON, ForeignKey, Enum +from sqlalchemy.orm import relationship, Mapped, mapped_column +import enum + +from ..db.base import BaseModel + +class WorkflowStatus(enum.Enum): + """工作流状态枚举""" + DRAFT = "DRAFT" # 草稿 + PUBLISHED = "PUBLISHED" # 已发布 + ARCHIVED = "ARCHIVED" # 已归档 + +class NodeType(enum.Enum): + """节点类型枚举""" + START = "start" # 开始节点 + END = "end" # 结束节点 + LLM = "llm" # 大模型节点 + CONDITION = "condition" # 条件分支节点 + LOOP = "loop" # 循环节点 + CODE = "code" # 代码执行节点 + HTTP = "http" # HTTP请求节点 + TOOL = "tool" # 工具节点 + +class ExecutionStatus(enum.Enum): + """执行状态枚举""" + PENDING = "pending" # 等待执行 + RUNNING = "running" # 执行中 + COMPLETED = "completed" # 执行完成 + FAILED = "failed" # 执行失败 + CANCELLED = "cancelled" # 已取消 + +class Workflow(BaseModel): + """工作流模型""" + __tablename__ = "workflows" + + name: Mapped[str] = mapped_column(String(100), nullable=False, comment="工作流名称") + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="工作流描述") + status: Mapped[WorkflowStatus] = mapped_column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态") + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, comment="是否激活") + + # 工作流定义(JSON格式存储节点和连接信息) + definition: Mapped[dict] = mapped_column(JSON, nullable=False, comment="工作流定义") + + # 版本信息 + version: Mapped[str] = mapped_column(String(20), default="1.0.0", nullable=False, comment="版本号") + + # 关联用户 + owner_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID") + + # 关系 + executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + def to_dict(self, include_definition=True): + """转换为字典""" + data = super().to_dict() + data.update({ + 'name': self.name, + 'description': self.description, + 'status': self.status.value, + 'is_active': self.is_active, + 'version': self.version, + 'owner_id': self.owner_id + }) + + if include_definition: + data['definition'] = self.definition + + return data + +class WorkflowExecution(BaseModel): + """工作流执行记录""" + + __tablename__ = "workflow_executions" + + workflow_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID") + status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态") + + # 执行输入和输出 + input_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输入数据") + output_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输出数据") + + # 执行信息 + started_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="开始时间") + completed_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="完成时间") + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="错误信息") + + # 执行者 + executor_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID") + + # 关系 + workflow = relationship("Workflow", back_populates="executions") + node_executions = relationship("NodeExecution", back_populates="workflow_execution", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + def to_dict(self, include_nodes=False): + """转换为字典""" + data = super().to_dict() + data.update({ + 'workflow_id': self.workflow_id, + 'status': self.status.value, + 'input_data': self.input_data, + 'output_data': self.output_data, + 'started_at': self.started_at, + 'completed_at': self.completed_at, + 'error_message': self.error_message, + 'executor_id': self.executor_id + }) + + if include_nodes: + data['node_executions'] = [node.to_dict() for node in self.node_executions] + + return data + +class NodeExecution(BaseModel): + """节点执行记录""" + __tablename__ = "node_executions" + + workflow_execution_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID") + node_id: Mapped[str] = mapped_column(String(50), nullable=False, comment="节点ID") + node_type: Mapped[NodeType] = mapped_column(Enum(NodeType), nullable=False, comment="节点类型") + node_name: Mapped[str] = mapped_column(String(100), nullable=False, comment="节点名称") + + # 执行状态和结果 + status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态") + input_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输入数据") + output_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输出数据") + + # 执行时间 + started_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="开始时间") + completed_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="完成时间") + duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, comment="执行时长(毫秒)") + + # 错误信息 + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="错误信息") + + # 关系 + workflow_execution = relationship("WorkflowExecution", back_populates="node_executions") + + def __repr__(self): + return f"" + + def to_dict(self): + """转换为字典""" + data = super().to_dict() + data.update({ + 'workflow_execution_id': self.workflow_execution_id, + 'node_id': self.node_id, + 'node_type': self.node_type.value, + 'node_name': self.node_name, + 'status': self.status.value, + 'input_data': self.input_data, + 'output_data': self.output_data, + 'started_at': self.started_at, + 'completed_at': self.completed_at, + 'duration_ms': self.duration_ms, + 'error_message': self.error_message + }) + + return data \ No newline at end of file diff --git a/th_agenter/schemas/__init__.py b/th_agenter/schemas/__init__.py new file mode 100644 index 0000000..ff37910 --- /dev/null +++ b/th_agenter/schemas/__init__.py @@ -0,0 +1,16 @@ +"""Schemas package initialization.""" + +from .user import UserCreate, UserUpdate, UserResponse +from .permission import ( + RoleCreate, RoleUpdate, RoleResponse, + UserRoleAssign +) + +__all__ = [ + # User schemas + "UserCreate", "UserUpdate", "UserResponse", + + # Permission schemas + "RoleCreate", "RoleUpdate", "RoleResponse", + "UserRoleAssign", +] \ No newline at end of file diff --git a/th_agenter/schemas/llm_config.py b/th_agenter/schemas/llm_config.py new file mode 100644 index 0000000..da1aaab --- /dev/null +++ b/th_agenter/schemas/llm_config.py @@ -0,0 +1,156 @@ +"""LLM Configuration Pydantic schemas.""" + +from typing import Optional, Dict, Any +from pydantic import BaseModel, Field, field_validator, computed_field +from datetime import datetime + + +class LLMConfigBase(BaseModel): + """大模型配置基础模式.""" + name: str = Field(..., min_length=1, max_length=100, description="配置名称") + provider: str = Field(..., min_length=1, max_length=50, description="服务商") + model_name: str = Field(..., min_length=1, max_length=100, description="模型名称") + api_key: str = Field(..., min_length=1, description="API密钥") + base_url: Optional[str] = Field(None, description="API基础URL") + max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数") + temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数") + top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数") + frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚") + presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚") + description: Optional[str] = Field(None, max_length=500, description="配置描述") + + is_active: bool = Field(True, description="是否激活") + is_default: bool = Field(False, description="是否为默认配置") + is_embedding: bool = Field(False, description="是否为嵌入模型") + extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置") + + +class LLMConfigCreate(LLMConfigBase): + """创建大模型配置模式.""" + + @field_validator('provider') + @classmethod + def validate_provider(cls, v: str) -> str: + allowed_providers = [ + 'openai', 'azure', 'anthropic', 'google', 'baidu', + 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', + 'ollama', 'custom', "doubao", "ollama" + ] + if v.lower() not in allowed_providers: + raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') + return v.lower() + + @field_validator('api_key') + @classmethod + def validate_api_key(cls, v: str) -> str: + if len(v.strip()) < 10: + raise ValueError('API密钥长度不能少于10个字符') + return v.strip() + + +class LLMConfigUpdate(BaseModel): + """更新大模型配置模式.""" + name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称") + provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商") + model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称") + api_key: Optional[str] = Field(None, min_length=1, description="API密钥") + base_url: Optional[str] = Field(None, description="API基础URL") + max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数") + temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数") + top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数") + frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚") + presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚") + description: Optional[str] = Field(None, max_length=500, description="配置描述") + + is_active: Optional[bool] = Field(None, description="是否激活") + is_default: Optional[bool] = Field(None, description="是否为默认配置") + is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型") + extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置") + + @field_validator('provider') + @classmethod + def validate_provider(cls, v: Optional[str]) -> Optional[str]: + if v is not None: + allowed_providers = [ + 'openai', 'azure', 'anthropic', 'google', 'baidu', + 'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek', + 'ollama', 'custom',"doubao", "ollama" + ] + if v.lower() not in allowed_providers: + raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}') + return v.lower() + return v + + @field_validator('api_key') + @classmethod + def validate_api_key(cls, v: Optional[str]) -> Optional[str]: + if v is not None and len(v.strip()) < 10: + raise ValueError('API密钥长度不能少于10个字符') + return v.strip() if v else v + + +class LLMConfigResponse(BaseModel): + """大模型配置响应模式.""" + id: int + name: str + provider: str + model_name: str + api_key: Optional[str] = None # 完整的API密钥(仅在include_sensitive=True时返回) + base_url: Optional[str] = None + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + description: Optional[str] = None + + is_active: bool + is_default: bool + is_embedding: bool + extra_config: Optional[Dict[str, Any]] = None + created_at: datetime + updated_at: Optional[datetime] = None + created_by: Optional[int] = None + updated_by: Optional[int] = None + + model_config = { + 'from_attributes': True + } + + @computed_field + @property + def api_key_masked(self) -> Optional[str]: + # 在响应中隐藏API密钥,只显示前4位和后4位 + if self.api_key: + key = self.api_key + if len(key) > 8: + return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}" + else: + return '*' * len(key) + return None + + +class LLMConfigTest(BaseModel): + """大模型配置测试模式.""" + message: Optional[str] = Field( + "Hello, this is a test message.", + max_length=1000, + description="测试消息" + ) + + +class LLMConfigClientResponse(BaseModel): + """大模型配置客户端响应模式(用于前端).""" + id: int + name: str + provider: str + model_name: str + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + is_active: bool + description: Optional[str] = None + + model_config = { + 'from_attributes': True + } \ No newline at end of file diff --git a/th_agenter/schemas/permission.py b/th_agenter/schemas/permission.py new file mode 100644 index 0000000..8905411 --- /dev/null +++ b/th_agenter/schemas/permission.py @@ -0,0 +1,69 @@ +"""Role Pydantic schemas.""" + +from typing import Optional, List +from pydantic import BaseModel, Field, field_validator +from datetime import datetime + +class RoleBase(BaseModel): + """角色基础模式.""" + name: str = Field(..., min_length=1, max_length=100, description="角色名称") + code: str = Field(..., min_length=1, max_length=50, description="角色代码") + description: Optional[str] = Field(None, max_length=500, description="角色描述") + sort_order: Optional[int] = Field(0, ge=0, description="排序") + is_active: bool = Field(True, description="是否激活") + +class RoleCreate(RoleBase): + """创建角色模式.""" + + @field_validator('code') + @classmethod + def validate_code(cls, v: str) -> str: + if not v.replace('_', '').replace('-', '').isalnum(): + raise ValueError('角色代码只能包含字母、数字、下划线和连字符') + return v.upper() + +class RoleUpdate(BaseModel): + """更新角色模式.""" + name: Optional[str] = Field(None, min_length=1, max_length=100, description="角色名称") + code: Optional[str] = Field(None, min_length=1, max_length=50, description="角色代码") + description: Optional[str] = Field(None, max_length=500, description="角色描述") + sort_order: Optional[int] = Field(None, ge=0, description="排序") + is_active: Optional[bool] = Field(None, description="是否激活") + + @field_validator('code') + @classmethod + def validate_code(cls, v: Optional[str]) -> Optional[str]: + if v is not None and not v.replace('_', '').replace('-', '').isalnum(): + raise ValueError('角色代码只能包含字母、数字、下划线和连字符') + return v.upper() if v else v + + +class RoleResponse(RoleBase): + """角色响应模式.""" + id: int + created_at: datetime + updated_at: Optional[datetime] = None + created_by: Optional[int] = None + updated_by: Optional[int] = None + + # 关联信息 + user_count: Optional[int] = 0 + + model_config = { + "from_attributes": True + } + + +class UserRoleAssign(BaseModel): + """用户角色分配模式.""" + user_id: int = Field(..., description="用户ID") + role_ids: List[int] = Field(..., description="角色ID列表") + + @field_validator('role_ids') + @classmethod + def validate_role_ids(cls, v: List[int]) -> List[int]: + if not v: + raise ValueError('角色ID列表不能为空') + if len(v) != len(set(v)): + raise ValueError('角色ID列表不能包含重复项') + return v \ No newline at end of file diff --git a/th_agenter/schemas/user.py b/th_agenter/schemas/user.py new file mode 100644 index 0000000..c8e20bd --- /dev/null +++ b/th_agenter/schemas/user.py @@ -0,0 +1,69 @@ +"""User schemas.""" + +from typing import Optional +from pydantic import BaseModel, Field + +from utils.util_schemas import BaseResponse + +class UserBase(BaseModel): + """User base schema.""" + username: str = Field(..., min_length=3, max_length=50) + email: str = Field(..., max_length=100) + full_name: Optional[str] = Field(None, max_length=100) + bio: Optional[str] = None + avatar_url: Optional[str] = None + +class UserCreate(UserBase): + """User creation schema.""" + password: str = Field(..., min_length=6) + +class UserUpdate(BaseModel): + """User update schema.""" + username: Optional[str] = Field(None, min_length=3, max_length=50) + email: Optional[str] = Field(None, max_length=100) + full_name: Optional[str] = Field(None, max_length=100) + bio: Optional[str] = None + avatar_url: Optional[str] = None + password: Optional[str] = Field(None, min_length=6) + is_active: Optional[bool] = None + +class ChangePasswordRequest(BaseModel): + """Change password request schema.""" + current_password: str = Field(..., description="Current password") + new_password: str = Field(..., min_length=6, description="New password") + +class ResetPasswordRequest(BaseModel): + """Admin reset password request schema.""" + new_password: str = Field(..., min_length=6, description="New password") + +class UserResponse(BaseResponse, UserBase): + """User response schema.""" + is_active: bool + is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员") + + model_config = { + 'from_attributes': True + } + + @classmethod + def model_validate(cls, obj, *, from_attributes=False): + """从对象创建响应模型,正确处理is_superuser方法""" + if hasattr(obj, '__dict__'): + data = obj.__dict__.copy() + # 调用is_superuser方法获取布尔值 + if hasattr(obj, 'is_admin'): + # 使用同步的 is_admin 属性代替异步的 is_superuser 方法 + data['is_superuser'] = obj.is_admin + elif hasattr(obj, 'is_superuser') and not callable(obj.is_superuser): + # 如果is_superuser是属性而不是方法 + data['is_superuser'] = obj.is_superuser + return super().model_validate(data) + return super().model_validate(obj, from_attributes=from_attributes) + + +class LoginResponse(BaseModel): + """登录响应模型,包含令牌和用户信息""" + access_token: str + token_type: str + expires_in: int + user: UserResponse \ No newline at end of file diff --git a/th_agenter/schemas/workflow.py b/th_agenter/schemas/workflow.py new file mode 100644 index 0000000..14cd774 --- /dev/null +++ b/th_agenter/schemas/workflow.py @@ -0,0 +1,258 @@ +"""Workflow schemas.""" + +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional, List +from datetime import datetime +from enum import Enum + + +class WorkflowStatus(str, Enum): + """工作流状态枚举""" + DRAFT = "DRAFT" + PUBLISHED = "PUBLISHED" + ARCHIVED = "ARCHIVED" + + +class NodeType(str, Enum): + """节点类型""" + START = "start" + END = "end" + LLM = "llm" + CONDITION = "condition" + LOOP = "loop" + CODE = "code" + HTTP = "http" + TOOL = "tool" + KNOWLEDGE_BASE = "knowledge-base" + + +class ExecutionStatus(str, Enum): + """执行状态""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +# 节点定义相关模式 +class NodePosition(BaseModel): + """节点位置""" + x: float + y: float + + +# 参数定义相关模式 +class ParameterType(str, Enum): + """参数类型""" + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + OBJECT = "object" + ARRAY = "array" + + +class NodeParameter(BaseModel): + """节点参数定义""" + name: str = Field(..., min_length=1, max_length=50) + type: ParameterType + description: Optional[str] = None + required: bool = True + default_value: Optional[Any] = None + source: Optional[str] = None # 参数来源:'input'(用户输入), 'node'(其他节点输出), 'variable'(变量引用) + source_node_id: Optional[str] = None # 来源节点ID(当source为'node'时) + source_field: Optional[str] = None # 来源字段名 + variable_name: Optional[str] = None # 变量名称(用于结束节点的输出参数) + + +class NodeInputOutput(BaseModel): + """节点输入输出定义""" + inputs: List[NodeParameter] = [] + outputs: List[NodeParameter] = [] + + +class NodeConfig(BaseModel): + """节点配置基类""" + pass + + +class LLMNodeConfig(NodeConfig): + """LLM节点配置""" + model_id: Optional[int] = None # 大模型配置ID + model_name: Optional[str] = None # 模型名称(兼容前端) + temperature: float = Field(default=0.7, ge=0, le=2) + max_tokens: Optional[int] = Field(default=None, gt=0) + prompt: str = Field(..., min_length=1) + enable_variable_substitution: bool = True # 是否启用变量替换 + + +class ConditionNodeConfig(NodeConfig): + """条件节点配置""" + condition: str = Field(..., min_length=1) + + +class LoopNodeConfig(NodeConfig): + """循环节点配置""" + loop_type: str = Field(..., pattern="^(count|while|foreach)$") + count: Optional[int] = Field(None, description="循环次数(当loop_type为count时)") + condition: Optional[str] = Field(None, description="循环条件(当loop_type为while时)") + iterable: Optional[str] = Field(None, description="可迭代对象(当loop_type为foreach时)") + + +class CodeNodeConfig(NodeConfig): + """代码执行节点配置""" + language: str = Field(..., pattern="^(python|javascript)$") + code: str = Field(..., min_length=1) + + +class HttpNodeConfig(NodeConfig): + """HTTP请求节点配置""" + method: str = Field(..., pattern="^(GET|POST|PUT|DELETE|PATCH)$") + url: str = Field(..., min_length=1) + headers: Optional[Dict[str, str]] = None + body: Optional[str] = None + + +class ToolNodeConfig(NodeConfig): + """工具节点配置""" + tool_type: str + parameters: Optional[Dict[str, Any]] = None + + +class WorkflowNode(BaseModel): + """工作流节点""" + id: str + type: NodeType + name: str + description: Optional[str] = None + position: NodePosition + config: Optional[Dict[str, Any]] = None + parameters: Optional[NodeInputOutput] = None # 节点输入输出参数定义 + + +from pydantic import model_validator + + +class WorkflowConnection(BaseModel): + """工作流连接""" + id: str + from_node: str = Field(..., alias="from") + to_node: str = Field(..., alias="to") + from_point: str = Field(default="output") + to_point: str = Field(default="input") + + @model_validator(mode='before') + @classmethod + def handle_node_fields(cls, values): + """处理from_node/to_node和from/to之间的转换""" + # 如果存在from但不存在from_node,使用from的值 + if "from" in values and "from_node" not in values: + values["from_node"] = values["from"] + # 如果存在from_node但不存在from,使用from_node的值 + elif "from_node" in values and "from" not in values: + values["from"] = values["from_node"] + + # 如果存在to但不存在to_node,使用to的值 + if "to" in values and "to_node" not in values: + values["to_node"] = values["to"] + # 如果存在to_node但不存在to,使用to_node的值 + elif "to_node" in values and "to" not in values: + values["to"] = values["to_node"] + + return values + + +class WorkflowDefinition(BaseModel): + """工作流定义""" + nodes: List[WorkflowNode] + connections: List[WorkflowConnection] + + +# 工作流CRUD模式 +class WorkflowCreate(BaseModel): + """创建工作流""" + name: str = Field(..., min_length=1, max_length=100) + description: Optional[str] = None + definition: WorkflowDefinition + status: WorkflowStatus = WorkflowStatus.DRAFT + + +class WorkflowUpdate(BaseModel): + """更新工作流""" + name: Optional[str] = Field(None, min_length=1, max_length=100) + description: Optional[str] = None + definition: Optional[WorkflowDefinition] = None + status: Optional[WorkflowStatus] = None + is_active: Optional[bool] = None + + +class WorkflowResponse(BaseModel): + """工作流响应""" + id: int + name: str + description: Optional[str] + status: WorkflowStatus + is_active: bool + version: str + owner_id: int + definition: Optional[WorkflowDefinition] = None + created_at: datetime + updated_at: datetime + + model_config = { + 'from_attributes': True + } + + +# 工作流执行相关模式 +class WorkflowExecuteRequest(BaseModel): + """工作流执行请求""" + input_data: Optional[Dict[str, Any]] = None + + +class NodeExecutionResponse(BaseModel): + """节点执行响应""" + id: int + node_id: str + node_type: NodeType + node_name: str + status: ExecutionStatus + input_data: Optional[Dict[str, Any]] + output_data: Optional[Dict[str, Any]] + started_at: Optional[str] + completed_at: Optional[str] + duration_ms: Optional[int] + error_message: Optional[str] + + model_config = { + 'from_attributes': True + } + + +class WorkflowExecutionResponse(BaseModel): + """工作流执行响应""" + id: int + workflow_id: int + status: ExecutionStatus + input_data: Optional[Dict[str, Any]] + output_data: Optional[Dict[str, Any]] + started_at: Optional[str] + completed_at: Optional[str] + error_message: Optional[str] + executor_id: int + node_executions: Optional[List[NodeExecutionResponse]] = None + created_at: datetime + updated_at: datetime + + model_config = { + 'from_attributes': True + } + + +# 工作流列表响应 +class WorkflowListResponse(BaseModel): + """工作流列表响应""" + workflows: List[WorkflowResponse] + total: int + page: int + size: int \ No newline at end of file diff --git a/th_agenter/services/agent/__init__.py b/th_agenter/services/agent/__init__.py new file mode 100644 index 0000000..6c8f8f9 --- /dev/null +++ b/th_agenter/services/agent/__init__.py @@ -0,0 +1,12 @@ +"""Agent services package. + +轻量化导入:仅暴露基础工具类型,避免在包导入时加载耗时的服务层。使用 AgentService 时请从子模块显式导入: + from open_agent.services.agent.agent_service import AgentService +""" + +from .base import BaseTool, ToolRegistry + +__all__ = [ + "BaseTool", + "ToolRegistry" +] \ No newline at end of file diff --git a/th_agenter/services/agent/agent_service.py b/th_agenter/services/agent/agent_service.py new file mode 100644 index 0000000..753f6b4 --- /dev/null +++ b/th_agenter/services/agent/agent_service.py @@ -0,0 +1,248 @@ +"""LangChain Agent service with tool calling capabilities.""" + +import asyncio +from typing import List, Dict, Any, Optional, AsyncGenerator +from langgraph.prebuilt import create_react_agent +from langchain_core.messages import HumanMessage, AIMessage +from pydantic import BaseModel, Field + +from .base import BaseTool, ToolRegistry, ToolResult +from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool +from ..postgresql_tool_manager import get_postgresql_tool +from ..mysql_tool_manager import get_mysql_tool +from ...core.config import get_settings +from ..agent_config import AgentConfigService +from loguru import logger + +class AgentConfig(BaseModel): + """Agent configuration.""" + enabled_tools: List[str] = Field(default_factory=lambda: [ + "calculator", "weather", "search", "datetime", "file", "generate_image", "postgresql_mcp", "mysql_mcp" + ]) + max_iterations: int = Field(default=10) + temperature: float = Field(default=0.1) + system_message: str = Field( + default="You are a helpful AI assistant with access to various tools. " + "Use the available tools to help answer user questions accurately. " + "Always explain your reasoning and the tools you're using." + ) + verbose: bool = Field(default=True) + + +class AgentService: + """LangChain Agent service with tool calling capabilities.""" + + def __init__(self): + self.settings = get_settings() + + async def initialize(self, session=None): + self.tool_registry = ToolRegistry() + self.config = AgentConfig() + self.session = session + self.config_service = AgentConfigService(session) if session else None + self._initialize_tools() + await self._load_config() + + def _initialize_tools(self): + """Initialize and register all available tools.""" + tools = [ + WeatherQueryTool(), + TavilySearchTool(), + DateTimeTool(), + get_postgresql_tool(), # 使用单例PostgreSQL MCP工具 + get_mysql_tool() # 使用单例MySQL MCP工具 + ] + + for tool in tools: + self.tool_registry.register(tool) + logger.info(f"Registered tool: {tool.get_name()}") + + async def _load_config(self): + """Load configuration from database if available.""" + if self.config_service: + try: + config_dict = await self.config_service.get_config_dict() + # Update config with database values + for key, value in config_dict.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + except Exception as e: + logger.error(f"Failed to load config from database, using defaults: {str(e)}") + + def _get_enabled_tools(self) -> List[Any]: + """Get list of enabled LangChain tools.""" + enabled_tools = [] + + for tool_name in self.config.enabled_tools: + tool = self.tool_registry.get_tool(tool_name) + if tool: + enabled_tools.append(tool) + logger.debug(f"Enabled tool: {tool_name}") + else: + logger.warning(f"Tool not found: {tool_name}") + + return enabled_tools + + async def _create_agent_executor(self) -> Any: + """Create LangChain agent executor.""" + from ...core.new_agent import new_llm + llm = await new_llm() + tools = self._get_enabled_tools() + return create_react_agent(llm, tools, prompt=self.config.system_message) + + async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]: + """Process chat message with agent.""" + try: + logger.info(f"Processing agent chat message: {message[:100]}...") + + # Create agent + agent = await self._create_agent_executor() + + langchain_history = [] + if chat_history: + for msg in chat_history: + if msg["role"] == "user": + langchain_history.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + langchain_history.append(AIMessage(content=msg["content"])) + messages = langchain_history + [HumanMessage(content=message)] + result = await agent.ainvoke({"messages": messages}) + msgs = result.get("messages") or [] + last = msgs[-1] if msgs else None + response = (getattr(last, "content", None) or str(last)) if last else str(result) + logger.info(f"Agent response generated successfully") + return { + "response": response, + "tool_calls": [], + "success": True + } + + except Exception as e: + logger.error(f"Agent chat error: {str(e)}", exc_info=True) + return { + "response": f"Sorry, I encountered an error: {str(e)}", + "tool_calls": [], + "success": False, + "error": str(e) + } + + async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]: + """Process chat message with agent (streaming).""" + tool_calls = [] # Initialize tool_calls at the beginning + try: + logger.info(f"Processing agent chat stream: {message[:100]}...") + + # Create agent + agent = await self._create_agent_executor() + + langchain_history = [] + if chat_history: + for msg in chat_history: + if msg["role"] == "user": + langchain_history.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + langchain_history.append(AIMessage(content=msg["content"])) + messages = langchain_history + [HumanMessage(content=message)] + yield {"type": "status", "content": "🤖 开始分析您的请求...", "done": False} + await asyncio.sleep(0.2) + result = await agent.ainvoke({"messages": messages}) + msgs = result.get("messages") or [] + last = msgs[-1] if msgs else None + response_content = (getattr(last, "content", None) or str(last)) if last else str(result) + + # Yield the final response in chunks to simulate streaming + words = response_content.split() + current_content = "" + + for i, word in enumerate(words): + current_content += word + " " + + # Yield every 2-3 words or at the end + if (i + 1) % 2 == 0 or i == len(words) - 1: + yield { + "type": "response", + "content": current_content.strip(), + "tool_calls": tool_calls if i == len(words) - 1 else [], + "done": i == len(words) - 1 + } + + # Small delay to simulate typing + if i < len(words) - 1: + await asyncio.sleep(0.05) + + logger.info(f"Agent stream response completed") + + except Exception as e: + logger.error(f"Agent chat stream error: {str(e)}", exc_info=True) + yield { + "type": "error", + "content": f"Sorry, I encountered an error: {str(e)}", + "done": True + } + + def update_config(self, config: Dict[str, Any]): + """Update agent configuration.""" + try: + # Update configuration + for key, value in config.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + logger.info(f"Updated agent config: {key} = {value}") + except Exception as e: + logger.error(f"Error updating agent config: {str(e)}", exc_info=True) + raise + + def load_config_from_db(self, config_id: Optional[int] = None): + """Load configuration from database.""" + if not self.config_service: + logger.warning("No database session available for loading config") + return + + try: + config_dict = self.config_service.get_config_dict(config_id) + self.update_config(config_dict) + logger.info(f"Loaded configuration from database (ID: {config_id})") + except Exception as e: + logger.error(f"Error loading config from database: {str(e)}") + raise + + def get_available_tools(self) -> List[Dict[str, Any]]: + """Get list of available tools.""" + tools = [] + for tool_name, tool in self.tool_registry._tools.items(): + tools.append({ + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": [{ + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum + } for param in tool.get_parameters()], + "enabled": tool_name in self.config.enabled_tools + }) + return tools + + def get_config(self) -> Dict[str, Any]: + """Get current agent configuration.""" + return self.config.dict() + + +# Global agent service instance +_global_agent_service: Optional[AgentService] = None + + +async def get_agent_service(session=None) -> AgentService: + """Get global agent service instance.""" + global _global_agent_service + if _global_agent_service is None: + _global_agent_service = AgentService() + await _global_agent_service.initialize(session) + elif session and session != _global_agent_service.session: + # Update with database session if not already set + _global_agent_service.session = session + _global_agent_service.config_service = AgentConfigService(session) + _global_agent_service._load_config() + return _global_agent_service diff --git a/th_agenter/services/agent/base.py b/th_agenter/services/agent/base.py new file mode 100644 index 0000000..fde7e2c --- /dev/null +++ b/th_agenter/services/agent/base.py @@ -0,0 +1,244 @@ +"""Base classes for Agent tools.""" + +import json +from loguru import logger +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Type, Callable +from pydantic import BaseModel, Field +from dataclasses import dataclass +from enum import Enum + +class ToolParameterType(str, Enum): + """Tool parameter types.""" + STRING = "string" + INTEGER = "integer" + FLOAT = "float" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +@dataclass +class ToolParameter: + """Tool parameter definition.""" + name: str + type: ToolParameterType + description: str + required: bool = True + default: Any = None + enum: Optional[List[Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON schema.""" + param_dict = { + "type": self.type.value, + "description": self.description + } + + if self.enum: + param_dict["enum"] = self.enum + + if self.default is not None: + param_dict["default"] = self.default + + return param_dict + + +class ToolResult(BaseModel): + """Tool execution result.""" + success: bool = Field(description="Whether the tool execution was successful") + result: Any = Field(description="The result data") + error: Optional[str] = Field(default=None, description="Error message if failed") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata") + + +class BaseTool(ABC): + """Base class for all Agent tools.""" + + def __init__(self): + self.name = self.get_name() + self.description = self.get_description() + self.parameters = self.get_parameters() + + @abstractmethod + def get_name(self) -> str: + """Get tool name.""" + pass + + @abstractmethod + def get_description(self) -> str: + """Get tool description.""" + pass + + @abstractmethod + def get_parameters(self) -> List[ToolParameter]: + """Get tool parameters.""" + pass + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """Execute the tool with given parameters.""" + pass + + def get_schema(self) -> Dict[str, Any]: + """Get tool schema for LangChain.""" + properties = {} + required = [] + + for param in self.parameters: + properties[param.name] = param.to_dict() + if param.required: + required.append(param.name) + + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required + } + } + } + + def validate_parameters(self, **kwargs) -> Dict[str, Any]: + """Validate and process input parameters.""" + validated = {} + + for param in self.parameters: + value = kwargs.get(param.name) + + # Check required parameters + if param.required and value is None: + raise ValueError(f"Required parameter '{param.name}' is missing") + + # Use default if not provided + if value is None and param.default is not None: + value = param.default + + # Type validation (basic) + if value is not None: + if param.type == ToolParameterType.INTEGER and not isinstance(value, int): + try: + value = int(value) + except (ValueError, TypeError): + raise ValueError(f"Parameter '{param.name}' must be an integer") + + elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)): + try: + value = float(value) + except (ValueError, TypeError): + raise ValueError(f"Parameter '{param.name}' must be a number") + + elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool): + if isinstance(value, str): + value = value.lower() in ('true', '1', 'yes', 'on') + else: + value = bool(value) + + # Enum validation + if param.enum and value not in param.enum: + raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}") + + validated[param.name] = value + + return validated + + +class ToolRegistry: + """Registry for managing Agent tools.""" + + def __init__(self): + self._tools: Dict[str, BaseTool] = {} + self._enabled_tools: Dict[str, bool] = {} + + def register(self, tool: BaseTool, enabled: bool = True) -> None: + """Register a tool.""" + tool_name = tool.get_name() + self._tools[tool_name] = tool + self._enabled_tools[tool_name] = enabled + logger.info(f"Registered tool: {tool_name} (enabled: {enabled})") + + def unregister(self, tool_name: str) -> None: + """Unregister a tool.""" + if tool_name in self._tools: + del self._tools[tool_name] + del self._enabled_tools[tool_name] + logger.info(f"Unregistered tool: {tool_name}") + + def get_tool(self, tool_name: str) -> Optional[BaseTool]: + """Get a tool by name.""" + return self._tools.get(tool_name) + + def get_enabled_tools(self) -> Dict[str, BaseTool]: + """Get all enabled tools.""" + return { + name: tool for name, tool in self._tools.items() + if self._enabled_tools.get(name, False) + } + + def get_all_tools(self) -> Dict[str, BaseTool]: + """Get all registered tools.""" + return self._tools.copy() + + def enable_tool(self, tool_name: str) -> None: + """Enable a tool.""" + if tool_name in self._tools: + self._enabled_tools[tool_name] = True + logger.info(f"Enabled tool: {tool_name}") + + def disable_tool(self, tool_name: str) -> None: + """Disable a tool.""" + if tool_name in self._tools: + self._enabled_tools[tool_name] = False + logger.info(f"Disabled tool: {tool_name}") + + def is_enabled(self, tool_name: str) -> bool: + """Check if a tool is enabled.""" + return self._enabled_tools.get(tool_name, False) + + def get_tools_schema(self) -> List[Dict[str, Any]]: + """Get schema for all enabled tools.""" + enabled_tools = self.get_enabled_tools() + return [tool.get_schema() for tool in enabled_tools.values()] + + async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult: + """Execute a tool with given parameters.""" + tool = self.get_tool(tool_name) + + if not tool: + return ToolResult( + success=False, + result=None, + error=f"Tool '{tool_name}' not found" + ) + + if not self.is_enabled(tool_name): + return ToolResult( + success=False, + result=None, + error=f"Tool '{tool_name}' is disabled" + ) + + try: + # Validate parameters + validated_params = tool.validate_parameters(**kwargs) + + # Execute tool + result = await tool.execute(**validated_params) + logger.info(f"Tool '{tool_name}' executed successfully") + return result + + except Exception as e: + logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True) + return ToolResult( + success=False, + result=None, + error=f"Tool execution failed: {str(e)}" + ) + + +# Global tool registry instance +tool_registry = ToolRegistry() \ No newline at end of file diff --git a/th_agenter/services/agent/langgraph_agent_service.py b/th_agenter/services/agent/langgraph_agent_service.py new file mode 100644 index 0000000..443a205 --- /dev/null +++ b/th_agenter/services/agent/langgraph_agent_service.py @@ -0,0 +1,739 @@ +"""LangGraph Agent service with tool calling capabilities.""" + +import asyncio +from typing import List, Dict, Any, Optional, AsyncGenerator +from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.tools import tool +from langchain.chat_models import init_chat_model +# from langgraph.prebuilt import create_react_agent +from pydantic import BaseModel, Field + +from .base import ToolRegistry +from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool +from ..postgresql_tool_manager import get_postgresql_tool +from ...core.config import get_settings +from ..agent_config import AgentConfigService +from th_agenter.services.mcp.mcp_dynamic_tools import load_mcp_tools +from loguru import logger + +class LangGraphAgentConfig(BaseModel): + """LangGraph Agent configuration.""" + model_name: str = Field(default="gpt-3.5-turbo") + model_provider: str = Field(default="openai") + base_url: Optional[str] = Field(default=None) + api_key: Optional[str] = Field(default=None) + enabled_tools: List[str] = Field(default_factory=lambda: [ + "calculator", "weather", "search", "file", "image" + ]) + max_iterations: int = Field(default=10) + temperature: float = Field(default=0.7) + max_tokens: int = Field(default=1000) + system_message: str = Field( + default="""你是一个有用的AI助手,可以使用各种工具来帮助用户解决问题。 + 重要规则: + 1. 工具调用失败时,必须仔细分析失败原因,特别是参数格式问题 + 3. 在重新调用工具前,先解释上次失败的原因和改进方案 + 4. 确保每个工具调用的参数格式严格符合工具的要求 """ + ) + verbose: bool = Field(default=True) + + +class LangGraphAgentService: + """LangGraph Agent service using low-level LangGraph graph (React pattern).""" + + def __init__(self): + self.settings = get_settings() + + async def initialize(self, session=None): + self.tool_registry = ToolRegistry() + self.config = LangGraphAgentConfig() + self.tools = [] + self.session = session + self.config_service = AgentConfigService(session) if session else None + self._initialize_tools() + await self._load_config() + await self._create_react_agent() + + def _initialize_tools(self): + """Initialize available tools.""" + try: + dynamic_tools = load_mcp_tools() + except Exception as e: + logger.warning(f"加载 MCP 动态工具失败,使用本地工具回退: {e}") + dynamic_tools = [] + + # Always keep DateTimeTool locally + base_tools = [DateTimeTool()] + + if dynamic_tools: + self.tools = dynamic_tools + base_tools + logger.info(f"LangGraph 绑定 MCP 动态工具: {[t.name for t in dynamic_tools]}") + else: + # Fallback to local weather/search when MCP not available + self.tools = [ + WeatherQueryTool(), + TavilySearchTool(), + ] + base_tools + logger.info("MCP 不可用,已回退到本地 Weather/Search 工具") + + + + async def _load_config(self): + """Load configuration from database if available.""" + pass + # if self.config_service: + # try: + # db_config = self.config_service.get_active_config() + # if db_config: + # # Update config with database values + # config_dict = db_config.config_data + # for key, value in config_dict.items(): + # if hasattr(self.config, key): + # setattr(self.config, key, value) + # logger.info("Loaded configuration from database") + # except Exception as e: + # logger.exception(f"Failed to load config from database: {e}") + + + + async def _create_react_agent(self): + """Create LangGraph agent using low-level StateGraph with explicit nodes/edges.""" + try: + # Initialize the model + llm_config = await get_settings().llm.get_current_config(self.db_session) + self.model = init_chat_model( + model=llm_config['model'], + model_provider='openai', + temperature=llm_config['temperature'], + max_tokens=llm_config['max_tokens'], + base_url= llm_config['base_url'], + api_key=llm_config['api_key'] + ) + + # Bind tools to the model so it can propose tool calls + try: + self.bound_model = self.model.bind_tools(self.tools) + except Exception as e: + logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}") + self.bound_model = self.model + + # Build low-level React graph: State -> agent -> tools -> agent ... until stop + from typing import TypedDict + from langgraph.graph import StateGraph, START, END + from langchain_core.messages import ToolMessage, BaseMessage + from typing import Annotated + from langgraph.graph.message import add_messages + + class AgentState(TypedDict): + messages: Annotated[List[BaseMessage], add_messages] + + # Node: call the model + def agent_node(state: AgentState) -> AgentState: + messages = state["messages"] + # Optionally include a system instruction at the start for first turn + if messages and messages[0].__class__.__name__ != 'SystemMessage': + # Keep user history untouched; rely on upstream to include system if desired + pass + ai = self.bound_model.invoke(messages) + return {"messages": [ai]} + + # Node: execute tools requested by the last AI message + def tools_node(state: AgentState) -> AgentState: + messages = state["messages"] + last = messages[-1] + outputs: List[ToolMessage] = [] + try: + tool_calls = getattr(last, 'tool_calls', []) or [] + tool_map = {t.name: t for t in self.tools} + for call in tool_calls: + name = call.get('name') if isinstance(call, dict) else getattr(call, 'name', None) + args = call.get('args') if isinstance(call, dict) else getattr(call, 'args', {}) + call_id = call.get('id') if isinstance(call, dict) else getattr(call, 'id', '') + if name in tool_map: + try: + result = tool_map[name].invoke(args) + except Exception as te: + result = f"Tool {name} execution error: {te}" + else: + result = f"Unknown tool: {name}" + outputs.append(ToolMessage(content=str(result), tool_call_id=call_id)) + except Exception as e: + outputs.append(ToolMessage(content=f"Tool execution error: {e}", tool_call_id="")) + return {"messages": outputs} + + # Router: decide next step after agent node + def route_after_agent(state: AgentState) -> str: + last = state["messages"][-1] + finish_reason = None + try: + meta = getattr(last, 'response_metadata', {}) or {} + finish_reason = meta.get('finish_reason') + except Exception: + finish_reason = None + # If the model decided to call tools, continue to tools node + if getattr(last, 'tool_calls', None): + return "tools" + # Otherwise, end + return END + + graph = StateGraph(AgentState) + graph.add_node("agent", agent_node) + graph.add_node("tools", tools_node) + graph.add_edge(START, "agent") + graph.add_conditional_edges("agent", route_after_agent, {"tools": "tools", END: END}) + graph.add_edge("tools", "agent") + + # Compile graph and store as self.agent for compatibility with existing code + self.react_agent = graph.compile() + + logger.info("LangGraph 底层 React 智能体创建成功") + + except Exception as e: + logger.error(f"Failed to create agent: {str(e)}") + raise + + + + + + + + def _format_tools_info(self) -> str: + """Format tools information for the prompt.""" + tools_info = [] + for tool_name in self.config.enabled_tools: + tool = self.tool_registry.get_tool(tool_name) + if tool: + params_info = [] + for param in tool.get_parameters(): + params_info.append(f" - {param.name} ({param.type.value}): {param.description}") + + tool_info = f"**{tool.get_name()}**: {tool.get_description()}" + if params_info: + tool_info += "\n" + "\n".join(params_info) + tools_info.append(tool_info) + + return "\n\n".join(tools_info) + + + + async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]: + """Process a chat message using LangGraph.""" + try: + logger.info(f"Starting chat with message: {message[:100]}...") + + # Convert chat history to messages + messages = [] + if chat_history: + for msg in chat_history: + if msg["role"] == "user": + messages.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + messages.append(AIMessage(content=msg["content"])) + + # Add current message + messages.append(HumanMessage(content=message)) + + # Use the low-level graph directly + result = await self.react_agent.ainvoke({"messages": messages}, {"recursion_limit": 6}, ) + + # Extract final response + final_response = "" + if "messages" in result and result["messages"]: + last_message = result["messages"][-1] + if hasattr(last_message, 'content'): + final_response = last_message.content + elif isinstance(last_message, dict) and "content" in last_message: + final_response = last_message["content"] + + return { + "response": final_response, + "intermediate_steps": [], + "success": True, + "error": None + } + + except Exception as e: + logger.error(f"LangGraph chat error: {str(e)}", exc_info=True) + return { + "response": f"抱歉,处理您的请求时出现错误: {str(e)}", + "intermediate_steps": [], + "success": False, + "error": str(e) + } + + async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[ + Dict[str, Any], None]: + """Process a chat message using LangGraph with streaming.""" + try: + logger.info(f"Starting streaming chat with message: {message[:100]}...") + + # Convert chat history to messages + messages = [] + if chat_history: + for msg in chat_history: + if msg["role"] == "user": + messages.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + messages.append(AIMessage(content=msg["content"])) + + # Add current message + messages.append(HumanMessage(content=message)) + + # Track state for streaming + intermediate_steps = [] + final_response_started = False + accumulated_response = "" + final_ai_message = None + + # Stream the agent execution + async for event in self.react_agent.astream({"messages": messages}): + # Handle different event types from LangGraph + print('event===', event) + if isinstance(event, dict): + for node_name, node_output in event.items(): + logger.info(f"Processing node: {node_name}, output type: {type(node_output)}") + + # 处理 tools 节点 + if "tools" in node_name.lower(): + # 提取工具信息 + tool_infos = [] + + if isinstance(node_output, dict) and "messages" in node_output: + messages_in_output = node_output["messages"] + + for msg in messages_in_output: + # 处理 ToolMessage 对象 + if hasattr(msg, 'name') and hasattr(msg, 'content'): + tool_info = { + "tool_name": msg.name, + "tool_output": msg.content, + "tool_call_id": getattr(msg, 'tool_call_id', ''), + "status": "completed" + } + tool_infos.append(tool_info) + elif isinstance(msg, dict): + if 'name' in msg and 'content' in msg: + tool_info = { + "tool_name": msg['name'], + "tool_output": msg['content'], + "tool_call_id": msg.get('tool_call_id', ''), + "status": "completed" + } + tool_infos.append(tool_info) + + # 返回 tools_end 事件 + for tool_info in tool_infos: + yield { + "type": "tools_end", + "content": f"工具 {tool_info['tool_name']} 执行完成", + "tool_name": tool_info["tool_name"], + "tool_output": tool_info["tool_output"], + "node_name": node_name, + "done": False + } + await asyncio.sleep(0.1) + + # 处理 agent 节点 + elif "agent" in node_name.lower(): + if isinstance(node_output, dict) and "messages" in node_output: + messages_in_output = node_output["messages"] + if messages_in_output: + last_msg = messages_in_output[-1] + + # 获取 finish_reason + finish_reason = None + if hasattr(last_msg, 'response_metadata'): + finish_reason = last_msg.response_metadata.get('finish_reason') + elif isinstance(last_msg, dict) and 'response_metadata' in last_msg: + finish_reason = last_msg['response_metadata'].get('finish_reason') + + # 判断是否为 thinking 或 response + if finish_reason == 'tool_calls': + # thinking 状态 + thinking_content = "🤔 正在思考..." + if hasattr(last_msg, 'content') and last_msg.content: + thinking_content = f"🤔 思考: {last_msg.content[:200]}..." + elif isinstance(last_msg, dict) and "content" in last_msg: + thinking_content = f"🤔 思考: {last_msg['content'][:200]}..." + + yield { + "type": "thinking", + "content": thinking_content, + "node_name": node_name, + "raw_output": str(node_output)[:500] if node_output else "", + "done": False + } + await asyncio.sleep(0.1) + + elif finish_reason == 'stop': + # response 状态 + if hasattr(last_msg, 'content') and hasattr(last_msg, + '__class__') and 'AI' in last_msg.__class__.__name__: + current_content = last_msg.content + final_ai_message = last_msg + + if not final_response_started and current_content: + final_response_started = True + yield { + "type": "response_start", + "content": "", + "intermediate_steps": intermediate_steps, + "done": False + } + + if current_content and len(current_content) > len(accumulated_response): + new_content = current_content[len(accumulated_response):] + + for char in new_content: + accumulated_response += char + yield { + "type": "response", + "content": accumulated_response, + "intermediate_steps": intermediate_steps, + "done": False + } + await asyncio.sleep(0.03) + + else: + # 其他 agent 状态 + yield { + "type": "step", + "content": f"📋 执行步骤: {node_name}", + "node_name": node_name, + "raw_output": str(node_output)[:500] if node_output else "", + "done": False + } + await asyncio.sleep(0.1) + + # 处理其他节点 + else: + yield { + "type": "step", + "content": f"📋 执行步骤: {node_name}", + "node_name": node_name, + "raw_output": str(node_output)[:500] if node_output else "", + "done": False + } + await asyncio.sleep(0.1) + + # 最终完成事件 + yield { + "type": "complete", + "content": accumulated_response, + "intermediate_steps": intermediate_steps, + "done": True + } + + except Exception as e: + logger.error(f"Error in chat_stream: {str(e)}", exc_info=True) + yield { + "type": "error", + "content": f"处理请求时出错: {str(e)}", + "done": True + } + + # 确保最终响应包含完整内容 + final_content = accumulated_response + if not final_content and final_ai_message and hasattr(final_ai_message, 'content'): + final_content = final_ai_message.content or "" + + # Final completion signal + yield { + "type": "response", + "content": final_content, + "intermediate_steps": intermediate_steps, + "done": True + } + + except Exception as e: + logger.error(f"LangGraph chat stream error: {str(e)}", exc_info=True) + yield { + "type": "error", + "content": f"抱歉,处理您的请求时出现错误: {str(e)}", + "error": str(e), + "done": True + } + + def get_available_tools(self) -> List[Dict[str, Any]]: + """Get list of available tools.""" + tools = [] + for tool in self.tools: + tools.append({ + "name": tool.name, + "description": tool.description, + "parameters": [], + "enabled": True + }) + return tools + + def get_config(self) -> Dict[str, Any]: + """Get current agent configuration.""" + return self.config.dict() + + def update_config(self, config: Dict[str, Any]): + """Update agent configuration.""" + for key, value in config.items(): + if hasattr(self.config, key): + setattr(self.config, key, value) + + # Recreate agent with new config + self._create_react_agent() + logger.info("Agent configuration updated") + + def _create_plan_execute_agent(self): + """Create a Plan-and-Execute agent using LangGraph low-level API. + 结构:START -> planner -> executor(loop) -> summarize -> END + - planner:根据用户问题生成计划(JSON 数组) + - executor:逐步执行计划(可调用工具),收集每步结果 + - summarize:综合计划与执行结果产出最终回答 + """ + from typing import TypedDict, Annotated, List + import json + from langgraph.graph import StateGraph, START, END + from langgraph.graph.message import add_messages + from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage + try: + self.bound_model = self.model.bind_tools(self.tools) + except Exception as e: + logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}") + self.bound_model = self.model + class PlanState(TypedDict): + messages: Annotated[List[BaseMessage], add_messages] + plan_steps: List[str] + current_step: int + step_results: List[str] + + def planner_node(state: PlanState) -> PlanState: + messages = state.get("messages", []) + plan_prompt = ( + "你是规划助手。基于对话内容生成可执行计划," + "用 JSON 数组返回,每个元素是一条明确且可操作的步骤。" + "仅输出 JSON,不要额外解释。" + ) + ai_plan = self.model.invoke(messages + [HumanMessage(content=plan_prompt)]) + steps: List[str] = [] + try: + parsed = json.loads(ai_plan.content) + if isinstance(parsed, list): + steps = [str(s) for s in parsed] + except Exception: + # 回退:按行拆分 + steps = [s.strip() for s in ai_plan.content.split("\n") if s.strip()] + return { + "messages": [ai_plan], + "plan_steps": steps, + "current_step": 0, + "step_results": [] + } + + def executor_node(state: PlanState) -> PlanState: + idx = state.get("current_step", 0) + steps = state.get("plan_steps", []) + msgs = state.get("messages", []) + if idx >= len(steps): + return {"messages": [], "current_step": idx, "step_results": state.get("step_results", [])} + + step_text = steps[idx] + exec_prompt = ( + f"请执行计划的第{idx+1}步:{step_text}。" + "需要用工具时创建工具调用;完成后给出该步的结果。" + ) + ai_exec = self.bound_model.invoke(msgs + [HumanMessage(content=exec_prompt)]) + + new_messages: List[BaseMessage] = [ai_exec] + step_result_content = None + + # 处理工具调用 + tool_map = {t.name: t for t in self.tools} + tool_msgs: List[ToolMessage] = [] + tool_calls = getattr(ai_exec, "tool_calls", []) or (ai_exec.additional_kwargs.get("tool_calls") if hasattr(ai_exec, "additional_kwargs") else []) + if tool_calls: + for call in tool_calls: + name = call.get("name") + args = call.get("args", {}) + tool_obj = tool_map.get(name) + if tool_obj: + try: + result = tool_obj.invoke(args) + except Exception as e: + result = f"工具执行失败: {e}" + else: + result = f"未找到工具: {name}" + tool_call_id = call.get("id") or call.get("tool_call_id") or call.get("call_id") or f"tool_{name}" + tool_msgs.append(ToolMessage(content=str(result), tool_call_id=tool_call_id, name=name or "tool")) + new_messages.extend(tool_msgs) + # 基于工具输出总结该步结果 + summarize_step = "请基于上述工具输出,总结该步骤的结果,给出结构化要点与可读说明。" + ai_step = self.bound_model.invoke(msgs + [ai_exec] + tool_msgs + [HumanMessage(content=summarize_step)]) + step_result_content = ai_step.content + new_messages.append(ai_step) + else: + step_result_content = ai_exec.content + + all_results = list(state.get("step_results", [])) + if step_result_content: + all_results.append(step_result_content) + + return { + "messages": new_messages, + "current_step": idx + 1, + "step_results": all_results + } + + def route_after_planner(state: PlanState) -> str: + return "executor" if state.get("plan_steps") else END + + def route_after_executor(state: PlanState) -> str: + cur = state.get("current_step", 0) + total = len(state.get("plan_steps", [])) + return "executor" if cur < total else "summarize" + + def summarize_node(state: PlanState) -> PlanState: + import json as _json + msgs = state.get("messages", []) + steps = state.get("plan_steps", []) + results = state.get("step_results", []) + final_prompt = ( + "请综合以上计划与各步骤结果,生成最终回答。" + "要求:逻辑清晰、结论明确、可读性强;如存在不确定性请注明。" + ) + context_msg = HumanMessage(content=( + f"计划: {_json.dumps(steps, ensure_ascii=False)}\n" + f"步骤结果: {_json.dumps(results, ensure_ascii=False)}\n" + f"{final_prompt}" + )) + ai_final = self.model.invoke(msgs + [context_msg]) + return {"messages": [ai_final]} + + graph = StateGraph(PlanState) + graph.add_node("planner", planner_node) + graph.add_node("executor", executor_node) + graph.add_node("summarize", summarize_node) + graph.add_edge(START, "planner") + graph.add_conditional_edges("planner", route_after_planner, {"executor": "executor", END: END}) + graph.add_conditional_edges("executor", route_after_executor, {"executor": "executor", "summarize": "summarize"}) + graph.add_edge("summarize", END) + + self.plan_execute_agent = graph.compile() + + async def chat_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]: + """Single-turn Plan-and-Execute chat.""" + # 确保 agent 已创建 + if not hasattr(self, "plan_execute_agent"): + self._create_plan_execute_agent() + + # 构建消息 + messages = [] + if chat_history: + for msg in chat_history: + role = msg.get("role") + content = msg.get("content", "") + if role == "user": + messages.append(HumanMessage(content=content)) + else: + messages.append(AIMessage(content=content)) + messages.append(HumanMessage(content=message)) + + try: + result = await self.plan_execute_agent.ainvoke({"messages": messages}, config={"recursion_limit": self.config.max_iterations}) + final_msg = None + if isinstance(result, dict) and "messages" in result: + ms = result["messages"] + if ms: + final_msg = ms[-1] + final_text = getattr(final_msg, "content", "") if final_msg else "" + return { + "status": "success", + "response": final_text, + "raw": str(result) + } + except Exception as e: + logger.error(f"Error in chat_plan_execute: {e}", exc_info=True) + return { + "status": "error", + "error": str(e) + } + + async def chat_stream_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]: + """Streamed Plan-and-Execute chat.""" + import asyncio as _asyncio + if not hasattr(self, "plan_execute_agent"): + self._create_plan_execute_agent() + + messages = [] + if chat_history: + for msg in chat_history: + role = msg.get("role") + content = msg.get("content", "") + if role == "user": + messages.append(HumanMessage(content=content)) + else: + messages.append(AIMessage(content=content)) + messages.append(HumanMessage(content=message)) + + try: + accumulated = "" + async for event in self.react_agent.astream({"messages": messages}, config={"recursion_limit": self.config.max_iterations}): + for key, node_output in event.items(): + node_name = key[0] if isinstance(key, tuple) else key + if node_name == "planner": + # 规划阶段 + content = "生成计划中..." + if node_output and isinstance(node_output, dict): + m = node_output.get("messages", []) + if m: + last = m[-1] + if hasattr(last, "content"): + content = str(last.content)[:400] + yield {"type": "planning", "content": content, "done": False} + await _asyncio.sleep(0.05) + elif node_name == "executor": + # 执行阶段(可能包含工具) + yield {"type": "step", "content": "执行计划步骤", "done": False} + await _asyncio.sleep(0.05) + if node_output and isinstance(node_output, dict): + msgs = node_output.get("messages", []) + # 输出工具结束标记 + tool_msgs = [m for m in msgs if hasattr(m, "__class__") and "Tool" in m.__class__.__name__] + if tool_msgs: + yield {"type": "tools_end", "content": f"完成 {len(tool_msgs)} 次工具执行", "done": False} + await _asyncio.sleep(0.03) + # 尝试输出该步总结 + ai_msgs = [m for m in msgs if hasattr(m, "__class__") and "AI" in m.__class__.__name__] + if ai_msgs: + text = ai_msgs[-1].content + if text: + accumulated = text + yield {"type": "response", "content": accumulated, "done": False} + await _asyncio.sleep(0.02) + elif node_name == "summarize": + # 最终总结 + if node_output and isinstance(node_output, dict): + msgs = node_output.get("messages", []) + if msgs: + final = msgs[-1] + content = getattr(final, "content", "") + if content: + yield {"type": "response_start", "content": "", "done": False} + yield {"type": "response", "content": content, "done": False} + accumulated = content + await _asyncio.sleep(0.02) + yield {"type": "complete", "content": accumulated, "done": True} + except Exception as e: + logger.error(f"Error in chat_stream_plan_execute: {e}", exc_info=True) + yield {"type": "error", "content": str(e), "done": True} + + +# Global instance +_global_langgraph_agent_service: Optional[LangGraphAgentService] = None + +async def get_langgraph_agent_service(session=None) -> LangGraphAgentService: + """Get or create LangGraph agent service instance.""" + global _global_langgraph_agent_service + + if _global_langgraph_agent_service is None: + _global_langgraph_agent_service = LangGraphAgentService() + await _global_langgraph_agent_service.initialize(session) + + return _global_langgraph_agent_service \ No newline at end of file diff --git a/th_agenter/services/agent_config.py b/th_agenter/services/agent_config.py new file mode 100644 index 0000000..a689509 --- /dev/null +++ b/th_agenter/services/agent_config.py @@ -0,0 +1,201 @@ +"""Agent configuration service.""" + +from typing import List, Dict, Any, Optional +from sqlalchemy.orm import Session +from sqlalchemy import and_, select, update + +from ..models.agent_config import AgentConfig +from utils.util_exceptions import ValidationError, NotFoundError +from loguru import logger + + +class AgentConfigService: + """Service for managing agent configurations.""" + + def __init__(self, db: Session): + self.db = db + + async def create_config(self, config_data: Dict[str, Any]) -> AgentConfig: + """Create a new agent configuration.""" + try: + # Validate required fields + if not config_data.get("name"): + raise ValidationError("Configuration name is required") + + # Check if name already exists + stmt = select(AgentConfig).where(AgentConfig.name == config_data["name"]) + existing = (await self.db.execute(stmt)).scalar_one_or_none() + if existing: + raise ValidationError(f"Configuration with name '{config_data['name']}' already exists") + + # Create new configuration + config = AgentConfig( + name=config_data["name"], + description=config_data.get("description", ""), + enabled_tools=config_data.get("enabled_tools", ["calculator", "weather", "search", "datetime", "file"]), + max_iterations=config_data.get("max_iterations", 10), + temperature=config_data.get("temperature", 0.1), + system_message=config_data.get("system_message", "You are a helpful AI assistant with access to various tools. Use the available tools to help answer user questions accurately. Always explain your reasoning and the tools you're using."), + verbose=config_data.get("verbose", True), + is_active=config_data.get("is_active", True), + is_default=config_data.get("is_default", False) + ) + + # If this is set as default, unset other defaults + if config.is_default: + stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False}) + await self.db.execute(stmt) + + self.db.add(config) + await self.db.commit() + await self.db.refresh(config) + + logger.info(f"Created agent configuration: {config.name}") + return config + + except Exception as e: + await self.db.rollback() + logger.error(f"Error creating agent configuration: {str(e)}") + raise + + async def get_config(self, config_id: int) -> Optional[AgentConfig]: + """Get agent configuration by ID.""" + stmt = select(AgentConfig).where(AgentConfig.id == config_id) + return (await self.db.execute(stmt)).scalar_one_or_none() + + async def get_config_by_name(self, name: str) -> Optional[AgentConfig]: + """Get agent configuration by name.""" + stmt = select(AgentConfig).where(AgentConfig.name == name) + return (await self.db.execute(stmt)).scalar_one_or_none() + + async def get_default_config(self) -> Optional[AgentConfig]: + """Get default agent configuration.""" + stmt = select(AgentConfig).where(and_(AgentConfig.is_default == True, AgentConfig.is_active == True)) + return (await self.db.execute(stmt)).scalar_one_or_none() + + def list_configs(self, active_only: bool = True) -> List[AgentConfig]: + """List all agent configurations.""" + stmt = select(AgentConfig) + if active_only: + stmt = stmt.where(AgentConfig.is_active == True) + stmt = stmt.order_by(AgentConfig.created_at.desc()) + return self.db.execute(stmt).scalars().all() + + async def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig: + """Update agent configuration.""" + try: + config = await self.get_config(config_id) + if not config: + raise NotFoundError(f"Agent configuration with ID {config_id} not found") + + # Check if name change conflicts with existing + if "name" in config_data and config_data["name"] != config.name: + stmt = select(AgentConfig).where( + and_( + AgentConfig.name == config_data["name"], + AgentConfig.id != config_id + ) + ) + existing = (await self.db.execute(stmt)).scalar_one_or_none() + if existing: + raise ValidationError(f"Configuration with name '{config_data['name']}' already exists") + + # Update fields + for key, value in config_data.items(): + if hasattr(config, key): + setattr(config, key, value) + + # If this is set as default, unset other defaults + if config_data.get("is_default", False): + stmt = update(AgentConfig).where( + and_( + AgentConfig.is_default == True, + AgentConfig.id != config_id + ) + ).values({"is_default": False}) + await self.db.execute(stmt) + + await self.db.commit() + await self.db.refresh(config) + + logger.info(f"Updated agent configuration: {config.name}") + return config + + except Exception as e: + await self.db.rollback() + logger.error(f"Error updating agent configuration: {str(e)}") + raise + + async def delete_config(self, config_id: int) -> bool: + """Delete agent configuration (soft delete by setting is_active=False).""" + try: + config = await self.get_config(config_id) + if not config: + raise NotFoundError(f"Agent configuration with ID {config_id} not found") + + # Don't allow deleting the default configuration + if config.is_default: + raise ValidationError("Cannot delete the default configuration") + + config.is_active = False + self.db.commit() + + logger.info(f"Deleted agent configuration: {config.name}") + return True + + except Exception as e: + await self.db.rollback() + logger.error(f"Error deleting agent configuration: {str(e)}") + raise + + async def set_default_config(self, config_id: int) -> AgentConfig: + """Set a configuration as default.""" + try: + config = await self.get_config(config_id) + if not config: + raise NotFoundError(f"Agent configuration with ID {config_id} not found") + + if not config.is_active: + raise ValidationError("Cannot set inactive configuration as default") + + # Unset other defaults + stmt = update(AgentConfig).where(AgentConfig.is_default == True).values({"is_default": False}) + await self.db.execute(stmt) + + # Set this as default + config.is_default = True + await self.db.commit() + await self.db.refresh(config) + + logger.info(f"Set default agent configuration: {config.name}") + return config + + except Exception as e: + await self.db.rollback() + logger.error(f"Error setting default agent configuration: {str(e)}") + raise + + async def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]: + """Get configuration as dictionary. If no ID provided, returns default config.""" + if config_id: + config = await self.get_config(config_id) + else: + config = await self.get_default_config() + + if not config: + # Return default values if no configuration found + return { + "enabled_tools": ["calculator", "weather", "search", "datetime", "file", "generate_image"], + "max_iterations": 10, + "temperature": 0.1, + "system_message": "You are a helpful AI assistant with access to various tools. Use the available tools to help answer user questions accurately. Always explain your reasoning and the tools you're using.", + "verbose": True + } + + return { + "enabled_tools": config.enabled_tools, + "max_iterations": config.max_iterations, + "temperature": config.temperature, + "system_message": config.system_message, + "verbose": config.verbose + } \ No newline at end of file diff --git a/th_agenter/services/auth.py b/th_agenter/services/auth.py new file mode 100644 index 0000000..c47a402 --- /dev/null +++ b/th_agenter/services/auth.py @@ -0,0 +1,135 @@ +"""Authentication service.""" + +from loguru import logger +from typing import Optional +from datetime import datetime, timedelta, timezone +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from sqlalchemy.orm import Session +from sqlalchemy import select +import bcrypt +import jwt + +from ..core.config import settings +from ..db.database import get_session +from ..models.user import User + +security = HTTPBearer() + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), + session: Session = Depends(get_session) +) -> User: + """Get current authenticated user (for Depends).""" + from ..core.context import UserContext + + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = credentials.credentials + payload = AuthService.verify_token(token) + if payload is None: + session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..." + raise credentials_exception + username: str = payload.get("sub") + if username is None: + session.desc = "ERROR: 令牌中没有用户名" + raise credentials_exception + stmt = select(User).where(User.username == username) + user = (await session.execute(stmt)).scalar_one_or_none() + if user is None: + session.desc = f"ERROR: 数据库中未找到用户 {username}" + raise credentials_exception + UserContext.set_current_user(user, canLog=True) + return user + + +def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: + """Get current active user (for Depends).""" + if not current_user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Inactive user" + ) + return current_user + + +class AuthService: + """Authentication service.""" + get_current_user = get_current_user + get_current_active_user = get_current_active_user + + @staticmethod + async def authenticate_user_by_email(session: Session, email: str, password: str) -> Optional[User]: + """Authenticate user with email and password.""" + session.desc = f"根据邮箱 {email} 验证用户密码" + stmt = select(User).where(User.email == email) + user = (await session.execute(stmt)).scalar_one_or_none() + if not user: + return None + if not AuthService.verify_password(password, user.hashed_password): + return None + return user + + @staticmethod + async def authenticate_user(session: Session, username: str, password: str) -> Optional[User]: + """Authenticate user with username and password.""" + session.desc = f"根据用户名 {username} 验证用户密码" + stmt = select(User).where(User.username == username) + user = (await session.execute(stmt)).scalar_one_or_none() + if not user: + return None + if not AuthService.verify_password(password, user.hashed_password): + return None + return user + + @staticmethod + async def create_access_token(session: Session, data: dict, expires_delta: Optional[timedelta] = None) -> str: + """创建 JWT 访问 token""" + session.desc = f"创建 JWT 访问 token - 数据: {data}" + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=settings.security.access_token_expire_minutes) + + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode( + to_encode, + settings.security.secret_key, + algorithm=settings.security.algorithm + ) + return encoded_jwt + + @staticmethod + def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash.""" + # 直接使用bcrypt库进行密码验证 + return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8')) + + @staticmethod + def get_password_hash(password: str) -> str: + """Generate password hash.""" + # 直接使用bcrypt库进行哈希 + salt = bcrypt.gensalt() + hashed_bytes = bcrypt.hashpw(password.encode('utf-8'), salt) + hashed_password = hashed_bytes.decode('utf-8') + return hashed_password + + @staticmethod + def verify_token(token: str) -> Optional[dict]: + """Verify JWT token.""" + try: + payload = jwt.decode( + token, + settings.security.secret_key, + algorithms=[settings.security.algorithm] + ) + return payload + except jwt.PyJWTError as e: + logger.error(f"Token verification failed: {e}") + logger.error(f"Token: {token[:50]}...") + return None \ No newline at end of file diff --git a/th_agenter/services/chat.py b/th_agenter/services/chat.py new file mode 100644 index 0000000..7b6f66c --- /dev/null +++ b/th_agenter/services/chat.py @@ -0,0 +1,313 @@ +"""Chat service for AI model integration using LangChain.""" + +from th_agenter import db +import json +import asyncio +import os +from typing import AsyncGenerator, Optional, List, Dict, Any, TypedDict +from sqlalchemy.orm import Session + +from loguru import logger + +from th_agenter.core.new_agent import new_agent, new_llm +from ..core.config import settings +from ..models.message import MessageRole +from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse +from utils.util_exceptions import ChatServiceError, HxfResponse, OpenAIError +from .conversation import ConversationService +from .langchain_chat import LangChainChatService +try: + from .knowledge_chat import KnowledgeChatService +except ModuleNotFoundError as e: + KnowledgeChatService = None # 需 pip install langchain-chroma +from .agent.agent_service import get_agent_service +from .agent.langgraph_agent_service import get_langgraph_agent_service + +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph +class AgentState(TypedDict): + messages: List[dict] # 存储对话消息(核心记忆) + +class ChatService: + """Service for handling AI chat functionality using LangChain.""" + _checkpointer_initialized = False + _conn_string = None + + async def chat( + self, + conversation_id: int, + message: str, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + use_agent: bool = False, + use_langgraph: bool = False, + use_knowledge_base: bool = False, + knowledge_base_id: Optional[int] = None + ) -> ChatResponse: + """Send a message and get AI response using LangChain, Agent, or Knowledge Base.""" + if use_knowledge_base and knowledge_base_id: + if not self.knowledge_chat_service: + raise ChatServiceError("知识库功能需要安装: pip install langchain-chroma") + logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}") + return await self.knowledge_chat_service.chat_with_knowledge_base( + conversation_id=conversation_id, + message=message, + knowledge_base_id=knowledge_base_id, + stream=stream, + temperature=temperature, + max_tokens=max_tokens + ) + elif use_langgraph: + logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent") + + # Get conversation history for LangGraph agent + conversation = await self.conversation_service.get_conversation(conversation_id) + if not conversation: + raise ChatServiceError(f"Conversation {conversation_id} not found") + + messages = await self.conversation_service.get_conversation_messages(conversation_id) + chat_history = [{ + "role": "user" if msg.role == MessageRole.USER else "assistant", + "content": msg.content + } for msg in messages] + + # Use LangGraph agent service + agent_result = await self.langgraph_agent_service.chat(message, chat_history) + + if agent_result["success"]: + # Save user message + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Save assistant response + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=agent_result["response"], + role=MessageRole.ASSISTANT, + message_metadata={"intermediate_steps": agent_result["intermediate_steps"]} + ) + + return ChatResponse( + message=MessageResponse( + id=assistant_message.id, + content=agent_result["response"], + role=MessageRole.ASSISTANT, + conversation_id=conversation_id, + created_at=assistant_message.created_at, + metadata=assistant_message.metadata + ) + ) + else: + raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}") + elif use_agent: + logger.info(f"Processing chat request for conversation {conversation_id} via Agent") + + # Get conversation history for agent + conversation = await self.conversation_service.get_conversation(conversation_id) + if not conversation: + raise ChatServiceError(f"Conversation {conversation_id} not found") + + messages = await self.conversation_service.get_conversation_messages(conversation_id) + chat_history = [{ + "role": "user" if msg.role == MessageRole.USER else "assistant", + "content": msg.content + } for msg in messages] + + # Use agent service + agent_result = await self.agent_service.chat(message, chat_history) + + if agent_result["success"]: + # Save user message + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Save assistant response + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=agent_result["response"], + role=MessageRole.ASSISTANT, + message_metadata={"tool_calls": agent_result["tool_calls"]} + ) + + return ChatResponse( + message=MessageResponse( + id=assistant_message.id, + content=agent_result["response"], + role=MessageRole.ASSISTANT, + conversation_id=conversation_id, + created_at=assistant_message.created_at, + metadata=assistant_message.metadata + ) + ) + else: + raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}") + else: + logger.info(f"Processing chat request for conversation {conversation_id} via LangChain") + + # Delegate to LangChain service + return await self.langchain_chat_service.chat( + conversation_id=conversation_id, + message=message, + stream=stream, + temperature=temperature, + max_tokens=max_tokens + ) + + async def get_available_models(self) -> List[str]: + """Get list of available models from LangChain.""" + logger.info("Getting available models via LangChain") + + # Delegate to LangChain service + return await self.langchain_chat_service.get_available_models() + + def update_model_config( + self, + model: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ): + """Update LLM configuration via LangChain.""" + logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}") + + # Delegate to LangChain service + self.langchain_chat_service.update_model_config( + model=model, + temperature=temperature, + max_tokens=max_tokens + ) + # ------------------------------------------------------------------------- + def __init__(self, session: Session): + self.session = session + self.knowledge_chat_service = KnowledgeChatService(session) if KnowledgeChatService else None + + async def initialize(self, conversation_id: int, streaming: bool = False): + self.conversation_service = ConversationService(self.session) + self.session.desc = "ChatService初始化 - ConversationService 实例化完毕" + self.conversation = await self.conversation_service.get_conversation( + conversation_id=conversation_id + ) + if not self.conversation: + raise ChatServiceError(f"Conversation {conversation_id} not found") + + if not ChatService._checkpointer_initialized: + from langgraph.checkpoint.postgres import PostgresSaver + import psycopg2 + CONN_STRING = "postgresql://postgres:postgres@localhost:5433/postgres" + ChatService._conn_string = CONN_STRING + + # 检查必要的表是否已存在 + tables_need_setup = True + try: + # 连接到数据库并检查表是否存在 + conn = psycopg2.connect(CONN_STRING) + cursor = conn.cursor() + + # 检查langgraph需要的表是否存在 + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name IN ('checkpoints', 'checkpoint_writes', 'checkpoint_blobs') + """) + existing_tables = [row[0] for row in cursor.fetchall()] + + # 检查是否所有必要的表都存在 + required_tables = ['checkpoints', 'checkpoint_writes', 'checkpoint_blobs'] + if all(table in existing_tables for table in required_tables): + tables_need_setup = False + self.session.desc = "ChatService初始化 - 检测到langgraph表已存在,跳过setup" + + cursor.close() + conn.close() + + except Exception as e: + self.session.desc = f"ChatService初始化 - checkpoint失败: {str(e)},将进行setup" + tables_need_setup = True + + # 只有在需要时才进行setup + if tables_need_setup: + self.session.desc = "ChatService初始化 - 正在进行PostgresSaver setup" + try: + async with AsyncPostgresSaver.from_conn_string(CONN_STRING) as checkpointer: + await checkpointer.setup() + self.session.desc = "ChatService初始化 - PostgresSaver setup完成" + logger.info("PostgresSaver setup完成") + except Exception as e: + self.session.desc = f"ChatService初始化 - PostgresSaver setup失败: {str(e)}" + logger.error(f"PostgresSaver setup失败: {e}") + raise + else: + self.session.desc = "ChatService初始化 - 使用现有的langgraph表" + + # 存储连接字符串供后续使用 + ChatService._checkpointer_initialized = True + + self.llm = await new_llm(session=self.session, streaming=streaming) + self.session.desc = f"ChatService初始化 - 获取对话实例完毕 > {self.conversation}" + + def get_config(self): + config = { + "configurable": { + "thread_id": str(self.conversation.id), + "checkpoint_ns": "drgraph" + } + } + return config + + async def chat_stream( + self, + message: str + ) -> AsyncGenerator[str, None]: + """Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base.""" + self.session.desc = f"ChatService - 发送消息 {message} >>> 流式对话请求,会话 ID: {self.conversation.id}" + await self.conversation_service.add_message( + conversation_id=self.conversation.id, + role=MessageRole.USER, + content=message + ) + full_assistant_content = "" + + async with AsyncPostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: + from langgraph.prebuilt import create_react_agent + from langchain_core.messages import HumanMessage + agent = create_react_agent(self.llm, [], checkpointer=checkpointer) + async for chunk in agent.astream( + {"messages": [HumanMessage(content=message)]}, + config=self.get_config(), + stream_mode="messages" + ): + part = chunk[0].content if hasattr(chunk[0], "content") else str(chunk[0]) + full_assistant_content += part + json_result = {"data": {"v": part}} + yield json.dumps( + json_result, + ensure_ascii=True + ) + + if len(full_assistant_content) > 0: + await self.conversation_service.add_message( + conversation_id=self.conversation.id, + role=MessageRole.ASSISTANT, + content=full_assistant_content + ) + + def get_conversation_history_messages( + self, conversation_id: int, skip: int = 0, limit: int = 100 + ): + """Get conversation history messages with pagination.""" + result = [] + with PostgresSaver.from_conn_string(conn_string=self._conn_string) as checkpointer: + checkpoints = checkpointer.list(self.get_config()) + for checkpoint in checkpoints: + print(checkpoint) + result.append(checkpoint.messages) + return result diff --git a/th_agenter/services/conversation.py b/th_agenter/services/conversation.py new file mode 100644 index 0000000..f87c158 --- /dev/null +++ b/th_agenter/services/conversation.py @@ -0,0 +1,295 @@ +"""Conversation service.""" + +from typing import List, Optional +from sqlalchemy.orm import Session +from sqlalchemy import select, desc, func, or_ +from langchain_core.messages import HumanMessage, AIMessage + +from th_agenter.db.database import AsyncSessionFactory + +from ..models.conversation import Conversation +from ..models.message import Message, MessageRole +from utils.util_schemas import ConversationCreate, ConversationUpdate +from utils.util_exceptions import ConversationNotFoundError, DatabaseError +from ..core.context import UserContext +from datetime import datetime, timezone +from loguru import logger + +class ConversationService: + """Service for managing conversations and messages.""" + + def __init__(self, session: Session): + self.session = session + + async def create_conversation( + self, + user_id: int, + conversation_data: ConversationCreate + ) -> Conversation: + """Create a new conversation.""" + self.session.desc = f"创建新会话 - 用户ID: {user_id},会话数据: {conversation_data}" + + try: + conversation = Conversation( + **conversation_data.model_dump(), + user_id=user_id + ) + + # Set audit fields + conversation.set_audit_fields(user_id=user_id, is_update=False) + + self.session.add(conversation) + await self.session.commit() + await self.session.refresh(conversation) + + self.session.desc = f"创建新会话 Conversation ID: {conversation.id},用户ID: {user_id}" + return conversation + + except Exception as e: + self.session.desc = f"ERROR: 创建会话失败 - 用户ID: {user_id},错误: {str(e)}" + await self.session.rollback() + raise DatabaseError(f"创建会话失败: {str(e)}") + + async def get_conversation(self, conversation_id: int) -> Optional[Conversation]: + """Get a conversation by ID.""" + try: + user_id = UserContext.get_current_user_id() + self.session.desc = f"获取会话 - 会话ID: {conversation_id},用户ID: {user_id}" + if user_id is None: + logger.error(f"Failed to get conversation {conversation_id}: No user context available") + return None + + conversation = await self.session.scalar( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id + ) + ) + + if not conversation: + self.session.desc = f"警告: 会话 {conversation_id} 不存在,用户ID: {user_id}" + return conversation + + except Exception as e: + self.session.desc = f"ERROR: 获取会话失败 - 会话ID: {conversation_id},用户ID: {user_id},错误: {str(e)}" + raise DatabaseError(f"Failed to get conversation: {str(e)}") + + async def get_user_conversations( + self, + skip: int = 0, + limit: int = 50, + search_query: Optional[str] = None, + include_archived: bool = False, + order_by: str = "updated_at", + order_desc: bool = True + ) -> List[Conversation]: + """Get user's conversations with search and filtering.""" + user_id = UserContext.get_current_user_id() + if user_id is None: + logger.error("Failed to get user conversations: No user context available") + return [] + + query = select(Conversation).where( + Conversation.user_id == user_id + ) + + # Filter archived conversations + if not include_archived: + query = query.where(Conversation.is_archived == False) + + # Search functionality + if search_query and search_query.strip(): + search_term = f"%{search_query.strip()}%" + query = query.where( + or_( + Conversation.title.ilike(search_term), + Conversation.system_prompt.ilike(search_term) + ) + ) + + # Ordering + order_column = getattr(Conversation, order_by, Conversation.updated_at) + if order_desc: + query = query.order_by(desc(order_column)) + else: + query = query.order_by(order_column) + + return (await self.session.scalars(query.offset(skip).limit(limit))).all() + + async def update_conversation( + self, + conversation_id: int, + conversation_update: ConversationUpdate + ) -> Optional[Conversation]: + """Update a conversation.""" + conversation = await self.get_conversation(conversation_id) + if not conversation: + return None + + update_data = conversation_update.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(conversation, field, value) + + # Update audit fields + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + + try: + await self.session.commit() + await self.session.refresh(conversation) + return conversation + except Exception as e: + logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True) + await self.session.rollback() + raise DatabaseError(f"Failed to update conversation: {str(e)}") + + async def delete_conversation(self, conversation_id: int) -> bool: + """Delete a conversation.""" + conversation = await self.get_conversation(conversation_id) + if not conversation: + return False + + await self.session.delete(conversation) + await self.session.commit() + return True + + async def get_conversation_messages( + self, + conversation_id: int, + skip: int = 0, + limit: int = 100 + ) -> List[Message]: + """Get messages from a conversation.""" + return (await self.session.scalars( + select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at).offset(skip).limit(limit) + )).all() + + async def add_message( + self, + conversation_id: int, + content: str, + role: MessageRole, + message_metadata: Optional[dict] = None, + context_documents: Optional[list] = None, + prompt_tokens: Optional[int] = None, + completion_tokens: Optional[int] = None, + total_tokens: Optional[int] = None + ) -> Message: + """Add a message to a conversation.""" + message = Message( + conversation_id=conversation_id, + content=content, + role=role, + message_metadata=message_metadata, + context_documents=context_documents, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens + ) + + # Set audit fields + message.set_audit_fields() + + session = AsyncSessionFactory() + session.begin() + try: + session.add(message) + await session.commit() + await session.refresh(message) + + # Update conversation's updated_at timestamp + conversation = await self.get_conversation(conversation_id) + if conversation: + conversation.updated_at = datetime.now(timezone.utc) + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + await session.commit() + except Exception as e: + logger.error(f"Failed to add message to conversation {conversation_id}: {str(e)}", exc_info=True) + await session.rollback() + finally: + await session.close() + + return message + + async def get_conversation_history_messages( + self, + conversation_id: int, + limit: int = 20 + ) -> List[Message]: + """Get recent conversation history messages.""" + history = await self.get_conversation_history(conversation_id, limit) + history_messages = [] + for message in history: + if message.role == MessageRole.USER: + history_messages.append(HumanMessage(content=message.content)) + elif message.role == MessageRole.ASSISTANT: + history_messages.append(AIMessage(content=message.content)) + return history_messages + + async def get_conversation_history( + self, + conversation_id: int, + limit: int = 20 + ) -> List[Message]: + """Get recent conversation history for context.""" + return (await self.session.scalars( + select(Message).where( + Message.conversation_id == conversation_id + ).order_by(desc(Message.created_at)).limit(limit) + )).all()[::-1] # Reverse to get chronological order + + async def update_conversation_timestamp(self, conversation_id: int) -> None: + """Update conversation's updated_at timestamp.""" + conversation = await self.get_conversation(conversation_id) + if conversation: + conversation.updated_at = datetime.now(timezone.utc) + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + await self.session.commit() + + async def get_user_conversations_count( + self, + search_query: Optional[str] = None, + include_archived: bool = False + ) -> int: + """Get total count of user's conversations.""" + user_id = UserContext.get_current_user_id() + query = select(func.count(Conversation.id)).where( + Conversation.user_id == user_id + ) + + if not include_archived: + query = query.where(Conversation.is_archived == False) + + if search_query and search_query.strip(): + search_term = f"%{search_query.strip()}%" + query = query.where( + or_( + Conversation.title.ilike(search_term), + Conversation.system_prompt.ilike(search_term) + ) + ) + + return (await self.session.scalar(query)) or 0 + + async def archive_conversation(self, conversation_id: int) -> bool: + """Archive a conversation.""" + conversation = await self.get_conversation(conversation_id) + if not conversation: + return False + + conversation.is_archived = True + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + await self.session.commit() + return True + + async def unarchive_conversation(self, conversation_id: int) -> bool: + """Unarchive a conversation.""" + conversation = await self.get_conversation(conversation_id) + if not conversation: + return False + + conversation.is_archived = False + conversation.set_audit_fields(user_id=conversation.user_id, is_update=True) + await self.session.commit() + return True \ No newline at end of file diff --git a/th_agenter/services/conversation_context.py b/th_agenter/services/conversation_context.py new file mode 100644 index 0000000..bc34b2f --- /dev/null +++ b/th_agenter/services/conversation_context.py @@ -0,0 +1,310 @@ +from typing import Dict, Any, List, Optional +import json +from datetime import datetime +from sqlalchemy.orm import Session +from th_agenter.models.conversation import Conversation +from th_agenter.models.message import Message +from th_agenter.db.database import get_session + +class ConversationContextService: + """ + 对话上下文管理服务 + 用于管理智能问数的对话历史和上下文信息 + """ + + def __init__(self): + self.context_cache = {} # 内存缓存对话上下文 + + async def create_conversation(self, user_id: int, title: str = "智能问数对话") -> int: + """ + 创建新的对话 + + Args: + user_id: 用户ID + title: 对话标题 + + Returns: + 新创建的对话ID + """ + try: + session = await anext(get_session()) + + conversation = Conversation( + user_id=user_id, + title=title, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + session.add(conversation) + await session.commit() + await session.refresh(conversation) + + # 初始化对话上下文 + self.context_cache[conversation.id] = { + 'conversation_id': conversation.id, + 'user_id': user_id, + 'file_list': [], + 'selected_files': [], + 'query_history': [], + 'created_at': datetime.utcnow().isoformat() + } + + return conversation.id + + except Exception as e: + print(f"创建对话失败: {e}") + raise + finally: + session.close() + + async def get_conversation_context(self, conversation_id: int) -> Optional[Dict[str, Any]]: + """ + 获取对话上下文 + + Args: + conversation_id: 对话ID + + Returns: + 对话上下文信息 + """ + # 先从缓存中查找 + if conversation_id in self.context_cache: + return self.context_cache[conversation_id] + + # 从数据库加载 + try: + session = await anext(get_session()) + + conversation = session.query(Conversation).filter( + Conversation.id == conversation_id + ).first() + + if not conversation: + return None + + # 加载消息历史 + messages = session.query(Message).filter( + Message.conversation_id == conversation_id + ).order_by(Message.created_at).all() + + # 重建上下文 + context = { + 'conversation_id': conversation_id, + 'user_id': conversation.user_id, + 'file_list': [], + 'selected_files': [], + 'query_history': [], + 'created_at': conversation.created_at.isoformat() + } + + # 从消息中提取查询历史 + for message in messages: + if message.role == 'user': + context['query_history'].append({ + 'query': message.content, + 'timestamp': message.created_at.isoformat() + }) + elif message.role == 'assistant' and message.metadata: + # 从助手消息的元数据中提取文件信息 + try: + metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata + if 'selected_files' in metadata: + context['selected_files'] = metadata['selected_files'] + if 'file_list' in metadata: + context['file_list'] = metadata['file_list'] + except (json.JSONDecodeError, TypeError): + pass + + # 缓存上下文 + self.context_cache[conversation_id] = context + + return context + + except Exception as e: + print(f"获取对话上下文失败: {e}") + return None + finally: + session.close() + + async def update_conversation_context( + self, + conversation_id: int, + file_list: List[Dict[str, Any]] = None, + selected_files: List[Dict[str, Any]] = None, + query: str = None + ) -> bool: + """ + 更新对话上下文 + + Args: + conversation_id: 对话ID + file_list: 文件列表 + selected_files: 选中的文件 + query: 用户查询 + + Returns: + 更新是否成功 + """ + try: + # 获取或创建上下文 + context = await self.get_conversation_context(conversation_id) + if not context: + return False + + # 更新上下文信息 + if file_list is not None: + context['file_list'] = file_list + + if selected_files is not None: + context['selected_files'] = selected_files + + if query is not None: + context['query_history'].append({ + 'query': query, + 'timestamp': datetime.utcnow().isoformat() + }) + + # 更新缓存 + self.context_cache[conversation_id] = context + + return True + + except Exception as e: + print(f"更新对话上下文失败: {e}") + return False + + async def save_message( + self, + conversation_id: int, + role: str, + content: str, + metadata: Dict[str, Any] = None + ) -> bool: + """ + 保存消息到数据库 + + Args: + conversation_id: 对话ID + role: 消息角色 (user/assistant) + content: 消息内容 + metadata: 元数据 + + Returns: + 保存是否成功 + """ + try: + session = await anext(get_session()) + + message = Message( + conversation_id=conversation_id, + role=role, + content=content, + metadata=json.dumps(metadata) if metadata else None, + created_at=datetime.utcnow() + ) + + session.add(message) + await session.commit() + + # 更新对话的最后更新时间 + conversation = session.query(Conversation).filter( + Conversation.id == conversation_id + ).first() + + if conversation: + conversation.updated_at = datetime.utcnow() + await session.commit() + + return True + + except Exception as e: + print(f"保存消息失败: {e}") + return False + finally: + session.close() + + async def reset_conversation_context(self, conversation_id: int) -> bool: + """ + 重置对话上下文 + + Args: + conversation_id: 对话ID + + Returns: + 重置是否成功 + """ + try: + # 清除缓存 + if conversation_id in self.context_cache: + context = self.context_cache[conversation_id] + # 保留基本信息,清除文件和查询历史 + context.update({ + 'file_list': [], + 'selected_files': [], + 'query_history': [] + }) + + return True + + except Exception as e: + print(f"重置对话上下文失败: {e}") + return False + + async def get_conversation_history(self, conversation_id: int) -> List[Dict[str, Any]]: + """ + 获取对话历史消息 + + Args: + conversation_id: 对话ID + + Returns: + 消息历史列表 + """ + try: + session = await anext(get_session()) + + messages = session.query(Message).filter( + Message.conversation_id == conversation_id + ).order_by(Message.created_at).all() + + history = [] + for message in messages: + msg_data = { + 'id': message.id, + 'role': message.role, + 'content': message.content, + 'timestamp': message.created_at.isoformat() + } + + if message.metadata: + try: + metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata + msg_data['metadata'] = metadata + except (json.JSONDecodeError, TypeError): + pass + + history.append(msg_data) + + return history + + except Exception as e: + print(f"获取对话历史失败: {e}") + return [] + finally: + session.close() + + def clear_cache(self, conversation_id: int = None): + """ + 清除缓存 + + Args: + conversation_id: 特定对话ID,如果为None则清除所有缓存 + """ + if conversation_id: + self.context_cache.pop(conversation_id, None) + else: + self.context_cache.clear() + +# 全局实例 +conversation_context_service = ConversationContextService() \ No newline at end of file diff --git a/th_agenter/services/database_config_service.py b/th_agenter/services/database_config_service.py new file mode 100644 index 0000000..5e8ffba --- /dev/null +++ b/th_agenter/services/database_config_service.py @@ -0,0 +1,375 @@ +"""数据库配置服务""" +from loguru import logger +from typing import List, Dict, Any, Optional +from sqlalchemy.orm import Session +from sqlalchemy import select +from cryptography.fernet import Fernet +import os + +from ..models.database_config import DatabaseConfig +from utils.util_exceptions import ValidationError, NotFoundError +from .postgresql_tool_manager import get_postgresql_tool +from .mysql_tool_manager import get_mysql_tool + +class DatabaseConfigService: + """数据库配置管理服务""" + + def __init__(self, db_session: Session): + self.session = db_session + self.postgresql_tool = get_postgresql_tool() + self.mysql_tool = get_mysql_tool() + # 初始化加密密钥 + self.encryption_key = self._get_or_create_encryption_key() + self.cipher = Fernet(self.encryption_key) + def _get_or_create_encryption_key(self) -> bytes: + """获取或创建加密密钥""" + key_file = "db/db_config_key.key" + if os.path.exists(key_file): + print('find db_config_key') + with open(key_file, 'rb') as f: + return f.read() + + else: + print('not find db_config_key') + key = Fernet.generate_key() + with open(key_file, 'wb') as f: + f.write(key) + return key + + def _encrypt_password(self, password: str) -> str: + """加密密码""" + return self.cipher.encrypt(password.encode()).decode() + + def _decrypt_password(self, encrypted_password: str) -> str: + """解密密码""" + return self.cipher.decrypt(encrypted_password.encode()).decode() + + async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig: + """创建数据库配置""" + try: + # 验证配置 + required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password'] + for field in required_fields: + if field not in config_data: + raise ValidationError(f"缺少必需字段: {field}") + + + # 测试连接 + test_config = { + 'host': config_data['host'], + 'port': config_data['port'], + 'database': config_data['database'], + 'username': config_data['username'], + 'password': config_data['password'] + } + if 'postgresql' == config_data['db_type']: + test_result = await self.postgresql_tool.execute( + operation="test_connection", + connection_config=test_config + ) + if not test_result.success: + raise ValidationError(f"数据库连接测试失败: {test_result.error}") + elif 'mysql' == config_data['db_type']: + test_result = await self.mysql_tool.execute( + operation="test_connection", + connection_config=test_config + ) + if not test_result.success: + raise ValidationError(f"数据库连接测试失败: {test_result.error}") + # 如果设置为默认配置,先取消其他默认配置 + if config_data.get('is_default', False): + stmt = select(DatabaseConfig).where( + DatabaseConfig.created_by == user_id, + DatabaseConfig.is_default == True + ) + result = self.session.execute(stmt) + for config in result.scalars(): + config.is_default = False + + # 创建配置 + db_config = DatabaseConfig( + created_by=user_id, + name=config_data['name'], + db_type=config_data['db_type'], + host=config_data['host'], + port=config_data['port'], + database=config_data['database'], + username=config_data['username'], + password=self._encrypt_password(config_data['password']), + is_active=config_data.get('is_active', True), + is_default=config_data.get('is_default', False), + connection_params=config_data.get('connection_params') + ) + + self.session.add(db_config) + await self.session.commit() + await self.session.refresh(db_config) + + logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})") + return db_config + + except Exception as e: + await self.session.rollback() + logger.error(f"创建数据库配置失败: {str(e)}") + raise + + async def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]: + """获取用户的数据库配置列表""" + stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id) + if active_only: + stmt = stmt.where(DatabaseConfig.is_active == True) + stmt = stmt.order_by(DatabaseConfig.created_at.desc()) + return (await self.session.execute(stmt)).scalars().all() + + async def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]: + """根据ID获取配置""" + stmt = select(DatabaseConfig).where( + DatabaseConfig.id == config_id, + DatabaseConfig.created_by == user_id + ) + return (await self.session.execute(stmt)).scalar_one_or_none() + + async def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]: + """获取用户的默认配置""" + stmt = select(DatabaseConfig).where( + DatabaseConfig.created_by == user_id, + # DatabaseConfig.is_default == True, + DatabaseConfig.is_active == True + ) + return (await self.session.execute(stmt)).scalar_one_or_none() + + async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]: + """测试数据库连接""" + config = self.get_config_by_id(config_id, user_id) + if not config: + raise NotFoundError("数据库配置不存在") + + test_config = { + 'host': config.host, + 'port': config.port, + 'database': config.database, + 'username': config.username, + 'password': self._decrypt_password(config.password) + } + + result = await self.postgresql_tool.execute( + operation="test_connection", + connection_config=test_config + ) + + return { + 'success': result.success, + 'message': result.result.get('message') if result.success else result.error, + 'details': result.result if result.success else None + } + + async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]: + """连接数据库并获取表列表""" + config = self.get_config_by_id(config_id, user_id) + if not config: + raise NotFoundError("数据库配置不存在") + + connection_config = { + 'host': config.host, + 'port': config.port, + 'database': config.database, + 'username': config.username, + 'password': self._decrypt_password(config.password) + } + + if 'postgresql' == config.db_type: + # 连接数据库 + connect_result = await self.postgresql_tool.execute( + operation="connect", + connection_config=connection_config, + user_id=str(user_id) + ) + elif 'mysql' == config.db_type: + # 连接数据库 + connect_result = await self.mysql_tool.execute( + operation="connect", + connection_config=connection_config, + user_id=str(user_id) + ) + + if not connect_result.success: + return { + 'success': False, + 'message': connect_result.error + } + # 连接信息已保存到PostgreSQLMCPTool的connections中 + return { + 'success': True, + 'data': connect_result.result, + 'config_name': config.name + } + + async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]: + """获取表数据预览(复用已建立的连接)""" + try: + user_id_str = str(user_id) + + # 根据db_type选择相应的数据库工具 + if db_type.lower() == 'postgresql': + db_tool = self.postgresql_tool + elif db_type.lower() == 'mysql': + db_tool = self.mysql_tool + else: + return { + 'success': False, + 'message': f'不支持的数据库类型: {db_type}' + } + + # 检查是否已有连接 + if user_id_str not in db_tool.connections: + return { + 'success': False, + 'message': '数据库连接已断开,请重新连接数据库' + } + + # 直接使用已建立的连接执行查询 + sql_query = f"SELECT * FROM {table_name}" + result = await db_tool.execute( + operation="execute_query", + user_id=user_id_str, + sql_query=sql_query, + limit=limit + ) + + if not result.success: + return { + 'success': False, + 'message': result.error + } + + return { + 'success': True, + 'data': result.result, + 'db_type': db_type + } + + except Exception as e: + logger.error(f"获取表数据失败: {str(e)}", exc_info=True) + return { + 'success': False, + 'message': f'获取表数据失败: {str(e)}' + } + + def disconnect_database(self, user_id: int) -> Dict[str, Any]: + """断开数据库连接""" + try: + # 从PostgreSQLMCPTool断开连接 + self.postgresql_tool.execute( + operation="disconnect", + user_id=str(user_id) + ) + + # 从本地连接管理中移除 + if user_id in self.user_connections: + del self.user_connections[user_id] + + return { + 'success': True, + 'message': '数据库连接已断开' + } + except Exception as e: + return { + 'success': False, + 'message': f'断开连接失败: {str(e)}' + } + + async def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]: + """根据数据库类型获取用户配置""" + stmt = select(DatabaseConfig).where( + DatabaseConfig.created_by == user_id, + DatabaseConfig.db_type == db_type, + DatabaseConfig.is_active == True + ) + return await self.session.scalar(stmt) + + async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig: + """创建或更新数据库配置(保证db_type唯一性)""" + try: + # 检查是否已存在该类型的配置 + existing_config = self.get_config_by_type(user_id, config_data['db_type']) + + if existing_config: + # 更新现有配置 + for key, value in config_data.items(): + if key == 'password': + setattr(existing_config, key, self._encrypt_password(value)) + elif hasattr(existing_config, key): + setattr(existing_config, key, value) + + await self.session.commit() + await self.session.refresh(existing_config) + logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})") + return existing_config + else: + # 创建新配置 + return await self.create_config(user_id, config_data) + + except Exception as e: + await self.session.rollback() + logger.error(f"创建或更新数据库配置失败: {str(e)}") + raise + + + async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]: + """获取表结构信息(复用已建立的连接)""" + try: + logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}") + user_id_str = str(user_id) + + # 获取用户默认数据库配置 + default_config = self.get_default_config(user_id) + if not default_config: + return { + 'success': False, + 'message': '未找到默认数据库配置' + } + + # 根据db_type选择相应的数据库工具 + if default_config.db_type.lower() == 'postgresql': + db_tool = self.postgresql_tool + elif default_config.db_type.lower() == 'mysql': + db_tool = self.mysql_tool + else: + return { + 'success': False, + 'message': f'不支持的数据库类型: {default_config.db_type}' + } + + # 检查是否已有连接 + if user_id_str not in db_tool.connections: + return { + 'success': False, + 'message': '数据库连接已断开,请重新连接数据库' + } + + # 使用已建立的连接执行describe_table操作 + result = await db_tool.execute( + operation="describe_table", + user_id=user_id_str, + table_name=table_name + ) + + if not result.success: + return { + 'success': False, + 'message': result.error + } + + return { + 'success': True, + 'data': result.result, + 'db_type': default_config.db_type + } + + except Exception as e: + logger.error(f"获取表结构失败: {str(e)}", exc_info=True) + return { + 'success': False, + 'message': f'获取表结构失败: {str(e)}' + } + \ No newline at end of file diff --git a/th_agenter/services/document.py b/th_agenter/services/document.py new file mode 100644 index 0000000..04c49f8 --- /dev/null +++ b/th_agenter/services/document.py @@ -0,0 +1,319 @@ +"""Document service.""" + +import os +from pathlib import Path +from typing import List, Optional, Dict, Any +from sqlalchemy import select, func +from sqlalchemy.orm import Session +from fastapi import UploadFile + +from ..models.knowledge_base import Document, KnowledgeBase +from ..core.config import get_settings +from utils.util_file import FileUtils +from .storage import storage_service +from .document_processor import get_document_processor +from utils.util_schemas import DocumentChunk +from loguru import logger + +settings = get_settings() + + +class DocumentService: + """Document service for managing documents in knowledge bases.""" + + def __init__(self, session: Session): + self.session = session + self.file_utils = FileUtils() + + async def upload_document(self, file: UploadFile, kb_id: int) -> Document: + """Upload a document to knowledge base.""" + self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}" + # Validate knowledge base exists + stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) + kb = await self.session.scalar(stmt) + if not kb: + self.session.desc = f"ERROR: 知识库 {kb_id} 不存在" + raise ValueError(f"知识库 {kb_id} 不存在") + + # Validate file + if not file.filename: + self.session.desc = f"ERROR: 上传文件时未提供文件名" + raise ValueError("No filename provided") + + # Validate file extension + file_extension = Path(file.filename).suffix.lower() + if file_extension not in settings.file.allowed_extensions: + self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}" + raise ValueError(f"非期望的文件类型 {file_extension}") + + # Upload file using storage service + storage_info = await storage_service.upload_file(file, kb_id) + self.session.desc = f"文档 {file.filename} 上传到 {storage_info}" + + # Create document record + document = Document( + knowledge_base_id=kb_id, + filename=os.path.basename(storage_info["file_path"]), + original_filename=file.filename, + file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available + file_size=storage_info["size"], + file_type=file_extension, + mime_type=storage_info["mime_type"], + is_processed=False + ) + + # Set audit fields + document.set_audit_fields() + + self.session.add(document) + await self.session.commit() + await self.session.refresh(document) + + self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})" + return document + + async def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]: + """根据文档ID查询文档,可选地根据知识库ID过滤。""" + self.session.desc = f"根据文档ID查询文档 {doc_id}" + stmt = select(Document).where(Document.id == doc_id) + if kb_id is not None: + stmt = stmt.where(Document.knowledge_base_id == kb_id) + return await self.session.scalar(stmt) + + async def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]: + """根据知识库ID查询文档,支持分页。""" + self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" + stmt = ( + select(Document) + .where(Document.knowledge_base_id == kb_id) + .offset(skip) + .limit(limit) + ) + return (await self.session.scalars(stmt)).all() + + async def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]: + """根据知识库ID查询文档,支持分页,并返回总文档数。""" + self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)" + # Get total count + count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id) + total = await self.session.scalar(count_stmt) + + # Get documents with pagination + documents_stmt = ( + select(Document) + .where(Document.knowledge_base_id == kb_id) + .offset(skip) + .limit(limit) + ) + documents = (await self.session.scalars(documents_stmt)).all() + + return documents, total + + async def delete_document(self, doc_id: int, kb_id: int = None) -> bool: + """根据文档ID删除文档,可选地根据知识库ID过滤。""" + self.session.desc = f"删除文档 {doc_id}" + document = await self.get_document(doc_id, kb_id) + if not document: + self.session.desc = f"ERROR: 文档 {doc_id} 不存在" + return False + + # Delete file from storage + try: + await storage_service.delete_file(document.file_path) + self.session.desc = f"SUCCESS: 删除文档 {doc_id} 关联文件 {document.file_path}" + except Exception as e: + self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}" + + # TODO: Remove from vector database + # This should be implemented when vector database service is ready + self.session.desc = f"从向量数据库删除文档 {doc_id}" + (await get_document_processor(self.session)).delete_document_from_vector_store(kb_id,doc_id) + # Delete database record + self.session.desc = f"删除数据库记录 {doc_id}" + await self.session.delete(document) + await self.session.commit() + self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}" + return True + + async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]: + """处理文档,提取文本并创建嵌入向量。""" + try: + self.session.desc = f"处理文档 {doc_id} - 提取文本并创建嵌入向量" + document = await self.get_document(doc_id, kb_id) + self.session.desc = f"获取文档 {doc_id} >>> {document}" + if not document: + self.session.desc = f"ERROR: 文档 {doc_id} 不存在" + raise ValueError(f"Document {doc_id} not found") + + # document.file_path[为('C:\\DrGraph\\TH_Backend\\data\\uploads\\kb_1\\997eccbb-9081-4ddf-879e-bc7d781fab50_答辩.txt',) ,需要取第一个元素 + file_path = document.file_path + knowledge_base_id=document.knowledge_base_id + is_processed=document.is_processed + + if is_processed: + self.session.desc = f"INFO: 文档 {doc_id} 已处理" + return { + "document_id": doc_id, + "status": "already_processed", + "message": "文档已处理" + } + + self.session.desc = f"查询文档完毕 {doc_id} >>> is_processed = {is_processed}" + # 更新文档状态为处理中 + document.processing_error = None + await self.session.commit() + self.session.desc = f"更新文档状态为处理中 {doc_id}" + + # 调用文档处理器进行处理 + document_processor = await get_document_processor(self.session) + self.session.desc = f"调用文档处理器进行处理=== {doc_id} >>> {document_processor}" + result = await document_processor.process_document( + session=self.session, + document_id=doc_id, + file_path=file_path, + knowledge_base_id=knowledge_base_id + ) + self.session.desc = f"处理文档完毕 {doc_id}" + + # 如果处理成功,更新文档状态 + if result["status"] == "success": + document.is_processed = True + document.chunk_count = result.get("chunks_count", 0) + await self.session.commit() + await self.session.refresh(document) + logger.info(f"Processed document: {document.filename} (ID: {doc_id})") + + return result + + except Exception as e: + await self.session.rollback() + self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}" + + # Update document with error + try: + document = await self.get_document(doc_id, kb_id) + if document: + document.processing_error = str(e) + await self.session.commit() + except Exception as db_error: + logger.error(f"Failed to update document error status: {db_error}") + + return { + "document_id": doc_id, + "status": "failed", + "error": str(e), + "message": "文档处理失败" + } + + async def _extract_text(self, document: Document) -> str: + """从文档文件中提取文本内容。""" + try: + if document.is_text_file: + # Read text files directly + with open(document.file_path, 'r', encoding='utf-8') as f: + return f.read() + + elif document.is_pdf_file: + # TODO: Implement PDF text extraction using PyPDF2 or similar + # For now, return placeholder + return f"PDF content from {document.original_filename}" + + elif document.is_office_file: + # TODO: Implement Office file text extraction using python-docx, openpyxl, etc. + # For now, return placeholder + return f"Office document content from {document.original_filename}" + + else: + self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}" + raise ValueError(f"不支持的文件类型: {document.file_type}") + + except Exception as e: + self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}" + raise + + async def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool: + """更新文档处理状态。""" + self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}" + document = await self.get_document(doc_id) + if not document: + self.session.desc = f"ERROR: 文档 {doc_id} 不存在" + return False + + document.is_processed = is_processed + document.processing_error = error + + await self.session.commit() + self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}" + return True + + async def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]: + """在知识库中搜索文档使用向量相似度。""" + try: + # 使用文档处理器进行相似性搜索 + self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {limit}条" + results = (await get_document_processor(self.session)).search_similar_documents(kb_id, query, limit) + self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果" + return results + except Exception as e: + self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}" + logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}") + return [] + + async def get_document_stats(self, kb_id: int) -> Dict[str, Any]: + """获取知识库中的文档统计信息。""" + documents = await self.get_documents(kb_id, limit=1000) # Get all documents + + total_count = len(documents) + processed_count = len([doc for doc in documents if doc.is_processed]) + total_size = sum(doc.file_size for doc in documents) + + file_types = {} + for doc in documents: + file_type = doc.file_type + file_types[file_type] = file_types.get(file_type, 0) + 1 + + return { + "total_documents": total_count, + "processed_documents": processed_count, + "pending_documents": total_count - processed_count, + "total_size_bytes": total_size, + "total_size_mb": round(total_size / (1024 * 1024), 2), + "file_types": file_types + } + + async def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]: + """获取特定文档的文档块。""" + try: + self.session.desc = f"获取文档 {doc_id} 的文档块" + stmt = select(Document).where(Document.id == doc_id) + document = await self.session.scalar(stmt) + if not document: + self.session.desc = f"ERROR: 文档 {doc_id} 不存在" + return [] + + self.session.desc = f"获取文档 {doc_id} 的文档块 > document" + # Get chunks from document processor + chunks_data = (await get_document_processor(self.session)).get_document_chunks(document.knowledge_base_id, doc_id) + + self.session.desc = f"获取文档 {doc_id} 的文档块 > chunks_data" + # Convert to DocumentChunk objects + chunks = [] + for chunk_data in chunks_data: + chunk = DocumentChunk( + id=chunk_data["id"], + content=chunk_data["content"], + metadata=chunk_data["metadata"], + page_number=chunk_data.get("page_number"), + chunk_index=chunk_data["chunk_index"], + start_char=chunk_data.get("start_char"), + end_char=chunk_data.get("end_char"), + vector_id=chunk_data.get("vector_id") + ) + chunks.append(chunk) + + self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)} 个" + return chunks + + except Exception as e: + self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}" + return [] \ No newline at end of file diff --git a/th_agenter/services/document_processor.py b/th_agenter/services/document_processor.py new file mode 100644 index 0000000..5f90e63 --- /dev/null +++ b/th_agenter/services/document_processor.py @@ -0,0 +1,751 @@ +"""文档处理服务,负责文档的分段、向量化和索引""" + +import os +from typing import List, Dict, Any, Optional +from pathlib import Path +from fastapi import HTTPException +from requests import Session +from sqlalchemy import text +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import ( + TextLoader, + PyPDFLoader, + Docx2txtLoader, + UnstructuredMarkdownLoader +) +import pdfplumber +from langchain_core.documents import Document +from langchain_postgres import PGVector +from typing import List + + +from ..core.config import BaseSettings, get_settings +from ..models.knowledge_base import Document as DocumentModel +from ..db.database import get_session +from loguru import logger + +settings = get_settings() +class DocumentProcessor: + """文档处理器,负责文档的加载、分段和向量化""" + + def __init__(self): + # 初始化语义分割器配置 + self.embeddings = None + self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=settings.file.chunk_size, + chunk_overlap=settings.file.chunk_overlap, + length_function=len, + separators=["\n\n", "\n", " ", ""] + ) + + async def initialize(self, session: Session = None): + # 初始化嵌入模型 - 根据配置选择提供商 + await self._init_embeddings(session) + + # 初始化连接池(仅对PGVector) + self.pgvector_pool = None + + # PostgreSQL pgvector连接配置 + # if settings.vector_db.type == "pgvector": + # # 新版本PGVector使用psycopg3连接字符串 + # # 对密码进行URL编码以处理特殊字符(如@符号) + # encoded_password = quote(settings.vector_db.pgvector_password, safe="") + # self.connection_string = ( + # f"postgresql+psycopg://{settings.vector_db.pgvector_user}:" + # f"{encoded_password}@" + # f"{settings.vector_db.pgvector_host}:" + # f"{settings.vector_db.pgvector_port}/" + # f"{settings.vector_db.pgvector_database}" + # ) + # # 初始化连接池 + # self.pgvector_pool = PGVectorConnectionPool() + # logger.info("新版本PGVector使用psycopg3连接字符串: %s", self.connection_string) + # else: + # 向量数据库存储路径(Chroma兼容) + vector_db_path = settings.vector_db.persist_directory + if not os.path.isabs(vector_db_path): + # 如果是相对路径,则基于项目根目录计算绝对路径 + # 项目根目录是backend的父目录 + backend_dir = Path(__file__).parent.parent.parent + vector_db_path = str(backend_dir / vector_db_path) + self.vector_db_path = vector_db_path + session.desc = f"初始化向量数据库 - 路径 = {self.vector_db_path}" + + async def _init_embeddings(self, session: Optional[Any] = None): + """初始化嵌入模型。""" + try: + if not self.embeddings: + # 使用llm_config_service获取嵌入配置 + from .llm_config_service import LLMConfigService + llm_config_service = LLMConfigService() + + # 获取嵌入配置 + config = None + if session: + config = await llm_config_service.get_default_embedding_config(session) + if config: + if(session != None): + session.desc = f"获取默认嵌入模型配置: {config}" + # # 转换配置格式 + # config = { + # "provider": config.provider, + # "api_key": config.api_key, + # "model": config.model_name + # } + + # 如果未找到配置,使用默认配置 + if not config: + session.desc = f"ERROR: 未找到嵌入模型配置" + raise HTTPException(status_code=400, detail="未找到嵌入模型配置") + session.desc = f"获取嵌入模型配置 > 结果:{config}" + + # 根据配置创建嵌入模型 + if config.provider == "openai": + from langchain_openai import OpenAIEmbeddings + self.embeddings = OpenAIEmbeddings( + model=config.get("model", "text-embedding-3-small"), + api_key=config.get("api_key") + ) + session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})" + elif config.provider == "ollama": + from langchain_ollama import OllamaEmbeddings + self.embeddings = OllamaEmbeddings( + model=config.model_name, + base_url=config.base_url + ) + session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})" + elif config.provider == "local": + from langchain_huggingface import HuggingFaceEmbeddings + self.embeddings = HuggingFaceEmbeddings( + model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2") + ) + session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})" + else: + # 默认使用OpenAI + from langchain_openai import OpenAIEmbeddings + self.embeddings = OpenAIEmbeddings( + model=config.get("model", "text-embedding-3-small"), + api_key=config.get("api_key") + ) + session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效" + + return self.embeddings + except Exception as e: + logger.error(f"初始化嵌入模型时出错: {e}") + raise + + def load_document(self, session: Session, file_path: str) -> List[Document]: + """根据文件类型加载文档""" + file_extension = Path(file_path).suffix.lower() + try: + if file_extension == '.txt': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: txt" + loader = TextLoader(file_path, encoding='utf-8') + documents = loader.load() + elif file_extension == '.pdf': + # 使用pdfplumber处理PDF文件,更稳定 + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: pdf" + from langchain_community.document_loaders import PyPDFLoader + loader = PyPDFLoader(file_path) + documents = loader.load() + # documents = self._load_pdf_with_pdfplumber(file_path) + elif file_extension == '.docx': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: docx" + loader = Docx2txtLoader(file_path) + documents = loader.load() + elif file_extension == '.md': + session.desc = f"加载文档 - 文件路径: {file_path} - 类型: md" + loader = UnstructuredMarkdownLoader(file_path) + documents = loader.load() + else: + raise ValueError(f"不支持的文件类型: {file_extension}") + + session.desc = f"已载文档: {file_path}, 页数: {len(documents)}" + # if len(documents) > 0: + # session.desc = f"文档内容示例: {type(documents[0])} - {documents[0]}" + return documents + + except Exception as e: + session.desc = f"ERROR: 加载文档失败 {file_path}: {str(e)}" + raise e + + def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]: + """使用pdfplumber加载PDF文档""" + documents = [] + try: + with pdfplumber.open(file_path) as pdf: + for page_num, page in enumerate(pdf.pages): + text = page.extract_text() + if text and text.strip(): # 只处理有文本内容的页面 + doc = Document( + page_content=text, + metadata={ + "source": file_path, + "page": page_num + 1 + } + ) + documents.append(doc) + return documents + except Exception as e: + logger.error(f"使用pdfplumber加载PDF失败 {file_path}: {str(e)}") + # 如果pdfplumber失败,回退到PyPDFLoader + try: + loader = PyPDFLoader(file_path) + return loader.load() + except Exception as fallback_e: + logger.error(f"PyPDFLoader回退也失败 {file_path}: {str(fallback_e)}") + raise fallback_e + + def _merge_documents(self, documents: List[Document]) -> Document: + """将多个文档合并成一个文档""" + merged_text = "" + merged_metadata = {} + + for doc in documents: + if merged_text: + merged_text += "\n\n" + merged_text += doc.page_content + # 合并元数据 + merged_metadata.update(doc.metadata) + + return Document(page_content=merged_text, metadata=merged_metadata) + + def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]: + """根据语义分割点切分文本""" + chunks = [] + current_pos = 0 + + # 按顺序查找每个分割点并切分文本 + for point in split_points: + pos = text.find(point, current_pos) + if pos != -1: + # 添加当前位置到分割点位置的文本块 + if pos > current_pos: + chunk = text[current_pos:pos].strip() + if chunk: + chunks.append(chunk) + current_pos = pos + + # 添加最后一个文本块 + if current_pos < len(text): + chunk = text[current_pos:].strip() + if chunk: + chunks.append(chunk) + + return chunks + + async def split_documents(self, session: Session, documents: List[Document]) -> List[Document]: + """将文档分割成小块(含短段落合并和超长强制分割功能)""" + try: + chunks = self.text_splitter.split_documents(documents) + + session.desc = f"文档分割完成,共生成 {len(chunks)} 个文档块" + if len(chunks) > 0: + session.desc = f"文档块内容示例: {type(chunks[0])} - {chunks[0]}" + return chunks + + except Exception as e: + session.desc = f"ERROR: 文档分割失败: {str(e)}" + raise e + + def _force_split_long_chunk(self, chunk: str) -> List[str]: + """强制分割超长段落(超过1000字符)""" + max_length = 1000 + chunks = [] + + # 先尝试按换行符分割 + if '\n' in chunk: + lines = chunk.split('\n') + current_chunk = "" + for line in lines: + if len(current_chunk) + len(line) + 1 > max_length: + if current_chunk: + chunks.append(current_chunk) + current_chunk = line + else: + chunks.append(line[:max_length]) + current_chunk = line[max_length:] + else: + if current_chunk: + current_chunk += "\n" + line + else: + current_chunk = line + if current_chunk: + chunks.append(current_chunk) + else: + # 没有换行符则直接按长度分割 + chunks = [chunk[i:i + max_length] for i in range(0, len(chunk), max_length)] + + return chunks + + def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str: + """为知识库创建向量存储""" + try: + # if settings.vector_db.type == "pgvector": + # # 添加元数据 + # for i, doc in enumerate(documents): + # doc.metadata.update({ + # "knowledge_base_id": knowledge_base_id, + # "document_id": str(document_id) if document_id else "unknown", + # "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + # "chunk_index": i + # }) + + # # 创建PostgreSQL pgvector存储 + # collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" + + # # 创建新版本PGVector实例 + # vector_store = PGVector( + # connection=self.connection_string, + # embeddings=self.embeddings, + # collection_name=collection_name, + # use_jsonb=True # 使用JSONB存储元数据 + # ) + + # # 手动添加文档 + # vector_store.add_documents(documents) + + # logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}") + # return collection_name + # else: + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + # 添加元数据 + for i, doc in enumerate(documents): + doc.metadata.update({ + "knowledge_base_id": knowledge_base_id, + "document_id": str(document_id) if document_id else "unknown", + "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + "chunk_index": i + }) + + # 创建向量存储 + vector_store = Chroma.from_documents( + documents=documents, + embedding=self.embeddings, + persist_directory=kb_vector_path + ) + + logger.info(f"向量存储创建成功: {kb_vector_path}") + return kb_vector_path + + except Exception as e: + logger.error(f"创建向量存储失败: {str(e)}") + raise + + def add_documents_to_vector_store(self, session: Session, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None: + """向现有向量存储添加文档""" + if len(documents) == 0: + session.desc = f"WARNING: 文档列表为空,不执行添加操作" + return + from langchain_chroma import Chroma + + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + session.desc = f"添加文档到向量存储: {kb_vector_path} - documents number: {len(documents)}" + # 检查向量存储是否存在 + if not os.path.exists(kb_vector_path): + # 如果不存在,创建新的向量存储 + session.desc = f"WARNING: 向量存储不存在,创建新的向量存储" + self.create_vector_store(knowledge_base_id, documents, document_id) + return + session.desc = f"添加文档到向量存储: exists" + # 添加元数据 + for i, doc in enumerate(documents): + doc.metadata.update({ + "knowledge_base_id": knowledge_base_id, + "document_id": str(document_id) if document_id else "unknown", + "chunk_id": f"{knowledge_base_id}_{document_id}_{i}", + "chunk_index": i + }) + + session.desc = f"添加文档到向量存储: enumerate" + # 加载现有向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + session.desc = f"添加文档到向量存储: Chroma" + # 添加新文档 + ids = vector_store.add_documents(documents) + session.desc = f"文档已添加到向量存储: {kb_vector_path} -> {len(ids)} IDS - \n{ids}" + + async def process_document(self, session: Session, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]: + """处理单个文档:加载、分段、向量化""" + try: + session.desc = f"处理文档 ID: {document_id} 文件路径: {file_path}" + + # 1. 加载文档 + documents = self.load_document(session, file_path) + + # 2. 分割文档 + chunks = await self.split_documents(session, documents) + + # 3. 添加到向量存储 + self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id) + + # 4. 更新文档状态 + session = await anext(get_session()) + try: + from sqlalchemy import select + document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) + + if document: + document.status = "processed" + document.chunk_count = len(chunks) + await session.commit() + finally: + await session.close() + + result = { + "document_id": document_id, + "status": "success", + "chunks_count": len(chunks), + "message": "文档处理完成" + } + + + session.desc = f"文档处理完成: {result}" + return result + + except Exception as e: + session.desc = f"ERROR: 文档处理失败 ID: {document_id}: {str(e)}" + + # 更新文档状态为失败 + try: + session = await anext(get_session()) + try: + from sqlalchemy import select + document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) + if document: + document.status = "failed" + document.error_message = str(e) + await session.commit() + finally: + await session.close() + except Exception as db_error: + session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}" + + return { + "document_id": document_id, + "status": "failed", + "error": str(e), + "message": "文档处理失败" + } + + def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None: + """从向量存储中删除文档""" + try: + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"向量存储不存在: {kb_vector_path}") + return + + chunks = self.get_document_chunks(knowledge_base_id, document_id) + # 加载向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + count_before = vector_store._collection.count() + count_after = count_before + + if len(chunks) > 0: + where_filter = {"document_id": str(document_id)} + vector_store.delete(where=where_filter) + count_after = vector_store._collection.count() + + # 注意:Chroma的删除功能可能需要特定的实现方式 + logger.info(f"文档已从向量存储中删除: document_id={document_id},删除前有 {count_before} 个向量,删除后有 {count_after} 个向量") + + except Exception as e: + logger.error(f"从向量存储删除文档失败: {str(e)}") + raise + + def get_document_chunks(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]: + """获取文档的所有分段内容 + + 改进说明: + - 避免使用空查询进行相似性搜索,防止触发不必要的embedding API调用 + - 优先使用直接SQL查询,提高性能 + - 确保结果按chunk_index排序 + """ + try: + return self._get_chunks_chroma(knowledge_base_id, document_id) + + except Exception as e: + logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}") + return [] + + def _get_chunks_by_sql(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]: + """使用SQLAlchemy连接池查询获取文档分段(推荐方法)""" + try: + if not self.pgvector_pool: + logger.error("PGVector连接池未初始化") + return [] + + # 直接SQL查询,避免相似性搜索和embedding计算 + query = """ + SELECT + id, + document, + cmetadata + FROM langchain_pg_embedding + WHERE cmetadata->>'document_id' = :document_id + AND cmetadata->>'knowledge_base_id' = :knowledge_base_id + ORDER BY + CAST(cmetadata->>'chunk_index' AS INTEGER) ASC; + """ + + # 使用连接池执行查询 + session = self.pgvector_pool.get_session() + try: + result = session.execute( + text(query), + { + 'document_id': str(document_id), + 'knowledge_base_id': str(knowledge_base_id) + } + ) + results = result.fetchall() + + chunks = [] + for row in results: + # SQLAlchemy结果行访问 + metadata = row.cmetadata + chunk = { + "id": f"chunk_{document_id}_{metadata.get('chunk_index', 0)}", + "content": row.document, + "metadata": metadata, + "page_number": metadata.get("page"), + "chunk_index": metadata.get("chunk_index", 0), + "start_char": metadata.get("start_char"), + "end_char": metadata.get("end_char") + } + chunks.append(chunk) + + logger.info(f"通过SQLAlchemy连接池查询获取到文档 {document_id} 的 {len(chunks)} 个分段") + return chunks + + finally: + session.close() + + except Exception as e: + logger.error(f"SQLAlchemy连接池查询失败: {e}") + return [] + + def _get_chunks_by_langchain_improved(self, knowledge_base_id: int, document_id: int, collection_name: str) -> List[Dict[str, Any]]: + """改进的LangChain查询方法(回退方案)""" + try: + vector_store = PGVector( + connection=self.connection_string, + embeddings=self.embeddings, + collection_name=collection_name, + use_jsonb=True + ) + + # 使用有意义的查询而不是空查询,避免触发embedding API错误 + # 先尝试获取少量结果来构造查询 + try: + sample_results = vector_store.similarity_search( + query="文档内容", # 使用通用查询词而非空字符串 + k=5, + filter={"document_id": {"$eq": str(document_id)}} + ) + + if sample_results: + # 使用第一个结果的内容片段作为查询 + first_content = sample_results[0].page_content[:50] + results = vector_store.similarity_search( + query=first_content, + k=1000, + filter={"document_id": {"$eq": str(document_id)}} + ) + else: + # 如果没有结果,尝试不使用filter的查询 + results = vector_store.similarity_search( + query="文档", + k=1000 + ) + # 手动过滤结果 + results = [doc for doc in results if doc.metadata.get("document_id") == str(document_id)] + + except Exception as e: + logger.warning(f"改进的相似性搜索失败: {e}") + return [] + + chunks = [] + for i, doc in enumerate(results): + chunk = { + "id": f"chunk_{document_id}_{i}", + "content": doc.page_content, + "metadata": doc.metadata, + "page_number": doc.metadata.get("page"), + "chunk_index": doc.metadata.get("chunk_index", i), + "start_char": doc.metadata.get("start_char"), + "end_char": doc.metadata.get("end_char") + } + chunks.append(chunk) + + # 按chunk_index排序 + chunks.sort(key=lambda x: x.get("chunk_index", 0)) + + logger.info(f"通过改进的LangChain方法获取到文档 {document_id} 的 {len(chunks)} 个分段") + return chunks + + except Exception as e: + logger.error(f"LangChain改进方法失败: {e}") + return [] + + def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]: + """Chroma存储的处理逻辑""" + from langchain_chroma import Chroma + # 构建向量数据库路径 + vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(vector_db_path): + logger.warning(f"向量数据库不存在: {vector_db_path}") + return [] + + # 加载向量数据库 + vectorstore = Chroma( + persist_directory=vector_db_path, + embedding_function=self.embeddings + ) + + # 获取所有文档的元数据,筛选出指定文档的分段 + collection = vectorstore._collection + all_docs = collection.get(include=["metadatas", "documents"]) + all_ids_data = collection.get() + + chunks = [] + chunk_index = 0 + + for i, metadata in enumerate(all_docs["metadatas"]): + if metadata.get("document_id") == str(document_id): + chunk_content = all_docs["documents"][i] + vector_id = all_ids_data["ids"][i] + + chunk = { + "id": f"chunk_{document_id}_{chunk_index}", + "content": chunk_content, + "metadata": metadata, + "page_number": metadata.get("page"), + "chunk_index": chunk_index, + "start_char": metadata.get("start_char"), + "end_char": metadata.get("end_char"), + "vector_id": vector_id + } + chunks.append(chunk) + chunk_index += 1 + + return chunks + + def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]: + """在知识库中搜索相似文档""" + try: + # if settings.vector_db.type == "pgvector": + # # PostgreSQL pgvector存储 + # collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}" + + # try: + # vector_store = PGVector( + # connection=self.connection_string, + # embeddings=self.embeddings, + # collection_name=collection_name, + # use_jsonb=True + # ) + + # # 执行相似性搜索 + # results = vector_store.similarity_search_with_score(query, k=k) + + # # 格式化结果 + # formatted_results = [] + # for doc, distance_score in results: + # # pgvector使用余弦距离,距离越小相似度越高 + # # 将距离转换为0-1之间的相似度分数 + # similarity_score = 1.0 / (1.0 + distance_score) + + # formatted_results.append({ + # "content": doc.page_content, + # "metadata": doc.metadata, + # "similarity_score": distance_score, # 保留原始距离分数 + # "normalized_score": similarity_score, # 归一化相似度分数 + # "source": doc.metadata.get('filename', 'unknown'), + # "document_id": doc.metadata.get('document_id', 'unknown'), + # "chunk_id": doc.metadata.get('chunk_id', 'unknown') + # }) + + # # 按相似度分数排序(距离越小越相似) + # formatted_results.sort(key=lambda x: x['similarity_score']) + + # logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档") + # return formatted_results + + # except Exception as e: + # logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}") + # return [] + # else: + # Chroma兼容模式 + from langchain_chroma import Chroma + kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"向量存储不存在: {kb_vector_path}") + return [] + + # 加载向量存储 + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + # 执行相似性搜索 + results = vector_store.similarity_search_with_score(query, k=k) + + # 格式化结果 + formatted_results = [] + for doc, distance_score in results: + # Chroma使用欧几里得距离,距离越小相似度越高 + # 将距离转换为0-1之间的相似度分数 + similarity_score = 1.0 / (1.0 + distance_score) + + formatted_results.append({ + "content": doc.page_content, + "metadata": doc.metadata, + "similarity_score": distance_score, # 保留原始距离分数 + "normalized_score": similarity_score, # 归一化相似度分数 + "source": doc.metadata.get('filename', 'unknown'), + "document_id": doc.metadata.get('document_id', 'unknown'), + "chunk_id": doc.metadata.get('chunk_id', 'unknown') + }) + + # 按相似度分数排序(距离越小越相似) + formatted_results.sort(key=lambda x: x['similarity_score']) + + logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档") + return formatted_results + + except Exception as e: + logger.error(f"搜索文档失败: {str(e)}") + return [] # 返回空列表而不是抛出异常 + +# 全局文档处理器实例(延迟初始化) +document_processor = None + +async def get_document_processor(session: Session = None): + """获取文档处理器实例(延迟初始化)""" + global document_processor + if session: + session.desc = "获取文档处理器实例" + if document_processor is None: + document_processor = DocumentProcessor() + await document_processor.initialize(session) + return document_processor \ No newline at end of file diff --git a/th_agenter/services/embedding_factory.py b/th_agenter/services/embedding_factory.py new file mode 100644 index 0000000..64afa06 --- /dev/null +++ b/th_agenter/services/embedding_factory.py @@ -0,0 +1,86 @@ +"""Embedding factory for different providers.""" + +from typing import Optional +from langchain_core.embeddings import Embeddings +from langchain_openai import OpenAIEmbeddings +from langchain_community.embeddings import HuggingFaceEmbeddings +from requests import Session +from .zhipu_embeddings import ZhipuOpenAIEmbeddings +from ..core.config import settings +from loguru import logger + +class EmbeddingFactory: + """Factory class for creating embedding instances based on provider.""" + + @staticmethod + async def create_embeddings( + session: Session = None, + provider: Optional[str] = None, + model: Optional[str] = None, + dimensions: Optional[int] = None + ) -> Embeddings: + """Create embeddings instance based on provider. + + Args: + provider: Embedding provider (openai, zhipu, deepseek, doubao, moonshot, sentence-transformers) + model: Model name + dimensions: Embedding dimensions + + Returns: + Embeddings instance + """ + # 使用新的embedding配置 + embedding_config = await settings.embedding.get_current_config(session) + provider = provider or settings.embedding.provider + model = model or embedding_config.get("model") + dimensions = dimensions or settings.vector_db.embedding_dimension + + session.desc = f"创建嵌入模型: {provider}, {model}" + + if provider == "openai": + return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions) + elif provider in ["zhipu", "deepseek", "doubao", "moonshot"]: + return EmbeddingFactory._create_openai_compatible_embeddings(embedding_config, model, dimensions, provider) + elif provider == "sentence-transformers": + return EmbeddingFactory._create_huggingface_embeddings(model) + else: + raise ValueError(f"Unsupported embedding provider: {provider}") + + @staticmethod + def _create_openai_embeddings(embedding_config: dict, model: str, dimensions: int) -> OpenAIEmbeddings: + """Create OpenAI embeddings.""" + return OpenAIEmbeddings( + api_key=embedding_config["api_key"], + base_url=embedding_config["base_url"], + model=model if model.startswith("text-embedding") else "text-embedding-ada-002", + dimensions=dimensions if model.startswith("text-embedding-3") else None + ) + + + + @staticmethod + def _create_openai_compatible_embeddings(embedding_config: dict, model: str, dimensions: int, provider: str) -> Embeddings: + """Create OpenAI-compatible embeddings for ZhipuAI, DeepSeek, Doubao, Moonshot.""" + if provider == "zhipu": + return ZhipuOpenAIEmbeddings( + api_key=embedding_config["api_key"], + base_url=embedding_config["base_url"], + model=model if model.startswith("embedding") else "embedding-3", + dimensions=dimensions + ) + else: + return OpenAIEmbeddings( + api_key=embedding_config["api_key"], + base_url=embedding_config["base_url"], + model=model, + dimensions=dimensions + ) + + @staticmethod + def _create_huggingface_embeddings(model: str) -> HuggingFaceEmbeddings: + """Create HuggingFace embeddings.""" + return HuggingFaceEmbeddings( + model_name=model, + model_kwargs={'device': 'cpu'}, + encode_kwargs={'normalize_embeddings': True} + ) \ No newline at end of file diff --git a/th_agenter/services/excel_metadata_service.py b/th_agenter/services/excel_metadata_service.py new file mode 100644 index 0000000..52c626c --- /dev/null +++ b/th_agenter/services/excel_metadata_service.py @@ -0,0 +1,240 @@ +"""Excel metadata extraction service.""" + +import os +import pandas as pd +from typing import Dict, List, Any, Optional, Tuple +from sqlalchemy.orm import Session +from sqlalchemy import select, func +from ..models.excel_file import ExcelFile +from ..db.database import get_session +from loguru import logger + + +class ExcelMetadataService: + """Service for extracting and managing Excel file metadata.""" + + def __init__(self, session: Session): + self.session = session + + def extract_file_metadata(self, file_path: str, original_filename: str, + user_id: int, file_size: int) -> Dict[str, Any]: + """Extract metadata from Excel file.""" + try: + # Determine file type + file_extension = os.path.splitext(original_filename)[1].lower() + + # Read Excel file + if file_extension == '.csv': + # For CSV files, treat as single sheet + df = pd.read_csv(file_path) + sheets_data = {'Sheet1': df} + else: + # For Excel files, read all sheets + sheets_data = pd.read_excel(file_path, sheet_name=None) + + # Extract metadata for each sheet + sheet_names = list(sheets_data.keys()) + columns_info = {} + preview_data = {} + data_types = {} + total_rows = {} + total_columns = {} + + for sheet_name, df in sheets_data.items(): + # Clean column names (remove unnamed columns) + df = df.loc[:, ~df.columns.str.contains('^Unnamed')] + + # Get column information - ensure proper encoding + columns_info[sheet_name] = [str(col) if not isinstance(col, str) else col for col in df.columns.tolist()] + + # Get preview data (first 5 rows) and convert to JSON serializable format + preview_df = df.head(5) + # Convert all values to strings to ensure JSON serialization + preview_values = [] + for row in preview_df.values: + string_row = [] + for value in row: + if pd.isna(value): + string_row.append(None) + elif hasattr(value, 'strftime'): # Handle datetime/timestamp objects + string_row.append(value.strftime('%Y-%m-%d %H:%M:%S')) + else: + # Preserve Chinese characters and other unicode content + if isinstance(value, str): + string_row.append(value) + else: + string_row.append(str(value)) + preview_values.append(string_row) + preview_data[sheet_name] = preview_values + + # Get data types + data_types[sheet_name] = {col: str(dtype) for col, dtype in df.dtypes.items()} + + # Get statistics + total_rows[sheet_name] = len(df) + total_columns[sheet_name] = len(df.columns) + + # Determine default sheet + default_sheet = sheet_names[0] if sheet_names else None + + return { + 'sheet_names': sheet_names, + 'default_sheet': default_sheet, + 'columns_info': columns_info, + 'preview_data': preview_data, + 'data_types': data_types, + 'total_rows': total_rows, + 'total_columns': total_columns, + 'is_processed': True, + 'processing_error': None + } + + except Exception as e: + logger.error(f"Error extracting metadata from {file_path}: {str(e)}") + return { + 'sheet_names': [], + 'default_sheet': None, + 'columns_info': {}, + 'preview_data': {}, + 'data_types': {}, + 'total_rows': {}, + 'total_columns': {}, + 'is_processed': False, + 'processing_error': str(e) + } + + async def save_file_metadata(self, file_path: str, original_filename: str, + user_id: int, file_size: int) -> ExcelFile: + """Extract and save Excel file metadata to database.""" + try: + # Extract metadata + metadata = self.extract_file_metadata(file_path, original_filename, user_id, file_size) + + # Determine file type + file_extension = os.path.splitext(original_filename)[1].lower() + + # Create ExcelFile record + excel_file = ExcelFile( + original_filename=original_filename, + file_path=file_path, + file_size=file_size, + file_type=file_extension, + sheet_names=metadata['sheet_names'], + default_sheet=metadata['default_sheet'], + columns_info=metadata['columns_info'], + preview_data=metadata['preview_data'], + data_types=metadata['data_types'], + total_rows=metadata['total_rows'], + total_columns=metadata['total_columns'], + is_processed=metadata['is_processed'], + processing_error=metadata['processing_error'] + ) + + + # Save to database + self.session.add(excel_file) + await self.session.commit() + await self.session.refresh(excel_file) + + logger.info(f"Saved metadata for file {original_filename} with ID {excel_file.id}") + return excel_file + + except Exception as e: + logger.error(f"Error saving metadata for {original_filename}: {str(e)}") + await self.session.rollback() + raise + + async def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]: + """Get Excel files for a user with pagination.""" + try: + # Get total count + stmt = select(func.count(ExcelFile.id)).where(ExcelFile.created_by == user_id) + result = await self.session.execute(stmt) + total = result.scalar_one() + + # Get files with pagination + stmt = select(ExcelFile).where(ExcelFile.created_by == user_id).order_by(ExcelFile.created_at.desc()).offset(skip).limit(limit) + result = await self.session.execute(stmt) + files = result.scalars().all() + + return files, total + + except Exception as e: + logger.error(f"Error getting user files for user {user_id}: {str(e)}") + return [], 0 + + async def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]: + """Get Excel file by ID and user ID.""" + try: + stmt = select(ExcelFile).where(ExcelFile.id == file_id, ExcelFile.created_by == user_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + except Exception as e: + logger.error(f"Error getting file {file_id} for user {user_id}: {str(e)}") + return None + + async def delete_file(self, file_id: int, user_id: int) -> bool: + """Delete Excel file record and physical file.""" + try: + # Get file record + excel_file = await self.get_file_by_id(file_id, user_id) + if not excel_file: + return False + + # Delete physical file if exists + if os.path.exists(excel_file.file_path): + os.remove(excel_file.file_path) + logger.info(f"Deleted physical file: {excel_file.file_path}") + + # Delete database record + await self.session.delete(excel_file) + await self.session.commit() + + logger.info(f"Deleted Excel file record with ID {file_id}") + return True + + except Exception as e: + logger.error(f"Error deleting file {file_id}: {str(e)}") + await self.session.rollback() + return False + + async def update_last_accessed(self, file_id: int, user_id: int) -> bool: + """Update last accessed time for a file.""" + try: + excel_file = await self.get_file_by_id(file_id, user_id) + if not excel_file: + return False + + excel_file.last_accessed = func.now() + await self.session.commit() + + return True + + except Exception as e: + logger.error(f"Error updating last accessed for file {file_id}: {str(e)}") + await self.session.rollback() + return False + + async def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]: + """Get file summary information for LLM context.""" + try: + stmt = select(ExcelFile).where(ExcelFile.user_id == user_id) + result = await self.session.execute(stmt) + files = result.scalars().all() + + summary = [] + for file in files: + file_info = { + 'file_id': file.id, + 'filename': file.original_filename, + 'file_type': file.file_type, + 'sheets': file.get_all_sheets_summary(), + 'upload_time': file.upload_time.isoformat() if file.upload_time else None + } + summary.append(file_info) + + return summary + + except Exception as e: + logger.error(f"Error getting file summary for user {user_id}: {str(e)}") + return [] \ No newline at end of file diff --git a/th_agenter/services/knowledge_base.py b/th_agenter/services/knowledge_base.py new file mode 100644 index 0000000..abfc531 --- /dev/null +++ b/th_agenter/services/knowledge_base.py @@ -0,0 +1,244 @@ +"""Knowledge base service.""" + +# Standard library imports +from typing import List, Optional, Dict, Any + +# Third-party imports +from loguru import logger +from sqlalchemy import select, and_, or_ +from sqlalchemy.orm import Session + +# Local imports +from ..core.config import get_settings +from ..core.context import UserContext +from ..models.knowledge_base import KnowledgeBase +from .document_processor import get_document_processor +from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate + +settings = get_settings() + +class KnowledgeBaseService: + """知识库基础服务类,用于管理知识基础。 + + 该服务类提供了创建、获取、更新、删除和搜索知识库基础的功能。 + """ + + def __init__(self, session: Session): + """初始化知识库基础服务类。 + + Args: + session (Session): 数据库会话,用于执行ORM操作。 + """ + if session is None: + logger.error("session为空,session must be an instance of Session") + self.session = session + + async def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase: + """创建一个新的知识库实例。 + + Args: + kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据。 + + Returns: + KnowledgeBase: 创建的知识库实例。 + + Raises: + Exception: 如果创建过程中发生错误。 + """ + try: + # Generate collection name for vector database + collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}" + + kb = KnowledgeBase( + name=kb_data.name, + description=kb_data.description, + embedding_model=kb_data.embedding_model, + chunk_size=kb_data.chunk_size, + chunk_overlap=kb_data.chunk_overlap, + vector_db_type=settings.vector_db.type, + collection_name=collection_name + ) + + # 自动更新created_by和updated_by字段 + kb.set_audit_fields() + + self.session.add(kb) + await self.session.commit() + await self.session.refresh(kb) + + self.session.desc = f"Created knowledge base: {kb.name} - collection_name = {collection_name}, embedding_model = {kb.embedding_model}" + return kb + + except Exception as e: + await self.session.rollback() + logger.error(f"Failed to create knowledge base: {str(e)}") + raise + + async def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]: + """Search knowledge bases by name or description for the current user. + + Args: + query (str): Search query. + skip (int, optional): Number of records to skip. Defaults to 0. + limit (int, optional): Maximum number of records to return. Defaults to 50. + + Returns: + List[KnowledgeBase]: List of matching knowledge bases. + """ + stmt = select(KnowledgeBase).where( + KnowledgeBase.created_by == UserContext.get_current_user()['id'], + KnowledgeBase.is_active == True, + or_( + KnowledgeBase.name.ilike(f"%{query}%"), + KnowledgeBase.description.ilike(f"%{query}%") + ) + ) + + stmt = stmt.offset(skip).limit(limit) + return (await self.session.execute(stmt)).scalars().all() + + async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: + """Search in knowledge base using vector similarity. + + Args: + kb_id (int): ID of the knowledge base to search in. + query (str): Search query. + top_k (int, optional): Maximum number of results to return. Defaults to 5. + similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7. + + Returns: + List[Dict[str, Any]]: List of search results with content, source, score, and metadata. + """ + try: + logger.info(f"Searching in knowledge base {kb_id} for: {query}") + + # Use document processor for vector search + search_results = (await get_document_processor(self.session)).search_similar_documents( + knowledge_base_id=kb_id, + query=query, + k=top_k + ) + + # Filter by similarity threshold + filtered_results = [] + for result in search_results: + # Use already normalized similarity score + normalized_score = result.get('normalized_score', 0) + + if normalized_score >= similarity_threshold: + filtered_results.append({ + "content": result.get('content', ''), + "source": result.get('source', 'unknown'), + "score": normalized_score, + "metadata": result.get('metadata', {}), + "document_id": result.get('document_id', 'unknown'), + "chunk_id": result.get('chunk_id', 'unknown') + }) + + logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})") + return filtered_results + + except Exception as e: + logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}") + return [] + +# ---------------------------------------------------------------------------------- + async def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]: + """根据名称获取当前用户的知识库实例。 + + Args: + name (str): 知识库实例的名称。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 + """ + stmt = select(KnowledgeBase).where( + KnowledgeBase.name == name, + KnowledgeBase.created_by == UserContext.get_current_user()['id'] + ) + result = (await self.session.execute(stmt)).scalar_one_or_none() + return result + + async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]: + """获取当前用户的所有知识库的列表。 + + Args: + skip (int, optional): 跳过的记录数。默认值为0。 + limit (int, optional): 返回的最大记录数。默认值为50。 + active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。 + + Returns: + List[KnowledgeBase]: 当前用户的知识库列表。 + """ + stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user()['id']) # 使用字典键索引访问用户ID + if active_only: + stmt = stmt.where(KnowledgeBase.is_active == True) + stmt = stmt.offset(skip).limit(limit) + return (await self.session.execute(stmt)).scalars().all() + + async def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]: + """根据ID获取知识库实例。 + + Args: + kb_id (int): 知识库实例的ID。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。 + """ + stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) + return (await self.session.execute(stmt)).scalar_one_or_none() + + async def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]: + """更新知识库实例。 + + Args: + kb_id (int): 待更新的知识库实例ID。 + kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。 + + Returns: + Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。 + + Raises: + Exception: 如果更新过程中发生错误。 + """ + kb = await self.get_knowledge_base(kb_id) + if not kb: + return None + + # Update fields + update_data = kb_update.model_dump(exclude_unset=True) + for field, value in update_data.items(): + setattr(kb, field, value) + + # Set audit fields + kb.set_audit_fields(is_update=True) + + await self.session.commit() + await self.session.refresh(kb) + + self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})" + return kb + + async def delete_knowledge_base(self, kb_id: int) -> bool: + """删除知识库实例。 + + Args: + kb_id (int): 待删除的知识库实例ID。 + + Returns: + bool: 如果知识库实例被成功删除则返回True,否则返回False。 + + Raises: + Exception: 如果删除过程中发生错误。 + """ + kb = await self.get_knowledge_base(kb_id) + if not kb: + return False + + # TODO: Clean up vector database collection + # This should be implemented when vector database service is ready + + await self.session.delete(kb) + await self.session.commit() + + return True diff --git a/th_agenter/services/knowledge_chat.py b/th_agenter/services/knowledge_chat.py new file mode 100644 index 0000000..561c747 --- /dev/null +++ b/th_agenter/services/knowledge_chat.py @@ -0,0 +1,351 @@ +"""Knowledge base chat service using LangChain RAG.""" + +import asyncio +from typing import List, Dict, Any, Optional, AsyncGenerator +from sqlalchemy.orm import Session + +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnablePassthrough +from langchain_core.output_parsers import StrOutputParser +from langchain_chroma import Chroma +from langchain_postgres import PGVector +from .embedding_factory import EmbeddingFactory + +from ..core.config import settings +from ..models.message import MessageRole +from utils.util_schemas import ChatResponse, MessageResponse +from utils.util_exceptions import ChatServiceError +from .conversation import ConversationService +from .document_processor import get_document_processor +from loguru import logger + +class KnowledgeChatService: + """Knowledge base chat service using LangChain RAG.""" + + def __init__(self, session: Session): + self.session = session + self.conversation_service = ConversationService(session) + + async def initialize(self): + # 获取当前LLM配置 + llm_config = await settings.llm.get_current_config(self.session) + + # Initialize LangChain ChatOpenAI + self.llm = ChatOpenAI( + model=llm_config["model"], + api_key=llm_config["api_key"], + base_url=llm_config["base_url"], + temperature=llm_config["temperature"], + max_tokens=llm_config["max_tokens"], + streaming=False + ) + + # Streaming LLM for stream responses + self.streaming_llm = ChatOpenAI( + model=llm_config["model"], + api_key=llm_config["api_key"], + base_url=llm_config["base_url"], + temperature=llm_config["temperature"], + max_tokens=llm_config["max_tokens"], + streaming=True + ) + + # Initialize embeddings based on provider + self.embeddings = await EmbeddingFactory.create_embeddings(self.session) + async def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]: + """Get vector store for knowledge base.""" + try: + import os + kb_vector_path = os.path.join((await get_document_processor(self.session)).vector_db_path, f"kb_{knowledge_base_id}") + + if not os.path.exists(kb_vector_path): + logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}") + return None + + vector_store = Chroma( + persist_directory=kb_vector_path, + embedding_function=self.embeddings + ) + + return vector_store + + except Exception as e: + logger.error(f"Failed to load vector store for KB {knowledge_base_id}: {str(e)}") + return None + + def _create_rag_chain(self, vector_store, conversation_history: List[Dict[str, str]]): + """Create RAG chain with conversation history.""" + + # Create retriever + retriever = vector_store.as_retriever( + search_type="similarity", + search_kwargs={"k": 5} + ) + + # Create prompt template + system_prompt = """你是一个智能助手,基于提供的上下文信息回答用户问题。 + +上下文信息: +{context} + +请根据上下文信息回答用户的问题。如果上下文信息不足以回答问题,请诚实地说明。 +保持回答准确、有用且简洁。""" + + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}") + ]) + + # Create chain + def format_docs(docs): + return "\n\n".join(doc.page_content for doc in docs) + + rag_chain = ( + { + "context": retriever | format_docs, + "question": RunnablePassthrough(), + "chat_history": lambda x: conversation_history + } + | prompt + | self.llm + | StrOutputParser() + ) + + return rag_chain, retriever + + def _prepare_conversation_history(self, messages: List) -> List[Dict[str, str]]: + """Prepare conversation history for RAG chain.""" + history = [] + + for msg in messages[:-1]: # Exclude the last message (current user message) + if msg.role == MessageRole.USER: + history.append({"role": "human", "content": msg.content}) + elif msg.role == MessageRole.ASSISTANT: + history.append({"role": "assistant", "content": msg.content}) + + return history + + async def chat_with_knowledge_base( + self, + conversation_id: int, + message: str, + knowledge_base_id: int, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> ChatResponse: + """Chat with knowledge base using RAG.""" + + try: + # Get conversation and validate + conversation = await self.conversation_service.get_conversation(conversation_id) + if not conversation: + raise ChatServiceError("Conversation not found") + + # Get vector store + vector_store = self._get_vector_store(knowledge_base_id) + if not vector_store: + raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed") + + # Save user message + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Get conversation history + messages = await self.conversation_service.get_conversation_messages(conversation_id) + conversation_history = self._prepare_conversation_history(messages) + + # Create RAG chain + rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history) + + # Get relevant documents for context + relevant_docs = retriever.get_relevant_documents(message) + context_documents = [] + + for doc in relevant_docs: + context_documents.append({ + "content": doc.page_content[:500], # Limit content length + "metadata": doc.metadata, + "source": doc.metadata.get("filename", "unknown") + }) + + # Generate response + if stream: + # For streaming, we'll use a different approach + response_content = await self._generate_streaming_response( + rag_chain, message, conversation_id + ) + else: + response_content = await asyncio.to_thread(rag_chain.invoke, message) + + # Save assistant message with context + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=response_content, + role=MessageRole.ASSISTANT, + context_documents=context_documents + ) + + # Create response + return ChatResponse( + user_message=MessageResponse.from_orm(user_message), + assistant_message=MessageResponse.from_orm(assistant_message), + model_used=self.llm.model_name, + total_tokens=None # TODO: Calculate tokens if needed + ) + + except Exception as e: + logger.error(f"Knowledge base chat failed: {str(e)}") + raise ChatServiceError(f"Knowledge base chat failed: {str(e)}") + + async def _generate_streaming_response( + self, + rag_chain, + message: str, + conversation_id: int + ) -> str: + """Generate streaming response (placeholder for now).""" + # For now, use non-streaming approach + # TODO: Implement proper streaming with RAG chain + return await asyncio.to_thread(rag_chain.invoke, message) + + async def chat_stream_with_knowledge_base( + self, + conversation_id: int, + message: str, + knowledge_base_id: int, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> AsyncGenerator[str, None]: + """Chat with knowledge base using RAG with streaming response.""" + + try: + + # Get vector store + vector_store = self._get_vector_store(knowledge_base_id) + if not vector_store: + raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed") + + # Get conversation history + messages = await self.conversation_service.get_conversation_messages(conversation_id) + conversation_history = self._prepare_conversation_history(messages) + + # Create RAG chain + rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history) + + # Save user message + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Get relevant documents + relevant_docs = retriever.get_relevant_documents(message) + context = "\n\n".join([doc.page_content for doc in relevant_docs]) + + # Create streaming LLM + llm_config = await settings.llm.get_current_config() + streaming_llm = ChatOpenAI( + model=llm_config["model"], + temperature=temperature or llm_config["temperature"], + max_tokens=max_tokens or llm_config["max_tokens"], + streaming=True, + api_key=llm_config["api_key"], + base_url=llm_config["base_url"] + ) + + # Create prompt for streaming + prompt = ChatPromptTemplate.from_messages([ + ("system", "你是一个智能助手。请基于以下上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地说明。\n\n上下文信息:\n{context}"), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}") + ]) + + # Prepare chat history for prompt + chat_history_messages = [] + for hist in conversation_history: + if hist["role"] == "human": + chat_history_messages.append(HumanMessage(content=hist["content"])) + elif hist["role"] == "assistant": + chat_history_messages.append(AIMessage(content=hist["content"])) + + # Create streaming chain + streaming_chain = ( + { + "context": lambda x: context, + "chat_history": lambda x: chat_history_messages, + "question": lambda x: x["question"] + } + | prompt + | streaming_llm + | StrOutputParser() + ) + + # Generate streaming response + full_response = "" + async for chunk in streaming_chain.astream({"question": message}): + if chunk: + full_response += chunk + yield chunk + + # Save assistant response + if full_response: + await self.conversation_service.add_message( + conversation_id=conversation_id, + content=full_response, + role=MessageRole.ASSISTANT, + message_metadata={ + "knowledge_base_id": knowledge_base_id, + "relevant_docs_count": len(relevant_docs) + } + ) + + except Exception as e: + logger.error(f"Error in knowledge base streaming chat: {str(e)}") + error_message = f"知识库对话出错: {str(e)}" + yield error_message + + # Save error message + await self.conversation_service.add_message( + conversation_id=conversation_id, + content=error_message, + role=MessageRole.ASSISTANT + ) + + async def search_knowledge_base( + self, + knowledge_base_id: int, + query: str, + k: int = 5 + ) -> List[Dict[str, Any]]: + """Search knowledge base for relevant documents.""" + + try: + vector_store = self._get_vector_store(knowledge_base_id) + if not vector_store: + return [] + + # Perform similarity search + results = vector_store.similarity_search_with_score(query, k=k) + + formatted_results = [] + for doc, score in results: + formatted_results.append({ + "content": doc.page_content, + "metadata": doc.metadata, + "similarity_score": float(score), + "source": doc.metadata.get("filename", "unknown") + }) + + return formatted_results + + except Exception as e: + logger.error(f"Knowledge base search failed: {str(e)}") + return [] \ No newline at end of file diff --git a/th_agenter/services/langchain_chat.py b/th_agenter/services/langchain_chat.py new file mode 100644 index 0000000..8724372 --- /dev/null +++ b/th_agenter/services/langchain_chat.py @@ -0,0 +1,397 @@ +"""LangChain-based chat service.""" + +import json +import asyncio +import os +from typing import AsyncGenerator, Optional, List, Dict, Any +from sqlalchemy.orm import Session + +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +from ..core.config import settings +from ..models.message import MessageRole +from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse +from utils.util_exceptions import ChatServiceError, OpenAIError, AuthenticationError, RateLimitError +from loguru import logger +from .conversation import ConversationService + + +class StreamingCallbackHandler(BaseCallbackHandler): + """Custom callback handler for streaming responses.""" + + def __init__(self): + self.tokens = [] + + def on_llm_new_token(self, token: str, **kwargs) -> None: + """Handle new token from LLM.""" + self.tokens.append(token) + + def get_response(self) -> str: + """Get the complete response.""" + return "".join(self.tokens) + + def clear(self): + """Clear the tokens.""" + self.tokens = [] + + +class LangChainChatService: + """LangChain-based chat service for AI model integration.""" + + def __init__(self, session: Session): + self.session = session + self.conversation_service = ConversationService(session) + + async def initialize(self): + from ..core.new_agent import new_agent + + # Initialize LangChain ChatOpenAI + self.llm = await new_agent(self.session, streaming=False) + self.session.desc = "LangChainChatService初始化 - llm 实例化完毕" + + # Streaming LLM for stream responses + self.streaming_llm = await new_agent(self.session, streaming=True) + self.session.desc = "LangChainChatService初始化 - streaming_llm 实例化完毕" + + self.streaming_handler = StreamingCallbackHandler() + self.session.desc = "LangChainChatService初始化 - streaming_handler 实例化完毕" + + def _prepare_langchain_messages(self, conversation, history: List) -> List: + """Prepare messages for LangChain format.""" + messages = [] + + # Add system message if conversation has system prompt + if hasattr(conversation, 'system_prompt') and conversation.system_prompt: + messages.append(SystemMessage(content=conversation.system_prompt)) + else: + # Default system message + messages.append(SystemMessage( + content="You are a helpful AI assistant. Please provide accurate and helpful responses." + )) + + # Add conversation history + for msg in history[:-1]: # Exclude the last message (current user message) + if msg.role == MessageRole.USER: + messages.append(HumanMessage(content=msg.content)) + elif msg.role == MessageRole.ASSISTANT: + messages.append(AIMessage(content=msg.content)) + + # Add current user message + if history: + last_msg = history[-1] + if last_msg.role == MessageRole.USER: + messages.append(HumanMessage(content=last_msg.content)) + + return messages + + async def chat( + self, + conversation_id: int, + message: str, + stream: bool = False, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> ChatResponse: + """Send a message and get AI response using LangChain.""" + logger.info(f"Processing LangChain chat request for conversation {conversation_id}") + + try: + # Get conversation details + conversation = await self.conversation_service.get_conversation(conversation_id) + if not conversation: + raise ChatServiceError("Conversation not found") + + # Add user message to database + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Get conversation history for context + history = await self.conversation_service.get_conversation_history( + conversation_id, limit=20 + ) + + # Prepare messages for LangChain + langchain_messages = self._prepare_langchain_messages(conversation, history) + + # Update LLM parameters if provided + llm_to_use = self.llm + if temperature is not None or max_tokens is not None: + llm_config = await settings.llm.get_current_config() + llm_to_use = ChatOpenAI( + model=llm_config["model"], + openai_api_key=llm_config["api_key"], + openai_api_base=llm_config["base_url"], + temperature=temperature if temperature is not None else float(conversation.temperature), + max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens, + streaming=False + ) + + # Call LangChain LLM + response = await llm_to_use.ainvoke(langchain_messages) + + # Extract response content + assistant_content = response.content + + # Add assistant message to database + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=assistant_content, + role=MessageRole.ASSISTANT, + message_metadata={ + "model": llm_to_use.model_name, + "langchain_version": "0.1.0", + "provider": "langchain_openai" + } + ) + + # Update conversation timestamp + await self.conversation_service.update_conversation_timestamp(conversation_id) + + logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}") + + return ChatResponse( + user_message=MessageResponse.from_orm(user_message), + assistant_message=MessageResponse.from_orm(assistant_message), + total_tokens=None, # LangChain doesn't provide token count by default + model_used=llm_to_use.model_name + ) + + except Exception as e: + logger.error(f"Failed to process LangChain chat request for conversation {conversation_id}: {str(e)}", exc_info=True) + + # Classify error types for better handling + error_type = type(e).__name__ + error_message = self._format_error_message(e) + + # Add error message to database + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=error_message, + role=MessageRole.ASSISTANT, + message_metadata={ + "error": True, + "error_type": error_type, + "original_error": str(e), + "langchain_error": True + } + ) + + # Re-raise specific exceptions for proper error handling + if "rate limit" in str(e).lower(): + raise RateLimitError(str(e)) + elif "api key" in str(e).lower() or "authentication" in str(e).lower(): + raise AuthenticationError(str(e)) + elif "openai" in str(e).lower(): + raise OpenAIError(str(e)) + + return ChatResponse( + user_message=MessageResponse.from_orm(user_message), + assistant_message=MessageResponse.from_orm(assistant_message), + total_tokens=0, + model_used=self.llm.model_name + ) + + async def chat_stream( + self, + conversation_id: int, + message: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> AsyncGenerator[str, None]: + """Send a message and get streaming AI response using LangChain.""" + logger.info(f"通过 LangChain 进行流式处理对话 请求,会话 ID: {conversation_id}") + + try: + # Get conversation details + conversation = await self.conversation_service.get_conversation(conversation_id) + conv = conversation.to_dict() + if not conversation: + raise ChatServiceError("Conversation not found") + + # Add user message to database + user_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=message, + role=MessageRole.USER + ) + + # Get conversation history for context + history = await self.conversation_service.get_conversation_history( + conversation_id, limit=20 + ) + + # Prepare messages for LangChain + langchain_messages = self._prepare_langchain_messages(conv, history) + # Update streaming LLM parameters if provided + streaming_llm_to_use = self.streaming_llm + if temperature is not None or max_tokens is not None: + llm_config = await settings.llm.get_current_config() + streaming_llm_to_use = ChatOpenAI( + model=llm_config["model"], + openai_api_key=llm_config["api_key"], + openai_api_base=llm_config["base_url"], + temperature=temperature if temperature is not None else float(conversation.temperature), + max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens, + streaming=True + ) + # Clear previous streaming handler state + self.streaming_handler.clear() + + # Stream response + full_response = "" + try: + async for chunk in streaming_llm_to_use._astream(langchain_messages): + # Handle different chunk types to avoid KeyError + chunk_content = None + if hasattr(chunk, 'content'): + # For object-like chunks with content attribute + chunk_content = chunk.content + elif isinstance(chunk, dict) and 'content' in chunk: + # For dict-like chunks with content key + chunk_content = chunk['content'] + elif isinstance(chunk, dict) and 'error' in chunk: + # Handle error chunks explicitly + logger.error(f"Error in LLM response: {chunk['error']}") + yield self._format_error_message(Exception(chunk['error'])) + continue + + if chunk_content: + full_response += chunk_content + yield chunk_content + except Exception as e: + logger.error(f"Error in LLM streaming: {e}") + yield f"{self._format_error_message(e)} >>> {e}" + # Add complete assistant message to database + assistant_message = await self.conversation_service.add_message( + conversation_id=conversation_id, + content=full_response, + role=MessageRole.ASSISTANT, + message_metadata={ + "model": streaming_llm_to_use.model_name, + "langchain_version": "0.1.0", + "provider": "langchain_openai", + "streaming": True + } + ) + + # Update conversation timestamp + await self.conversation_service.update_conversation_timestamp(conversation_id) + logger.info(f"完成 LangChain 流式处理对话,会话 ID: {conversation_id}") + + except Exception as e: + # 安全地格式化异常信息,避免再次引发KeyError + error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id} >>> {e}" + logger.error(error_info, exc_info=True) + + # Format error message for user + error_message = self._format_error_message(e) + yield error_message + + # Add error message to database + await self.conversation_service.add_message( + conversation_id=conversation_id, + content=error_message, + role=MessageRole.ASSISTANT, + message_metadata={ + "error": True, + "error_type": type(e).__name__, + "original_error": str(e), + "langchain_error": True, + "streaming": True + } + ) + + async def get_available_models(self) -> List[str]: + """Get list of available models from LangChain.""" + try: + # LangChain doesn't have a direct method to list models + # Return commonly available OpenAI models + return [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-4", + "gpt-4-turbo-preview", + "gpt-4o", + "gpt-4o-mini" + ] + except Exception as e: + logger.error(f"Failed to get available models: {str(e)}") + return ["gpt-3.5-turbo"] + + async def update_model_config( + self, + model: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ): + """Update LLM configuration.""" + from ..core.new_agent import new_agent + + # 重新创建LLM实例 + self.llm = await new_agent( + model=model, + temperature=temperature, + streaming=False + ) + + self.streaming_llm = await new_agent( + model=model, + temperature=temperature, + streaming=True + ) + + logger.info(f"Updated LLM configuration: model={model}, temperature={temperature}, max_tokens={max_tokens}") + + def _format_error_message(self, error: Exception) -> str: + """Format error message for user display.""" + error_type = type(error).__name__ + error_str = str(error) + + # Provide user-friendly error messages + if "rate limit" in error_str.lower(): + return "服务器繁忙,请稍后再试。" + elif "api key" in error_str.lower() or "authentication" in error_str.lower(): + return f"API认证失败,请检查配置文件。" + elif "timeout" in error_str.lower(): + return "请求超时,请重试。" + elif "connection" in error_str.lower(): + return "网络连接错误,请检查网络连接。" + elif "model" in error_str.lower() and "not found" in error_str.lower(): + return "指定的模型不可用,请选择其他模型。" + else: + return f"处理请求时发生错误:{error_str}" + + async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0): + """Retry function with exponential backoff.""" + for attempt in range(max_retries): + try: + return await func() + except Exception as e: + if attempt == max_retries - 1: + raise e + + # Check if error is retryable + if not self._is_retryable_error(e): + raise e + + delay = base_delay * (2 ** attempt) + logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}") + await asyncio.sleep(delay) + + def _is_retryable_error(self, error: Exception) -> bool: + """Check if an error is retryable.""" + error_str = str(error).lower() + retryable_errors = [ + "timeout", + "connection", + "server error", + "internal error", + "rate limit" + ] + return any(err in error_str for err in retryable_errors) \ No newline at end of file diff --git a/th_agenter/services/llm_config_service.py b/th_agenter/services/llm_config_service.py new file mode 100644 index 0000000..0d9fd2a --- /dev/null +++ b/th_agenter/services/llm_config_service.py @@ -0,0 +1,123 @@ +"""LLM配置服务 - 从数据库读取默认配置""" + +from typing import Optional, Dict, Any, List +from sqlalchemy.orm import Session +from sqlalchemy import and_, select + +from ..models.llm_config import LLMConfig +from ..db.database import get_session +from loguru import logger + +class LLMConfigService: + """LLM配置管理服务""" + + async def get_default_chat_config(self, session: Session) -> Optional[LLMConfig]: + """获取默认对话模型配置""" + # async for session in get_session(): + try: + stmt = select(LLMConfig).where( + and_( + LLMConfig.is_default == True, + LLMConfig.is_embedding == False, + LLMConfig.is_active == True + ) + ) + config = (await session.execute(stmt)).scalar_one_or_none() + + if not config: + logger.warning("未找到默认对话模型配置") + return None + + return config + + except Exception as e: + logger.error(f"获取默认对话模型配置失败: {str(e)}") + return None + + async def get_default_embedding_config(self, session: Session) -> Optional[LLMConfig]: + """获取默认嵌入模型配置""" + try: + stmt = select(LLMConfig).where( + and_( + LLMConfig.is_default == True, + LLMConfig.is_embedding == True, + LLMConfig.is_active == True + ) + ) + config = None + if session != None: + config = (await session.execute(stmt)).scalar_one_or_none() + if not config: + if session != None: + session.desc = "ERROR: 未找到默认嵌入模型配置" + return None + + session.desc = f"获取默认嵌入模型配置 > 结果:{config}" + return config + + except Exception as e: + if session != None: + session.desc = f"ERROR: 获取默认嵌入模型配置失败: {str(e)}" + return None + + async def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]: + """根据ID获取配置""" + try: + stmt = select(LLMConfig).where(LLMConfig.id == config_id) + return (await self.db.execute(stmt)).scalar_one_or_none() + except Exception as e: + logger.error(f"获取配置失败: {str(e)}") + return None + + def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]: + """获取所有激活的配置""" + try: + stmt = select(LLMConfig).where(LLMConfig.is_active == True) + + if is_embedding is not None: + stmt = stmt.where(LLMConfig.is_embedding == is_embedding) + + stmt = stmt.order_by(LLMConfig.created_at) + return self.db.execute(stmt).scalars().all() + + except Exception as e: + logger.error(f"获取激活配置失败: {str(e)}") + return [] + + async def _get_fallback_chat_config(self) -> Dict[str, Any]: + """获取fallback对话模型配置(从环境变量)""" + from ..core.config import get_settings + settings = get_settings() + return await settings.llm.get_current_config() + + async def _get_fallback_embedding_config(self) -> Dict[str, Any]: + """获取fallback嵌入模型配置(从环境变量)""" + from ..core.config import get_settings + settings = get_settings() + return await settings.embedding.get_current_config() + + def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]: + """测试配置连接""" + try: + config = self.get_config_by_id(config_id) + if not config: + return {"success": False, "error": "配置不存在"} + + # 这里可以添加实际的连接测试逻辑 + # 例如发送一个简单的请求来验证配置是否有效 + + return {"success": True, "message": "配置测试成功"} + + except Exception as e: + logger.error(f"测试配置失败: {str(e)}") + return {"success": False, "error": str(e)} + +# # 全局实例 +# _llm_config_service = None + +# def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService: +# """获取LLM配置服务实例""" +# global _llm_config_service +# if _llm_config_service is None or db_session is not None: +# _llm_config_service = LLMConfigService(db_session) +# return _llm_config_service \ No newline at end of file diff --git a/th_agenter/services/llm_service.py b/th_agenter/services/llm_service.py new file mode 100644 index 0000000..466d9c5 --- /dev/null +++ b/th_agenter/services/llm_service.py @@ -0,0 +1,110 @@ +"""LLM service for workflow execution.""" + +import asyncio +from typing import List, Dict, Any, Optional, AsyncGenerator +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage + +from ..models.llm_config import LLMConfig +from loguru import logger + +class LLMService: + """LLM服务,用于工作流中的大模型调用""" + + def __init__(self): + pass + + async def chat_completion( + self, + model_config: LLMConfig, + messages: List[Dict[str, str]], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> str: + """调用大模型进行对话完成""" + try: + # 创建LangChain ChatOpenAI实例 + llm = ChatOpenAI( + model=model_config.model_name, + api_key=model_config.api_key, + base_url=model_config.base_url, + temperature=temperature or model_config.temperature, + max_tokens=max_tokens or model_config.max_tokens, + streaming=False + ) + + # 转换消息格式 + langchain_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + langchain_messages.append(SystemMessage(content=content)) + elif role == "user": + langchain_messages.append(HumanMessage(content=content)) + elif role == "assistant": + langchain_messages.append(AIMessage(content=content)) + + # 调用LLM + response = await llm.ainvoke(langchain_messages) + + # 返回响应内容 + return response.content + + except Exception as e: + logger.error(f"LLM调用失败: {str(e)}") + raise Exception(f"LLM调用失败: {str(e)}") + + async def chat_completion_stream( + self, + model_config: LLMConfig, + messages: List[Dict[str, str]], + temperature: Optional[float] = None, + max_tokens: Optional[int] = None + ) -> AsyncGenerator[str, None]: + """调用大模型进行流式对话完成""" + try: + # 创建LangChain ChatOpenAI实例(流式) + llm = ChatOpenAI( + model=model_config.model_name, + api_key=model_config.api_key, + base_url=model_config.base_url, + temperature=temperature or model_config.temperature, + max_tokens=max_tokens or model_config.max_tokens, + streaming=True + ) + + # 转换消息格式 + langchain_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + langchain_messages.append(SystemMessage(content=content)) + elif role == "user": + langchain_messages.append(HumanMessage(content=content)) + elif role == "assistant": + langchain_messages.append(AIMessage(content=content)) + + # 流式调用LLM + async for chunk in llm.astream(langchain_messages): + if hasattr(chunk, 'content') and chunk.content: + yield chunk.content + + except Exception as e: + logger.error(f"LLM流式调用失败: {str(e)}") + raise Exception(f"LLM流式调用失败: {str(e)}") + + def get_model_info(self, model_config: LLMConfig) -> Dict[str, Any]: + """获取模型信息""" + return { + "id": model_config.id, + "name": model_config.model_name, + "provider": model_config.provider, + "base_url": model_config.base_url, + "temperature": model_config.temperature, + "max_tokens": model_config.max_tokens, + "is_active": model_config.is_active + } \ No newline at end of file diff --git a/th_agenter/services/mcp/__init__.py b/th_agenter/services/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/th_agenter/services/mcp/mcp_dynamic_tools.py b/th_agenter/services/mcp/mcp_dynamic_tools.py new file mode 100644 index 0000000..e822686 --- /dev/null +++ b/th_agenter/services/mcp/mcp_dynamic_tools.py @@ -0,0 +1,145 @@ +"""Dynamic MCP tool wrapper for LangChain/LangGraph. + +Fetches available MCP tools from the MCP server and exposes them as LangChain BaseTool +instances that call the MCP `/execute` endpoint at runtime. +""" +from typing import Any, Dict, List, Optional, Type +import json +import requests +from pydantic import BaseModel, Field, PrivateAttr +from langchain.tools import BaseTool + +from th_agenter.core.config import get_settings +from loguru import logger +import os + +# Map MCP parameter types to Python type hints +_TYPE_MAP: Dict[str, Any] = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "array": List[Any], + "object": Dict[str, Any], +} + + +def _build_args_schema(params: List[Dict[str, Any]]) -> Type[BaseModel]: + """Build a Pydantic BaseModel class dynamically from MCP tool params.""" + annotations: Dict[str, Any] = {} + fields: Dict[str, Any] = {} + + for p in params: + name = p.get("name") + ptype = p.get("type", "string") + required = p.get("required", True) + default = p.get("default", None) + description = p.get("description", "") + enum = p.get("enum") + + py_type = _TYPE_MAP.get(ptype, Any) + annotations[name] = py_type + + if enum is not None and default is None: + # if enum present without default, keep required unless specified + field_default = ... if required else None + else: + field_default = ... if required and default is None else default + + fields[name] = Field( + default=field_default, + description=description, + ) + + # Create model class + namespace = {"__annotations__": annotations} + namespace.update(fields) + return type("MCPToolArgs", (BaseModel,), namespace) + + +class MCPDynamicTool(BaseTool): + """LangChain BaseTool wrapper that executes MCP tools via HTTP.""" + + name: str + description: str + args_schema: Type[BaseModel] + + _mcp_base_url: str = PrivateAttr() + _tool_name: str = PrivateAttr() + + def __init__(self, mcp_base_url: str, tool_info: Dict[str, Any]): + # Initialize BaseTool with dynamic metadata + super().__init__( + name=tool_info.get("name", "tool"), + description=tool_info.get("description", ""), + args_schema=_build_args_schema(tool_info.get("parameters", [])), + ) + # set private attrs after BaseTool init to avoid pydantic stripping + self._mcp_base_url = mcp_base_url.rstrip("/") + self._tool_name = tool_info["name"] + + def _execute(self, params: Dict[str, Any]) -> Dict[str, Any]: + url = f"{self._mcp_base_url}/execute" + payload = { + "tool_name": self._tool_name, + "parameters": params, + } + logger.info(f"调用 MCP 工具: {self._tool_name} 参数: {params}") + try: + resp = requests.post(url, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + return data + except Exception as e: + logger.error(f"MCP 工具调用失败: {e}") + return { + "success": False, + "error": str(e), + "result": None, + "tool_name": self._tool_name, + } + + def _run(self, **kwargs: Any) -> str: + """Synchronous execution for LangChain tools.""" + data = self._execute(kwargs) + if not isinstance(data, dict): + return json.dumps({"success": False, "error": "Invalid MCP response"}, ensure_ascii=False) + # Return string content; LangChain expects textual content for ToolMessage + if data.get("success"): + return json.dumps(data.get("result", {}), ensure_ascii=False) + return json.dumps({"error": data.get("error")}, ensure_ascii=False) + + async def _arun(self, **kwargs: Any) -> str: + # LangChain will call async version when available; we simply delegate to sync for now. + return self._run(**kwargs) + + +def load_mcp_tools(include: Optional[List[str]] = None) -> List[MCPDynamicTool]: + """Load MCP tools from the MCP server and construct dynamic tools. + + include: optional list of tool names to include (e.g., ["weather", "search"]). + """ + settings = get_settings() + # Try settings.tool.mcp_server_url, fallback to default + mcp_base_url = getattr(settings.tool, "mcp_server_url", None) or os.getenv("MCP_SERVER_URL") or "http://127.0.0.1:8001" + + url = f"{mcp_base_url.rstrip('/')}/tools" + try: + resp = requests.get(url, timeout=15) + resp.raise_for_status() + tools_info = resp.json() + except Exception as e: + logger.error(f"获取 MCP 工具列表失败: {e}") + return [] + + dynamic_tools: List[MCPDynamicTool] = [] + for tool in tools_info: + name = tool.get("name") + if include and name not in include: + continue + try: + dynamic_tools.append(MCPDynamicTool(mcp_base_url=mcp_base_url, tool_info=tool)) + except Exception as e: + logger.warning(f"构建 MCP 工具'{name}'失败: {e}") + logger.info(f"已加载 MCP 工具: {[t.name for t in dynamic_tools]}") + return dynamic_tools \ No newline at end of file diff --git a/th_agenter/services/mcp/mysql_mcp.py b/th_agenter/services/mcp/mysql_mcp.py new file mode 100644 index 0000000..dbb8be9 --- /dev/null +++ b/th_agenter/services/mcp/mysql_mcp.py @@ -0,0 +1,454 @@ +"""MySQL MCP (Model Context Protocol) tool for database operations.""" + +import json +import pymysql +from typing import List, Dict, Any, Optional +from datetime import datetime + +from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult + +class MySQLMCPTool(BaseTool): + """MySQL MCP tool for database operations and intelligent querying.""" + + def __init__(self): + super().__init__() + self.connections = {} # 存储用户的数据库连接 + + def get_name(self) -> str: + return "mysql_mcp" + + def get_description(self) -> str: + return "MySQL MCP服务工具,提供数据库连接、表结构查询、SQL执行等功能,支持智能数据问答。" + + def get_parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ToolParameterType.STRING, + description="操作类型", + required=True, + enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"] + ), + ToolParameter( + name="connection_config", + type=ToolParameterType.OBJECT, + description="数据库连接配置 {host, port, database, username, password}", + required=False + ), + ToolParameter( + name="user_id", + type=ToolParameterType.STRING, + description="用户ID,用于管理连接", + required=False + ), + ToolParameter( + name="table_name", + type=ToolParameterType.STRING, + description="表名(用于describe_table操作)", + required=False + ), + ToolParameter( + name="sql_query", + type=ToolParameterType.STRING, + description="SQL查询语句(用于execute_query操作)", + required=False + ), + ToolParameter( + name="limit", + type=ToolParameterType.INTEGER, + description="查询结果限制数量,默认100", + required=False, + default=100 + ) + ] + + def _get_tables(self, connection) -> List[Dict[str, Any]]: + """获取数据库表列表""" + cursor = connection.cursor() + try: + cursor.execute(""" + SELECT + table_name, + table_type, + table_schema + FROM information_schema.tables + WHERE table_schema = DATABASE() + ORDER BY table_name; + """) + + tables = [] + for row in cursor.fetchall(): + tables.append({ + "table_name": row[0], + "table_type": row[1], + "table_schema": row[2] + }) + + return tables + finally: + cursor.close() + + def _describe_table(self, connection, table_name: str) -> Dict[str, Any]: + """获取表结构信息""" + cursor = connection.cursor() + try: + # 获取列信息 + cursor.execute(""" + SELECT + column_name, + data_type, + is_nullable, + column_default, + character_maximum_length, + numeric_precision, + numeric_scale, + column_comment + FROM information_schema.columns + WHERE table_schema = DATABASE() AND table_name = %s + ORDER BY ordinal_position; + """, (table_name,)) + + columns = [] + for row in cursor.fetchall(): + column_info = { + "column_name": row[0], + "data_type": row[1], + "is_nullable": row[2] == 'YES', + "column_default": row[3], + "character_maximum_length": row[4], + "numeric_precision": row[5], + "numeric_scale": row[6], + "column_comment": row[7] or "" + } + columns.append(column_info) + + # 获取主键信息 + cursor.execute(""" + SELECT column_name + FROM information_schema.key_column_usage + WHERE table_schema = DATABASE() + AND table_name = %s + AND constraint_name = 'PRIMARY' + ORDER BY ordinal_position; + """, (table_name,)) + + primary_keys = [row[0] for row in cursor.fetchall()] + + # 获取外键信息 + cursor.execute(""" + SELECT + column_name, + referenced_table_name, + referenced_column_name + FROM information_schema.key_column_usage + WHERE table_schema = DATABASE() + AND table_name = %s + AND referenced_table_name IS NOT NULL; + """, (table_name,)) + + foreign_keys = [] + for row in cursor.fetchall(): + foreign_keys.append({ + "column_name": row[0], + "referenced_table": row[1], + "referenced_column": row[2] + }) + + # 获取索引信息 + cursor.execute(""" + SELECT + index_name, + column_name, + non_unique + FROM information_schema.statistics + WHERE table_schema = DATABASE() + AND table_name = %s + ORDER BY index_name, seq_in_index; + """, (table_name,)) + + indexes = [] + for row in cursor.fetchall(): + indexes.append({ + "index_name": row[0], + "column_name": row[1], + "is_unique": row[2] == 0 + }) + + # 获取表注释 + cursor.execute(""" + SELECT table_comment + FROM information_schema.tables + WHERE table_schema = DATABASE() AND table_name = %s; + """, (table_name,)) + + table_comment = "" + result = cursor.fetchone() + if result: + table_comment = result[0] or "" + + return { + "table_name": table_name, + "columns": columns, + "primary_keys": primary_keys, + "foreign_keys": foreign_keys, + "indexes": indexes, + "table_comment": table_comment + } + + finally: + cursor.close() + + def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]: + """执行SQL查询""" + cursor = connection.cursor() + try: + # 添加LIMIT限制(如果查询中没有LIMIT) + if limit and limit > 0 and "LIMIT" not in sql_query.upper(): + sql_query = f"{sql_query.rstrip(';')} LIMIT {limit}" + + cursor.execute(sql_query) + + # 获取列名 + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # 获取数据 + rows = cursor.fetchall() + + # 转换为字典列表 + data = [] + for row in rows: + row_dict = {} + for i, value in enumerate(row): + if i < len(columns): + # 处理特殊数据类型 + if isinstance(value, datetime): + row_dict[columns[i]] = value.isoformat() + else: + row_dict[columns[i]] = value + data.append(row_dict) + return { + "success": True, + "data": data, + "columns": columns, + "row_count": len(data), + "query": sql_query + } + + finally: + cursor.close() + + def _create_connection(self, config: Dict[str, Any]) -> pymysql.Connection: + """创建MySQL数据库连接""" + try: + connection = pymysql.connect( + host=config['host'], + port=int(config.get('port', 3306)), + user=config['username'], + password=config['password'], + database=config['database'], + connect_timeout=10, + charset='utf8mb4' + ) + return connection + except Exception as e: + raise Exception(f"MySQL连接失败: {str(e)}") + + def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]: + """测试数据库连接""" + try: + conn = self._create_connection(config) + cursor = conn.cursor() + + # 获取数据库版本信息 + cursor.execute("SELECT VERSION();") + version = cursor.fetchone()[0] + + # 获取数据库引擎信息 + cursor.execute("SHOW ENGINES;") + engines = cursor.fetchall() + has_innodb = any('InnoDB' in str(engine) for engine in engines) + + cursor.close() + conn.close() + + return { + "success": True, + "version": version, + "has_innodb": has_innodb, + "message": "连接测试成功" + } + except Exception as e: + return { + "success": False, + "error": str(e), + "message": "连接测试失败" + } + + + + async def execute(self, **kwargs) -> ToolResult: + """Execute the MySQL MCP tool operation.""" + try: + operation = kwargs.get("operation") + connection_config = kwargs.get("connection_config", {}) + user_id = kwargs.get("user_id") + table_name = kwargs.get("table_name") + sql_query = kwargs.get("sql_query") + limit = kwargs.get("limit", 100) + + logger.info(f"执行MySQL MCP操作: {operation}") + if operation == "test_connection": + if not connection_config: + return ToolResult( + success=False, + error="缺少连接配置参数" + ) + + result = self._test_connection(connection_config) + return ToolResult( + success=result["success"], + result=result, + error=result.get("error") + ) + elif operation == "connect": + if not connection_config: + return ToolResult( + success=False, + error="缺少connection_config参数" + ) + + if not user_id: + return ToolResult( + success=False, + error="缺少user_id参数" + ) + + try: + # 建立MySQL连接 + connection = pymysql.connect( + host=connection_config["host"], + port=int(connection_config["port"]), + user=connection_config["username"], + password=connection_config["password"], + database=connection_config["database"], + charset='utf8mb4', + cursorclass=pymysql.cursors.Cursor + ) + + # 存储连接 + self.connections[user_id] = { + "connection": connection, + "config": connection_config, + "connected_at": datetime.now().isoformat() + } + + # 获取表列表 + tables = self._get_tables(connection) + + return ToolResult( + success=True, + result={ + "message": "数据库连接成功", + "database": connection_config["database"], + "tables": tables, + "table_count": len(tables) + } + ) + except Exception as e: + return ToolResult( + success=False, + error=f"连接失败: {str(e)}" + ) + + elif operation == "list_tables": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + connection = self.connections[user_id]["connection"] + tables = self._get_tables(connection) + + return ToolResult( + success=True, + result={ + "tables": tables, + "table_count": len(tables) + } + ) + + elif operation == "describe_table": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + if not table_name: + return ToolResult( + success=False, + error="缺少table_name参数" + ) + + connection = self.connections[user_id]["connection"] + table_info = self._describe_table(connection, table_name) + + return ToolResult( + success=True, + result=table_info + ) + + elif operation == "execute_query": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + if not sql_query: + return ToolResult( + success=False, + error="缺少sql_query参数" + ) + + connection = self.connections[user_id]["connection"] + query_result = self._execute_query(connection, sql_query, limit) + + return ToolResult( + success=True, + result=query_result + ) + + elif operation == "disconnect": + if user_id and user_id in self.connections: + try: + self.connections[user_id]["connection"].close() + del self.connections[user_id] + return ToolResult( + success=True, + result={"message": "数据库连接已断开"} + ) + except Exception as e: + return ToolResult( + success=False, + error=f"断开连接失败: {str(e)}" + ) + else: + return ToolResult( + success=True, + result={"message": "用户未连接数据库"} + ) + + else: + return ToolResult( + success=False, + result=f"不支持的操作类型: {operation}", + ) + + except Exception as e: + logger.error(f"MySQL MCP工具执行失败: {str(e)}", exc_info=True) + return ToolResult( + success=False, + error=f"工具执行失败: {str(e)}" + ) \ No newline at end of file diff --git a/th_agenter/services/mcp/postgresql_mcp.py b/th_agenter/services/mcp/postgresql_mcp.py new file mode 100644 index 0000000..6384017 --- /dev/null +++ b/th_agenter/services/mcp/postgresql_mcp.py @@ -0,0 +1,385 @@ +"""PostgreSQL MCP (Model Context Protocol) tool for database operations.""" + +import json +import psycopg2 +from typing import List, Dict, Any, Optional +from datetime import datetime + +from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult + +class PostgreSQLMCPTool(BaseTool): + """PostgreSQL MCP tool for database operations and intelligent querying.""" + + def __init__(self): + super().__init__() + self.connections = {} # 存储用户的数据库连接 + + def get_name(self) -> str: + return "postgresql_mcp" + + def get_description(self) -> str: + return "PostgreSQL MCP服务工具,提供数据库连接、表结构查询、SQL执行等功能,支持智能数据问答。" + + def get_parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ToolParameterType.STRING, + description="操作类型", + required=True, + enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"] + ), + ToolParameter( + name="connection_config", + type=ToolParameterType.OBJECT, + description="数据库连接配置 {host, port, database, username, password}", + required=False + ), + ToolParameter( + name="user_id", + type=ToolParameterType.STRING, + description="用户ID,用于管理连接", + required=False + ), + ToolParameter( + name="table_name", + type=ToolParameterType.STRING, + description="表名(用于describe_table操作)", + required=False + ), + ToolParameter( + name="sql_query", + type=ToolParameterType.STRING, + description="SQL查询语句(用于execute_query操作)", + required=False + ), + ToolParameter( + name="limit", + type=ToolParameterType.INTEGER, + description="查询结果限制数量,默认100", + required=False, + default=100 + ) + ] + + def _create_connection(self, config: Dict[str, Any]) -> psycopg2.extensions.connection: + """创建PostgreSQL数据库连接""" + try: + connection = psycopg2.connect( + host=config['host'], + port=int(config.get('port', 5432)), + user=config['username'], + password=config['password'], + database=config['database'], + connect_timeout=10 + ) + return connection + except Exception as e: + raise Exception(f"PostgreSQL连接失败: {str(e)}") + + def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]: + """测试数据库连接""" + try: + conn = self._create_connection(config) + cursor = conn.cursor() + + # 获取数据库版本信息 + cursor.execute("SELECT version();") + version = cursor.fetchone()[0] + + # 检查pgvector扩展 + cursor.execute("SELECT * FROM pg_extension WHERE extname = 'vector';") + has_vector = bool(cursor.fetchall()) + + cursor.close() + conn.close() + + return { + "success": True, + "version": version, + "has_pgvector": has_vector, + "message": "连接测试成功" + } + except Exception as e: + return { + "success": False, + "error": str(e), + "message": "连接测试失败" + } + + def _get_tables(self, connection) -> List[Dict[str, Any]]: + """获取数据库表列表""" + cursor = connection.cursor() + try: + cursor.execute(""" + SELECT + table_name, + table_type, + table_schema + FROM information_schema.tables + WHERE table_schema = 'public' + ORDER BY table_name; + """) + + tables = [] + for row in cursor.fetchall(): + tables.append({ + "table_name": row[0], + "table_type": row[1], + "table_schema": row[2] + }) + + return tables + finally: + cursor.close() + + def _describe_table(self, connection, table_name: str) -> Dict[str, Any]: + """获取表结构信息""" + cursor = connection.cursor() + try: + # 获取列信息 + cursor.execute(""" + SELECT + column_name, + data_type, + is_nullable, + column_default, + character_maximum_length, + numeric_precision, + numeric_scale + FROM information_schema.columns + WHERE table_name = %s AND table_schema = 'public' + ORDER BY ordinal_position; + """, (table_name,)) + + columns = [] + for row in cursor.fetchall(): + columns.append({ + "column_name": row[0], + "data_type": row[1], + "is_nullable": row[2], + "column_default": row[3], + "character_maximum_length": row[4], + "numeric_precision": row[5], + "numeric_scale": row[6] + }) + + # 获取主键信息 + cursor.execute(""" + SELECT column_name + FROM information_schema.key_column_usage + WHERE table_name = %s AND table_schema = 'public' + AND constraint_name IN ( + SELECT constraint_name + FROM information_schema.table_constraints + WHERE table_name = %s AND constraint_type = 'PRIMARY KEY' + ); + """, (table_name, table_name)) + + primary_keys = [row[0] for row in cursor.fetchall()] + + # 获取表行数 + cursor.execute(f"SELECT COUNT(*) FROM {table_name};") + row_count = cursor.fetchone()[0] + + return { + "table_name": table_name, + "columns": columns, + "primary_keys": primary_keys, + "row_count": row_count + } + finally: + cursor.close() + + def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]: + """执行SQL查询""" + cursor = connection.cursor() + try: + # 添加LIMIT限制(如果查询中没有) + if limit and "LIMIT" not in sql_query.upper(): + sql_query = f"{sql_query.rstrip(';')} LIMIT {limit};" + + cursor.execute(sql_query) + + # 获取列名 + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # 获取结果 + if cursor.description: # SELECT查询 + rows = cursor.fetchall() + data = [] + for row in rows: + row_dict = {} + for i, value in enumerate(row): + if i < len(columns): + # 处理特殊数据类型 + if isinstance(value, datetime): + row_dict[columns[i]] = value.isoformat() + else: + row_dict[columns[i]] = value + data.append(row_dict) + + return { + "success": True, + "data": data, + "columns": columns, + "row_count": len(data), + "query": sql_query + } + else: # INSERT/UPDATE/DELETE查询 + affected_rows = cursor.rowcount + return { + "success": True, + "affected_rows": affected_rows, + "query": sql_query, + "message": f"查询执行成功,影响 {affected_rows} 行" + } + finally: + cursor.close() + + async def execute(self, operation: str, connection_config: Optional[Dict[str, Any]] = None, + user_id: Optional[str] = None, table_name: Optional[str] = None, + sql_query: Optional[str] = None, limit: int = 100) -> ToolResult: + """执行PostgreSQL MCP操作""" + try: + logger.info(f"执行PostgreSQL MCP操作: {operation}") + + if operation == "test_connection": + if not connection_config: + return ToolResult( + success=False, + error="缺少连接配置参数" + ) + + result = self._test_connection(connection_config) + return ToolResult( + success=result["success"], + result=result, + error=result.get("error") + ) + + elif operation == "connect": + if not connection_config or not user_id: + return ToolResult( + success=False, + error="缺少连接配置或用户ID参数" + ) + + try: + connection = self._create_connection(connection_config) + self.connections[user_id] = { + "connection": connection, + "config": connection_config, + "connected_at": datetime.now().isoformat() + } + + # 获取表列表 + tables = self._get_tables(connection) + + return ToolResult( + success=True, + result={ + "message": "数据库连接成功", + "database": connection_config["database"], + "tables": tables, + "table_count": len(tables) + } + ) + except Exception as e: + return ToolResult( + success=False, + error=f"连接失败: {str(e)}" + ) + + elif operation == "list_tables": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + connection = self.connections[user_id]["connection"] + tables = self._get_tables(connection) + + return ToolResult( + success=True, + result={ + "tables": tables, + "table_count": len(tables) + } + ) + + elif operation == "describe_table": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + if not table_name: + return ToolResult( + success=False, + error="缺少table_name参数" + ) + + connection = self.connections[user_id]["connection"] + table_info = self._describe_table(connection, table_name) + + return ToolResult( + success=True, + result=table_info + ) + + elif operation == "execute_query": + if not user_id or user_id not in self.connections: + return ToolResult( + success=False, + error="用户未连接数据库,请先执行connect操作" + ) + + if not sql_query: + return ToolResult( + success=False, + error="缺少sql_query参数" + ) + + connection = self.connections[user_id]["connection"] + query_result = self._execute_query(connection, sql_query, limit) + + return ToolResult( + success=True, + result=query_result + ) + + elif operation == "disconnect": + if user_id and user_id in self.connections: + try: + self.connections[user_id]["connection"].close() + del self.connections[user_id] + return ToolResult( + success=True, + result={"message": "数据库连接已断开"} + ) + except Exception as e: + return ToolResult( + success=False, + error=f"断开连接失败: {str(e)}" + ) + else: + return ToolResult( + success=True, + result={"message": "用户未连接数据库"} + ) + + else: + return ToolResult( + success=False, + error=f"不支持的操作类型: {operation}" + ) + + except Exception as e: + logger.error(f"PostgreSQL MCP工具执行失败: {str(e)}", exc_info=True) + return ToolResult( + success=False, + error=f"工具执行失败: {str(e)}" + ) \ No newline at end of file diff --git a/th_agenter/services/mysql_tool_manager.py b/th_agenter/services/mysql_tool_manager.py new file mode 100644 index 0000000..7bb7e04 --- /dev/null +++ b/th_agenter/services/mysql_tool_manager.py @@ -0,0 +1,36 @@ +"""MySQL MCP工具全局管理器""" +from loguru import logger +from typing import Optional +from th_agenter.services.mcp.mysql_mcp import MySQLMCPTool + +class MySQLToolManager: + """MySQL工具全局单例管理器""" + + _instance: Optional['MySQLToolManager'] = None + _mysql_tool: Optional[MySQLMCPTool] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @property + def mysql_tool(self) -> MySQLMCPTool: + """获取MySQL工具实例""" + if self._mysql_tool is None: + self._mysql_tool = MySQLMCPTool() + logger.info("创建全局MySQL工具实例") + return self._mysql_tool + + def get_tool(self) -> MySQLMCPTool: + """获取MySQL工具实例(别名方法)""" + return self.mysql_tool + + +# 全局实例 +mysql_tool_manager = MySQLToolManager() + + +def get_mysql_tool() -> MySQLMCPTool: + """获取全局MySQL工具实例""" + return mysql_tool_manager.get_tool() \ No newline at end of file diff --git a/th_agenter/services/postgresql_tool_manager.py b/th_agenter/services/postgresql_tool_manager.py new file mode 100644 index 0000000..2dd8683 --- /dev/null +++ b/th_agenter/services/postgresql_tool_manager.py @@ -0,0 +1,36 @@ +"""PostgreSQL MCP工具全局管理器""" +from loguru import logger +from typing import Optional +from th_agenter.services.mcp.postgresql_mcp import PostgreSQLMCPTool + +class PostgreSQLToolManager: + """PostgreSQL工具全局单例管理器""" + + _instance: Optional['PostgreSQLToolManager'] = None + _postgresql_tool: Optional[PostgreSQLMCPTool] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @property + def postgresql_tool(self) -> PostgreSQLMCPTool: + """获取PostgreSQL工具实例""" + if self._postgresql_tool is None: + self._postgresql_tool = PostgreSQLMCPTool() + logger.info("创建全局PostgreSQL工具实例") + return self._postgresql_tool + + def get_tool(self) -> PostgreSQLMCPTool: + """获取PostgreSQL工具实例(别名方法)""" + return self.postgresql_tool + + +# 全局实例 +postgresql_tool_manager = PostgreSQLToolManager() + + +def get_postgresql_tool() -> PostgreSQLMCPTool: + """获取全局PostgreSQL工具实例""" + return postgresql_tool_manager.get_tool() \ No newline at end of file diff --git a/th_agenter/services/smart_db_workflow.py b/th_agenter/services/smart_db_workflow.py new file mode 100644 index 0000000..aeb81f1 --- /dev/null +++ b/th_agenter/services/smart_db_workflow.py @@ -0,0 +1,880 @@ +from typing import Dict, Any, List, Optional +import logging +from datetime import datetime +import asyncio +from concurrent.futures import ThreadPoolExecutor +from langchain_openai import ChatOpenAI +from th_agenter.core.context import UserContext +from .smart_query import DatabaseQueryService +from .postgresql_tool_manager import get_postgresql_tool +from .mysql_tool_manager import get_mysql_tool +from .table_metadata_service import TableMetadataService +from ..core.config import get_settings + +# 配置日志 +logger = logging.getLogger(__name__) + +class SmartWorkflowError(Exception): + """智能工作流自定义异常""" + pass + +class DatabaseConnectionError(SmartWorkflowError): + """数据库连接异常""" + pass + +class TableSchemaError(SmartWorkflowError): + """表结构获取异常""" + pass + +class SQLGenerationError(SmartWorkflowError): + """SQL生成异常""" + pass + +class QueryExecutionError(SmartWorkflowError): + """查询执行异常""" + pass + + +class SmartDatabaseWorkflowManager: + """ + 智能数据库工作流管理器 + 负责协调数据库连接、表元数据获取、SQL生成、查询执行和AI总结的完整流程 + """ + + def __init__(self, db=None): + self.executor = ThreadPoolExecutor(max_workers=4) + self.database_service = DatabaseQueryService() + self.postgresql_tool = get_postgresql_tool() + self.mysql_tool = get_mysql_tool() + self.db = db + self.table_metadata_service = TableMetadataService(db) if db else None + + async def initialize(self): + from ..core.new_agent import new_agent + self.llm = await new_agent() + + def _get_database_tool(self, db_type: str): + """根据数据库类型获取对应的数据库工具""" + if db_type.lower() == 'postgresql': + return self.postgresql_tool + elif db_type.lower() == 'mysql': + return self.mysql_tool + else: + raise ValueError(f"不支持的数据库类型: {db_type}") + + async def _run_in_executor(self, func, *args): + """在线程池中运行阻塞函数""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, func, *args) + + def _convert_query_result_to_table_data(self, query_result: Dict[str, Any]) -> Dict[str, Any]: + """ + 将数据库查询结果转换为前端表格数据格式 + 参考Excel处理方式,以表格形式返回结果 + """ + try: + data = query_result.get('data', []) + columns = query_result.get('columns', []) + row_count = query_result.get('row_count', 0) + + if not data or not columns: + return { + 'result_type': 'table', + 'columns': [], + 'data': [], + 'total': 0, + 'message': '查询未返回数据' + } + + # 构建列定义 + table_columns = [] + for i, col_name in enumerate(columns): + table_columns.append({ + 'prop': f'col_{i}', + 'label': str(col_name), + 'width': 'auto' + }) + + # 转换数据行 + table_data = [] + for row_index, row in enumerate(data): + row_data = {'_index': str(row_index)} + # 处理字典格式的行数据 + if isinstance(row, dict): + for i, col_name in enumerate(columns): + col_prop = f'col_{i}' + value = row.get(col_name) + # 处理None值和特殊值 + if value is None: + row_data[col_prop] = '' + elif isinstance(value, (int, float, str, bool)): + row_data[col_prop] = str(value) + else: + row_data[col_prop] = str(value) + else: + # 处理列表格式的行数据(兼容性处理) + for i, value in enumerate(row): + col_prop = f'col_{i}' + # 处理None值和特殊值 + if value is None: + row_data[col_prop] = '' + elif isinstance(value, (int, float, str, bool)): + row_data[col_prop] = str(value) + else: + row_data[col_prop] = str(value) + + table_data.append(row_data) + + return { + 'result_type': 'table_data', + 'columns': table_columns, + 'data': table_data, + 'total': row_count, + 'message': f'查询成功,共返回 {row_count} 条记录' + } + + except Exception as e: + logger.error(f"转换查询结果异常: {str(e)}") + return { + 'result_type': 'error', + 'columns': [], + 'data': [], + 'total': 0, + 'message': f'结果转换失败: {str(e)}' + } + + async def process_database_query_stream( + self, + user_query: str, + user_id: int, + database_config_id: int + ): + """ + 流式处理数据库智能问数查询的主要工作流(基于保存的表元数据) + 实时推送每个工作流步骤 + + 新流程: + 1. 根据database_config_id获取数据库配置并创建连接 + 2. 从系统数据库读取表元数据(只包含启用问答的表) + 3. 根据表元数据生成SQL + 4. 执行SQL查询 + 5. 查询数据后处理成表格形式 + 6. 生成数据总结 + 7. 返回结果 + + Args: + user_query: 用户问题 + user_id: 用户ID + database_config_id: 数据库配置ID + + Yields: + 包含工作流步骤或最终结果的字典 + """ + workflow_steps = [] + + try: + logger.info(f"开始执行流式数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...") + + # 步骤1: 根据database_config_id获取数据库配置并创建连接 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'database_connection', + 'status': 'running', + 'message': '正在建立数据库连接...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + # 获取数据库配置并建立连接 + connection_result = await self._connect_database(user_id, database_config_id) + if not connection_result['success']: + raise DatabaseConnectionError(connection_result['message']) + + step_data.update({ + 'status': 'completed', + 'message': '数据库连接成功', + 'details': {'database': connection_result.get('database_name', 'Unknown')} + }) + yield step_data + + workflow_steps.append({ + 'step': 'database_connection', + 'status': 'completed', + 'message': '数据库连接成功' + }) + + except Exception as e: + error_msg = f'数据库连接失败: {str(e)}' + step_data = { + 'type': 'workflow_step', + 'step': 'database_connection', + 'status': 'failed', + 'message': error_msg, + 'timestamp': datetime.now().isoformat() + } + yield step_data + + yield { + 'type': 'error', + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + # 步骤2: 从系统数据库读取表元数据(只包含启用问答的表) + try: + step_data = { + 'type': 'workflow_step', + 'step': 'table_metadata', + 'status': 'running', + 'message': '正在从系统数据库读取表元数据...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + # 从系统数据库读取已保存的表元数据(只包含启用问答的表) + tables_info = await self._get_saved_tables_metadata(user_id, database_config_id) + + step_data.update({ + 'status': 'completed', + 'message': f'成功读取 {len(tables_info)} 个启用问答的表元数据', + 'details': {'table_count': len(tables_info), 'tables': list(tables_info.keys())} + }) + yield step_data + + workflow_steps.append({ + 'step': 'table_metadata', + 'status': 'completed', + 'message': f'成功读取表元数据' + }) + + except Exception as e: + error_msg = f'获取表元数据失败: {str(e)}' + step_data = { + 'type': 'workflow_step', + 'step': 'table_metadata', + 'status': 'failed', + 'message': error_msg, + 'timestamp': datetime.now().isoformat() + } + yield step_data + + yield { + 'type': 'error', + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + # 步骤3: 根据表元数据生成SQL + try: + step_data = { + 'type': 'workflow_step', + 'step': 'sql_generation', + 'status': 'running', + 'message': '正在根据表元数据生成SQL查询...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + # 根据表元数据选择相关表并生成SQL + target_tables, target_schemas = await self._select_target_table(user_query, tables_info) + step_data = { + 'type': 'workflow_step', + 'step': 'table_selected', + 'status': 'completed', + 'message': f'已经智能选择了相关表: {", ".join(target_tables)}', + 'timestamp': datetime.now().isoformat() + } + + yield step_data + workflow_steps.append({ + 'step': 'table_metadata', + 'status': 'completed', + 'message': f'已经智能选择了相关表: {", ".join(target_tables)}', + }) + sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas) + + step_data.update({ + 'status': 'completed', + 'message': 'SQL查询生成成功', + 'details': { + 'target_tables': target_tables, + 'generated_sql': sql_query[:100] + '...' if len(sql_query) > 100 else sql_query + } + }) + yield step_data + + workflow_steps.append({ + 'step': 'sql_generation', + 'status': 'completed', + 'message': 'SQL语句生成成功' + }) + + except Exception as e: + error_msg = f'SQL生成失败: {str(e)}' + step_data = { + 'type': 'workflow_step', + 'step': 'sql_generation', + 'status': 'failed', + 'message': error_msg, + 'timestamp': datetime.now().isoformat() + } + yield step_data + + yield { + 'type': 'error', + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + # 步骤4: 执行SQL查询 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'query_execution', + 'status': 'running', + 'message': '正在执行SQL查询...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + query_result = await self._execute_database_query(user_id, sql_query, database_config_id) + + step_data.update({ + 'status': 'completed', + 'message': f'查询执行成功,返回 {query_result.get("row_count", 0)} 条记录', + 'details': {'row_count': query_result.get('row_count', 0)} + }) + yield step_data + + workflow_steps.append({ + 'step': 'query_execution', + 'status': 'completed', + 'message': '查询执行成功' + }) + + except Exception as e: + error_msg = f'查询执行失败: {str(e)}' + step_data = { + 'type': 'workflow_step', + 'step': 'query_execution', + 'status': 'failed', + 'message': error_msg, + 'timestamp': datetime.now().isoformat() + } + yield step_data + + yield { + 'type': 'error', + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + # 步骤5: 查询数据后处理成表格形式(在步骤6中完成) + # 步骤6: 生成数据总结 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'ai_summary', + 'status': 'running', + 'message': '正在生成查询结果总结...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables)) + + step_data.update({ + 'status': 'completed', + 'message': '总结生成完成', + 'details': { + 'tables_analyzed': target_tables, + 'summary_length': len(summary) + } + }) + yield step_data + + workflow_steps.append({ + 'step': 'ai_summary', + 'status': 'completed', + 'message': '总结生成完成' + }) + + except Exception as e: + logger.warning(f'生成总结失败: {str(e)}') + summary = '查询执行完成,但生成总结时出现问题。' + + workflow_steps.append({ + 'step': 'ai_summary', + 'status': 'warning', + 'message': '总结生成失败,但查询成功' + }) + + # 步骤7: 返回最终结果,且结果参考excel的处理方式,尽量以表格形式返回 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'result_formatting', + 'status': 'running', + 'message': '正在格式化查询结果...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + # 转换为表格格式 + table_data = self._convert_query_result_to_table_data(query_result) + + step_data.update({ + 'status': 'completed', + 'message': '结果格式化完成' + }) + yield step_data + + workflow_steps.append({ + 'step': 'result_formatting', + 'status': 'completed', + 'message': '结果格式化完成' + }) + + # 返回最终结果 + final_result = { + 'type': 'final_result', + 'success': True, + 'data': { + **table_data, + 'generated_sql': sql_query, + 'summary': summary, + 'table_name': target_tables, + 'query_result': query_result, + 'metadata_source': 'saved_database' # 标记元数据来源 + }, + 'workflow_steps': workflow_steps, + 'timestamp': datetime.now().isoformat() + } + + yield final_result + logger.info(f"数据库查询工作流完成 - 用户ID: {user_id}") + + except Exception as e: + error_msg = f'结果格式化失败: {str(e)}' + yield { + 'type': 'error', + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + except Exception as e: + logger.error(f"数据库查询工作流异常: {str(e)}", exc_info=True) + yield { + 'type': 'error', + 'message': f'系统异常: {str(e)}', + 'workflow_steps': workflow_steps + } + + async def _connect_database(self, user_id: int, database_config_id: int) -> Dict[str, Any]: + """连接数据库(判断用户现有连接)""" + try: + # 获取数据库配置 + from ..services.database_config_service import DatabaseConfigService + config_service = DatabaseConfigService(self.db) + config = config_service.get_config_by_id(database_config_id, user_id) + + if not config: + return {'success': False, 'message': '数据库配置不存在'} + + # 根据数据库类型选择对应的工具 + try: + db_tool = self._get_database_tool(config.db_type) + except ValueError as e: + return {'success': False, 'message': str(e)} + + # 测试连接(如果已经有连接则直接复用) + connection_config = { + 'host': config.host, + 'port': config.port, + 'database': config.database, + 'username': config.username, + 'password': config_service._decrypt_password(config.password) + } + + try: + connection = db_tool._test_connection(connection_config) + if connection['success'] == True: + return { + 'success': True, + 'database_name': config.database, + 'db_type': config.db_type, + 'message': '连接成功' + } + else: + return { + 'success': False, + 'database_name': config.database, + 'db_type': config.db_type, + 'message': '连接失败' + } + except Exception as e: + return { + 'success': False, + 'message': f'连接失败: {str(e)}' + } + + except Exception as e: + logger.error(f"数据库连接异常: {str(e)}") + return {'success': False, 'message': f'连接异常: {str(e)}'} + + async def _get_saved_tables_metadata(self, user_id: int, database_config_id: int) -> Dict[str, Dict[str, Any]]: + """从系统数据库中读取已保存的表元数据""" + try: + if not self.table_metadata_service: + raise TableSchemaError("表元数据服务未初始化") + + # 从数据库中获取表元数据 + saved_metadata = self.table_metadata_service.get_user_table_metadata( + user_id, database_config_id + ) + + if not saved_metadata: + raise TableSchemaError(f"未找到数据库配置ID {database_config_id} 的表元数据,请先在数据库管理页面收集表元数据") + + # 转换为所需格式 + tables_metadata = {} + for meta in saved_metadata: + # 只处理启用问答的表 + if meta.is_enabled_for_qa: + tables_metadata[meta.table_name] = { + 'table_name': meta.table_name, + 'columns': meta.columns_info or [], + 'primary_keys': meta.primary_keys or [], + 'row_count': meta.row_count or 0, + 'table_comment': meta.table_comment or '', + 'qa_description': meta.qa_description or '', + 'business_context': meta.business_context or '', + 'from_saved_metadata': True # 标记来源 + } + + if not tables_metadata: + raise TableSchemaError("没有启用问答的表,请在数据库管理页面启用相关表的问答功能") + + logger.info(f"从系统数据库读取表元数据成功,共 {len(tables_metadata)} 个启用问答的表") + return tables_metadata + + except Exception as e: + logger.error(f"读取保存的表元数据异常: {str(e)}") + raise TableSchemaError(f'读取表元数据失败: {str(e)}') + + async def _get_table_schema(self, user_id: int, table_name: str, database_config_id: int) -> Dict[str, Any]: + """获取指定表结构""" + try: + # 获取数据库配置 + from ..services.database_config_service import DatabaseConfigService + config_service = DatabaseConfigService(self.db) + config = config_service.get_config_by_id(database_config_id, user_id) + + if not config: + raise TableSchemaError('数据库配置不存在') + + # 根据数据库类型选择对应的工具 + try: + db_tool = self._get_database_tool(config.db_type) + except ValueError as e: + raise TableSchemaError(str(e)) + + # 使用对应的数据库工具获取表结构 + schema_result = await db_tool.describe_table(table_name) + + if schema_result.get('success'): + return schema_result.get('schema', {}) + else: + raise TableSchemaError(schema_result.get('error', '获取表结构失败')) + + except Exception as e: + logger.error(f"获取表结构异常: {str(e)}") + raise TableSchemaError(f'获取表结构失败: {str(e)}') + + async def _select_target_table(self, user_query: str, tables_info: Dict[str, Dict]) -> tuple[List[str], List[Dict]]: + """根据用户查询选择相关的表,支持返回多个表""" + try: + if len(tables_info) == 1: + # 只有一个表,直接返回 + table_name = list(tables_info.keys())[0] + return [table_name], [tables_info[table_name]] + + # 多个表时,使用LLM选择相关的表 + tables_summary = [] + for table_name, schema in tables_info.items(): + columns = schema.get('columns', []) + column_names = [col.get('column_name', col.get('name', '')) for col in columns] + qa_desc = schema.get('qa_description', '') + business_ctx = schema.get('business_context', '') + tables_summary.append(f"表名: {table_name}\n字段: {', '.join(column_names[:10])}\n表描述: {qa_desc}\n业务上下文: {business_ctx}") + + prompt = f""" + 用户查询: {user_query} + + 可用的表: + {chr(10).join(tables_summary)} + + 请根据用户查询选择相关的表,可以选择多个表。分析表之间可能的关联关系,返回所有相关的表名,用逗号分隔。 + 可以通过qa_description(表描述),business_context(表的业务上下文),以及column_names几个字段判断要使用哪些表。 + 注意:只返回表名列表,后面不要跟其他的内容。 + 例如直接输出: table1,table2,table3 + """ + + response = await self.llm.ainvoke(prompt) + selected_tables = [t.strip() for t in response.content.strip().split(',')] + + # 验证选择的表是否存在 + valid_tables = [] + valid_schemas = [] + for table in selected_tables: + if table in tables_info: + valid_tables.append(table) + valid_schemas.append(tables_info[table]) + else: + logger.warning(f"LLM选择的表 {table} 不存在") + + if valid_tables: + return valid_tables, valid_schemas + else: + # 如果没有有效的表,选择第一个表 + table_name = list(tables_info.keys())[0] + logger.warning(f"没有找到有效的表,使用默认表 {table_name}") + return [table_name], [tables_info[table_name]] + + except Exception as e: + logger.error(f"选择目标表异常: {str(e)}") + # 出现异常时选择第一个表 + table_name = list(tables_info.keys())[0] + return [table_name], [tables_info[table_name]] + + async def _generate_sql_query(self, user_query: str, table_names: List[str], table_schemas: List[Dict]) -> str: + """生成SQL语句,支持多表关联查询""" + try: + # 构建所有表的结构信息 + tables_info = [] + for table_name, schema in zip(table_names, table_schemas): + columns_info = [] + for col in schema.get('columns', []): + col_info = f"{col['column_name']} ({col['data_type']})" + columns_info.append(col_info) + + table_info = f"表名: {table_name}\n" + table_info += f"表描述: {schema.get('qa_description', '')}\n" + table_info += f"业务上下文: {schema.get('business_context', '')}\n" + table_info += "字段信息:\n" + "\n".join(columns_info) + tables_info.append(table_info) + + schema_text = "\n\n".join(tables_info) + + prompt = f""" + 基于以下表结构,将自然语言查询转换为SQL语句。如果需要关联多个表,请分析表之间的关系,使用合适的JOIN语法: + + {schema_text} + + 用户查询: {user_query} + + 请生成对应的SQL查询语句,要求: + 1. 只返回SQL语句,不要包含其他解释 + 2. 如果查询涉及多个表,需要正确处理表之间的关联关系 + 3. 使用合适的JOIN类型(INNER JOIN、LEFT JOIN等) + 4. 确保SELECT的字段来源明确,必要时使用表名前缀 + """ + + # 使用LLM生成SQL + response = await self.llm.ainvoke(prompt) + sql_query = response.content.strip() + + # 清理SQL语句 + if sql_query.startswith('```sql'): + sql_query = sql_query[6:] + if sql_query.endswith('```'): + sql_query = sql_query[:-3] + + sql_query = sql_query.strip() + + logger.info(f"生成的SQL查询: {sql_query}") + return sql_query + + except Exception as e: + logger.error(f"SQL生成异常: {str(e)}") + raise SQLGenerationError(f'SQL生成失败: {str(e)}') + + async def _execute_database_query(self, user_id: int, sql_query: str, database_config_id: int) -> Dict[str, Any]: + """执行SQL语句""" + try: + # 获取数据库配置 + from ..services.database_config_service import DatabaseConfigService + config_service = DatabaseConfigService(self.db) + config = config_service.get_config_by_id(database_config_id, user_id) + + if not config: + raise QueryExecutionError('数据库配置不存在') + + # 根据数据库类型选择对应的工具 + try: + db_tool = self._get_database_tool(config.db_type) + except ValueError as e: + raise QueryExecutionError(str(e)) + + # 使用对应的数据库工具执行查询 + if str(user_id) in db_tool.connections: + query_result = db_tool._execute_query(db_tool.connections[str(user_id)]['connection'], sql_query) + else: + raise QueryExecutionError('请重新进行数据库连接') + + if query_result.get('success'): + data = query_result.get('data', []) + return { + 'success': True, + 'data': data, + 'row_count': len(data), + 'columns': query_result.get('columns', []), + 'sql_query': sql_query + } + else: + raise QueryExecutionError(query_result.get('error', '查询执行失败')) + + except Exception as e: + logger.error(f"查询执行异常: {str(e)}") + raise QueryExecutionError(f'查询执行失败: {str(e)}') + + async def _generate_database_summary(self, user_query: str, query_result: Dict, tables_str: str) -> str: + """生成AI总结,支持多表查询结果""" + try: + data = query_result.get('data', []) + row_count = query_result.get('row_count', 0) + columns = query_result.get('columns', []) + sql_query = query_result.get('sql_query', '') + + # 构建总结提示词 + prompt = f""" +用户查询: {user_query} +涉及的表: {tables_str} +查询结果: 共 {row_count} 条记录 +查询的字段: {', '.join(columns)} +执行的SQL: {sql_query} + +前几条数据示例: +{str(data[:3]) if data else '无数据'} + +请基于以上信息,用中文生成一个简洁的查询结果总结,包括: +1. 查询涉及的表及其关系 +2. 查询的主要发现和数据特征 +3. 如果有关联查询,说明关联的结果特点 +4. 最后对用户的问题进行回答 + +总结要求: +1. 语言简洁明了 +2. 重点突出查询结果 +3. 如果是多表查询,需要说明表之间的关系 +4. 总结不超过300字 +""" + + # 使用LLM生成总结 + response = await self.llm.ainvoke(prompt) + summary = response.content.strip() + + logger.info(f"生成的总结: {summary[:100]}...") + return summary + + except Exception as e: + logger.error(f"总结生成异常: {str(e)}") + return f"查询完成,共返回 {query_result.get('row_count', 0)} 条记录。涉及的表: {tables_str}" + + async def process_database_query( + self, + user_query: str, + user_id: int, + database_config_id: int, + table_name: Optional[str] = None, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ) -> Dict[str, Any]: + """ + 处理数据库智能问数查询的主要工作流(基于保存的表元数据) + + 新流程: + 1. 根据database_config_id获取数据库配置 + 2. 创建数据库连接 + 3. 从系统数据库读取表元数据(只包含启用问答的表) + 4. 根据表元数据生成SQL + 5. 执行SQL查询 + 6. 查询数据后处理成表格形式 + 7. 生成数据总结 + 8. 返回结果 + + Args: + user_query: 用户问题 + user_id: 用户ID + database_config_id: 数据库配置ID + table_name: 表名(可选) + conversation_id: 对话ID + is_new_conversation: 是否为新对话 + + Returns: + 包含查询结果的字典 + """ + try: + logger.info(f"开始执行数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...") + + # 步骤1: 根据database_config_id获取数据库配置并创建连接 + connection_result = await self._connect_database(user_id, database_config_id) + if not connection_result['success']: + raise DatabaseConnectionError(connection_result['message']) + + logger.info("数据库连接成功") + + # 步骤2: 从系统数据库读取表元数据(只包含启用问答的表) + tables_info = await self._get_saved_tables_metadata(user_id, database_config_id) + + logger.info(f"表元数据读取完成 - 共{len(tables_info)}个启用问答的表") + + # 步骤3: 根据表元数据选择相关表并生成SQL + target_tables, target_schemas = await self._select_target_table(user_query, tables_info) + sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas) + + logger.info(f"SQL生成完成 - 目标表: {', '.join(target_tables)}") + + # 步骤4: 执行SQL查询 + query_result = await self._execute_database_query(user_id, sql_query, database_config_id) + logger.info("查询执行完成") + + # 步骤5: 查询数据后处理成表格形式 + table_data = self._convert_query_result_to_table_data(query_result) + + # 步骤6: 生成数据总结 + summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables)) + + # 步骤7: 返回结果 + return { + 'success': True, + 'data': { + **table_data, + 'generated_sql': sql_query, + 'summary': summary, + 'table_names': target_tables, + 'query_result': query_result, + 'metadata_source': 'saved_database' # 标记元数据来源 + } + } + + except SmartWorkflowError as e: + logger.error(f"数据库工作流异常: {str(e)}") + return { + 'success': False, + 'error': str(e), + 'error_type': type(e).__name__ + } + except Exception as e: + logger.error(f"数据库工作流未知异常: {str(e)}", exc_info=True) + return { + 'success': False, + 'error': f'系统异常: {str(e)}', + 'error_type': 'SystemError' + } \ No newline at end of file diff --git a/th_agenter/services/smart_excel_workflow.py b/th_agenter/services/smart_excel_workflow.py new file mode 100644 index 0000000..8870a3e --- /dev/null +++ b/th_agenter/services/smart_excel_workflow.py @@ -0,0 +1,1355 @@ +from typing import Dict, Any, List, Optional, Union +import pandas as pd +import os +import tempfile +import json +import logging +from datetime import datetime +import asyncio +from concurrent.futures import ThreadPoolExecutor +from langchain_core.runnables import RunnableLambda +from langchain_community.chat_models import ChatZhipuAI +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnableSequence as LLMChain +from th_agenter.core.context import UserContext +from .smart_query import ExcelAnalysisService +from .excel_metadata_service import ExcelMetadataService +from ..core.config import get_settings +from pathlib import Path + +# 配置日志 +logger = logging.getLogger(__name__) + +class SmartWorkflowError(Exception): + """智能工作流自定义异常""" + pass + +class FileLoadError(SmartWorkflowError): + """文件加载异常""" + pass + +class FileSelectionError(SmartWorkflowError): + """文件选择异常""" + pass + +class CodeExecutionError(SmartWorkflowError): + """代码执行异常""" + pass + + +class SmartExcelWorkflowManager: + """ + 智能工作流管理器 + 负责协调文件选择、代码生成和执行的完整流程 + """ + + def __init__(self, db=None): + self.executor = ThreadPoolExecutor(max_workers=4) + self.excel_service = ExcelAnalysisService() + self.db = db + if db: + self.metadata_service = ExcelMetadataService(db) + else: + self.metadata_service = None + + async def initialize(self): + from ..core.new_agent import new_agent + # 禁用流式响应,避免pandas代理兼容性问题 + self.llm = await new_agent(streaming=False) + + async def _run_in_executor(self, func, *args): + """在线程池中运行阻塞函数""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, func, *args) + + def _convert_dataframe_to_markdown(self, df_string: str) -> str: + """ + 将DataFrame的字符串表示转换为Markdown表格格式 + + Args: + df_string: DataFrame的字符串表示 + + Returns: + Markdown格式的表格字符串 + """ + try: + lines = df_string.strip().split('\n') + + # 查找表格数据的开始位置 + table_start = -1 + for i, line in enumerate(lines): + if '|' in line or (len(line.split()) > 1 and any(char.isdigit() for char in line)): + table_start = i + break + + if table_start == -1: + return df_string # 如果找不到表格,返回原始字符串 + + # 提取表格行 + table_lines = [] + for line in lines[table_start:]: + if line.strip() and not line.startswith('Name:') and not line.startswith('dtype:'): + table_lines.append(line.strip()) + + if not table_lines: + return df_string + + # 处理第一行作为表头 + if table_lines: + # 检查是否已经是表格格式 + if '|' in table_lines[0]: + # 已经是表格格式,直接返回 + markdown_lines = [] + for i, line in enumerate(table_lines): + if i == 1 and not line.startswith('|'): + # 添加分隔行 + cols = table_lines[0].count('|') + 1 + separator = '|' + '---|' * (cols - 1) + '---|' + markdown_lines.append(separator) + markdown_lines.append(line) + return '\n'.join(markdown_lines) + else: + # 转换为Markdown表格格式 + markdown_lines = [] + + # 处理表头 + if len(table_lines) > 0: + # 假设第一行是索引和数据的混合 + first_line = table_lines[0] + parts = first_line.split() + + if len(parts) > 1: + # 创建表头 + header = '| 索引 | ' + ' | '.join(parts[1:]) + ' |' + markdown_lines.append(header) + + # 创建分隔行 + separator = '|' + '---|' * len(parts) + '' + markdown_lines.append(separator) + + # 处理数据行 + for line in table_lines[1:]: + if line.strip(): + parts = line.split() + if len(parts) > 0: + row = '| ' + ' | '.join(parts) + ' |' + markdown_lines.append(row) + + if markdown_lines: + return '\n'.join(markdown_lines) + + return df_string # 如果转换失败,返回原始字符串 + + except Exception as e: + logger.warning(f"DataFrame转Markdown失败: {str(e)}") + return df_string # 转换失败时返回原始字符串 + + def _convert_dataframe_to_markdown(self, df_string: str) -> str: + """ + 将DataFrame的字符串表示转换为Markdown表格格式 + + Args: + df_string: DataFrame的字符串表示 + + Returns: + Markdown格式的表格字符串 + """ + try: + lines = df_string.strip().split('\n') + + # 查找表格数据的开始位置 + table_start = -1 + for i, line in enumerate(lines): + if '|' in line or (len(line.split()) > 1 and any(char.isdigit() for char in line)): + table_start = i + break + + if table_start == -1: + return df_string # 如果找不到表格,返回原始字符串 + + # 提取表格行 + table_lines = [] + for line in lines[table_start:]: + if line.strip() and not line.startswith('Name:') and not line.startswith('dtype:'): + table_lines.append(line.strip()) + + if not table_lines: + return df_string + + # 处理第一行作为表头 + if table_lines: + # 检查是否已经是表格格式 + if '|' in table_lines[0]: + # 已经是表格格式,直接返回 + markdown_lines = [] + for i, line in enumerate(table_lines): + if i == 1 and not line.startswith('|'): + # 添加分隔行 + cols = table_lines[0].count('|') + 1 + separator = '|' + '---|' * (cols - 1) + '---|' + markdown_lines.append(separator) + markdown_lines.append(line) + return '\n'.join(markdown_lines) + else: + # 转换为Markdown表格格式 + markdown_lines = [] + + # 处理表头 + if len(table_lines) > 0: + # 假设第一行是索引和数据的混合 + first_line = table_lines[0] + parts = first_line.split() + + if len(parts) > 1: + # 创建表头 + header = '| 索引 | ' + ' | '.join(parts[1:]) + ' |' + markdown_lines.append(header) + + # 创建分隔行 + separator = '|' + '---|' * len(parts) + '' + markdown_lines.append(separator) + + # 处理数据行 + for line in table_lines[1:]: + if line.strip(): + parts = line.split() + if len(parts) > 0: + row = '| ' + ' | '.join(parts) + ' |' + markdown_lines.append(row) + + if markdown_lines: + return '\n'.join(markdown_lines) + + return df_string # 如果转换失败,返回原始字符串 + + except Exception as e: + logger.warning(f"DataFrame转Markdown失败: {str(e)}") + return df_string # 转换失败时返回原始字符串 + + def _convert_dataframe_to_table_data(self, df: pd.DataFrame) -> Dict[str, Any]: + """ + 将DataFrame转换为前端Table组件可用的结构化数据 + + Args: + df: pandas DataFrame + + Returns: + 包含columns和data的字典 + """ + try: + # 获取列信息 + columns = [] + for col in df.columns: + columns.append({ + 'prop': str(col), + 'label': str(col), + 'width': 'auto' + }) + + # 获取数据 + data = [] + for index, row in df.iterrows(): + row_data = {'_index': str(index)} + for col in df.columns: + # 处理各种数据类型 + value = row[col] + if pd.isna(value): + row_data[str(col)] = '' + elif isinstance(value, (int, float)): + row_data[str(col)] = value + else: + row_data[str(col)] = str(value) + data.append(row_data) + + return { + 'columns': columns, + 'data': data, + 'total': len(df) + } + + except Exception as e: + logger.warning(f"DataFrame转Table数据失败: {str(e)}") + return { + 'columns': [{'prop': 'result', 'label': '结果'}], + 'data': [{'result': str(df)}], + 'total': 1 + } + + async def process_excel_query_stream( + self, + user_query: str, + user_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ): + """ + 流式处理智能问数查询的主要工作流 + 实时推送每个工作流步骤 + + Args: + user_query: 用户问题 + user_id: 用户ID + conversation_id: 对话ID + is_new_conversation: 是否为新对话 + + Yields: + 包含工作流步骤或最终结果的字典 + """ + workflow_steps = [] + + try: + logger.info(f"开始执行流式智能查询工作流 - 用户ID: {user_id}, 查询: {user_query[:50]}...") + + # 步骤1: 加载文件列表 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'file_loading', + 'status': 'running', + 'message': '正在加载用户文件列表...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + if is_new_conversation or conversation_id is None: + file_list = await self._load_user_file_list(user_id) + if not file_list: + raise FileLoadError('未找到可用的Excel文件,请先上传文件') + else: + file_list = await self._load_user_file_list(user_id) + + step_completed = { + 'type': 'workflow_step', + 'step': 'file_loading', + 'status': 'completed', + 'message': f'成功加载{len(file_list)}个文件', + 'details': {'file_count': len(file_list)}, + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_completed) + yield step_completed + logger.info(f"文件加载完成 - 共{len(file_list)}个文件") + + except FileLoadError as e: + step_failed = { + 'type': 'workflow_step', + 'step': 'file_loading', + 'status': 'failed', + 'message': str(e), + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_failed) + yield step_failed + + yield { + 'type': 'final_result', + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + return + + # 步骤2: 智能文件选择 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'file_selection', + 'status': 'running', + 'message': '正在分析问题并选择相关文件...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + selected_files = await self._select_relevant_files(user_query, file_list) + + if not selected_files: + raise FileSelectionError('未找到与问题相关的Excel文件') + selected_files_names = names_str = ", ".join([file["filename"] for file in selected_files]) + step_completed = { + 'type': 'workflow_step', + 'step': 'file_selection', + 'status': 'completed', + 'message': f'选择了{len(selected_files)}个相关文件:{selected_files_names}', + 'details': {'selection_count': len(selected_files)}, + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_completed) + yield step_completed + logger.info(f"文件选择完成 - 选择了{len(selected_files)}个文件") + + except FileSelectionError as e: + step_failed = { + 'type': 'workflow_step', + 'step': 'file_selection', + 'status': 'failed', + 'message': str(e), + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_failed) + yield step_failed + + yield { + 'type': 'final_result', + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + return + + # 步骤3: 数据加载 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'data_loading', + 'status': 'running', + 'message': '正在加载Excel数据...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + dataframes = await self._load_selected_dataframes(selected_files, user_id) + + step_completed = { + 'type': 'workflow_step', + 'step': 'data_loading', + 'status': 'completed', + 'message': f'成功加载{len(dataframes)}个数据表', + 'details': { + 'dataframe_count': len(dataframes), + 'total_rows': sum(len(df) for df in dataframes.values()) + }, + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_completed) + yield step_completed + logger.info(f"数据加载完成 - 共{len(dataframes)}个数据表") + + except Exception as e: + step_failed = { + 'type': 'workflow_step', + 'step': 'data_loading', + 'status': 'failed', + 'message': str(e), + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_failed) + yield step_failed + + yield { + 'type': 'final_result', + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + return + + # 步骤4: 代码执行 + try: + step_data = { + 'type': 'workflow_step', + 'step': 'code_execution', + 'status': 'running', + 'message': '正在生成并执行Python代码...', + 'timestamp': datetime.now().isoformat() + } + yield step_data + + result = await self._execute_smart_query(user_query, dataframes, selected_files) + + step_completed = { + 'type': 'workflow_step', + 'step': 'code_execution', + 'status': 'completed', + 'message': '成功执行Python代码分析', + 'details': { + 'result_type': result.get('result_type'), + 'data_count': result.get('total', 0) + }, + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_completed) + yield step_completed + logger.info("查询执行完成") + + # 发送最终结果 + yield { + 'type': 'final_result', + 'success': True, + 'data': result, + 'workflow_steps': workflow_steps + } + + except CodeExecutionError as e: + error_msg = f'代码执行失败: {str(e)}' + step_failed = { + 'type': 'workflow_step', + 'step': 'code_execution', + 'status': 'failed', + 'message': error_msg, + 'timestamp': datetime.now().isoformat() + } + workflow_steps.append(step_failed) + yield step_failed + logger.error(error_msg) + + yield { + 'type': 'final_result', + 'success': False, + 'message': error_msg, + 'workflow_steps': workflow_steps + } + return + + except SmartWorkflowError as e: + logger.error(f"智能工作流异常: {str(e)}") + yield { + 'type': 'final_result', + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + except Exception as e: + logger.error(f"智能工作流未知异常: {str(e)}", exc_info=True) + yield { + 'type': 'final_result', + 'success': False, + 'message': f'系统异常: {str(e)}', + 'workflow_steps': workflow_steps + } + + async def process_smart_query( + self, + user_query: str, + user_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ) -> Dict[str, Any]: + """ + 处理智能问数查询的主要工作流 + + Args: + user_query: 用户问题 + user_id: 用户ID + conversation_id: 对话ID + is_new_conversation: 是否为新对话 + + Returns: + 包含查询结果的字典 + """ + workflow_steps = [] + + try: + logger.info(f"开始执行智能查询工作流 - 用户ID: {user_id}, 查询: {user_query[:50]}...") + + # 步骤1: 加载文件列表 + try: + if is_new_conversation or conversation_id is None: + file_list = await self._load_user_file_list(user_id) + if not file_list: + raise FileLoadError('未找到可用的Excel文件,请先上传文件') + else: + file_list = await self._load_user_file_list(user_id) + + workflow_steps.append({ + 'step': 'file_loading', + 'status': 'completed', + 'message': f'成功加载{len(file_list)}个文件', + 'details': {'file_count': len(file_list)} + }) + logger.info(f"文件加载完成 - 共{len(file_list)}个文件") + + except FileLoadError as e: + workflow_steps.append({ + 'step': 'file_loading', + 'status': 'failed', + 'message': str(e) + }) + return { + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + + # 步骤2: 智能文件选择 + try: + selected_files = await self._select_relevant_files(user_query, file_list) + + if not selected_files: + raise FileSelectionError('未找到与问题相关的Excel文件') + + workflow_steps.append({ + 'step': 'file_selection', + 'status': 'completed', + 'message': f'选择了{len(selected_files)}个相关文件', + 'selected_files': [f['filename'] for f in selected_files], + 'details': {'selection_count': len(selected_files)} + }) + logger.info(f"文件选择完成 - 选中{len(selected_files)}个文件") + + except FileSelectionError as e: + workflow_steps.append({ + 'step': 'file_selection', + 'status': 'failed', + 'message': str(e) + }) + return { + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + + # 步骤3: 加载DataFrame + try: + dataframes = await self._load_selected_dataframes(selected_files, user_id) + + if not dataframes: + raise FileLoadError('无法加载选中的Excel文件数据') + + workflow_steps.append({ + 'step': 'dataframe_loading', + 'status': 'completed', + 'message': f'成功加载{len(dataframes)}个数据表', + 'details': { + 'dataframe_count': len(dataframes), + 'total_rows': sum(len(df) for df in dataframes.values()) + } + }) + logger.info(f"DataFrame加载完成 - {len(dataframes)}个数据表") + + except Exception as e: + error_msg = f'数据加载失败: {str(e)}' + workflow_steps.append({ + 'step': 'dataframe_loading', + 'status': 'failed', + 'message': error_msg + }) + logger.error(error_msg) + return { + 'success': False, + 'message': error_msg, + 'workflow_steps': workflow_steps + } + + # 步骤4: 执行查询 + try: + result = await self._execute_smart_query(user_query, dataframes, selected_files) + + workflow_steps.append({ + 'step': 'code_execution', + 'status': 'completed', + 'message': '成功执行pandas代码分析', + 'details': { + 'result_type': result.get('result_type'), + 'data_count': result.get('total', 0) + } + }) + logger.info("查询执行完成") + + return { + 'success': True, + 'data': result, + 'workflow_steps': workflow_steps + } + + except CodeExecutionError as e: + error_msg = f'代码执行失败: {str(e)}' + workflow_steps.append({ + 'step': 'code_execution', + 'status': 'failed', + 'message': error_msg + }) + logger.error(error_msg) + return { + 'success': False, + 'message': error_msg, + 'workflow_steps': workflow_steps + } + + except SmartWorkflowError as e: + logger.error(f"智能工作流异常: {str(e)}") + return { + 'success': False, + 'message': str(e), + 'workflow_steps': workflow_steps + } + except Exception as e: + logger.error(f"工作流执行失败: {str(e)}", exc_info=True) + workflow_steps.append({ + 'step': 'error', + 'status': 'failed', + 'message': f'系统错误: {str(e)}' + }) + return { + 'success': False, + 'message': f'工作流执行失败: {str(e)}', + 'workflow_steps': workflow_steps + } + + async def _load_user_file_list(self, user_id: int) -> List[Dict[str, Any]]: + """ + 加载用户的所有文件列表信息 + """ + try: + # 从数据库获取用户的文件元数据 + file_metadata = [] + if self.metadata_service: + files, total = await self._run_in_executor( + self.metadata_service.get_user_files, user_id + ) + file_metadata = files + else: + logger.warning("metadata_service未初始化,跳过数据库文件查询") + + # 检查持久化目录中的文件 + persistent_dir = os.path.join("backend", "data", 'uploads', f"excel_{user_id}") + persistent_files = [] + if os.path.exists(persistent_dir): + persistent_files = [f for f in os.listdir(persistent_dir) + if f.endswith('.pkl')] + + file_list = [] + + # 合并数据库和持久化文件信息 + for metadata in file_metadata: + # 获取默认sheet的信息 + default_sheet = metadata.default_sheet or (metadata.sheet_names[0] if metadata.sheet_names else None) + columns = metadata.columns_info.get(default_sheet, []) if metadata.columns_info and default_sheet else [] + row_count = metadata.total_rows.get(default_sheet, 0) if metadata.total_rows and default_sheet else 0 + column_count = metadata.total_columns.get(default_sheet, 0) if metadata.total_columns and default_sheet else 0 + + file_info = { + 'id': metadata.id, + 'filename': metadata.original_filename, + 'file_path': metadata.file_path, + 'columns': columns, + 'row_count': row_count, + 'column_count': column_count, + 'description': f'Excel文件,包含{str(len(metadata.sheet_names))}个工作表' if metadata.sheet_names else '', + 'created_at': metadata.created_at.isoformat() if metadata.created_at else None + } + file_list.append(file_info) + + return file_list + + except Exception as e: + print(f"加载文件列表失败: {e}") + return [] + + async def _select_relevant_files( + self, + user_query: str, + file_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + 根据用户问题智能选择相关的Excel文件 + + Args: + user_query: 用户问题 + file_list: 可用文件列表 + + Returns: + 选中的文件列表 + + Raises: + FileSelectionError: 文件选择过程中发生错误 + """ + if not file_list: + logger.warning("文件列表为空,无法进行文件选择") + raise FileSelectionError("没有可用的文件进行选择") + + # 如果只有一个文件,直接返回 + if len(file_list) == 1: + logger.info("只有一个文件,直接选择") + return file_list + + try: + logger.info(f"开始智能文件选择 - 可用文件数: {len(file_list)}") + + # 构建文件选择提示 + file_descriptions = [] + for i, file_info in enumerate(file_list): + # 确保数值类型转换为字符串 + row_count = str(file_info.get('row_count', 'unknown')) + column_count = str(file_info.get('column_count', 'unknown')) + columns = file_info.get('columns', []) + # 确保列名都是字符串类型 + column_names = ', '.join(str(col) for col in columns) + + desc = f""" + 文件{i+1}: {file_info['filename']} + - 行数: {row_count} + - 列数: {column_count} + - 列名: {column_names} + - 描述: {file_info.get('description', '无描述')} + """ + file_descriptions.append(desc) + file_des_str = ' \n'.join(file_descriptions) + prompt = f""" + 用户问题: {user_query} + + 可用的Excel文件: + {file_des_str} + + 请分析用户问题,选择最相关的Excel文件来回答问题。 + 如果问题涉及多个文件的数据关联,可以选择多个文件。 + 如果问题只涉及特定类型的数据,只选择相关的文件。 + + 请返回JSON格式的结果,包含选中文件的索引(从1开始): + {{"selected_files": [1, 2, ...], "reason": "选择理由"}} + """ + + # 调用LLM进行文件选择 + response = await self._run_in_executor( + self.llm.invoke, [HumanMessage(content=prompt)] + ) + + # 解析LLM响应 + try: + import re + json_match = re.search(r'\{.*\}', response.content, re.DOTALL) + if json_match: + result = json.loads(json_match.group()) + selected_indices = result.get('selected_files', []) + reason = result.get('reason', '未提供理由') + + # 转换索引为实际文件 + selected_files = [] + for idx in selected_indices: + if 1 <= idx <= len(file_list): + selected_files.append(file_list[idx - 1]) + + if not selected_files: + logger.warning("LLM选择结果为空,回退到选择所有文件") + return file_list + + logger.info(f"成功选择{len(selected_files)}个文件: {[f['filename'] for f in selected_files]}") + logger.info(f"选择理由: {reason}") + return selected_files + else: + logger.warning("无法解析LLM响应中的JSON,回退到选择所有文件") + return file_list + + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"解析LLM响应失败: {str(e)},回退到选择所有文件") + return file_list + + except Exception as e: + raise e + logger.error(f"文件选择过程中发生错误: {str(e)}") + # 出错时返回所有文件作为备选方案 + logger.info("回退到选择所有文件") + return file_list + + async def _load_selected_dataframes( + self, + selected_files: List[Dict[str, Any]], + user_id: int + ) -> Dict[str, pd.DataFrame]: + """ + 加载选中的Excel文件为DataFrame + 使用新的持久化目录结构和文件匹配逻辑 + """ + dataframes = {} + + # 构建用户专属目录路径 + # base_dir = os.path.join("backend", "data", f"excel_{user_id}") + current_user_id = UserContext.get_current_user().id + backend_dir = Path(__file__).parent.parent.parent # 获取backend目录 + base_dir = backend_dir / "data/uploads" / f'excel_{current_user_id}' + if not os.path.exists(base_dir): + logger.warning(f"用户目录不存在: {base_dir}") + return dataframes + + try: + # 获取目录中所有文件 + all_files = os.listdir(base_dir) + + for file_info in selected_files: + filename = file_info.get('filename', '') + if not filename: + logger.warning(f"文件信息缺少filename: {file_info}") + continue + + # 查找匹配的文件(格式:{uuid}_{original_filename}) + matching_files = [] + for file in all_files: + if file.endswith(f"_{filename}") or file.endswith(f"_{filename}.pkl"): + matching_files.append(file) + + if not matching_files: + logger.warning(f"未找到匹配的文件: {filename}") + continue + + # 如果有多个匹配文件,选择最新的 + if len(matching_files) > 1: + matching_files.sort(key=lambda x: os.path.getmtime(os.path.join(base_dir, x)), reverse=True) + logger.info(f"找到多个匹配文件,选择最新的: {matching_files[0]}") + + selected_file = matching_files[0] + file_path = os.path.join(base_dir, selected_file) + + try: + # 优先加载pickle文件 + if selected_file.endswith('.pkl'): + df = await self._run_in_executor(pd.read_pickle, file_path) + logger.info(f"成功从pickle加载文件: {selected_file}") + else: + # 如果没有pickle文件,尝试加载原始文件 + if selected_file.endswith(('.xlsx', '.xls')): + df = await self._run_in_executor(pd.read_excel, file_path) + elif selected_file.endswith('.csv'): + df = await self._run_in_executor(pd.read_csv, file_path) + else: + logger.warning(f"不支持的文件格式: {selected_file}") + continue + logger.info(f"成功从原始文件加载: {selected_file}") + + # 使用原始文件名作为key + dataframes[filename] = df + logger.info(f"成功加载DataFrame: {filename}, 形状: {df.shape}") + + except Exception as e: + logger.error(f"加载文件失败 {selected_file}: {e}") + continue + + except Exception as e: + logger.error(f"加载DataFrames时发生错误: {e}") + raise FileLoadError(f"无法加载选中的文件: {e}") + + if not dataframes: + raise FileLoadError("没有成功加载任何文件") + + return dataframes + + def _parse_dataframe_string_to_table_data(self, df_string: str, subindex: int = -2) -> Dict[str, Any]: + """ + 将字符串格式的DataFrame转换为表格数据 + + Args: + df_string: DataFrame的字符串表示 + + Returns: + 包含columns和data的字典 + """ + try: + # 按行分割字符串 + lines = df_string.strip().split('\n') + + # 去掉最后两行(如因为最后两行可能是 "[12 rows x 11 columns]和空行") + + if len(lines) >= 2 and subindex == -2 : + lines = lines[:subindex] + + if len(lines) < 2: + # 如果行数不足,返回原始字符串 + return { + 'columns': [{'prop': 'result', 'label': '结果', 'width': 'auto'}], + 'data': [{'result': df_string}], + 'total': 1 + } + + # 第一行是列名 + header_line = lines[0].strip() + # 解析列名(去掉索引列) + columns_raw = header_line.split() + if columns_raw and columns_raw[0].isdigit() == False: + # 如果第一列不是数字,说明包含了列名 + column_names = columns_raw + else: + # 否则使用默认列名 + column_names = [f'Column_{i}' for i in range(len(columns_raw))] + + # 构建列定义 + columns = [] + for i, col_name in enumerate(column_names): + columns.append({ + 'prop': f'col_{i}', + 'label': str(col_name), + 'width': 'auto' + }) + + # 解析数据行 + data = [] + for line in lines[1:]: + if line.strip(): + # 分割数据行 + row_values = line.strip().split() + if row_values: + row_data = {} + # 第一个值通常是索引 + if len(row_values) > 0 and row_values[0].isdigit(): + row_data['_index'] = row_values[0] + values = row_values[1:] + else: + values = row_values + + # 填充列数据 + for i, value in enumerate(values): + if i < len(columns): + col_prop = f'col_{i}' + # 处理NaN值 + if value.lower() == 'nan': + row_data[col_prop] = '' + else: + row_data[col_prop] = value + + data.append(row_data) + + return { + 'columns': columns, + 'data': data, + 'total': len(data) + } + + except Exception as e: + logger.warning(f"解析DataFrame字符串失败: {str(e)}") + return { + 'columns': [{'prop': 'result', 'label': '结果', 'width': 'auto'}], + 'data': [{'result': df_string}], + 'total': 1 + } + + async def _execute_smart_query( + self, + user_query: str, + dataframes: Dict[str, pd.DataFrame], + selected_files: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + 执行智能查询,生成和运行pandas代码 + + Args: + user_query: 用户查询 + dataframes: 加载的数据框字典 + selected_files: 选中的文件信息 + + Returns: + 查询结果字典 + + Raises: + CodeExecutionError: 代码执行失败 + """ + if not dataframes: + raise CodeExecutionError("没有可用的数据文件") + + logger.info(f"开始执行智能查询: {user_query[:50]}...") + + try: + # 如果有多个DataFrame,合并或选择主要的一个 + if len(dataframes) == 1: + main_df = list(dataframes.values())[0] + main_filename = list(dataframes.keys())[0] + else: + # 多个文件时,选择行数最多的作为主DataFrame + main_filename = max(dataframes.keys(), key=lambda k: len(dataframes[k])) + main_df = dataframes[main_filename] + + logger.info(f"选择主数据文件: {main_filename}, 行数: {len(main_df)}, 列数: {len(main_df.columns)}") + + # 验证数据框 + if main_df.empty: + raise CodeExecutionError(f"主数据文件 {main_filename} 为空") + + # 使用PythonAstREPLTool替代pandas代理 + try: + from langchain_experimental.tools import PythonAstREPLTool + from langchain_core.output_parsers import JsonOutputKeyToolsParser + from langchain_core.prompts import ChatPromptTemplate + + # 准备数据框字典,支持多个文件 + df_locals = {} + var_name_to_filename = {} # 变量名到文件名的映射 + for filename, df in dataframes.items(): + # 使用简化的变量名 + var_name = f"df_{len(df_locals) + 1}" if len(dataframes) > 1 else "df" + df_locals[var_name] = df + var_name_to_filename[var_name] = filename + + # 创建Python代码执行工具 + python_tool = PythonAstREPLTool(locals=df_locals) + + logger.info(f"创建Python工具成功,可用数据框: {list(df_locals.keys())}") + + except Exception as e: + raise CodeExecutionError(f"创建Python工具失败: {str(e)}") + + # 构建数据集信息(包含文件名和前5行数据) + dataset_info = [] + for var_name, df in df_locals.items(): + filename = var_name_to_filename[var_name] + + # 基本信息 + basic_info = f"- {var_name} (来源文件: {filename}): {len(df)}行 x {len(df.columns)}列" + + # 列名信息 + columns_info = f" 列名: {', '.join(str(col) for col in df.columns.tolist())}" + + # 前5行数据预览 + try: + preview_df = df.head(5) + preview_data = [] + for idx, row in preview_df.iterrows(): + row_data = [] + for col in df.columns: + value = row[col] + # 处理空值和特殊值 + if pd.isna(value): + row_data.append('NaN') + elif isinstance(value, (int, float)): + row_data.append(str(value)) + else: + # 限制字符串长度避免过长 + str_value = str(value) + if len(str_value) > 20: + str_value = str_value[:17] + '...' + row_data.append(str_value) + preview_data.append(f" 行{idx}: {', '.join(row_data)}") + + preview_info = f" 前5行数据预览:\n{chr(10).join(preview_data)}" + except Exception as e: + preview_info = f" 前5行数据预览: 无法生成预览 ({str(e)})" + + # 组合完整信息 + dataset_info.append(f"{basic_info}\n{columns_info}\n{preview_info}") + + # 构建系统提示 + system_prompt = f""" + 你所有可以访问的数据来自于传递给您的python_tool里的locals里的pandas数据信(可能有多个)。 + pandas数据集详细信息(文件来源、列名信息和数据预览)如下: + {chr(10).join(dataset_info)} + 请根据用户提出的问题,结合给出的数据集的详细信息,直接编写Python相关代码来计算pandas中的值。要求: + 1. 只返回代码,不返回其他内容 + 2. 只允许使用pandas和内置库 + 3. 确保代码能够直接执行并返回结果,包括import必要的内置库 + 4. 返回的结果应该是详细的、完整的数据,而不仅仅是简单答案 + 5. 结果应该包含足够的上下文信息,让用户能够验证和理解答案 + 6. 优先返回DataFrame格式的结果,便于展示为表格 + 7. 务必不要再去写代码查看数据集的结构,提示词里已经给出了每个数据集的结构信息,直接根据提示词里的结构信息进行判断。 + 8. 要求代码中最后一次print的结果,必需是最后的正确结果(用户所需要的数据) + + 示例: + - 如果问"哪个项目合同额最高",不仅要返回项目名称,还要返回跟该项目其他有用的信息,比如合同额,合同时间,项目类型等(如果表格有该这些字段信息) + - 如果问"销售额最高的产品",要返回产品名称、销售额、销售数量、市场占比等完整信息(如果表格有该这些字段信息) + """ + + # 创建提示模板 + prompt = ChatPromptTemplate([ + ("system", system_prompt), + ("user", "{question}") + ]) + + # 创建解析器 + parser = JsonOutputKeyToolsParser(key_name=python_tool.name, first_tool_only=True) + + # 绑定工具到LLM + llm_with_tools = self.llm.bind_tools([python_tool]) + + def debug_print(x): + print('中间结果:', x) + return x + + debug_node = RunnableLambda(debug_print) + # 创建执行链 + llm_chain = prompt | llm_with_tools | debug_node| parser | debug_node| python_tool| debug_node + + # 执行查询 + try: + logger.debug("开始执行Python工具查询") + result = await self._run_in_executor(llm_chain.invoke, {"question": user_query}) + logger.debug(f"查询执行完成,结果: {str(result)[:200]}...") + except Exception as e: + error_msg = f"Python工具执行失败: {str(e)}" + logger.error(error_msg) + raise CodeExecutionError(error_msg) + + # 处理结果 + + try: + # 检查结果是否为pandas DataFrame + print('result type:',type(result)) + parse_result = '' + if isinstance(result, pd.DataFrame): + # 转换为表格数据 + table_data = self._convert_dataframe_to_table_data(result) + + data = table_data['data'] + columns = table_data['columns'] + total = table_data['total'] + result_type = 'table_data' + logger.info(f"处理DataFrame结果: {len(result)}行 x {len(result.columns)}列") + parse_result = table_data + # PythonAstREPLTool返回的是字符串结果 + elif isinstance(result, str): + # 尝试解析结果中的数据 + result_lines = result.strip().split('\n') + + # 检查是否是DataFrame的字符串表示 + if any('DataFrame' in line or ('|' in line and len([l for l in result_lines if '|' in l]) > 1) for line in result_lines): + + table_data = self._parse_dataframe_string_to_table_data(result) + data = table_data['data'] + columns = table_data['columns'] + total = 1 + result_type = 'table_data' + parse_result = table_data + elif ('rows' in result_lines[-1] and 'columns' in result_lines[-1]): + # 尝试解析DataFrame字符串为表格数据 + table_data = self._parse_dataframe_string_to_table_data(result) + if 'data' in table_data and 'columns' in table_data: + data = table_data['data'] + columns = table_data['columns'] + total = 1 + result_type = 'table_data' + parse_result = table_data + else: + total = 1 + result_type = 'text' + parse_result = table_data + + else: + # 简单的数值或文本结果 + # 尝试解析DataFrame字符串为表格数据 + table_data = self._parse_dataframe_string_to_table_data(result, 0) + if 'data' in table_data and 'columns' in table_data: + data = table_data['data'] + columns = table_data['columns'] + total = table_data['total'] + total = 1 + result_type = 'table_data' + parse_result = table_data + else: + total = 1 + result_type = 'text' + parse_result = table_data + elif isinstance(result, (int, float, bool)): + data = result + columns = result + total = 1 + result_type = 'scalar' + parse_result = result + else: + # 处理其他类型的结果 + data = result + columns = result + total = 1 + result_type = 'other' + parse_result = result + logger.info(f"结果处理完成: {result_type}, 数据行数: {total}") + + except Exception as e: + error_msg = f"结果处理失败: {str(e)}" + logger.error(error_msg) + raise CodeExecutionError(error_msg) + + # 生成总结 + try: + summary = await self._generate_query_summary(user_query, parse_result, main_df) + except Exception as e: + logger.warning(f"生成总结失败: {str(e)}") + summary = f"基于数据分析完成查询,共处理{len(main_df)}行数据。" + + return { + 'data': data, + 'columns': columns, + 'total': total, + 'result_type': result_type, + 'summary': summary, + 'used_files': list(dataframes.keys()), + 'generated_code': f"# 基于文件: {', '.join(dataframes.keys())}\n# 查询: {user_query}\n# 使用LangChain Python工具执行", + 'data_info': { + 'source_files': list(dataframes.keys()), + 'dataframes': {name: {'rows': len(df), 'columns': len(df.columns), 'column_names': [str(col) for col in df.columns.tolist()]} for name, df in dataframes.items()} + } + } + + except CodeExecutionError: + raise + except Exception as e: + error_msg = f"查询执行过程中发生未知错误: {str(e)}" + logger.error(error_msg, exc_info=True) + raise CodeExecutionError(error_msg) + + async def _generate_query_summary( + self, + query: str, + result: Any, + df: pd.DataFrame + ) -> str: + """ + 生成查询结果的AI总结 + """ + try: + logger.debug("开始生成查询总结") + + # 安全地获取数据集信息 + try: + dataset_info = f""" + 数据集信息: + - 总行数: {len(df)} + - 总列数: {len(df.columns)} + - 列名: {', '.join(str(col) for col in df.columns.tolist())} + """ + except Exception as e: + logger.warning(f"获取数据集信息失败: {str(e)}") + dataset_info = "数据集信息: 无法获取" + + # 安全地处理查询结果 + try: + if isinstance(result, pd.DataFrame): + if len(result) > 0: + result_preview = result.head(3).to_string(max_cols=5, max_rows=3) + else: + result_preview = "查询结果为空" + else: + result_preview = str(result) # 限制长度避免过长 + except Exception as e: + logger.warning(f"生成结果预览失败: {str(e)}") + result_preview = "无法生成结果预览" + + prompt = f""" + 用户问题: {query} + + {dataset_info} + + 查询结果: {result_preview}... + + 系统已经根据用户提问查询出了结果,请根据结果生成一个简洁的中文总结,说明: + 1. 查询的主要发现 + 2. 数据的关键特征 + 3. 结果的业务含义 + + 总结应该在100字以内,通俗易懂。 + """ + + try: + response = await self._run_in_executor( + self.llm.invoke, [HumanMessage(content=prompt)] + ) + + summary = response.content.strip() + + # 验证总结长度 + if len(summary) > 200: + logger.warning("AI生成的总结过长,进行截取") + summary = summary[:200] + "..." + + logger.debug("查询总结生成完成") + return summary + + except Exception as e: + logger.warning(f"LLM总结生成失败: {str(e)}") + # 生成基础总结 + if isinstance(result, pd.DataFrame): + return f"基于{len(df)}行数据完成了关于'{query}'的分析,返回了{len(result)}条结果。" + else: + return f"基于{len(df)}行数据完成了关于'{query}'的分析查询。" + + except Exception as e: + logger.error(f"生成查询总结时发生错误: {str(e)}") + # 如果所有方法都失败,返回最基础的总结 + try: + return f"基于数据分析完成查询,共处理{len(df)}行数据。" + except: + return "完成了数据分析查询。" \ No newline at end of file diff --git a/th_agenter/services/smart_query.py b/th_agenter/services/smart_query.py new file mode 100644 index 0000000..81727ea --- /dev/null +++ b/th_agenter/services/smart_query.py @@ -0,0 +1,717 @@ +import pandas as pd +import pymysql +import psycopg2 +import tempfile +import os +from typing import Dict, Any, List +from datetime import datetime +import asyncio +from concurrent.futures import ThreadPoolExecutor + +from langchain_community.chat_models import ChatZhipuAI +from langchain_core.messages import HumanMessage +from loguru import logger + +# 在 SmartQueryService 类中添加方法 + +from .table_metadata_service import TableMetadataService + +class SmartQueryService: + """ + 智能问数服务基类 + """ + def __init__(self): + self.executor = ThreadPoolExecutor(max_workers=4) + self.table_metadata_service = None + + def set_db_session(self, db_session): + """设置数据库会话""" + self.table_metadata_service = TableMetadataService(db_session) + + async def _run_in_executor(self, func, *args): + """在线程池中运行阻塞函数""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, func, *args) + +class ExcelAnalysisService(SmartQueryService): + """ + Excel数据分析服务 + """ + def __init__(self): + super().__init__() + self.user_dataframes = {} # 存储用户的DataFrame + + def analyze_dataframe(self, df: pd.DataFrame, filename: str) -> Dict[str, Any]: + """ + 分析DataFrame并返回基本信息 + """ + try: + # 基本统计信息 + rows, columns = df.shape + + # 列信息 + column_info = [] + for col in df.columns: + col_info = { + 'name': col, + 'dtype': str(df[col].dtype), + 'null_count': int(df[col].isnull().sum()), + 'unique_count': int(df[col].nunique()) + } + + # 如果是数值列,添加统计信息 + if pd.api.types.is_numeric_dtype(df[col]): + df.fillna({col:0}) #数值列,将空值补0 + col_info.update({ + 'mean': float(df[col].mean()) if not df[col].isnull().all() else None, + 'std': float(df[col].std()) if not df[col].isnull().all() else None, + 'min': float(df[col].min()) if not df[col].isnull().all() else None, + 'max': float(df[col].max()) if not df[col].isnull().all() else None + }) + + column_info.append(col_info) + + # 数据预览(前5行) + preview_data = df.head().fillna('').to_dict('records') + + # 数据质量检查 + quality_issues = [] + + # 检查缺失值 + missing_cols = df.columns[df.isnull().any()].tolist() + if missing_cols: + quality_issues.append({ + 'type': 'missing_values', + 'description': f'以下列存在缺失值: {", ".join(map(str, missing_cols))}', + 'columns': missing_cols + }) + + # 检查重复行 + duplicate_count = df.duplicated().sum() + if duplicate_count > 0: + quality_issues.append({ + 'type': 'duplicate_rows', + 'description': f'发现 {duplicate_count} 行重复数据', + 'count': int(duplicate_count) + }) + + return { + 'filename': filename, + 'rows': rows, + 'columns': columns, + 'column_names': [str(col) for col in df.columns.tolist()], + 'column_info': column_info, + 'preview': preview_data, + 'quality_issues': quality_issues, + 'memory_usage': f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB" + } + + except Exception as e: + print(e) + raise Exception(f"DataFrame分析失败: {str(e)}") + + def _create_pandas_agent(self, df: pd.DataFrame): + """ + 创建pandas代理 + """ + try: + # 使用智谱AI作为LLM + llm = ChatZhipuAI( + model="glm-4", + api_key=os.getenv("ZHIPUAI_API_KEY"), + temperature=0.1 + ) + agent = None + logger.error('创建pandas代理失败 - 暂屏蔽处理') + + # # 创建pandas代理 + # agent = create_pandas_dataframe_agent( + # llm=llm, + # df=df, + # verbose=True, + # return_intermediate_steps=True, + # handle_parsing_errors=True, + # max_iterations=3, + # early_stopping_method="force", + # allow_dangerous_code=True # 允许执行代码以支持数据分析 + # ) + + return agent + + except Exception as e: + raise Exception(f"创建pandas代理失败: {str(e)}") + + def _execute_pandas_query(self, agent, query: str) -> Dict[str, Any]: + """ + 执行pandas查询 + """ + try: + # 执行查询 + # 使用invoke方法来处理有多个输出键的情况 + agent_result = agent.invoke({"input": query}) + # 提取主要结果 + result = agent_result.get('output', agent_result) + + # 解析结果 + if isinstance(result, pd.DataFrame): + # 如果结果是DataFrame + data = result.fillna('').to_dict('records') + columns = result.columns.tolist() + total = len(result) + + return { + 'data': data, + 'columns': columns, + 'total': total, + 'result_type': 'dataframe' + } + else: + # 如果结果是其他类型(字符串、数字等) + return { + 'data': [{'result': str(result)}], + 'columns': ['result'], + 'total': 1, + 'result_type': 'scalar' + } + + except Exception as e: + raise Exception(f"pandas查询执行失败: {str(e)}") + + async def execute_natural_language_query( + self, + query: str, + user_id: int, + page: int = 1, + page_size: int = 20 + ) -> Dict[str, Any]: + """ + 执行自然语言查询 + """ + try: + # 查找用户的临时文件 + temp_dir = tempfile.gettempdir() + user_files = [f for f in os.listdir(temp_dir) + if f.startswith(f"excel_{user_id}_") and f.endswith('.pkl')] + + if not user_files: + return { + 'success': False, + 'message': '未找到上传的Excel文件,请先上传文件' + } + + # 使用最新的文件 + latest_file = sorted(user_files)[-1] + file_path = os.path.join(temp_dir, latest_file) + + # 加载DataFrame + df = pd.read_pickle(file_path) + + # 创建pandas代理 + agent = self._create_pandas_agent(df) + + # 执行查询 + query_result = await self._run_in_executor( + self._execute_pandas_query, agent, query + ) + + # 分页处理 + total = query_result['total'] + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + + paginated_data = query_result['data'][start_idx:end_idx] + + # 生成AI总结 + summary = await self._generate_summary(query, query_result, df) + + return { + 'success': True, + 'data': { + 'data': paginated_data, + 'columns': query_result['columns'], + 'total': total, + 'page': page, + 'page_size': page_size, + 'generated_code': f"# 基于自然语言查询: {query}\n# 使用LangChain Pandas代理执行", + 'summary': summary, + 'result_type': query_result['result_type'] + } + } + + except Exception as e: + return { + 'success': False, + 'message': f"查询执行失败: {str(e)}" + } + + async def _generate_summary(self, query: str, result: Dict[str, Any], df: pd.DataFrame) -> str: + """ + 生成AI总结 + """ + try: + llm = ChatZhipuAI( + model="glm-4", + api_key=os.getenv("ZHIPUAI_API_KEY"), + temperature=0.3 + ) + + # 构建总结提示 + prompt = f""" + 用户查询: {query} + + 数据集信息: + - 总行数: {len(df)} + - 总列数: {len(df.columns)} + - 列名: {', '.join(str(col) for col in df.columns.tolist())} + + 查询结果: + - 结果类型: {result['result_type']} + - 结果行数: {result['total']} + - 结果列数: {len(result['columns'])} + + 请基于以上信息,用中文生成一个简洁的分析总结,包括: + 1. 查询的主要目的 + 2. 关键发现 + 3. 数据洞察 + 4. 建议的后续分析方向 + + 总结应该专业、准确、易懂,控制在200字以内。 + """ + + response = await self._run_in_executor( + lambda: llm.invoke([HumanMessage(content=prompt)]) + ) + + return response.content + + except Exception as e: + return f"查询已完成,但生成总结时出现错误: {str(e)}" + +class DatabaseQueryService(SmartQueryService): + """ + 数据库查询服务 + """ + def __init__(self): + super().__init__() + self.user_connections = {} # 存储用户的数据库连接信息 + + def _create_connection(self, config: Dict[str, str]): + """ + 创建数据库连接 + """ + db_type = config['type'].lower() + + try: + if db_type == 'mysql': + connection = pymysql.connect( + host=config['host'], + port=int(config['port']), + user=config['username'], + password=config['password'], + database=config['database'], + charset='utf8mb4' + ) + elif db_type == 'postgresql': + connection = psycopg2.connect( + host=config['host'], + port=int(config['port']), + user=config['username'], + password=config['password'], + database=config['database'] + ) + + return connection + + except Exception as e: + raise Exception(f"数据库连接失败: {str(e)}") + + async def test_connection(self, config: Dict[str, str]) -> bool: + """ + 测试数据库连接 + """ + try: + connection = await self._run_in_executor(self._create_connection, config) + connection.close() + return True + except Exception: + return False + + async def connect_database(self, config: Dict[str, str], user_id: int) -> Dict[str, Any]: + """ + 连接数据库并获取表列表 + """ + try: + connection = await self._run_in_executor(self._create_connection, config) + + # 获取表列表 + tables = await self._run_in_executor(self._get_tables, connection, config['type']) + + # 存储连接信息 + self.user_connections[user_id] = { + 'config': config, + 'connection': connection, + 'connected_at': datetime.now() + } + + return { + 'success': True, + 'data': { + 'tables': tables, + 'database_type': config['type'], + 'database_name': config['database'] + } + } + + except Exception as e: + return { + 'success': False, + 'message': f"数据库连接失败: {str(e)}" + } + + def _get_tables(self, connection, db_type: str) -> List[str]: + """ + 获取数据库表列表 + """ + cursor = connection.cursor() + + try: + if db_type.lower() == 'mysql': + cursor.execute("SHOW TABLES") + tables = [row[0] for row in cursor.fetchall()] + elif db_type.lower() == 'postgresql': + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + """) + tables = [row[0] for row in cursor.fetchall()] + + elif db_type.lower() == 'sqlserver': + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' + """) + tables = [row[0] for row in cursor.fetchall()] + else: + tables = [] + + return tables + + finally: + cursor.close() + + async def get_table_schema(self, table_name: str, user_id: int) -> Dict[str, Any]: + """ + 获取表结构 + """ + try: + if user_id not in self.user_connections: + return { + 'success': False, + 'message': '数据库连接已断开,请重新连接' + } + + connection = self.user_connections[user_id]['connection'] + db_type = self.user_connections[user_id]['config']['type'] + + schema = await self._run_in_executor( + self._get_table_schema, connection, table_name, db_type + ) + + return { + 'success': True, + 'data': { + 'schema': schema, + 'table_name': table_name + } + } + + except Exception as e: + return { + 'success': False, + 'message': f"获取表结构失败: {str(e)}" + } + + def _get_table_schema(self, connection, table_name: str, db_type: str) -> List[Dict[str, Any]]: + """ + 获取表结构信息 + """ + cursor = connection.cursor() + + try: + if db_type.lower() == 'mysql': + cursor.execute(f"DESCRIBE {table_name}") + columns = cursor.fetchall() + schema = [{ + 'column_name': col[0], + 'data_type': col[1], + 'is_nullable': 'YES' if col[2] == 'YES' else 'NO', + 'column_key': col[3], + 'column_default': col[4] + } for col in columns] + elif db_type.lower() == 'postgresql': + cursor.execute(""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = %s + ORDER BY ordinal_position + """, (table_name,)) + columns = cursor.fetchall() + schema = [{ + 'column_name': col[0], + 'data_type': col[1], + 'is_nullable': col[2], + 'column_default': col[3] + } for col in columns] + + else: + schema = [] + + return schema + + finally: + cursor.close() + + async def execute_natural_language_query( + self, + query: str, + table_name: str, + user_id: int, + page: int = 1, + page_size: int = 20 + ) -> Dict[str, Any]: + """ + 执行自然语言数据库查询 + """ + try: + if user_id not in self.user_connections: + return { + 'success': False, + 'message': '数据库连接已断开,请重新连接' + } + + connection = self.user_connections[user_id]['connection'] + + # 这里应该集成MCP服务来将自然语言转换为SQL + # 目前先使用简单的实现 + sql_query = await self._convert_to_sql(query, table_name, connection) + + # 执行SQL查询 + result = await self._run_in_executor( + self._execute_sql_query, connection, sql_query, page, page_size + ) + + # 生成AI总结 + summary = await self._generate_db_summary(query, result, table_name) + + result['generated_code'] = sql_query + result['summary'] = summary + + return { + 'success': True, + 'data': result + } + + except Exception as e: + return { + 'success': False, + 'message': f"数据库查询执行失败: {str(e)}" + } + + async def _convert_to_sql(self, query: str, table_name: str, connection) -> str: + """ + 将自然语言转换为SQL查询 + TODO: 集成MCP服务 + """ + # 这是一个简化的实现,实际应该使用MCP服务 + # 根据常见的查询模式生成SQL + + query_lower = query.lower() + + if '所有' in query or '全部' in query or 'all' in query_lower: + return f"SELECT * FROM {table_name} LIMIT 100" + elif '统计' in query or '总数' in query or 'count' in query_lower: + return f"SELECT COUNT(*) as total_count FROM {table_name}" + elif '最近' in query or 'recent' in query_lower: + return f"SELECT * FROM {table_name} ORDER BY id DESC LIMIT 10" + elif '分组' in query or 'group' in query_lower: + # 简单的分组查询,需要根据实际表结构调整 + return f"SELECT COUNT(*) as count FROM {table_name} GROUP BY id LIMIT 10" + else: + # 默认查询 + return f"SELECT * FROM {table_name} LIMIT 20" + + def _execute_sql_query(self, connection, sql_query: str, page: int, page_size: int) -> Dict[str, Any]: + """ + 执行SQL查询 + """ + cursor = connection.cursor() + + try: + # 执行查询 + cursor.execute(sql_query) + + # 获取列名 + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # 获取所有结果 + all_results = cursor.fetchall() + total = len(all_results) + + # 分页 + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paginated_results = all_results[start_idx:end_idx] + + # 转换为字典格式 + data = [] + for row in paginated_results: + row_dict = {} + for i, value in enumerate(row): + if i < len(columns): + row_dict[columns[i]] = value + data.append(row_dict) + + return { + 'data': data, + 'columns': columns, + 'total': total, + 'page': page, + 'page_size': page_size + } + + finally: + cursor.close() + + async def _generate_db_summary(self, query: str, result: Dict[str, Any], table_name: str) -> str: + """ + 生成数据库查询总结 + """ + try: + llm = ChatZhipuAI( + model="glm-4", + api_key=os.getenv("ZHIPUAI_API_KEY"), + temperature=0.3 + ) + + prompt = f""" + 用户查询: {query} + 目标表: {table_name} + + 查询结果: + - 结果行数: {result['total']} + - 结果列数: {len(result['columns'])} + - 列名: {', '.join(result['columns'])} + + 请基于以上信息,用中文生成一个简洁的数据库查询分析总结,包括: + 1. 查询的主要目的 + 2. 关键数据发现 + 3. 数据特征分析 + 4. 建议的后续查询方向 + + 总结应该专业、准确、易懂,控制在200字以内。 + """ + + response = await self._run_in_executor( + lambda: llm.invoke([HumanMessage(content=prompt)]) + ) + + return response.content + + except Exception as e: + return f"查询已完成,但生成总结时出现错误: {str(e)}" + + # 在 SmartQueryService 类中添加方法 + + from .table_metadata_service import TableMetadataService + + class SmartQueryService: + def __init__(self): + super().__init__() + self.table_metadata_service = None + + def set_db_session(self, db_session): + """设置数据库会话""" + self.table_metadata_service = TableMetadataService(db_session) + + async def get_database_context(self, user_id: int, query: str) -> str: + """获取数据库上下文信息用于问答""" + if not self.table_metadata_service: + return "" + + try: + # 获取用户的表元数据 + table_metadata_list = self.table_metadata_service.get_user_table_metadata(user_id) + + if not table_metadata_list: + return "" + + # 构建数据库上下文 + context_parts = [] + context_parts.append("=== 数据库表信息 ===") + + for metadata in table_metadata_list: + table_info = [] + table_info.append(f"表名: {metadata.table_name}") + + if metadata.table_comment: + table_info.append(f"表描述: {metadata.table_comment}") + + if metadata.qa_description: + table_info.append(f"业务说明: {metadata.qa_description}") + + # 添加列信息 + if metadata.columns_info: + columns = [] + for col in metadata.columns_info: + col_desc = f"{col['column_name']} ({col['data_type']})" + if col.get('column_comment'): + col_desc += f" - {col['column_comment']}" + columns.append(col_desc) + table_info.append(f"字段: {', '.join(columns)}") + + # 添加示例数据 + if metadata.sample_data: + table_info.append(f"示例数据: {metadata.sample_data[:2]}") + + table_info.append(f"总行数: {metadata.row_count}") + + context_parts.append("\n".join(table_info)) + context_parts.append("---") + + return "\n".join(context_parts) + + except Exception as e: + logger.error(f"获取数据库上下文失败: {str(e)}") + return "" + + async def execute_smart_query(self, query: str, user_id: int, **kwargs) -> Dict[str, Any]: + """执行智能查询(集成表元数据)""" + try: + # 获取数据库上下文 + db_context = await self.get_database_context(user_id, query) + + # 构建增强的提示词 + enhanced_prompt = f""" + {db_context} + + 用户问题: {query} + + 请基于上述数据库表信息,生成相应的SQL查询语句。 + 注意: + 1. 使用准确的表名和字段名 + 2. 考虑数据类型和约束 + 3. 参考示例数据理解数据格式 + 4. 生成高效的查询语句 + """ + + # 调用原有的查询逻辑 + return await super().execute_smart_query(enhanced_prompt, user_id, **kwargs) + + except Exception as e: + logger.error(f"智能查询失败: {str(e)}") + return { + 'success': False, + 'message': f"查询失败: {str(e)}" + } \ No newline at end of file diff --git a/th_agenter/services/smart_workflow.py b/th_agenter/services/smart_workflow.py new file mode 100644 index 0000000..db4cce8 --- /dev/null +++ b/th_agenter/services/smart_workflow.py @@ -0,0 +1,87 @@ +from typing import Dict, Any, List, Optional, Union +import logging +from .smart_excel_workflow import SmartExcelWorkflowManager +from .smart_db_workflow import SmartDatabaseWorkflowManager + +logger = logging.getLogger(__name__) + +# 异常类已迁移到各自的工作流文件中 + +class SmartWorkflowManager: + """ + 智能工作流管理器 + 统一入口,委托给具体的Excel或数据库工作流管理器 + """ + + def __init__(self, db=None): + self.db = db + + async def initialize(self): + self.excel_workflow = SmartExcelWorkflowManager(self.db) + await self.excel_workflow.initialize() + self.database_workflow = SmartDatabaseWorkflowManager(self.db) + await self.database_workflow.initialize() + + async def process_excel_query_stream( + self, + user_query: str, + user_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ): + """ + 流式处理Excel智能问数查询,委托给Excel工作流管理器 + """ + async for result in self.excel_workflow.process_excel_query_stream( + user_query, user_id, conversation_id, is_new_conversation + ): + yield result + + async def process_database_query_stream( + self, + user_query: str, + user_id: int, + database_config_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ): + """ + 流式处理数据库智能问数查询,委托给数据库工作流管理器 + """ + async for result in self.database_workflow.process_database_query_stream( + user_query, user_id, database_config_id + ): + yield result + + async def process_smart_query( + self, + user_query: str, + user_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ) -> Dict[str, Any]: + """ + 处理智能问数查询的主要工作流(非流式版本) + 委托给Excel工作流管理器 + """ + return await self.excel_workflow.process_smart_query( + user_query=user_query, + user_id=user_id, + conversation_id=conversation_id, + is_new_conversation=is_new_conversation + ) + + async def process_database_query( + self, + user_query: str, + user_id: int, + database_config_id: int, + conversation_id: Optional[int] = None, + is_new_conversation: bool = False + ) -> Dict[str, Any]: + """ + 处理数据库智能问数查询,委托给数据库工作流管理器 + """ + return await self.database_workflow.process_database_query( + user_query, user_id, database_config_id, None, conversation_id, is_new_conversation + ) \ No newline at end of file diff --git a/th_agenter/services/storage.py b/th_agenter/services/storage.py new file mode 100644 index 0000000..c366230 --- /dev/null +++ b/th_agenter/services/storage.py @@ -0,0 +1,275 @@ +"""File storage service supporting local and S3 storage.""" + +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, BinaryIO, Dict, Any +from fastapi import UploadFile +import boto3 +from botocore.exceptions import ClientError, NoCredentialsError + +from ..core.config import settings +from utils.util_file import FileUtils + + +class StorageBackend(ABC): + """Abstract storage backend interface.""" + + @abstractmethod + async def upload_file( + self, + file: UploadFile, + file_path: str + ) -> Dict[str, Any]: + """Upload file and return storage info.""" + pass + + @abstractmethod + async def delete_file(self, file_path: str) -> bool: + """Delete file from storage.""" + pass + + @abstractmethod + async def get_file_url(self, file_path: str) -> Optional[str]: + """Get file access URL.""" + pass + + @abstractmethod + async def file_exists(self, file_path: str) -> bool: + """Check if file exists.""" + pass + + +class LocalStorageBackend(StorageBackend): + """Local file system storage backend.""" + + def __init__(self, base_path: str): + self.base_path = Path(base_path) + self.base_path.mkdir(parents=True, exist_ok=True) + + async def upload_file( + self, + file: UploadFile, + file_path: str + ) -> Dict[str, Any]: + """Upload file to local storage.""" + full_path = self.base_path / file_path + + # Create directory if it doesn't exist + full_path.parent.mkdir(parents=True, exist_ok=True) + + # Write file + with open(full_path, "wb") as f: + content = await file.read() + f.write(content) + + # Get file info + file_info = FileUtils.get_file_info(str(full_path)) + + return { + "file_path": file_path, + "full_path": str(full_path), + "size": file_info["size_bytes"], + "mime_type": file_info["mime_type"], + "storage_type": "local" + } + + async def delete_file(self, file_path: str) -> bool: + """Delete file from local storage.""" + full_path = self.base_path / file_path + return FileUtils.delete_file(str(full_path)) + + async def get_file_url(self, file_path: str) -> Optional[str]: + """Get local file URL (for development).""" + # In production, you might want to serve files through a web server + full_path = self.base_path / file_path + if full_path.exists(): + return f"/files/{file_path}" + return None + + async def file_exists(self, file_path: str) -> bool: + """Check if file exists in local storage.""" + full_path = self.base_path / file_path + return full_path.exists() + + +class S3StorageBackend(StorageBackend): + """Amazon S3 storage backend.""" + + def __init__( + self, + bucket_name: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region: str = "us-east-1", + endpoint_url: Optional[str] = None + ): + self.bucket_name = bucket_name + self.aws_region = aws_region + + # Initialize S3 client + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region + ) + + self.s3_client = session.client( + 's3', + endpoint_url=endpoint_url # For S3-compatible services like MinIO + ) + + # Verify bucket exists or create it + self._ensure_bucket_exists() + + def _ensure_bucket_exists(self): + """Ensure S3 bucket exists.""" + try: + self.s3_client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + error_code = int(e.response['Error']['Code']) + if error_code == 404: + # Bucket doesn't exist, create it + try: + if self.aws_region == 'us-east-1': + self.s3_client.create_bucket(Bucket=self.bucket_name) + else: + self.s3_client.create_bucket( + Bucket=self.bucket_name, + CreateBucketConfiguration={'LocationConstraint': self.aws_region} + ) + except ClientError as create_error: + raise Exception(f"Failed to create S3 bucket: {create_error}") + else: + raise Exception(f"Failed to access S3 bucket: {e}") + + async def upload_file( + self, + file: UploadFile, + file_path: str + ) -> Dict[str, Any]: + """Upload file to S3.""" + try: + # Read file content + content = await file.read() + + # Determine content type + content_type = FileUtils.get_mime_type(file.filename) or 'application/octet-stream' + + # Upload to S3 + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=file_path, + Body=content, + ContentType=content_type, + Metadata={ + 'original_filename': file.filename or 'unknown', + 'upload_timestamp': str(int(os.time.time())) + } + ) + + return { + "file_path": file_path, + "bucket": self.bucket_name, + "size": len(content), + "mime_type": content_type, + "storage_type": "s3" + } + except (ClientError, NoCredentialsError) as e: + raise Exception(f"Failed to upload file to S3: {e}") + + async def delete_file(self, file_path: str) -> bool: + """Delete file from S3.""" + try: + self.s3_client.delete_object( + Bucket=self.bucket_name, + Key=file_path + ) + return True + except ClientError: + return False + + async def get_file_url(self, file_path: str) -> Optional[str]: + """Get presigned URL for S3 file.""" + try: + url = self.s3_client.generate_presigned_url( + 'get_object', + Params={'Bucket': self.bucket_name, 'Key': file_path}, + ExpiresIn=3600 # 1 hour + ) + return url + except ClientError: + return None + + async def file_exists(self, file_path: str) -> bool: + """Check if file exists in S3.""" + try: + self.s3_client.head_object( + Bucket=self.bucket_name, + Key=file_path + ) + return True + except ClientError: + return False + + +class StorageService: + """统一的存储服务管理器""" + + def __init__(self): + self.storage_type = settings.storage.storage_type + + if self.storage_type == 's3': + self.backend = S3StorageBackend( + bucket_name=settings.storage.s3_bucket_name, + aws_access_key_id=settings.storage.aws_access_key_id, + aws_secret_access_key=settings.storage.aws_secret_access_key, + aws_region=settings.storage.aws_region, + endpoint_url=settings.storage.s3_endpoint_url + ) + else: + # 确保使用绝对路径,避免在不同目录运行时路径不一致 + upload_dir = settings.storage.upload_directory + if not os.path.isabs(upload_dir): + # 如果是相对路径,则基于项目根目录计算绝对路径 + # 项目根目录是backend的父目录 + backend_dir = Path(__file__).parent.parent.parent + upload_dir = str(backend_dir / upload_dir) + self.backend = LocalStorageBackend(upload_dir) + + def generate_file_path(self, knowledge_base_id: int, filename: str) -> str: + """Generate unique file path for storage.""" + # Sanitize filename + safe_filename = FileUtils.sanitize_filename(filename) + + # Generate unique identifier + file_id = str(uuid.uuid4()) + + # Create path: kb_{id}/{file_id}_{filename} + return f"kb_{knowledge_base_id}/{file_id}_{safe_filename}" + + async def upload_file( + self, + file: UploadFile, + knowledge_base_id: int + ) -> Dict[str, Any]: + """Upload file using configured storage backend.""" + file_path = self.generate_file_path(knowledge_base_id, file.filename) + return await self.backend.upload_file(file, file_path) + + async def delete_file(self, file_path: str) -> bool: + """Delete file using configured storage backend.""" + return await self.backend.delete_file(file_path) + + async def get_file_url(self, file_path: str) -> Optional[str]: + """Get file access URL.""" + return await self.backend.get_file_url(file_path) + + async def file_exists(self, file_path: str) -> bool: + """Check if file exists.""" + return await self.backend.file_exists(file_path) + + +# Global storage service instance +storage_service = StorageService() \ No newline at end of file diff --git a/th_agenter/services/table_metadata_service.py b/th_agenter/services/table_metadata_service.py new file mode 100644 index 0000000..84c6187 --- /dev/null +++ b/th_agenter/services/table_metadata_service.py @@ -0,0 +1,455 @@ +"""表元数据管理服务""" + +import json +from typing import List, Dict, Any, Optional +from sqlalchemy.orm import Session +from sqlalchemy import select, func +from datetime import datetime + +from ..models.table_metadata import TableMetadata +from ..models.database_config import DatabaseConfig +from utils.util_exceptions import ValidationError, NotFoundError +from .postgresql_tool_manager import get_postgresql_tool +from .mysql_tool_manager import get_mysql_tool +from loguru import logger + +class TableMetadataService: + """表元数据管理服务""" + + def __init__(self, db_session: Session): + self.session = db_session + self.postgresql_tool = get_postgresql_tool() + self.mysql_tool = get_mysql_tool() + + async def collect_and_save_table_metadata( + self, + user_id: int, + database_config_id: int, + table_names: List[str] + ) -> Dict[str, Any]: + """收集并保存表元数据""" + self.session.desc = f"为用户 {user_id} 收集数据库 {database_config_id} 的表元数据" + try: + # 获取数据库配置 + stmt = select(DatabaseConfig).where( + DatabaseConfig.id == database_config_id, + DatabaseConfig.created_by == user_id + ) + db_config = (await self.session.execute(stmt)).scalar_one_or_none() + + if not db_config: + self.session.desc = "ERROR: 数据库配置不存在" + raise NotFoundError("数据库配置不存在") + + # 根据数据库类型选择相应的工具 + if db_config.db_type.lower() == 'postgresql': + db_tool = self.postgresql_tool + elif db_config.db_type.lower() == 'mysql': + db_tool = self.mysql_tool + else: + self.session.desc = f"ERROR: 不支持的数据库类型: {db_config.db_type}, 期望为postgresql或mysql" + raise Exception(f"不支持的数据库类型: {db_config.db_type}") + + # 检查是否已有连接,如果没有则建立连接 + user_id_str = str(user_id) + if user_id_str not in db_tool.connections: + connection_config = { + 'host': db_config.host, + 'port': db_config.port, + 'database': db_config.database, + 'username': db_config.username, + 'password': self._decrypt_password(db_config.password) + } + + # 连接数据库 + connect_result = await db_tool.execute( + operation="connect", + connection_config=connection_config, + user_id=user_id_str + ) + + if not connect_result.success: + self.session.desc = f"ERROR: 数据库连接失败: {connect_result.error}" + raise Exception(f"数据库连接失败: {connect_result.error}") + + self.session.desc = f"SUCCESS: 为用户 {user_id} 建立了新的{db_config.db_type}数据库连接" + else: + self.session.desc = f"SUCCESS: 复用用户 {user_id} 的现有{db_config.db_type}数据库连接" + + collected_tables = [] + failed_tables = [] + + for table_name in table_names: + try: + # 收集表元数据 + metadata = await self._collect_single_table_metadata( + user_id, table_name, db_config.db_type + ) + + # 保存或更新元数据 + table_metadata = await self._save_table_metadata( + user_id, database_config_id, table_name, metadata + ) + + collected_tables.append({ + 'table_name': table_name, + 'metadata_id': table_metadata.id, + 'columns_count': len(metadata['columns_info']), + 'sample_rows': len(metadata['sample_data']) + }) + + except Exception as e: + self.session.desc = f"ERROR: 收集表 {table_name} 元数据失败: {str(e)}" + failed_tables.append({ + 'table_name': table_name, + 'error': str(e) + }) + + return { + 'success': True, + 'collected_tables': collected_tables, + 'failed_tables': failed_tables, + 'total_collected': len(collected_tables), + 'total_failed': len(failed_tables) + } + + except Exception as e: + self.session.desc = f"ERROR: 收集表元数据失败: {str(e)}" + return { + 'success': False, + 'message': str(e) + } + + async def _collect_single_table_metadata( + self, + user_id: int, + table_name: str, + db_type: str + ) -> Dict[str, Any]: + """收集单个表的元数据""" + self.session.desc = f"为用户 {user_id} 收集表 {table_name} 的元数据" + # 根据数据库类型选择相应的工具 + if db_type.lower() == 'postgresql': + db_tool = self.postgresql_tool + elif db_type.lower() == 'mysql': + db_tool = self.mysql_tool + else: + self.session.desc = f"ERROR: 不支持的数据库类型: {db_type}, 期望为postgresql或mysql" + raise Exception(f"不支持的数据库类型: {db_type}") + + # 获取表结构 + schema_result = await db_tool.execute( + operation="describe_table", + user_id=str(user_id), + table_name=table_name + ) + + if not schema_result.success: + self.session.desc = f"ERROR: 获取表 {table_name} 结构失败: {schema_result.error}" + raise Exception(f"获取表结构失败: {schema_result.error}") + + schema_data = schema_result.result + + # 获取示例数据(前5条) + sample_result = await db_tool.execute( + operation="execute_query", + user_id=str(user_id), + sql_query=f"SELECT * FROM {table_name} LIMIT 5", + limit=5 + ) + + sample_data = [] + if sample_result.success: + sample_data = sample_result.result.get('data', []) + + # 获取行数统计 + count_result = await db_tool.execute( + operation="execute_query", + user_id=str(user_id), + sql_query=f"SELECT COUNT(*) as total_rows FROM {table_name}", + limit=1 + ) + + row_count = 0 + if count_result.success and count_result.result.get('data'): + row_count = count_result.result['data'][0].get('total_rows', 0) + + self.session.desc = f"SUCCESS: 为用户 {user_id} 收集表 {table_name} 的元数据, 包含 {len(schema_data.get('columns', []))} 列, {row_count} 行数据" + + return { + 'columns_info': schema_data.get('columns', []), + 'primary_keys': schema_data.get('primary_keys', []), + 'foreign_keys': schema_data.get('foreign_keys', []), + 'indexes': schema_data.get('indexes', []), + 'sample_data': sample_data, + 'row_count': row_count, + 'table_comment': schema_data.get('table_comment', '') + } + + async def _save_table_metadata( + self, + user_id: int, + database_config_id: int, + table_name: str, + metadata: Dict[str, Any] + ) -> TableMetadata: + """保存表元数据""" + self.session.desc = f"为用户 {user_id} 保存表 {table_name} 的元数据" + + # 检查是否已存在 + stmt = select(TableMetadata).where( + TableMetadata.created_by == user_id, + TableMetadata.database_config_id == database_config_id, + TableMetadata.table_name == table_name + ) + existing = (await self.session.execute(stmt)).scalar_one_or_none() + + if existing: + self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" + # 更新现有记录 + existing.columns_info = metadata['columns_info'] + existing.primary_keys = metadata['primary_keys'] + existing.foreign_keys = metadata['foreign_keys'] + existing.indexes = metadata['indexes'] + existing.sample_data = metadata['sample_data'] + existing.row_count = metadata['row_count'] + existing.table_comment = metadata['table_comment'] + existing.last_synced_at = datetime.utcnow() + + await self.session.commit() + await self.session.refresh(existing) + return existing + else: + self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" + # 创建新记录 + table_metadata = TableMetadata( + created_by=user_id, + database_config_id=database_config_id, + table_name=table_name, + table_schema='public', + table_type='BASE TABLE', + table_comment=metadata['table_comment'], + columns_info=metadata['columns_info'], + primary_keys=metadata['primary_keys'], + foreign_keys=metadata['foreign_keys'], + indexes=metadata['indexes'], + sample_data=metadata['sample_data'], + row_count=metadata['row_count'], + is_enabled_for_qa=True, + last_synced_at=datetime.utcnow() + ) + + self.session.add(table_metadata) + await self.session.commit() + await self.session.refresh(table_metadata) + self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" + return table_metadata + + async def save_table_metadata_config( + self, + user_id: int, + database_config_id: int, + table_names: List[str] + ) -> Dict[str, Any]: + """保存表元数据配置(简化版,只保存基本信息)""" + self.session.desc = f"为用户 {user_id} 保存数据库配置 {database_config_id} 表 {table_names} 的元数据配置" + # 获取数据库配置 + stmt = select(DatabaseConfig).where( + DatabaseConfig.id == database_config_id, + DatabaseConfig.user_id == user_id + ) + db_config = (await self.session.execute(stmt)).scalar_one_or_none() + + if not db_config: + self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" + raise NotFoundError("数据库配置不存在") + + saved_tables = [] + failed_tables = [] + + for table_name in table_names: + try: + # 检查是否已存在 + stmt = select(TableMetadata).where( + TableMetadata.user_id == user_id, + TableMetadata.database_config_id == database_config_id, + TableMetadata.table_name == table_name + ) + existing = (await self.session.execute(stmt)).scalar_one_or_none() + + if existing: + self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置" + # 更新现有记录 + existing.is_enabled_for_qa = True + existing.last_synced_at = datetime.utcnow() + saved_tables.append({ + 'table_name': table_name, + 'action': 'updated' + }) + else: + self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置" + # 创建新记录 + metadata = TableMetadata( + created_by=user_id, + database_config_id=database_config_id, + table_name=table_name, + table_schema='public', # 默认值 + table_type='table', # 默认值 + table_comment='', + columns_count=0, # 后续可通过collect接口更新 + row_count=0, # 后续可通过collect接口更新 + is_enabled_for_qa=True, + qa_description='', + business_context='', + sample_data='{}', + column_info='{}', + last_synced_at=datetime.utcnow() + ) + + self.session.add(metadata) + saved_tables.append({ + 'table_name': table_name, + 'action': 'created' + }) + + except Exception as e: + self.session.desc = f"ERROR: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置失败: {str(e)}" + failed_tables.append({ + 'table_name': table_name, + 'error': str(e) + }) + + # 提交事务 + await self.session.commit() + self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_names} 的元数据配置" + return { + 'saved_tables': saved_tables, + 'failed_tables': failed_tables, + 'total_saved': len(saved_tables), + 'total_failed': len(failed_tables) + } + + + async def get_user_table_metadata( + self, + user_id: int, + database_config_id: Optional[int] = None + ) -> List[TableMetadata]: + """获取用户的表元数据列表""" + self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表元数据列表" + stmt = select(TableMetadata).where(TableMetadata.created_by == user_id) + + if database_config_id: + stmt = stmt.where(TableMetadata.database_config_id == database_config_id) + else: + self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在" + raise NotFoundError("数据库配置不存在") + stmt = stmt.where(TableMetadata.is_enabled_for_qa == True) + return (await self.session.scalars(stmt)).all() + + async def get_table_metadata_by_name( + self, + user_id: int, + database_config_id: int, + table_name: str + ) -> Optional[TableMetadata]: + """根据表名获取表元数据""" + self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" + stmt = select(TableMetadata).where( + TableMetadata.created_by == user_id, + TableMetadata.database_config_id == database_config_id, + TableMetadata.table_name == table_name + ) + return (await self.session.execute(stmt)).scalar_one_or_none() + + async def update_table_qa_settings( + self, + user_id: int, + metadata_id: int, + settings: Dict[str, Any] + ) -> bool: + """更新表的问答设置""" + self.session.desc = f"更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置" + try: + stmt = select(TableMetadata).where( + TableMetadata.id == metadata_id, + TableMetadata.created_by == user_id + ) + metadata = (await self.session.execute(stmt)).scalar_one_or_none() + + if not metadata: + self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在" + return False + + if 'is_enabled_for_qa' in settings: + metadata.is_enabled_for_qa = settings['is_enabled_for_qa'] + if 'qa_description' in settings: + metadata.qa_description = settings['qa_description'] + if 'business_context' in settings: + metadata.business_context = settings['business_context'] + + await self.session.commit() + return True + + except Exception as e: + self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}" + await self.session.rollback() + return False + + async def save_table_metadata( + self, + user_id: int, + database_config_id: int, + table_name: str, + columns_info: List[Dict[str, Any]], + primary_keys: List[str], + row_count: int, + table_comment: str = '' + ) -> TableMetadata: + """保存单个表的元数据""" + self.session.desc = f"保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据" + # 检查是否已存在 + stmt = select(TableMetadata).where( + TableMetadata.created_by == user_id, + TableMetadata.database_config_id == database_config_id, + TableMetadata.table_name == table_name + ) + existing = (await self.session.execute(stmt)).scalar_one_or_none() + + if existing: + self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 已存在,更新其元数据" + # 更新现有记录 + existing.columns_info = columns_info + existing.primary_keys = primary_keys + existing.row_count = row_count + existing.table_comment = table_comment + existing.last_synced_at = datetime.utcnow() + await self.session.commit() + return existing + else: + self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 不存在,创建新记录" + # 创建新记录 + metadata = TableMetadata( + created_by=user_id, + database_config_id=database_config_id, + table_name=table_name, + table_schema='public', + table_type='BASE TABLE', + table_comment=table_comment, + columns_info=columns_info, + primary_keys=primary_keys, + row_count=row_count, + is_enabled_for_qa=True, + last_synced_at=datetime.utcnow() + ) + + self.session.add(metadata) + await self.session.commit() + await self.session.refresh(metadata) + return metadata + + def _decrypt_password(self, encrypted_password: str) -> str: + """解密密码(需要实现加密逻辑)""" + # 这里需要实现与DatabaseConfigService相同的解密逻辑 + # 暂时返回原始密码 + return encrypted_password \ No newline at end of file diff --git a/th_agenter/services/tools/__init__.py b/th_agenter/services/tools/__init__.py new file mode 100644 index 0000000..f4ed7ac --- /dev/null +++ b/th_agenter/services/tools/__init__.py @@ -0,0 +1,24 @@ +"""Agent tools package.""" + +from .weather import WeatherQueryTool +from .search import TavilySearchTool +from .datetime_tool import DateTimeTool +from th_agenter.services.mcp.postgresql_mcp import PostgreSQLMCPTool +from th_agenter.services.mcp.mysql_mcp import MySQLMCPTool + + +# Try to import LangChain native tools if available +# TODO: 暂屏蔽 +# try: +# from .langchain_native_tools import LANGCHAIN_NATIVE_TOOLS +# except ImportError: +# LANGCHAIN_NATIVE_TOOLS = [] + +__all__ = [ + 'WeatherQueryTool', + 'TavilySearchTool', + 'DateTimeTool', + 'PostgreSQLMCPTool', + 'MySQLMCPTool', + 'LANGCHAIN_NATIVE_TOOLS' +] \ No newline at end of file diff --git a/th_agenter/services/tools/datetime_tool.py b/th_agenter/services/tools/datetime_tool.py new file mode 100644 index 0000000..e95ac58 --- /dev/null +++ b/th_agenter/services/tools/datetime_tool.py @@ -0,0 +1,180 @@ +from langchain.tools import BaseTool +from pydantic import BaseModel, Field +from typing import Optional, Type, Literal, ClassVar +import datetime +import pytz +import logging + +logger = logging.getLogger("datetime_tool") + +# 定义输入参数模型(使用Pydantic替代原get_parameters()) +class DateTimeInput(BaseModel): + operation: Literal["current_time", "timezone_convert", "date_diff", "add_time", "format_date"] = Field( + description="操作类型: current_time(当前时间), timezone_convert(时区转换), " + "date_diff(日期差), add_time(时间加减), format_date(格式化日期)" + ) + timezone: Optional[str] = Field( + default="UTC", + description="时区名称 (e.g., 'UTC', 'Asia/Shanghai')" + ) + date_string: Optional[str] = Field( + description="日期字符串 (格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)" + ) + target_timezone: Optional[str] = Field( + description="目标时区(用于时区转换)" + ) + days: Optional[int] = Field( + default=0, + description="要加减的天数" + ) + hours: Optional[int] = Field( + default=0, + description="要加减的小时数" + ) + format: Optional[str] = Field( + default="%Y-%m-%d %H:%M:%S", + description="日期格式字符串 (e.g., '%Y-%m-%d %H:%M:%S')" + ) + +class DateTimeTool(BaseTool): + """日期时间操作工具(支持时区转换、日期计算等)""" + + name: ClassVar[str] = "datetime_tool" + description: ClassVar[str] = """执行日期时间相关操作,包括: + - 获取当前时间 + - 时区转换 + - 计算日期差 + - 日期时间加减 + - 格式化日期 + 使用时必须指定operation参数确定操作类型。""" + args_schema: Type[BaseModel] = DateTimeInput + + def _parse_datetime(self, date_string: str, timezone_str: str = "UTC") -> datetime.datetime: + """解析日期字符串(私有方法)""" + tz = pytz.timezone(timezone_str) + formats = [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y-%m-%d", + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d", + "%d/%m/%Y %H:%M:%S", + "%d/%m/%Y", + "%m/%d/%Y %H:%M:%S", + "%m/%d/%Y" + ] + + for fmt in formats: + try: + dt = datetime.datetime.strptime(date_string, fmt) + return tz.localize(dt) + except ValueError: + continue + raise ValueError(f"无法解析日期字符串: {date_string}") + + def _run(self, + operation: str, + timezone: str = "UTC", + date_string: Optional[str] = None, + target_timezone: Optional[str] = None, + days: int = 0, + hours: int = 0, + format: str = "%Y-%m-%d %H:%M:%S") -> dict: + """同步执行日期时间操作""" + logger.info(f"执行日期时间操作: {operation}") + + try: + if operation == "current_time": + tz = pytz.timezone(timezone) + now = datetime.datetime.now(tz) + return { + "status": "success", + "result": { + "formatted": now.strftime(format), + "iso": now.isoformat(), + "timestamp": now.timestamp(), + "timezone": timezone + }, + "summary": f"当前时间 ({timezone}): {now.strftime(format)}" + } + + elif operation == "timezone_convert": + if not date_string or not target_timezone: + raise ValueError("必须提供date_string和target_timezone参数") + + source_dt = self._parse_datetime(date_string, timezone) + target_dt = source_dt.astimezone(pytz.timezone(target_timezone)) + + return { + "status": "success", + "result": { + "source": source_dt.strftime(format), + "target": target_dt.strftime(format), + "source_tz": timezone, + "target_tz": target_timezone + }, + "summary": f"时区转换: {source_dt.strftime(format)} → {target_dt.strftime(format)}" + } + + elif operation == "date_diff": + if not date_string: + raise ValueError("必须提供date_string参数") + + target_dt = self._parse_datetime(date_string, timezone) + current_dt = datetime.datetime.now(pytz.timezone(timezone)) + delta = target_dt - current_dt + + return { + "status": "success", + "result": { + "days": delta.days, + "hours": delta.seconds // 3600, + "total_seconds": delta.total_seconds(), + "is_future": delta.days > 0 + }, + "summary": f"日期差: {abs(delta.days)}天 {delta.seconds//3600}小时" + } + + elif operation == "add_time": + base_dt = self._parse_datetime(date_string, timezone) if date_string \ + else datetime.datetime.now(pytz.timezone(timezone)) + new_dt = base_dt + datetime.timedelta(days=days, hours=hours) + + return { + "status": "success", + "result": { + "original": base_dt.strftime(format), + "new": new_dt.strftime(format), + "delta": f"{days}天 {hours}小时" + }, + "summary": f"时间计算: {base_dt.strftime(format)} + {days}天 {hours}小时 = {new_dt.strftime(format)}" + } + + elif operation == "format_date": + dt = self._parse_datetime(date_string, timezone) if date_string \ + else datetime.datetime.now(pytz.timezone(timezone)) + formatted = dt.strftime(format) + + return { + "status": "success", + "result": { + "original": dt.isoformat(), + "formatted": formatted + }, + "summary": f"格式化结果: {formatted}" + } + + else: + raise ValueError(f"未知操作类型: {operation}") + + except Exception as e: + logger.error(f"操作失败: {str(e)}") + return { + "status": "error", + "message": str(e), + "operation": operation + } + + async def _arun(self, **kwargs): + """异步执行""" + return self._run(**kwargs) \ No newline at end of file diff --git a/th_agenter/services/tools/search.py b/th_agenter/services/tools/search.py new file mode 100644 index 0000000..82384f9 --- /dev/null +++ b/th_agenter/services/tools/search.py @@ -0,0 +1,75 @@ +"""基于TavilySearch的搜索工具""" + +from th_agenter.core.config import get_settings +from loguru import logger + +from langchain.tools import BaseTool +from langchain_community.tools.tavily_search import TavilySearchResults +from pydantic import BaseModel, Field, PrivateAttr +from typing import Optional, Type, ClassVar +from langchain_tavily import TavilySearch + +# 定义输入参数模型(替代原get_parameters()) +class SearchInput(BaseModel): + query: str = Field(description="搜索查询内容") + max_results: Optional[int] = Field( + default=5, + description="返回结果的最大数量(默认:5)" + ) + topic: Optional[str] = Field( + default="general", + description="搜索主题,可选值:general, academic, news, places" + ) + + +class TavilySearchTool(BaseTool): + name:ClassVar[str] = "tavily_search_tool" + description:ClassVar[str] = """使用Tavily搜索引擎进行网络搜索,可以获取最新信息。 + 输入应该包含搜索查询(query),可选参数包括max_results和topic。""" # 替代get_description() + args_schema: Type[BaseModel] = SearchInput # 用Pydantic模型定义参数 + _tavily_api_key: str = PrivateAttr() + _search_client: TavilySearchResults = PrivateAttr() + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._tavily_api_key = get_settings().tool.tavily_api_key + if not self._tavily_api_key: + raise ValueError("Tavily API key not found in settings") + + # 初始化Tavily客户端 + self._search_client = TavilySearch( + tavily_api_key=self._tavily_api_key + ) + + def _run(self, query: str, max_results: int = 5, topic: str = "general"): + try: + logger.info(f"执行搜索:{query}") + # 调用Tavily(LangChain已内置Tavily工具,这里直接使用) + results = self._search_client.run({ + "query": query, + "max_results": max_results, + "topic": topic + }) + + # 格式化结果(根据Tavily的实际返回结构调整) + if isinstance(results, list): + return { + "status": "success", + "results": [ + { + "title": r.get("title", ""), + "url": r.get("url", ""), + "content": r.get("content", "")[:200] + "..." + } for r in results + ] + } + else: + return {"status": "error", "message": "Unexpected result format"} + + except Exception as e: + logger.error(f"搜索失败: {str(e)}") + return {"status": "error", "message": str(e)} + + async def _arun(self, **kwargs): + """异步版本""" + """直接调用同步版本""" + return self._run(**kwargs) # 直接委托给同步方法 \ No newline at end of file diff --git a/th_agenter/services/tools/weather.py b/th_agenter/services/tools/weather.py new file mode 100644 index 0000000..60066ee --- /dev/null +++ b/th_agenter/services/tools/weather.py @@ -0,0 +1,79 @@ +from langchain.tools import BaseTool +from pydantic import BaseModel, Field, PrivateAttr +from typing import Optional, Type, ClassVar +import requests +import logging +from th_agenter.core.config import get_settings + +logger = logging.getLogger("weather_tool") + +# 定义输入参数模型(替代原get_parameters()) +class WeatherInput(BaseModel): + location: str = Field( + description="城市名称,例如:'北京',只能是单个城市", + examples=["北京", "上海", "New York"] + ) + +class WeatherQueryTool(BaseTool): + """心知天气API查询工具(LangChain标准版)""" + name: ClassVar[str] = "weather_query_tool" + description: ClassVar[str] = """通过心知天气API查询实时天气数据。获取指定城市的当前天气信息,包括温度、湿度、天气状况等。""" + args_schema: Type[BaseModel] = WeatherInput # 参数规范 + # 使用PrivateAttr声明不参与验证的私有属性 + _api_key: str = PrivateAttr() + _base_params: dict = PrivateAttr() + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._api_key = get_settings().tool.weather_api_key + if not self._api_key: + raise ValueError("Weather API key not found in settings") + + # 基础请求参数 + self._base_params = { + "key": self._api_key, + "language": "zh-Hans", + "unit": "c" + } + + def _run(self, location: str) -> dict: + """同步执行天气查询""" + try: + logger.info(f"查询天气 - 城市: {location}") + + # 构建API请求 + url = "https://api.seniverse.com/v3/weather/now.json" + params = {**self._base_params, "location": location} + + response = requests.get(url, params=params, timeout=10) + response.raise_for_status() + + data = response.json() + + # 处理API响应 + if 'results' not in data: + error_msg = data.get('status', 'API返回格式异常') + raise ValueError(f"天气API错误: {error_msg}") + + weather = data['results'][0]['now'] + return { + "status": "success", + "location": location, + "temperature": weather["temperature"], + "condition": weather["text"], + "humidity": weather.get("humidity", "N/A"), + "wind": weather.get("wind_direction", "N/A"), + "full_data": weather + } + + except requests.exceptions.RequestException as e: + logger.error(f"网络请求失败: {str(e)}") + return {"status": "error", "message": f"网络错误: {str(e)}"} + except Exception as e: + logger.error(f"查询失败: {str(e)}") + return {"status": "error", "message": str(e)} + + async def _arun(self, location: str) -> dict: + """异步执行(示例实现)""" + # 实际项目中可以用aiohttp替换requests + return self._run(location) \ No newline at end of file diff --git a/th_agenter/services/user.py b/th_agenter/services/user.py new file mode 100644 index 0000000..8da17ce --- /dev/null +++ b/th_agenter/services/user.py @@ -0,0 +1,214 @@ +"""User service for managing user operations.""" + +from typing import Optional, List, Tuple +from sqlalchemy.orm import Session +from sqlalchemy import select, or_, desc, text + +from ..models.user import User +from utils.util_schemas import UserCreate, UserUpdate +from utils.util_exceptions import DatabaseError, ValidationError +from .auth import AuthService + +class UserService: + """Service for user management operations.""" + def __init__(self, session: Session): ### Async: OK + self.session = session + self.session.desc = "创建UserService;-1" + + async def get_user_by_email(self, email: str) -> Optional[User]: ### Async: OK + """Get user by email.""" + self.session.desc = f"通过邮箱 [{email}] 获取用户" + stmt = select(User).where(User.email == email) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_user_by_username(self, username: str) -> Optional[User]: ### Async: OK + """Get user by username.""" + self.session.desc = f"通过用户名 [{username}] 获取用户" + stmt = select(User).where(User.username == username) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def create_user(self, user_data: UserCreate) -> User: ### Async: OK + """Create a new user.""" + self.session.desc = f"创建用户 [{user_data.username}]" + # Validate input + if len(user_data.password) < 6: + self.session.desc = "ERROR: 密码长度必须至少为6个字符" + raise ValidationError("密码长度必须至少为6个字符") + + # Hash password + hashed_password = self.get_password_hash(user_data.password) + + self.session.desc = f"对密码 [{user_data.password}] 进行哈希处理完毕" + # Create user + db_user = User( + username=user_data.username, + email=user_data.email, + hashed_password=hashed_password, + full_name=user_data.full_name, + is_active=True + ) + + self.session.desc = f"创建用户 [{user_data.username}] 到数据库" + self.session.add(db_user) + self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - add" + await self.session.commit() + self.session.desc = f"创建用户 [{user_data.username}] 到数据库 - commit" + await self.session.refresh(db_user) + + self.session.desc = f"创建用户 [{user_data.username}] 成功" + return db_user + + def get_password_hash(self, password: str) -> str: ### Async: OK + """Hash a password.""" + self.session.desc = f"对密码 [{password}] 进行哈希处理" + return AuthService.get_password_hash(password) + + async def get_user_by_id(self, user_id: int) -> Optional[User]: ### DrGraph: OK + """Get user by ID.""" + self.session.desc = f"通过ID{user_id}获取用户" + from sqlalchemy.orm import noload + stmt = select(User).where(User.id == user_id).options(noload(User.roles)) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + def verify_password(self, plain_password: str, hashed_password: str) -> bool: ### DrGraph: OK + """Verify a password against its hash.""" + self.session.desc = f"验证密码 [{plain_password}] 与哈希 [{hashed_password}] 是否匹配" + return AuthService.verify_password(plain_password, hashed_password) + + async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[User]: ### DrGraph: OK + """Update user information.""" + self.session.desc = f"更新用户ID为{user_id}的信息" + user = await self.get_user_by_id(user_id) + if not user: + return None + + # Update fields + update_data = user_update.model_dump(exclude_unset=True) + + if "password" in update_data: + update_data["hashed_password"] = self.get_password_hash(update_data.pop("password")) + + # session.desc = f"更新用户ID为{user_id}的信息" + for field, value in update_data.items(): + setattr(user, field, value) + + # Audit fields are set automatically by SQLAlchemy event listener + await self.session.commit() + await self.session.refresh(user) + + return user + + async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: ### DrGraph: OK + """Change user password.""" + self.session.desc = f"更改用户ID为{user_id}的密码" + user = await self.get_user_by_id(user_id) + if not user: + raise ValidationError("User not found") + + # Verify current password + if not self.verify_password(current_password, user.hashed_password): + raise ValidationError("Current password is incorrect") + + # Validate new password + if len(new_password) < 6: + raise ValidationError("New password must be at least 6 characters long") + + # Hash new password + hashed_password = self.get_password_hash(new_password) + + # Update password + user.hashed_password = hashed_password + await self.session.commit() + + self.session.desc = f"用户ID为{user_id}的密码已成功更改" + return True + + async def reset_password(self, user_id: int, new_password: str) -> bool: ### DrGraph: OK + """Reset user password (admin only, no current password required).""" + self.session.desc = f"重置用户ID为{user_id}的密码" + user = await self.get_user_by_id(user_id) + if not user: + raise ValidationError("User not found") + + # Validate new password + if len(new_password) < 6: + raise ValidationError("New password must be at least 6 characters long") + + # Hash new password + hashed_password = self.get_password_hash(new_password) + # Update password + user.hashed_password = hashed_password + await self.session.commit() + + self.session.desc = f"用户ID为{user_id}的密码已成功重置" + return True + + async def get_users_with_filters( ### DrGraph: OK + self, + skip: int = 0, + limit: int = 100, + search: Optional[str] = None, + role_id: Optional[int] = None, + is_active: Optional[bool] = None + ) -> Tuple[List[User], int]: + """Get users with filters and return total count.""" + # Build base query + stmt = select(User).order_by(desc(User.created_at)) + + # Apply filters + if search: + search_term = f"%{search}%" + stmt = stmt.where( + or_( + User.username.ilike(search_term), + User.email.ilike(search_term), + User.full_name.ilike(search_term) + ) + ) + + if role_id is not None: + from ..models.permission import UserRole + stmt = stmt.join(UserRole).where(UserRole.role_id == role_id) + + if is_active is not None: + stmt = stmt.where(User.is_active == is_active) + + # Get total count + count_stmt = select(text("COUNT(*)")).select_from(stmt.subquery()) + total_result = await self.session.execute(count_stmt) + total = total_result.scalar_one() + self.session.desc = f"获取用户总数为{total}" + + # Apply pagination + stmt = stmt.offset(skip).limit(limit) + users_result = await self.session.execute(stmt) + users = users_result.scalars().all() + + return users, total + + async def get_users(self, skip: int = 0, limit: int = 100) -> List[User]: + """Get all users with pagination.""" + # session.desc = f"分页获取用户列表,跳过{skip}条,限制{limit}条" + stmt = select(User).offset(skip).limit(limit) + result = await self.session.execute(stmt) + return result.scalars().all() + + async def delete_user(self, user_id: int) -> bool: ### DrGraph: OK + """删除一个用户.""" + self.session.desc = f"删除ID为{user_id}用户" + user = await self.get_user_by_id(user_id) + if not user: + return False + + # Manually delete related records to avoid cascade issues + # Delete user_roles records + await self.session.execute(text("DELETE FROM user_roles WHERE user_id = :user_id"), parameters={"user_id": user_id}) + # Now delete the user + await self.session.delete(user) + await self.session.commit() + + self.session.desc = f"用户ID为{user_id}已成功删除" + return True \ No newline at end of file diff --git a/th_agenter/services/workflow_engine.py b/th_agenter/services/workflow_engine.py new file mode 100644 index 0000000..d958a4f --- /dev/null +++ b/th_agenter/services/workflow_engine.py @@ -0,0 +1,947 @@ +"""Workflow execution engine.""" + +import asyncio +import json +import time +from datetime import datetime +from typing import Dict, Any, Optional, List +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models.workflow import Workflow, WorkflowExecution, NodeExecution, ExecutionStatus, NodeType +from ..models.llm_config import LLMConfig +from ..services.llm_service import LLMService + +from ..db.database import get_session, AsyncSessionFactory +from loguru import logger + + +class WorkflowEngine: + """工作流执行引擎""" + + def __init__(self, session: AsyncSession): + self.session = session + self.llm_service = LLMService() + + + async def execute_workflow(self, workflow: Workflow, input_data: Optional[Dict[str, Any]] = None, + user_id: int = None, session: AsyncSession = None): + """执行工作流""" + from ..schemas.workflow import WorkflowExecutionResponse + + id = workflow.id + if session: + self.session = session + + session.desc = f"执行工作流数据 - {id} > Enter" + # 创建执行记录 + execution = WorkflowExecution( + workflow_id=id, + status=ExecutionStatus.RUNNING, + input_data=input_data or {}, + executor_id=user_id, + started_at=datetime.now().isoformat() + ) + session.desc = f"执行工作流数据 - {id} > 创建执行记录" + execution.set_audit_fields(user_id) + + self.session.add(execution) + await self.session.commit() + await self.session.refresh(execution) + session.desc = f"执行工作流数据 - {id} > 添加执行记录" + + try: + # 重新加载 workflow 对象,确保数据是最新的 + from sqlalchemy import select + from ..models.workflow import Workflow + result = await session.execute( + select(Workflow).where(Workflow.id == id) + ) + workflow = result.scalar_one_or_none() + session.desc = f"执行工作流数据 - {id} > reload workflow" + + # 解析工作流定义 + definition = workflow.definition + nodes = {node['id']: node for node in definition['nodes']} + connections = definition['connections'] + session.desc = f"执行工作流数据 - {id} > definition {id}" + + # 构建节点依赖图 + node_graph = self._build_node_graph(nodes, connections) + session.desc = f"执行工作流数据 - {id} > _build_node_graph {id}" + + # 执行工作流 + result = await self._execute_nodes(execution, nodes, node_graph, input_data or {}) + session.desc = f"执行工作流数据 - {id} > _execute_nodes {id}" + + # 更新执行状态 + execution.status = ExecutionStatus.COMPLETED + execution.output_data = result + execution.completed_at = datetime.now().isoformat() + session.desc = f"执行工作流数据 - {id} > execution {id}" + + except Exception as e: + logger.error(f"工作流执行失败 - {id}: {str(e)}") + execution.status = ExecutionStatus.FAILED + execution.error_message = str(e) + execution.completed_at = datetime.now().isoformat() + + execution.set_audit_fields(user_id, is_update=True) + session.desc = f"执行工作流数据 - {id} > set_audit_fields {id}" + await self.session.commit() + await self.session.refresh(execution) + session.desc = f"执行工作流数据 - {id} > refresh {id}" + + from sqlalchemy import select + from ..models.workflow import NodeExecution + result = await session.execute( + select(NodeExecution).where(NodeExecution.workflow_execution_id == execution.id) + ) + node_executions = result.scalars().all() + session.desc = f"执行工作流数据 - {id} > load node_executions {id}" + node_executions = [node.to_dict() for node in node_executions] + execution_dict = execution.to_dict() + execution_dict['node_executions'] = node_executions + session.desc = f"执行工作流数据 - {id} > build response {id}" + + return WorkflowExecutionResponse(**execution_dict) + + async def execute_workflow_stream(self, workflow: 'Workflow', input_data: Optional[Dict[str, Any]] = None, + user_id: int = None, session: AsyncSession = None): + """流式执行工作流,实时推送节点状态""" + from ..schemas.workflow import WorkflowExecutionResponse + from typing import AsyncGenerator + + if session: + self.session = session + + # 创建执行记录 + execution = WorkflowExecution( + workflow_id=workflow.id, + status=ExecutionStatus.RUNNING, + input_data=input_data or {}, + executor_id=user_id, + started_at=datetime.now().isoformat() + ) + execution.set_audit_fields(user_id) + + self.session.add(execution) + await self.session.commit() + await self.session.refresh(execution) + + # 发送工作流开始执行的消息 + yield { + 'type': 'workflow_status', + 'execution_id': execution.id, + 'status': 'started', + 'data': { + "workflow_id": workflow.id, + "workflow_name": workflow.name, + "input_data": input_data or {}, + "started_at": execution.started_at + }, + 'timestamp': datetime.now().isoformat() + } + + try: + # 解析工作流定义 + definition = workflow.definition + nodes = {node['id']: node for node in definition['nodes']} + connections = definition['connections'] + + # 构建节点依赖图 + node_graph = self._build_node_graph(nodes, connections) + + # 执行工作流(流式版本) + result = None + async for step_data in self._execute_nodes_stream(execution, nodes, node_graph, input_data or {}): + yield step_data + # 如果是最终结果,保存它 + if step_data.get('type') == 'workflow_result': + result = step_data.get('data', {}) + + # 更新执行状态 + execution.status = ExecutionStatus.COMPLETED + execution.output_data = result + execution.completed_at = datetime.now().isoformat() + + # 发送工作流完成的消息 + yield { + 'type': 'workflow_status', + 'execution_id': execution.id, + 'status': 'completed', + 'data': { + "output_data": result, + "completed_at": execution.completed_at + }, + 'timestamp': datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"工作流执行失败: {str(e)}") + execution.status = ExecutionStatus.FAILED + execution.error_message = str(e) + execution.completed_at = datetime.now().isoformat() + + # 发送工作流失败的消息 + yield { + 'type': 'workflow_status', + 'execution_id': execution.id, + 'status': 'failed', + 'data': { + "error_message": str(e), + "completed_at": execution.completed_at + }, + 'timestamp': datetime.now().isoformat() + } + + execution.set_audit_fields(user_id, is_update=True) + await self.session.commit() + await self.session.refresh(execution) + + def _build_node_graph(self, nodes: Dict[str, Any], connections: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """构建节点依赖图""" + graph = {} + + for node_id, node in nodes.items(): + graph[node_id] = { + 'node': node, + 'inputs': [], # 输入节点 + 'outputs': [] # 输出节点 + } + + for connection in connections: + # 支持两种字段名格式:from/to 和 from_node/to_node + from_node = connection.get('from') or connection.get('from_node') + to_node = connection.get('to') or connection.get('to_node') + + if from_node in graph and to_node in graph: + graph[from_node]['outputs'].append(to_node) + graph[to_node]['inputs'].append(from_node) + + return graph + + async def _execute_nodes(self, execution: WorkflowExecution, nodes: Dict[str, Any], + node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]) -> Dict[str, Any]: + """执行节点""" + # 找到开始节点 + start_nodes = [node_id for node_id, info in node_graph.items() + if info['node']['type'] == 'start'] + + if not start_nodes: + raise ValueError("未找到开始节点") + + if len(start_nodes) > 1: + raise ValueError("存在多个开始节点") + + start_node_id = start_nodes[0] + + # 执行上下文 + context = { + 'workflow_input': workflow_input, + 'node_outputs': {} + } + + # 从开始节点开始执行 + await self._execute_node_recursive(execution, start_node_id, node_graph, context) + + # 找到结束节点的输出作为工作流结果 + end_nodes = [node_id for node_id, info in node_graph.items() + if info['node']['type'] == 'end'] + + if end_nodes: + end_node_id = end_nodes[0] + return context['node_outputs'].get(end_node_id, {}) + + return {} + + async def _execute_nodes_stream(self, execution: WorkflowExecution, nodes: Dict[str, Any], + node_graph: Dict[str, Dict[str, Any]], workflow_input: Dict[str, Any]): + """流式执行节点,实时推送节点状态""" + # 找到开始节点 + start_nodes = [node_id for node_id, info in node_graph.items() + if info['node']['type'] == 'start'] + + if not start_nodes: + raise ValueError("未找到开始节点") + + if len(start_nodes) > 1: + raise ValueError("存在多个开始节点") + + start_node_id = start_nodes[0] + + # 执行上下文 + context = { + 'workflow_input': workflow_input, + 'node_outputs': {} + } + + # 从开始节点开始执行 + async for step_data in self._execute_node_recursive_stream(execution, start_node_id, node_graph, context): + yield step_data + + # 找到结束节点的输出作为工作流结果 + end_nodes = [node_id for node_id, info in node_graph.items() + if info['node']['type'] == 'end'] + + if end_nodes: + end_node_id = end_nodes[0] + result = context['node_outputs'].get(end_node_id, {}) + else: + result = {} + + # 发送最终结果 + yield { + 'type': 'workflow_result', + 'execution_id': execution.id, + 'data': result, + 'timestamp': datetime.now().isoformat() + } + + async def _execute_node_recursive_stream(self, execution: WorkflowExecution, node_id: str, + node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]): + """递归执行节点(流式版本)""" + if node_id in context['node_outputs']: + # 节点已执行过 + return + + node_info = node_graph[node_id] + node = node_info['node'] + + # 等待所有输入节点完成 + for input_node_id in node_info['inputs']: + async for step_data in self._execute_node_recursive_stream(execution, input_node_id, node_graph, context): + yield step_data + + # 发送节点开始执行的消息 + yield { + 'type': 'node_status', + 'execution_id': execution.id, + 'node_id': node_id, + 'status': 'started', + 'data': { + 'node_name': node.get('name', ''), + 'node_type': node.get('type', ''), + 'started_at': datetime.now().isoformat() + }, + 'timestamp': datetime.now().isoformat() + } + + try: + # 执行当前节点 + output = await self._execute_single_node(execution, node, context) + context['node_outputs'][node_id] = output + + # 发送节点完成的消息 + yield { + 'type': 'node_status', + 'execution_id': execution.id, + 'node_id': node_id, + 'status': 'completed', + 'data': { + 'node_name': node.get('name', ''), + 'node_type': node.get('type', ''), + 'output': output, + 'completed_at': datetime.now().isoformat() + }, + 'timestamp': datetime.now().isoformat() + } + + except Exception as e: + # 发送节点失败的消息 + yield { + 'type': 'node_status', + 'execution_id': execution.id, + 'node_id': node_id, + 'status': 'failed', + 'data': { + 'node_name': node.get('name', ''), + 'node_type': node.get('type', ''), + 'error_message': str(e), + 'failed_at': datetime.now().isoformat() + }, + 'timestamp': datetime.now().isoformat() + } + raise + + # 执行所有输出节点 + for output_node_id in node_info['outputs']: + async for step_data in self._execute_node_recursive_stream(execution, output_node_id, node_graph, context): + yield step_data + + async def _execute_node_recursive(self, execution: WorkflowExecution, node_id: str, + node_graph: Dict[str, Dict[str, Any]], context: Dict[str, Any]): + """递归执行节点""" + if node_id in context['node_outputs']: + # 节点已执行过 + return + + node_info = node_graph[node_id] + node = node_info['node'] + + # 等待所有输入节点完成 + for input_node_id in node_info['inputs']: + await self._execute_node_recursive(execution, input_node_id, node_graph, context) + + # 执行当前节点 + output = await self._execute_single_node(execution, node, context) + context['node_outputs'][node_id] = output + + # 执行所有输出节点 + for output_node_id in node_info['outputs']: + await self._execute_node_recursive(execution, output_node_id, node_graph, context) + + async def _execute_single_node(self, execution: WorkflowExecution, node: Dict[str, Any], + context: Dict[str, Any]) -> Dict[str, Any]: + """执行单个节点""" + node_id = node['id'] + node_type = node['type'] + node_name = node['name'] + + # 创建节点执行记录 + node_execution = NodeExecution( + workflow_execution_id=execution.id, + node_id=node_id, + node_type=NodeType(node_type), + node_name=node_name, + status=ExecutionStatus.RUNNING, + started_at=datetime.now().isoformat() + ) + self.session.add(node_execution) + await self.session.commit() + await self.session.refresh(node_execution) + await self.session.refresh(execution) + + start_time = time.time() + + try: + # 准备输入数据 + input_data = self._prepare_node_input(node, context) + + # 为前端显示准备输入数据 + display_input_data = input_data.copy() + + # 对于开始节点,显示的输入应该是workflow_input + if node_type == 'start': + display_input_data = input_data['workflow_input'] + elif node_type == 'llm': + # 对于LLM节点,先执行变量替换以获取处理后的提示词 + config = input_data['node_config'] + prompt_template = config.get('prompt', '') + enable_variable_substitution = config.get('enable_variable_substitution', True) + + if enable_variable_substitution: + processed_prompt = self._substitute_variables(prompt_template, input_data) + else: + processed_prompt = prompt_template + + display_input_data = { + 'original_prompt': prompt_template, + 'processed_prompt': processed_prompt, + 'model_config': config, + 'resolved_inputs': input_data.get('resolved_inputs', {}) + } + + node_execution.input_data = display_input_data + await self.session.commit() + await self.session.refresh(execution) + + # 根据节点类型执行 + if node_type == 'start': + output_data = await self._execute_start_node(node, input_data) + elif node_type == 'end': + output_data = await self._execute_end_node(node, input_data) + elif node_type == 'llm': + output_data = await self._execute_llm_node(node, input_data) + elif node_type == 'condition': + output_data = await self._execute_condition_node(node, input_data) + elif node_type == 'code': + output_data = await self._execute_code_node(node, input_data) + elif node_type == 'http': + output_data = await self._execute_http_node(node, input_data) + else: + raise ValueError(f"不支持的节点类型: {node_type}") + + # 更新执行状态 + end_time = time.time() + node_execution.status = ExecutionStatus.COMPLETED + node_execution.output_data = output_data + node_execution.completed_at = datetime.now().isoformat() + node_execution.duration_ms = int((end_time - start_time) * 1000) + + await self.session.commit() + await self.session.refresh(execution) + + return output_data + + except Exception as e: + logger.error(f"节点 {node_id} 执行失败: {str(e)}") + end_time = time.time() + node_execution.status = ExecutionStatus.FAILED + node_execution.error_message = str(e) + node_execution.completed_at = datetime.now().isoformat() + node_execution.duration_ms = int((end_time - start_time) * 1000) + await self.session.commit() + await self.session.refresh(execution) + + raise + + def _prepare_node_input(self, node: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """准备节点输入数据""" + # 基础输入数据 + input_data = { + 'workflow_input': context['workflow_input'], + 'node_config': node.get('config', {}), + 'previous_outputs': context['node_outputs'] + } + + # 处理节点参数配置 + node_parameters = node.get('parameters', {}) + if node_parameters and 'inputs' in node_parameters: + resolved_inputs = {} + + for param in node_parameters['inputs']: + param_name = param.get('name') + param_source = param.get('source', 'default') + param_default = param.get('default_value') + variable_name = param.get('variable_name', '') + + # 优先使用variable_name,如果存在的话 + if variable_name: + resolved_value = self._resolve_variable_value(variable_name, context) + resolved_inputs[param_name] = resolved_value if resolved_value is not None else param_default + elif param_source == 'workflow': + # 从工作流输入获取 + source_param_name = param.get('source_param_name', param_name) + resolved_inputs[param_name] = context['workflow_input'].get(source_param_name, param_default) + elif param_source == 'node': + # 从其他节点输出获取 + source_node_id = param.get('source_node_id') + source_param_name = param.get('source_param_name', 'data') + + if source_node_id and source_node_id in context['node_outputs']: + source_output = context['node_outputs'][source_node_id] + if isinstance(source_output, dict): + resolved_inputs[param_name] = source_output.get(source_param_name, param_default) + else: + resolved_inputs[param_name] = source_output + else: + resolved_inputs[param_name] = param_default + else: + # 使用默认值 + resolved_inputs[param_name] = param_default + + # 将解析后的参数添加到输入数据 + input_data['resolved_inputs'] = resolved_inputs + + return input_data + + def _resolve_variable_value(self, variable_name: str, context: Dict[str, Any]) -> Any: + """解析变量值,支持格式如 "node_id.output.field_name" 或更深层路径""" + try: + # 解析变量名格式:node_id.output.field_name 或 node_id.field1.field2.field3 + parts = variable_name.split('.') + if len(parts) >= 2: + source_node_id = parts[0] + + # 从previous_outputs中获取源节点的输出 + if source_node_id in context['node_outputs']: + source_output = context['node_outputs'][source_node_id] + + if isinstance(source_output, dict): + # 从第二个部分开始遍历路径 + current_value = source_output + for field_name in parts[1:]: + if isinstance(current_value, dict) and field_name in current_value: + current_value = current_value[field_name] + else: + # 如果路径不存在,返回None + return None + + return current_value + + return None + except Exception as e: + logger.warning(f"解析变量值失败: {variable_name}, 错误: {str(e)}") + return None + + async def _execute_start_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行开始节点""" + # 开始节点的输入和输出应该一致,都是workflow_input + workflow_input = input_data['workflow_input'] + return { + 'success': True, + 'message': '工作流开始', + 'data': workflow_input, + 'user_input': workflow_input # 添加用户输入显示 + } + + async def _execute_end_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行结束节点""" + previous_outputs = input_data.get('previous_outputs', {}) + + # 处理结束节点的输出参数配置 + node_parameters = node.get('parameters', {}) + output_params = node_parameters.get('outputs', []) + + result_data = {} + + # 根据输出参数配置获取对应的值 + for param in output_params: + param_name = param.get('name') + variable_name = param.get('variable_name') + + if variable_name: + # 解析variable_name,格式如: "node_1759022611056.output.response" + try: + parts = variable_name.split('.') + if len(parts) >= 3: + source_node_id = parts[0] + output_type = parts[1] # 通常是"output" + field_name = parts[2] # 具体的字段名,如"response" + + # 从前一个节点的输出中获取值 + if source_node_id in previous_outputs: + source_output = previous_outputs[source_node_id] + if isinstance(source_output, dict): + # 首先尝试从根级别获取字段(如LLM节点的response字段) + if field_name in source_output: + result_data[param_name] = source_output[field_name] + # 如果根级别没有,再尝试从data字段中获取 + elif 'data' in source_output and isinstance(source_output['data'], dict): + if field_name in source_output['data']: + result_data[param_name] = source_output['data'][field_name] + else: + result_data[param_name] = None + else: + result_data[param_name] = None + else: + result_data[param_name] = source_output + else: + result_data[param_name] = None + else: + # 格式不正确,使用默认值 + result_data[param_name] = param.get('default_value') + except Exception as e: + logger.warning(f"解析variable_name失败: {variable_name}, 错误: {str(e)}") + result_data[param_name] = param.get('default_value') + else: + # 没有variable_name,使用默认值 + result_data[param_name] = param.get('default_value') + + # 如果没有配置输出参数,返回简化的前一个节点输出(保持向后兼容) + if not output_params: + simplified_outputs = {} + for node_id, output in previous_outputs.items(): + if isinstance(output, dict): + simplified_outputs[node_id] = { + 'success': output.get('success', False), + 'message': output.get('message', ''), + 'data': output.get('data', {}) if not isinstance(output.get('data'), dict) or node_id not in str(output.get('data', {})) else {} + } + else: + simplified_outputs[node_id] = output + result_data = simplified_outputs + + return { + 'success': True, + 'message': '工作流结束', + 'data': result_data + } + + async def _execute_llm_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行LLM节点""" + config = input_data['node_config'] + + # 获取LLM配置 + model_id = config.get('model_id') + if not model_id: + # 兼容前端的model字段(可能是ID或名称) + model_value = config.get('model_name', config.get('model')) + if model_value: + # 如果是整数,直接作为ID使用 + if isinstance(model_value, int): + model_id = model_value + else: + # 如果是字符串,按名称查询 + llm_config = self.session.query(LLMConfig).filter(LLMConfig.model_name == model_value).first() + if llm_config: + model_id = llm_config.id + + if not model_id: + raise ValueError("未指定有效的大模型配置") + + llm_config = self.session.query(LLMConfig).filter(LLMConfig.id == model_id).first() + if not llm_config: + raise ValueError(f"大模型配置 {model_id} 不存在") + + # 准备提示词 + prompt_template = config.get('prompt', '') + + # 检查是否启用变量替换 + enable_variable_substitution = config.get('enable_variable_substitution', True) + + if enable_variable_substitution: + # 使用增强的变量替换 + prompt = self._substitute_variables(prompt_template, input_data) + else: + prompt = prompt_template + + # 记录处理后的提示词到输入数据中,用于前端显示 + input_data['processed_prompt'] = prompt + input_data['original_prompt'] = prompt_template + + # 调用LLM服务 + try: + response = await self.llm_service.chat_completion( + model_config=llm_config, + messages=[{"role": "user", "content": prompt}], + temperature=config.get('temperature', 0.7), + max_tokens=config.get('max_tokens') + ) + + return { + 'success': True, + 'response': response, + 'prompt': prompt, + 'model': llm_config.model_name, + 'tokens_used': getattr(response, 'usage', {}).get('total_tokens', 0) if hasattr(response, 'usage') else 0 + } + + except Exception as e: + logger.error(f"LLM调用失败: {str(e)}") + raise ValueError(f"LLM调用失败: {str(e)}") + + def _substitute_variables(self, template: str, input_data: Dict[str, Any]) -> str: + """变量替换函数""" + import re + + # 获取解析后的输入参数 + resolved_inputs = input_data.get('resolved_inputs', {}) + + # 获取工作流输入数据 + # input_data['workflow_input'] 包含了用户输入的参数 + workflow_input = input_data.get('workflow_input', {}) + + # 构建变量上下文 + variable_context = {} + + # 首先添加解析后的参数 + variable_context.update(resolved_inputs) + + # 添加工作流输入的顶层字段 + variable_context.update(workflow_input) + + # 如果 workflow_input 包含 user_input 字段,将其内容提升到顶层 + if 'user_input' in workflow_input and isinstance(workflow_input['user_input'], dict): + variable_context.update(workflow_input['user_input']) + + # 添加前一个节点的输出(简化访问) + for node_id, output in input_data.get('previous_outputs', {}).items(): + if isinstance(output, dict): + # 添加节点输出的直接访问 + variable_context[f'node_{node_id}'] = output.get('data', output) + # 如果输出有response字段,也添加直接访问 + if 'response' in output: + variable_context[f'node_{node_id}_response'] = output['response'] + + # 调试日志:打印变量上下文 + logger.info(f"变量替换上下文: {variable_context}") + logger.info(f"原始模板: {template}") + + # 使用正则表达式替换变量 {{variable_name}} 和 {variable_name} + def replace_variable(match): + var_name = match.group(1) + replacement = variable_context.get(var_name, match.group(0)) + logger.info(f"替换变量 {match.group(0)} -> {replacement}") + return str(replacement) + + # 首先替换 {{variable_name}} 格式的变量 + result = re.sub(r'\{\{([^}]+)\}\}', replace_variable, template) + # 然后替换 {variable_name} 格式的变量 + result = re.sub(r'\{([^}]+)\}', replace_variable, result) + + logger.info(f"替换后结果: {result}") + return result + + async def _execute_condition_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行条件节点""" + config = input_data['node_config'] + condition = config.get('condition', '') + + # 简单的条件评估(生产环境需要更安全的实现) + try: + # 构建评估上下文 + eval_context = { + 'input': input_data['workflow_input'], + 'previous': input_data['previous_outputs'] + } + + # 评估条件 + result = eval(condition, {"__builtins__": {}}, eval_context) + + return { + 'success': True, + 'condition': condition, + 'result': bool(result) + } + + except Exception as e: + logger.error(f"条件评估失败: {str(e)}") + raise ValueError(f"条件评估失败: {str(e)}") + + async def _execute_code_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行代码节点""" + config = input_data['node_config'] + language = config.get('language', 'python') + code = config.get('code', '') + + if language == 'python': + # 执行Python代码 + execution_result = await self._execute_python_code(code, input_data) + + # 处理输出参数配置 + node_parameters = node.get('parameters', {}) + if node_parameters and 'outputs' in node_parameters: + output_params = node_parameters['outputs'] + code_result = execution_result.get('result', {}) + + # 根据输出参数配置构建最终输出 + final_output = { + 'success': execution_result['success'], + 'code': execution_result['code'], + 'input_parameters': execution_result.get('input_parameters', {}) + } + + # 如果代码返回的是字典,根据输出参数配置提取对应字段 + if isinstance(code_result, dict): + for output_param in output_params: + param_name = output_param.get('name') + if param_name and param_name in code_result: + final_output[param_name] = code_result[param_name] + else: + # 如果代码返回的不是字典,且只有一个输出参数,直接使用返回值 + if len(output_params) == 1: + param_name = output_params[0].get('name') + if param_name: + final_output[param_name] = code_result + + return final_output + else: + # 如果没有输出参数配置,返回原始结果 + return execution_result + else: + raise ValueError(f"不支持的代码语言: {language}") + + async def _execute_python_code(self, code: str, input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行Python代码""" + try: + # 构建执行上下文 + safe_builtins = { + 'len': len, + 'str': str, + 'int': int, + 'float': float, + 'bool': bool, + 'list': list, + 'dict': dict, + 'tuple': tuple, + 'set': set, + 'range': range, + 'enumerate': enumerate, + 'zip': zip, + 'sum': sum, + 'min': min, + 'max': max, + 'abs': abs, + 'round': round, + 'sorted': sorted, + 'reversed': reversed, + 'print': print, + '__import__': __import__, + } + + # 导入常用模块 + import json + import datetime + import math + import re + + exec_context = { + '__builtins__': safe_builtins, + 'json': json, # 允许使用json模块 + 'datetime': datetime, # 允许使用datetime模块 + 'math': math, # 允许使用math模块 + 're': re, # 允许使用re模块 + } + + # 执行代码以定义函数 + exec(code, exec_context) + + # 检查是否定义了main函数 + if 'main' not in exec_context: + raise ValueError("代码中必须定义一个main函数") + + main_function = exec_context['main'] + + # 获取已解析的输入参数 + resolved_inputs = input_data.get('resolved_inputs', {}) + + # 调用main函数并传递参数 + if resolved_inputs: + # 使用解析后的输入参数调用main函数 + result = main_function(**resolved_inputs) + else: + # 如果没有输入参数,直接调用main函数 + result = main_function() + + return { + 'success': True, + 'result': result, + 'code': code, + 'input_parameters': resolved_inputs + } + + except Exception as e: + logger.error(f"Python代码执行失败: {str(e)}") + raise ValueError(f"Python代码执行失败: {str(e)}") + + async def _execute_http_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]: + """执行HTTP请求节点""" + import aiohttp + + config = input_data['node_config'] + method = config.get('method', 'GET').upper() + url = config.get('url', '') + headers = config.get('headers', {}) + body = config.get('body') + + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method=method, + url=url, + headers=headers, + data=body + ) as response: + response_text = await response.text() + + return { + 'success': True, + 'status_code': response.status, + 'response': response_text, + 'headers': dict(response.headers) + } + + except Exception as e: + logger.error(f"HTTP请求失败: {str(e)}") + raise ValueError(f"HTTP请求失败: {str(e)}") + + +# 工作流引擎实例 +async def get_workflow_engine(session: AsyncSession = None) -> WorkflowEngine: + """获取工作流引擎实例""" + if session is None: + async for s in get_session(): + session = s + break + return WorkflowEngine(session) diff --git a/th_agenter/services/zhipu_embeddings.py b/th_agenter/services/zhipu_embeddings.py new file mode 100644 index 0000000..1535fc6 --- /dev/null +++ b/th_agenter/services/zhipu_embeddings.py @@ -0,0 +1,71 @@ +"""Custom ZhipuAI Embeddings using OpenAI compatible API.""" + +import asyncio +from typing import List, Optional +from openai import OpenAI +from langchain_core.embeddings import Embeddings +from ..core.config import settings +from loguru import logger + +class ZhipuOpenAIEmbeddings(Embeddings): + """ZhipuAI Embeddings using OpenAI compatible API.""" + + def __init__( + self, + api_key: Optional[str] = None, + base_url: str = "https://open.bigmodel.cn/api/paas/v4", + model: str = "embedding-3", + dimensions: int = 1024 + ): + self.api_key = api_key or settings.embedding.zhipu_api_key + self.base_url = base_url + self.model = model + self.dimensions = dimensions + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + logger.info(f"ZhipuOpenAI Embeddings initialized with model: {self.model}") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + try: + embeddings = [] + for text in texts: + response = self.client.embeddings.create( + model=self.model, + input=text, + dimensions=self.dimensions, + encoding_format="float" + ) + embeddings.append(response.data[0].embedding) + return embeddings + except Exception as e: + logger.error(f"Error embedding documents: {e}") + raise + + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + try: + response = self.client.embeddings.create( + model=self.model, + input=text, + dimensions=self.dimensions, + encoding_format="float" + ) + return response.data[0].embedding + except Exception as e: + logger.error(f"Error embedding query: {e}") + raise + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Async embed search docs.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.embed_documents, texts) + + async def aembed_query(self, text: str) -> List[float]: + """Async embed query text.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.embed_query, text) \ No newline at end of file diff --git a/utils/Constant.py b/utils/Constant.py new file mode 100644 index 0000000..c0d40e7 --- /dev/null +++ b/utils/Constant.py @@ -0,0 +1,490 @@ +# from PySide6.QtCore import QPointF, QRectF, QSizeF +# -*- coding: utf-8 -*- +class Constant: + # invalid_point = QPointF(-123.4, -234.71) # 无效点 + # invalid_rect = QRectF(invalid_point, QSizeF(0, 0)) # 无效矩形 + service_yml_path = 'config/service/dsp_%s_service.yml' # 服务配置路径 + kafka_yml_path = 'config/kafka/dsp_%s_kafka.yml' # kafka配置路径 + aliyun_yml_path = "config/aliyun/dsp_%s_aliyun.yml" # 阿里云配置路径 + baidu_yml_path = 'config/baidu/dsp_%s_baidu.yml' # 百度配置路径 + pull_frame_width = 1400 # 拉流帧宽度 + + LLM_MODE_NONE = 0 + LLM_MODE_NONCHAT = 1 + LLM_MODE_CHAT = 2 + LLM_MODE_EMBEDDING = 3 + LLM_MODE_LOCAL_OLLAMA = 4 + + LLM_PROMPT_TEMPLATE_METHOD_FORMAT = "format" + LLM_PROMPT_TEMPLATE_METHOD_INVOKE = "invoke" + + LLM_PROMPT_VALUE_STR = "str" + LLM_PROMPT_VALUE_MESSAGES = "messages" + LLM_PROMPT_VALUE_VALUE = "promptValue" + + INPUT_NONE = 0 + INPUT_PULL_STREAM = 1 + INPUT_NET_CAMERA = 2 + INPUT_USB_CAMERA = 3 + INPUT_LOCAL_VIDEO = 4 + INPUT_LOCAL_DIR = 5 + INPUT_LOCAL_FILE = 6 + INPUT_LOCAL_LABELLED_DIR = 7 + INPUT_LOCAL_LABELLED_ZIP = 8 + + SYSTEM_BUTTON_SNAP = 1 + SYSTEM_BUTTON_CLEAR_LOG = 2 + + MODE_OCR = 7 + + WEB_OWNER_NONE = 0 + WEB_OWNER_ALG_TEST = 1 + WEB_OWNER_LLM = 2 + + PAGE_MODE_NONE = 0 + PAGE_MODE_VIDEO = 1 + PAGE_MODE_LABELLING = 2 + PAGE_MODE_ALG_TEST = 3 + PAGE_MODE_WEB = 4 + PAGE_MODE_LLM = 5 + PAGE_MODE_CONFIG = 6 + + ALG_TEST_LOAD_DATA = 1 + ALG_TEST_TRAIN = 2 + ALG_TEST_INFER = 3 + + ALG_STATUS_NOTHING = 0 # 啥也没干 + ALG_STATUS_MODEL_READY = 1 # 模型已就绪 + ALG_STATUS_DATA_LOADED = 2 # 加载已数据/加载中 + ALG_STATUS_TRAINED = 4 # 训练完成/训练中 + ALG_STATUS_INFER = 8 # 推理完成/推理中 + + UTF_8 = "utf-8" # 编码格式 + + COLOR = ( + [0, 0, 255], + [255, 0, 0], + [211, 0, 148], + [0, 127, 0], + [0, 69, 255], + [0, 255, 0], + [255, 0, 255], + [0, 0, 127], + [127, 0, 255], + [255, 129, 0], + [139, 139, 0], + [255, 255, 0], + [127, 255, 0], + [0, 127, 255], + [0, 255, 127], + [255, 127, 255], + [8, 101, 139], + [171, 130, 255], + [139, 112, 74], + [205, 205, 180]) + + ONLINE = "online" + OFFLINE = "offline" + PHOTO = "photo" + RECORDING = "recording" + + ONLINE_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "models": { + 'type': 'list', + 'required': True, + 'nullable': False, + 'minlength': 1, + 'maxlength': 3, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + ONLINE_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + OFFLINE_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "models": { + 'type': 'list', + 'required': True, + 'maxlength': 3, + 'minlength': 1, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + OFFLINE_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + IMAGE_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + }, + "image_urls": { + 'type': 'list', + 'required': True, + 'minlength': 1, + 'schema': { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 5000 + } + }, + "models": { + 'type': 'list', + 'required': True, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "code": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "categories", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "is_video": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "is_image": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "code", + 'allowed': ["0", "1"] + }, + "categories": { + 'type': 'list', + 'required': True, + 'dependencies': "code", + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{0,255}$'}, + "config": { + 'type': 'dict', + 'required': False, + 'dependencies': "id", + } + } + } + } + } + } + } + } + + RECORDING_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'maxlength': 255 + }, + "push_url": { + 'type': 'string', + 'required': False, + 'empty': True, + 'maxlength': 255 + }, + "logo_url": { + 'type': 'string', + 'required': False, + 'nullable': True, + 'maxlength': 255 + } + } + + RECORDING_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["stop"] + } + } + + PULL2PUSH_START_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start"] + }, + "video_urls": { + 'type': 'list', + 'required': True, + 'nullable': False, + 'schema': { + 'type': 'dict', + 'required': True, + 'schema': { + "id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "pull_url", + 'regex': r'^[a-zA-Z0-9]{1,255}$' + }, + "pull_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "push_url", + 'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$' + }, + "push_url": { + 'type': 'string', + 'required': True, + 'empty': False, + 'dependencies': "id", + 'regex': r'^(https|http|rtsp|rtmp|artc|webrtc|ws)://\w.+$' + } + } + } + } + } + + PULL2PUSH_STOP_SCHEMA = { + "request_id": { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,36}$' + }, + "command": { + 'type': 'string', + 'required': True, + 'allowed': ["start", "stop"] + }, + "video_ids": { + 'type': 'list', + 'required': False, + 'nullable': True, + 'schema': { + 'type': 'string', + 'required': True, + 'empty': False, + 'regex': r'^[a-zA-Z0-9]{1,255}$' + } + } + } diff --git a/utils/Exception.py b/utils/Exception.py new file mode 100644 index 0000000..b347fa6 --- /dev/null +++ b/utils/Exception.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +from loguru import logger +from enum import Enum, unique + +class ServiceException(Exception): + def __init__(self, code, msg, desc=None): + self.code = code + if desc is None: + self.msg = msg + else: + self.msg = msg % desc + + def __str__(self): + logger.error("异常编码:{}, 异常描述:{}", self.code, self.msg) + +# 异常枚举 +@unique +class ExceptionType(Enum): + OR_VIDEO_ADDRESS_EXCEPTION = ("SP000", "未拉取到视频流, 请检查拉流地址是否有视频流!") + ANALYSE_TIMEOUT_EXCEPTION = ("SP001", "AI分析超时!") + PULLSTREAM_TIMEOUT_EXCEPTION = ("SP002", "原视频拉流超时!") + READSTREAM_TIMEOUT_EXCEPTION = ("SP003", "原视频读取视频流超时!") + GET_VIDEO_URL_EXCEPTION = ("SP004", "获取视频播放地址失败!") + GET_VIDEO_URL_TIMEOUT_EXCEPTION = ("SP005", "获取原视频播放地址超时!") + PULL_STREAM_URL_EXCEPTION = ("SP006", "拉流地址不能为空!") + PUSH_STREAM_URL_EXCEPTION = ("SP007", "推流地址不能为空!") + PUSH_STREAM_TIME_EXCEPTION = ("SP008", "未生成本地视频地址!") + AI_MODEL_MATCH_EXCEPTION = ("SP009", "未匹配到对应的AI模型!") + ILLEGAL_PARAMETER_FORMAT = ("SP010", "非法参数格式!") + PUSH_STREAMING_CHANNEL_IS_OCCUPIED = ("SP011", "推流通道可能被占用, 请稍后再试!") + VIDEO_RESOLUTION_EXCEPTION = ("SP012", "不支持该分辨率类型的视频,请切换分辨率再试!") + READ_IAMGE_URL_EXCEPTION = ("SP013", "未能解析图片地址!") + DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED = ("SP014", "不支持该类型的检测目标!") + WRITE_STREAM_EXCEPTION = ("SP015", "写流异常!") + OR_VIDEO_DO_NOT_EXEIST_EXCEPTION = ("SP016", "原视频不存在!") + MODEL_LOADING_EXCEPTION = ("SP017", "模型加载异常!") + MODEL_ANALYSE_EXCEPTION = ("SP018", "算法模型分析异常!") + AI_MODEL_CONFIG_EXCEPTION = ("SP019", "模型配置不能为空!") + AI_MODEL_GET_CONFIG_EXCEPTION = ("SP020", "获取模型配置异常, 请检查模型配置是否正确!") + MODEL_GROUP_LIMIT_EXCEPTION = ("SP021", "模型组合个数超过限制!") + MODEL_NOT_SUPPORT_VIDEO_EXCEPTION = ("SP022", "%s不支持视频识别!") + MODEL_NOT_SUPPORT_IMAGE_EXCEPTION = ("SP023", "%s不支持图片识别!") + THE_DETECTION_TARGET_CANNOT_BE_EMPTY = ("SP024", "检测目标不能为空!") + URL_ADDRESS_ACCESS_FAILED = ("SP025", "URL地址访问失败, 请检测URL地址是否正确!") + UNIVERSAL_TEXT_RECOGNITION_FAILED = ("SP026", "识别失败!") + COORDINATE_ACQUISITION_FAILED = ("SP027", "飞行坐标识别异常!") + PUSH_STREAM_EXCEPTION = ("SP028", "推流异常!") + MODEL_DUPLICATE_EXCEPTION = ("SP029", "存在重复模型配置!") + DETECTION_TARGET_NOT_SUPPORT = ("SP031", "存在不支持的检测目标!") + TASK_EXCUTE_TIMEOUT = ("SP032", "任务执行超时!") + PUSH_STREAM_URL_IS_NULL = ("SP033", "拉流、推流地址不能为空!") + PULL_STREAM_NUM_LIMIT_EXCEPTION = ("SP034", "转推流数量超过限制!") + NOT_REQUESTID_TASK_EXCEPTION = ("SP993", "未查询到该任务,无法停止任务!") + NO_RESOURCES = ("SP995", "服务器暂无资源可以使用,请稍后30秒后再试!") + NO_CPU_RESOURCES = ("SP996", "暂无CPU资源可以使用,请稍后再试!") + SERVICE_COMMON_EXCEPTION = ("SP997", "公共服务异常!") + NO_GPU_RESOURCES = ("SP998", "暂无GPU资源可以使用,请稍后再试!") + SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常!") diff --git a/utils/Flag.py b/utils/Flag.py new file mode 100644 index 0000000..f737d57 --- /dev/null +++ b/utils/Flag.py @@ -0,0 +1,12 @@ + +class Flag: + Unique = True + Append = False + + Debug = True + +class Option: + NoOption = 0x00 + + AddObject_AutoName = 0x01 + AddObject_Select = 0x02 \ No newline at end of file diff --git a/utils/Helper.py b/utils/Helper.py new file mode 100644 index 0000000..daeff4f --- /dev/null +++ b/utils/Helper.py @@ -0,0 +1,668 @@ +# -*- coding: utf-8 -*- +import sys, os, cv2 +from os import makedirs +from os.path import join, exists +from loguru import logger +from json import loads +from ruamel.yaml import safe_load, YAML +import random, sys, math, inspect, psutil +from pathlib import Path +from PySide6.QtGui import QIcon, QColor +from PySide6.QtCore import QObject, QRectF, QEventLoop, QIODevice, QTextStream, QFile +from PySide6.QtWidgets import QApplication + +import DrGraph.utils.vclEnums as enums +#region Property +class Property: + def __init__(self, read_func=None, write_func=None, default=None, hasMember=True): + self.read_func = read_func + self.write_func = write_func + self.default = default + self.owner_class = None + self.private_name = None + self.hasMember = hasMember + def __set_name__(self, owner, name): + if self.hasMember: + self.private_name = f"_{name}" + self.owner_class = owner + def callDirectGet(self, instance, owner): + if instance is None or not self.hasMember: + return self.default + if not hasattr(instance, self.private_name): + return self.default + return getattr(instance, self.private_name) + def callCustomGet(self, instance, owner): + if instance is None: + return self + + if self.read_func is None: + if hasattr(instance, self.private_name): + return getattr(instance, self.private_name) + return self.default + + try: + if isinstance(self.read_func, str): + if hasattr(instance, self.read_func): + method = getattr(instance, self.read_func) + return method() + elif hasattr(self.read_func, '__name__'): + method_name = self.read_func.__name__ + if hasattr(instance, method_name): + method = getattr(instance, method_name) + return method() + elif callable(self.read_func): + try: + return self.read_func(instance) + except (TypeError, AttributeError): + return self.read_func() + except Exception as e: + logger.error(e) + + if hasattr(instance, self.private_name): + return getattr(instance, self.private_name) + return self.default + + def callDirectSet(self, instance, value): + if instance is None or not self.hasMember: + return + setattr(instance, self.private_name, value) + + def callCustomSet(self, instance, value): + try: + if isinstance(self.write_func, str): + if hasattr(instance, self.write_func): + method = getattr(instance, self.write_func) + method(value) + elif hasattr(self.write_func, '__name__'): + method_name = self.write_func.__name__ + if hasattr(instance, method_name): + method = getattr(instance, method_name) + method(value) + elif callable(self.write_func): + try: + self.write_func(instance, value) + except (TypeError, AttributeError): + self.write_func(value) + except Exception as e: + pass + +class Property_rw(Property): + def __init__(self, default=None, hasMember=True): + super().__init__(None, None, default, hasMember) + def __get__(self, instance, owner): + return self.callDirectGet(instance, owner) + def __set__(self, instance, value): + self.callDirectSet(instance, value) + +class Property_Rw(Property): + def __init__(self, read_func=None, default=None, hasMember=True): + super().__init__(read_func, None, default, hasMember) + + def __get__(self, instance, owner): + return self.callCustomGet(instance, owner) + + def __set__(self, instance, value): + setattr(instance, self.private_name, value) + +class Property_rW(Property): + def __init__(self, write_func=None, default=None, hasMember=True): + super().__init__(None, write_func, default, hasMember) + + def __get__(self, instance, owner): + return self.callDirectGet(instance, owner) + + def __set__(self, instance, value): + self.callCustomSet(instance, value) + +class Property_RW(Property): + def __init__(self, read_func=None, write_func=None, default=None, hasMember=True): + super().__init__(read_func, write_func, default, hasMember) + + def __get__(self, instance, owner): + return self.callCustomGet(instance, owner) + + def __set__(self, instance, value): + self.callCustomSet(instance, value) +#endregion Property + +class AppHelper(QObject): + app = Property_rw(None) + def setBriefStatusText(self, text): + if self.briefStatusControl: + self.briefStatusControl.setText(text) + else: + print(text) + briefStatusText = Property_rW(setBriefStatusText, '') + + def setProgress(self, value): + if self.progressBarControl: + self.progressBarControl.setValue(value) + progress = Property_rW(setProgress, 0) + + def setProgressMax(self, value): + if self.progressBarControl: + self.progressBarControl.setMaximum(value) + progressMax = Property_rW(setProgressMax, 100) + + def setProgressMin(self, value): + if self.progressBarControl: + self.progressBarControl.setMinimum(value) + progressMin = Property_rW(setProgressMin, 0) + + def __init__(self): + self.briefStatusControl = None + self.progressBarControl = None + self._briefStatusText = '' + pass + +class Helper: + OnLogMsg = None + AppFlag_SaveAnalysisResult = True + AppFlag_SaveLog = False + App = None + + @staticmethod + def castRange(value, minValue, maxValue): + return max(minValue, min(maxValue, value)) + # 取得程序目录 + @staticmethod + def getPath_App(): + if getattr(sys, 'frozen', False): + # 如果程序是打包的exe文件 + return os.path.dirname(sys.executable) + else: + # 如果是Python脚本 - 获取上两级目录 + current_file = os.path.abspath(__file__) # f:\PySide6\AiBase\DrGraph\utils\Helper.py + current_dir = os.path.dirname(current_file) # f:\PySide6\AiBase\DrGraph\utils + parent_dir = os.path.dirname(current_dir) # f:\PySide6\AiBase\DrGraph + root_dir = os.path.dirname(parent_dir) # f:\PySide6\AiBase + return root_dir + + @staticmethod + def fitOS(file_name): + if sys.platform.startswith('win'): + file_name = file_name.replace('/','\\') + else: + file_name = file_name.replace('\\', '/') + return file_name + + def generateDistinctColors(n, s=0.8, v=0.7): + import colorsys + colors = [] + for i in range(n): + hue = i * 1.0 / n # 均匀分布在 [0, 1) + r, g, b = colorsys.hsv_to_rgb(hue, s, v) + colors.append(QColor(r * 255, g * 255, b * 255)) + return colors + + def setBriefStatusText(self, text): + Helper.App.setBriefStatusText(text) + briefStatusText = Property_rW(setBriefStatusText, '') + @staticmethod + def Sleep(msec): + QApplication.processEvents(QEventLoop.AllEvents, msec) + + + @staticmethod + def getAbsoluteFileName(file_name): + if os.path.isabs(file_name): + return Helper.fitOS(file_name) + else: + return Helper.fitOS(os.path.join(Helper.getPath_App(), file_name)) + @staticmethod + def getConfigs(path, read_type='yml'): + """ + 读取配置文件并返回解析后的配置信息 + + :param path: 配置文件路径 + :param read_type: 配置文件类型,默认为'yml',可选'json'或'yml' + :return: 解析后的配置信息,JSON格式返回字典,YML格式返回对应的数据结构 + :raises Exception: 当无法获取配置信息时抛出异常 + """ + yaml = YAML(typ='safe', pure=True) + with open(path, 'r', encoding='utf-8') as f: + return yaml.load(f) + # with open(path, 'r', encoding="utf-8") as f: + # # 根据文件类型选择相应的解析方式 + # if read_type == 'json': + # return loads(f.read()) + # if read_type == 'yml': + # return safe_load(f) + # 如果未成功读取配置信息,则抛出异常 + raise Exception('路径: %s未获取配置信息' % path) + + @staticmethod + def getTooltipText(content): + # 增加一个小喇叭图标 + # content = f' {content}' + return f""" + + +

DrGraph

+

{content}

+ + +""" + @staticmethod + def log_init(app, base_dir, env): + """ + 初始化日志配置 + + :param base_dir: 基础目录路径,用于定位配置文件和日志文件存储位置 + :param env: 环境标识,用于加载对应环境的日志配置文件 + :return: 无返回值 + """ + Helper.App = AppHelper() + Helper.App.app = app + # QToolTip样式 - 自定义样式 - 增加Header + app.setStyleSheet(""" + QToolTip { + background-color: #dd2222; + color: #f0f0f0; + border: 1px solid #555; + border-radius: 4px; + padding: 6px; + font: 10pt "Segoe UI"; + opacity: 220; + } + """) + + log_config = Helper.getConfigs(join(base_dir, 'appIOs/configs/logger/drgraph_%s_logger.yml' % env)) + # 判断日志文件是否存在,不存在创建 + base_path = join(base_dir, log_config.get("base_path")) + if not exists(base_path): + makedirs(base_path) + # 移除日志设置 + logger.remove(handler_id=None) + # 打印日志到文件 + if bool(log_config.get("enable_file_log")): + logger.add(join(base_path, log_config.get("log_name")), + rotation=log_config.get("rotation"), + retention=log_config.get("retention"), + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True, + encoding=log_config.get("encoding")) + # 控制台输出 + if bool(log_config.get("enable_stderr")): + logger.add(sys.stderr, + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True) + logger.info("\n\n\n----=========== 日志配置初始化完成, 开始新的日志记录 ==========----") + + @staticmethod + def log_info(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'INFO: {msg}', 'black') + caller = inspect.stack()[1] + logger.info(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "INFO", "msg" : msg} ) + @staticmethod + def log_error(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'ERROR: {msg}', 'red') + caller = inspect.stack()[1] + logger.error(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "ERROR", "msg" : msg} ) + @staticmethod + def log_warning(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'WARNING: {msg}', (255, 128, 0)) + caller = inspect.stack()[1] + logger.warning(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "WARNING", "msg" : msg} ) + @staticmethod + def log_debug(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'DEBUG: {msg}', (0, 128, 128)) + caller = inspect.stack()[1] + logger.debug(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "DEBUG", "msg" : msg} ) + @staticmethod + def log_critical(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'CRITICAL: {msg}', (128, 0, 128)) + caller = inspect.stack()[1] + logger.critical(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "CRITICAL", "msg" : msg} ) + @staticmethod + def log_exception(msg, toWss = False): + if Helper.OnLogMsg: + Helper.OnLogMsg(f'EXCEPTION: {msg}', (255, 140, 0)) + caller = inspect.stack()[1] + logger.exception(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "EXCEPTION", "msg" : msg} ) + @staticmethod + def log(msg, toWss = False): + caller = inspect.stack()[1] + logger.log(f"[{Path(caller.filename).name}:{caller.lineno}.{caller.function}()] > {msg}") + if toWss: + Helper.log_wss({"type": "log", "kind": "LOG", "msg" : msg} ) + @staticmethod + def log_wss(msg): + if Helper.wss: + Helper.wss.send(msg) + + @staticmethod + def getTextSize(font, text): + import pygame as pg + surface = font.render(text, True, (0, 0, 0)) + return (surface.get_width(), surface.get_height(), surface) + + @staticmethod + def buildSurfaces(font, text, width, color, wordWrap): + text = text.strip() + w = Helper.getTextSize(font, text)[0] + result = [] + if w > width and wordWrap: + segLen = math.floor(width / w * len(text)) + while len(text): + if len(text) < segLen: + t = text + text = '' + else: + t = text[:segLen] + text = text[segLen:] + result.append(font.render(t, True, color)) + else: + result.append(font.render(text, True, color)) + return result + + @staticmethod + def randomColor(): + '''随机颜色''' + return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + + @staticmethod + def reverseColor(color: QColor): + '''反转颜色''' + return (255 - color.red(), 255 - color.green(), 255 - color.blue()) + @staticmethod + def getRGB(color_value): + # 如果是元组或列表形式的RGB值 + if isinstance(color_value, (tuple, list)): + if len(color_value) >= 3: + # 取前三个值作为RGB + r, g, b = color_value[0], color_value[1], color_value[2] + # 确保值在0-255范围内 + return (max(0, min(255, int(r))), + max(0, min(255, int(g))), + max(0, min(255, int(b)))) + + # 如果是整数形式的颜色值 + elif isinstance(color_value, int): + # 将整数转换为RGB分量 + # 假设格式为0xRRGGBB + r = (color_value >> 16) & 0xFF + g = (color_value >> 8) & 0xFF + b = color_value & 0xFF + return (r, g, b) + + # 如果是字符串形式 + elif isinstance(color_value, str): + # 处理十六进制颜色值 + if color_value.startswith('#'): + hex_value = color_value[1:] + if len(hex_value) == 3: # 简写形式 #RGB + hex_value = ''.join([c*2 for c in hex_value]) + if len(hex_value) in (6, 8): # #RRGGBB 或 #RRGGBBAA + r = int(hex_value[0:2], 16) + g = int(hex_value[2:4], 16) + b = int(hex_value[4:6], 16) + return (r, g, b) + # 处理颜色名称(需要额外的颜色名称映射表) + # 这里只列举几种常见颜色 + color_names = { + 'black': (0, 0, 0), + 'white': (255, 255, 255), + 'red': (255, 0, 0), + 'green': (0, 255, 0), + 'blue': (0, 0, 255), + 'yellow': (255, 255, 0), + 'magenta': (255, 0, 255), + 'cyan': (0, 255, 255), + 'orange': (255, 128, 0), # 根据项目规范 + 'teal': (0, 128, 128) # 根据项目规范 + } + if color_value.lower() in color_names: + return color_names[color_value.lower()] + + # 如果是Color对象(如pygame.Color) + elif hasattr(color_value, 'r') and hasattr(color_value, 'g') and hasattr(color_value, 'b'): + return (color_value.r, color_value.g, color_value.b) + + # 默认返回黑色 + return (0, 0, 0) + + @staticmethod + def check_system_resources(): + """检查系统资源使用情况""" + logger.info("检查系统资源使用情况...") + cpu_percent = psutil.cpu_percent(interval=1) + memory = psutil.virtual_memory() + network = psutil.net_io_counters() + + logger.info("检查系统资源使用情况完毕") + + return { + 'cpu_percent': cpu_percent, + 'memory_percent': memory.percent, + 'memory_available': memory.available / (1024**3), # GB + 'network_bytes_sent': int(network.bytes_sent / 1024), + 'network_bytes_recv': int(network.bytes_recv / 1024) + } + + @staticmethod + def build_response(type, status : enums.Response, msg): + status_code, status_msg = status.value + result = { + "type": "response", + "request_type": type, + "status_code": status_code, + "status_msg": status_msg, + "detail_msg": msg + } + if status_code != 0: + Helper.error(result); + return result + + @staticmethod + def get_surrounding_rect(points): + if len(points) == 0: + return Constant.invalid_rect + min_x = min(p.x() for p in points) + min_y = min(p.y() for p in points) + max_x = max(p.x() for p in points) + max_y = max(p.y() for p in points) + return QRectF(min_x, min_y, max_x - min_x, max_y - min_y) + + @staticmethod + def getYoloLabellingInfo(dir_path, file_names, desc): + if len(dir_path) > 0: + imageNumber, labelNumber = 0, 0 + imagePath = dir_path + 'images/' + labelPath = dir_path + 'labels/' + for file_name in file_names: + if file_name.startswith(imagePath): + imageNumber += 1 + elif file_name.startswith(labelPath): + labelNumber += 1 + return f'{desc} {imageNumber - 1} 张图片,{labelNumber - 1} 张标签;', imageNumber - 1 + return f'无{desc};', 0 + + @staticmethod + def getMarkdownRenderText(mdContent): + # 使用Python库直接将Markdown转换为HTML,避免JavaScript依赖 + try: + # 尝试导入markdown库 + import markdown + html_content = markdown.markdown(mdContent) + # 添加基本样式使其美观 + styled_html = f""" + + + + + + + + {html_content} + + + """ + return styled_html + except ImportError: + # 回退到JavaScript的marked.js方法 + logger.warning("未找到markdown库,使用JavaScript渲染方式") + # 从appIOs/configs加载marked.js + file_js = QFile('appIOs/configs/marked.min.js') + markedJs = '' + if file_js.open(QIODevice.ReadOnly | QIODevice.Text): + markedJs = file_js.readAll().data().decode('utf-8') + file_js.close() + + # 转义markdown内容 + escapedMd = mdContent.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') + + # 创建HTML模板 + htmlTemplate = ''' + + + + + + + + +
+ + + + ''' + + # 生成HTML内容 + htmlContent = htmlTemplate.replace('%1', markedJs).replace('%2', escapedMd) + return htmlContent + @staticmethod + def getMarkdownRender(mdFileName): + file = QFile(mdFileName) + mdContent = '' + if file.open(QIODevice.ReadOnly | QIODevice.Text): + stream = QTextStream(file) + stream.setAutoDetectUnicode(True) + mdContent = stream.readAll() + file.close() + return Helper.getMarkdownRenderText(mdContent) + else: + logger.error(f'打开文件 {mdFileName} 失败') + return f"

无法打开文件: {mdFileName}

" + +class RTTI: + @staticmethod + def _do_set_attr(obj, property_name, property_value): + if obj is None: + logger.error(f"RTTI.set: obj is None") + return + + class_name = type(obj).__name__ + object_name = obj.objectName() + if property_name not in dir(obj): + logger.error(f"RTTI.set: {class_name} {object_name}.{property_name} not in dir(obj)") + return + if property_name.endswith('icon') and isinstance(property_value, str): + original_property_value = property_value + if not os.path.exists(property_value): + property_value = os.path.join('appIOs/res/images/icons',property_value) + # logger.info(f"RTTI.set: {class_name} {object_name}.{property_name} = {property_value}(自动匹配)") + if not os.path.exists(property_value): + logger.error(f"{original_property_value}文件不存在 > RTTI.set: {class_name} {object_name}.{property_name} = '{original_property_value}'") + return + property_value = QIcon(property_value) + setter_method = getattr(obj, f'set{property_name[0].upper() + property_name[1:]}') + setter_method(property_value) + @staticmethod + def set(obj, property_name, property_value): + property_list = property_name.split('.') + if len(property_list) == 1: + RTTI._do_set_attr(obj, property_name, property_value) + else: + dest_obj = obj + for i in range(len(property_list) - 1): + if not dest_obj: + logger.error(f"RTTI.set: {property_list.join('.')} not found") + return + dest_obj = getattr(dest_obj, property_list[i]) + RTTI._do_set_attr(dest_obj, property_list[-1], property_value) + + @staticmethod + def _do_get_attr(obj, property_name): + if obj is None: + logger.error(f"RTTI.get: obj is None") + return None, None + if property_name not in dir(obj): + logger.error(f"RTTI.get: {type(obj).__name__} {obj.objectName()}.{property_name} not in dir(obj)") + return None, None + # 返回属性类型与属性值 + type_name = type(getattr(obj, property_name)).__name__ + value = getattr(obj, property_name) + return type_name, value + # 取得属性类型与属性值 type, value = RTTI.get(obj, property_name) + @staticmethod + def get(obj, property_name): + property_list = property_name.split('.') + if len(property_list) == 1: + return RTTI._do_get_attr(obj, property_name) + else: + dest_obj = obj + for i in range(len(property_list) - 1): + if not dest_obj: + logger.error(f"RTTI.get: {property_list.join('.')} not found") + return None + dest_obj = getattr(dest_obj, property_list[i]) + return RTTI._do_get_attr(dest_obj, property_list[-1]) + +class DrawHelper: + @staticmethod + def draw_dashed_line(mat, pt1, pt2, color, thickness=1, dash_length=10): + dist = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5 + dashes = int(dist / dash_length) + for i in range(dashes): + start = (int(pt1[0] + (pt2[0] - pt1[0]) * i / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * i / dashes)) + end = (int(pt1[0] + (pt2[0] - pt1[0]) * (i + 0.5) / dashes), int(pt1[1] + (pt2[1] - pt1[1]) * (i + 0.5) / dashes)) + cv2.line(mat, start, end, color, thickness) + + @staticmethod + def draw_dashed_rect(painter, rect, color, thickness=1, dash_length=10): + x1, y1 = rect.left(), rect.top() + x2, y2 = rect.right(), rect.bottom() + DrawHelper.draw_dashed_line(painter, (x1, y1), (x2, y1), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x1, y2), (x2, y2), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x1, y1), (x1, y2), color, thickness, dash_length) + DrawHelper.draw_dashed_line(painter, (x2, y1), (x2, y2), color, thickness, dash_length) \ No newline at end of file diff --git a/utils/YOLOTracker.py b/utils/YOLOTracker.py new file mode 100644 index 0000000..7301a48 --- /dev/null +++ b/utils/YOLOTracker.py @@ -0,0 +1,466 @@ +from loguru import logger +import subprocess as sp +from ultralytics import YOLO +import time, cv2, numpy as np, math +from traceback import format_exc +from DrGraph.utils.pull_push import NetStream +from DrGraph.utils.Helper import * +from DrGraph.utils.Constant import Constant +from zipfile import ZipFile + +class YOLOTracker: + def __init__(self, model_path): + """ + 初始化YOLOv11追踪器 + """ + self.model = YOLO(model_path) + self.tracking_config = { + "tracker": "appIOs/configs/yolo11/bytetrack.yaml", # "/home/thsw/jcq/projects/yolov11/ultralytics-main/ultralytics/cfg/trackers/bytetrack.yaml", + "conf": 0.25, + "iou": 0.45, + "persist": True, + "verbose": False + } + self.frame_count = 0 + self.processing_time = 0 + + def process_frame(self, frame): + """ + 处理单帧图像,进行目标检测和追踪 + """ + start_time = time.time() + + try: + # 执行YOLOv11目标检测和追踪 + results = self.model.track( + source=frame, + **self.tracking_config + ) + + # 获取第一个结果(因为只处理单张图片) + result = results[0] + + # 绘制检测结果 + processed_frame = result.plot() + + # 计算处理时间 + self.processing_time = (time.time() - start_time) * 1000 # 转换为毫秒 + self.frame_count += 1 + + # 打印检测信息(可选) + if self.frame_count % 100 == 0: + self._print_detection_info(result) + + return processed_frame, result + + except Exception as e: + logger.error("YOLO处理异常: {}", format_exc()) + return frame, None + + def _print_detection_info(self, result): + """ + 打印检测信息 + """ + boxes = result.boxes + if boxes is not None and len(boxes) > 0: + detection_count = len(boxes) + unique_ids = set() + for box in boxes: + if box.id is not None: + unique_ids.add(int(box.id[0])) + + logger.info(f"帧 {self.frame_count}: 检测到 {detection_count} 个目标, 追踪ID数: {len(unique_ids)}, 处理时间: {self.processing_time:.2f}ms") + else: + logger.info(f"帧 {self.frame_count}: 未检测到目标, 处理时间: {self.processing_time:.2f}ms") + +class YOLOTrackerManager: + def __init__(self, model_path, pull_url, push_url, request_id): + self.pull_url = pull_url + self.push_url = push_url + self.request_id = request_id + self.tracker = YOLOTracker(model_path) + self.stream = None + self.videoStream = None + self.videoType = Constant.INPUT_NONE + self.localFile = '' + self.localPath = '' + self.localFiles = [] + self._currentFrame = None + self.totalFrames = 0 + self.frameChanged = False + + def _stop(self): + if self.videoStream is not None: + self.videoStream.release() + self.videoStream = None + if self.stream is not None: + self.stream.clear_pull_p(self.stream.pull_p, self.request_id) + self.stream = None + self.localFile = '' + self.localPath = '' + self.localFiles = [] + self._currentFrame = None + self.totalFrames = 0 + self._frameIndex = -1 + self.videoType = Constant.INPUT_NONE + self.frameChanged = True + + def startLocalFile(self, fileName): + self._stop() + self.localFile = fileName + self._frameIndex = -1 + + def startLocalDir(self, dirName): + self._stop() + self.localPath = dirName + self.localFiles = [os.path.join(dirName, f) for f in os.listdir(dirName) if f.endswith(('.jpg', '.jpeg', '.png'))] + self.totalFrames = len(self.localFiles) + Helper.App.progressMax = self.totalFrames + self.localFiles.sort() + logger.info("本地目录打开: {}, 总帧数: {}", dirName, self.totalFrames) + self._frameIndex = 0 + + def startLabelledZip(self, labelledPath, categoryPath): + self._stop() + self.localPath = labelledPath + localFiles = ZipFile(labelledPath).namelist() + _, self.totalFrames = Helper.getYoloLabellingInfo(categoryPath, localFiles, '') + imagePath = categoryPath + 'images/' + self.localFiles = [file for file in localFiles if imagePath in file] + logger.info(f"标注压缩文件{labelledPath}的{categoryPath}集共有{self.totalFrames}帧, 有效帧数: {len(self.localFiles)}") + self._frameIndex = 0 + Helper.App.progressMax = self.totalFrames + + def startUsbCamera(self, index = 0): + self._stop() + self.videoStream = cv2.VideoCapture(index) + self.videoType = Constant.INPUT_USB_CAMERA + Helper.Sleep(200) + if not self.videoStream.isOpened(): + logger.error("无法打开USB摄像头: {}", index) + self.videoType = Constant.INPUT_NONE + return + self.totalFrames = 0x7FFFFFFF + + def startLocalVideo(self, fileName): + self._stop() + self.videoStream = cv2.VideoCapture(fileName) + self.videoType = Constant.INPUT_LOCAL_VIDEO + Helper.Sleep(200) + if not self.videoStream.isOpened(): + logger.error("无法打开本地视频流: {}", fileName) + self.videoType = Constant.INPUT_NONE + return + try: + total = int(self.videoStream.get(cv2.CAP_PROP_FRAME_COUNT)) + except Exception: + total = 0 + self.totalFrames = total if total is not None else 0 + Helper.App.progressMax = self.totalFrames + logger.info("本地视频打开: {}, 总帧数: {}", fileName, self.totalFrames) + + def startPull(self, url = ''): + self._stop() + if len(url) > 0: + self.pull_url = url + logger.info("拉流地址: {}", self.pull_url) + self.stream = NetStream(self.pull_url, self.push_url, self.request_id) + self.stream.prepare_pull() + + def getCurrentFrame(self): + if self._currentFrame is None: + self._currentFrame = self.nextFrame() + if self._currentFrame is not None: + return self._currentFrame.copy() + return None + currentFrame = Property_Rw(getCurrentFrame, None) + + def setFrameIndex(self, index): + if self.videoStream is None and len(self.localFiles) == 0: + return + if self.videoStream is not None and self.videoType != Constant.INPUT_LOCAL_VIDEO: + return + if index < 0: + index = 0 + if index >= self.totalFrames: + index = self.totalFrames - 1 + if self.videoStream: + self.videoStream.set(cv2.CAP_PROP_POS_FRAMES, index) + self._frameIndex = index - 1 + self._currentFrame = self.nextFrame() + self.frameChanged = True + frameIndex = Property_rW(setFrameIndex, 0) + + def getLabels(self): + with ZipFile(self.localPath, 'r') as zip_ref: + content = zip_ref.read(self.localFile) + content = content.decode('utf-8') + return content + return '' + # 取得待分析的图像帧 + def getAnalysisFrame(self, nextFlag): + frameChanged = self.frameChanged + self.frameChanged = False + if nextFlag: # 流式媒体 + self._currentFrame = self.nextFrame() + self.frameChanged = True + frame = self.currentFrame + return frame.copy() if frame is not None else None, frameChanged + + def nextFrame(self): + frame = None + if self.stream: + frame = self.stream.next_pull_frame() + elif self.videoStream: + ret, frame = self.videoStream.read() + self._frameIndex += 1 + if not ret: + self._frameIndex -= 1 + frame = None + elif len(self.localFiles) > 0: + if self.localPath.endswith('.zip'): + index = -1 + for img_file in self.localFiles: + if '/images/' in img_file: + if index == self._frameIndex: + # logger.warning(f'Loading image from zip file: {img_file}') + try: + with ZipFile(self.localPath, 'r') as zip_ref: + image_data = zip_ref.read(img_file) + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + self._frameIndex += 1 + lable_file = img_file.replace('/images/', '/labels/').replace('.jpg', '.txt').replace('.png', '.txt') + self.localFile = lable_file + except Exception as e: + # logger.error(f"读取压缩文件 {self.localPath} 中的 {img_file} 失败: {e}") + frame = None + break + index += 1 + else: + if self._frameIndex < 0: + self._frameIndex = 0 + if self._frameIndex >= len(self.localFiles): + self._frameIndex = 0 + if self._frameIndex < len(self.localFiles): + frame = cv2.imread(self.localFiles[self._frameIndex]) + if frame is None: + logger.error(f"无法读取目标目录 {self.localPath}中下标为 {self._frameIndex} 的视频文件 {self.localFiles[self._frameIndex]}") + self._frameIndex = -1 + return + self._frameIndex += 1 + elif self.localFile is not None and self.localFile != '': + frame = cv2.imread(self.localFile) + if frame is None: + logger.error("无法读取本地视频文件: {}", self.localFile) + return + if frame is not None: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if self.totalFrames > 0: + Helper.App.progress = self._frameIndex + return frame + + def test_yolo11_recognize(self, frame): + processed_frame = self.process_frame_with_yolo(frame, self.request_id) + return processed_frame + + def process_frame_with_yolo(self, frame, requestId): + """ + 使用YOLOv11处理帧 + """ + try: + # 使用YOLO进行目标检测和追踪 + processed_frame, detection_result = self.tracker.process_frame(frame) + + # 在帧上添加处理信息 + fps_info = f"FPS: {1000/max(self.tracker.processing_time, 1):.1f}" + cv2.putText(processed_frame, fps_info, (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + # 添加检测目标数量信息 + if detection_result and detection_result.boxes is not None: + obj_count = len(detection_result.boxes) + count_info = f"Objects: {obj_count}" + cv2.putText(processed_frame, count_info, (10, 70), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + return processed_frame + + except Exception as e: + logger.error("YOLO处理异常:{}, requestId:{}", format_exc(), requestId) + # 如果处理失败,返回原帧 + return frame + + def get_gray_mask(self, frame): + """ + 生成灰度像素的掩码图 + 灰度像素定义:三颜色分量差小于20 + """ + # 创建与原图大小相同的掩码图 + maskMat = np.zeros(frame.shape[:2], dtype=np.uint8) + + # 获取图像的三个颜色通道 + b, g, r = cv2.split(frame) + + r = r.astype(np.int16) + g = g.astype(np.int16) + b = b.astype(np.int16) + # 计算任意两个颜色分量之间的差值 + diff_rg = np.abs(r - g) + is_shadow = (b > r) & (b - r < 40) + diff_rb = np.abs(r - b) + diff_gb = np.abs(g - b) + + # 判断条件:三颜色分量差都小于20 + gray_pixels = (diff_rg < 20 ) & (diff_rb < 20| is_shadow) & (diff_gb < 20) + + # 将满足条件的像素在掩码图中设为255(白色) + maskMat[gray_pixels] = 255 + + return maskMat + + def debugLine(self, line, y_intersect): + x1, y1, x2, y2 = line + length = np.linalg.norm([x2 - x1, y2 - y1]) + # 计算线与水平线的夹角(度数) + # 使用atan2计算弧度,再转换为度数 + angle_rad = math.atan2(y2 - y1, x2 - x1) + angle_deg = math.degrees(angle_rad) + # 调整角度范围到0-180度(平面角) + if angle_deg < 0: + angle_deg += 180 + # angle_deg = min(angle_deg, 180 - angle_deg) + x_intersect = (x2 - x1) * (y_intersect - y1) / (y2 - y1) + x1 + return angle_deg, length, x_intersect + def test_highway_recognize(self, frame, debugFlag = False): + processed_frame = frame.copy() + + try: + IGNORE_HEIGHT = 100 + y_intersect = frame.shape[0] / 2 + frame[:IGNORE_HEIGHT, :] = (255, 0, 0) + + gray_mask = self.get_gray_mask(frame) + + kernel = np.ones((5, 5), np.uint8) # 使用形态学开运算(先腐蚀后膨胀)去除小噪声点 + gray_mask = cv2.erode(gray_mask, kernel) + gray_mask = cv2.erode(gray_mask, kernel) + + # 过滤掉面积小于10000的区域 + contours, _ = cv2.findContours(gray_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + # 创建新的掩码图像,只保留面积大于等于10000的区域 + filtered_mask = np.zeros_like(gray_mask) + for contour in contours: + area = cv2.contourArea(contour) + if area >= 10000: # 填充满足条件的轮廓区域 + cv2.fillPoly(filtered_mask, [contour], 255) + + gray_mask = filtered_mask # 使用过滤后的掩码替换原来的gray_mask + edges = cv2.Canny(frame, 100, 200) # 边缘检测 + road_edges = cv2.bitwise_and(edges, edges, mask=filtered_mask) # 在过滤后的路面区域内进行边缘检测 + + # 用color_mask过滤原图,得到待处理的图 + whiteLineMat = cv2.bitwise_and(processed_frame, processed_frame, mask=filtered_mask) + whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_RGB2GRAY) # 灰度化 + # sobel边缘检测 + whiteLineMat = cv2.Sobel(whiteLineMat, cv2.CV_8U, 1, 0, ksize=3) + tempMat = whiteLineMat.copy() + # whiteLineMat = cv2.Canny(whiteLineMat, 100, 200) + lines = cv2.HoughLinesP(tempMat, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) + whiteLineMat = cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) + + # logger.info(f"{lines.shape[0]} lines: ") + # if lines is not None: + # for line in lines: + # x1, y1, x2, y2 = line[0] + # cv2.line(whiteLineMat, (x1, y1), (x2, y2), (255, 0, 0), 2) + + + # 创建彩色掩码用于叠加(使用绿色标记识别出的路面) + color_mask = cv2.cvtColor(gray_mask, cv2.COLOR_GRAY2RGB) + color_mask[:] = (0, 255, 0) # 设置为绿色 + color_mask = cv2.bitwise_and(color_mask, color_mask, mask=filtered_mask) + + # 先叠加路面绿色标记,再叠加白色线条红色标记 + overlay = cv2.addWeighted(processed_frame, 0.7, color_mask, 0.3, 0) + + # # 在road_edges的基础上,识别其中的实线 + # lines = cv2.HoughLinesP(road_edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) + # logger.info(f"{lines.shape[0]} lines: ") + # linesWithAngle = [] + # # if lines is not None: + # for index, line in enumerate(lines): + # angle_deg, length, x_intersect = self.debugLine(line[0], y_intersect) + # linesWithAngle.append((line, angle_deg, x_intersect)) + # if debugFlag: + # logger.info(f'line {index + 1}: {line}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # linesWithAngle进行聚类算法,按夹角分两类即可 + # 使用自定义的简单K-means聚类实现 + # line_data = np.array([[angle, x_intersect] for line, angle, x_intersect in linesWithAngle]) + # if len(line_data) > 0: + # labels = self._simple_kmeans(line_data, n_clusters=2, random_state=2, random_state=0) + # # 输出两类线的数目 + # logger.info(f"聚类结果:{np.bincount(labels)}") + # if debugFlag: + # lines0 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 0] + # lines1 = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == 1] + # # 取得lines0中所有线段并输出日志信息 + # for index, line in enumerate(lines0): + # angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect) + # logger.info(f'聚类0: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # for index, line in enumerate(lines1): + # angle_deg, length, x_intersect = self.debugLine(line[0][0], y_intersect) + # logger.info(f'聚类1: {line[0]}, 线长:{length:.2f}, 夹角:{angle_deg:.2f}°, 交点:({x_intersect:.2f}, {y_intersect:.2f})') + + # # 保留数量多的类别 + # dominant_cluster = np.argmax(np.bincount(labels)) + # # 绘制dominant_cluster类别的线 + # dominant_lines = [line for idx, line in enumerate(linesWithAngle) if labels[idx] == dominant_cluster] + + # for line, angle, x_intersect in dominant_lines: + # cv2.line(overlay, (int(line[0][0]), int(line[0][1])), (int(line[0][2]), int(line[0][3])), (255, 0, 0), 2) + + return overlay, color_mask, whiteLineMat # cv2.cvtColor(whiteLineMat, cv2.COLOR_GRAY2RGB) # cv2.cvtColor(road_edges, cv2.COLOR_GRAY2RGB) + + except Exception as e: + logger.error("路面识别异常:{}", format_exc()) + # 如果处理失败,返回原始帧 + return processed_frame + + # def _simple_kmeans(self, data, n_clusters=2, max_iter=100, random_state=0): + # """ + # 使用K-means算法对数据进行聚类 + + # 参数: + # data: array-like, 形状为 (n_samples, n_features) 的输入数据 + # n_clusters: int, 聚类数量,默认为2 + # max_iter: int, 最大迭代次数,默认为100 + # random_state: int, 随机种子,用于初始化质心,默认为0 + + # 返回: + # labels: array, 形状为 (n_samples,) 的聚类标签数组 + # """ + # np.random.seed(random_state) + + # # 随机选择初始质心 + # centroids_idx = np.random.choice(len(data), size=n_clusters, replace=False) + # centroids = data[centroids_idx].copy() + + # # 迭代优化质心位置 + # for _ in range(max_iter): + # # 为每个数据点分配最近的质心标签 + # labels = np.zeros(len(data), dtype=int) + # for i, point in enumerate(data): + # distancesi=ids - point, ax(centroids - point, axis=1) ce置为- # 情况如果d sfnpcnsy>d9e,则置为>180 -9dis ini作为新质心 + # new_centroids[c] = data[np.random.choice(len(data))] + + # # 检查收敛条件 + # if np.allclose(centroids, new_centroids): + # break + + # centroids = new_centroids + + # return labels + diff --git a/utils/general.py b/utils/general.py new file mode 100644 index 0000000..9ce7124 --- /dev/null +++ b/utils/general.py @@ -0,0 +1,218 @@ +import os +import re +import shutil +import math +import textwrap +import platform +import subprocess +import webbrowser +from difflib import SequenceMatcher +from importlib_metadata import version as get_package_version +from typing import Iterator, Tuple + +try: + import psutil +except ImportError: + psutil = None + + +def format_bold(text): + return f"\033[1m{text}\033[0m" + + +def format_color(text, color_code): + return f"\033[{color_code}m{text}\033[0m" + + +def gradient_text( + text: str, + start_color: Tuple[int, int, int] = (0, 0, 255), + end_color: Tuple[int, int, int] = (255, 0, 255), + frequency: float = 1.0, +) -> str: + def color_function(t: float) -> Tuple[int, int, int]: + def interpolate(start: float, end: float, t: float) -> float: + # Use a sine wave for smooth, periodic interpolation + return ( + start + + (end - start) * (math.sin(math.pi * t * frequency) + 1) / 2 + ) + + return tuple( + round(interpolate(s, e, t)) for s, e in zip(start_color, end_color) + ) + + def gradient_gen(length: int) -> Iterator[Tuple[int, int, int]]: + return (color_function(i / (length - 1)) for i in range(length)) + + gradient = gradient_gen(len(text)) + return "".join( + f"\033[38;2;{r};{g};{b}m{char}\033[0m" + for char, (r, g, b) in zip(text, gradient) + ) # noqa: E501 + + +def hex_to_rgb(hex_color): + hex_color = hex_color.lstrip("#") + return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) + + +def indent_text(text, indent=4): + return textwrap.indent(text, " " * indent) + + +def is_chinese(s="人工智能"): + # Is string composed of any Chinese characters? + return bool(re.search("[\u4e00-\u9fff]", str(s))) + + +def is_possible_rectangle(points): + if len(points) != 4: + return False + + # Check if four points form a rectangle + # The points are expected to be in the format: + # [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] + dists = [square_dist(points[i], points[(i + 1) % 4]) for i in range(4)] + dists.sort() + + # For a rectangle, the two smallest distances + # should be equal and the two largest should be equal + return dists[0] == dists[1] and dists[2] == dists[3] + + +def square_dist(p, q): + # Calculate the square distance between two points + return (p[0] - q[0]) ** 2 + (p[1] - q[1]) ** 2 + + +def collect_system_info(): + os_info = platform.platform() + cpu_info = platform.processor() + cpu_count = os.cpu_count() + + if psutil: + gib = 1 << 30 + ram = psutil.virtual_memory().total + ram_info = f"{ram / gib:.1f} GB" + total, used, free = shutil.disk_usage("/") + disk_info = f"{(total - free) / gib:.1f}/{total / gib:.1f} GB" + else: + ram_info = "N/A (psutil not installed)" + disk_info = "N/A (psutil not installed)" + + gpu_info = get_gpu_info() + cuda_info = get_cuda_version() + python_info = platform.python_version() + pyqt5_info = get_installed_package_version("PyQt5") + onnx_info = get_installed_package_version("onnx") + ort_info = get_installed_package_version("onnxruntime") + ort_gpu_info = get_installed_package_version("onnxruntime-gpu") + opencv_contrib_info = get_installed_package_version( + "opencv-contrib-python-headless" + ) + + system_info = { + "Operating System": os_info, + "CPU": cpu_info, + "CPU Count": cpu_count, + "RAM": ram_info, + "Disk": disk_info, + "GPU": gpu_info, + "CUDA": cuda_info, + "Python Version": python_info, + } + pkg_info = { + "PyQt5 Version": pyqt5_info, + "ONNX Version": onnx_info, + "ONNX Runtime Version": ort_info, + "ONNX Runtime GPU Version": ort_gpu_info, + "OpenCV Contrib Python Headless Version": opencv_contrib_info, + } + + return system_info, pkg_info + + +def find_most_similar_label(text, valid_labels): + max_similarity = 0 + most_similar_label = valid_labels[0] + + for label in valid_labels: + similarity = SequenceMatcher(None, text, label).ratio() + if similarity > max_similarity: + max_similarity = similarity + most_similar_label = label + + return most_similar_label + + +def get_installed_package_version(package_name): + try: + return get_package_version(package_name) + except Exception: + return None + + +def get_cuda_version(): + try: + nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode( + "utf-8" + ) + version_line = next( + (line for line in nvcc_output.split("\n") if "release" in line), + None, + ) + if version_line: + return version_line.split()[-1] + except Exception: + return None + + +def get_gpu_info(): + try: + smi_output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,name,memory.total", + "--format=csv,noheader,nounits", + ], + encoding="utf-8", + ) + gpu_info_lines = [] + for line in smi_output.strip().split("\n"): + parts = line.split(",") + if len(parts) == 3: + index = parts[0].strip() + name = parts[1].strip() + memory = parts[2].strip() + "MiB" + gpu_info_lines.append(f"CUDA:{index} ({name}, {memory})") + return ", ".join(gpu_info_lines) + except Exception: + return None + + +def open_url(url: str) -> None: + """Open URL in browser while suppressing TTY warnings""" + try: + if platform.system() == "Linux": + # Check if running in WSL + with open("/proc/version", "r") as f: + if "microsoft" in f.read().lower(): + # Use powershell.exe for WSL + subprocess.run( + [ + "powershell.exe", + "-Command", + f'Start-Process "{url}"', + ] + ) + else: + # For native Linux, use xdg-open + subprocess.run( + ["xdg-open", url], stderr=subprocess.DEVNULL + ) + else: + webbrowser.open(url) + except Exception: + # Fallback to regular webbrowser.open + webbrowser.open(url) diff --git a/utils/pull_push.py b/utils/pull_push.py new file mode 100644 index 0000000..42219f7 --- /dev/null +++ b/utils/pull_push.py @@ -0,0 +1,115 @@ +import subprocess as sp +from traceback import format_exc +import cv2, time +import numpy as np +from loguru import logger +from DrGraph.utils.Helper import Helper +from DrGraph.utils.Exception import ServiceException +from DrGraph.utils.Constant import Constant + +class NetStream: + def __init__(self, pull_url, push_url, request_id): + self.pull_url = pull_url + self.push_url = push_url + self.request_id = request_id + self.pull_p = None + + self.width = 1920 + self.height = 1080 * 3 // 2 + self.width_height_3 = 1920 * 1080 * 3 // 2 + self.w_2 = 960 + self.h_2 = 540 + + self.frame_count = 0 + self.start_time = time.time(); + + def clear_pull_p(self,pull_p, requestId): + try: + if pull_p and pull_p.poll() is None: + logger.info("关闭拉流管道, requestId:{}", requestId) + if pull_p.stdout: + pull_p.stdout.close() + pull_p.terminate() + pull_p.wait(timeout=30) + logger.info("拉流管道已关闭, requestId:{}", requestId) + except Exception as e: + logger.error("关闭拉流管道异常: {}, requestId:{}", format_exc(), requestId) + if pull_p and pull_p.poll() is None: + pull_p.kill() + pull_p.wait(timeout=30) + raise e + + def start_pull_p(self, pull_url, requestId): + try: + command = ['D:/DrGraph/DSP/ffmpeg.exe'] + # if pull_url.startswith("rtsp://"): + # command.extend(['-timeout', '20000000', '-rtsp_transport', 'tcp']) + # if pull_url.startswith("http") or pull_url.startswith("rtmp"): + # command.extend(['-rw_timeout', '20000000']) + command.extend(['-re', + '-y', + '-an', + # '-hwaccel', 'cuda', cuvid + '-c:v', 'h264_cuvid', + # '-resize', self.wah, + '-i', pull_url, + '-f', 'rawvideo', + # '-pix_fmt', 'bgr24', + '-r', '25', + '-']) + self.pull_p = sp.Popen(command, stdout=sp.PIPE) + return self.pull_p + except ServiceException as s: + logger.error("构建拉流管道ServiceException异常: url={}, {}, requestId:{}", pull_url, s.msg, requestId) + raise s + except Exception as e: + logger.error("构建拉流管道Exception异常:url={}, {}, requestId:{}", pull_url, format_exc(), requestId) + raise e + + def pull_read_video_stream(self): + result = None + try: + if self.pull_p is None: + self.start_pull_p(self.pull_url, self.request_id) + in_bytes = self.pull_p.stdout.read(self.width_height_3) + if in_bytes is not None and len(in_bytes) > 0: + try: + # result = (np.frombuffer(in_bytes, np.uint8).reshape([height * 3 // 2, width, 3])) + # ValueError: cannot reshape array of size 3110400 into shape (1080,1920) + result = (np.frombuffer(in_bytes, np.uint8)).reshape((self.height, self.width)) + result = cv2.cvtColor(result, cv2.COLOR_YUV2BGR_NV12) + # result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) + if result.shape[1] > Constant.pull_frame_width: + result = cv2.resize(result, (result.shape[1] // 2, result.shape[0] // 2), interpolation=cv2.INTER_LINEAR) + except Exception: + logger.error("视频格式异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[0], + ExceptionType.VIDEO_RESOLUTION_EXCEPTION.value[1]) + except ServiceException as s: + logger.error("ServiceException 读流异常: {}, requestId:{}", s.msg, self.request_id) + self.clear_pull_p(self.pull_p, self.request_id) + self.pull_p = None + result = None + raise s + except Exception: + logger.error("Exception 读流异常:{}, requestId:{}", format_exc(), self.request_id) + self.clear_pull_p(self.pull_p, self.request_id) + self.pull_p = None + self.width = None + self.height = None + self.width_height_3 = None + result = None + logger.error("读流异常:{}, requestId:{}", format_exc(), self.request_id) + return result + + def prepare_pull(self): + if self.pull_p is None: + self.start_time = time.time(); + self.start_pull_p(self.pull_url, self.request_id) + + def next_pull_frame(self): + if self.pull_p is None: + logger.error(f'pull_p is None, requestId: {self.request_id}') + return None + frame = self.pull_read_video_stream() + return frame \ No newline at end of file diff --git a/utils/util_env.py b/utils/util_env.py new file mode 100644 index 0000000..11f6075 --- /dev/null +++ b/utils/util_env.py @@ -0,0 +1,167 @@ +import dotenv, os, sys +from pathlib import Path +from loguru import logger + + +def load_env(): + # 显式指定UTF-8编码加载.env文件 + dotenv.load_dotenv(encoding="utf-8") + + # 从环境变量中获取OpenAI API密钥和基础URL + os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") + os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL") + + # 从环境变量中获取Deepseek API密钥和基础URL + os.environ["DEEPSEEK_API_KEY"] = os.getenv("DEEPSEEK_API_KEY") + os.environ["DEEPSEEK_BASE_URL"] = os.getenv("DEEPSEEK_BASE_URL") + + # 从环境变量中获取Langsmith API密钥和项目 + os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY") + os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT") + + # 从环境变量中获取Tavily API密钥 + os.environ["TRAILY_API_KEY"] = os.getenv("TRAILY_API_KEY") + + # 服务器配置 + os.environ["HOST"] = os.getenv("HOST", "localhost") + os.environ["PORT"] = os.getenv("PORT", "8000") + os.environ["DEBUG"] = os.getenv("DEBUG", "True") + + # 日志配置 + os.environ["LOG_LEVEL"] = os.getenv("LOG_LEVEL", "INFO") + + # 模型配置 + os.environ["DEFAULT_MODEL"] = os.getenv("DEFAULT_MODEL", "gpt-4o-mini") + os.environ["MAX_TOKENS"] = os.getenv("MAX_TOKENS", "2048") + os.environ["TEMPERATURE"] = os.getenv("TEMPERATURE", "0.7") + + os.environ["TESSDATA_PREFIX"] = os.getenv( + "TESSDATA_PREFIX", r"C:\ProgramData\anaconda3\envs\rag\share\tessdata" + ) + logger.info("环境变量加载成功") + +def print_env(): + """打印当前环境变量""" + logger.info("当前环境变量:") + for key, value in os.environ.items(): + logger.info(f"{key}: {value}") + +def check_environment(): + # 添加当前目录到 Python 路径 + current_dir = Path(__file__).parent + sys.path.insert(0, str(current_dir)) + try: + import imageio_ffmpeg + except ImportError as e: + logger.error(f"⚠️ imageio_ffmpeg 模块未安装,跳过 FFmpeg 检查: {e}") + """检查环境配置""" + logger.info("🔍 检查环境配置...") + + # 尝试导入 imageio_ffmpeg 并获取 FFmpeg 路径 + try: + ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() + logger.info(f"✅ 已找到 FFmpeg 路径: {ffmpeg_path}") + os.environ["IMAGEIO_FFMPEG_EXE"] = ffmpeg_path + except ImportError: + logger.warning("⚠️ imageio_ffmpeg 模块未安装,跳过 FFmpeg 检查") + except Exception as e: + logger.warning(f"⚠️ 获取 FFmpeg 路径失败: {e}") + + # 检查 .env 文件 + env_file = current_dir.parent / ".env" + if not env_file.exists(): + logger.warning(f"⚠️ 未找到 .env 文件{env_file},使用默认配置") + logger.warning(f"💡 建议创建 .env 文件{env_file}并配置 OPENAI_API_KEY") + + # 检查 OpenAI API Key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + logger.warning("⚠️ 未配置 OPENAI_API_KEY") + logger.warning("💡 请在 .env 文件中设置: OPENAI_API_KEY=your_api_key") + else: + logger.info(f"✅ OpenAI API Key 已配置 (前6位: {api_key[:6]}...)") + + # 检查日志目录 + logs_dir = current_dir / "logs" + if not logs_dir.exists(): + logs_dir.mkdir(exist_ok=True) + logger.warning(f"📁 创建日志目录: {logs_dir}") + + logger.info("✅ 环境检查完成\n") + + +def install_dependencies(): + """检查并安装依赖""" + logger.info("📦 检查依赖包...") + + try: + import fastapi + import uvicorn + import langchain + + logger.info("✅ 主要依赖包已安装") + except ImportError as e: + logger.error(f"❌ 缺少依赖包: {e}") + logger.error("💡 请运行: pip install -r requirements.txt") + return False + + return True + + +from typing import List +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """应用配置""" + + # OpenAI 配置 + openai_api_key: str = "" + openai_base_url: str = "https://api.openai.com/v1" + + # OpenXLab 配置 + open_x_lab_ak: str = "" + open_x_lab_sk: str = "" + + # DashScope 配置 + dashscope_api_key: str = "" + dashscope_base_url: str = "" + + # Firecrawl 配置 + firecrawl_api_key: str = "" + + # DeepSeek 配置 + deepseek_api_key: str = "" + deepseek_base_url: str = "" + + # LangSmith 配置 + langsmith_project: str = "" + langsmith_api_key: str = "" + + # Traily 配置 + traily_api_key: str = "" + + # 服务器配置 + host: str = "localhost" + port: int = 8000 + debug: bool = True + + # CORS 配置 + allowed_origins: List[str] = ["*"] + + # 日志配置 + log_level: str = "INFO" + log_file: str = "logs/app.log" + + # 模型配置 + default_model: str = "gpt-4o" + max_tokens: int = 2048 + temperature: float = 0.7 + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + +# 创建全局配置实例 +settings = Settings() diff --git a/utils/util_exceptions.py b/utils/util_exceptions.py new file mode 100644 index 0000000..3090501 --- /dev/null +++ b/utils/util_exceptions.py @@ -0,0 +1,229 @@ +"""Custom exceptions and error handlers for the chat agent application.""" +from typing import Any, Dict, List, Optional +from fastapi import HTTPException, Request, status +from fastapi.responses import JSONResponse, StreamingResponse +from typing import Union +from pydantic import BaseModel +from datetime import datetime +import json +from loguru import logger + +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, + HTTP_404_NOT_FOUND, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, +) + +# 创建一个完整的响应模型 +class FullHxfResponseModel(BaseModel): + """完整的响应模型,包含状态码、数据、错误信息等""" + code: int + status: int + data: Dict[str, Any] + error: Optional[Dict[str, Any]] + message: Optional[str] + +class HxfResponse(JSONResponse): + def __new__(cls, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]], StreamingResponse]): + # 如果是StreamingResponse,直接返回,不进行JSON包装 + if isinstance(response, StreamingResponse): + return response + # 否则创建HxfResponse实例 + return super().__new__(cls) + + def __init__(self, response: Union[BaseModel, Dict[str, Any], List[BaseModel], List[Dict[str, Any]]]): + code = 0 + if isinstance(response, list): + # 处理BaseModel对象列表 + if all(isinstance(item, BaseModel) for item in response): + data_dict = [item.model_dump(mode='json') for item in response] + else: + data_dict = response + elif isinstance(response, BaseModel): + # 处理单个BaseModel对象 + data_dict = response.model_dump(mode='json') + else: + # 处理字典或其他可JSON序列化的数据 + data_dict = response + + if 'success' in data_dict and data_dict['success'] == False: + code = -1 + + content = { + "code": code, + "status": status.HTTP_200_OK, + "data": data_dict, + "error": None, + "message": None + } + super().__init__( + content=content, + status_code=status.HTTP_200_OK, + media_type="application/json" + ) + +class HxfErrorResponse(JSONResponse): + """Custom JSON response class with standard format.""" + + def __init__(self, message: Union[str, Exception], status_code: int = status.HTTP_401_UNAUTHORIZED): + """Return a JSON error response.""" + if isinstance(message, Exception): + msg = message.message if 'message' in message.__dict__ else str(message) + logger.error(f"[HxfErrorResponse] - {type(message)}, 异常: {message}") + if isinstance(message, TypeError): + content = { + "code": -1, + "status": 500, + "data": None, + "error": f"错误类型: {type(message)} 错误信息: {message}", + "message": msg + } + else: + content = { + "code": -1, + "status": message.status_code, + "data": None, + "error": None, + "message": msg + } + else: + content = { + "code": -1, + "status": status_code, + "data": None, + "error": None, + "message": message + } + super().__init__(content=content, status_code=status_code) + +class ChatAgentException(Exception): + """Base exception for chat agent application.""" + + def __init__( + self, + message: str, + status_code: int = HTTP_500_INTERNAL_SERVER_ERROR, + details: Optional[Dict[str, Any]] = None + ): + self.message = message + self.status_code = status_code + self.details = details or {} + super().__init__(self.message) + + +class ValidationError(ChatAgentException): + """Validation error exception.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + super().__init__(message, HTTP_422_UNPROCESSABLE_ENTITY, details) + + +class AuthenticationError(ChatAgentException): + """Authentication error exception.""" + + def __init__(self, message: str = "Authentication failed"): + super().__init__(message, HTTP_401_UNAUTHORIZED) + + +class AuthorizationError(ChatAgentException): + """Authorization error exception.""" + + def __init__(self, message: str = "Access denied"): + super().__init__(message, HTTP_403_FORBIDDEN) + + +class NotFoundError(ChatAgentException): + """Resource not found exception.""" + + def __init__(self, message: str = "Resource not found"): + super().__init__(message, HTTP_404_NOT_FOUND) + + +class ConversationNotFoundError(NotFoundError): + """Conversation not found exception.""" + + def __init__(self, conversation_id: str): + super().__init__(f"Conversation with ID {conversation_id} not found") + + +class UserNotFoundError(NotFoundError): + """User not found exception.""" + + def __init__(self, user_id: str): + super().__init__(f"User with ID {user_id} not found") + + +class ChatServiceError(ChatAgentException): + """Chat service error exception.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + super().__init__(message, HTTP_500_INTERNAL_SERVER_ERROR, details) + + +class OpenAIError(ChatServiceError): + """OpenAI API error exception.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + super().__init__(f"OpenAI API error: {message}", details) + + +class RateLimitError(ChatAgentException): + """Rate limit exceeded error.""" + pass + + +class DatabaseError(ChatAgentException): + """Database operation error exception.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + super().__init__(f"Database error: {message}", HTTP_500_INTERNAL_SERVER_ERROR, details) + + +# Error handlers +async def chat_agent_exception_handler(request: Request, exc: ChatAgentException) -> JSONResponse: + """Handle ChatAgentException and its subclasses.""" + from loguru import logger + logger.error( + f"ChatAgentException: {exc.message}", + extra={ + "status_code": exc.status_code, + "details": exc.details, + "path": request.url.path, + "method": request.method + } + ) + + return HxfErrorResponse(exc) + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + """Handle HTTPException.""" + from loguru import logger + logger.warning( + f"HTTPException: {exc.detail}", + extra={ + "status_code": exc.status_code, + "path": request.url.path, + "method": request.method + } + ) + + return HxfErrorResponse(exc) + +async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Handle general exceptions.""" + from loguru import logger + logger.error( + f"Unhandled exception: {str(exc)}", + extra={ + "exception_type": exc.__class__.__name__, + "path": request.url.path, + "method": request.method + }, + exc_info=True + ) + + return HxfErrorResponse(exc) \ No newline at end of file diff --git a/utils/util_file.py b/utils/util_file.py new file mode 100644 index 0000000..5dbb15e --- /dev/null +++ b/utils/util_file.py @@ -0,0 +1,270 @@ +"""File utilities.""" + +import os +import re +import hashlib +import mimetypes +from pathlib import Path +from typing import Optional, List, Dict, Any + +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias + +FileInfo: TypeAlias = Dict[str, Any] +ExtensionList: TypeAlias = List[str] + + +class FileUtils: + """Utility class for file operations. + + This class provides static methods for common file operations like validation, + metadata extraction, hashing, and more. + """ + + # Allowed file extensions for document upload + ALLOWED_EXTENSIONS: set[str] = { + '.txt', '.md', '.csv', # Text files + '.pdf', # PDF files + '.docx', '.doc', # Word documents + '.xlsx', '.xls', # Excel files + '.pptx', '.ppt', # PowerPoint files + '.rtf', # Rich text format + '.odt', '.ods', '.odp' # OpenDocument formats + } + + # MIME type mappings + MIME_TYPE_MAPPING: dict[str, str] = { + '.txt': 'text/plain', + '.md': 'text/markdown', + '.csv': 'text/csv', + '.pdf': 'application/pdf', + '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + '.doc': 'application/msword', + '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + '.xls': 'application/vnd.ms-excel', + '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + '.ppt': 'application/vnd.ms-powerpoint', + '.rtf': 'application/rtf', + '.odt': 'application/vnd.oasis.opendocument.text', + '.ods': 'application/vnd.oasis.opendocument.spreadsheet', + '.odp': 'application/vnd.oasis.opendocument.presentation' + } + + @staticmethod + def sanitize_filename(filename: str) -> str: + """Sanitize filename to remove dangerous characters. + + Args: + filename: The filename to sanitize. + + Returns: + A sanitized filename that is safe to use. + """ + # Remove or replace dangerous characters + sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename) + + # Remove leading/trailing spaces and dots + sanitized = sanitized.strip(' .') + + # Ensure filename is not empty + if not sanitized: + sanitized = 'unnamed_file' + + # Limit filename length + if len(sanitized) > 255: + name, ext = os.path.splitext(sanitized) + sanitized = name[:255 - len(ext)] + ext + + return sanitized + + @staticmethod + def get_file_hash(file_path: str, algorithm: str = 'md5') -> str: + """Calculate file hash. + + Args: + file_path: The path to the file. + algorithm: The hash algorithm to use (default: 'md5'). + + Returns: + The hexadecimal representation of the file hash. + + Raises: + FileNotFoundError: If the file does not exist. + PermissionError: If the file cannot be read. + ValueError: If the specified algorithm is not supported. + """ + try: + hash_func = hashlib.new(algorithm) + + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_func.update(chunk) + + return hash_func.hexdigest() + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except PermissionError: + raise PermissionError(f"Permission denied when reading file: {file_path}") + except ValueError: + raise ValueError(f"Unsupported hash algorithm: {algorithm}") + + @staticmethod + def get_file_info(file_path: str) -> FileInfo: + """Get comprehensive file information. + + Args: + file_path: The path to the file. + + Returns: + A dictionary containing detailed file information. + + Raises: + FileNotFoundError: If the file does not exist. + """ + path = Path(file_path) + + if not path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + stat = path.stat() + + # Get MIME type + mime_type, encoding = mimetypes.guess_type(str(path)) + + return { + 'filename': path.name, + 'extension': path.suffix.lower(), + 'size_bytes': stat.st_size, + 'size_mb': round(stat.st_size / (1024 * 1024), 2), + 'mime_type': mime_type, + 'encoding': encoding, + 'created_at': stat.st_ctime, + 'modified_at': stat.st_mtime, + 'is_file': path.is_file(), + 'is_readable': os.access(file_path, os.R_OK) + } + + @staticmethod + def validate_file_extension(filename: str, allowed_extensions: Optional[ExtensionList] = None) -> bool: + """Validate file extension. + + Args: + filename: The filename to validate. + allowed_extensions: List of allowed extensions (default: ALLOWED_EXTENSIONS). + + Returns: + True if the file extension is allowed, False otherwise. + """ + if allowed_extensions is None: + allowed_extensions = list(FileUtils.ALLOWED_EXTENSIONS) + + extension = Path(filename).suffix.lower() + return extension in allowed_extensions + + @staticmethod + def validate_file_size(file_size: int, max_size: int) -> bool: + """Validate file size.""" + return file_size <= max_size + + @staticmethod + def create_directory(directory_path: str) -> bool: + """Create directory if it doesn't exist. + + Args: + directory_path: The path to the directory to create. + + Returns: + True if the directory was created or already exists, False otherwise. + """ + try: + Path(directory_path).mkdir(parents=True, exist_ok=True) + return True + except PermissionError: + return False + except FileExistsError: + return True # Directory already exists + except Exception: + return False + + @staticmethod + def delete_file(file_path: str) -> bool: + """Safely delete a file. + + Args: + file_path: The path to the file to delete. + + Returns: + True if the file was deleted, False otherwise. + """ + try: + path = Path(file_path) + if path.exists() and path.is_file(): + path.unlink() + return True + return False + except PermissionError: + return False + except FileNotFoundError: + return False # File doesn't exist + except Exception: + return False + + @staticmethod + def get_mime_type(filename: str) -> Optional[str]: + """Get MIME type for filename.""" + extension = Path(filename).suffix.lower() + return FileUtils.MIME_TYPE_MAPPING.get(extension) + + @staticmethod + def format_file_size(size_bytes: int) -> str: + """Format file size in human readable format.""" + if size_bytes == 0: + return "0 B" + + size_names = ["B", "KB", "MB", "GB", "TB"] + i = 0 + size = float(size_bytes) + + while size >= 1024.0 and i < len(size_names) - 1: + size /= 1024.0 + i += 1 + + return f"{size:.1f} {size_names[i]}" + + @staticmethod + def is_text_file(filename: str) -> bool: + """Check if file is a text file.""" + extension = Path(filename).suffix.lower() + return extension in {'.txt', '.md', '.csv', '.rtf'} + + @staticmethod + def is_pdf_file(filename: str) -> bool: + """Check if file is a PDF.""" + extension = Path(filename).suffix.lower() + return extension == '.pdf' + + @staticmethod + def is_office_file(filename: str) -> bool: + """Check if file is an Office document.""" + extension = Path(filename).suffix.lower() + return extension in {'.docx', '.doc', '.xlsx', '.xls', '.pptx', '.ppt', '.odt', '.ods', '.odp'} + + @staticmethod + def get_file_category(filename: str) -> str: + """Get file category based on extension.""" + extension = Path(filename).suffix.lower() + + if extension in {'.txt', '.md', '.csv', '.rtf'}: + return 'text' + elif extension == '.pdf': + return 'pdf' + elif extension in {'.docx', '.doc', '.odt'}: + return 'document' + elif extension in {'.xlsx', '.xls', '.ods'}: + return 'spreadsheet' + elif extension in {'.pptx', '.ppt', '.odp'}: + return 'presentation' + else: + return 'unknown' \ No newline at end of file diff --git a/utils/util_log.py b/utils/util_log.py new file mode 100644 index 0000000..a95dd0d --- /dev/null +++ b/utils/util_log.py @@ -0,0 +1,110 @@ +from ruamel.yaml import YAML +from os.path import join, exists +from loguru import logger +from os import makedirs +import os, sys +import logging +from pathlib import Path + +from utils.general import gradient_text + +# 获取项目根目录 +PROJECT_ROOT = Path(__file__).parent.parent.absolute() + +# 自定义格式化函数 +def relative_path_formatter(record): + """将绝对路径转换为相对路径""" + try: + # 获取文件的绝对路径 + abs_path = Path(record["file"].path) + # 转换为相对于项目根目录的路径 + rel_path = abs_path.relative_to(PROJECT_ROOT) + record["extra"]["relative_path"] = str(rel_path) + except (ValueError, AttributeError): + # 如果转换失败,使用原文件名 + record["extra"]["relative_path"] = record["file"].name + return record + +# 定义一个日志处理器,将标准日志转发到loguru +class LoguruHandler(logging.Handler): + def emit(self, record): + # 获取日志记录的级别和消息 + level = record.levelname + message = self.format(record) + + # 获取原始日志记录的位置信息 + frame_data = { + "file": record.filename, + "line": record.lineno, + "function": record.funcName + } + + # 将标准日志转发到loguru,并使用原始位置信息 + # logger.patch(lambda record: record.update(frame_data)).log(level, f'[后台消息] - {message}') + +def getConfigs(path, read_type="yml"): + """ + 读取配置文件并返回解析后的配置信息 + + :param path: 配置文件路径 + :param read_type: 配置文件类型,默认为'yml',可选'json'或'yml' + :return: 解析后的配置信息,JSON格式返回字典,YML格式返回对应的数据结构 + :raises Exception: 当无法获取配置信息时抛出异常 + """ + path = path.replace("\\", "/") + yaml = YAML(typ="safe", pure=True) + with open(path, "r", encoding="utf-8") as f: + return yaml.load(f) + # 如果未成功读取配置信息,则抛出异常 + raise Exception("路径: %s未获取配置信息" % path) +def init_logger(base_dir: str): + log_config = getConfigs(join(base_dir, "webIOs/configs/th_agenter_config_logger.yml")) + # 判断日志文件是否存在,不存在创建 + base_path = join(base_dir, log_config.get("base_path")) + if not exists(base_path): + makedirs(base_path) + # 移除日志设置 + logger.remove(handler_id=None) + # 打印日志到文件 + if bool(log_config.get("enable_file_log")): + logger.add( + join(base_path, log_config.get("log_name")), + rotation=log_config.get("rotation"), + retention=log_config.get("retention"), + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True, + encoding=log_config.get("encoding"), + filter=relative_path_formatter, + diagnose=True + ) + # 控制台输出 + if bool(log_config.get("enable_stderr")): + logger.add( + sys.stderr, + format=log_config.get("log_fmt"), + level=log_config.get("level"), + enqueue=True, + filter=relative_path_formatter, + diagnose=True + ) + + log_level = log_config.get("level", "INFO") + logging.basicConfig(handlers=[LoguruHandler()], level=log_level) + + # 特别配置uvicorn和fastapi的日志记录器 + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"]: + uvicorn_logger = logging.getLogger(logger_name) + uvicorn_logger.handlers = [LoguruHandler()] + uvicorn_logger.propagate = False + + + print("\n\n\n") + logger.info(f"----=========== {gradient_text('日志配置初始化完成, 开始新的日志记录')} ===========----") + logger.info("🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀") + logger.info("🚀 DDDD RRRR GGGG RRRR AAA PPPPP H H ") + logger.info("🚀 D D R R G R R A A P P H H ") + logger.info("🚀 D D RRRR GGGGG RRRR AAAAA PPPPP HHHHH ") + logger.info("🚀 D D R R G G R R A A P H H ") + logger.info("🚀 DDDD R R GGGGG R R A A P H H ") + logger.info("🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀🚀") diff --git a/utils/util_models.py b/utils/util_models.py new file mode 100644 index 0000000..8d1c135 --- /dev/null +++ b/utils/util_models.py @@ -0,0 +1,134 @@ +from langchain.chat_models import init_chat_model +from langchain_openai import ChatOpenAI + + +def model_deepseek_chat(temperature: float = 1.3): + """ + 初始化并返回一个基于deepseek-chat模型的聊天模型实例。 + DEEPSEEK_API_KEY: 从环境变量中获取的Deepseek API密钥。 + DEEPSEEK_BASE_URL: 从环境变量中获取的Deepseek API基础URL。 + + 定价: + 输入token:$0.55/1M + 输出token:$1.70/1M + 返回: + ChatDeepseek: 一个初始化的ChatDeepseek模型实例,用于与deepseek-chat模型进行交互。 + """ + llm = init_chat_model("deepseek-chat", temperature=temperature) + return llm + + +def model_deepseek_reasoner(temperature: float = 1.3): + """ + 初始化并返回一个基于deepseek-reasoner模型的聊天模型实例。 + + 定价: + 输入token:$0.55/1M + 输出token:$1.70/1M + 返回: + ChatDeepseek: 一个初始化的ChatDeepseek模型实例,用于与deepseek-reasoner模型进行交互。 + """ + llm = init_chat_model("deepseek-reasoner", temperature=temperature) + return llm + + +def model_gpt_4o_mini(temperature: float = 0.7): + """ + 初始化并返回一个基于gpt-4o-mini模型的聊天模型实例。 + + 定价: + 输入token:$0.15/1M + 输出token:$0.60/1M + 返回: + ChatOpenAI: 一个初始化的ChatOpenAI模型实例,用于与gpt-4o-mini模型进行交互。 + """ + # 初始化并返回gpt-4o-mini模型的聊天模型实例 + llm = init_chat_model("gpt-4o-mini", temperature=temperature) + return llm + + +def model_local_deepseek_r1_1dot5b(temperature: float = 0.7, max_tokens: int = 2000): + """ + 初始化并返回一个基于deepseek-r1:1.5b模型的聊天模型实例。 + + 定价:免费 + 返回: + ChatOllama: 一个初始化的ChatOllama模型实例,用于与本地deepseek-r1:1.5b模型进行交互。 + """ + # 初始化并返回deepseek-r1:1.5b模型的聊天模型实例 + llm = init_chat_model( + "ollama:deepseek-r1:1.5b", + base_url="http://localhost:11434", + temperature=temperature, + max_tokens=max_tokens, + ) + return llm + +def model_local_llama3_2dot3b(temperature: float = 0.7, max_tokens: int = 2000): + """ + 初始化并返回一个基于llama3.2:3b模型的聊天模型实例。 + + 定价:免费 + 返回: + ChatOllama: 一个初始化的ChatOllama模型实例,用于与本地llama3.2:3b模型进行交互。 + """ + # 初始化并返回llama3.2:3b模型的聊天模型实例 + llm = init_chat_model( + "ollama:llama3.2:3b", + base_url="http://localhost:11434", + temperature=temperature, + max_tokens=max_tokens, + ) + return llm + + +def model_local_nomic_embed_text_latest( + temperature: float = 0.7, max_tokens: int = 2000 +): + """ + 初始化并返回一个基于nomic-embed-text:latest模型的聊天模型实例。 + + 定价:免费 + 返回: + ChatOllama: 一个初始化的ChatOllama模型实例,用于与本地nomic-embed-text:latest嵌入模型进行交互。 + """ + # 初始化并返回nomic-embed-text:latest模型的聊天模型实例 + llm = init_chat_model( + "ollama:nomic-embed-text:latest", + base_url="http://localhost:11434", + temperature=temperature, + max_tokens=max_tokens, + ) + return llm + + +def model_local_Qwen3_0dot6B_GGUF_latest( + temperature: float = 0.7, max_tokens: int = 2000 +): + """ + 初始化并返回一个基于OxW/Qwen3-0.6B-GGUF:latest模型的聊天模型实例。 + + 返回: + ChatOllama: 一个初始化的ChatOllama模型实例,用于与本地OxW/Qwen3-0.6B-GGUF:latest模型进行交互。 + """ + # 初始化并返回OxW/Qwen3-0.6B-GGUF:latest模型的聊天模型实例 + llm = init_chat_model( + "ollama:OxW/Qwen3-0.6B-GGUF:latest", + base_url="http://localhost:11434", + temperature=temperature, + max_tokens=max_tokens, + ) + return llm + +def model_lan_vllm_qwen2_7b_instruct( + temperature: float = 0.7, max_tokens: int = 2000 +): + llm = ChatOpenAI( + openai_api_base="http://192.168.10.11:8000/v1", # vLLM 服务器地址 + model_name="qwen2-7b-instruct", # 模型名称(与 --served-model-name 一致) + openai_api_key="none", # vLLM 默认不需要 key + max_tokens=512, + temperature=temperature, + streaming=False, # 是否流式输出 + ) + return llm diff --git a/utils/util_node_parameters.py b/utils/util_node_parameters.py new file mode 100644 index 0000000..f8868f2 --- /dev/null +++ b/utils/util_node_parameters.py @@ -0,0 +1,215 @@ +""" +节点参数默认配置工具 +""" +from typing import Dict, List +from th_agenter.schemas.workflow import NodeInputOutput, NodeParameter, ParameterType, NodeType + +def get_default_node_parameters(node_type: NodeType) -> NodeInputOutput: + """获取节点类型的默认输入输出参数""" + + if node_type == NodeType.START: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="workflow_input", + type=ParameterType.OBJECT, + description="工作流初始输入数据", + required=False, + source="input" + ) + ], + outputs=[ + NodeParameter( + name="data", + type=ParameterType.OBJECT, + description="开始节点输出数据" + ) + ] + ) + + elif node_type == NodeType.END: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="final_result", + type=ParameterType.OBJECT, + description="最终结果数据", + required=False, + source="node" + ) + ], + outputs=[ + NodeParameter( + name="workflow_result", + type=ParameterType.OBJECT, + description="工作流最终输出" + ) + ] + ) + + elif node_type == NodeType.LLM: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="prompt_variables", + type=ParameterType.OBJECT, + description="Prompt中使用的变量", + required=False, + source="node" + ), + NodeParameter( + name="user_input", + type=ParameterType.STRING, + description="用户输入文本", + required=False, + source="input" + ) + ], + outputs=[ + NodeParameter( + name="response", + type=ParameterType.STRING, + description="LLM生成的回复" + ), + NodeParameter( + name="tokens_used", + type=ParameterType.NUMBER, + description="使用的token数量" + ) + ] + ) + + elif node_type == NodeType.CODE: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="input_data", + type=ParameterType.OBJECT, + description="代码执行的输入数据", + required=False, + source="node" + ) + ], + outputs=[ + NodeParameter( + name="result", + type=ParameterType.OBJECT, + description="代码执行结果" + ), + NodeParameter( + name="output", + type=ParameterType.STRING, + description="代码输出内容" + ) + ] + ) + + elif node_type == NodeType.HTTP: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="url_params", + type=ParameterType.OBJECT, + description="URL参数", + required=False, + source="node" + ), + NodeParameter( + name="request_body", + type=ParameterType.OBJECT, + description="请求体数据", + required=False, + source="node" + ) + ], + outputs=[ + NodeParameter( + name="response_data", + type=ParameterType.OBJECT, + description="HTTP响应数据" + ), + NodeParameter( + name="status_code", + type=ParameterType.NUMBER, + description="HTTP状态码" + ) + ] + ) + + elif node_type == NodeType.CONDITION: + return NodeInputOutput( + inputs=[ + NodeParameter( + name="condition_data", + type=ParameterType.OBJECT, + description="条件判断的输入数据", + required=True, + source="node" + ) + ], + outputs=[ + NodeParameter( + name="result", + type=ParameterType.BOOLEAN, + description="条件判断结果" + ), + NodeParameter( + name="branch", + type=ParameterType.STRING, + description="执行分支(true/false)" + ) + ] + ) + + else: + # 默认参数 + return NodeInputOutput( + inputs=[ + NodeParameter( + name="input", + type=ParameterType.OBJECT, + description="节点输入数据", + required=False, + source="node" + ) + ], + outputs=[ + NodeParameter( + name="output", + type=ParameterType.OBJECT, + description="节点输出数据" + ) + ] + ) + +def validate_parameter_connections(nodes: List[Dict], connections: List[Dict]) -> List[str]: + """验证节点参数连接的有效性""" + errors = [] + node_dict = {node['id']: node for node in nodes} + + for node in nodes: + if 'parameters' not in node or not node['parameters']: + continue + + for input_param in node['parameters'].get('inputs', []): + if input_param.get('source') == 'node': + source_node_id = input_param.get('source_node_id') + source_field = input_param.get('source_field') + + if not source_node_id: + errors.append(f"节点 {node['name']} 的输入参数 {input_param['name']} 缺少来源节点ID") + continue + + if source_node_id not in node_dict: + errors.append(f"节点 {node['name']} 的输入参数 {input_param['name']} 引用了不存在的节点 {source_node_id}") + continue + + source_node = node_dict[source_node_id] + if 'parameters' in source_node and source_node['parameters']: + source_outputs = source_node['parameters'].get('outputs', []) + output_fields = [output['name'] for output in source_outputs] + + if source_field and source_field not in output_fields: + errors.append(f"节点 {node['name']} 的输入参数 {input_param['name']} 引用了节点 {source_node['name']} 不存在的输出字段 {source_field}") + + return errors \ No newline at end of file diff --git a/utils/util_schemas.py b/utils/util_schemas.py new file mode 100644 index 0000000..52d1bf4 --- /dev/null +++ b/utils/util_schemas.py @@ -0,0 +1,513 @@ +"""Pydantic schemas for API requests and responses.""" + +from typing import Optional, List, Any, Dict, TYPE_CHECKING +from datetime import datetime +from pydantic import BaseModel, Field +from enum import Enum + +if TYPE_CHECKING: + from th_agenter.schemas.permission import RoleResponse + +class MessageRole(str, Enum): + """消息角色枚举""" + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + +class MessageType(str, Enum): + """消息类型枚举""" + TEXT = "text" + IMAGE = "image" + FILE = "file" + AUDIO = "audio" + +# Base schemas +class BaseResponse(BaseModel): + """基础响应模型""" + id: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +# User schemas +class UserBase(BaseModel): + """用户基础模型""" + username: str = Field(..., min_length=3, max_length=50) + email: str = Field(..., max_length=100) + full_name: Optional[str] = Field(None, max_length=100) + bio: Optional[str] = None + avatar_url: Optional[str] = None + + +class UserCreate(UserBase): + """用户创建模型""" + password: str = Field(..., min_length=6) + + +class UserUpdate(BaseModel): + """用户更新模型""" + username: Optional[str] = Field(None, min_length=3, max_length=50) + email: Optional[str] = Field(None, max_length=100) + full_name: Optional[str] = Field(None, max_length=100) + bio: Optional[str] = None + avatar_url: Optional[str] = None + password: Optional[str] = Field(None, min_length=6) + is_active: Optional[bool] = None + department_id: Optional[int] = None + + +class UserResponse(BaseResponse, UserBase): + """用户响应模型""" + is_active: bool + department_id: Optional[int] = None + roles: Optional[List['RoleResponse']] = Field(default=[], description="用户角色列表") + permissions: Optional[List[Dict[str, Any]]] = Field(default=[], description="用户权限列表") + is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员") + + @classmethod + def from_orm(cls, obj): + """从ORM对象创建响应对象,安全处理关系属性(同步版本).""" + # 获取基本字段 + data = { + 'id': obj.id, + 'username': obj.username, + 'email': obj.email, + 'full_name': obj.full_name, + 'is_active': obj.is_active, + 'department_id': obj.department_id, + 'created_at': obj.created_at, + 'updated_at': obj.updated_at, + 'created_by': obj.created_by, + 'updated_by': obj.updated_by, + } + + # 安全处理roles关系 - 仅使用已加载的关系,不尝试刷新 + try: + if hasattr(obj, 'roles'): + try: + from th_agenter.schemas.permission import RoleResponse + # 仅访问已加载的角色,不触发新查询 + data['roles'] = [RoleResponse.from_orm(role) for role in obj.roles if role.is_active] + except Exception: + # 如果访问roles失败(DetachedInstanceError或延迟加载错误),使用空列表 + data['roles'] = [] + else: + data['roles'] = [] + except Exception: + data['roles'] = [] + + # 安全处理权限信息 - 仅使用已加载的关系,不尝试刷新 + try: + permissions = set() + if hasattr(obj, 'roles'): + try: + for role in obj.roles: + if role.is_active: + try: + for perm in role.permissions: + if perm.is_active: + permissions.add((perm.code, perm.name)) + except Exception: + # 权限加载失败,跳过 + continue + except Exception: + # 角色加载失败,跳过 + pass + data['permissions'] = [{'code': code, 'name': name} for code, name in permissions] + except Exception: + data['permissions'] = [] + + # 添加is_superuser字段 + try: + # 检查是否有is_admin属性或is_superuser属性 + if hasattr(obj, 'is_admin'): + data['is_superuser'] = obj.is_admin + elif hasattr(obj, 'is_superuser'): + if callable(obj.is_superuser): + try: + data['is_superuser'] = obj.is_superuser() + except Exception: + data['is_superuser'] = False + else: + data['is_superuser'] = obj.is_superuser + else: + data['is_superuser'] = False + except Exception: + data['is_superuser'] = False + + return cls(**data) + + @classmethod + async def from_orm_async(cls, obj): + """从ORM对象创建响应对象,安全处理关系属性(异步版本).""" + # 获取基本字段 + data = { + 'id': obj.id, + 'username': obj.username, + 'email': obj.email, + 'full_name': obj.full_name, + 'is_active': obj.is_active, + 'department_id': obj.department_id, + 'created_at': obj.created_at, + 'updated_at': obj.updated_at, + 'created_by': obj.created_by, + 'updated_by': obj.updated_by, + } + + # 安全处理roles关系 + try: + from sqlalchemy.orm import object_session + from sqlalchemy.ext.asyncio import AsyncSession + + session = object_session(obj) + roles_loaded = [] + + if hasattr(obj, 'roles'): + # 根据会话类型加载角色 + if session and isinstance(session, AsyncSession): + # 异步会话,使用await刷新 + await session.refresh(obj, ['roles']) + roles_loaded = obj.roles if obj.roles is not None else [] + else: + # 同步会话或无会话,直接访问 + try: + roles_loaded = obj.roles if obj.roles is not None else [] + except Exception: + roles_loaded = [] + else: + roles_loaded = [] + + from th_agenter.schemas.permission import RoleResponse + data['roles'] = [RoleResponse.from_orm(role) for role in roles_loaded] + except Exception as e: + # 如果访问roles失败,使用空列表 + data['roles'] = [] + + # 添加权限信息 + try: + # 获取数据库会话 + from sqlalchemy.orm import object_session + session = object_session(obj) + + is_super_admin = False + if hasattr(obj, 'has_role'): + if callable(obj.has_role): + # 检查has_role是否为异步方法 + import inspect + if inspect.iscoroutinefunction(obj.has_role): + is_super_admin = await obj.has_role('SUPER_ADMIN') + else: + is_super_admin = obj.has_role('SUPER_ADMIN') + + if is_super_admin: + # 超级管理员拥有所有权限 + if session: + from th_agenter.models.permission import Permission + if isinstance(session, AsyncSession): + from sqlalchemy import select + all_permissions = await session.execute(select(Permission).filter(Permission.is_active == True)) + all_permissions = all_permissions.scalars().all() + else: + all_permissions = session.query(Permission).filter(Permission.is_active == True).all() + data['permissions'] = [{'code': perm.code, 'name': perm.name} for perm in all_permissions] + else: + data['permissions'] = [{'code': '*', 'name': '所有权限'}] + else: + # 从角色获取权限 + permissions = set() + # 使用已加载的角色,避免再次访问关系 + for role in roles_loaded: + if role.is_active: + # 同样处理role.permissions关系 + role_perms = [] + if hasattr(role, 'permissions'): + try: + if session and isinstance(session, AsyncSession): + await session.refresh(role, ['permissions']) + role_perms = role.permissions if role.permissions is not None else [] + else: + role_perms = role.permissions if role.permissions is not None else [] + except Exception: + role_perms = [] + + for perm in role_perms: + if perm.is_active: + permissions.add((perm.code, perm.name)) + + data['permissions'] = [{'code': code, 'name': name} for code, name in permissions] + except Exception as e: + # 如果访问权限失败,使用空列表 + data['permissions'] = [] + + # 添加is_superuser字段 + try: + # 检查是否有is_admin属性或is_superuser属性 + if hasattr(obj, 'is_admin'): + data['is_superuser'] = obj.is_admin + elif hasattr(obj, 'is_superuser'): + if callable(obj.is_superuser): + import inspect + if inspect.iscoroutinefunction(obj.is_superuser): + data['is_superuser'] = await obj.is_superuser() + else: + data['is_superuser'] = obj.is_superuser() + else: + data['is_superuser'] = obj.is_superuser + else: + data['is_superuser'] = False + except Exception: + data['is_superuser'] = False + + return cls(**data) + + +# Authentication schemas +class LoginRequest(BaseModel): + """登录请求模型,兼容前端多余字段(如 selectAccount、captcha、username)""" + email: str = Field(..., max_length=100) + password: str = Field(..., min_length=6) + + model_config = {"extra": "ignore"} + + +class Token(BaseModel): + """访问令牌响应模型""" + access_token: str + token_type: str + expires_in: int + + +# Conversation schemas +class ConversationBase(BaseModel): + """对话基础模型""" + title: str = Field(..., min_length=1, max_length=200) + system_prompt: Optional[str] = None + model_name: str = Field(default="gpt-3.5-turbo", max_length=100) + temperature: str = Field(default="0.7", max_length=10) + max_tokens: int = Field(default=2048, ge=1, le=8192) + knowledge_base_id: Optional[int] = None + + +class ConversationCreate(ConversationBase): + """对话创建模型""" + pass + + +class ConversationUpdate(BaseModel): + """对话更新模型""" + title: Optional[str] = Field(None, min_length=1, max_length=200) + system_prompt: Optional[str] = None + model_name: Optional[str] = Field(None, max_length=100) + temperature: Optional[str] = Field(None, max_length=10) + max_tokens: Optional[int] = Field(None, ge=1, le=8192) + is_archived: Optional[bool] = None + + +class ConversationResponse(BaseResponse, ConversationBase): + """对话响应模型""" + user_id: int + is_archived: bool + message_count: int = 0 + last_message_at: Optional[datetime] = None + messages: Optional[List["MessageResponse"]] = None + + +# Message schemas +class MessageBase(BaseModel): + """消息基础模型""" + content: str = Field(..., min_length=1) + role: MessageRole + message_type: MessageType = MessageType.TEXT + metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata") + + +class MessageCreate(MessageBase): + """消息创建模型""" + conversation_id: int + + +class MessageResponse(BaseResponse, MessageBase): + """消息响应模型""" + conversation_id: int + context_documents: Optional[List[Dict[str, Any]]] = None + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + total_tokens: Optional[int] = None + + class Config: + from_attributes = True + populate_by_name = True + + +# Chat schemas +class ChatRequest(BaseModel): + """聊天请求模型""" + message: str = Field(..., min_length=1, max_length=10000) + stream: bool = Field(default=False) + use_knowledge_base: bool = Field(default=False) + knowledge_base_id: Optional[int] = Field(default=None, description="Knowledge base ID for RAG mode") + use_agent: bool = Field(default=False, description="Enable agent mode with tool calling capabilities") + use_langgraph: bool = Field(default=False, description="Enable LangGraph agent mode with advanced tool calling") + temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) + max_tokens: Optional[int] = Field(default=2048, ge=1, le=8192) + + +class ChatResponse(BaseModel): + """聊天响应模型""" + user_message: MessageResponse + assistant_message: MessageResponse + total_tokens: Optional[int] = None + model_used: str + + +class StreamChunk(BaseModel): + """流式响应块模型""" + content: str + role: MessageRole = MessageRole.ASSISTANT + finish_reason: Optional[str] = None + tokens_used: Optional[int] = None + + +# Knowledge Base schemas +class KnowledgeBaseBase(BaseModel): + """知识库基础模型""" + name: str = Field(..., min_length=1, max_length=100) + description: Optional[str] = None + embedding_model: str = Field(default="sentence-transformers/all-MiniLM-L6-v2") + chunk_size: int = Field(default=1000, ge=100, le=5000) + chunk_overlap: int = Field(default=200, ge=0, le=1000) + + +class KnowledgeBaseCreate(KnowledgeBaseBase): + """知识库创建模型""" + pass + + +class KnowledgeBaseUpdate(BaseModel): + """知识库更新模型""" + name: Optional[str] = Field(None, min_length=1, max_length=100) + description: Optional[str] = None + embedding_model: Optional[str] = None + chunk_size: Optional[int] = Field(None, ge=100, le=5000) + chunk_overlap: Optional[int] = Field(None, ge=0, le=1000) + is_active: Optional[bool] = None + + +class KnowledgeBaseResponse(BaseResponse, KnowledgeBaseBase): + """知识库响应模型""" + is_active: bool + vector_db_type: str + collection_name: Optional[str] + document_count: int = 0 + active_document_count: int = 0 + + +# Document schemas +class DocumentBase(BaseModel): + """文档基础模型""" + filename: str + original_filename: str + file_type: str + file_size: int + + +class DocumentUpload(BaseModel): + """文档上传模型""" + knowledge_base_id: int + process_immediately: bool = Field(default=True) + + +class DocumentResponse(BaseResponse, DocumentBase): + """文档响应模型""" + knowledge_base_id: int + file_path: str + mime_type: Optional[str] + is_processed: bool + processing_error: Optional[str] + chunk_count: int = 0 + embedding_model: Optional[str] + file_size_mb: float + + +class DocumentListResponse(BaseModel): + """文档列表响应模型""" + documents: List[DocumentResponse] + total: int + page: int + page_size: int + + +class DocumentProcessingStatus(BaseModel): + """文档处理状态模型""" + document_id: int + status: str # 'pending', 'processing', 'completed', 'failed' + progress: float = Field(default=0.0, ge=0.0, le=100.0) + error_message: Optional[str] = None + chunks_created: int = 0 + estimated_time_remaining: Optional[int] = None # seconds + + +# Error schemas +# Document chunk schemas +class DocumentChunk(BaseModel): + """文档分块模型""" + id: str + content: str + metadata: Dict[str, Any] = Field(default_factory=dict) + page_number: Optional[int] = None + chunk_index: int + start_char: Optional[int] = None + end_char: Optional[int] = None + vector_id: Optional[str] = None + + +class DocumentChunksResponse(BaseModel): + """文档分块响应模型""" + document_id: int + document_name: str + total_chunks: int + chunks: List[DocumentChunk] + + +class ErrorResponse(BaseModel): + """错误响应模型""" + error: str + detail: Optional[str] = None + code: Optional[str] = None + +# 通用返回结构 +class NormalResponse(BaseModel): + """通用返回模型""" + success: bool + message: str + data: Optional[Dict[str, Any]] = None + +class ExcelPreviewRequest(BaseModel): + """Excel预览请求模型""" + file_id: str + page: int = 1 + page_size: int = 20 + +class FileListResponse(BaseModel): + """文件列表响应模型""" + success: bool + message: str + data: Optional[Dict[str, Any]] = None + + +# 解决前向引用问题 +def rebuild_models(): + """重建模型以解决前向引用问题.""" + try: + from th_agenter.schemas.permission import RoleResponse + UserResponse.model_rebuild() + except ImportError: + # 如果无法导入RoleResponse,跳过重建 + pass + + +# 在模块加载时尝试重建模型 +rebuild_models() \ No newline at end of file diff --git a/utils/util_test.py b/utils/util_test.py new file mode 100644 index 0000000..a1e97ba --- /dev/null +++ b/utils/util_test.py @@ -0,0 +1,20 @@ +from modelscope import snapshot_download +from loguru import logger + +def test_vllm(): + from utils.util_models import model_lan_vllm_qwen2_7b_instruct + llm = model_lan_vllm_qwen2_7b_instruct() + response = llm.invoke("今天天气怎么样?") + logger.info(f"AI Message: {response.content}") + +def test_download_model_modelscope(model_name: str = 'iic/VibeVoice-Realtime-0.5B', cache_dir: str = './models'): + try: + model_dir = snapshot_download( + model_name, + cache_dir=cache_dir + ) + logger.info(f"模型 {model_name} 已下载到 {model_dir}") + return model_dir + except Exception as e: + logger.error(f"下载模型 {model_name} 失败: {e}") + return None \ No newline at end of file diff --git a/utils/vclEnums.py b/utils/vclEnums.py new file mode 100644 index 0000000..494fbf6 --- /dev/null +++ b/utils/vclEnums.py @@ -0,0 +1,120 @@ +from enum import Enum, unique + +CommonColors_Backgnd = [ + "#FF0000", # 红色 + "#00FF00", # 绿色 + "#0000FF", # 蓝色 + "#FFFF00", # 黄色 + "#FF00FF", # 品红 + "#00FFFF", # 青色 + "#FFA500", # 橙色 + "#800080", # 紫色 + "#FFC0CB", # 粉色 + "#008000", # 深绿色 + "#000080", # 深蓝色 + "#800000", # 深红色 + "#808000", # 橄榄色 + "#008080", # 青色 + "#808080", # 灰色 + "#FF0080", # 玫瑰红 + "#0080FF", # 天蓝色 + "#FF8000", # 橙红色 + "#8000FF", # 紫罗兰 + "#00FF80" # 海绿色 +] + +CommonColors_Foregnd = [ + "#FFFFFF", # 红色背景 -> 白色文字 + "#000000", # 绿色背景 -> 黑色文字 + "#FFFFFF", # 蓝色背景 -> 白色文字 + "#000000", # 黄色背景 -> 黑色文字 + "#FFFFFF", # 品红背景 -> 白色文字 + "#000000", # 青色背景 -> 黑色文字 + "#000000", # 橙色背景 -> 黑色文字 + "#FFFFFF", # 紫色背景 -> 白色文字 + "#000000", # 粉色背景 -> 黑色文字 + "#FFFFFF", # 深绿色背景 -> 白色文字 + "#FFFFFF", # 深蓝色背景 -> 白色文字 + "#FFFFFF", # 深红色背景 -> 白色文字 + "#FFFFFF", # 橄榄色背景 -> 白色文字 + "#FFFFFF", # 青色背景 -> 白色文字 + "#FFFFFF", # 灰色背景 -> 白色文字 + "#FFFFFF", # 玫瑰红背景 -> 白色文字 + "#000000", # 天蓝色背景 -> 黑色文字 + "#000000", # 橙红色背景 -> 黑色文字 + "#FFFFFF", # 紫罗兰背景 -> 白色文字 + "#000000" # 海绿色背景 -> 黑色文字 +] + +@unique +class LabellingKind(Enum): + Unknown = 0 + Select = 1 + Create = 2 + TempDrag = 3 + TempResize = 4 + TempRotate = 5 + +@unique +class Meta(Enum): + Unknown = 0 + Line = 1 + Rectangle = 2 + Ellipse = 3 + Polygon = 4 + Text = 5 + +@unique +class AiAlg(Enum): + Unknown = 0 + FashionMNIST = 1, + ColorDetector = 2, + Face = 3, + Coco8 = 4 + + +@unique +class Response(Enum): + OK = (0, "OK - 响应成功") + DEBUG = (1, "DEBUG - 正常调试") + WARNING = (2, "FAIL - 响应警告") + ERROR = (3, "ERROR - 响应错误") + EXCEPTION = (4, "EXCEPTION - 响应异常") + CRITICAL = (5, "CRITICAL - 响应严重错误") + +class Flag(Enum): + Paint = True + Client = False + + +class Align(Enum): + Left = 0 + Top = 1 + Right = 2 + Bottom = 3 + Client = 4 + Custom = 5 + +class Anchor(Enum): + Left = 1 + Top = 2 + Right = 4 + Bottom = 8 + +class IconPos(Enum): + NoIcon = 0 + OnlyIcon = 1 + IconLeft = 2 + IconRight = 3 + IconTop = 4 + IconBottom = 5 + +class ControlEvent(Enum): + MouseEnter = 0 + MouseLeave = 1 + MouseMove = 2 + MouseDown = 3 + MouseUp = 4 + Click = 5 + DblClick = 6 + \ No newline at end of file diff --git a/utils/wssServer.py b/utils/wssServer.py new file mode 100644 index 0000000..8e9742d --- /dev/null +++ b/utils/wssServer.py @@ -0,0 +1,136 @@ +import time, websockets, json, inspect, os +import numpy as np +from loguru import logger + +from typing import Dict, Set, Callable, Any, Optional +from traceback import format_exc + +from DrGraph.utils.Helper import * +import DrGraph.utils.vclEnums as enums + +class WebSocketServer: + """ + WebSocket服务器类 + 提供WebSocket服务器功能,包括客户端连接管理、消息处理和广播功能 + """ + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.message_handlers: Dict[str, Callable] = {} + self.server = None + self.onClientConnect = None + self.register_handler("file", self.handle_file) + + def register_handler(self, message_type: str, handler: Callable): + self.message_handlers[message_type] = handler + + async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'): + if isinstance(message, dict): + message = json.dumps(message, ensure_ascii=False) + if Helper.AppFlag_SaveLog: + logger.warning(f"发送消息: {message}") + await websocket.send(f't{message}') + caller = inspect.stack()[1] + logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ") + + async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data): + # if isinstance(data, np.ndarray): + # data = data.tobytes() + await websocket.send(bytearray(b'b' + data)) + + async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str): + try: + data = json.loads(message) + message_type = data.get("type") + payload = data.get("data") + if Helper.AppFlag_SaveLog: + logger.warning(f"收到消息: {message}") + + if message_type in self.message_handlers: + response = await self.message_handlers[message_type](websocket, payload) + if response is not None: + if isinstance(response, (bytes, bytearray, np.ndarray)): + # print("发送图片数据") + await self.response_binary(websocket, response); + else: + await self.response_text(websocket, response, message_type); + elif message_type == 'echo': + print("echo message ", int(time.time() * 1000)) + data["type"] = "echo_response" + await self.response_text(websocket, data); + else: + logger.warning(f"未知类型消息: {message} - {websocket.remote_address}") + await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}")) + except json.JSONDecodeError: + logger.error(f"无效的JSON格式 - {message}") + await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}")) + except BrokenPipeError as e: + logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}") + # BrokenPipeError表示连接已断开,不需要特殊处理,让上层处理ConnectionClosed异常 + except Exception as e: + logger.error(f"处理消息时出错: {format_exc()}") + await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}")) + + async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""): + """ + 处理客户端连接 + + 参数: + websocket (websockets.WebSocketServerProtocol): WebSocket连接对象 + path (str): 请求路径 + """ + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接") + if self.onClientConnect: + await self.onClientConnect(websocket, True) + try: + async for message in websocket: + await self.handle_message(websocket, message) + except websockets.exceptions.ConnectionClosed: + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭") + except BrokenPipeError: + logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError") + finally: + if self.onClientConnect: + await self.onClientConnect(websocket, False) + + async def start(self): + """ + 启动WebSocket服务器 + """ + self.server = await websockets.serve(self.handle_client, self.host, self.port) + logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}") + + async def stop(self): + """ + 停止WebSocket服务器 + """ + if self.server: + self.server.close() + await self.server.wait_closed() + logger.info("WebSocket服务器已停止") + + async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str): + logger.info(f"接收文件: {payload}, {payload['command']}") + command = payload['command'] + if command == 'dir': + path = payload.get('path', '/') + files = [] + folders = [] + try: + # 获取目录下的所有文件和文件夹 + if os.path.exists(path) and os.path.isdir(path): + with os.scandir(path) as entries: + for entry in entries: + if entry.is_file(): + files.append(entry.name) + elif entry.is_dir(): + folders.append(entry.name) + else: + logger.warning(f"路径不存在或不是目录: {path}") + except Exception as e: + logger.error(f"读取目录时出错: {path}, 错误: {e}") + + # 合并文件和文件夹列表 + all_items = folders + files + logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}") \ No newline at end of file diff --git a/vl_main.py b/vl_main.py new file mode 100644 index 0000000..e40369f --- /dev/null +++ b/vl_main.py @@ -0,0 +1,133 @@ +# uvicorn main:app --host 0.0.0.0 --port 8000 --reload +from fastapi import FastAPI +from os.path import dirname, realpath + +from dotenv import load_dotenv +load_dotenv() + +from utils.util_log import init_logger +from loguru import logger +base_dir: str = dirname(realpath(__file__)) +init_logger(base_dir) + +from th_agenter.api.routes import router +# from contextlib import asynccontextmanager +# from starlette.exceptions import HTTPException as StarletteHTTPException +# from fastapi.exceptions import RequestValidationError +# from fastapi.responses import JSONResponse +# from fastapi.staticfiles import StaticFiles +# @asynccontextmanager +# async def lifespan(app: FastAPI): +# """Application lifespan manager.""" +# logger.info("[生命周期] - Starting up TH Agenter application...") +# yield +# # Shutdown +# logger.info("[生命周期] - Shutting down TH Agenter application...") + +# def setup_exception_handlers(app: FastAPI) -> None: +# """Setup global exception handlers.""" + +# # Import custom exceptions and handlers +# from utils.util_exceptions import ChatAgentException, chat_agent_exception_handler + +# @app.exception_handler(ChatAgentException) +# async def custom_chat_agent_exception_handler(request, exc): +# return await chat_agent_exception_handler(request, exc) + +# @app.exception_handler(StarletteHTTPException) +# async def http_exception_handler(request, exc): +# from utils.util_exceptions import HxfErrorResponse +# logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}") +# return HxfErrorResponse(exc) + +# def make_json_serializable(obj): +# """递归地将对象转换为JSON可序列化的格式""" +# if obj is None or isinstance(obj, (str, int, float, bool)): +# return obj +# elif isinstance(obj, bytes): +# return obj.decode('utf-8') +# elif isinstance(obj, (ValueError, Exception)): +# return str(obj) +# elif isinstance(obj, dict): +# return {k: make_json_serializable(v) for k, v in obj.items()} +# elif isinstance(obj, (list, tuple)): +# return [make_json_serializable(item) for item in obj] +# else: +# # For any other object, convert to string +# return str(obj) + +# @app.exception_handler(RequestValidationError) +# async def validation_exception_handler(request, exc): +# # Convert any non-serializable objects to strings in error details +# try: +# errors = make_json_serializable(exc.errors()) +# except Exception as e: +# # Fallback: if even our conversion fails, use a simple error message +# errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}] +# logger.exception(f"Request Validation Error: {errors}") + +# logger.exception(f"validation_error: {errors}") +# return JSONResponse( +# status_code=422, +# content={ +# "error": { +# "type": "validation_error", +# "message": "Request validation failed", +# "details": errors +# } +# } +# ) + +# @app.exception_handler(Exception) +# async def general_exception_handler(request, exc): +# logger.error(f"Unhandled exception: {exc}", exc_info=True) +# return JSONResponse( +# status_code=500, +# content={ +# "error": { +# "type": "internal_error", +# "message": "Internal server error" +# } +# } +# ) + +# def create_app() -> FastAPI: +# """Create and configure FastAPI application.""" +# from th_agenter.core.config import get_settings +# settings = get_settings() + +# # Create FastAPI app +# app = FastAPI( +# title=settings.app_name, +# version=settings.app_version, +# description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由TH Agenter修改", +# debug=settings.debug, +# lifespan=lifespan, +# ) +# app.mount("/static", StaticFiles(directory="static"), name="th_agenter_static") + +# # Add middleware +# from th_agenter.core.app import setup_middleware +# setup_middleware(app, settings) + +# # # Add exception handlers +# setup_exception_handlers(app) +# add_router(app) + +# return app + +# def add_router(app: FastAPI) -> None: +# """Add default routers to the FastAPI application.""" + +# @app.get("/") +# def read_root(): +# logger.info("Hello World") +# return {"Hello": "World"} + +# # Include routers +# app.include_router(router, prefix="/api") + +# app = create_app() +from test.vl_test import vl_test + +vl_test() diff --git a/webIOs/configs/settings.yaml b/webIOs/configs/settings.yaml new file mode 100644 index 0000000..f5c3ba9 --- /dev/null +++ b/webIOs/configs/settings.yaml @@ -0,0 +1,41 @@ +# Chat Agent Configuration +app: + name: "TH Agenter" + version: "0.2.0" + debug: true + environment: "development" + host: "0.0.0.0" + port: 8000 + +# File Configuration +file: + upload_dir: "./data/uploads" + max_size: 10485760 # 10MB + allowed_extensions: [".txt", ".pdf", ".docx", ".md"] + chunk_size: 1000 + chunk_overlap: 200 + semantic_splitter_enabled: true # 启用语义分割器 + +# Storage Configuration +storage: + storage_type: "local" # local or s3 + upload_directory: "./data/uploads" + + # S3 Configuration + s3_bucket_name: "chat-agent-files" + aws_access_key_id: null + aws_secret_access_key: null + aws_region: "us-east-1" + s3_endpoint_url: null + +# CORS Configuration +cors: + allowed_origins: ["*"] + allowed_methods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"] + allowed_headers: ["*"] + +# Chat Configuration +chat: + max_history_length: 10 + system_prompt: "你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。" + max_response_tokens: 1000 \ No newline at end of file diff --git a/webIOs/configs/th_agenter_config_logger.yml b/webIOs/configs/th_agenter_config_logger.yml new file mode 100644 index 0000000..f8464fc --- /dev/null +++ b/webIOs/configs/th_agenter_config_logger.yml @@ -0,0 +1,9 @@ +enable_file_log: true +enable_stderr: true +base_path: "webIOs/output/logs" +log_name: "th_agenter_web.log" +log_fmt: "{time: HH:mm:ss.SSS} [{level:7}] - {message} @ {extra[relative_path]}:{line} in {function}" +level: "INFO" +rotation: "00:00" +retention: "1 days" +encoding: "utf8" \ No newline at end of file