陈博士backend
This commit is contained in:
parent
d964b7d4b9
commit
6bf0601638
|
|
@ -0,0 +1,179 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control. See https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
webIOs/output/logs/
|
||||
|
||||
# OS generated files
|
||||
Thumbs.db
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
|
||||
# FastAPI specific
|
||||
*.pyc
|
||||
uvicorn*.log
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
Generic single-database configuration with an async dbapi.
|
||||
|
||||
alembic revision --autogenerate -m "init"
|
||||
alembic upgrade head
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio, os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if not database_url:
|
||||
raise ValueError("环境变量DATABASE_URL未设置")
|
||||
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
from th_agenter.db import Base
|
||||
# from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
"""init
|
||||
|
||||
Revision ID: 1ea5548d641d
|
||||
Revises:
|
||||
Create Date: 2025-12-13 13:47:07.838600
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1ea5548d641d'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('agent_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('enabled_tools', sa.JSON(), nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('system_message', sa.Text(), nullable=True),
|
||||
sa.Column('verbose', sa.Boolean(), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False)
|
||||
op.create_table('conversations',
|
||||
sa.Column('title', sa.String(length=200), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=True),
|
||||
sa.Column('system_prompt', sa.Text(), nullable=True),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_archived', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations'))
|
||||
)
|
||||
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
|
||||
op.create_table('database_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('db_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('host', sa.String(length=255), nullable=False),
|
||||
sa.Column('port', sa.Integer(), nullable=False),
|
||||
sa.Column('database', sa.String(length=100), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('password', sa.Text(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('connection_params', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')),
|
||||
sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type'))
|
||||
)
|
||||
op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False)
|
||||
op.create_table('documents',
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('doc_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('chunk_count', sa.Integer(), nullable=False),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=True),
|
||||
sa.Column('vector_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_documents'))
|
||||
)
|
||||
op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False)
|
||||
op.create_table('excel_files',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('sheet_names', sa.JSON(), nullable=False),
|
||||
sa.Column('default_sheet', sa.String(length=100), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('preview_data', sa.JSON(), nullable=False),
|
||||
sa.Column('data_types', sa.JSON(), nullable=True),
|
||||
sa.Column('total_rows', sa.JSON(), nullable=True),
|
||||
sa.Column('total_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('last_accessed', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files'))
|
||||
)
|
||||
op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False)
|
||||
op.create_table('knowledge_bases',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=False),
|
||||
sa.Column('chunk_size', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_overlap', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('vector_db_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('collection_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases'))
|
||||
)
|
||||
op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False)
|
||||
op.create_table('llm_configs',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('api_key', sa.String(length=500), nullable=False),
|
||||
sa.Column('base_url', sa.String(length=200), nullable=True),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.Float(), nullable=False),
|
||||
sa.Column('top_p', sa.Float(), nullable=False),
|
||||
sa.Column('frequency_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('presence_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_embedding', sa.Boolean(), nullable=False),
|
||||
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||
op.create_table('messages',
|
||||
sa.Column('conversation_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False),
|
||||
sa.Column('message_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('context_documents', sa.JSON(), nullable=True),
|
||||
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('completion_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
|
||||
)
|
||||
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
|
||||
op.create_table('roles',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_system', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_roles'))
|
||||
)
|
||||
op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True)
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||
op.create_table('table_metadata',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('table_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('table_schema', sa.String(length=50), nullable=False),
|
||||
sa.Column('table_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('table_comment', sa.Text(), nullable=True),
|
||||
sa.Column('database_config_id', sa.Integer(), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('primary_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('foreign_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('indexes', sa.JSON(), nullable=True),
|
||||
sa.Column('sample_data', sa.JSON(), nullable=True),
|
||||
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||
sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False),
|
||||
sa.Column('qa_description', sa.Text(), nullable=True),
|
||||
sa.Column('business_context', sa.Text(), nullable=True),
|
||||
sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata'))
|
||||
)
|
||||
op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('username', sa.String(length=50), nullable=False),
|
||||
sa.Column('email', sa.String(length=100), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.Column('bio', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_users'))
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('user_roles',
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')),
|
||||
sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles'))
|
||||
)
|
||||
op.create_table('workflows',
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'),
|
||||
sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
|
||||
sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'),
|
||||
sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False)
|
||||
op.create_table('workflow_executions',
|
||||
sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')),
|
||||
sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False)
|
||||
op.create_table('node_executions',
|
||||
sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'),
|
||||
sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'),
|
||||
sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'),
|
||||
sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions')
|
||||
op.drop_table('node_executions')
|
||||
op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions')
|
||||
op.drop_table('workflow_executions')
|
||||
op.drop_index(op.f('ix_workflows_id'), table_name='workflows')
|
||||
op.drop_table('workflows')
|
||||
op.drop_table('user_roles')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata')
|
||||
op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata')
|
||||
op.drop_table('table_metadata')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_code'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
op.drop_index(op.f('ix_messages_id'), table_name='messages')
|
||||
op.drop_table('messages')
|
||||
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases')
|
||||
op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases')
|
||||
op.drop_table('knowledge_bases')
|
||||
op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files')
|
||||
op.drop_table('excel_files')
|
||||
op.drop_index(op.f('ix_documents_id'), table_name='documents')
|
||||
op.drop_table('documents')
|
||||
op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs')
|
||||
op.drop_table('database_configs')
|
||||
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
|
||||
op.drop_table('conversations')
|
||||
op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs')
|
||||
op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs')
|
||||
op.drop_table('agent_configs')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import re
|
||||
import traceback
|
||||
from loguru import logger
|
||||
|
||||
class DrGraphSession:
|
||||
def __init__(self, stepIndex: int, msg: str, session_id: str):
|
||||
logger.info(f"DrGraphSession.__init__: stepIndex={stepIndex}, msg={msg}, session_id={session_id}")
|
||||
self.stepIndex = stepIndex
|
||||
self.session_id = session_id
|
||||
|
||||
match = re.search(r";(-\d+)", msg);
|
||||
level = -3
|
||||
if match:
|
||||
level = int(match.group(1))
|
||||
value = value.replace(f";{level}", "")
|
||||
level = -3 + level
|
||||
|
||||
if "警告" in value or value.startswith("WARNING"):
|
||||
self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "异常" in value or value.startswith("EXCEPTION"):
|
||||
self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "成功" in value or value.startswith("SUCCESS"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "开始" in value or value.startswith("START"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "失败" in value or value.startswith("ERROR"):
|
||||
self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
else:
|
||||
self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
|
||||
def log_prefix(self) -> str:
|
||||
"""Get log prefix with session ID and desc."""
|
||||
return f"〖Session{self.session_id}〗"
|
||||
|
||||
def parse_source_pos(self, level: int):
|
||||
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
||||
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||
if match:
|
||||
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||
return pos
|
||||
|
||||
def log_info(self, msg: str, level: int = -2):
|
||||
"""Log info message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_success(self, msg: str, level: int = -2):
|
||||
"""Log success message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_warning(self, msg: str, level: int = -2):
|
||||
"""Log warning message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_error(self, msg: str, level: int = -2):
|
||||
"""Log error message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_exception(self, msg: str, level: int = -2):
|
||||
"""Log exception message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import jwt
|
||||
import inspect
|
||||
|
||||
print(f"jwt module path: {inspect.getfile(jwt)}")
|
||||
print(f"jwt module attributes: {dir(jwt)}")
|
||||
try:
|
||||
print(f"jwt module __version__: {jwt.__version__}")
|
||||
except AttributeError:
|
||||
print("jwt module has no __version__ attribute")
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
container_name: pgvector-db
|
||||
environment:
|
||||
POSTGRES_USER: drgraph
|
||||
POSTGRES_PASSWORD: yingping
|
||||
POSTGRES_DB: th_agenter
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
# docker exec -it pgvector-db psql -U drgraph -d th_agenter
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
# uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
from fastapi import FastAPI
|
||||
from os.path import dirname, realpath
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from utils.util_log import init_logger
|
||||
from loguru import logger
|
||||
base_dir: str = dirname(realpath(__file__))
|
||||
init_logger(base_dir)
|
||||
|
||||
from th_agenter.api.routes import router
|
||||
from contextlib import asynccontextmanager
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("[生命周期] - Starting up TH Agenter application...")
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("[生命周期] - Shutting down TH Agenter application...")
|
||||
|
||||
def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"""Setup global exception handlers."""
|
||||
|
||||
# Import custom exceptions and handlers
|
||||
from utils.util_exceptions import ChatAgentException, chat_agent_exception_handler
|
||||
|
||||
@app.exception_handler(ChatAgentException)
|
||||
async def custom_chat_agent_exception_handler(request, exc):
|
||||
return await chat_agent_exception_handler(request, exc)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
from utils.util_exceptions import HxfErrorResponse
|
||||
logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}")
|
||||
return HxfErrorResponse(exc.status_code, exc.detail)
|
||||
|
||||
def make_json_serializable(obj):
|
||||
"""递归地将对象转换为JSON可序列化的格式"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
elif isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
elif isinstance(obj, (ValueError, Exception)):
|
||||
return str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
# For any other object, convert to string
|
||||
return str(obj)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
# Convert any non-serializable objects to strings in error details
|
||||
try:
|
||||
errors = make_json_serializable(exc.errors())
|
||||
except Exception as e:
|
||||
# Fallback: if even our conversion fails, use a simple error message
|
||||
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||
logger.exception(f"Request Validation Error: {errors}")
|
||||
|
||||
logger.exception(f"validation_error: {errors}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"type": "validation_error",
|
||||
"message": "Request validation failed",
|
||||
"details": errors
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": "Internal server error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
from th_agenter.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由TH Agenter修改",
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.mount("/static", StaticFiles(directory="static"), name="th_agenter_static")
|
||||
|
||||
# Add middleware
|
||||
from th_agenter.core.app import setup_middleware
|
||||
setup_middleware(app, settings)
|
||||
|
||||
# # Add exception handlers
|
||||
setup_exception_handlers(app)
|
||||
add_router(app)
|
||||
|
||||
return app
|
||||
|
||||
def add_router(app: FastAPI) -> None:
|
||||
"""Add default routers to the FastAPI application."""
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
logger.info("Hello World")
|
||||
return {"Hello": "World"}
|
||||
|
||||
# Include routers
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
|
||||
# app.include_router(table_metadata.router)
|
||||
# # 在现有导入中添加
|
||||
# from ..api.endpoints import database_config
|
||||
|
||||
# # 在路由注册部分添加
|
||||
# app.include_router(database_config.router)
|
||||
# # Health check endpoint
|
||||
# @app.get("/health")
|
||||
# async def health_check():
|
||||
# return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# # Root endpoint
|
||||
# @app.get("/")
|
||||
# async def root():
|
||||
# return {"message": "Chat Agent API is running"}
|
||||
|
||||
# # Test endpoint
|
||||
# @app.get("/test")
|
||||
# async def test_endpoint():
|
||||
# return {"message": "API is working"}
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
# from utils.util_test import test_db
|
||||
# test_db()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.9 KiB |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1 @@
|
|||
"""API module for TH Agenter."""
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""API endpoints for TH Agenter."""
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
"""Authentication endpoints."""
|
||||
|
||||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...core.config import get_settings
|
||||
from ...db.database import DrSession, get_session
|
||||
from ...services.auth import AuthService
|
||||
from ...services.user import UserService
|
||||
from ...schemas.user import UserResponse, UserCreate, LoginResponse
|
||||
from utils.util_schemas import Token, LoginRequest
|
||||
from loguru import logger
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter()
|
||||
settings = get_settings()
|
||||
|
||||
@router.post("/register", response_model=UserResponse, summary="注册新用户")
|
||||
async def register(
|
||||
request_user_data: UserCreate,
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""注册新用户"""
|
||||
user_service = UserService(session)
|
||||
session.desc = f"START: 注册用户 {request_user_data.email}"
|
||||
if await user_service.get_user_by_email(request_user_data.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"邮箱 {request_user_data.email} 已被注册,请使用其他邮箱注册!!!"
|
||||
)
|
||||
|
||||
if await user_service.get_user_by_username(request_user_data.username):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"用户名 {request_user_data.username} 已被注册,请使用其他用户名注册!!!"
|
||||
)
|
||||
|
||||
user = await user_service.create_user(request_user_data)
|
||||
response = UserResponse.model_validate(user, from_attributes=True)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/login", response_model=LoginResponse, summary="邮箱与密码登录")
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""邮箱与密码登录"""
|
||||
# Authenticate user by email
|
||||
session.desc = f"START: 用户 {login_data.email} 尝试登录"
|
||||
user = await AuthService.authenticate_user_by_email(session, login_data.email, login_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"邮箱 {login_data.email} 或密码错误,请检查后重试!!!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
session.desc = f"用户 {user.username} 登录成功"
|
||||
|
||||
response = LoginResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.security.access_token_expire_minutes * 60,
|
||||
user=UserResponse.model_validate(user, from_attributes=True)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
@router.post("/login-oauth", response_model=Token, summary="用户通过用户名和密码登录 (OAuth2 兼容)")
|
||||
async def login_oauth(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""用户通过用户名和密码登录 (OAuth2 兼容)"""
|
||||
session.desc = f"START: 用户 {form_data.username} 尝试 OAuth2 登录"
|
||||
user = await AuthService.authenticate_user(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
session.desc = f"用户 {form_data.username} 尝试 OAuth2 登录失败"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
session.desc = f"用户 {user.username} OAuth2 登录成功"
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
|
||||
@router.post("/refresh", response_model=Token, summary="刷新访问token")
|
||||
async def refresh_token(
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""刷新访问 token"""
|
||||
# Create new access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.security.access_token_expire_minutes * 60
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
|
||||
async def get_current_user_info(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
return UserResponse.model_validate(current_user, from_attributes=True)
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""Chat endpoints for TH Agenter."""
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...services.auth import AuthService
|
||||
from ...services.chat import ChatService
|
||||
from ...services.conversation import ConversationService
|
||||
from utils.util_schemas import (
|
||||
ConversationCreate,
|
||||
ConversationResponse,
|
||||
ConversationUpdate,
|
||||
MessageCreate,
|
||||
MessageResponse,
|
||||
ChatRequest,
|
||||
ChatResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Conversation management
|
||||
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
||||
async def create_conversation(
|
||||
conversation_data: ConversationCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""创建新对话"""
|
||||
session.desc = "START: 创建新对话"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
conversation_data=conversation_data
|
||||
)
|
||||
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
||||
async def list_conversations(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
order_by: str = "updated_at",
|
||||
order_desc: bool = True,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话列表"""
|
||||
session.desc = "START: 获取用户对话列表"
|
||||
conversation_service = ConversationService(session)
|
||||
conversations = await conversation_service.get_user_conversations(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search_query=search,
|
||||
include_archived=include_archived,
|
||||
order_by=order_by,
|
||||
order_desc=order_desc
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话"
|
||||
return [ConversationResponse.model_validate(conv) for conv in conversations]
|
||||
|
||||
@router.get("/conversations/count", summary="获取用户对话总数")
|
||||
async def get_conversations_count(
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取用户对话总数"""
|
||||
session.desc = "START: 获取用户对话总数"
|
||||
conversation_service = ConversationService(session)
|
||||
count = await conversation_service.get_user_conversations_count(
|
||||
search_query=search,
|
||||
include_archived=include_archived
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
||||
return {"count": count}
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取指定对话"""
|
||||
session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.get_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Conversation not found"
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
|
||||
async def update_conversation(
|
||||
conversation_id: int,
|
||||
conversation_update: ConversationUpdate,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""更新指定对话"""
|
||||
session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}"
|
||||
conversation_service = ConversationService(session)
|
||||
updated_conversation = await conversation_service.update_conversation(
|
||||
conversation_id, conversation_update
|
||||
)
|
||||
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(updated_conversation)
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
|
||||
async def delete_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""删除指定对话"""
|
||||
session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
await conversation_service.delete_conversation(conversation_id)
|
||||
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation deleted successfully"}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
|
||||
async def archive_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""归档指定对话."""
|
||||
conversation_service = ConversationService(session)
|
||||
success = await conversation_service.archive_conversation(conversation_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to archive conversation"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation archived successfully"}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
|
||||
async def unarchive_conversation(
|
||||
conversation_id: int,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""取消归档指定对话."""
|
||||
session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
success = await conversation_service.unarchive_conversation(conversation_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to unarchive conversation"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation unarchived successfully"}
|
||||
|
||||
|
||||
# Message management
|
||||
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息")
|
||||
async def get_conversation_messages(
|
||||
conversation_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取指定对话的消息"""
|
||||
session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||
conversation_service = ConversationService(session)
|
||||
messages = await conversation_service.get_conversation_messages(
|
||||
conversation_id, skip=skip, limit=limit
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||
return [MessageResponse.model_validate(msg) for msg in messages]
|
||||
|
||||
# Chat functionality
|
||||
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
|
||||
async def chat(
|
||||
conversation_id: int,
|
||||
chat_request: ChatRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""发送消息并获取AI响应"""
|
||||
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
|
||||
chat_service = ChatService(session)
|
||||
response = await chat_service.chat(
|
||||
conversation_id=conversation_id,
|
||||
message=chat_request.message,
|
||||
stream=False,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
use_agent=chat_request.use_agent,
|
||||
use_langgraph=chat_request.use_langgraph,
|
||||
use_knowledge_base=chat_request.use_knowledge_base,
|
||||
knowledge_base_id=chat_request.knowledge_base_id
|
||||
)
|
||||
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
|
||||
|
||||
return response
|
||||
|
||||
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
|
||||
async def chat_stream(
|
||||
conversation_id: int,
|
||||
chat_request: ChatRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""发送消息并获取流式AI响应."""
|
||||
chat_service = ChatService(session)
|
||||
|
||||
async def generate_response():
|
||||
async for chunk in chat_service.chat_stream(
|
||||
conversation_id=conversation_id,
|
||||
message=chat_request.message,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
use_agent=chat_request.use_agent,
|
||||
use_langgraph=chat_request.use_langgraph,
|
||||
use_knowledge_base=chat_request.use_knowledge_base,
|
||||
knowledge_base_id=chat_request.knowledge_base_id
|
||||
):
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
"""数据库配置管理API"""
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from th_agenter.models.user import User
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.database_config_service import DatabaseConfigService
|
||||
from th_agenter.services.auth import AuthService
|
||||
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||
|
||||
# 在文件顶部添加
|
||||
from functools import lru_cache
|
||||
|
||||
router = APIRouter(prefix="/api/database-config", tags=["database-config"])
|
||||
# 创建服务单例
|
||||
@lru_cache()
|
||||
def get_database_config_service() -> DatabaseConfigService:
|
||||
"""获取DatabaseConfigService单例"""
|
||||
# 注意:这里需要处理db session的问题
|
||||
return DatabaseConfigService(None) # 临时方案
|
||||
|
||||
# 或者使用全局变量
|
||||
_database_service_instance = None
|
||||
|
||||
def get_database_service(session: Session = Depends(get_session)) -> DatabaseConfigService:
|
||||
"""获取DatabaseConfigService实例"""
|
||||
global _database_service_instance
|
||||
if _database_service_instance is None:
|
||||
_database_service_instance = DatabaseConfigService(session)
|
||||
else:
|
||||
# 更新db session
|
||||
_database_service_instance.db = session
|
||||
return _database_service_instance
|
||||
|
||||
class DatabaseConfigCreate(BaseModel):
|
||||
name: str = Field(..., description="配置名称")
|
||||
db_type: str = Field(default="postgresql", description="数据库类型")
|
||||
host: str = Field(..., description="主机地址")
|
||||
port: int = Field(..., description="端口号")
|
||||
database: str = Field(..., description="数据库名")
|
||||
username: str = Field(..., description="用户名")
|
||||
password: str = Field(..., description="密码")
|
||||
is_default: bool = Field(default=False, description="是否为默认配置")
|
||||
connection_params: Dict[str, Any] = Field(default=None, description="额外连接参数")
|
||||
|
||||
class DatabaseConfigResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
db_type: str
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
is_active: bool
|
||||
is_default: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
@router.post("/", response_model=NormalResponse, summary="创建或更新数据库配置")
|
||||
async def create_database_config(
|
||||
config_data: DatabaseConfigCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""创建或更新数据库配置"""
|
||||
config = await service.create_or_update_config(current_user.id, config_data.model_dump())
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="保存数据库配置成功",
|
||||
data=config
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表")
|
||||
async def get_database_configs(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""获取用户的数据库配置列表"""
|
||||
configs = service.get_user_configs(current_user.id)
|
||||
|
||||
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||
return config_list
|
||||
|
||||
@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""测试数据库连接"""
|
||||
result = await service.test_connection(config_id, current_user.id)
|
||||
return result
|
||||
|
||||
@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表")
|
||||
async def connect_database(
|
||||
config_id: int,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""连接数据库并获取表列表"""
|
||||
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/tables/{table_name}/data", summary="获取表数据预览")
|
||||
async def get_table_data(
|
||||
table_name: str,
|
||||
db_type: str,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""获取表数据预览"""
|
||||
try:
|
||||
result = await service.get_table_data(table_name, current_user.id, db_type, limit)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取表数据失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
@router.get("/tables/{table_name}/schema", summary="获取表结构信息")
|
||||
async def get_table_schema(
|
||||
table_name: str,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""获取表结构信息"""
|
||||
result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的?
|
||||
return result
|
||||
|
||||
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置")
|
||||
async def get_config_by_type(
|
||||
db_type: str,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""根据数据库类型获取配置"""
|
||||
config = service.get_config_by_type(current_user.id, db_type)
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到类型为 {db_type} 的配置"
|
||||
)
|
||||
# 返回包含解密密码的配置
|
||||
return config.to_dict(include_password=True, decrypt_service=service)
|
||||
|
|
@ -0,0 +1,599 @@
|
|||
"""Knowledge base API endpoints."""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.knowledge_base import KnowledgeBase, Document
|
||||
from ...services.knowledge_base import KnowledgeBaseService
|
||||
from ...services.document import DocumentService
|
||||
from ...services.auth import AuthService
|
||||
from utils.util_schemas import (
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeBaseResponse,
|
||||
DocumentResponse,
|
||||
DocumentListResponse,
|
||||
DocumentUpload,
|
||||
DocumentProcessingStatus,
|
||||
DocumentChunksResponse,
|
||||
ErrorResponse
|
||||
)
|
||||
from utils.util_file import FileUtils
|
||||
from ...core.config import settings
|
||||
|
||||
router = APIRouter(tags=["knowledge-bases"])
|
||||
|
||||
@router.post("/", response_model=KnowledgeBaseResponse, summary="创建新的知识库")
|
||||
async def create_knowledge_base(
|
||||
kb_data: KnowledgeBaseCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""创建新的知识库"""
|
||||
# Check if knowledge base with same name already exists for this user
|
||||
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data.name}"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
|
||||
existing_kb = service.get_knowledge_base_by_name(kb_data.name)
|
||||
if existing_kb:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"知识库名称 {kb_data.name} 已存在"
|
||||
)
|
||||
|
||||
# Create knowledge base
|
||||
session.desc = f"知识库 {kb_data.name}不存在,创建之"
|
||||
kb = service.create_knowledge_base(kb_data)
|
||||
|
||||
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
|
||||
return KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
name=kb.name,
|
||||
description=kb.description,
|
||||
embedding_model=kb.embedding_model,
|
||||
chunk_size=kb.chunk_size,
|
||||
chunk_overlap=kb.chunk_overlap,
|
||||
is_active=kb.is_active,
|
||||
vector_db_type=kb.vector_db_type,
|
||||
collection_name=kb.collection_name,
|
||||
document_count=0,
|
||||
active_document_count=0
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
|
||||
async def list_knowledge_bases(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search: Optional[str] = None,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户的所有知识库"""
|
||||
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
|
||||
knowledge_bases = await service.get_knowledge_bases(skip=skip, limit=limit)
|
||||
|
||||
result = []
|
||||
for kb in knowledge_bases:
|
||||
# Count documents
|
||||
total_docs = await session.scalar(
|
||||
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||
)
|
||||
|
||||
active_docs = await session.scalar(
|
||||
select(func.count()).where(
|
||||
Document.knowledge_base_id == kb.id,
|
||||
Document.is_processed == True
|
||||
)
|
||||
)
|
||||
|
||||
result.append(KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
name=kb.name,
|
||||
description=kb.description,
|
||||
embedding_model=kb.embedding_model,
|
||||
chunk_size=kb.chunk_size,
|
||||
chunk_overlap=kb.chunk_overlap,
|
||||
is_active=kb.is_active,
|
||||
vector_db_type=kb.vector_db_type,
|
||||
collection_name=kb.collection_name,
|
||||
document_count=total_docs,
|
||||
active_document_count=active_docs
|
||||
))
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
|
||||
return result
|
||||
|
||||
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
|
||||
async def get_knowledge_base(
|
||||
kb_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""根据知识库ID获取知识库详情"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 的详情"
|
||||
service = KnowledgeBaseService(session)
|
||||
session.desc = f"检查知识库 {kb_id} 是否存在"
|
||||
kb = service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Count documents
|
||||
total_docs = await session.scalar(
|
||||
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||
)
|
||||
session.desc = f"获取知识库 {kb_id} 共 {total_docs} 个文档"
|
||||
|
||||
active_docs = await session.scalar(
|
||||
select(func.count()).where(
|
||||
Document.knowledge_base_id == kb.id,
|
||||
Document.is_processed == True
|
||||
)
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||
return KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
name=kb.name,
|
||||
description=kb.description,
|
||||
embedding_model=kb.embedding_model,
|
||||
chunk_size=kb.chunk_size,
|
||||
chunk_overlap=kb.chunk_overlap,
|
||||
is_active=kb.is_active,
|
||||
vector_db_type=kb.vector_db_type,
|
||||
collection_name=kb.collection_name,
|
||||
document_count=total_docs,
|
||||
active_document_count=active_docs
|
||||
)
|
||||
|
||||
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
|
||||
async def update_knowledge_base(
|
||||
kb_id: int,
|
||||
kb_data: KnowledgeBaseCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""更新知识库"""
|
||||
session.desc = f"START: 更新知识库 {kb_id}"
|
||||
service = KnowledgeBaseService(session)
|
||||
kb = service.update_knowledge_base(kb_id, kb_data)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Count documents
|
||||
total_docs = await session.scalar(
|
||||
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||
)
|
||||
|
||||
active_docs = await session.scalar(
|
||||
select(func.count()).where(
|
||||
Document.knowledge_base_id == kb.id,
|
||||
Document.is_processed == True
|
||||
)
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 更新知识库 {kb_id},共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||
return KnowledgeBaseResponse(
|
||||
id=kb.id,
|
||||
created_at=kb.created_at,
|
||||
updated_at=kb.updated_at,
|
||||
name=kb.name,
|
||||
description=kb.description,
|
||||
embedding_model=kb.embedding_model,
|
||||
chunk_size=kb.chunk_size,
|
||||
chunk_overlap=kb.chunk_overlap,
|
||||
is_active=kb.is_active,
|
||||
vector_db_type=kb.vector_db_type,
|
||||
collection_name=kb.collection_name,
|
||||
document_count=total_docs,
|
||||
active_document_count=active_docs
|
||||
)
|
||||
|
||||
@router.delete("/{kb_id}", summary="删除知识库")
|
||||
async def delete_knowledge_base(
|
||||
kb_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除知识库"""
|
||||
session.desc = f"START: 删除知识库 {kb_id}"
|
||||
service = KnowledgeBaseService(session)
|
||||
success = service.delete_knowledge_base(kb_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
|
||||
return {"message": "Knowledge base deleted successfully"}
|
||||
|
||||
# Document management endpoints
|
||||
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
|
||||
async def upload_document(
|
||||
kb_id: int,
|
||||
file: UploadFile = File(...),
|
||||
process_immediately: bool = Form(True),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""上传文档到知识库"""
|
||||
session.desc = f"START: 上传文档到知识库 {kb_id}"
|
||||
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Validate file
|
||||
if not FileUtils.validate_file_extension(file.filename):
|
||||
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# Check file size (50MB limit)
|
||||
max_size = 50 * 1024 * 1024 # 50MB
|
||||
if file.size and file.size > max_size:
|
||||
session.desc = f"ERROR: 文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
|
||||
)
|
||||
|
||||
# Upload document
|
||||
doc_service = DocumentService(session)
|
||||
document = await doc_service.upload_document(
|
||||
file, kb_id
|
||||
)
|
||||
|
||||
# Process document immediately if requested
|
||||
if process_immediately:
|
||||
try:
|
||||
await doc_service.process_document(document.id, kb_id)
|
||||
# Refresh document to get updated status
|
||||
await session.refresh(document)
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
|
||||
|
||||
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
|
||||
return DocumentResponse(
|
||||
id=document.id,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
knowledge_base_id=document.knowledge_base_id,
|
||||
filename=document.filename,
|
||||
original_filename=document.original_filename,
|
||||
file_path=document.file_path,
|
||||
file_type=document.file_type,
|
||||
file_size=document.file_size,
|
||||
mime_type=document.mime_type,
|
||||
is_processed=document.is_processed,
|
||||
processing_error=document.processing_error,
|
||||
chunk_count=document.chunk_count or 0,
|
||||
embedding_model=document.embedding_model,
|
||||
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
|
||||
async def list_documents(
|
||||
kb_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档列表。"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
documents, total = doc_service.list_documents(kb_id, skip, limit)
|
||||
|
||||
doc_responses = []
|
||||
for doc in documents:
|
||||
doc_responses.append(DocumentResponse(
|
||||
id=doc.id,
|
||||
created_at=doc.created_at,
|
||||
updated_at=doc.updated_at,
|
||||
knowledge_base_id=doc.knowledge_base_id,
|
||||
filename=doc.filename,
|
||||
original_filename=doc.original_filename,
|
||||
file_path=doc.file_path,
|
||||
file_type=doc.file_type,
|
||||
file_size=doc.file_size,
|
||||
mime_type=doc.mime_type,
|
||||
is_processed=doc.is_processed,
|
||||
processing_error=doc.processing_error,
|
||||
chunk_count=doc.chunk_count or 0,
|
||||
embedding_model=doc.embedding_model,
|
||||
file_size_mb=round(doc.file_size / (1024 * 1024), 2)
|
||||
))
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total} 条"
|
||||
return DocumentListResponse(
|
||||
documents=doc_responses,
|
||||
total=total,
|
||||
page=skip // limit + 1,
|
||||
page_size=limit
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
|
||||
async def get_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档详情。"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
# Verify knowledge base exists and user has access
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||
return DocumentResponse(
|
||||
id=document.id,
|
||||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
knowledge_base_id=document.knowledge_base_id,
|
||||
filename=document.filename,
|
||||
original_filename=document.original_filename,
|
||||
file_path=document.file_path,
|
||||
file_type=document.file_type,
|
||||
file_size=document.file_size,
|
||||
mime_type=document.mime_type,
|
||||
is_processed=document.is_processed,
|
||||
processing_error=document.processing_error,
|
||||
chunk_count=document.chunk_count or 0,
|
||||
embedding_model=document.embedding_model,
|
||||
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||
)
|
||||
|
||||
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
|
||||
async def delete_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除知识库中的文档。"""
|
||||
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
success = doc_service.delete_document(doc_id, kb_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||
return {"message": "Document deleted successfully"}
|
||||
|
||||
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
|
||||
async def process_document(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""处理知识库中的文档,用于向量搜索。"""
|
||||
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Check if document exists
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Process the document
|
||||
result = await doc_service.process_document(doc_id, kb_id)
|
||||
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||
return DocumentProcessingStatus(
|
||||
document_id=doc_id,
|
||||
status=result["status"],
|
||||
progress=result.get("progress", 0.0),
|
||||
error_message=result.get("error_message"),
|
||||
chunks_created=result.get("chunks_created", 0)
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
|
||||
async def get_document_processing_status(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取知识库中的文档处理状态。"""
|
||||
# Verify knowledge base exists and user has access
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document not found"
|
||||
)
|
||||
|
||||
# Determine status
|
||||
if document.processing_error:
|
||||
status_str = "failed"
|
||||
progress = 0.0
|
||||
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
|
||||
elif document.is_processed:
|
||||
status_str = "completed"
|
||||
progress = 100.0
|
||||
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
|
||||
else:
|
||||
status_str = "pending"
|
||||
progress = 0.0
|
||||
session.desc = f"文档 {doc_id} 处理pending中"
|
||||
|
||||
return DocumentProcessingStatus(
|
||||
document_id=document.id,
|
||||
status=status_str,
|
||||
progress=progress,
|
||||
error_message=document.processing_error,
|
||||
chunks_created=document.chunk_count or 0
|
||||
)
|
||||
|
||||
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
|
||||
async def search_knowledge_base(
|
||||
kb_id: int,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""在知识库中搜索文档。"""
|
||||
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
kb = kb_service.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Knowledge base not found"
|
||||
)
|
||||
|
||||
# Perform search
|
||||
doc_service = DocumentService(session)
|
||||
results = doc_service.search_documents(kb_id, query, limit)
|
||||
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
|
||||
return {
|
||||
"knowledge_base_id": kb_id,
|
||||
"query": query,
|
||||
"results": results,
|
||||
"total_results": len(results)
|
||||
}
|
||||
|
||||
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
|
||||
async def get_document_chunks(
|
||||
kb_id: int,
|
||||
doc_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""
|
||||
获取知识库中特定文档的所有文档块(片段)。
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
doc_id: 文档ID
|
||||
session: 数据库会话
|
||||
current_user: 当前认证用户
|
||||
|
||||
Returns:
|
||||
DocumentChunksResponse: 文档块(片段)响应模型
|
||||
"""
|
||||
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
|
||||
kb_service = KnowledgeBaseService(session)
|
||||
knowledge_base = kb_service.get_knowledge_base(kb_id)
|
||||
if not knowledge_base:
|
||||
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="知识库不存在"
|
||||
)
|
||||
|
||||
# Verify document exists in the knowledge base
|
||||
doc_service = DocumentService(session)
|
||||
document = doc_service.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文档不存在"
|
||||
)
|
||||
|
||||
# Get document chunks
|
||||
chunks = doc_service.get_document_chunks(doc_id)
|
||||
session.desc = f"SUCCESS: 获取文档 {doc_id} 共 {len(chunks)} 个文档块(片段)"
|
||||
return DocumentChunksResponse(
|
||||
document_id=doc_id,
|
||||
document_name=document.filename,
|
||||
total_chunks=len(chunks),
|
||||
chunks=chunks
|
||||
)
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
"""LLM configuration management API endpoints."""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_, select, delete, update
|
||||
|
||||
from loguru import logger
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.llm_config import LLMConfig
|
||||
from ...core.simple_permissions import require_super_admin, require_authenticated_user
|
||||
from ...schemas.llm_config import (
|
||||
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
|
||||
LLMConfigTest
|
||||
)
|
||||
from th_agenter.services.document_processor import get_document_processor
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
|
||||
|
||||
@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表")
|
||||
async def get_llm_configs(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
search: Optional[str] = Query(None),
|
||||
provider: Optional[str] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
is_embedding: Optional[bool] = Query(None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取大模型配置列表."""
|
||||
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
LLMConfig.name.ilike(f"%{search}%"),
|
||||
LLMConfig.model_name.ilike(f"%{search}%"),
|
||||
LLMConfig.description.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
# 服务商筛选
|
||||
if provider:
|
||||
stmt = stmt.where(LLMConfig.provider == provider)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(LLMConfig.is_active == is_active)
|
||||
|
||||
# 模型类型筛选
|
||||
if is_embedding is not None:
|
||||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
# 排序
|
||||
stmt = stmt.order_by(LLMConfig.name)
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
configs = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置"
|
||||
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||
|
||||
|
||||
@router.get("/providers", summary="获取支持的大模型服务商列表")
|
||||
async def get_llm_providers(
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取支持的大模型服务商列表."""
|
||||
session.desc = "START: 获取支持的大模型服务商列表"
|
||||
stmt = select(LLMConfig.provider).distinct()
|
||||
providers = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商"
|
||||
return HxfResponse([provider for provider in providers if provider])
|
||||
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置")
|
||||
async def get_active_llm_configs(
|
||||
is_embedding: Optional[bool] = Query(None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取所有激活的大模型配置."""
|
||||
session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
|
||||
|
||||
if is_embedding is not None:
|
||||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
stmt = stmt.order_by(LLMConfig.created_at)
|
||||
configs = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置"
|
||||
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||
|
||||
@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置")
|
||||
async def get_default_llm_config(
|
||||
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取默认大模型配置."""
|
||||
session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig).where(
|
||||
LLMConfig.is_default == True,
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.is_active == True
|
||||
)
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
model_type = "嵌入模型" if is_embedding else "对话模型"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到默认{model_type}配置"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||
|
||||
@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情")
|
||||
async def get_llm_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取大模型配置详情."""
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||
|
||||
|
||||
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置")
|
||||
async def create_llm_config(
|
||||
config_data: LLMConfigCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""创建大模型配置."""
|
||||
# 检查配置名称是否已存在
|
||||
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 创建临时配置对象进行验证
|
||||
temp_config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = temp_config.validate_config()
|
||||
if not validation_result['valid']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=validation_result['error']
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default:
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == config_data.is_embedding
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 创建配置
|
||||
config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.add(config)
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置")
|
||||
async def update_llm_config(
|
||||
config_id: int,
|
||||
config_data: LLMConfigUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新大模型配置."""
|
||||
session.desc = f"START: 更新大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置名称是否已存在(排除自己)
|
||||
if config_data.name and config_data.name != config.name:
|
||||
stmt = select(LLMConfig).where(
|
||||
LLMConfig.name == config_data.name,
|
||||
LLMConfig.id != config_id
|
||||
)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default is True:
|
||||
# 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值
|
||||
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 更新字段
|
||||
update_data = config_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
|
||||
async def delete_llm_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""删除大模型配置."""
|
||||
session.desc = f"START: 删除大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||
# 这里可以添加相关的检查逻辑
|
||||
|
||||
session.delete(config)
|
||||
session.commit()
|
||||
|
||||
session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}"
|
||||
return HxfResponse({"message": "LLM config deleted successfully"})
|
||||
|
||||
@router.post("/{config_id}/test", summary="测试连接大模型配置")
|
||||
async def test_llm_config(
|
||||
config_id: int,
|
||||
test_data: LLMConfigTest,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""测试连接大模型配置."""
|
||||
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = config.validate_config()
|
||||
if not validation_result["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"配置验证失败: {validation_result['error']}",
|
||||
"details": validation_result
|
||||
}
|
||||
|
||||
# 尝试创建客户端并发送测试请求
|
||||
try:
|
||||
# 这里应该根据不同的服务商创建相应的客户端
|
||||
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||
|
||||
test_message = test_data.message or "Hello, this is a test message."
|
||||
|
||||
# TODO: 实现具体的测试逻辑
|
||||
# 例如:
|
||||
# client = config.get_client()
|
||||
# response = client.chat.completions.create(
|
||||
# model=config.model_name,
|
||||
# messages=[{"role": "user", "content": test_message}],
|
||||
# max_tokens=100
|
||||
# )
|
||||
|
||||
# 模拟测试成功
|
||||
session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}"
|
||||
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"message": "配置测试成功",
|
||||
"test_message": test_message,
|
||||
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
|
||||
"latency_ms": 150, # 模拟延迟
|
||||
"config_info": config.get_client_config()
|
||||
})
|
||||
|
||||
except Exception as test_error:
|
||||
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": f"配置测试失败: {str(test_error)}",
|
||||
"test_message": test_message,
|
||||
"config_info": config.get_client_config()
|
||||
})
|
||||
|
||||
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
|
||||
async def toggle_llm_config_status(
|
||||
config_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""切换大模型配置状态."""
|
||||
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 切换状态
|
||||
config.is_active = not config.is_active
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
status_text = "激活" if config.is_active else "禁用"
|
||||
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}"
|
||||
|
||||
return HxfResponse({
|
||||
"message": f"配置已{status_text}",
|
||||
"is_active": config.is_active
|
||||
})
|
||||
|
||||
|
||||
@router.post("/{config_id}/set-default", summary="设置默认大模型配置")
|
||||
async def set_default_llm_config(
|
||||
config_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""设置默认大模型配置."""
|
||||
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置是否激活
|
||||
if not config.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="只能将激活的配置设为默认"
|
||||
)
|
||||
|
||||
# 取消同类型的其他默认配置
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == config.is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 设置当前配置为默认
|
||||
config.is_default = True
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||
# 更新文档处理器默认embedding
|
||||
get_document_processor()._init_embeddings()
|
||||
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}"
|
||||
return HxfResponse({
|
||||
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||
"is_default": config.is_default
|
||||
})
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
"""Role management API endpoints."""
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, and_, or_, delete
|
||||
|
||||
from ...core.simple_permissions import require_super_admin
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.permission import Role, UserRole
|
||||
from ...services.auth import AuthService
|
||||
from ...schemas.permission import (
|
||||
RoleCreate, RoleUpdate, RoleResponse,
|
||||
UserRoleAssign
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/roles", tags=["roles"])
|
||||
|
||||
@router.get("/", response_model=List[RoleResponse], summary="获取角色列表")
|
||||
async def get_roles(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
search: Optional[str] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user = Depends(require_super_admin),
|
||||
):
|
||||
"""获取角色列表."""
|
||||
session.desc = f"START: 获取用户 {current_user.username} 角色列表"
|
||||
stmt = select(Role)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Role.name.ilike(f"%{search}%"),
|
||||
Role.code.ilike(f"%{search}%"),
|
||||
Role.description.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(Role.is_active == is_active)
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色"
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情")
|
||||
async def get_role(
|
||||
role_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""获取角色详情."""
|
||||
session.desc = f"START: 获取角色 {role_id} 详情"
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
return role.to_dict()
|
||||
|
||||
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色")
|
||||
async def create_role(
|
||||
role_data: RoleCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""创建角色."""
|
||||
session.desc = f"START: 创建角色 {role_data.name}"
|
||||
# 检查角色代码是否已存在
|
||||
stmt = select(Role).where(Role.code == role_data.code)
|
||||
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="角色代码已存在"
|
||||
)
|
||||
|
||||
# 创建角色
|
||||
role = Role(
|
||||
name=role_data.name,
|
||||
code=role_data.code,
|
||||
description=role_data.description,
|
||||
is_active=role_data.is_active
|
||||
)
|
||||
role.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色")
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
role_data: RoleUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新角色."""
|
||||
session.desc = f"更新用户 {current_user.username} 角色 {role_id}"
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
# 超级管理员角色不能被编辑
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被编辑"
|
||||
)
|
||||
|
||||
# 检查角色编码是否已存在(排除当前角色)
|
||||
if role_data.code and role_data.code != role.code:
|
||||
stmt = select(Role).where(
|
||||
and_(
|
||||
Role.code == role_data.code,
|
||||
Role.id != role_id
|
||||
)
|
||||
)
|
||||
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="角色代码已存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = role_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(role, field, value)
|
||||
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色")
|
||||
async def delete_role(
|
||||
role_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""删除角色."""
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
# 超级管理员角色不能被删除
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被删除"
|
||||
)
|
||||
|
||||
# 检查是否有用户使用该角色
|
||||
stmt = select(UserRole).where(UserRole.role_id == role_id)
|
||||
user_count = (await session.execute(stmt)).scalars().count()
|
||||
if user_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无法删除角色,还有 {user_count} 个用户关联此角色"
|
||||
)
|
||||
|
||||
# 删除角色
|
||||
await session.delete(role)
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"角色删除成功: {role.name} by user {current_user.username}"
|
||||
return {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
|
||||
|
||||
# 用户角色管理路由
|
||||
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
||||
@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED, summary="为用户分配角色")
|
||||
async def assign_user_roles(
|
||||
assignment_data: UserRoleAssign,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""为用户分配角色."""
|
||||
# 验证用户是否存在
|
||||
stmt = select(User).where(User.id == assignment_data.user_id)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
# 验证角色是否存在
|
||||
stmt = select(Role).where(Role.id.in_(assignment_data.role_ids))
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
if len(roles) != len(assignment_data.role_ids):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="部分角色不存在"
|
||||
)
|
||||
|
||||
# 删除现有角色关联
|
||||
stmt = delete(UserRole).where(UserRole.user_id == assignment_data.user_id)
|
||||
await session.execute(stmt)
|
||||
|
||||
# 添加新的角色关联
|
||||
for role_id in assignment_data.role_ids:
|
||||
user_role = UserRole(
|
||||
user_id=assignment_data.user_id,
|
||||
role_id=role_id
|
||||
)
|
||||
await session.add(user_role)
|
||||
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}"
|
||||
|
||||
return {"message": "角色分配成功"}
|
||||
|
||||
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表")
|
||||
async def get_user_roles(
|
||||
user_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_active_user)
|
||||
):
|
||||
"""获取用户角色列表."""
|
||||
# 检查权限:用户只能查看自己的角色,或者是超级管理员
|
||||
if current_user.id != user_id and not await current_user.is_superuser():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权限查看其他用户的角色"
|
||||
)
|
||||
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
stmt = select(Role).join(
|
||||
UserRole, Role.id == UserRole.role_id
|
||||
).where(
|
||||
UserRole.user_id == user_id
|
||||
)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
# 将子路由添加到主路由
|
||||
router.include_router(user_role_router)
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.auth import AuthService
|
||||
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||
from th_agenter.services.conversation import ConversationService
|
||||
from th_agenter.services.conversation_context import conversation_context_service
|
||||
from utils.util_schemas import BaseResponse
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
|
||||
security = HTTPBearer()
|
||||
|
||||
# Request/Response Models
|
||||
class SmartQueryRequest(BaseModel):
|
||||
query: str
|
||||
conversation_id: Optional[int] = None
|
||||
is_new_conversation: bool = False
|
||||
|
||||
class SmartQueryResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
workflow_steps: Optional[list] = None
|
||||
conversation_id: Optional[int] = None
|
||||
|
||||
class ConversationContextResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@router.post("/query", response_model=SmartQueryResponse, summary="智能问数查询")
|
||||
async def smart_query(
|
||||
request: SmartQueryRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
智能问数查询接口
|
||||
支持新对话时自动加载文件列表,智能选择相关Excel文件,生成和执行pandas代码
|
||||
"""
|
||||
session.desc = f"START: 用户 {current_user.username} 智能问数查询"
|
||||
conversation_id = None
|
||||
|
||||
try:
|
||||
# 验证请求参数
|
||||
if not request.query or not request.query.strip():
|
||||
session.desc = "ERROR: 用户输入为空, 查询内容不能为空"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="查询内容不能为空"
|
||||
)
|
||||
|
||||
if len(request.query) > 1000:
|
||||
session.desc = "ERROR: 用户输入过长, 查询内容不能超过1000字符"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="查询内容过长,请控制在1000字符以内"
|
||||
)
|
||||
|
||||
# 初始化工作流管理器
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_service = ConversationService(session)
|
||||
|
||||
# 处理对话上下文
|
||||
conversation_id = request.conversation_id
|
||||
|
||||
# 如果是新对话或没有指定对话ID,创建新对话
|
||||
if request.is_new_conversation or not conversation_id:
|
||||
try:
|
||||
conversation_id = await conversation_context_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
title=f"智能问数: {request.query[:20]}..."
|
||||
)
|
||||
request.is_new_conversation = True
|
||||
session.desc = f"创建新对话: {conversation_id}"
|
||||
except Exception as e:
|
||||
session.desc = f"WARNING: 创建对话失败,使用临时会话: {e}"
|
||||
conversation_id = None
|
||||
else:
|
||||
# 验证对话是否存在且属于当前用户
|
||||
try:
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
if not context or context.get('user_id') != current_user.id:
|
||||
session.desc = f"ERROR: 对话 {conversation_id} 不存在或无权访问"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话不存在或无权访问"
|
||||
)
|
||||
session.desc = f"使用现有对话: {conversation_id}"
|
||||
except HTTPException:
|
||||
session.desc = f"EXCEPTION: 对话 {conversation_id} 不存在或无权访问"
|
||||
raise
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 验证对话失败: {e}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="对话验证失败"
|
||||
)
|
||||
|
||||
# 保存用户消息
|
||||
if conversation_id:
|
||||
try:
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=request.query
|
||||
)
|
||||
except Exception as e:
|
||||
session.desc = f"WARNING: 保存用户消息失败: {e}"
|
||||
# 不阻断流程,继续执行查询
|
||||
|
||||
# 执行智能查询工作流
|
||||
try:
|
||||
result = await workflow_manager.process_smart_query(
|
||||
user_query=request.query,
|
||||
user_id=current_user.id,
|
||||
conversation_id=conversation_id,
|
||||
is_new_conversation=request.is_new_conversation
|
||||
)
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 智能查询执行失败: {e}"
|
||||
# 返回结构化的错误响应
|
||||
return SmartQueryResponse(
|
||||
success=False,
|
||||
message=f"查询执行失败: {str(e)}",
|
||||
data={'error_type': 'query_execution_error'},
|
||||
workflow_steps=[{
|
||||
'step': 'error',
|
||||
'status': 'failed',
|
||||
'message': str(e)
|
||||
}],
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 如果查询成功,保存助手回复和更新上下文
|
||||
if result['success'] and conversation_id:
|
||||
try:
|
||||
# 保存助手回复
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result.get('data', {}).get('summary', '查询完成'),
|
||||
metadata={
|
||||
'query_result': result.get('data'),
|
||||
'workflow_steps': result.get('workflow_steps', []),
|
||||
'selected_files': result.get('data', {}).get('used_files', [])
|
||||
}
|
||||
)
|
||||
|
||||
# 更新对话上下文
|
||||
await conversation_context_service.update_conversation_context(
|
||||
conversation_id=conversation_id,
|
||||
query=request.query,
|
||||
selected_files=result.get('data', {}).get('used_files', [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.desc = f"EXCEPTION: 保存消息到对话历史失败: {e}"
|
||||
# 不影响返回结果,只记录警告
|
||||
|
||||
# 返回结果,包含对话ID
|
||||
response_data = result.get('data', {})
|
||||
if conversation_id:
|
||||
response_data['conversation_id'] = conversation_id
|
||||
session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}"
|
||||
return SmartQueryResponse(
|
||||
success=result['success'],
|
||||
message=result.get('message', '查询完成'),
|
||||
data=response_data,
|
||||
workflow_steps=result.get('workflow_steps', []),
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
session.desc = f"EXCEPTION: HTTP异常: {e}"
|
||||
raise
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 智能查询接口异常: {e}"
|
||||
# 返回通用错误响应
|
||||
return SmartQueryResponse(
|
||||
success=False,
|
||||
message="服务器内部错误,请稍后重试",
|
||||
data={'error_type': 'internal_server_error'},
|
||||
workflow_steps=[{
|
||||
'step': 'error',
|
||||
'status': 'failed',
|
||||
'message': '系统异常'
|
||||
}],
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
|
||||
async def get_conversation_context(
|
||||
conversation_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取对话上下文信息,包括已使用的文件和历史查询
|
||||
"""
|
||||
# 获取对话上下文
|
||||
session.desc = f"START: 获取对话上下文,对话ID: {conversation_id}"
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
# 验证用户权限
|
||||
if context['user_id'] != current_user.id:
|
||||
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 获取对话历史
|
||||
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||
context['message_history'] = history
|
||||
session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}"
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message="获取对话上下文成功",
|
||||
data=context
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
|
||||
async def get_files_status(
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取用户当前的文件状态和统计信息
|
||||
"""
|
||||
session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}"
|
||||
workflow_manager = SmartWorkflowManager()
|
||||
|
||||
# 获取用户文件列表
|
||||
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||
|
||||
# 统计信息
|
||||
total_files = len(file_list)
|
||||
total_rows = sum(f.get('row_count', 0) for f in file_list)
|
||||
total_columns = sum(f.get('column_count', 0) for f in file_list)
|
||||
|
||||
# 文件类型统计
|
||||
file_types = {}
|
||||
for file_info in file_list:
|
||||
filename = file_info['filename']
|
||||
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
|
||||
file_types[ext] = file_types.get(ext, 0) + 1
|
||||
|
||||
status_data = {
|
||||
'total_files': total_files,
|
||||
'total_rows': total_rows,
|
||||
'total_columns': total_columns,
|
||||
'file_types': file_types,
|
||||
'files': [{
|
||||
'id': f['id'],
|
||||
'filename': f['filename'],
|
||||
'row_count': f.get('row_count', 0),
|
||||
'column_count': f.get('column_count', 0),
|
||||
'columns': f.get('columns', []),
|
||||
'upload_time': f.get('upload_time')
|
||||
} for f in file_list],
|
||||
'ready_for_query': total_files > 0
|
||||
}
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}"
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||
data=status_data
|
||||
)
|
||||
|
||||
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
|
||||
async def reset_conversation_context(
|
||||
conversation_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
重置对话上下文,清除历史查询记录但保留文件
|
||||
"""
|
||||
session.desc = f"START: 重置对话上下文,对话ID: {conversation_id}"
|
||||
# 验证对话存在和用户权限
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
if context['user_id'] != current_user.id:
|
||||
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 重置对话上下文
|
||||
success = await conversation_context_service.reset_conversation_context(conversation_id)
|
||||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}"
|
||||
return {
|
||||
"success": True,
|
||||
"message": "对话上下文已重置,可以开始新的数据分析会话"
|
||||
}
|
||||
else:
|
||||
session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="重置对话上下文失败"
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,733 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
|
||||
from fastapi.security import HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse, BaseResponse
|
||||
import os
|
||||
import tempfile
|
||||
from th_agenter.services.smart_query import (
|
||||
SmartQueryService,
|
||||
ExcelAnalysisService,
|
||||
DatabaseQueryService
|
||||
)
|
||||
from th_agenter.services.excel_metadata_service import ExcelMetadataService
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from utils.util_file import FileUtils
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, AsyncGenerator
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.auth import AuthService
|
||||
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||
from th_agenter.services.conversation_context import ConversationContextService
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
|
||||
security = HTTPBearer()
|
||||
|
||||
# Request/Response Models
|
||||
class DatabaseConfig(BaseModel):
|
||||
type: str
|
||||
host: str
|
||||
port: str
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
page: int = 1
|
||||
page_size: int = 20
|
||||
table_name: Optional[str] = None
|
||||
|
||||
class TableSchemaRequest(BaseModel):
|
||||
table_name: str
|
||||
|
||||
class ExcelUploadResponse(BaseModel):
|
||||
file_id: int
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None # 添加data字段
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@router.post("/upload-excel", response_model=ExcelUploadResponse, summary="上传Excel文件并进行预处理")
|
||||
async def upload_excel(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
上传Excel文件并进行预处理
|
||||
"""
|
||||
session.desc = f"START: 用户 {current_user.username} 上传 Excel 文件并进行预处理"
|
||||
# 验证文件类型
|
||||
allowed_extensions = ['.xlsx', '.xls', '.csv']
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
|
||||
if file_extension not in allowed_extensions:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传了不支持的文件格式 {file_extension},请上传 .xlsx, .xls 或 .csv 文件"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件"
|
||||
)
|
||||
|
||||
# 验证文件大小 (10MB)
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
if file_size > 10 * 1024 * 1024:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 大小为 {file_size / (1024 * 1024):.2f}MB,超过最大限制 10MB"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件大小不能超过 10MB"
|
||||
)
|
||||
|
||||
# 创建持久化目录结构
|
||||
backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录
|
||||
data_dir = backend_dir / "data/uploads"
|
||||
excel_user_dir = data_dir / f"excel_{current_user.id}"
|
||||
|
||||
# 确保目录存在
|
||||
excel_user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名:{uuid}_{原始文件名称}
|
||||
file_id = str(uuid.uuid4())
|
||||
safe_filename = FileUtils.sanitize_filename(file.filename)
|
||||
new_filename = f"{file_id}_{safe_filename}"
|
||||
file_path = excel_user_dir / new_filename
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# 使用Excel元信息服务提取并保存元信息
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.save_file_metadata(
|
||||
file_path=str(file_path),
|
||||
original_filename=file.filename,
|
||||
user_id=current_user.id,
|
||||
file_size=file_size
|
||||
)
|
||||
|
||||
# 为了兼容现有前端,仍然创建pickle文件
|
||||
try:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='utf-8')
|
||||
else:
|
||||
df = pd.read_excel(file_path)
|
||||
except UnicodeDecodeError:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='gbk')
|
||||
else:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 编码错误,请确保文件为UTF-8或GBK编码"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件编码错误,请确保文件为UTF-8或GBK编码"
|
||||
)
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 读取失败: {str(e)}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件读取失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 保存pickle文件到同一目录
|
||||
pickle_filename = f"{file_id}_{safe_filename}.pkl"
|
||||
pickle_path = excel_user_dir / pickle_filename
|
||||
df.to_pickle(pickle_path)
|
||||
|
||||
# 数据预处理和分析(保持兼容性)
|
||||
excel_service = ExcelAnalysisService()
|
||||
analysis_result = excel_service.analyze_dataframe(df, file.filename)
|
||||
|
||||
# 添加数据库文件信息
|
||||
analysis_result.update({
|
||||
'file_id': str(excel_file.id),
|
||||
'database_id': excel_file.id,
|
||||
'temp_file_path': str(pickle_path), # 更新为新的pickle路径
|
||||
'original_filename': file.filename,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
})
|
||||
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}"
|
||||
return ExcelUploadResponse(
|
||||
file_id=excel_file.id,
|
||||
success=True,
|
||||
message="Excel文件上传成功",
|
||||
data=analysis_result
|
||||
)
|
||||
|
||||
@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据")
|
||||
async def preview_excel(
|
||||
request: ExcelPreviewRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
预览Excel文件数据
|
||||
"""
|
||||
session.desc = f"START: 用户 {current_user.username} 预览文件 {request.file_id}"
|
||||
|
||||
# 验证file_id格式
|
||||
try:
|
||||
file_id = int(request.file_id)
|
||||
except ValueError:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 提供了无效的文件ID格式: {request.file_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"无效的文件ID格式: {request.file_id}"
|
||||
)
|
||||
|
||||
# 从数据库获取文件信息
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 不存在或已被删除"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在或已被删除"
|
||||
)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(excel_file.file_path):
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 已被移动或删除"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件已被移动或删除"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
# 读取Excel文件
|
||||
if excel_file.file_type.lower() == 'csv':
|
||||
df = pd.read_csv(excel_file.file_path, encoding='utf-8')
|
||||
else:
|
||||
# 对于Excel文件,使用默认sheet或第一个sheet
|
||||
sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0
|
||||
df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name)
|
||||
|
||||
# 计算分页
|
||||
total_rows = len(df)
|
||||
start_idx = (request.page - 1) * request.page_size
|
||||
end_idx = start_idx + request.page_size
|
||||
|
||||
# 获取分页数据
|
||||
paginated_df = df.iloc[start_idx:end_idx]
|
||||
|
||||
# 转换为字典格式
|
||||
data = paginated_df.fillna('').to_dict('records')
|
||||
columns = df.columns.tolist()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据"
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="Excel文件预览加载成功",
|
||||
data={
|
||||
'data': data,
|
||||
'columns': columns,
|
||||
'total_rows': total_rows,
|
||||
'page': request.page,
|
||||
'page_size': request.page_size,
|
||||
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
config: DatabaseConfig,
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""
|
||||
测试数据库连接
|
||||
"""
|
||||
try:
|
||||
db_service = DatabaseQueryService()
|
||||
is_connected = await db_service.test_connection(config.model_dump())
|
||||
|
||||
if is_connected:
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="数据库连接测试成功"
|
||||
)
|
||||
else:
|
||||
return NormalResponse(
|
||||
success=False,
|
||||
message="数据库连接测试失败"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return NormalResponse(
|
||||
success=False,
|
||||
message=f"连接测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 删除第285-314行的connect_database方法
|
||||
# @router.post("/connect-database", response_model=QueryResponse)
|
||||
# async def connect_database(
|
||||
# config_id: int,
|
||||
# current_user = Depends(AuthService.get_current_user),
|
||||
# db: Session = Depends(get_session)
|
||||
# ):
|
||||
# """连接数据库并获取表列表"""
|
||||
# ... (整个方法都删除)
|
||||
|
||||
@router.post("/table-schema", response_model=QueryResponse, summary="获取数据表结构")
|
||||
async def get_table_schema(
|
||||
request: TableSchemaRequest,
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""
|
||||
获取数据表结构
|
||||
"""
|
||||
try:
|
||||
db_service = DatabaseQueryService()
|
||||
schema_result = await db_service.get_table_schema(request.table_name, current_user.id)
|
||||
|
||||
if schema_result['success']:
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="获取表结构成功",
|
||||
data=schema_result['data']
|
||||
)
|
||||
else:
|
||||
return QueryResponse(
|
||||
success=False,
|
||||
message=schema_result['message']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return QueryResponse(
|
||||
success=False,
|
||||
message=f"获取表结构失败: {str(e)}"
|
||||
)
|
||||
|
||||
class StreamQueryRequest(BaseModel):
|
||||
query: str
|
||||
conversation_id: Optional[int] = None
|
||||
is_new_conversation: bool = False
|
||||
|
||||
class DatabaseStreamQueryRequest(BaseModel):
|
||||
query: str
|
||||
database_config_id: int
|
||||
conversation_id: Optional[int] = None
|
||||
is_new_conversation: bool = False
|
||||
|
||||
@router.post("/execute-excel-query", summary="流式智能问答查询")
|
||||
async def stream_smart_query(
|
||||
request: StreamQueryRequest,
|
||||
current_user=Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
流式智能问答查询接口
|
||||
支持实时推送工作流步骤和最终结果
|
||||
"""
|
||||
|
||||
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||
workflow_manager = None
|
||||
|
||||
try:
|
||||
# 验证请求参数
|
||||
if not request.query or not request.query.strip():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
if len(request.query) > 1000:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 发送开始信号
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
conversation_id = request.conversation_id
|
||||
|
||||
# 如果是新对话或没有指定对话ID,创建新对话
|
||||
if request.is_new_conversation or not conversation_id:
|
||||
try:
|
||||
conversation_id = await conversation_context_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
title=f"智能问数: {request.query[:20]}..."
|
||||
)
|
||||
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
logger.warning(f"创建对话失败: {e}")
|
||||
# 不阻断流程,继续执行查询
|
||||
|
||||
# 保存用户消息
|
||||
if conversation_id:
|
||||
try:
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=request.query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存用户消息失败: {e}")
|
||||
|
||||
# 执行智能查询工作流(带流式推送)
|
||||
async for step_data in workflow_manager.process_excel_query_stream(
|
||||
user_query=request.query,
|
||||
user_id=current_user.id,
|
||||
conversation_id=conversation_id,
|
||||
is_new_conversation=request.is_new_conversation
|
||||
):
|
||||
# 推送工作流步骤
|
||||
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 如果是最终结果,保存到对话历史
|
||||
if step_data.get('type') == 'final_result' and conversation_id:
|
||||
try:
|
||||
result_data = step_data.get('data', {})
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result_data.get('summary', '查询完成'),
|
||||
metadata={
|
||||
'query_result': result_data,
|
||||
'workflow_steps': step_data.get('workflow_steps', []),
|
||||
'selected_files': result_data.get('used_files', [])
|
||||
}
|
||||
)
|
||||
|
||||
# 更新对话上下文
|
||||
await conversation_context_service.update_conversation_context(
|
||||
conversation_id=conversation_id,
|
||||
query=request.query,
|
||||
selected_files=result_data.get('used_files', [])
|
||||
)
|
||||
|
||||
logger.info(f"查询成功完成,对话ID: {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||
|
||||
# 发送完成信号
|
||||
yield f"data: {json.dumps({'type': 'complete', 'message': '查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式智能查询异常: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
finally:
|
||||
# 清理资源
|
||||
if workflow_manager:
|
||||
try:
|
||||
workflow_manager.excel_workflow.executor.shutdown(wait=False)
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/execute-db-query", summary="流式数据库查询")
|
||||
async def execute_database_query(
|
||||
request: DatabaseStreamQueryRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
流式数据库查询接口
|
||||
支持实时推送工作流步骤和最终结果
|
||||
"""
|
||||
|
||||
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||
workflow_manager = None
|
||||
|
||||
try:
|
||||
# 验证请求参数
|
||||
if not request.query or not request.query.strip():
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
if len(request.query) > 1000:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 发送开始信号
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理数据库查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
conversation_id = request.conversation_id
|
||||
|
||||
# 如果是新对话或没有指定对话ID,创建新对话
|
||||
if request.is_new_conversation or not conversation_id:
|
||||
try:
|
||||
conversation_id = await conversation_context_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
title=f"数据库查询: {request.query[:20]}..."
|
||||
)
|
||||
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
logger.warning(f"创建对话失败: {e}")
|
||||
# 不阻断流程,继续执行查询
|
||||
|
||||
# 保存用户消息
|
||||
if conversation_id:
|
||||
try:
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=request.query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存用户消息失败: {e}")
|
||||
|
||||
# 执行数据库查询工作流(带流式推送)
|
||||
async for step_data in workflow_manager.process_database_query_stream(
|
||||
user_query=request.query,
|
||||
user_id=current_user.id,
|
||||
database_config_id=request.database_config_id
|
||||
):
|
||||
# 推送工作流步骤
|
||||
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 如果是最终结果,保存到对话历史
|
||||
if step_data.get('type') == 'final_result' and conversation_id:
|
||||
try:
|
||||
result_data = step_data.get('data', {})
|
||||
await conversation_context_service.save_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result_data.get('summary', '查询完成'),
|
||||
metadata={
|
||||
'query_result': result_data,
|
||||
'workflow_steps': step_data.get('workflow_steps', []),
|
||||
'generated_sql': result_data.get('generated_sql', '')
|
||||
}
|
||||
)
|
||||
|
||||
# 更新对话上下文
|
||||
await conversation_context_service.update_conversation_context(
|
||||
conversation_id=conversation_id,
|
||||
query=request.query,
|
||||
selected_files=[]
|
||||
)
|
||||
|
||||
logger.info(f"数据库查询成功完成,对话ID: {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||
|
||||
# 发送完成信号
|
||||
yield f"data: {json.dumps({'type': 'complete', 'message': '数据库查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式数据库查询异常: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
finally:
|
||||
# 清理资源
|
||||
if workflow_manager:
|
||||
try:
|
||||
workflow_manager.database_workflow.executor.shutdown(wait=False)
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
|
||||
@router.delete("/cleanup-temp-files", summary="清理临时文件")
|
||||
async def cleanup_temp_files(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""
|
||||
清理临时文件
|
||||
"""
|
||||
try:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
user_prefix = f"excel_{current_user.id}_"
|
||||
|
||||
cleaned_count = 0
|
||||
for filename in os.listdir(temp_dir):
|
||||
if filename.startswith(user_prefix) and filename.endswith('.pkl'):
|
||||
file_path = os.path.join(temp_dir, filename)
|
||||
try:
|
||||
os.remove(file_path)
|
||||
cleaned_count += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return BaseResponse(
|
||||
success=True,
|
||||
message=f"已清理 {cleaned_count} 个临时文件"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
success=False,
|
||||
message=f"清理临时文件失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表")
|
||||
async def get_file_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取用户上传的Excel文件列表
|
||||
"""
|
||||
try:
|
||||
session.desc = f"START: 获取用户 {current_user.id} 的文件列表"
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
skip = (page - 1) * page_size
|
||||
files, total = metadata_service.get_user_files(current_user.id, skip, page_size)
|
||||
|
||||
file_list = []
|
||||
for file in files:
|
||||
file_info = {
|
||||
'id': file.id,
|
||||
'filename': file.original_filename,
|
||||
'file_size': file.file_size,
|
||||
'file_size_mb': file.file_size_mb,
|
||||
'file_type': file.file_type,
|
||||
'sheet_names': file.sheet_names,
|
||||
'sheet_count': file.sheet_count,
|
||||
'last_accessed': file.last_accessed.isoformat() if file.last_accessed else None,
|
||||
'is_processed': file.is_processed,
|
||||
'processing_error': file.processing_error
|
||||
}
|
||||
file_list.append(file_info)
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件"
|
||||
return FileListResponse(
|
||||
success=True,
|
||||
message="获取文件列表成功",
|
||||
data={
|
||||
'files': file_list,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total + page_size - 1) // page_size
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return FileListResponse(
|
||||
success=False,
|
||||
message=f"获取文件列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件")
|
||||
async def delete_file(
|
||||
file_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
删除指定的Excel文件
|
||||
"""
|
||||
try:
|
||||
session.desc = f"START: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
success = metadata_service.delete_file(file_id, current_user.id)
|
||||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="文件删除成功"
|
||||
)
|
||||
else:
|
||||
session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败"
|
||||
return NormalResponse(
|
||||
success=False,
|
||||
message="文件不存在或删除失败"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message=str(e)
|
||||
)
|
||||
|
||||
@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息")
|
||||
async def get_file_info(
|
||||
file_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取指定文件的详细信息
|
||||
"""
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
session.desc = f"ERROR: 获取用户 {current_user.id} 的文件 {file_id} 信息,文件不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
file_info = {
|
||||
'id': excel_file.id,
|
||||
'filename': excel_file.original_filename,
|
||||
'file_size': excel_file.file_size,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'file_type': excel_file.file_type,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
'default_sheet': excel_file.default_sheet,
|
||||
'columns_info': excel_file.columns_info,
|
||||
'preview_data': excel_file.preview_data,
|
||||
'data_types': excel_file.data_types,
|
||||
'total_rows': excel_file.total_rows,
|
||||
'total_columns': excel_file.total_columns,
|
||||
'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None,
|
||||
'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None,
|
||||
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||
}
|
||||
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="获取文件信息成功",
|
||||
data=file_info
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
"""表元数据管理API"""
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from th_agenter.models.user import User
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.table_metadata_service import TableMetadataService
|
||||
from th_agenter.services.auth import AuthService
|
||||
|
||||
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
|
||||
|
||||
class TableSelectionRequest(BaseModel):
|
||||
database_config_id: int = Field(..., description="数据库配置ID")
|
||||
table_names: List[str] = Field(..., description="选中的表名列表")
|
||||
|
||||
class TableMetadataResponse(BaseModel):
|
||||
id: int
|
||||
table_name: str
|
||||
table_schema: str
|
||||
table_type: str
|
||||
table_comment: str
|
||||
columns_count: int
|
||||
row_count: int
|
||||
is_enabled_for_qa: bool
|
||||
qa_description: str
|
||||
business_context: str
|
||||
last_synced_at: str
|
||||
|
||||
class QASettingsUpdate(BaseModel):
|
||||
is_enabled_for_qa: bool = Field(default=True)
|
||||
qa_description: str = Field(default="")
|
||||
business_context: str = Field(default="")
|
||||
|
||||
class TableByNameRequest(BaseModel):
|
||||
database_config_id: int = Field(..., description="数据库配置ID")
|
||||
table_name: str = Field(..., description="表名")
|
||||
|
||||
|
||||
@router.post("/collect", summary="收集选中表的元数据")
|
||||
async def collect_table_metadata(
|
||||
request: TableSelectionRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""收集选中表的元数据"""
|
||||
session.desc = f"START: 用户 {current_user.id} 收集表元数据"
|
||||
service = TableMetadataService(session)
|
||||
result = await service.collect_and_save_table_metadata(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
request.table_names
|
||||
)
|
||||
session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据"
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/", summary="获取用户表元数据列表")
|
||||
async def get_table_metadata(
|
||||
database_config_id: int = None,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取表元数据列表"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
metadata_list = service.get_user_table_metadata(
|
||||
current_user.id,
|
||||
database_config_id
|
||||
)
|
||||
|
||||
data = [
|
||||
{
|
||||
"id": meta.id,
|
||||
"table_name": meta.table_name,
|
||||
"table_schema": meta.table_schema,
|
||||
"table_type": meta.table_type,
|
||||
"table_comment": meta.table_comment or "",
|
||||
"columns": meta.columns_info if meta.columns_info else [],
|
||||
"column_count": len(meta.columns_info) if meta.columns_info else 0,
|
||||
"row_count": meta.row_count,
|
||||
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||
"qa_description": meta.qa_description or "",
|
||||
"business_context": meta.business_context or "",
|
||||
"created_at": meta.created_at.isoformat() if meta.created_at else "",
|
||||
"updated_at": meta.updated_at.isoformat() if meta.updated_at else "",
|
||||
"last_synced_at": meta.last_synced_at.isoformat() if meta.last_synced_at else "",
|
||||
"qa_settings": {
|
||||
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||
"qa_description": meta.qa_description or "",
|
||||
"business_context": meta.business_context or ""
|
||||
}
|
||||
}
|
||||
for meta in metadata_list
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/by-table", summary="根据表名获取表元数据")
|
||||
async def get_table_metadata_by_name(
|
||||
request: TableByNameRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""根据表名获取表元数据"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
metadata = service.get_table_metadata_by_name(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
request.table_name
|
||||
)
|
||||
|
||||
if metadata:
|
||||
data = {
|
||||
"id": metadata.id,
|
||||
"table_name": metadata.table_name,
|
||||
"table_schema": metadata.table_schema,
|
||||
"table_type": metadata.table_type,
|
||||
"table_comment": metadata.table_comment or "",
|
||||
"columns": metadata.columns_info if metadata.columns_info else [],
|
||||
"column_count": len(metadata.columns_info) if metadata.columns_info else 0,
|
||||
"row_count": metadata.row_count,
|
||||
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||
"qa_description": metadata.qa_description or "",
|
||||
"business_context": metadata.business_context or "",
|
||||
"created_at": metadata.created_at.isoformat() if metadata.created_at else "",
|
||||
"updated_at": metadata.updated_at.isoformat() if metadata.updated_at else "",
|
||||
"last_synced_at": metadata.last_synced_at.isoformat() if metadata.last_synced_at else "",
|
||||
"qa_settings": {
|
||||
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||
"qa_description": metadata.qa_description or "",
|
||||
"business_context": metadata.business_context or ""
|
||||
}
|
||||
}
|
||||
return {"success": True, "data": data}
|
||||
else:
|
||||
return {"success": False, "data": None, "message": "表元数据不存在"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": data
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取表元数据失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置")
|
||||
async def update_qa_settings(
|
||||
metadata_id: int,
|
||||
settings: QASettingsUpdate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""更新表的问答设置"""
|
||||
try:
|
||||
service = TableMetadataService(session)
|
||||
success = service.update_table_qa_settings(
|
||||
current_user.id,
|
||||
metadata_id,
|
||||
settings.dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"success": True, "message": "设置更新成功"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="表元数据不存在"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新问答设置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
class TableSaveRequest(BaseModel):
|
||||
database_config_id: int = Field(..., description="数据库配置ID")
|
||||
table_names: List[str] = Field(..., description="要保存的表名列表")
|
||||
|
||||
|
||||
@router.post("/save")
|
||||
async def save_table_metadata(
|
||||
request: TableSaveRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""保存选中表的元数据配置"""
|
||||
service = TableMetadataService(session)
|
||||
result = await service.save_table_metadata_config(
|
||||
user_id=current_user.id,
|
||||
database_config_id=request.database_config_id,
|
||||
table_names=request.table_names
|
||||
)
|
||||
|
||||
session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||
"saved_tables": result['saved_tables'],
|
||||
"failed_tables": result.get('failed_tables', [])
|
||||
}
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
"""User management endpoints."""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.database import get_session
|
||||
from ...core.simple_permissions import require_super_admin
|
||||
from ...services.auth import AuthService
|
||||
from ...services.user import UserService
|
||||
from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/profile", response_model=UserResponse, summary="获取当前用户的个人信息")
|
||||
async def get_user_profile(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取当前用户的个人信息."""
|
||||
return UserResponse.model_validate(current_user)
|
||||
|
||||
@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息")
|
||||
async def update_user_profile(
|
||||
user_update: UserUpdate,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""更新当前用户的个人信息."""
|
||||
user_service = UserService(session)
|
||||
|
||||
# Check if email is being changed and is already taken
|
||||
if user_update.email and user_update.email != current_user.email:
|
||||
existing_user = await user_service.get_user_by_email(user_update.email)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Update user
|
||||
updated_user = await user_service.update_user(current_user.id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.delete("/profile", summary="删除当前用户的账户")
|
||||
async def delete_user_account(
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""删除当前用户的账户."""
|
||||
username = current_user.username
|
||||
user_service = UserService(session)
|
||||
await user_service.delete_user(current_user.id)
|
||||
session.desc = f"删除用户 [{username}] 成功"
|
||||
return {"message": f"删除用户 {username} 成功"}
|
||||
|
||||
# Admin endpoints
|
||||
@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)")
|
||||
async def create_user(
|
||||
user_create: UserCreate,
|
||||
current_user = Depends(require_super_admin),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""创建一个新用户 (需要有管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
# Check if username already exists
|
||||
existing_user = await user_service.get_user_by_username(user_create.username)
|
||||
if existing_user:
|
||||
session.desc = f"创建用户 [{user_create.username}] 失败 - 用户名已存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
# Check if email already exists
|
||||
existing_user = await user_service.get_user_by_email(user_create.email)
|
||||
if existing_user:
|
||||
session.desc = f"创建用户 [{user_create.username}] 失败 - 邮箱已存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Create user
|
||||
new_user = await user_service.create_user(user_create)
|
||||
return UserResponse.model_validate(new_user)
|
||||
|
||||
@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)")
|
||||
async def list_users(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None),
|
||||
role_id: Optional[int] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""列出所有用户,支持分页和筛选 (仅管理员权限)."""
|
||||
session.desc = f"START: 列出所有用户,分页={page}, 每页大小={size}, 搜索={search}, 角色ID={role_id}, 激活状态={is_active}"
|
||||
user_service = UserService(session)
|
||||
skip = (page - 1) * size
|
||||
users, total = await user_service.get_users_with_filters(
|
||||
skip=skip,
|
||||
limit=size,
|
||||
search=search,
|
||||
role_id=role_id,
|
||||
is_active=is_active
|
||||
)
|
||||
result = {
|
||||
"users": [UserResponse.model_validate(user) for user in users],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": size
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)")
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""通过ID获取用户信息 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
@router.put("/change-password", summary="修改当前用户的密码")
|
||||
async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""修改当前用户的密码."""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
await user_service.change_password(
|
||||
user_id=current_user.id,
|
||||
current_password=request.current_password,
|
||||
new_password=request.new_password
|
||||
)
|
||||
return {"message": "Password changed successfully"}
|
||||
except Exception as e:
|
||||
if "Current password is incorrect" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
elif "must be at least 6 characters" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be at least 6 characters long"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to change password"
|
||||
)
|
||||
|
||||
@router.put("/{user_id}/reset-password", summary="重置用户密码 (仅管理员权限)")
|
||||
async def reset_user_password(
|
||||
user_id: int,
|
||||
request: ResetPasswordRequest,
|
||||
current_user = Depends(require_super_admin),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""重置用户密码 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
await user_service.reset_password(
|
||||
user_id=user_id,
|
||||
new_password=request.new_password
|
||||
)
|
||||
return {"message": "Password reset successfully"}
|
||||
except Exception as e:
|
||||
if "User not found" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
elif "must be at least 6 characters" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be at least 6 characters long"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to reset password"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse, summary="更新用户信息 (仅管理员权限)")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_update: UserUpdate,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""更新用户信息 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
updated_user = await user_service.update_user(user_id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""删除用户 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
await user_service.delete_user(user_id)
|
||||
return {"message": "User deleted successfully"}
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
"""工作流管理API"""
|
||||
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, and_, func
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ...db.database import get_session
|
||||
from ...schemas.workflow import (
|
||||
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
|
||||
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
|
||||
)
|
||||
from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
|
||||
from ...services.workflow_engine import get_workflow_engine
|
||||
from ...services.auth import AuthService
|
||||
from ...models.user import User
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def convert_workflow_for_response(workflow_dict):
|
||||
"""转换工作流数据以适配响应模型"""
|
||||
if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'):
|
||||
for conn in workflow_dict['definition']['connections']:
|
||||
if 'from_node' in conn:
|
||||
conn['from'] = conn.pop('from_node')
|
||||
if 'to_node' in conn:
|
||||
conn['to'] = conn.pop('to_node')
|
||||
return workflow_dict
|
||||
|
||||
@router.post("/", response_model=WorkflowResponse)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""创建工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 创建工作流
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
definition=workflow_data.definition.model_dump(),
|
||||
version="1.0.0",
|
||||
status=workflow_data.status,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
workflow.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
|
||||
# 转换definition中的字段映射
|
||||
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||
|
||||
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
|
||||
@router.get("/", response_model=WorkflowListResponse)
|
||||
async def list_workflows(
|
||||
skip: Optional[int] = Query(None, ge=0),
|
||||
limit: Optional[int] = Query(None, ge=1, le=100),
|
||||
workflow_status: Optional[WorkflowStatus] = None,
|
||||
search: Optional[str] = Query(None),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流列表"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 构建查询
|
||||
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
|
||||
|
||||
if workflow_status:
|
||||
stmt = stmt.where(Workflow.status == workflow_status)
|
||||
|
||||
# 添加搜索功能
|
||||
if search:
|
||||
stmt = stmt.where(Workflow.name.ilike(f"%{search}%"))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id)
|
||||
if workflow_status:
|
||||
count_query = count_query.where(Workflow.status == workflow_status)
|
||||
if search:
|
||||
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
|
||||
total = session.scalar(count_query)
|
||||
|
||||
# 如果没有传分页参数,返回所有数据
|
||||
if skip is None and limit is None:
|
||||
workflows = session.scalars(stmt).all()
|
||||
return WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=1,
|
||||
size=total
|
||||
)
|
||||
|
||||
# 使用默认分页参数
|
||||
if skip is None:
|
||||
skip = 0
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
# 分页查询
|
||||
workflows = session.scalars(stmt.offset(skip).limit(limit)).all()
|
||||
|
||||
return WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=skip // limit + 1, # 计算页码
|
||||
size=limit
|
||||
)
|
||||
|
||||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def get_workflow(
|
||||
workflow_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流详情"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def update_workflow(
|
||||
workflow_id: int,
|
||||
workflow_data: WorkflowUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""更新工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = workflow_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "definition" and value:
|
||||
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||
if hasattr(value, 'dict'):
|
||||
setattr(workflow, field, value.dict())
|
||||
else:
|
||||
setattr(workflow, field, value)
|
||||
else:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
|
||||
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}")
|
||||
async def delete_workflow(
|
||||
workflow_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
session.delete(workflow)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流删除成功"}
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/activate")
|
||||
async def activate_workflow(
|
||||
workflow_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""激活工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流激活成功"}
|
||||
|
||||
@router.post("/{workflow_id}/deactivate")
|
||||
async def deactivate_workflow(
|
||||
workflow_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""停用工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流停用成功"}
|
||||
|
||||
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||
async def execute_workflow(
|
||||
workflow_id: int,
|
||||
request: WorkflowExecuteRequest,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""执行工作流"""
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="工作流未激活,无法执行"
|
||||
)
|
||||
|
||||
# 获取工作流引擎并执行
|
||||
engine = get_workflow_engine()
|
||||
execution_result = await engine.execute_workflow(
|
||||
workflow=workflow,
|
||||
input_data=request.input_data,
|
||||
user_id=current_user.id,
|
||||
session=session
|
||||
)
|
||||
|
||||
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
|
||||
return execution_result
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
|
||||
async def list_workflow_executions(
|
||||
workflow_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行历史"""
|
||||
try:
|
||||
from ...models.workflow import Workflow, WorkflowExecution
|
||||
|
||||
# 验证工作流所有权
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 获取执行历史
|
||||
executions = session.scalars(
|
||||
select(WorkflowExecution).where(
|
||||
WorkflowExecution.workflow_id == workflow_id
|
||||
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
|
||||
).all()
|
||||
|
||||
return [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取执行历史失败"
|
||||
)
|
||||
|
||||
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||
async def get_workflow_execution(
|
||||
execution_id: int,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行详情"""
|
||||
try:
|
||||
from ...models.workflow import WorkflowExecution, Workflow
|
||||
|
||||
execution = session.scalar(
|
||||
select(WorkflowExecution).join(
|
||||
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||
).where(
|
||||
WorkflowExecution.id == execution_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="执行记录不存在"
|
||||
)
|
||||
|
||||
return WorkflowExecutionResponse.model_validate(execution)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow execution {execution_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取执行详情失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/execute-stream")
|
||||
async def execute_workflow_stream(
|
||||
workflow_id: int,
|
||||
request: WorkflowExecuteRequest,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""流式执行工作流,实时推送节点执行状态"""
|
||||
|
||||
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||
workflow_engine = None
|
||||
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 验证工作流
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
# 发送开始信号
|
||||
yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 获取工作流引擎
|
||||
workflow_engine = get_workflow_engine()
|
||||
|
||||
# 执行工作流(流式版本)
|
||||
async for step_data in workflow_engine.execute_workflow_stream(
|
||||
workflow=workflow,
|
||||
input_data=request.input_data,
|
||||
user_id=current_user.id,
|
||||
session=session
|
||||
):
|
||||
# 推送工作流步骤
|
||||
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送完成信号
|
||||
yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "*",
|
||||
"Access-Control-Allow-Methods": "*"
|
||||
}
|
||||
)
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
"""Main API router."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .endpoints import chat
|
||||
from .endpoints import auth
|
||||
from .endpoints import knowledge_base
|
||||
from .endpoints import smart_query
|
||||
from .endpoints import smart_chat
|
||||
from .endpoints import database_config
|
||||
from .endpoints import table_metadata
|
||||
|
||||
# # System management endpoints
|
||||
from .endpoints import roles
|
||||
from .endpoints import llm_configs
|
||||
from .endpoints import users
|
||||
|
||||
# # Workflow endpoints
|
||||
from .endpoints import workflow
|
||||
|
||||
# Create main API router
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(
|
||||
auth.router,
|
||||
prefix="/auth",
|
||||
tags=["身份验证"]
|
||||
)
|
||||
router.include_router(
|
||||
users.router,
|
||||
prefix="/users",
|
||||
tags=["users"]
|
||||
)
|
||||
router.include_router(
|
||||
roles.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-roles"]
|
||||
)
|
||||
router.include_router(
|
||||
llm_configs.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-llm-configs"]
|
||||
)
|
||||
router.include_router(
|
||||
knowledge_base.router,
|
||||
prefix="/knowledge-bases",
|
||||
tags=["knowledge-bases"]
|
||||
)
|
||||
router.include_router(
|
||||
database_config.router,
|
||||
tags=["database-config"]
|
||||
)
|
||||
router.include_router(
|
||||
table_metadata.router,
|
||||
tags=["table-metadata"]
|
||||
)
|
||||
router.include_router(
|
||||
smart_query.router,
|
||||
tags=["smart-query"]
|
||||
)
|
||||
router.include_router(
|
||||
chat.router,
|
||||
prefix="/chat",
|
||||
tags=["chat"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
smart_chat.router,
|
||||
tags=["smart-chat"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
router.include_router(
|
||||
workflow.router,
|
||||
prefix="/workflows",
|
||||
tags=["workflows"]
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Core module for TH Agenter."""
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
"""FastAPI application factory."""
|
||||
|
||||
from loguru import logger
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from .config import Settings
|
||||
from .middleware import UserContextMiddleware
|
||||
from ..api.routes import router
|
||||
from ..api.endpoints import table_metadata
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting up TH Agenter application...")
|
||||
yield
|
||||
logger.info("Shutting down TH Agenter application...")
|
||||
|
||||
|
||||
# def create_app(settings: Settings = None) -> FastAPI:
|
||||
# """Create and configure FastAPI application."""
|
||||
# if settings is None:
|
||||
# from .config import get_settings
|
||||
# settings = get_settings()
|
||||
|
||||
# # Create FastAPI app
|
||||
# app = FastAPI(
|
||||
# title=settings.app_name,
|
||||
# version=settings.app_version,
|
||||
# description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由DrGraph修改",
|
||||
# debug=settings.debug,
|
||||
# lifespan=lifespan,
|
||||
# )
|
||||
|
||||
# # Add middleware
|
||||
# setup_middleware(app, settings)
|
||||
|
||||
# # Add exception handlers
|
||||
# setup_exception_handlers(app)
|
||||
|
||||
# # Include routers
|
||||
# app.include_router(router, prefix="/api")
|
||||
|
||||
# app.include_router(table_metadata.router)
|
||||
# # 在现有导入中添加
|
||||
# from ..api.endpoints import database_config
|
||||
|
||||
# # 在路由注册部分添加
|
||||
# app.include_router(database_config.router)
|
||||
# # Health check endpoint
|
||||
# @app.get("/health")
|
||||
# async def health_check():
|
||||
# return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# # Root endpoint
|
||||
# @app.get("/")
|
||||
# async def root():
|
||||
# return {"message": "Chat Agent API is running"}
|
||||
# return app
|
||||
|
||||
|
||||
def setup_middleware(app: FastAPI, settings: Settings) -> None:
|
||||
"""Setup application middleware."""
|
||||
|
||||
# User context middleware (should be first to set context for all requests)
|
||||
app.add_middleware(UserContextMiddleware)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors.allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=settings.cors.allowed_methods,
|
||||
allow_headers=settings.cors.allowed_headers,
|
||||
)
|
||||
|
||||
# Trusted host middleware (for production)
|
||||
if settings.environment == "production":
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=["*"] # Configure this properly in production
|
||||
)
|
||||
|
||||
|
||||
def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"""Setup global exception handlers."""
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": "http_error",
|
||||
"message": exc.detail,
|
||||
"status_code": exc.status_code
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def make_json_serializable(obj):
|
||||
"""递归地将对象转换为JSON可序列化的格式"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
elif isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
elif isinstance(obj, (ValueError, Exception)):
|
||||
return str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
# For any other object, convert to string
|
||||
return str(obj)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
# Convert any non-serializable objects to strings in error details
|
||||
try:
|
||||
errors = make_json_serializable(exc.errors())
|
||||
except Exception as e:
|
||||
# Fallback: if even our conversion fails, use a simple error message
|
||||
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"type": "validation_error",
|
||||
"message": "Request validation failed",
|
||||
"details": errors
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": "Internal server error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Create the app instance
|
||||
# app = create_app()
|
||||
|
|
@ -0,0 +1,453 @@
|
|||
"""Configuration management for TH Agenter."""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
"""Database configuration."""
|
||||
url: str = Field(..., alias="database_url") # Must be provided via environment variable
|
||||
echo: bool = Field(default=False)
|
||||
pool_size: int = Field(default=5)
|
||||
max_overflow: int = Field(default=10)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class SecuritySettings(BaseSettings):
|
||||
"""Security configuration."""
|
||||
secret_key: str = Field(default="your-secret-key-here-change-in-production")
|
||||
algorithm: str = Field(default="HS256")
|
||||
access_token_expire_minutes: int = Field(default=300)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class ToolSetings(BaseSettings):
|
||||
# Tavily搜索配置
|
||||
tavily_api_key: Optional[str] = Field(default=None)
|
||||
weather_api_key: Optional[str] = Field(default=None)
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class LLMSettings(BaseSettings):
|
||||
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
|
||||
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||
|
||||
# OpenAI配置
|
||||
openai_api_key: Optional[str] = Field(default=None)
|
||||
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||
openai_model: str = Field(default="gpt-3.5-turbo")
|
||||
|
||||
# DeepSeek配置
|
||||
deepseek_api_key: Optional[str] = Field(default=None)
|
||||
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||
deepseek_model: str = Field(default="deepseek-chat")
|
||||
|
||||
# 豆包配置
|
||||
doubao_api_key: Optional[str] = Field(default=None)
|
||||
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||
doubao_model: str = Field(default="doubao-lite-4k")
|
||||
|
||||
# 智谱AI配置
|
||||
zhipu_api_key: Optional[str] = Field(default=None)
|
||||
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||
zhipu_model: str = Field(default="glm-4")
|
||||
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||
|
||||
# 月之暗面配置
|
||||
moonshot_api_key: Optional[str] = Field(default=None)
|
||||
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||
moonshot_model: str = Field(default="moonshot-v1-8k")
|
||||
|
||||
# 通用配置
|
||||
max_tokens: int = Field(default=2048)
|
||||
temperature: float = Field(default=0.7)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
def get_current_config(self) -> dict:
|
||||
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
|
||||
try:
|
||||
# 尝试从数据库读取默认聊天模型配置
|
||||
from th_agenter.services.llm_config_service import LLMConfigService
|
||||
llm_service = LLMConfigService()
|
||||
db_config = llm_service.get_default_chat_config()
|
||||
|
||||
if db_config:
|
||||
# 如果数据库中有默认配置,使用数据库配置
|
||||
config = {
|
||||
"api_key": db_config.api_key,
|
||||
"base_url": db_config.base_url,
|
||||
"model": db_config.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
"openai": {
|
||||
"api_key": self.openai_api_key,
|
||||
"base_url": self.openai_base_url,
|
||||
"model": self.openai_model
|
||||
},
|
||||
"deepseek": {
|
||||
"api_key": self.deepseek_api_key,
|
||||
"base_url": self.deepseek_base_url,
|
||||
"model": self.deepseek_model
|
||||
},
|
||||
"doubao": {
|
||||
"api_key": self.doubao_api_key,
|
||||
"base_url": self.doubao_base_url,
|
||||
"model": self.doubao_model
|
||||
},
|
||||
"zhipu": {
|
||||
"api_key": self.zhipu_api_key,
|
||||
"base_url": self.zhipu_base_url,
|
||||
"model": self.zhipu_model
|
||||
},
|
||||
"moonshot": {
|
||||
"api_key": self.moonshot_api_key,
|
||||
"base_url": self.moonshot_base_url,
|
||||
"model": self.moonshot_model
|
||||
}
|
||||
}
|
||||
|
||||
config = provider_configs.get(self.provider, provider_configs["openai"])
|
||||
config.update({
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature
|
||||
})
|
||||
return config
|
||||
|
||||
class EmbeddingSettings(BaseSettings):
|
||||
"""Embedding模型配置 - 支持多种提供商."""
|
||||
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||
|
||||
# OpenAI配置
|
||||
openai_api_key: Optional[str] = Field(default=None)
|
||||
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||
openai_embedding_model: str = Field(default="text-embedding-ada-002")
|
||||
|
||||
# DeepSeek配置
|
||||
deepseek_api_key: Optional[str] = Field(default=None)
|
||||
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||
deepseek_embedding_model: str = Field(default="deepseek-embedding")
|
||||
|
||||
# 豆包配置
|
||||
doubao_api_key: Optional[str] = Field(default=None)
|
||||
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||
doubao_embedding_model: str = Field(default="doubao-embedding")
|
||||
|
||||
# 智谱AI配置
|
||||
zhipu_api_key: Optional[str] = Field(default=None)
|
||||
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||
|
||||
# 月之暗面配置
|
||||
moonshot_api_key: Optional[str] = Field(default=None)
|
||||
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||
moonshot_embedding_model: str = Field(default="moonshot-embedding")
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
def get_current_config(self) -> dict:
|
||||
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
|
||||
try:
|
||||
# 尝试从数据库读取默认嵌入模型配置
|
||||
from th_agenter.services.llm_config_service import LLMConfigService
|
||||
llm_service = LLMConfigService()
|
||||
db_config = llm_service.get_default_embedding_config()
|
||||
|
||||
if db_config:
|
||||
# 如果数据库中有默认配置,使用数据库配置
|
||||
config = {
|
||||
"api_key": db_config.api_key,
|
||||
"base_url": db_config.base_url,
|
||||
"model": db_config.model_name
|
||||
}
|
||||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
"openai": {
|
||||
"api_key": self.openai_api_key,
|
||||
"base_url": self.openai_base_url,
|
||||
"model": self.openai_embedding_model
|
||||
},
|
||||
"deepseek": {
|
||||
"api_key": self.deepseek_api_key,
|
||||
"base_url": self.deepseek_base_url,
|
||||
"model": self.deepseek_embedding_model
|
||||
},
|
||||
"doubao": {
|
||||
"api_key": self.doubao_api_key,
|
||||
"base_url": self.doubao_base_url,
|
||||
"model": self.doubao_embedding_model
|
||||
},
|
||||
"zhipu": {
|
||||
"api_key": self.zhipu_api_key,
|
||||
"base_url": self.zhipu_base_url,
|
||||
"model": self.zhipu_embedding_model
|
||||
},
|
||||
"moonshot": {
|
||||
"api_key": self.moonshot_api_key,
|
||||
"base_url": self.moonshot_base_url,
|
||||
"model": self.moonshot_embedding_model
|
||||
}
|
||||
}
|
||||
|
||||
return provider_configs.get(self.provider, provider_configs["zhipu"])
|
||||
|
||||
class VectorDBSettings(BaseSettings):
|
||||
"""Vector database configuration."""
|
||||
type: str = Field(default="pgvector", alias="vector_db_type")
|
||||
persist_directory: str = Field(default="./data/chroma")
|
||||
collection_name: str = Field(default="documents")
|
||||
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
|
||||
|
||||
# PostgreSQL pgvector configuration
|
||||
pgvector_host: str = Field(default="localhost")
|
||||
pgvector_port: int = Field(default=5432)
|
||||
pgvector_database: str = Field(default="vectordb")
|
||||
pgvector_user: str = Field(default="postgres")
|
||||
pgvector_password: str = Field(default="")
|
||||
pgvector_table_name: str = Field(default="embeddings")
|
||||
pgvector_vector_dimension: int = Field(default=1024)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class FileSettings(BaseSettings):
|
||||
"""File processing configuration."""
|
||||
upload_dir: str = Field(default="./data/uploads")
|
||||
max_size: int = Field(default=10485760) # 10MB
|
||||
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
|
||||
chunk_size: int = Field(default=1000)
|
||||
chunk_overlap: int = Field(default=200)
|
||||
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
|
||||
|
||||
@field_validator('allowed_extensions', mode='before')
|
||||
@classmethod
|
||||
def parse_allowed_extensions(cls, v):
|
||||
"""Parse comma-separated string to list of extensions."""
|
||||
if isinstance(v, str):
|
||||
# Split by comma and add dots if not present
|
||||
extensions = [ext.strip() for ext in v.split(',')]
|
||||
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||
elif isinstance(v, list):
|
||||
# Ensure all extensions start with dot
|
||||
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
|
||||
return v
|
||||
|
||||
def get_allowed_extensions_list(self) -> List[str]:
|
||||
"""Get allowed extensions as a list."""
|
||||
if isinstance(self.allowed_extensions, list):
|
||||
return self.allowed_extensions
|
||||
elif isinstance(self.allowed_extensions, str):
|
||||
# Split by comma and add dots if not present
|
||||
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
|
||||
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||
return []
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
"""Storage configuration."""
|
||||
storage_type: str = Field(default="local") # local or s3
|
||||
upload_directory: str = Field(default="./data/uploads")
|
||||
|
||||
# S3 settings
|
||||
s3_bucket_name: str = Field(default="chat-agent-files")
|
||||
aws_access_key_id: Optional[str] = Field(default=None)
|
||||
aws_secret_access_key: Optional[str] = Field(default=None)
|
||||
aws_region: str = Field(default="us-east-1")
|
||||
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class CORSSettings(BaseSettings):
|
||||
"""CORS configuration."""
|
||||
allowed_origins: List[str] = Field(default=["*"])
|
||||
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
allowed_headers: List[str] = Field(default=["*"])
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class ChatSettings(BaseSettings):
|
||||
"""Chat configuration."""
|
||||
max_history_length: int = Field(default=10)
|
||||
system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。")
|
||||
max_response_tokens: int = Field(default=1000)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Main application settings."""
|
||||
|
||||
# App info
|
||||
app_name: str = Field(default="TH Agenter")
|
||||
app_version: str = Field(default="0.2.0")
|
||||
debug: bool = Field(default=True)
|
||||
environment: str = Field(default="development")
|
||||
|
||||
# Server
|
||||
host: str = Field(default="0.0.0.0")
|
||||
port: int = Field(default=8000)
|
||||
|
||||
# Configuration sections
|
||||
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||||
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
||||
llm: LLMSettings = Field(default_factory=LLMSettings)
|
||||
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
|
||||
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
|
||||
file: FileSettings = Field(default_factory=FileSettings)
|
||||
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||||
cors: CORSSettings = Field(default_factory=CORSSettings)
|
||||
chat: ChatSettings = Field(default_factory=ChatSettings)
|
||||
tool: ToolSetings = Field(default_factory=ToolSetings)
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
|
||||
"""Load settings from YAML file."""
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
# 获取当前文件所在目录(backend/open_agent/core)
|
||||
current_dir = Path(__file__).parent
|
||||
# 向上两级到backend目录,然后找configs/settings.yaml
|
||||
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
|
||||
if backend_config_path.exists():
|
||||
config_file = backend_config_path
|
||||
else:
|
||||
return cls()
|
||||
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
# 处理环境变量替换
|
||||
config_data = cls._resolve_env_vars_nested(config_data)
|
||||
|
||||
# 为每个子设置类创建实例,确保它们能正确加载环境变量
|
||||
# 如果YAML中没有对应配置,则使用默认的BaseSettings加载(会自动读取.env文件)
|
||||
settings_kwargs = {}
|
||||
|
||||
# 显式处理各个子设置,以解决debug等情况因为环境的变化没有自动加载.env配置的问题
|
||||
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
|
||||
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
|
||||
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
|
||||
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
|
||||
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
|
||||
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
|
||||
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
|
||||
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
|
||||
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
|
||||
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||||
|
||||
# 添加顶级配置
|
||||
for key, value in config_data.items():
|
||||
if key not in settings_kwargs:
|
||||
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
|
||||
settings_kwargs[key] = value
|
||||
|
||||
return cls(**settings_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
||||
"""Flatten nested configuration dictionary."""
|
||||
flat = {}
|
||||
for key, value in config.items():
|
||||
new_key = f"{prefix}_{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
flat.update(Settings._flatten_config(value, new_key))
|
||||
else:
|
||||
flat[new_key] = value
|
||||
return flat
|
||||
|
||||
@staticmethod
|
||||
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Resolve environment variables in nested configuration."""
|
||||
if isinstance(config, dict):
|
||||
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
|
||||
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
|
||||
env_var = config[2:-1]
|
||||
return os.getenv(env_var, config)
|
||||
else:
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Resolve environment variables in configuration values."""
|
||||
resolved = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
env_var = value[2:-1]
|
||||
resolved[key] = os.getenv(env_var, value)
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
settings = Settings.load_from_yaml()
|
||||
return settings
|
||||
|
||||
settings = get_settings()
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
HTTP请求上下文管理,如:获取当前登录用户信息及Token信息
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
import threading
|
||||
from ..models.user import User
|
||||
from loguru import logger
|
||||
|
||||
# Context variable to store current user
|
||||
current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None)
|
||||
|
||||
# Thread-local storage as backup
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
class UserContext:
|
||||
"""User context manager for accessing current user globally."""
|
||||
|
||||
@staticmethod
|
||||
def set_current_user(user: User) -> None:
|
||||
"""Set current user in context."""
|
||||
logger.info(f"[UserContext] - Setting user in context: {user.username} (ID: {user.id})")
|
||||
|
||||
# Set in ContextVar
|
||||
current_user_context.set(user)
|
||||
|
||||
# Also set in thread-local as backup
|
||||
_thread_local.current_user = user
|
||||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
|
||||
@staticmethod
|
||||
def set_current_user_with_token(user: User):
|
||||
"""Set current user in context and return token for cleanup."""
|
||||
logger.info(f"[UserContext] - Setting user in context with token: {user.username} (ID: {user.id})")
|
||||
|
||||
# Set in ContextVar and get token
|
||||
token = current_user_context.set(user)
|
||||
|
||||
# Also set in thread-local as backup
|
||||
_thread_local.current_user = user
|
||||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def reset_current_user_token(token):
|
||||
"""Reset current user context using token."""
|
||||
logger.info("[UserContext] - Resetting user context using token")
|
||||
|
||||
# Reset ContextVar using token
|
||||
current_user_context.reset(token)
|
||||
|
||||
# Clear thread-local as well
|
||||
if hasattr(_thread_local, 'current_user'):
|
||||
delattr(_thread_local, 'current_user')
|
||||
|
||||
@staticmethod
|
||||
def get_current_user() -> Optional[User]:
|
||||
"""Get current user from context."""
|
||||
logger.debug("[UserContext] - Attempting to get user from context")
|
||||
|
||||
# Try ContextVar first
|
||||
user = current_user_context.get()
|
||||
if user:
|
||||
logger.debug(f"[UserContext] - Got user from ContextVar: {user.username} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# Fallback to thread-local
|
||||
user = getattr(_thread_local, 'current_user', None)
|
||||
if user:
|
||||
logger.debug(f"[UserContext] - Got user from thread-local: {user.username} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
logger.debug("[UserContext] - No user found in context (neither ContextVar nor thread-local)")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_current_user_id() -> Optional[int]:
|
||||
"""Get current user ID from context."""
|
||||
user = UserContext.get_current_user()
|
||||
return user.id if user else None
|
||||
|
||||
@staticmethod
|
||||
def clear_current_user() -> None:
|
||||
"""Clear current user from context."""
|
||||
logger.info("[UserContext] - 清除当前用户上下文")
|
||||
|
||||
current_user_context.set(None)
|
||||
if hasattr(_thread_local, 'current_user'):
|
||||
delattr(_thread_local, 'current_user')
|
||||
|
||||
@staticmethod
|
||||
def require_current_user() -> User:
|
||||
"""Get current user from context, raise exception if not found."""
|
||||
# Use the same logic as get_current_user to check both ContextVar and thread-local
|
||||
user = UserContext.get_current_user()
|
||||
if user is None:
|
||||
from fastapi import HTTPException, status
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No authenticated user in context"
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def require_current_user_id() -> int:
|
||||
"""Get current user ID from context, raise exception if not found."""
|
||||
user = UserContext.require_current_user()
|
||||
return user.id
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""Custom exceptions for the application."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class BaseCustomException(Exception):
|
||||
"""Base custom exception class."""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundError(BaseCustomException):
|
||||
"""Exception raised when a resource is not found."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(BaseCustomException):
|
||||
"""Exception raised when validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(BaseCustomException):
|
||||
"""Exception raised when authentication fails."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(BaseCustomException):
|
||||
"""Exception raised when authorization fails."""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(BaseCustomException):
|
||||
"""Exception raised when database operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(BaseCustomException):
|
||||
"""Exception raised when configuration is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class ExternalServiceError(BaseCustomException):
|
||||
"""Exception raised when external service calls fail."""
|
||||
pass
|
||||
|
||||
|
||||
class BusinessLogicError(BaseCustomException):
|
||||
"""Exception raised when business logic validation fails."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""LLM工厂类,用于创建和管理LLM实例"""
|
||||
|
||||
from typing import Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from .config import get_settings
|
||||
|
||||
def create_llm(model: Optional[str] = None, temperature: Optional[float] = None, streaming: bool = False) -> ChatOpenAI:
|
||||
"""创建LLM实例
|
||||
|
||||
Args:
|
||||
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||
temperature: 可选,模型温度参数
|
||||
streaming: 是否启用流式响应,默认False
|
||||
|
||||
Returns:
|
||||
ChatOpenAI实例
|
||||
"""
|
||||
settings = get_settings()
|
||||
llm_config = settings.llm.get_current_config()
|
||||
|
||||
if model:
|
||||
# 根据指定的模型获取对应配置
|
||||
if model.startswith('deepseek'):
|
||||
llm_config['model'] = settings.llm.deepseek_model
|
||||
llm_config['api_key'] = settings.llm.deepseek_api_key
|
||||
llm_config['base_url'] = settings.llm.deepseek_base_url
|
||||
elif model.startswith('doubao'):
|
||||
llm_config['model'] = settings.llm.doubao_model
|
||||
llm_config['api_key'] = settings.llm.doubao_api_key
|
||||
llm_config['base_url'] = settings.llm.doubao_base_url
|
||||
elif model.startswith('glm'):
|
||||
llm_config['model'] = settings.llm.zhipu_model
|
||||
llm_config['api_key'] = settings.llm.zhipu_api_key
|
||||
llm_config['base_url'] = settings.llm.zhipu_base_url
|
||||
elif model.startswith('moonshot'):
|
||||
llm_config['model'] = settings.llm.moonshot_model
|
||||
llm_config['api_key'] = settings.llm.moonshot_api_key
|
||||
llm_config['base_url'] = settings.llm.moonshot_base_url
|
||||
|
||||
return ChatOpenAI(
|
||||
model=llm_config['model'],
|
||||
api_key=llm_config['api_key'],
|
||||
base_url=llm_config['base_url'],
|
||||
temperature=temperature if temperature is not None else llm_config['temperature'],
|
||||
max_tokens=llm_config['max_tokens'],
|
||||
streaming=streaming
|
||||
)
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
中间件管理,如上下文中间件:校验Token等
|
||||
"""
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Callable
|
||||
from loguru import logger
|
||||
from fastapi import status
|
||||
from utils.util_exceptions import HxfErrorResponse
|
||||
|
||||
from ..db.database import get_session, AsyncSessionFactory, engine_async
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from ..services.auth import AuthService
|
||||
from .context import UserContext
|
||||
|
||||
class UserContextMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to set user context for authenticated requests."""
|
||||
|
||||
def __init__(self, app, exclude_paths: list = None):
|
||||
super().__init__(app)
|
||||
self.canLog = True
|
||||
# Paths that don't require authentication
|
||||
self.exclude_paths = exclude_paths or [
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/auth/login",
|
||||
"/api/auth/register",
|
||||
"/api/auth/login-oauth",
|
||||
"/auth/login",
|
||||
"/auth/register",
|
||||
"/auth/login-oauth",
|
||||
"/health",
|
||||
"/static/"
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request and set user context if authenticated."""
|
||||
if self.canLog:
|
||||
logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}")
|
||||
|
||||
# Skip authentication for excluded paths
|
||||
path = request.url.path
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}")
|
||||
|
||||
should_skip = False
|
||||
for exclude_path in self.exclude_paths:
|
||||
# Exact match
|
||||
if path == exclude_path:
|
||||
should_skip = True
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}")
|
||||
break
|
||||
# For paths ending with '/', check if request path starts with it
|
||||
elif exclude_path.endswith('/') and path.startswith(exclude_path):
|
||||
should_skip = True
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头")
|
||||
break
|
||||
# For paths not ending with '/', check if request path starts with it + '/'
|
||||
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
|
||||
should_skip = True
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头")
|
||||
break
|
||||
|
||||
if should_skip:
|
||||
if self.canLog:
|
||||
logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next")
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
|
||||
|
||||
# Always clear any existing user context to ensure fresh authentication
|
||||
UserContext.clear_current_user()
|
||||
|
||||
# Initialize context token
|
||||
user_token = None
|
||||
|
||||
# Try to extract and validate token
|
||||
try:
|
||||
# Get authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# No token provided, return 401 error
|
||||
return HxfErrorResponse(
|
||||
message="缺少或无效的授权头",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Extract token
|
||||
token = authorization.split(" ")[1]
|
||||
|
||||
# Verify token
|
||||
payload = AuthService.verify_token(token)
|
||||
if payload is None:
|
||||
# Invalid token, return 401 error
|
||||
return HxfErrorResponse(
|
||||
message="无效或过期的令牌",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Get username from token
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
return HxfErrorResponse(
|
||||
message="令牌负载无效",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
from sqlalchemy import select
|
||||
from ..models.user import User
|
||||
|
||||
# 创建一个临时的异步会话获取用户信息
|
||||
session = AsyncSession(bind=engine_async)
|
||||
try:
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = await session.execute(stmt)
|
||||
user = user.scalar_one_or_none()
|
||||
if not user:
|
||||
return HxfErrorResponse(
|
||||
message="用户不存在",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
return HxfErrorResponse(
|
||||
message="用户账户已停用",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Set user in context using token mechanism
|
||||
user_token = UserContext.set_current_user_with_token(user)
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
|
||||
|
||||
# Verify context is set correctly
|
||||
current_user_id = UserContext.get_current_user_id()
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail the request
|
||||
logger.error(f"[MIDDLEWARE] - 认证过程中设置用户上下文出错: {e}")
|
||||
# Return 401 error
|
||||
return HxfErrorResponse(
|
||||
message="认证过程中出错",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
finally:
|
||||
# Always clear user context after request processing
|
||||
UserContext.clear_current_user()
|
||||
if self.canLog:
|
||||
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""简化的权限检查系统."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.database import get_session
|
||||
from ..models.user import User
|
||||
from ..models.permission import Role
|
||||
from ..services.auth import AuthService
|
||||
|
||||
|
||||
def is_super_admin(user: User, db: Session) -> bool:
|
||||
"""检查用户是否为超级管理员."""
|
||||
if not user or not user.is_active:
|
||||
return False
|
||||
|
||||
# 检查用户是否有超级管理员角色
|
||||
try:
|
||||
# 尝试访问已加载的角色
|
||||
for role in user.roles:
|
||||
if role.code == "SUPER_ADMIN":
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
# 如果角色未加载或访问失败,直接从数据库查询
|
||||
from sqlalchemy import select, and_
|
||||
from ..models.permission import Role, UserRole
|
||||
|
||||
try:
|
||||
# 直接查询用户角色
|
||||
stmt = select(Role).join(UserRole).filter(
|
||||
and_(
|
||||
UserRole.user_id == user.id,
|
||||
Role.code == "SUPER_ADMIN",
|
||||
Role.is_active == True
|
||||
)
|
||||
)
|
||||
super_admin_role = db.execute(stmt).scalar_one_or_none()
|
||||
return super_admin_role is not None
|
||||
except Exception:
|
||||
# 如果查询失败,返回False
|
||||
return False
|
||||
|
||||
|
||||
def require_super_admin(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
) -> User:
|
||||
"""要求超级管理员权限的依赖项."""
|
||||
if not is_super_admin(current_user, session):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="需要超级管理员权限"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def require_authenticated_user(
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
) -> User:
|
||||
"""要求已认证用户的依赖项."""
|
||||
if not current_user or not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要登录"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
class SimplePermissionChecker:
|
||||
"""简化的权限检查器."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_super_admin(self, user: User) -> bool:
|
||||
"""检查是否为超级管理员."""
|
||||
return is_super_admin(user, self.db)
|
||||
|
||||
def check_user_access(self, user: User, target_user_id: int) -> bool:
|
||||
"""检查用户访问权限(自己或超级管理员)."""
|
||||
if not user or not user.is_active:
|
||||
return False
|
||||
|
||||
# 超级管理员可以访问所有用户
|
||||
if self.check_super_admin(user):
|
||||
return True
|
||||
|
||||
# 用户只能访问自己的信息
|
||||
return user.id == target_user_id
|
||||
|
||||
|
||||
# 权限装饰器
|
||||
def super_admin_required(func):
|
||||
"""超级管理员权限装饰器."""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def authenticated_required(func):
|
||||
"""认证用户权限装饰器."""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
"""User utility functions for easy access to current user context."""
|
||||
|
||||
from typing import Optional
|
||||
from ..models.user import User
|
||||
from .context import UserContext
|
||||
|
||||
|
||||
def get_current_user() -> Optional[User]:
|
||||
"""Get current authenticated user from context.
|
||||
|
||||
Returns:
|
||||
Current user if authenticated, None otherwise
|
||||
"""
|
||||
return UserContext.get_current_user()
|
||||
|
||||
|
||||
def get_current_user_id() -> Optional[int]:
|
||||
"""Get current authenticated user ID from context.
|
||||
|
||||
Returns:
|
||||
Current user ID if authenticated, None otherwise
|
||||
"""
|
||||
return UserContext.get_current_user_id()
|
||||
|
||||
|
||||
def require_current_user() -> User:
|
||||
"""Get current authenticated user from context, raise exception if not found.
|
||||
|
||||
Returns:
|
||||
Current user
|
||||
|
||||
Raises:
|
||||
HTTPException: If no authenticated user in context
|
||||
"""
|
||||
return UserContext.require_current_user()
|
||||
|
||||
|
||||
def require_current_user_id() -> int:
|
||||
"""Get current authenticated user ID from context, raise exception if not found.
|
||||
|
||||
Returns:
|
||||
Current user ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If no authenticated user in context
|
||||
"""
|
||||
return UserContext.require_current_user_id()
|
||||
|
||||
|
||||
def is_user_authenticated() -> bool:
|
||||
"""Check if there is an authenticated user in the current context.
|
||||
|
||||
Returns:
|
||||
True if user is authenticated, False otherwise
|
||||
"""
|
||||
return UserContext.get_current_user() is not None
|
||||
|
||||
|
||||
def get_current_username() -> Optional[str]:
|
||||
"""Get current authenticated user's username from context.
|
||||
|
||||
Returns:
|
||||
Current user's username if authenticated, None otherwise
|
||||
"""
|
||||
user = UserContext.get_current_user()
|
||||
return user.username if user else None
|
||||
|
||||
|
||||
def get_current_user_email() -> Optional[str]:
|
||||
"""Get current authenticated user's email from context.
|
||||
|
||||
Returns:
|
||||
Current user's email if authenticated, None otherwise
|
||||
"""
|
||||
user = UserContext.get_current_user()
|
||||
return user.email if user else None
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Database module for TH Agenter."""
|
||||
|
||||
from .database import get_session
|
||||
from .base import Base
|
||||
from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||
|
||||
|
||||
__all__ = ["get_session", "Base"]
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
"""Database base model."""
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Integer, DateTime, event
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
|
||||
from sqlalchemy.sql import func
|
||||
from typing import Optional
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(
|
||||
naming_convention={
|
||||
# ix: index, 索引
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
# uq: unique, 唯一约束
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
# ck: check, 检查约束
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
# fk: foreign key, 外键约束
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
# pk: primary key, 主键约束
|
||||
"pk": "pk_%(table_name)s"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
"""Base model with common fields."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||||
created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert model to dictionary."""
|
||||
return {
|
||||
column.name: getattr(self, column.name)
|
||||
for column in self.__table__.columns
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
"""Create model instance from dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing model field values
|
||||
|
||||
Returns:
|
||||
Model instance created from the dictionary
|
||||
"""
|
||||
# Filter out fields that don't exist in the model
|
||||
model_fields = {column.name for column in cls.__table__.columns}
|
||||
filtered_data = {key: value for key, value in data.items() if key in model_fields}
|
||||
|
||||
# Create and return the instance
|
||||
return cls(**filtered_data)
|
||||
|
||||
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||
"""Set audit fields for create/update operations.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||
is_update: True for update operations, False for create operations
|
||||
"""
|
||||
# Get user_id from context if not provided
|
||||
if user_id is None:
|
||||
from ..core.context import UserContext
|
||||
try:
|
||||
user_id = UserContext.get_current_user_id()
|
||||
except Exception:
|
||||
# If no user in context, skip setting audit fields
|
||||
return
|
||||
|
||||
# Skip if still no user_id
|
||||
if user_id is None:
|
||||
return
|
||||
|
||||
if not is_update:
|
||||
# For create operations, set both create_by and update_by
|
||||
self.created_by = user_id
|
||||
self.updated_by = user_id
|
||||
else:
|
||||
# For update operations, only set update_by
|
||||
self.updated_by = user_id
|
||||
|
||||
# @event.listens_for(Session, 'before_flush')
|
||||
# def set_audit_fields_before_flush(session, flush_context, instances):
|
||||
# """Automatically set audit fields before flush."""
|
||||
# try:
|
||||
# from th_agenter.core.context import UserContext
|
||||
# user_id = UserContext.get_current_user_id()
|
||||
# except Exception:
|
||||
# user_id = None
|
||||
|
||||
# # 处理新增对象
|
||||
# for instance in session.new:
|
||||
# if isinstance(instance, BaseModel) and user_id:
|
||||
# instance.created_by = user_id
|
||||
# instance.updated_by = user_id
|
||||
|
||||
# # 处理修改对象
|
||||
# for instance in session.dirty:
|
||||
# if isinstance(instance, BaseModel) and user_id:
|
||||
# instance.updated_by = user_id
|
||||
|
||||
# # def __init__(self, **kwargs):
|
||||
# # """Initialize model with automatic audit fields setting."""
|
||||
# # super().__init__(**kwargs)
|
||||
# # # Set audit fields for new instances
|
||||
# # self.set_audit_fields()
|
||||
|
||||
# # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||
# # """Set audit fields for create/update operations.
|
||||
|
||||
# # Args:
|
||||
# # user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||
# # is_update: True for update operations, False for create operations
|
||||
# # """
|
||||
# # # Get user_id from context if not provided
|
||||
# # if user_id is None:
|
||||
# # from ..core.context import UserContext
|
||||
# # try:
|
||||
# # user_id = UserContext.get_current_user_id()
|
||||
# # except Exception:
|
||||
# # # If no user in context, skip setting audit fields
|
||||
# # return
|
||||
|
||||
# # # Skip if still no user_id
|
||||
# # if user_id is None:
|
||||
# # return
|
||||
|
||||
# # if not is_update:
|
||||
# # # For create operations, set both create_by and update_by
|
||||
# # self.created_by = user_id
|
||||
# # self.updated_by = user_id
|
||||
# # else:
|
||||
# # # For update operations, only set update_by
|
||||
# # self.updated_by = user_id
|
||||
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""Database connection and session management."""
|
||||
|
||||
import uuid, re
|
||||
from loguru import logger
|
||||
import traceback
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import Optional
|
||||
|
||||
from ..core.config import get_settings
|
||||
from .base import Base
|
||||
from utils.util_exceptions import DatabaseError
|
||||
|
||||
# Custom Session class with desc property and unique ID
|
||||
class DrSession(AsyncSession):
|
||||
"""Custom Session class with desc property and unique ID."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize DrSession with unique ID."""
|
||||
super().__init__(**kwargs)
|
||||
# 确保info属性存在
|
||||
if not hasattr(self, 'info'):
|
||||
self.info = {}
|
||||
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
||||
self.stepIndex = 0
|
||||
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Get work brief from session info."""
|
||||
return self.info.get('desc')
|
||||
|
||||
@desc.setter
|
||||
def desc(self, value: str) -> None:
|
||||
"""Set work brief in session info."""
|
||||
self.stepIndex += 1
|
||||
|
||||
|
||||
def log_prefix(self) -> str:
|
||||
"""Get log prefix with session ID and desc."""
|
||||
return f"〖Session{self.info['session_id']}〗"
|
||||
|
||||
def parse_source_pos(self, level: int):
|
||||
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
||||
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||
if match:
|
||||
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||
return pos
|
||||
|
||||
def log_info(self, msg: str, level: int = -2):
|
||||
"""Log info message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_success(self, msg: str, level: int = -2):
|
||||
"""Log success message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_warning(self, msg: str, level: int = -2):
|
||||
"""Log warning message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_error(self, msg: str, level: int = -2):
|
||||
"""Log error message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_exception(self, msg: str, level: int = -2):
|
||||
"""Log exception message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
engine_async = create_async_engine(
|
||||
get_settings().database.url,
|
||||
echo=True, # get_settings().database.echo,
|
||||
future=True,
|
||||
pool_size=get_settings().database.pool_size,
|
||||
max_overflow=get_settings().database.max_overflow,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
from fastapi import Request
|
||||
|
||||
AsyncSessionFactory = sessionmaker(
|
||||
bind=engine_async,
|
||||
class_=DrSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=True
|
||||
)
|
||||
|
||||
async def get_session(request: Request = None):
|
||||
url = "无request"
|
||||
if request:
|
||||
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
||||
# session = AsyncSessionFactory()
|
||||
|
||||
session = DrSession(bind=engine_async)
|
||||
|
||||
session.desc = f"SUCCESS: 创建数据库 session >>> {url}"
|
||||
|
||||
# 设置request属性
|
||||
if request:
|
||||
session.request = request
|
||||
|
||||
try:
|
||||
yield session
|
||||
|
||||
except Exception as e:
|
||||
errMsg = f"数据库 session 异常 >>> {e}"
|
||||
session.desc = f"EXCEPTION: {errMsg}"
|
||||
await session.rollback()
|
||||
raise e
|
||||
# DatabaseError(e)
|
||||
finally:
|
||||
session.desc = f"数据库 session 关闭"
|
||||
await session.close()
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
"""Add system management tables.
|
||||
|
||||
Revision ID: add_system_management
|
||||
Revises:
|
||||
Create Date: 2024-01-01 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic_sync import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_system_management'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create system management tables."""
|
||||
|
||||
# Create departments table
|
||||
op.create_table('departments',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=50), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('parent_id', sa.Integer(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['departments.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('code')
|
||||
)
|
||||
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_departments_parent_id'), 'departments', ['parent_id'], unique=False)
|
||||
|
||||
# Create permissions table
|
||||
op.create_table('permissions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=100), nullable=False),
|
||||
sa.Column('category', sa.String(length=50), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('code')
|
||||
)
|
||||
op.create_index(op.f('ix_permissions_category'), 'permissions', ['category'], unique=False)
|
||||
op.create_index(op.f('ix_permissions_name'), 'permissions', ['name'], unique=False)
|
||||
|
||||
# Create roles table
|
||||
op.create_table('roles',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=50), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('code')
|
||||
)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False)
|
||||
|
||||
# Create role_permissions table
|
||||
op.create_table('role_permissions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('role_id', 'permission_id', name='uq_role_permission')
|
||||
)
|
||||
op.create_index(op.f('ix_role_permissions_permission_id'), 'role_permissions', ['permission_id'], unique=False)
|
||||
op.create_index(op.f('ix_role_permissions_role_id'), 'role_permissions', ['role_id'], unique=False)
|
||||
|
||||
# Create user_roles table
|
||||
op.create_table('user_roles',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'role_id', name='uq_user_role')
|
||||
)
|
||||
op.create_index(op.f('ix_user_roles_role_id'), 'user_roles', ['role_id'], unique=False)
|
||||
op.create_index(op.f('ix_user_roles_user_id'), 'user_roles', ['user_id'], unique=False)
|
||||
|
||||
# Create user_permissions table
|
||||
op.create_table('user_permissions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||
sa.Column('granted', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'permission_id', name='uq_user_permission')
|
||||
)
|
||||
op.create_index(op.f('ix_user_permissions_permission_id'), 'user_permissions', ['permission_id'], unique=False)
|
||||
op.create_index(op.f('ix_user_permissions_user_id'), 'user_permissions', ['user_id'], unique=False)
|
||||
|
||||
# Create llm_configs table
|
||||
op.create_table('llm_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('api_key', sa.Text(), nullable=True),
|
||||
sa.Column('api_base', sa.String(length=500), nullable=True),
|
||||
sa.Column('api_version', sa.String(length=20), nullable=True),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('temperature', sa.Float(), nullable=True),
|
||||
sa.Column('top_p', sa.Float(), nullable=True),
|
||||
sa.Column('frequency_penalty', sa.Float(), nullable=True),
|
||||
sa.Column('presence_penalty', sa.Float(), nullable=True),
|
||||
sa.Column('timeout', sa.Integer(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||
|
||||
# Add new columns to users table
|
||||
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||
op.add_column('users', sa.Column('is_superuser', sa.Boolean(), nullable=True, default=False))
|
||||
op.add_column('users', sa.Column('is_admin', sa.Boolean(), nullable=True, default=False))
|
||||
op.add_column('users', sa.Column('last_login_at', sa.DateTime(), nullable=True))
|
||||
op.add_column('users', sa.Column('login_count', sa.Integer(), nullable=True, default=0))
|
||||
|
||||
# Create foreign key constraint for department_id
|
||||
op.create_foreign_key('fk_users_department_id', 'users', 'departments', ['department_id'], ['id'])
|
||||
op.create_index(op.f('ix_users_department_id'), 'users', ['department_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop system management tables."""
|
||||
|
||||
# Drop foreign key and index for users.department_id
|
||||
op.drop_index(op.f('ix_users_department_id'), table_name='users')
|
||||
op.drop_constraint('fk_users_department_id', 'users', type_='foreignkey')
|
||||
|
||||
# Drop new columns from users table
|
||||
op.drop_column('users', 'login_count')
|
||||
op.drop_column('users', 'last_login_at')
|
||||
op.drop_column('users', 'is_admin')
|
||||
op.drop_column('users', 'is_superuser')
|
||||
op.drop_column('users', 'department_id')
|
||||
|
||||
# Drop llm_configs table
|
||||
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
|
||||
# Drop user_permissions table
|
||||
op.drop_index(op.f('ix_user_permissions_user_id'), table_name='user_permissions')
|
||||
op.drop_index(op.f('ix_user_permissions_permission_id'), table_name='user_permissions')
|
||||
op.drop_table('user_permissions')
|
||||
|
||||
# Drop user_roles table
|
||||
op.drop_index(op.f('ix_user_roles_user_id'), table_name='user_roles')
|
||||
op.drop_index(op.f('ix_user_roles_role_id'), table_name='user_roles')
|
||||
op.drop_table('user_roles')
|
||||
|
||||
# Drop role_permissions table
|
||||
op.drop_index(op.f('ix_role_permissions_role_id'), table_name='role_permissions')
|
||||
op.drop_index(op.f('ix_role_permissions_permission_id'), table_name='role_permissions')
|
||||
op.drop_table('role_permissions')
|
||||
|
||||
# Drop roles table
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
|
||||
# Drop permissions table
|
||||
op.drop_index(op.f('ix_permissions_name'), table_name='permissions')
|
||||
op.drop_index(op.f('ix_permissions_category'), table_name='permissions')
|
||||
op.drop_table('permissions')
|
||||
|
||||
# Drop departments table
|
||||
op.drop_index(op.f('ix_departments_parent_id'), table_name='departments')
|
||||
op.drop_index(op.f('ix_departments_name'), table_name='departments')
|
||||
op.drop_table('departments')
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
"""Add user_department association table migration."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
from th_agenter.core.config import get_settings
|
||||
|
||||
async def create_user_department_table():
|
||||
"""Create user_departments association table."""
|
||||
settings = get_settings()
|
||||
database_url = settings.database.url
|
||||
|
||||
print(f"Database URL: {database_url}")
|
||||
|
||||
try:
|
||||
# 解析PostgreSQL连接URL
|
||||
# postgresql://user:password@host:port/database
|
||||
url_parts = database_url.replace('postgresql://', '').split('/')
|
||||
db_name = url_parts[1] if len(url_parts) > 1 else 'postgres'
|
||||
user_host = url_parts[0].split('@')
|
||||
user_pass = user_host[0].split(':')
|
||||
host_port = user_host[1].split(':')
|
||||
|
||||
user = user_pass[0]
|
||||
password = user_pass[1] if len(user_pass) > 1 else ''
|
||||
host = host_port[0]
|
||||
port = int(host_port[1]) if len(host_port) > 1 else 5432
|
||||
|
||||
# 连接PostgreSQL数据库
|
||||
conn = await asyncpg.connect(
|
||||
user=user,
|
||||
password=password,
|
||||
database=db_name,
|
||||
host=host,
|
||||
port=port
|
||||
)
|
||||
|
||||
# 创建user_departments表
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS user_departments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
department_id INTEGER NOT NULL,
|
||||
is_primary BOOLEAN NOT NULL DEFAULT true,
|
||||
is_active BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (department_id) REFERENCES departments (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
|
||||
await conn.execute(create_table_sql)
|
||||
|
||||
# 创建索引
|
||||
create_indexes_sql = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_user_departments_user_id ON user_departments (user_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_user_departments_department_id ON user_departments (department_id);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_user_departments_unique ON user_departments (user_id, department_id);"
|
||||
]
|
||||
|
||||
for index_sql in create_indexes_sql:
|
||||
await conn.execute(index_sql)
|
||||
|
||||
print("User departments table created successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating user departments table: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 'conn' in locals():
|
||||
await conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(create_user_department_table())
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Migration script to move hardcoded resources to database."""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from th_agenter.core.config import settings
|
||||
from th_agenter.db.database import Base, get_session
|
||||
from th_agenter.models import * # Import all models to ensure they're registered
|
||||
from th_agenter.utils.logger import get_logger
|
||||
from th_agenter.models.resource import Resource
|
||||
from th_agenter.models.permission import Role
|
||||
from th_agenter.models.resource import RoleResource
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def migrate_hardcoded_resources():
|
||||
"""Migrate hardcoded resources from init_resource_data.py to database."""
|
||||
db = None
|
||||
try:
|
||||
# Get database session
|
||||
db = get_session() # xxxx
|
||||
|
||||
if db is None:
|
||||
logger.error("Failed to create database session")
|
||||
return False
|
||||
|
||||
# Create all tables if they don't exist
|
||||
from th_agenter.db.database import engine as global_engine
|
||||
if global_engine:
|
||||
Base.metadata.create_all(bind=global_engine)
|
||||
|
||||
logger.info("Starting hardcoded resources migration...")
|
||||
|
||||
# Check if resources already exist
|
||||
existing_count = db.query(Resource).count()
|
||||
if existing_count > 0:
|
||||
logger.info(f"Found {existing_count} existing resources. Checking role assignments.")
|
||||
# 即使资源已存在,也要检查并分配角色资源关联
|
||||
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||
if admin_role:
|
||||
# 获取所有资源
|
||||
all_resources = db.query(Resource).all()
|
||||
assigned_count = 0
|
||||
|
||||
for resource in all_resources:
|
||||
# 检查关联是否已存在
|
||||
existing = db.query(RoleResource).filter(
|
||||
RoleResource.role_id == admin_role.id,
|
||||
RoleResource.resource_id == resource.id
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
role_resource = RoleResource(
|
||||
role_id=admin_role.id,
|
||||
resource_id=resource.id
|
||||
)
|
||||
db.add(role_resource)
|
||||
assigned_count += 1
|
||||
|
||||
if assigned_count > 0:
|
||||
db.commit()
|
||||
logger.info(f"已为系统管理员角色分配 {assigned_count} 个新资源")
|
||||
else:
|
||||
logger.info("系统管理员角色已拥有所有资源")
|
||||
else:
|
||||
logger.warning("未找到系统管理员角色")
|
||||
|
||||
return True
|
||||
|
||||
# Define hardcoded resource data
|
||||
main_menu_data = [
|
||||
{
|
||||
"name": "智能问答",
|
||||
"code": "CHAT",
|
||||
"type": "menu",
|
||||
"path": "/chat",
|
||||
"component": "views/Chat.vue",
|
||||
"icon": "ChatDotRound",
|
||||
"description": "智能问答功能",
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": False
|
||||
},
|
||||
{
|
||||
"name": "智能问数",
|
||||
"code": "SMART_QUERY",
|
||||
"type": "menu",
|
||||
"path": "/smart-query",
|
||||
"component": "views/SmartQuery.vue",
|
||||
"icon": "DataAnalysis",
|
||||
"description": "智能问数功能",
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": False
|
||||
},
|
||||
{
|
||||
"name": "知识库",
|
||||
"code": "KNOWLEDGE",
|
||||
"type": "menu",
|
||||
"path": "/knowledge",
|
||||
"component": "views/KnowledgeBase.vue",
|
||||
"icon": "Collection",
|
||||
"description": "知识库管理",
|
||||
"sort_order": 3,
|
||||
"requires_auth": True,
|
||||
"requires_admin": False
|
||||
},
|
||||
{
|
||||
"name": "工作流编排",
|
||||
"code": "WORKFLOW",
|
||||
"type": "menu",
|
||||
"path": "/workflow",
|
||||
"component": "views/Workflow.vue",
|
||||
"icon": "Connection",
|
||||
"description": "工作流编排功能",
|
||||
"sort_order": 4,
|
||||
"requires_auth": True,
|
||||
"requires_admin": False
|
||||
},
|
||||
{
|
||||
"name": "智能体管理",
|
||||
"code": "AGENT",
|
||||
"type": "menu",
|
||||
"path": "/agent",
|
||||
"component": "views/Agent.vue",
|
||||
"icon": "User",
|
||||
"description": "智能体管理功能",
|
||||
"sort_order": 5,
|
||||
"requires_auth": True,
|
||||
"requires_admin": False
|
||||
},
|
||||
{
|
||||
"name": "系统管理",
|
||||
"code": "SYSTEM",
|
||||
"type": "menu",
|
||||
"path": "/system",
|
||||
"component": "views/SystemManagement.vue",
|
||||
"icon": "Setting",
|
||||
"description": "系统管理功能",
|
||||
"sort_order": 6,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
}
|
||||
]
|
||||
|
||||
# Create main menu resources
|
||||
created_resources = {}
|
||||
for menu_data in main_menu_data:
|
||||
resource = Resource(**menu_data)
|
||||
db.add(resource)
|
||||
db.flush()
|
||||
created_resources[menu_data["code"]] = resource
|
||||
logger.info(f"Created main menu resource: {menu_data['name']}")
|
||||
|
||||
# System management submenu data
|
||||
system_submenu_data = [
|
||||
{
|
||||
"name": "用户管理",
|
||||
"code": "SYSTEM_USERS",
|
||||
"type": "menu",
|
||||
"path": "/system/users",
|
||||
"component": "components/system/UserManagement.vue",
|
||||
"icon": "User",
|
||||
"description": "用户管理功能",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "部门管理",
|
||||
"code": "SYSTEM_DEPARTMENTS",
|
||||
"type": "menu",
|
||||
"path": "/system/departments",
|
||||
"component": "components/system/DepartmentManagement.vue",
|
||||
"icon": "OfficeBuilding",
|
||||
"description": "部门管理功能",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "角色管理",
|
||||
"code": "SYSTEM_ROLES",
|
||||
"type": "menu",
|
||||
"path": "/system/roles",
|
||||
"component": "components/system/RoleManagement.vue",
|
||||
"icon": "Avatar",
|
||||
"description": "角色管理功能",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 3,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "权限管理",
|
||||
"code": "SYSTEM_PERMISSIONS",
|
||||
"type": "menu",
|
||||
"path": "/system/permissions",
|
||||
"component": "components/system/PermissionManagement.vue",
|
||||
"icon": "Lock",
|
||||
"description": "权限管理功能",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 4,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "资源管理",
|
||||
"code": "SYSTEM_RESOURCES",
|
||||
"type": "menu",
|
||||
"path": "/system/resources",
|
||||
"component": "components/system/ResourceManagement.vue",
|
||||
"icon": "Grid",
|
||||
"description": "资源管理功能",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 5,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "大模型管理",
|
||||
"code": "SYSTEM_LLM_CONFIGS",
|
||||
"type": "menu",
|
||||
"path": "/system/llm-configs",
|
||||
"component": "components/system/LLMConfigManagement.vue",
|
||||
"icon": "Cpu",
|
||||
"description": "大模型配置管理",
|
||||
"parent_id": created_resources["SYSTEM"].id,
|
||||
"sort_order": 6,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
}
|
||||
]
|
||||
|
||||
# Create system management submenu
|
||||
for submenu_data in system_submenu_data:
|
||||
submenu = Resource(**submenu_data)
|
||||
db.add(submenu)
|
||||
db.flush()
|
||||
created_resources[submenu_data["code"]] = submenu
|
||||
logger.info(f"Created system submenu resource: {submenu_data['name']}")
|
||||
|
||||
# Button resources data
|
||||
button_resources_data = [
|
||||
# User management buttons
|
||||
{
|
||||
"name": "新增用户",
|
||||
"code": "USER_CREATE_BTN",
|
||||
"type": "button",
|
||||
"description": "新增用户按钮",
|
||||
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "编辑用户",
|
||||
"code": "USER_EDIT_BTN",
|
||||
"type": "button",
|
||||
"description": "编辑用户按钮",
|
||||
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
# Role management buttons
|
||||
{
|
||||
"name": "新增角色",
|
||||
"code": "ROLE_CREATE_BTN",
|
||||
"type": "button",
|
||||
"description": "新增角色按钮",
|
||||
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "编辑角色",
|
||||
"code": "ROLE_EDIT_BTN",
|
||||
"type": "button",
|
||||
"description": "编辑角色按钮",
|
||||
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
# Permission management buttons
|
||||
{
|
||||
"name": "新增权限",
|
||||
"code": "PERMISSION_CREATE_BTN",
|
||||
"type": "button",
|
||||
"description": "新增权限按钮",
|
||||
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "编辑权限",
|
||||
"code": "PERMISSION_EDIT_BTN",
|
||||
"type": "button",
|
||||
"description": "编辑权限按钮",
|
||||
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
}
|
||||
]
|
||||
|
||||
# Create button resources
|
||||
for button_data in button_resources_data:
|
||||
button = Resource(**button_data)
|
||||
db.add(button)
|
||||
db.flush()
|
||||
created_resources[button_data["code"]] = button
|
||||
logger.info(f"Created button resource: {button_data['name']}")
|
||||
|
||||
# API resources data
|
||||
api_resources_data = [
|
||||
# User management APIs
|
||||
{
|
||||
"name": "用户列表API",
|
||||
"code": "USER_LIST_API",
|
||||
"type": "api",
|
||||
"path": "/api/users",
|
||||
"description": "获取用户列表API",
|
||||
"sort_order": 1,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "创建用户API",
|
||||
"code": "USER_CREATE_API",
|
||||
"type": "api",
|
||||
"path": "/api/users",
|
||||
"description": "创建用户API",
|
||||
"sort_order": 2,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
# Role management APIs
|
||||
{
|
||||
"name": "角色列表API",
|
||||
"code": "ROLE_LIST_API",
|
||||
"type": "api",
|
||||
"path": "/api/admin/roles",
|
||||
"description": "获取角色列表API",
|
||||
"sort_order": 5,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "创建角色API",
|
||||
"code": "ROLE_CREATE_API",
|
||||
"type": "api",
|
||||
"path": "/api/admin/roles",
|
||||
"description": "创建角色API",
|
||||
"sort_order": 6,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
# Resource management APIs
|
||||
{
|
||||
"name": "资源列表API",
|
||||
"code": "RESOURCE_LIST_API",
|
||||
"type": "api",
|
||||
"path": "/api/admin/resources",
|
||||
"description": "获取资源列表API",
|
||||
"sort_order": 10,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
},
|
||||
{
|
||||
"name": "创建资源API",
|
||||
"code": "RESOURCE_CREATE_API",
|
||||
"type": "api",
|
||||
"path": "/api/admin/resources",
|
||||
"description": "创建资源API",
|
||||
"sort_order": 11,
|
||||
"requires_auth": True,
|
||||
"requires_admin": True
|
||||
}
|
||||
]
|
||||
|
||||
# Create API resources
|
||||
for api_data in api_resources_data:
|
||||
api_resource = Resource(**api_data)
|
||||
db.add(api_resource)
|
||||
db.flush()
|
||||
created_resources[api_data["code"]] = api_resource
|
||||
logger.info(f"Created API resource: {api_data['name']}")
|
||||
|
||||
# 分配资源给系统管理员角色
|
||||
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||
if admin_role:
|
||||
all_resources = list(created_resources.values())
|
||||
for resource in all_resources:
|
||||
# 检查关联是否已存在
|
||||
existing = db.query(RoleResource).filter(
|
||||
RoleResource.role_id == admin_role.id,
|
||||
RoleResource.resource_id == resource.id
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
role_resource = RoleResource(
|
||||
role_id=admin_role.id,
|
||||
resource_id=resource.id
|
||||
)
|
||||
db.add(role_resource)
|
||||
|
||||
logger.info(f"已为系统管理员角色分配 {len(all_resources)} 个资源")
|
||||
else:
|
||||
logger.warning("未找到系统管理员角色")
|
||||
|
||||
db.commit()
|
||||
|
||||
total_resources = db.query(Resource).count()
|
||||
logger.info(f"Migration completed successfully. Total resources: {total_resources}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {str(e)}")
|
||||
if db:
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
main()
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""删除权限相关表的迁移脚本
|
||||
|
||||
Revision ID: remove_permission_tables
|
||||
Revises: add_system_management
|
||||
Create Date: 2024-01-25 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic_sync import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'remove_permission_tables'
|
||||
down_revision = 'add_system_management'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""删除权限相关表."""
|
||||
|
||||
# 获取数据库连接
|
||||
connection = op.get_bind()
|
||||
|
||||
# 删除外键约束和表(按依赖关系顺序)
|
||||
tables_to_drop = [
|
||||
'user_permissions', # 用户权限关联表
|
||||
'role_permissions', # 角色权限关联表
|
||||
'permission_resources', # 权限资源关联表
|
||||
'permissions', # 权限表
|
||||
'role_resources', # 角色资源关联表
|
||||
'resources', # 资源表
|
||||
'user_departments', # 用户部门关联表
|
||||
'departments' # 部门表
|
||||
]
|
||||
|
||||
for table_name in tables_to_drop:
|
||||
try:
|
||||
# 检查表是否存在
|
||||
result = connection.execute(text(f"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = '{table_name}'
|
||||
);
|
||||
"""))
|
||||
table_exists = result.scalar()
|
||||
|
||||
if table_exists:
|
||||
print(f"删除表: {table_name}")
|
||||
op.drop_table(table_name)
|
||||
else:
|
||||
print(f"表 {table_name} 不存在,跳过")
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除表 {table_name} 时出错: {e}")
|
||||
# 继续删除其他表
|
||||
continue
|
||||
|
||||
# 删除用户表中的部门相关字段
|
||||
try:
|
||||
# 检查字段是否存在
|
||||
result = connection.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'users' AND column_name = 'department_id';
|
||||
"""))
|
||||
column_exists = result.fetchone()
|
||||
|
||||
if column_exists:
|
||||
print("删除用户表中的 department_id 字段")
|
||||
op.drop_column('users', 'department_id')
|
||||
else:
|
||||
print("用户表中的 department_id 字段不存在,跳过")
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除 department_id 字段时出错: {e}")
|
||||
|
||||
# 简化 user_roles 表结构(如果需要的话)
|
||||
try:
|
||||
# 检查 user_roles 表是否有多余的字段
|
||||
result = connection.execute(text("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'user_roles' AND column_name IN ('id', 'created_at', 'updated_at', 'created_by', 'updated_by');
|
||||
"""))
|
||||
extra_columns = [row[0] for row in result.fetchall()]
|
||||
|
||||
if extra_columns:
|
||||
print("简化 user_roles 表结构")
|
||||
# 创建新的简化表
|
||||
op.execute(text("""
|
||||
CREATE TABLE user_roles_new (
|
||||
user_id INTEGER NOT NULL,
|
||||
role_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (user_id, role_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (role_id) REFERENCES roles(id) ON DELETE CASCADE
|
||||
);
|
||||
"""))
|
||||
|
||||
# 迁移数据
|
||||
op.execute(text("""
|
||||
INSERT INTO user_roles_new (user_id, role_id)
|
||||
SELECT DISTINCT user_id, role_id FROM user_roles;
|
||||
"""))
|
||||
|
||||
# 删除旧表,重命名新表
|
||||
op.drop_table('user_roles')
|
||||
op.execute(text("ALTER TABLE user_roles_new RENAME TO user_roles;"))
|
||||
|
||||
except Exception as e:
|
||||
print(f"简化 user_roles 表时出错: {e}")
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""回滚操作 - 重新创建权限相关表."""
|
||||
|
||||
# 注意:这是一个破坏性操作,回滚会丢失数据
|
||||
# 在生产环境中应该谨慎使用
|
||||
|
||||
print("警告:回滚操作会重新创建权限相关表,但不会恢复数据")
|
||||
|
||||
# 重新创建基本的权限表结构(简化版)
|
||||
op.create_table('permissions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(100), nullable=False),
|
||||
sa.Column('code', sa.String(100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('code')
|
||||
)
|
||||
|
||||
op.create_table('role_permissions',
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('role_id', 'permission_id')
|
||||
)
|
||||
|
||||
# 添加用户表的 department_id 字段
|
||||
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""Database models for TH Agenter."""
|
||||
|
||||
from .user import User
|
||||
from .conversation import Conversation
|
||||
from .message import Message
|
||||
from .knowledge_base import KnowledgeBase, Document
|
||||
from .agent_config import AgentConfig
|
||||
from .excel_file import ExcelFile
|
||||
from .permission import Role, UserRole
|
||||
from .llm_config import LLMConfig
|
||||
from .workflow import Workflow, WorkflowExecution, NodeExecution
|
||||
from .database_config import DatabaseConfig
|
||||
from .table_metadata import TableMetadata
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"Conversation",
|
||||
"Message",
|
||||
"KnowledgeBase",
|
||||
"Document",
|
||||
"AgentConfig",
|
||||
"ExcelFile",
|
||||
"Role",
|
||||
"UserRole",
|
||||
"LLMConfig",
|
||||
"Workflow",
|
||||
"WorkflowExecution",
|
||||
"NodeExecution",
|
||||
"DatabaseConfig",
|
||||
"TableMetadata"
|
||||
]
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Agent configuration model."""
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Agent configuration model."""
|
||||
|
||||
__tablename__ = "agent_configs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Agent configuration
|
||||
enabled_tools: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
max_iterations: Mapped[int] = mapped_column(default=10)
|
||||
temperature: Mapped[str] = mapped_column(String(10), default="0.1")
|
||||
system_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
verbose: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Model configuration
|
||||
model_name: Mapped[str] = mapped_column(String(100), default="gpt-3.5-turbo")
|
||||
max_tokens: Mapped[int] = mapped_column(default=2048)
|
||||
|
||||
# Status
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
is_default: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentConfig(id={self.id}, name='{self.name}', is_active={self.is_active})>"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}[{self.id}] Active: {self.is_active}"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary."""
|
||||
data = super().to_dict()
|
||||
data['enabled_tools'] = self.enabled_tools or []
|
||||
return data
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""Conversation model."""
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, DateTime
|
||||
from datetime import datetime
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Conversation model."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
|
||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("users.id")
|
||||
knowledge_base_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id")
|
||||
system_prompt: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False, default="gpt-3.5-turbo")
|
||||
temperature: Mapped[str] = mapped_column(String(10), nullable=False, default="0.7")
|
||||
max_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=2048)
|
||||
is_archived: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
last_message_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
"""Get the number of messages in this conversation."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def last_message_at(self):
|
||||
"""Get the timestamp of the last message."""
|
||||
return self.messages[-1].created_at or self.created_at
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""数据库配置模型"""
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import Integer, String, Text, Boolean, JSON
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
# 在现有的DatabaseConfig类中添加关系
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
"""数据库配置表"""
|
||||
__tablename__ = "database_configs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False) # 配置名称
|
||||
db_type: Mapped[str] = mapped_column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等
|
||||
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
port: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
database: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
username: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
password: Mapped[str] = mapped_column(Text, nullable=False) # 加密存储
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
connection_params: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 额外连接参数
|
||||
|
||||
def to_dict(self, include_password=False, decrypt_service=None):
|
||||
result = {
|
||||
"id": self.id,
|
||||
"created_by": self.created_by,
|
||||
"name": self.name,
|
||||
"db_type": self.db_type,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"database": self.database,
|
||||
"username": self.username,
|
||||
"is_active": self.is_active,
|
||||
"is_default": self.is_default,
|
||||
"connection_params": self.connection_params,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
|
||||
# 如果需要包含密码且提供了解密服务
|
||||
if include_password and decrypt_service:
|
||||
logger.info(f"begin decrypt password for db config {self.id}")
|
||||
result["password"] = decrypt_service._decrypt_password(self.password)
|
||||
|
||||
return result
|
||||
|
||||
# 添加关系
|
||||
# table_metadata = relationship("TableMetadata", back_populates="database_config")
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Excel file models for smart query."""
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, JSON, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class ExcelFile(BaseModel):
|
||||
"""Excel file model for storing file metadata."""
|
||||
__tablename__ = "excel_files"
|
||||
# Basic file information
|
||||
# user_id: Mapped[int] = mapped_column(Integer, nullable=False) # 用户ID
|
||||
original_filename: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False) # 文件存储路径
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # 文件大小(字节)
|
||||
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv)
|
||||
|
||||
# Excel specific information
|
||||
sheet_names: Mapped[list] = mapped_column(JSON, nullable=False) # 所有sheet名称列表
|
||||
default_sheet: Mapped[str | None] = mapped_column(String(100), nullable=True) # 默认sheet名称
|
||||
|
||||
# Data preview information
|
||||
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]}
|
||||
preview_data: Mapped[dict] = mapped_column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]}
|
||||
data_types: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}}
|
||||
|
||||
# Statistics
|
||||
total_rows: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count}
|
||||
total_columns: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count}
|
||||
|
||||
# Processing status
|
||||
is_processed: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否已处理
|
||||
processing_error: Mapped[str | None] = mapped_column(Text, nullable=True) # 处理错误信息
|
||||
|
||||
# Upload information
|
||||
# upload_time: Mapped[DateTime] = mapped_column(DateTime, default=func.now(), nullable=False) # 上传时间
|
||||
last_accessed: Mapped[DateTime | None] = mapped_column(DateTime, nullable=True) # 最后访问时间
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ExcelFile(id={self.id}, filename='{self.original_filename}')>"
|
||||
|
||||
@property
|
||||
def file_size_mb(self):
|
||||
"""Get file size in MB."""
|
||||
return round(self.file_size / (1024 * 1024), 2)
|
||||
|
||||
@property
|
||||
def sheet_count(self):
|
||||
"""Get number of sheets."""
|
||||
return len(self.sheet_names) if self.sheet_names else 0
|
||||
|
||||
def get_sheet_info(self, sheet_name: str = None):
|
||||
"""Get information for a specific sheet or default sheet."""
|
||||
if not sheet_name:
|
||||
sheet_name = self.default_sheet or (self.sheet_names[0] if self.sheet_names else None)
|
||||
|
||||
if not sheet_name or sheet_name not in self.sheet_names:
|
||||
return None
|
||||
|
||||
return {
|
||||
'sheet_name': sheet_name,
|
||||
'columns': self.columns_info.get(sheet_name, []) if self.columns_info else [],
|
||||
'preview_data': self.preview_data.get(sheet_name, []) if self.preview_data else [],
|
||||
'data_types': self.data_types.get(sheet_name, {}) if self.data_types else {},
|
||||
'total_rows': self.total_rows.get(sheet_name, 0) if self.total_rows else 0,
|
||||
'total_columns': self.total_columns.get(sheet_name, 0) if self.total_columns else 0
|
||||
}
|
||||
|
||||
def get_all_sheets_summary(self):
|
||||
"""Get summary information for all sheets."""
|
||||
if not self.sheet_names:
|
||||
return []
|
||||
|
||||
summary = []
|
||||
for sheet_name in self.sheet_names:
|
||||
sheet_info = self.get_sheet_info(sheet_name)
|
||||
if sheet_info:
|
||||
summary.append({
|
||||
'sheet_name': sheet_name,
|
||||
'columns_count': len(sheet_info['columns']),
|
||||
'rows_count': sheet_info['total_rows'],
|
||||
'columns': sheet_info['columns'][:10] # 只显示前10列
|
||||
})
|
||||
return summary
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""Knowledge base models."""
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, JSON
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class KnowledgeBase(BaseModel):
|
||||
"""Knowledge base model."""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), unique=False, index=True, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
embedding_model: Mapped[str] = mapped_column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
chunk_size: Mapped[int] = mapped_column(Integer, nullable=False, default=1000)
|
||||
chunk_overlap: Mapped[int] = mapped_column(Integer, nullable=False, default=200)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Vector database settings
|
||||
vector_db_type: Mapped[str] = mapped_column(String(50), nullable=False, default="chroma")
|
||||
collection_name: Mapped[str | None] = mapped_column(String(100), nullable=True) # For vector DB collection
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KnowledgeBase(id={self.id}, name='{self.name}')>"
|
||||
|
||||
# Relationships are commented out to remove foreign key constraints, so these properties should be updated
|
||||
# @property
|
||||
# def document_count(self):
|
||||
# """Get the number of documents in this knowledge base."""
|
||||
# return len(self.documents)
|
||||
|
||||
# @property
|
||||
# def active_document_count(self):
|
||||
# """Get the number of active documents in this knowledge base."""
|
||||
# return len([doc for doc in self.documents if doc.is_processed])
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Document model."""
|
||||
|
||||
__tablename__ = "documents"
|
||||
|
||||
knowledge_base_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id")
|
||||
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
original_filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # in bytes
|
||||
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # .pdf, .txt, .docx, etc.
|
||||
mime_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
|
||||
# Processing status
|
||||
is_processed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
processing_error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Content and metadata
|
||||
content: Mapped[str | None] = mapped_column(Text, nullable=True) # Extracted text content
|
||||
doc_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Additional metadata
|
||||
|
||||
# Chunking information
|
||||
chunk_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
|
||||
# Embedding information
|
||||
embedding_model: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
vector_ids: Mapped[list | None] = mapped_column(JSON, nullable=True) # Store vector database IDs for chunks
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Document(id={self.id}, filename='{self.filename}', kb_id={self.knowledge_base_id})>"
|
||||
|
||||
@property
|
||||
def file_size_mb(self):
|
||||
"""Get file size in MB."""
|
||||
return round(self.file_size / (1024 * 1024), 2)
|
||||
|
||||
@property
|
||||
def is_text_file(self):
|
||||
"""Check if document is a text file."""
|
||||
return self.file_type.lower() in ['.txt', '.md', '.csv']
|
||||
|
||||
@property
|
||||
def is_pdf_file(self):
|
||||
"""Check if document is a PDF file."""
|
||||
return self.file_type.lower() == '.pdf'
|
||||
|
||||
@property
|
||||
def is_office_file(self):
|
||||
"""Check if document is an Office file."""
|
||||
return self.file_type.lower() in ['.docx', '.xlsx', '.pptx']
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
"""LLM Configuration model for managing multiple AI models."""
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM Configuration model for managing AI model settings."""
|
||||
|
||||
__tablename__ = "llm_configs"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
|
||||
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥(加密存储)
|
||||
base_url: Mapped[str | None] = mapped_column(String(200), nullable=True) # API基础URL
|
||||
|
||||
# 模型参数
|
||||
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
|
||||
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
|
||||
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
|
||||
# 配置信息
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 配置描述
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||||
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||||
|
||||
# 扩展配置(JSON格式)
|
||||
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
|
||||
|
||||
# 使用统计
|
||||
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
|
||||
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
|
||||
|
||||
def to_dict(self, include_sensitive=False):
|
||||
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'name': self.name,
|
||||
'provider': self.provider,
|
||||
'model_name': self.model_name,
|
||||
'base_url': self.base_url,
|
||||
'max_tokens': self.max_tokens,
|
||||
'temperature': self.temperature,
|
||||
'top_p': self.top_p,
|
||||
'frequency_penalty': self.frequency_penalty,
|
||||
'presence_penalty': self.presence_penalty,
|
||||
'description': self.description,
|
||||
'is_active': self.is_active,
|
||||
'is_default': self.is_default,
|
||||
'is_embedding': self.is_embedding,
|
||||
'extra_config': self.extra_config,
|
||||
'usage_count': self.usage_count,
|
||||
'last_used_at': self.last_used_at
|
||||
})
|
||||
|
||||
if include_sensitive:
|
||||
data['api_key'] = self.api_key
|
||||
else:
|
||||
# 只显示API密钥的前几位和后几位
|
||||
if self.api_key:
|
||||
key_len = len(self.api_key)
|
||||
if key_len > 8:
|
||||
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
|
||||
else:
|
||||
data['api_key_masked'] = "***"
|
||||
else:
|
||||
data['api_key_masked'] = None
|
||||
|
||||
return data
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""获取用于创建客户端的配置."""
|
||||
config = {
|
||||
'api_key': self.api_key,
|
||||
'base_url': self.base_url,
|
||||
'model': self.model_name,
|
||||
'max_tokens': self.max_tokens,
|
||||
'temperature': self.temperature,
|
||||
'top_p': self.top_p,
|
||||
'frequency_penalty': self.frequency_penalty,
|
||||
'presence_penalty': self.presence_penalty
|
||||
}
|
||||
|
||||
# 添加额外配置
|
||||
if self.extra_config:
|
||||
config.update(self.extra_config)
|
||||
|
||||
return config
|
||||
|
||||
def validate_config(self) -> Dict[str, Any]:
|
||||
"""验证配置是否有效."""
|
||||
if not self.name or not self.name.strip():
|
||||
return {"valid": False, "error": "配置名称不能为空"}
|
||||
|
||||
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
|
||||
return {"valid": False, "error": "不支持的服务商"}
|
||||
|
||||
if not self.model_name or not self.model_name.strip():
|
||||
return {"valid": False, "error": "模型名称不能为空"}
|
||||
|
||||
if not self.api_key or not self.api_key.strip():
|
||||
return {"valid": False, "error": "API密钥不能为空"}
|
||||
|
||||
if self.max_tokens <= 0 or self.max_tokens > 32000:
|
||||
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
|
||||
|
||||
if self.temperature < 0 or self.temperature > 2:
|
||||
return {"valid": False, "error": "温度参数必须在0-2之间"}
|
||||
|
||||
return {"valid": True, "error": None}
|
||||
|
||||
def increment_usage(self):
|
||||
"""增加使用次数."""
|
||||
self.usage_count += 1
|
||||
self.last_used_at = datetime.now()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, provider: str, is_embedding: bool = False):
|
||||
"""获取服务商的默认配置模板."""
|
||||
templates = {
|
||||
'openai': {
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'deepseek': {
|
||||
'base_url': 'https://api.deepseek.com/v1',
|
||||
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'doubao': {
|
||||
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'zhipu': {
|
||||
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'moonshot': {
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
}
|
||||
}
|
||||
|
||||
return templates.get(provider, {})
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Message model."""
|
||||
|
||||
from sqlalchemy import String, Integer, Text, Enum, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
import enum
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class MessageRole(str, enum.Enum):
|
||||
"""Message role enumeration."""
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class MessageType(str, enum.Enum):
|
||||
"""Message type enumeration."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
FILE = "file"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model."""
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
conversation_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("conversations.id")
|
||||
role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||
message_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Store additional data like file info, tokens used, etc.
|
||||
|
||||
# For knowledge base context
|
||||
context_documents: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Store retrieved document references
|
||||
|
||||
# Token usage tracking
|
||||
prompt_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content
|
||||
return f"<Message(id={self.id}, role='{self.role}', content='{content_preview}')>"
|
||||
|
||||
def to_dict(self, include_metadata=True):
|
||||
"""Convert to dictionary."""
|
||||
data = super().to_dict()
|
||||
if not include_metadata:
|
||||
data.pop('message_metadata', None)
|
||||
data.pop('context_documents', None)
|
||||
data.pop('prompt_tokens', None)
|
||||
data.pop('completion_tokens', None)
|
||||
data.pop('total_tokens', None)
|
||||
return data
|
||||
|
||||
@property
|
||||
def is_from_user(self):
|
||||
"""Check if message is from user."""
|
||||
return self.role == MessageRole.USER
|
||||
|
||||
@property
|
||||
def is_from_assistant(self):
|
||||
"""Check if message is from assistant."""
|
||||
return self.role == MessageRole.ASSISTANT
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Role models for simplified RBAC system."""
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..db.base import BaseModel, Base
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
"""Role model for simplified RBAC system."""
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色名称
|
||||
code: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色编码
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 角色描述
|
||||
is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否系统角色
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 关系 - 只保留用户关系
|
||||
users = relationship("User", secondary="user_roles", back_populates="roles")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Role(id={self.id}, code='{self.code}', name='{self.name}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary."""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'name': self.name,
|
||||
'code': self.code,
|
||||
'description': self.description,
|
||||
'is_system': self.is_system,
|
||||
'is_active': self.is_active
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
"""User role association model."""
|
||||
|
||||
__tablename__ = "user_roles"
|
||||
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'), primary_key=True)
|
||||
role_id: Mapped[int] = mapped_column(Integer, ForeignKey('roles.id'), primary_key=True)
|
||||
|
||||
# 关系 - 用于直接操作关联表的场景
|
||||
user = relationship("User", viewonly=True)
|
||||
role = relationship("Role", viewonly=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserRole(user_id={self.user_id}, role_id={self.role_id})>"
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
"""表元数据模型"""
|
||||
|
||||
from sqlalchemy import Integer, String, Text, DateTime, Boolean, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class TableMetadata(BaseModel):
|
||||
"""表元数据表"""
|
||||
__tablename__ = "table_metadata"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
# database_config_id = Column(Integer, ForeignKey('database_configs.id'), nullable=False)
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
table_schema: Mapped[str] = mapped_column(String(50), default='public')
|
||||
table_type: Mapped[str] = mapped_column(String(20), default='BASE TABLE')
|
||||
table_comment: Mapped[str | None] = mapped_column(Text, nullable=True) # 表描述
|
||||
database_config_id: Mapped[int | None] = mapped_column(Integer, nullable=True) #数据库配置ID
|
||||
# 表结构信息
|
||||
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:名称、类型、注释等
|
||||
primary_keys: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 主键列表
|
||||
foreign_keys: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 外键信息
|
||||
indexes: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 索引信息
|
||||
|
||||
# 示例数据
|
||||
sample_data: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 前5条示例数据
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0) # 总行数
|
||||
|
||||
# 问答相关
|
||||
is_enabled_for_qa: Mapped[bool] = mapped_column(Boolean, default=True) # 是否启用问答
|
||||
qa_description: Mapped[str | None] = mapped_column(Text, nullable=True) # 问答描述
|
||||
business_context: Mapped[str | None] = mapped_column(Text, nullable=True) # 业务上下文
|
||||
|
||||
last_synced_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True) # 最后同步时间
|
||||
|
||||
# 关系
|
||||
# database_config = relationship("DatabaseConfig", back_populates="table_metadata")
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"created_by": self.created_by, # 改为created_by
|
||||
"database_config_id": self.database_config_id,
|
||||
"table_name": self.table_name,
|
||||
"table_schema": self.table_schema,
|
||||
"table_type": self.table_type,
|
||||
"table_comment": self.table_comment,
|
||||
"columns_info": self.columns_info,
|
||||
"primary_keys": self.primary_keys,
|
||||
# "foreign_keys": self.foreign_keys,
|
||||
"indexes": self.indexes,
|
||||
"sample_data": self.sample_data,
|
||||
"row_count": self.row_count,
|
||||
"is_enabled_for_qa": self.is_enabled_for_qa,
|
||||
"qa_description": self.qa_description,
|
||||
"business_context": self.business_context,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_synced_at": self.last_synced_at.isoformat() if self.last_synced_at else None
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
"""User model."""
|
||||
|
||||
from sqlalchemy import String, Boolean, Text
|
||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
bio: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 关系 - 只保留角色关系
|
||||
roles = relationship("Role", secondary="user_roles", back_populates="users")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username='{self.username}', email='{self.email}', full_name='{self.full_name}', is_active={self.is_active}, avatar_url='{self.avatar_url}', bio='{self.bio}', is_superuser={self.is_admin})>"
|
||||
|
||||
def to_dict(self, include_sensitive=False, include_roles=False):
|
||||
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'username': self.username,
|
||||
'email': self.email,
|
||||
'full_name': self.full_name,
|
||||
'is_active': self.is_active,
|
||||
'avatar_url': self.avatar_url,
|
||||
'bio': self.bio,
|
||||
'is_superuser': self.is_admin # 使用同步的 is_admin 属性代替异步的 is_superuser 方法
|
||||
})
|
||||
|
||||
if not include_sensitive:
|
||||
data.pop('hashed_password', None)
|
||||
|
||||
if include_roles:
|
||||
try:
|
||||
# 安全访问roles关系属性
|
||||
data['roles'] = [role.to_dict() for role in self.roles if role.is_active]
|
||||
except Exception:
|
||||
# 如果角色关系未加载或访问出错,返回空列表
|
||||
data['roles'] = []
|
||||
|
||||
return data
|
||||
|
||||
async def has_role(self, role_code: str) -> bool:
|
||||
"""检查用户是否拥有指定角色."""
|
||||
try:
|
||||
# 在异步环境中,需要先加载关系属性
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy import select
|
||||
from .permission import Role, UserRole
|
||||
|
||||
session = object_session(self)
|
||||
if isinstance(session, AsyncSession):
|
||||
# 如果是异步会话,使用await加载关系
|
||||
await session.refresh(self, ['roles'])
|
||||
return any(role.code == role_code and role.is_active for role in self.roles)
|
||||
except Exception:
|
||||
# 如果对象已分离或加载关系失败,使用数据库查询
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy import select
|
||||
from .permission import Role, UserRole
|
||||
|
||||
session = object_session(self)
|
||||
if session is None:
|
||||
# 如果没有会话,返回False
|
||||
return False
|
||||
else:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
if isinstance(session, AsyncSession):
|
||||
# 如果是异步会话,使用异步查询
|
||||
user_role = await session.execute(
|
||||
select(UserRole).join(Role).filter(
|
||||
UserRole.user_id == self.id,
|
||||
Role.code == role_code,
|
||||
Role.is_active == True
|
||||
)
|
||||
)
|
||||
return user_role.scalar_one_or_none() is not None
|
||||
else:
|
||||
# 如果是同步会话,使用同步查询
|
||||
user_role = session.query(UserRole).join(Role).filter(
|
||||
UserRole.user_id == self.id,
|
||||
Role.code == role_code,
|
||||
Role.is_active == True
|
||||
).first()
|
||||
return user_role is not None
|
||||
|
||||
async def is_superuser(self) -> bool:
|
||||
"""检查用户是否为超级管理员."""
|
||||
return await self.has_role('SUPER_ADMIN')
|
||||
|
||||
async def is_admin_user(self) -> bool:
|
||||
"""检查用户是否为管理员(兼容性方法)."""
|
||||
return await self.is_superuser()
|
||||
|
||||
# 注意:属性方式的 is_admin 无法是异步的,所以我们改为同步方法并简化实现
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
"""检查用户是否为管理员(属性方式)."""
|
||||
# 同步属性无法使用 await,所以我们只能检查已加载的角色
|
||||
# 使用try-except捕获可能的MissingGreenlet错误
|
||||
try:
|
||||
# 检查角色关系是否已经加载
|
||||
# 如果roles属性是一个InstrumentedList且已经加载,那么它应该有__iter__方法
|
||||
return any(role.code == 'SUPER_ADMIN' and role.is_active for role in self.roles)
|
||||
except Exception:
|
||||
# 如果角色关系未加载或访问出错,返回False
|
||||
return False
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""Workflow models."""
|
||||
|
||||
from sqlalchemy import String, Text, Boolean, Integer, JSON, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||
import enum
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
class WorkflowStatus(enum.Enum):
|
||||
"""工作流状态枚举"""
|
||||
DRAFT = "DRAFT" # 草稿
|
||||
PUBLISHED = "PUBLISHED" # 已发布
|
||||
ARCHIVED = "ARCHIVED" # 已归档
|
||||
|
||||
class NodeType(enum.Enum):
|
||||
"""节点类型枚举"""
|
||||
START = "start" # 开始节点
|
||||
END = "end" # 结束节点
|
||||
LLM = "llm" # 大模型节点
|
||||
CONDITION = "condition" # 条件分支节点
|
||||
LOOP = "loop" # 循环节点
|
||||
CODE = "code" # 代码执行节点
|
||||
HTTP = "http" # HTTP请求节点
|
||||
TOOL = "tool" # 工具节点
|
||||
|
||||
class ExecutionStatus(enum.Enum):
|
||||
"""执行状态枚举"""
|
||||
PENDING = "pending" # 等待执行
|
||||
RUNNING = "running" # 执行中
|
||||
COMPLETED = "completed" # 执行完成
|
||||
FAILED = "failed" # 执行失败
|
||||
CANCELLED = "cancelled" # 已取消
|
||||
|
||||
class Workflow(BaseModel):
|
||||
"""工作流模型"""
|
||||
__tablename__ = "workflows"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, comment="工作流名称")
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True, comment="工作流描述")
|
||||
status: Mapped[WorkflowStatus] = mapped_column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
|
||||
# 工作流定义(JSON格式存储节点和连接信息)
|
||||
definition: Mapped[dict] = mapped_column(JSON, nullable=False, comment="工作流定义")
|
||||
|
||||
# 版本信息
|
||||
version: Mapped[str] = mapped_column(String(20), default="1.0.0", nullable=False, comment="版本号")
|
||||
|
||||
# 关联用户
|
||||
owner_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID")
|
||||
|
||||
# 关系
|
||||
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Workflow(id={self.id}, name='{self.name}', status='{self.status.value}')>"
|
||||
|
||||
def to_dict(self, include_definition=True):
|
||||
"""转换为字典"""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'status': self.status.value,
|
||||
'is_active': self.is_active,
|
||||
'version': self.version,
|
||||
'owner_id': self.owner_id
|
||||
})
|
||||
|
||||
if include_definition:
|
||||
data['definition'] = self.definition
|
||||
|
||||
return data
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""工作流执行记录"""
|
||||
|
||||
__tablename__ = "workflow_executions"
|
||||
|
||||
workflow_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID")
|
||||
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
|
||||
# 执行输入和输出
|
||||
input_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||
output_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||
|
||||
# 执行信息
|
||||
started_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||
|
||||
# 执行者
|
||||
executor_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID")
|
||||
|
||||
# 关系
|
||||
workflow = relationship("Workflow", back_populates="executions")
|
||||
node_executions = relationship("NodeExecution", back_populates="workflow_execution", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<WorkflowExecution(id={self.id}, workflow_id={self.workflow_id}, status='{self.status.value}')>"
|
||||
|
||||
def to_dict(self, include_nodes=False):
|
||||
"""转换为字典"""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'workflow_id': self.workflow_id,
|
||||
'status': self.status.value,
|
||||
'input_data': self.input_data,
|
||||
'output_data': self.output_data,
|
||||
'started_at': self.started_at,
|
||||
'completed_at': self.completed_at,
|
||||
'error_message': self.error_message,
|
||||
'executor_id': self.executor_id
|
||||
})
|
||||
|
||||
if include_nodes:
|
||||
data['node_executions'] = [node.to_dict() for node in self.node_executions]
|
||||
|
||||
return data
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
"""节点执行记录"""
|
||||
__tablename__ = "node_executions"
|
||||
|
||||
workflow_execution_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID")
|
||||
node_id: Mapped[str] = mapped_column(String(50), nullable=False, comment="节点ID")
|
||||
node_type: Mapped[NodeType] = mapped_column(Enum(NodeType), nullable=False, comment="节点类型")
|
||||
node_name: Mapped[str] = mapped_column(String(100), nullable=False, comment="节点名称")
|
||||
|
||||
# 执行状态和结果
|
||||
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
input_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||
output_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||
|
||||
# 执行时间
|
||||
started_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True, comment="执行时长(毫秒)")
|
||||
|
||||
# 错误信息
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||
|
||||
# 关系
|
||||
workflow_execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<NodeExecution(id={self.id}, node_id='{self.node_id}', status='{self.status.value}')>"
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
data = super().to_dict()
|
||||
data.update({
|
||||
'workflow_execution_id': self.workflow_execution_id,
|
||||
'node_id': self.node_id,
|
||||
'node_type': self.node_type.value,
|
||||
'node_name': self.node_name,
|
||||
'status': self.status.value,
|
||||
'input_data': self.input_data,
|
||||
'output_data': self.output_data,
|
||||
'started_at': self.started_at,
|
||||
'completed_at': self.completed_at,
|
||||
'duration_ms': self.duration_ms,
|
||||
'error_message': self.error_message
|
||||
})
|
||||
|
||||
return data
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
"""Schemas package initialization."""
|
||||
|
||||
from .user import UserCreate, UserUpdate, UserResponse
|
||||
from .permission import (
|
||||
RoleCreate, RoleUpdate, RoleResponse,
|
||||
UserRoleAssign
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# User schemas
|
||||
"UserCreate", "UserUpdate", "UserResponse",
|
||||
|
||||
# Permission schemas
|
||||
"RoleCreate", "RoleUpdate", "RoleResponse",
|
||||
"UserRoleAssign",
|
||||
]
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
"""LLM Configuration Pydantic schemas."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LLMConfigBase(BaseModel):
|
||||
"""大模型配置基础模式."""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="配置名称")
|
||||
provider: str = Field(..., min_length=1, max_length=50, description="服务商")
|
||||
model_name: str = Field(..., min_length=1, max_length=100, description="模型名称")
|
||||
api_key: str = Field(..., min_length=1, description="API密钥")
|
||||
base_url: Optional[str] = Field(None, description="API基础URL")
|
||||
max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数")
|
||||
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数")
|
||||
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数")
|
||||
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚")
|
||||
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚")
|
||||
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
||||
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_default: bool = Field(False, description="是否为默认配置")
|
||||
is_embedding: bool = Field(False, description="是否为嵌入模型")
|
||||
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
"""创建大模型配置模式."""
|
||||
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: str) -> str:
|
||||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||
'ollama', 'custom', "doubao"
|
||||
]
|
||||
if v.lower() not in allowed_providers:
|
||||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
return v.lower()
|
||||
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
if len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
"""更新大模型配置模式."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称")
|
||||
provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商")
|
||||
model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称")
|
||||
api_key: Optional[str] = Field(None, min_length=1, description="API密钥")
|
||||
base_url: Optional[str] = Field(None, description="API基础URL")
|
||||
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数")
|
||||
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数")
|
||||
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数")
|
||||
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚")
|
||||
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚")
|
||||
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
||||
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
is_default: Optional[bool] = Field(None, description="是否为默认配置")
|
||||
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
|
||||
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
||||
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||
'ollama', 'custom',"doubao"
|
||||
]
|
||||
if v.lower() not in allowed_providers:
|
||||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class LLMConfigResponse(BaseModel):
|
||||
"""大模型配置响应模式."""
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
api_key: Optional[str] = None # 完整的API密钥(仅在include_sensitive=True时返回)
|
||||
base_url: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
is_active: bool
|
||||
is_default: bool
|
||||
is_embedding: bool
|
||||
extra_config: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
created_by: Optional[int] = None
|
||||
updated_by: Optional[int] = None
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def api_key_masked(self) -> Optional[str]:
|
||||
# 在响应中隐藏API密钥,只显示前4位和后4位
|
||||
if self.api_key:
|
||||
key = self.api_key
|
||||
if len(key) > 8:
|
||||
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
|
||||
else:
|
||||
return '*' * len(key)
|
||||
return None
|
||||
|
||||
|
||||
class LLMConfigTest(BaseModel):
|
||||
"""大模型配置测试模式."""
|
||||
message: Optional[str] = Field(
|
||||
"Hello, this is a test message.",
|
||||
max_length=1000,
|
||||
description="测试消息"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigClientResponse(BaseModel):
|
||||
"""大模型配置客户端响应模式(用于前端)."""
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
is_active: bool
|
||||
description: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Role Pydantic schemas."""
|
||||
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
|
||||
class RoleBase(BaseModel):
|
||||
"""角色基础模式."""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="角色名称")
|
||||
code: str = Field(..., min_length=1, max_length=50, description="角色代码")
|
||||
description: Optional[str] = Field(None, max_length=500, description="角色描述")
|
||||
sort_order: Optional[int] = Field(0, ge=0, description="排序")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
|
||||
class RoleCreate(RoleBase):
|
||||
"""创建角色模式."""
|
||||
|
||||
@field_validator('code')
|
||||
@classmethod
|
||||
def validate_code(cls, v: str) -> str:
|
||||
if not v.replace('_', '').replace('-', '').isalnum():
|
||||
raise ValueError('角色代码只能包含字母、数字、下划线和连字符')
|
||||
return v.upper()
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
"""更新角色模式."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100, description="角色名称")
|
||||
code: Optional[str] = Field(None, min_length=1, max_length=50, description="角色代码")
|
||||
description: Optional[str] = Field(None, max_length=500, description="角色描述")
|
||||
sort_order: Optional[int] = Field(None, ge=0, description="排序")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
|
||||
@field_validator('code')
|
||||
@classmethod
|
||||
def validate_code(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and not v.replace('_', '').replace('-', '').isalnum():
|
||||
raise ValueError('角色代码只能包含字母、数字、下划线和连字符')
|
||||
return v.upper() if v else v
|
||||
|
||||
|
||||
class RoleResponse(RoleBase):
|
||||
"""角色响应模式."""
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
created_by: Optional[int] = None
|
||||
updated_by: Optional[int] = None
|
||||
|
||||
# 关联信息
|
||||
user_count: Optional[int] = 0
|
||||
|
||||
model_config = {
|
||||
"from_attributes": True
|
||||
}
|
||||
|
||||
|
||||
class UserRoleAssign(BaseModel):
|
||||
"""用户角色分配模式."""
|
||||
user_id: int = Field(..., description="用户ID")
|
||||
role_ids: List[int] = Field(..., description="角色ID列表")
|
||||
|
||||
@field_validator('role_ids')
|
||||
@classmethod
|
||||
def validate_role_ids(cls, v: List[int]) -> List[int]:
|
||||
if not v:
|
||||
raise ValueError('角色ID列表不能为空')
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError('角色ID列表不能包含重复项')
|
||||
return v
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""User schemas."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from utils.util_schemas import BaseResponse
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""User base schema."""
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
email: str = Field(..., max_length=100)
|
||||
full_name: Optional[str] = Field(None, max_length=100)
|
||||
bio: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""User creation schema."""
|
||||
password: str = Field(..., min_length=6)
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""User update schema."""
|
||||
username: Optional[str] = Field(None, min_length=3, max_length=50)
|
||||
email: Optional[str] = Field(None, max_length=100)
|
||||
full_name: Optional[str] = Field(None, max_length=100)
|
||||
bio: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
password: Optional[str] = Field(None, min_length=6)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Change password request schema."""
|
||||
current_password: str = Field(..., description="Current password")
|
||||
new_password: str = Field(..., min_length=6, description="New password")
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""Admin reset password request schema."""
|
||||
new_password: str = Field(..., min_length=6, description="New password")
|
||||
|
||||
class UserResponse(BaseResponse, UserBase):
|
||||
"""User response schema."""
|
||||
is_active: bool
|
||||
is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员")
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, obj, *, from_attributes=False):
|
||||
"""从对象创建响应模型,正确处理is_superuser方法"""
|
||||
if hasattr(obj, '__dict__'):
|
||||
data = obj.__dict__.copy()
|
||||
# 调用is_superuser方法获取布尔值
|
||||
if hasattr(obj, 'is_admin'):
|
||||
# 使用同步的 is_admin 属性代替异步的 is_superuser 方法
|
||||
data['is_superuser'] = obj.is_admin
|
||||
elif hasattr(obj, 'is_superuser') and not callable(obj.is_superuser):
|
||||
# 如果is_superuser是属性而不是方法
|
||||
data['is_superuser'] = obj.is_superuser
|
||||
return super().model_validate(data)
|
||||
return super().model_validate(obj, from_attributes=from_attributes)
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应模型,包含令牌和用户信息"""
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_in: int
|
||||
user: UserResponse
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
"""Workflow schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class WorkflowStatus(str, Enum):
|
||||
"""工作流状态枚举"""
|
||||
DRAFT = "DRAFT"
|
||||
PUBLISHED = "PUBLISHED"
|
||||
ARCHIVED = "ARCHIVED"
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
"""节点类型"""
|
||||
START = "start"
|
||||
END = "end"
|
||||
LLM = "llm"
|
||||
CONDITION = "condition"
|
||||
LOOP = "loop"
|
||||
CODE = "code"
|
||||
HTTP = "http"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
"""执行状态"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
# 节点定义相关模式
|
||||
class NodePosition(BaseModel):
|
||||
"""节点位置"""
|
||||
x: float
|
||||
y: float
|
||||
|
||||
|
||||
# 参数定义相关模式
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型"""
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
|
||||
|
||||
class NodeParameter(BaseModel):
|
||||
"""节点参数定义"""
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
type: ParameterType
|
||||
description: Optional[str] = None
|
||||
required: bool = True
|
||||
default_value: Optional[Any] = None
|
||||
source: Optional[str] = None # 参数来源:'input'(用户输入), 'node'(其他节点输出), 'variable'(变量引用)
|
||||
source_node_id: Optional[str] = None # 来源节点ID(当source为'node'时)
|
||||
source_field: Optional[str] = None # 来源字段名
|
||||
variable_name: Optional[str] = None # 变量名称(用于结束节点的输出参数)
|
||||
|
||||
|
||||
class NodeInputOutput(BaseModel):
|
||||
"""节点输入输出定义"""
|
||||
inputs: List[NodeParameter] = []
|
||||
outputs: List[NodeParameter] = []
|
||||
|
||||
|
||||
class NodeConfig(BaseModel):
|
||||
"""节点配置基类"""
|
||||
pass
|
||||
|
||||
|
||||
class LLMNodeConfig(NodeConfig):
|
||||
"""LLM节点配置"""
|
||||
model_id: Optional[int] = None # 大模型配置ID
|
||||
model_name: Optional[str] = None # 模型名称(兼容前端)
|
||||
temperature: float = Field(default=0.7, ge=0, le=2)
|
||||
max_tokens: Optional[int] = Field(default=None, gt=0)
|
||||
prompt: str = Field(..., min_length=1)
|
||||
enable_variable_substitution: bool = True # 是否启用变量替换
|
||||
|
||||
|
||||
class ConditionNodeConfig(NodeConfig):
|
||||
"""条件节点配置"""
|
||||
condition: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class LoopNodeConfig(NodeConfig):
|
||||
"""循环节点配置"""
|
||||
loop_type: str = Field(..., pattern="^(count|while|foreach)$")
|
||||
count: Optional[int] = Field(None, description="循环次数(当loop_type为count时)")
|
||||
condition: Optional[str] = Field(None, description="循环条件(当loop_type为while时)")
|
||||
iterable: Optional[str] = Field(None, description="可迭代对象(当loop_type为foreach时)")
|
||||
|
||||
|
||||
class CodeNodeConfig(NodeConfig):
|
||||
"""代码执行节点配置"""
|
||||
language: str = Field(..., pattern="^(python|javascript)$")
|
||||
code: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class HttpNodeConfig(NodeConfig):
|
||||
"""HTTP请求节点配置"""
|
||||
method: str = Field(..., pattern="^(GET|POST|PUT|DELETE|PATCH)$")
|
||||
url: str = Field(..., min_length=1)
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
body: Optional[str] = None
|
||||
|
||||
|
||||
class ToolNodeConfig(NodeConfig):
|
||||
"""工具节点配置"""
|
||||
tool_type: str
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class WorkflowNode(BaseModel):
|
||||
"""工作流节点"""
|
||||
id: str
|
||||
type: NodeType
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
position: NodePosition
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
parameters: Optional[NodeInputOutput] = None # 节点输入输出参数定义
|
||||
|
||||
|
||||
class WorkflowConnection(BaseModel):
|
||||
"""工作流连接"""
|
||||
id: str
|
||||
from_node: str = Field(..., alias="from")
|
||||
to_node: str = Field(..., alias="to")
|
||||
from_point: str = Field(default="output")
|
||||
to_point: str = Field(default="input")
|
||||
|
||||
|
||||
class WorkflowDefinition(BaseModel):
|
||||
"""工作流定义"""
|
||||
nodes: List[WorkflowNode]
|
||||
connections: List[WorkflowConnection]
|
||||
|
||||
|
||||
# 工作流CRUD模式
|
||||
class WorkflowCreate(BaseModel):
|
||||
"""创建工作流"""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
definition: WorkflowDefinition
|
||||
status: WorkflowStatus = WorkflowStatus.DRAFT
|
||||
|
||||
|
||||
class WorkflowUpdate(BaseModel):
|
||||
"""更新工作流"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
definition: Optional[WorkflowDefinition] = None
|
||||
status: Optional[WorkflowStatus] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class WorkflowResponse(BaseModel):
|
||||
"""工作流响应"""
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
status: WorkflowStatus
|
||||
is_active: bool
|
||||
version: str
|
||||
owner_id: int
|
||||
definition: Optional[WorkflowDefinition] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
|
||||
# 工作流执行相关模式
|
||||
class WorkflowExecuteRequest(BaseModel):
|
||||
"""工作流执行请求"""
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class NodeExecutionResponse(BaseModel):
|
||||
"""节点执行响应"""
|
||||
id: int
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_name: str
|
||||
status: ExecutionStatus
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Dict[str, Any]]
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
duration_ms: Optional[int]
|
||||
error_message: Optional[str]
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseModel):
|
||||
"""工作流执行响应"""
|
||||
id: int
|
||||
workflow_id: int
|
||||
status: ExecutionStatus
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Dict[str, Any]]
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
error_message: Optional[str]
|
||||
executor_id: int
|
||||
node_executions: Optional[List[NodeExecutionResponse]] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
|
||||
# 工作流列表响应
|
||||
class WorkflowListResponse(BaseModel):
|
||||
"""工作流列表响应"""
|
||||
workflows: List[WorkflowResponse]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Agent services package.
|
||||
|
||||
轻量化导入:仅暴露基础工具类型,避免在包导入时加载耗时的服务层。使用 AgentService 时请从子模块显式导入:
|
||||
from open_agent.services.agent.agent_service import AgentService
|
||||
"""
|
||||
|
||||
from .base import BaseTool, ToolRegistry
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolRegistry"
|
||||
]
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
"""LangChain Agent service with tool calling capabilities."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool, ToolRegistry, ToolResult
|
||||
from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool
|
||||
from ..postgresql_tool_manager import get_postgresql_tool
|
||||
from ..mysql_tool_manager import get_mysql_tool
|
||||
from ...core.config import get_settings
|
||||
from ...utils.logger import get_logger
|
||||
from ..agent_config import AgentConfigService
|
||||
|
||||
logger = get_logger("agent_service")
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Agent configuration."""
|
||||
enabled_tools: List[str] = Field(default_factory=lambda: [
|
||||
"calculator", "weather", "search", "datetime", "file", "generate_image", "postgresql_mcp", "mysql_mcp"
|
||||
])
|
||||
max_iterations: int = Field(default=10)
|
||||
temperature: float = Field(default=0.1)
|
||||
system_message: str = Field(
|
||||
default="You are a helpful AI assistant with access to various tools. "
|
||||
"Use the available tools to help answer user questions accurately. "
|
||||
"Always explain your reasoning and the tools you're using."
|
||||
)
|
||||
verbose: bool = Field(default=True)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""LangChain Agent service with tool calling capabilities."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
self.settings = get_settings()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = AgentConfig()
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize and register all available tools."""
|
||||
tools = [
|
||||
WeatherQueryTool(),
|
||||
TavilySearchTool(),
|
||||
DateTimeTool(),
|
||||
get_postgresql_tool(), # 使用单例PostgreSQL MCP工具
|
||||
get_mysql_tool() # 使用单例MySQL MCP工具
|
||||
]
|
||||
|
||||
for tool in tools:
|
||||
self.tool_registry.register(tool)
|
||||
logger.info(f"Registered tool: {tool.get_name()}")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
config_dict = self.config_service.get_config_dict()
|
||||
# Update config with database values
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded agent configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database, using defaults: {str(e)}")
|
||||
|
||||
def _get_enabled_tools(self) -> List[Any]:
|
||||
"""Get list of enabled LangChain tools."""
|
||||
enabled_tools = []
|
||||
|
||||
for tool_name in self.config.enabled_tools:
|
||||
tool = self.tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
enabled_tools.append(tool)
|
||||
logger.debug(f"Enabled tool: {tool_name}")
|
||||
else:
|
||||
logger.warning(f"Tool not found: {tool_name}")
|
||||
|
||||
return enabled_tools
|
||||
|
||||
def _create_agent_executor(self) -> Any:
|
||||
"""Create LangChain agent executor."""
|
||||
# Get LLM configuration
|
||||
from ...core.llm import create_llm
|
||||
llm = create_llm()
|
||||
|
||||
# Get enabled tools
|
||||
tools = self._get_enabled_tools()
|
||||
|
||||
# Create prompt template
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", self.config.system_message),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{input}"),
|
||||
])
|
||||
|
||||
# Create agent using new LangChain 1.0+ API
|
||||
agent = create_agent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Process chat message with agent."""
|
||||
try:
|
||||
logger.info(f"Processing agent chat message: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
langchain_history.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
langchain_history.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Execute agent
|
||||
result = await agent.ainvoke({
|
||||
"input": message,
|
||||
"chat_history": langchain_history
|
||||
})
|
||||
|
||||
logger.info(f"Agent response generated successfully")
|
||||
|
||||
return {
|
||||
"response": result["output"] if isinstance(result, dict) and "output" in result else str(result),
|
||||
"tool_calls": [],
|
||||
"success": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent chat error: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"response": f"Sorry, I encountered an error: {str(e)}",
|
||||
"tool_calls": [],
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process chat message with agent (streaming)."""
|
||||
tool_calls = [] # Initialize tool_calls at the beginning
|
||||
try:
|
||||
logger.info(f"Processing agent chat stream: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
langchain_history.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
langchain_history.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Yield initial status
|
||||
yield {
|
||||
"type": "status",
|
||||
"content": "🤖 开始分析您的请求...",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Generate response
|
||||
result = await agent.ainvoke({
|
||||
"input": message,
|
||||
"chat_history": langchain_history
|
||||
})
|
||||
|
||||
response_content = result["output"] if isinstance(result, dict) and "output" in result else str(result)
|
||||
|
||||
# Yield the final response in chunks to simulate streaming
|
||||
words = response_content.split()
|
||||
current_content = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
current_content += word + " "
|
||||
|
||||
# Yield every 2-3 words or at the end
|
||||
if (i + 1) % 2 == 0 or i == len(words) - 1:
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": current_content.strip(),
|
||||
"tool_calls": tool_calls if i == len(words) - 1 else [],
|
||||
"done": i == len(words) - 1
|
||||
}
|
||||
|
||||
# Small delay to simulate typing
|
||||
if i < len(words) - 1:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Agent stream response completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent chat stream error: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Sorry, I encountered an error: {str(e)}",
|
||||
"done": True
|
||||
}
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update agent configuration."""
|
||||
try:
|
||||
# Update configuration
|
||||
for key, value in config.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info(f"Updated agent config: {key} = {value}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent config: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def load_config_from_db(self, config_id: Optional[int] = None):
|
||||
"""Load configuration from database."""
|
||||
if not self.config_service:
|
||||
logger.warning("No database session available for loading config")
|
||||
return
|
||||
|
||||
try:
|
||||
config_dict = self.config_service.get_config_dict(config_id)
|
||||
self.update_config(config_dict)
|
||||
logger.info(f"Loaded configuration from database (ID: {config_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config from database: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available tools."""
|
||||
tools = []
|
||||
for tool_name, tool in self.tool_registry._tools.items():
|
||||
tools.append({
|
||||
"name": tool.get_name(),
|
||||
"description": tool.get_description(),
|
||||
"parameters": [{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum
|
||||
} for param in tool.get_parameters()],
|
||||
"enabled": tool_name in self.config.enabled_tools
|
||||
})
|
||||
return tools
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current agent configuration."""
|
||||
return self.config.dict()
|
||||
|
||||
|
||||
# Global agent service instance
|
||||
_agent_service: Optional[AgentService] = None
|
||||
|
||||
|
||||
def get_agent_service(db_session=None) -> AgentService:
|
||||
"""Get global agent service instance."""
|
||||
global _agent_service
|
||||
if _agent_service is None:
|
||||
_agent_service = AgentService(db_session)
|
||||
elif db_session and not _agent_service.db_session:
|
||||
# Update with database session if not already set
|
||||
_agent_service.db_session = db_session
|
||||
_agent_service.config_service = AgentConfigService(db_session)
|
||||
_agent_service._load_config()
|
||||
return _agent_service
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
"""Base classes for Agent tools."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Callable
|
||||
from pydantic import BaseModel, Field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from ...utils.logger import get_logger
|
||||
|
||||
logger = get_logger("agent_tools")
|
||||
|
||||
|
||||
class ToolParameterType(str, Enum):
|
||||
"""Tool parameter types."""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParameter:
|
||||
"""Tool parameter definition."""
|
||||
name: str
|
||||
type: ToolParameterType
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Any = None
|
||||
enum: Optional[List[Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON schema."""
|
||||
param_dict = {
|
||||
"type": self.type.value,
|
||||
"description": self.description
|
||||
}
|
||||
|
||||
if self.enum:
|
||||
param_dict["enum"] = self.enum
|
||||
|
||||
if self.default is not None:
|
||||
param_dict["default"] = self.default
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Tool execution result."""
|
||||
success: bool = Field(description="Whether the tool execution was successful")
|
||||
result: Any = Field(description="The result data")
|
||||
error: Optional[str] = Field(default=None, description="Error message if failed")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""Base class for all Agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self.name = self.get_name()
|
||||
self.description = self.get_description()
|
||||
self.parameters = self.get_parameters()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Get tool name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get tool description."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> List[ToolParameter]:
|
||||
"""Get tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with given parameters."""
|
||||
pass
|
||||
|
||||
def get_schema(self) -> Dict[str, Any]:
|
||||
"""Get tool schema for LangChain."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in self.parameters:
|
||||
properties[param.name] = param.to_dict()
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Validate and process input parameters."""
|
||||
validated = {}
|
||||
|
||||
for param in self.parameters:
|
||||
value = kwargs.get(param.name)
|
||||
|
||||
# Check required parameters
|
||||
if param.required and value is None:
|
||||
raise ValueError(f"Required parameter '{param.name}' is missing")
|
||||
|
||||
# Use default if not provided
|
||||
if value is None and param.default is not None:
|
||||
value = param.default
|
||||
|
||||
# Type validation (basic)
|
||||
if value is not None:
|
||||
if param.type == ToolParameterType.INTEGER and not isinstance(value, int):
|
||||
try:
|
||||
value = int(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Parameter '{param.name}' must be an integer")
|
||||
|
||||
elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)):
|
||||
try:
|
||||
value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Parameter '{param.name}' must be a number")
|
||||
|
||||
elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool):
|
||||
if isinstance(value, str):
|
||||
value = value.lower() in ('true', '1', 'yes', 'on')
|
||||
else:
|
||||
value = bool(value)
|
||||
|
||||
# Enum validation
|
||||
if param.enum and value not in param.enum:
|
||||
raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}")
|
||||
|
||||
validated[param.name] = value
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing Agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, BaseTool] = {}
|
||||
self._enabled_tools: Dict[str, bool] = {}
|
||||
|
||||
def register(self, tool: BaseTool, enabled: bool = True) -> None:
|
||||
"""Register a tool."""
|
||||
tool_name = tool.get_name()
|
||||
self._tools[tool_name] = tool
|
||||
self._enabled_tools[tool_name] = enabled
|
||||
logger.info(f"Registered tool: {tool_name} (enabled: {enabled})")
|
||||
|
||||
def unregister(self, tool_name: str) -> None:
|
||||
"""Unregister a tool."""
|
||||
if tool_name in self._tools:
|
||||
del self._tools[tool_name]
|
||||
del self._enabled_tools[tool_name]
|
||||
logger.info(f"Unregistered tool: {tool_name}")
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(tool_name)
|
||||
|
||||
def get_enabled_tools(self) -> Dict[str, BaseTool]:
|
||||
"""Get all enabled tools."""
|
||||
return {
|
||||
name: tool for name, tool in self._tools.items()
|
||||
if self._enabled_tools.get(name, False)
|
||||
}
|
||||
|
||||
def get_all_tools(self) -> Dict[str, BaseTool]:
|
||||
"""Get all registered tools."""
|
||||
return self._tools.copy()
|
||||
|
||||
def enable_tool(self, tool_name: str) -> None:
|
||||
"""Enable a tool."""
|
||||
if tool_name in self._tools:
|
||||
self._enabled_tools[tool_name] = True
|
||||
logger.info(f"Enabled tool: {tool_name}")
|
||||
|
||||
def disable_tool(self, tool_name: str) -> None:
|
||||
"""Disable a tool."""
|
||||
if tool_name in self._tools:
|
||||
self._enabled_tools[tool_name] = False
|
||||
logger.info(f"Disabled tool: {tool_name}")
|
||||
|
||||
def is_enabled(self, tool_name: str) -> bool:
|
||||
"""Check if a tool is enabled."""
|
||||
return self._enabled_tools.get(tool_name, False)
|
||||
|
||||
def get_tools_schema(self) -> List[Dict[str, Any]]:
|
||||
"""Get schema for all enabled tools."""
|
||||
enabled_tools = self.get_enabled_tools()
|
||||
return [tool.get_schema() for tool in enabled_tools.values()]
|
||||
|
||||
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult:
|
||||
"""Execute a tool with given parameters."""
|
||||
tool = self.get_tool(tool_name)
|
||||
|
||||
if not tool:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool '{tool_name}' not found"
|
||||
)
|
||||
|
||||
if not self.is_enabled(tool_name):
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool '{tool_name}' is disabled"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate parameters
|
||||
validated_params = tool.validate_parameters(**kwargs)
|
||||
|
||||
# Execute tool
|
||||
result = await tool.execute(**validated_params)
|
||||
logger.info(f"Tool '{tool_name}' executed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Global tool registry instance
|
||||
tool_registry = ToolRegistry()
|
||||
|
|
@ -0,0 +1,741 @@
|
|||
"""LangGraph Agent service with tool calling capabilities."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain.chat_models import init_chat_model
|
||||
# from langgraph.prebuilt import create_react_agent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import ToolRegistry
|
||||
from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool
|
||||
from ..postgresql_tool_manager import get_postgresql_tool
|
||||
from ...core.config import get_settings
|
||||
from ...utils.logger import get_logger
|
||||
from ..agent_config import AgentConfigService
|
||||
from th_agenter.services.mcp.mcp_dynamic_tools import load_mcp_tools
|
||||
|
||||
logger = get_logger("langgraph_agent_service")
|
||||
|
||||
|
||||
|
||||
class LangGraphAgentConfig(BaseModel):
|
||||
"""LangGraph Agent configuration."""
|
||||
model_name: str = Field(default="gpt-3.5-turbo")
|
||||
model_provider: str = Field(default="openai")
|
||||
base_url: Optional[str] = Field(default=None)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
enabled_tools: List[str] = Field(default_factory=lambda: [
|
||||
"calculator", "weather", "search", "file", "image"
|
||||
])
|
||||
max_iterations: int = Field(default=10)
|
||||
temperature: float = Field(default=0.7)
|
||||
max_tokens: int = Field(default=1000)
|
||||
system_message: str = Field(
|
||||
default="""你是一个有用的AI助手,可以使用各种工具来帮助用户解决问题。
|
||||
重要规则:
|
||||
1. 工具调用失败时,必须仔细分析失败原因,特别是参数格式问题
|
||||
3. 在重新调用工具前,先解释上次失败的原因和改进方案
|
||||
4. 确保每个工具调用的参数格式严格符合工具的要求 """
|
||||
)
|
||||
verbose: bool = Field(default=True)
|
||||
|
||||
|
||||
class LangGraphAgentService:
|
||||
"""LangGraph Agent service using low-level LangGraph graph (React pattern)."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
self.settings = get_settings()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = LangGraphAgentConfig()
|
||||
self.tools = []
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
self._create_react_agent()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize available tools."""
|
||||
try:
|
||||
dynamic_tools = load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 MCP 动态工具失败,使用本地工具回退: {e}")
|
||||
dynamic_tools = []
|
||||
|
||||
# Always keep DateTimeTool locally
|
||||
base_tools = [DateTimeTool()]
|
||||
|
||||
if dynamic_tools:
|
||||
self.tools = dynamic_tools + base_tools
|
||||
logger.info(f"LangGraph 绑定 MCP 动态工具: {[t.name for t in dynamic_tools]}")
|
||||
else:
|
||||
# Fallback to local weather/search when MCP not available
|
||||
self.tools = [
|
||||
WeatherQueryTool(),
|
||||
TavilySearchTool(),
|
||||
] + base_tools
|
||||
logger.info("MCP 不可用,已回退到本地 Weather/Search 工具")
|
||||
|
||||
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
db_config = self.config_service.get_active_config()
|
||||
if db_config:
|
||||
# Update config with database values
|
||||
config_dict = db_config.config_data
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database: {e}")
|
||||
|
||||
|
||||
|
||||
def _create_react_agent(self):
|
||||
"""Create LangGraph agent using low-level StateGraph with explicit nodes/edges."""
|
||||
try:
|
||||
# Initialize the model
|
||||
llm_config = get_settings().llm.get_current_config()
|
||||
self.model = init_chat_model(
|
||||
model=llm_config['model'],
|
||||
model_provider='openai',
|
||||
temperature=llm_config['temperature'],
|
||||
max_tokens=llm_config['max_tokens'],
|
||||
base_url= llm_config['base_url'],
|
||||
api_key=llm_config['api_key']
|
||||
)
|
||||
|
||||
# Bind tools to the model so it can propose tool calls
|
||||
try:
|
||||
self.bound_model = self.model.bind_tools(self.tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}")
|
||||
self.bound_model = self.model
|
||||
|
||||
# Build low-level React graph: State -> agent -> tools -> agent ... until stop
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langchain_core.messages import ToolMessage, BaseMessage
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[List[BaseMessage], add_messages]
|
||||
|
||||
# Node: call the model
|
||||
def agent_node(state: AgentState) -> AgentState:
|
||||
messages = state["messages"]
|
||||
# Optionally include a system instruction at the start for first turn
|
||||
if messages and messages[0].__class__.__name__ != 'SystemMessage':
|
||||
# Keep user history untouched; rely on upstream to include system if desired
|
||||
pass
|
||||
ai = self.bound_model.invoke(messages)
|
||||
return {"messages": [ai]}
|
||||
|
||||
# Node: execute tools requested by the last AI message
|
||||
def tools_node(state: AgentState) -> AgentState:
|
||||
messages = state["messages"]
|
||||
last = messages[-1]
|
||||
outputs: List[ToolMessage] = []
|
||||
try:
|
||||
tool_calls = getattr(last, 'tool_calls', []) or []
|
||||
tool_map = {t.name: t for t in self.tools}
|
||||
for call in tool_calls:
|
||||
name = call.get('name') if isinstance(call, dict) else getattr(call, 'name', None)
|
||||
args = call.get('args') if isinstance(call, dict) else getattr(call, 'args', {})
|
||||
call_id = call.get('id') if isinstance(call, dict) else getattr(call, 'id', '')
|
||||
if name in tool_map:
|
||||
try:
|
||||
result = tool_map[name].invoke(args)
|
||||
except Exception as te:
|
||||
result = f"Tool {name} execution error: {te}"
|
||||
else:
|
||||
result = f"Unknown tool: {name}"
|
||||
outputs.append(ToolMessage(content=str(result), tool_call_id=call_id))
|
||||
except Exception as e:
|
||||
outputs.append(ToolMessage(content=f"Tool execution error: {e}", tool_call_id=""))
|
||||
return {"messages": outputs}
|
||||
|
||||
# Router: decide next step after agent node
|
||||
def route_after_agent(state: AgentState) -> str:
|
||||
last = state["messages"][-1]
|
||||
finish_reason = None
|
||||
try:
|
||||
meta = getattr(last, 'response_metadata', {}) or {}
|
||||
finish_reason = meta.get('finish_reason')
|
||||
except Exception:
|
||||
finish_reason = None
|
||||
# If the model decided to call tools, continue to tools node
|
||||
if getattr(last, 'tool_calls', None):
|
||||
return "tools"
|
||||
# Otherwise, end
|
||||
return END
|
||||
|
||||
graph = StateGraph(AgentState)
|
||||
graph.add_node("agent", agent_node)
|
||||
graph.add_node("tools", tools_node)
|
||||
graph.add_edge(START, "agent")
|
||||
graph.add_conditional_edges("agent", route_after_agent, {"tools": "tools", END: END})
|
||||
graph.add_edge("tools", "agent")
|
||||
|
||||
# Compile graph and store as self.agent for compatibility with existing code
|
||||
self.react_agent = graph.compile()
|
||||
|
||||
logger.info("LangGraph low-level React agent created successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _format_tools_info(self) -> str:
|
||||
"""Format tools information for the prompt."""
|
||||
tools_info = []
|
||||
for tool_name in self.config.enabled_tools:
|
||||
tool = self.tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
params_info = []
|
||||
for param in tool.get_parameters():
|
||||
params_info.append(f" - {param.name} ({param.type.value}): {param.description}")
|
||||
|
||||
tool_info = f"**{tool.get_name()}**: {tool.get_description()}"
|
||||
if params_info:
|
||||
tool_info += "\n" + "\n".join(params_info)
|
||||
tools_info.append(tool_info)
|
||||
|
||||
return "\n\n".join(tools_info)
|
||||
|
||||
|
||||
|
||||
async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Process a chat message using LangGraph."""
|
||||
try:
|
||||
logger.info(f"Starting chat with message: {message[:100]}...")
|
||||
|
||||
# Convert chat history to messages
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Add current message
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# Use the low-level graph directly
|
||||
result = await self.react_agent.ainvoke({"messages": messages}, {"recursion_limit": 6}, )
|
||||
|
||||
# Extract final response
|
||||
final_response = ""
|
||||
if "messages" in result and result["messages"]:
|
||||
last_message = result["messages"][-1]
|
||||
if hasattr(last_message, 'content'):
|
||||
final_response = last_message.content
|
||||
elif isinstance(last_message, dict) and "content" in last_message:
|
||||
final_response = last_message["content"]
|
||||
|
||||
return {
|
||||
"response": final_response,
|
||||
"intermediate_steps": [],
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangGraph chat error: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"response": f"抱歉,处理您的请求时出现错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[
|
||||
Dict[str, Any], None]:
|
||||
"""Process a chat message using LangGraph with streaming."""
|
||||
try:
|
||||
logger.info(f"Starting streaming chat with message: {message[:100]}...")
|
||||
|
||||
# Convert chat history to messages
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Add current message
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# Track state for streaming
|
||||
intermediate_steps = []
|
||||
final_response_started = False
|
||||
accumulated_response = ""
|
||||
final_ai_message = None
|
||||
|
||||
# Stream the agent execution
|
||||
async for event in self.react_agent.astream({"messages": messages}):
|
||||
# Handle different event types from LangGraph
|
||||
print('event===', event)
|
||||
if isinstance(event, dict):
|
||||
for node_name, node_output in event.items():
|
||||
logger.info(f"Processing node: {node_name}, output type: {type(node_output)}")
|
||||
|
||||
# 处理 tools 节点
|
||||
if "tools" in node_name.lower():
|
||||
# 提取工具信息
|
||||
tool_infos = []
|
||||
|
||||
if isinstance(node_output, dict) and "messages" in node_output:
|
||||
messages_in_output = node_output["messages"]
|
||||
|
||||
for msg in messages_in_output:
|
||||
# 处理 ToolMessage 对象
|
||||
if hasattr(msg, 'name') and hasattr(msg, 'content'):
|
||||
tool_info = {
|
||||
"tool_name": msg.name,
|
||||
"tool_output": msg.content,
|
||||
"tool_call_id": getattr(msg, 'tool_call_id', ''),
|
||||
"status": "completed"
|
||||
}
|
||||
tool_infos.append(tool_info)
|
||||
elif isinstance(msg, dict):
|
||||
if 'name' in msg and 'content' in msg:
|
||||
tool_info = {
|
||||
"tool_name": msg['name'],
|
||||
"tool_output": msg['content'],
|
||||
"tool_call_id": msg.get('tool_call_id', ''),
|
||||
"status": "completed"
|
||||
}
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
# 返回 tools_end 事件
|
||||
for tool_info in tool_infos:
|
||||
yield {
|
||||
"type": "tools_end",
|
||||
"content": f"工具 {tool_info['tool_name']} 执行完成",
|
||||
"tool_name": tool_info["tool_name"],
|
||||
"tool_output": tool_info["tool_output"],
|
||||
"node_name": node_name,
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 处理 agent 节点
|
||||
elif "agent" in node_name.lower():
|
||||
if isinstance(node_output, dict) and "messages" in node_output:
|
||||
messages_in_output = node_output["messages"]
|
||||
if messages_in_output:
|
||||
last_msg = messages_in_output[-1]
|
||||
|
||||
# 获取 finish_reason
|
||||
finish_reason = None
|
||||
if hasattr(last_msg, 'response_metadata'):
|
||||
finish_reason = last_msg.response_metadata.get('finish_reason')
|
||||
elif isinstance(last_msg, dict) and 'response_metadata' in last_msg:
|
||||
finish_reason = last_msg['response_metadata'].get('finish_reason')
|
||||
|
||||
# 判断是否为 thinking 或 response
|
||||
if finish_reason == 'tool_calls':
|
||||
# thinking 状态
|
||||
thinking_content = "🤔 正在思考..."
|
||||
if hasattr(last_msg, 'content') and last_msg.content:
|
||||
thinking_content = f"🤔 思考: {last_msg.content[:200]}..."
|
||||
elif isinstance(last_msg, dict) and "content" in last_msg:
|
||||
thinking_content = f"🤔 思考: {last_msg['content'][:200]}..."
|
||||
|
||||
yield {
|
||||
"type": "thinking",
|
||||
"content": thinking_content,
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
elif finish_reason == 'stop':
|
||||
# response 状态
|
||||
if hasattr(last_msg, 'content') and hasattr(last_msg,
|
||||
'__class__') and 'AI' in last_msg.__class__.__name__:
|
||||
current_content = last_msg.content
|
||||
final_ai_message = last_msg
|
||||
|
||||
if not final_response_started and current_content:
|
||||
final_response_started = True
|
||||
yield {
|
||||
"type": "response_start",
|
||||
"content": "",
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": False
|
||||
}
|
||||
|
||||
if current_content and len(current_content) > len(accumulated_response):
|
||||
new_content = current_content[len(accumulated_response):]
|
||||
|
||||
for char in new_content:
|
||||
accumulated_response += char
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": accumulated_response,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
else:
|
||||
# 其他 agent 状态
|
||||
yield {
|
||||
"type": "step",
|
||||
"content": f"📋 执行步骤: {node_name}",
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 处理其他节点
|
||||
else:
|
||||
yield {
|
||||
"type": "step",
|
||||
"content": f"📋 执行步骤: {node_name}",
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 最终完成事件
|
||||
yield {
|
||||
"type": "complete",
|
||||
"content": accumulated_response,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_stream: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"处理请求时出错: {str(e)}",
|
||||
"done": True
|
||||
}
|
||||
|
||||
# 确保最终响应包含完整内容
|
||||
final_content = accumulated_response
|
||||
if not final_content and final_ai_message and hasattr(final_ai_message, 'content'):
|
||||
final_content = final_ai_message.content or ""
|
||||
|
||||
# Final completion signal
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": final_content,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangGraph chat stream error: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"抱歉,处理您的请求时出现错误: {str(e)}",
|
||||
"error": str(e),
|
||||
"done": True
|
||||
}
|
||||
|
||||
def get_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available tools."""
|
||||
tools = []
|
||||
for tool in self.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": [],
|
||||
"enabled": True
|
||||
})
|
||||
return tools
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current agent configuration."""
|
||||
return self.config.dict()
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update agent configuration."""
|
||||
for key, value in config.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Recreate agent with new config
|
||||
self._create_react_agent()
|
||||
logger.info("Agent configuration updated")
|
||||
|
||||
def _create_plan_execute_agent(self):
|
||||
"""Create a Plan-and-Execute agent using LangGraph low-level API.
|
||||
结构:START -> planner -> executor(loop) -> summarize -> END
|
||||
- planner:根据用户问题生成计划(JSON 数组)
|
||||
- executor:逐步执行计划(可调用工具),收集每步结果
|
||||
- summarize:综合计划与执行结果产出最终回答
|
||||
"""
|
||||
from typing import TypedDict, Annotated, List
|
||||
import json
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
|
||||
try:
|
||||
self.bound_model = self.model.bind_tools(self.tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}")
|
||||
self.bound_model = self.model
|
||||
class PlanState(TypedDict):
|
||||
messages: Annotated[List[BaseMessage], add_messages]
|
||||
plan_steps: List[str]
|
||||
current_step: int
|
||||
step_results: List[str]
|
||||
|
||||
def planner_node(state: PlanState) -> PlanState:
|
||||
messages = state.get("messages", [])
|
||||
plan_prompt = (
|
||||
"你是规划助手。基于对话内容生成可执行计划,"
|
||||
"用 JSON 数组返回,每个元素是一条明确且可操作的步骤。"
|
||||
"仅输出 JSON,不要额外解释。"
|
||||
)
|
||||
ai_plan = self.model.invoke(messages + [HumanMessage(content=plan_prompt)])
|
||||
steps: List[str] = []
|
||||
try:
|
||||
parsed = json.loads(ai_plan.content)
|
||||
if isinstance(parsed, list):
|
||||
steps = [str(s) for s in parsed]
|
||||
except Exception:
|
||||
# 回退:按行拆分
|
||||
steps = [s.strip() for s in ai_plan.content.split("\n") if s.strip()]
|
||||
return {
|
||||
"messages": [ai_plan],
|
||||
"plan_steps": steps,
|
||||
"current_step": 0,
|
||||
"step_results": []
|
||||
}
|
||||
|
||||
def executor_node(state: PlanState) -> PlanState:
|
||||
idx = state.get("current_step", 0)
|
||||
steps = state.get("plan_steps", [])
|
||||
msgs = state.get("messages", [])
|
||||
if idx >= len(steps):
|
||||
return {"messages": [], "current_step": idx, "step_results": state.get("step_results", [])}
|
||||
|
||||
step_text = steps[idx]
|
||||
exec_prompt = (
|
||||
f"请执行计划的第{idx+1}步:{step_text}。"
|
||||
"需要用工具时创建工具调用;完成后给出该步的结果。"
|
||||
)
|
||||
ai_exec = self.bound_model.invoke(msgs + [HumanMessage(content=exec_prompt)])
|
||||
|
||||
new_messages: List[BaseMessage] = [ai_exec]
|
||||
step_result_content = None
|
||||
|
||||
# 处理工具调用
|
||||
tool_map = {t.name: t for t in self.tools}
|
||||
tool_msgs: List[ToolMessage] = []
|
||||
tool_calls = getattr(ai_exec, "tool_calls", []) or (ai_exec.additional_kwargs.get("tool_calls") if hasattr(ai_exec, "additional_kwargs") else [])
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
name = call.get("name")
|
||||
args = call.get("args", {})
|
||||
tool_obj = tool_map.get(name)
|
||||
if tool_obj:
|
||||
try:
|
||||
result = tool_obj.invoke(args)
|
||||
except Exception as e:
|
||||
result = f"工具执行失败: {e}"
|
||||
else:
|
||||
result = f"未找到工具: {name}"
|
||||
tool_call_id = call.get("id") or call.get("tool_call_id") or call.get("call_id") or f"tool_{name}"
|
||||
tool_msgs.append(ToolMessage(content=str(result), tool_call_id=tool_call_id, name=name or "tool"))
|
||||
new_messages.extend(tool_msgs)
|
||||
# 基于工具输出总结该步结果
|
||||
summarize_step = "请基于上述工具输出,总结该步骤的结果,给出结构化要点与可读说明。"
|
||||
ai_step = self.bound_model.invoke(msgs + [ai_exec] + tool_msgs + [HumanMessage(content=summarize_step)])
|
||||
step_result_content = ai_step.content
|
||||
new_messages.append(ai_step)
|
||||
else:
|
||||
step_result_content = ai_exec.content
|
||||
|
||||
all_results = list(state.get("step_results", []))
|
||||
if step_result_content:
|
||||
all_results.append(step_result_content)
|
||||
|
||||
return {
|
||||
"messages": new_messages,
|
||||
"current_step": idx + 1,
|
||||
"step_results": all_results
|
||||
}
|
||||
|
||||
def route_after_planner(state: PlanState) -> str:
|
||||
return "executor" if state.get("plan_steps") else END
|
||||
|
||||
def route_after_executor(state: PlanState) -> str:
|
||||
cur = state.get("current_step", 0)
|
||||
total = len(state.get("plan_steps", []))
|
||||
return "executor" if cur < total else "summarize"
|
||||
|
||||
def summarize_node(state: PlanState) -> PlanState:
|
||||
import json as _json
|
||||
msgs = state.get("messages", [])
|
||||
steps = state.get("plan_steps", [])
|
||||
results = state.get("step_results", [])
|
||||
final_prompt = (
|
||||
"请综合以上计划与各步骤结果,生成最终回答。"
|
||||
"要求:逻辑清晰、结论明确、可读性强;如存在不确定性请注明。"
|
||||
)
|
||||
context_msg = HumanMessage(content=(
|
||||
f"计划: {_json.dumps(steps, ensure_ascii=False)}\n"
|
||||
f"步骤结果: {_json.dumps(results, ensure_ascii=False)}\n"
|
||||
f"{final_prompt}"
|
||||
))
|
||||
ai_final = self.model.invoke(msgs + [context_msg])
|
||||
return {"messages": [ai_final]}
|
||||
|
||||
graph = StateGraph(PlanState)
|
||||
graph.add_node("planner", planner_node)
|
||||
graph.add_node("executor", executor_node)
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_edge(START, "planner")
|
||||
graph.add_conditional_edges("planner", route_after_planner, {"executor": "executor", END: END})
|
||||
graph.add_conditional_edges("executor", route_after_executor, {"executor": "executor", "summarize": "summarize"})
|
||||
graph.add_edge("summarize", END)
|
||||
|
||||
self.plan_execute_agent = graph.compile()
|
||||
|
||||
async def chat_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Single-turn Plan-and-Execute chat."""
|
||||
# 确保 agent 已创建
|
||||
if not hasattr(self, "plan_execute_agent"):
|
||||
self._create_plan_execute_agent()
|
||||
|
||||
# 构建消息
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
try:
|
||||
result = await self.plan_execute_agent.ainvoke({"messages": messages}, config={"recursion_limit": self.config.max_iterations})
|
||||
final_msg = None
|
||||
if isinstance(result, dict) and "messages" in result:
|
||||
ms = result["messages"]
|
||||
if ms:
|
||||
final_msg = ms[-1]
|
||||
final_text = getattr(final_msg, "content", "") if final_msg else ""
|
||||
return {
|
||||
"status": "success",
|
||||
"response": final_text,
|
||||
"raw": str(result)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_plan_execute: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Streamed Plan-and-Execute chat."""
|
||||
import asyncio as _asyncio
|
||||
if not hasattr(self, "plan_execute_agent"):
|
||||
self._create_plan_execute_agent()
|
||||
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
try:
|
||||
accumulated = ""
|
||||
async for event in self.react_agent.astream({"messages": messages}, config={"recursion_limit": self.config.max_iterations}):
|
||||
for key, node_output in event.items():
|
||||
node_name = key[0] if isinstance(key, tuple) else key
|
||||
if node_name == "planner":
|
||||
# 规划阶段
|
||||
content = "生成计划中..."
|
||||
if node_output and isinstance(node_output, dict):
|
||||
m = node_output.get("messages", [])
|
||||
if m:
|
||||
last = m[-1]
|
||||
if hasattr(last, "content"):
|
||||
content = str(last.content)[:400]
|
||||
yield {"type": "planning", "content": content, "done": False}
|
||||
await _asyncio.sleep(0.05)
|
||||
elif node_name == "executor":
|
||||
# 执行阶段(可能包含工具)
|
||||
yield {"type": "step", "content": "执行计划步骤", "done": False}
|
||||
await _asyncio.sleep(0.05)
|
||||
if node_output and isinstance(node_output, dict):
|
||||
msgs = node_output.get("messages", [])
|
||||
# 输出工具结束标记
|
||||
tool_msgs = [m for m in msgs if hasattr(m, "__class__") and "Tool" in m.__class__.__name__]
|
||||
if tool_msgs:
|
||||
yield {"type": "tools_end", "content": f"完成 {len(tool_msgs)} 次工具执行", "done": False}
|
||||
await _asyncio.sleep(0.03)
|
||||
# 尝试输出该步总结
|
||||
ai_msgs = [m for m in msgs if hasattr(m, "__class__") and "AI" in m.__class__.__name__]
|
||||
if ai_msgs:
|
||||
text = ai_msgs[-1].content
|
||||
if text:
|
||||
accumulated = text
|
||||
yield {"type": "response", "content": accumulated, "done": False}
|
||||
await _asyncio.sleep(0.02)
|
||||
elif node_name == "summarize":
|
||||
# 最终总结
|
||||
if node_output and isinstance(node_output, dict):
|
||||
msgs = node_output.get("messages", [])
|
||||
if msgs:
|
||||
final = msgs[-1]
|
||||
content = getattr(final, "content", "")
|
||||
if content:
|
||||
yield {"type": "response_start", "content": "", "done": False}
|
||||
yield {"type": "response", "content": content, "done": False}
|
||||
accumulated = content
|
||||
await _asyncio.sleep(0.02)
|
||||
yield {"type": "complete", "content": accumulated, "done": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_stream_plan_execute: {e}", exc_info=True)
|
||||
yield {"type": "error", "content": str(e), "done": True}
|
||||
|
||||
|
||||
# Global instance
|
||||
_langgraph_agent_service: Optional[LangGraphAgentService] = None
|
||||
|
||||
|
||||
def get_langgraph_agent_service(db_session=None) -> LangGraphAgentService:
|
||||
"""Get or create LangGraph agent service instance."""
|
||||
global _langgraph_agent_service
|
||||
|
||||
if _langgraph_agent_service is None:
|
||||
_langgraph_agent_service = LangGraphAgentService(db_session)
|
||||
logger.info("LangGraph Agent service initialized")
|
||||
|
||||
return _langgraph_agent_service
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
"""Agent services package.
|
||||
|
||||
轻量化导入:仅暴露基础工具类型,避免在包导入时加载耗时的服务层。使用 AgentService 时请从子模块显式导入:
|
||||
from open_agent.services.agent.agent_service import AgentService
|
||||
"""
|
||||
|
||||
from .base import BaseTool, ToolRegistry
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolRegistry"
|
||||
]
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
"""LangChain Agent service with tool calling capabilities."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool, ToolRegistry, ToolResult
|
||||
from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool
|
||||
from ..postgresql_tool_manager import get_postgresql_tool
|
||||
from ..mysql_tool_manager import get_mysql_tool
|
||||
from ...core.config import get_settings
|
||||
from ..agent_config import AgentConfigService
|
||||
from loguru import logger
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Agent configuration."""
|
||||
enabled_tools: List[str] = Field(default_factory=lambda: [
|
||||
"calculator", "weather", "search", "datetime", "file", "generate_image", "postgresql_mcp", "mysql_mcp"
|
||||
])
|
||||
max_iterations: int = Field(default=10)
|
||||
temperature: float = Field(default=0.1)
|
||||
system_message: str = Field(
|
||||
default="You are a helpful AI assistant with access to various tools. "
|
||||
"Use the available tools to help answer user questions accurately. "
|
||||
"Always explain your reasoning and the tools you're using."
|
||||
)
|
||||
verbose: bool = Field(default=True)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""LangChain Agent service with tool calling capabilities."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
self.settings = get_settings()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = AgentConfig()
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize and register all available tools."""
|
||||
tools = [
|
||||
WeatherQueryTool(),
|
||||
TavilySearchTool(),
|
||||
DateTimeTool(),
|
||||
get_postgresql_tool(), # 使用单例PostgreSQL MCP工具
|
||||
get_mysql_tool() # 使用单例MySQL MCP工具
|
||||
]
|
||||
|
||||
for tool in tools:
|
||||
self.tool_registry.register(tool)
|
||||
logger.info(f"Registered tool: {tool.get_name()}")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
config_dict = self.config_service.get_config_dict()
|
||||
# Update config with database values
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded agent configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database, using defaults: {str(e)}")
|
||||
|
||||
def _get_enabled_tools(self) -> List[Any]:
|
||||
"""Get list of enabled LangChain tools."""
|
||||
enabled_tools = []
|
||||
|
||||
for tool_name in self.config.enabled_tools:
|
||||
tool = self.tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
enabled_tools.append(tool)
|
||||
logger.debug(f"Enabled tool: {tool_name}")
|
||||
else:
|
||||
logger.warning(f"Tool not found: {tool_name}")
|
||||
|
||||
return enabled_tools
|
||||
|
||||
def _create_agent_executor(self) -> Any:
|
||||
"""Create LangChain agent executor."""
|
||||
# Get LLM configuration
|
||||
from ...core.llm import create_llm
|
||||
llm = create_llm()
|
||||
|
||||
# Get enabled tools
|
||||
tools = self._get_enabled_tools()
|
||||
|
||||
# Create prompt template
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", self.config.system_message),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{input}"),
|
||||
])
|
||||
|
||||
# Create agent using new LangChain 1.0+ API
|
||||
agent = create_agent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Process chat message with agent."""
|
||||
try:
|
||||
logger.info(f"Processing agent chat message: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
langchain_history.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
langchain_history.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Execute agent
|
||||
result = await agent.ainvoke({
|
||||
"input": message,
|
||||
"chat_history": langchain_history
|
||||
})
|
||||
|
||||
logger.info(f"Agent response generated successfully")
|
||||
|
||||
return {
|
||||
"response": result["output"] if isinstance(result, dict) and "output" in result else str(result),
|
||||
"tool_calls": [],
|
||||
"success": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent chat error: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"response": f"Sorry, I encountered an error: {str(e)}",
|
||||
"tool_calls": [],
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process chat message with agent (streaming)."""
|
||||
tool_calls = [] # Initialize tool_calls at the beginning
|
||||
try:
|
||||
logger.info(f"Processing agent chat stream: {message[:100]}...")
|
||||
|
||||
# Create agent
|
||||
agent = self._create_agent_executor()
|
||||
|
||||
# Convert chat history to LangChain format
|
||||
langchain_history = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
langchain_history.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
langchain_history.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Yield initial status
|
||||
yield {
|
||||
"type": "status",
|
||||
"content": "🤖 开始分析您的请求...",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Generate response
|
||||
result = await agent.ainvoke({
|
||||
"input": message,
|
||||
"chat_history": langchain_history
|
||||
})
|
||||
|
||||
response_content = result["output"] if isinstance(result, dict) and "output" in result else str(result)
|
||||
|
||||
# Yield the final response in chunks to simulate streaming
|
||||
words = response_content.split()
|
||||
current_content = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
current_content += word + " "
|
||||
|
||||
# Yield every 2-3 words or at the end
|
||||
if (i + 1) % 2 == 0 or i == len(words) - 1:
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": current_content.strip(),
|
||||
"tool_calls": tool_calls if i == len(words) - 1 else [],
|
||||
"done": i == len(words) - 1
|
||||
}
|
||||
|
||||
# Small delay to simulate typing
|
||||
if i < len(words) - 1:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Agent stream response completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent chat stream error: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"Sorry, I encountered an error: {str(e)}",
|
||||
"done": True
|
||||
}
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update agent configuration."""
|
||||
try:
|
||||
# Update configuration
|
||||
for key, value in config.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info(f"Updated agent config: {key} = {value}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent config: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def load_config_from_db(self, config_id: Optional[int] = None):
|
||||
"""Load configuration from database."""
|
||||
if not self.config_service:
|
||||
logger.warning("No database session available for loading config")
|
||||
return
|
||||
|
||||
try:
|
||||
config_dict = self.config_service.get_config_dict(config_id)
|
||||
self.update_config(config_dict)
|
||||
logger.info(f"Loaded configuration from database (ID: {config_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config from database: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available tools."""
|
||||
tools = []
|
||||
for tool_name, tool in self.tool_registry._tools.items():
|
||||
tools.append({
|
||||
"name": tool.get_name(),
|
||||
"description": tool.get_description(),
|
||||
"parameters": [{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum
|
||||
} for param in tool.get_parameters()],
|
||||
"enabled": tool_name in self.config.enabled_tools
|
||||
})
|
||||
return tools
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current agent configuration."""
|
||||
return self.config.dict()
|
||||
|
||||
|
||||
# Global agent service instance
|
||||
_agent_service: Optional[AgentService] = None
|
||||
|
||||
|
||||
def get_agent_service(db_session=None) -> AgentService:
|
||||
"""Get global agent service instance."""
|
||||
global _agent_service
|
||||
if _agent_service is None:
|
||||
_agent_service = AgentService(db_session)
|
||||
elif db_session and not _agent_service.db_session:
|
||||
# Update with database session if not already set
|
||||
_agent_service.db_session = db_session
|
||||
_agent_service.config_service = AgentConfigService(db_session)
|
||||
_agent_service._load_config()
|
||||
return _agent_service
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Base classes for Agent tools."""
|
||||
|
||||
import json
|
||||
from loguru import logger
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type, Callable
|
||||
from pydantic import BaseModel, Field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
class ToolParameterType(str, Enum):
|
||||
"""Tool parameter types."""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParameter:
|
||||
"""Tool parameter definition."""
|
||||
name: str
|
||||
type: ToolParameterType
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Any = None
|
||||
enum: Optional[List[Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON schema."""
|
||||
param_dict = {
|
||||
"type": self.type.value,
|
||||
"description": self.description
|
||||
}
|
||||
|
||||
if self.enum:
|
||||
param_dict["enum"] = self.enum
|
||||
|
||||
if self.default is not None:
|
||||
param_dict["default"] = self.default
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Tool execution result."""
|
||||
success: bool = Field(description="Whether the tool execution was successful")
|
||||
result: Any = Field(description="The result data")
|
||||
error: Optional[str] = Field(default=None, description="Error message if failed")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""Base class for all Agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self.name = self.get_name()
|
||||
self.description = self.get_description()
|
||||
self.parameters = self.get_parameters()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Get tool name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get tool description."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> List[ToolParameter]:
|
||||
"""Get tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with given parameters."""
|
||||
pass
|
||||
|
||||
def get_schema(self) -> Dict[str, Any]:
|
||||
"""Get tool schema for LangChain."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in self.parameters:
|
||||
properties[param.name] = param.to_dict()
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def validate_parameters(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Validate and process input parameters."""
|
||||
validated = {}
|
||||
|
||||
for param in self.parameters:
|
||||
value = kwargs.get(param.name)
|
||||
|
||||
# Check required parameters
|
||||
if param.required and value is None:
|
||||
raise ValueError(f"Required parameter '{param.name}' is missing")
|
||||
|
||||
# Use default if not provided
|
||||
if value is None and param.default is not None:
|
||||
value = param.default
|
||||
|
||||
# Type validation (basic)
|
||||
if value is not None:
|
||||
if param.type == ToolParameterType.INTEGER and not isinstance(value, int):
|
||||
try:
|
||||
value = int(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Parameter '{param.name}' must be an integer")
|
||||
|
||||
elif param.type == ToolParameterType.FLOAT and not isinstance(value, (int, float)):
|
||||
try:
|
||||
value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Parameter '{param.name}' must be a number")
|
||||
|
||||
elif param.type == ToolParameterType.BOOLEAN and not isinstance(value, bool):
|
||||
if isinstance(value, str):
|
||||
value = value.lower() in ('true', '1', 'yes', 'on')
|
||||
else:
|
||||
value = bool(value)
|
||||
|
||||
# Enum validation
|
||||
if param.enum and value not in param.enum:
|
||||
raise ValueError(f"Parameter '{param.name}' must be one of {param.enum}")
|
||||
|
||||
validated[param.name] = value
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Registry for managing Agent tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, BaseTool] = {}
|
||||
self._enabled_tools: Dict[str, bool] = {}
|
||||
|
||||
def register(self, tool: BaseTool, enabled: bool = True) -> None:
|
||||
"""Register a tool."""
|
||||
tool_name = tool.get_name()
|
||||
self._tools[tool_name] = tool
|
||||
self._enabled_tools[tool_name] = enabled
|
||||
logger.info(f"Registered tool: {tool_name} (enabled: {enabled})")
|
||||
|
||||
def unregister(self, tool_name: str) -> None:
|
||||
"""Unregister a tool."""
|
||||
if tool_name in self._tools:
|
||||
del self._tools[tool_name]
|
||||
del self._enabled_tools[tool_name]
|
||||
logger.info(f"Unregistered tool: {tool_name}")
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(tool_name)
|
||||
|
||||
def get_enabled_tools(self) -> Dict[str, BaseTool]:
|
||||
"""Get all enabled tools."""
|
||||
return {
|
||||
name: tool for name, tool in self._tools.items()
|
||||
if self._enabled_tools.get(name, False)
|
||||
}
|
||||
|
||||
def get_all_tools(self) -> Dict[str, BaseTool]:
|
||||
"""Get all registered tools."""
|
||||
return self._tools.copy()
|
||||
|
||||
def enable_tool(self, tool_name: str) -> None:
|
||||
"""Enable a tool."""
|
||||
if tool_name in self._tools:
|
||||
self._enabled_tools[tool_name] = True
|
||||
logger.info(f"Enabled tool: {tool_name}")
|
||||
|
||||
def disable_tool(self, tool_name: str) -> None:
|
||||
"""Disable a tool."""
|
||||
if tool_name in self._tools:
|
||||
self._enabled_tools[tool_name] = False
|
||||
logger.info(f"Disabled tool: {tool_name}")
|
||||
|
||||
def is_enabled(self, tool_name: str) -> bool:
|
||||
"""Check if a tool is enabled."""
|
||||
return self._enabled_tools.get(tool_name, False)
|
||||
|
||||
def get_tools_schema(self) -> List[Dict[str, Any]]:
|
||||
"""Get schema for all enabled tools."""
|
||||
enabled_tools = self.get_enabled_tools()
|
||||
return [tool.get_schema() for tool in enabled_tools.values()]
|
||||
|
||||
async def execute_tool(self, tool_name: str, **kwargs) -> ToolResult:
|
||||
"""Execute a tool with given parameters."""
|
||||
tool = self.get_tool(tool_name)
|
||||
|
||||
if not tool:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool '{tool_name}' not found"
|
||||
)
|
||||
|
||||
if not self.is_enabled(tool_name):
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool '{tool_name}' is disabled"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate parameters
|
||||
validated_params = tool.validate_parameters(**kwargs)
|
||||
|
||||
# Execute tool
|
||||
result = await tool.execute(**validated_params)
|
||||
logger.info(f"Tool '{tool_name}' executed successfully")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool '{tool_name}' execution failed: {str(e)}", exc_info=True)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=None,
|
||||
error=f"Tool execution failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Global tool registry instance
|
||||
tool_registry = ToolRegistry()
|
||||
|
|
@ -0,0 +1,737 @@
|
|||
"""LangGraph Agent service with tool calling capabilities."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain.chat_models import init_chat_model
|
||||
# from langgraph.prebuilt import create_react_agent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import ToolRegistry
|
||||
from th_agenter.services.tools import WeatherQueryTool, TavilySearchTool, DateTimeTool
|
||||
from ..postgresql_tool_manager import get_postgresql_tool
|
||||
from ...core.config import get_settings
|
||||
from ..agent_config import AgentConfigService
|
||||
from th_agenter.services.mcp.mcp_dynamic_tools import load_mcp_tools
|
||||
from loguru import logger
|
||||
|
||||
class LangGraphAgentConfig(BaseModel):
|
||||
"""LangGraph Agent configuration."""
|
||||
model_name: str = Field(default="gpt-3.5-turbo")
|
||||
model_provider: str = Field(default="openai")
|
||||
base_url: Optional[str] = Field(default=None)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
enabled_tools: List[str] = Field(default_factory=lambda: [
|
||||
"calculator", "weather", "search", "file", "image"
|
||||
])
|
||||
max_iterations: int = Field(default=10)
|
||||
temperature: float = Field(default=0.7)
|
||||
max_tokens: int = Field(default=1000)
|
||||
system_message: str = Field(
|
||||
default="""你是一个有用的AI助手,可以使用各种工具来帮助用户解决问题。
|
||||
重要规则:
|
||||
1. 工具调用失败时,必须仔细分析失败原因,特别是参数格式问题
|
||||
3. 在重新调用工具前,先解释上次失败的原因和改进方案
|
||||
4. 确保每个工具调用的参数格式严格符合工具的要求 """
|
||||
)
|
||||
verbose: bool = Field(default=True)
|
||||
|
||||
|
||||
class LangGraphAgentService:
|
||||
"""LangGraph Agent service using low-level LangGraph graph (React pattern)."""
|
||||
|
||||
def __init__(self, db_session=None):
|
||||
self.settings = get_settings()
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.config = LangGraphAgentConfig()
|
||||
self.tools = []
|
||||
self.db_session = db_session
|
||||
self.config_service = AgentConfigService(db_session) if db_session else None
|
||||
self._initialize_tools()
|
||||
self._load_config()
|
||||
self._create_react_agent()
|
||||
|
||||
def _initialize_tools(self):
|
||||
"""Initialize available tools."""
|
||||
try:
|
||||
dynamic_tools = load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 MCP 动态工具失败,使用本地工具回退: {e}")
|
||||
dynamic_tools = []
|
||||
|
||||
# Always keep DateTimeTool locally
|
||||
base_tools = [DateTimeTool()]
|
||||
|
||||
if dynamic_tools:
|
||||
self.tools = dynamic_tools + base_tools
|
||||
logger.info(f"LangGraph 绑定 MCP 动态工具: {[t.name for t in dynamic_tools]}")
|
||||
else:
|
||||
# Fallback to local weather/search when MCP not available
|
||||
self.tools = [
|
||||
WeatherQueryTool(),
|
||||
TavilySearchTool(),
|
||||
] + base_tools
|
||||
logger.info("MCP 不可用,已回退到本地 Weather/Search 工具")
|
||||
|
||||
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from database if available."""
|
||||
if self.config_service:
|
||||
try:
|
||||
db_config = self.config_service.get_active_config()
|
||||
if db_config:
|
||||
# Update config with database values
|
||||
config_dict = db_config.config_data
|
||||
for key, value in config_dict.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
logger.info("Loaded configuration from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config from database: {e}")
|
||||
|
||||
|
||||
|
||||
def _create_react_agent(self):
|
||||
"""Create LangGraph agent using low-level StateGraph with explicit nodes/edges."""
|
||||
try:
|
||||
# Initialize the model
|
||||
llm_config = get_settings().llm.get_current_config()
|
||||
self.model = init_chat_model(
|
||||
model=llm_config['model'],
|
||||
model_provider='openai',
|
||||
temperature=llm_config['temperature'],
|
||||
max_tokens=llm_config['max_tokens'],
|
||||
base_url= llm_config['base_url'],
|
||||
api_key=llm_config['api_key']
|
||||
)
|
||||
|
||||
# Bind tools to the model so it can propose tool calls
|
||||
try:
|
||||
self.bound_model = self.model.bind_tools(self.tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}")
|
||||
self.bound_model = self.model
|
||||
|
||||
# Build low-level React graph: State -> agent -> tools -> agent ... until stop
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langchain_core.messages import ToolMessage, BaseMessage
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[List[BaseMessage], add_messages]
|
||||
|
||||
# Node: call the model
|
||||
def agent_node(state: AgentState) -> AgentState:
|
||||
messages = state["messages"]
|
||||
# Optionally include a system instruction at the start for first turn
|
||||
if messages and messages[0].__class__.__name__ != 'SystemMessage':
|
||||
# Keep user history untouched; rely on upstream to include system if desired
|
||||
pass
|
||||
ai = self.bound_model.invoke(messages)
|
||||
return {"messages": [ai]}
|
||||
|
||||
# Node: execute tools requested by the last AI message
|
||||
def tools_node(state: AgentState) -> AgentState:
|
||||
messages = state["messages"]
|
||||
last = messages[-1]
|
||||
outputs: List[ToolMessage] = []
|
||||
try:
|
||||
tool_calls = getattr(last, 'tool_calls', []) or []
|
||||
tool_map = {t.name: t for t in self.tools}
|
||||
for call in tool_calls:
|
||||
name = call.get('name') if isinstance(call, dict) else getattr(call, 'name', None)
|
||||
args = call.get('args') if isinstance(call, dict) else getattr(call, 'args', {})
|
||||
call_id = call.get('id') if isinstance(call, dict) else getattr(call, 'id', '')
|
||||
if name in tool_map:
|
||||
try:
|
||||
result = tool_map[name].invoke(args)
|
||||
except Exception as te:
|
||||
result = f"Tool {name} execution error: {te}"
|
||||
else:
|
||||
result = f"Unknown tool: {name}"
|
||||
outputs.append(ToolMessage(content=str(result), tool_call_id=call_id))
|
||||
except Exception as e:
|
||||
outputs.append(ToolMessage(content=f"Tool execution error: {e}", tool_call_id=""))
|
||||
return {"messages": outputs}
|
||||
|
||||
# Router: decide next step after agent node
|
||||
def route_after_agent(state: AgentState) -> str:
|
||||
last = state["messages"][-1]
|
||||
finish_reason = None
|
||||
try:
|
||||
meta = getattr(last, 'response_metadata', {}) or {}
|
||||
finish_reason = meta.get('finish_reason')
|
||||
except Exception:
|
||||
finish_reason = None
|
||||
# If the model decided to call tools, continue to tools node
|
||||
if getattr(last, 'tool_calls', None):
|
||||
return "tools"
|
||||
# Otherwise, end
|
||||
return END
|
||||
|
||||
graph = StateGraph(AgentState)
|
||||
graph.add_node("agent", agent_node)
|
||||
graph.add_node("tools", tools_node)
|
||||
graph.add_edge(START, "agent")
|
||||
graph.add_conditional_edges("agent", route_after_agent, {"tools": "tools", END: END})
|
||||
graph.add_edge("tools", "agent")
|
||||
|
||||
# Compile graph and store as self.agent for compatibility with existing code
|
||||
self.react_agent = graph.compile()
|
||||
|
||||
logger.info("LangGraph low-level React agent created successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _format_tools_info(self) -> str:
|
||||
"""Format tools information for the prompt."""
|
||||
tools_info = []
|
||||
for tool_name in self.config.enabled_tools:
|
||||
tool = self.tool_registry.get_tool(tool_name)
|
||||
if tool:
|
||||
params_info = []
|
||||
for param in tool.get_parameters():
|
||||
params_info.append(f" - {param.name} ({param.type.value}): {param.description}")
|
||||
|
||||
tool_info = f"**{tool.get_name()}**: {tool.get_description()}"
|
||||
if params_info:
|
||||
tool_info += "\n" + "\n".join(params_info)
|
||||
tools_info.append(tool_info)
|
||||
|
||||
return "\n\n".join(tools_info)
|
||||
|
||||
|
||||
|
||||
async def chat(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Process a chat message using LangGraph."""
|
||||
try:
|
||||
logger.info(f"Starting chat with message: {message[:100]}...")
|
||||
|
||||
# Convert chat history to messages
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Add current message
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# Use the low-level graph directly
|
||||
result = await self.react_agent.ainvoke({"messages": messages}, {"recursion_limit": 6}, )
|
||||
|
||||
# Extract final response
|
||||
final_response = ""
|
||||
if "messages" in result and result["messages"]:
|
||||
last_message = result["messages"][-1]
|
||||
if hasattr(last_message, 'content'):
|
||||
final_response = last_message.content
|
||||
elif isinstance(last_message, dict) and "content" in last_message:
|
||||
final_response = last_message["content"]
|
||||
|
||||
return {
|
||||
"response": final_response,
|
||||
"intermediate_steps": [],
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangGraph chat error: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"response": f"抱歉,处理您的请求时出现错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[
|
||||
Dict[str, Any], None]:
|
||||
"""Process a chat message using LangGraph with streaming."""
|
||||
try:
|
||||
logger.info(f"Starting streaming chat with message: {message[:100]}...")
|
||||
|
||||
# Convert chat history to messages
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if msg["role"] == "user":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "assistant":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
|
||||
# Add current message
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# Track state for streaming
|
||||
intermediate_steps = []
|
||||
final_response_started = False
|
||||
accumulated_response = ""
|
||||
final_ai_message = None
|
||||
|
||||
# Stream the agent execution
|
||||
async for event in self.react_agent.astream({"messages": messages}):
|
||||
# Handle different event types from LangGraph
|
||||
print('event===', event)
|
||||
if isinstance(event, dict):
|
||||
for node_name, node_output in event.items():
|
||||
logger.info(f"Processing node: {node_name}, output type: {type(node_output)}")
|
||||
|
||||
# 处理 tools 节点
|
||||
if "tools" in node_name.lower():
|
||||
# 提取工具信息
|
||||
tool_infos = []
|
||||
|
||||
if isinstance(node_output, dict) and "messages" in node_output:
|
||||
messages_in_output = node_output["messages"]
|
||||
|
||||
for msg in messages_in_output:
|
||||
# 处理 ToolMessage 对象
|
||||
if hasattr(msg, 'name') and hasattr(msg, 'content'):
|
||||
tool_info = {
|
||||
"tool_name": msg.name,
|
||||
"tool_output": msg.content,
|
||||
"tool_call_id": getattr(msg, 'tool_call_id', ''),
|
||||
"status": "completed"
|
||||
}
|
||||
tool_infos.append(tool_info)
|
||||
elif isinstance(msg, dict):
|
||||
if 'name' in msg and 'content' in msg:
|
||||
tool_info = {
|
||||
"tool_name": msg['name'],
|
||||
"tool_output": msg['content'],
|
||||
"tool_call_id": msg.get('tool_call_id', ''),
|
||||
"status": "completed"
|
||||
}
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
# 返回 tools_end 事件
|
||||
for tool_info in tool_infos:
|
||||
yield {
|
||||
"type": "tools_end",
|
||||
"content": f"工具 {tool_info['tool_name']} 执行完成",
|
||||
"tool_name": tool_info["tool_name"],
|
||||
"tool_output": tool_info["tool_output"],
|
||||
"node_name": node_name,
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 处理 agent 节点
|
||||
elif "agent" in node_name.lower():
|
||||
if isinstance(node_output, dict) and "messages" in node_output:
|
||||
messages_in_output = node_output["messages"]
|
||||
if messages_in_output:
|
||||
last_msg = messages_in_output[-1]
|
||||
|
||||
# 获取 finish_reason
|
||||
finish_reason = None
|
||||
if hasattr(last_msg, 'response_metadata'):
|
||||
finish_reason = last_msg.response_metadata.get('finish_reason')
|
||||
elif isinstance(last_msg, dict) and 'response_metadata' in last_msg:
|
||||
finish_reason = last_msg['response_metadata'].get('finish_reason')
|
||||
|
||||
# 判断是否为 thinking 或 response
|
||||
if finish_reason == 'tool_calls':
|
||||
# thinking 状态
|
||||
thinking_content = "🤔 正在思考..."
|
||||
if hasattr(last_msg, 'content') and last_msg.content:
|
||||
thinking_content = f"🤔 思考: {last_msg.content[:200]}..."
|
||||
elif isinstance(last_msg, dict) and "content" in last_msg:
|
||||
thinking_content = f"🤔 思考: {last_msg['content'][:200]}..."
|
||||
|
||||
yield {
|
||||
"type": "thinking",
|
||||
"content": thinking_content,
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
elif finish_reason == 'stop':
|
||||
# response 状态
|
||||
if hasattr(last_msg, 'content') and hasattr(last_msg,
|
||||
'__class__') and 'AI' in last_msg.__class__.__name__:
|
||||
current_content = last_msg.content
|
||||
final_ai_message = last_msg
|
||||
|
||||
if not final_response_started and current_content:
|
||||
final_response_started = True
|
||||
yield {
|
||||
"type": "response_start",
|
||||
"content": "",
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": False
|
||||
}
|
||||
|
||||
if current_content and len(current_content) > len(accumulated_response):
|
||||
new_content = current_content[len(accumulated_response):]
|
||||
|
||||
for char in new_content:
|
||||
accumulated_response += char
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": accumulated_response,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
else:
|
||||
# 其他 agent 状态
|
||||
yield {
|
||||
"type": "step",
|
||||
"content": f"📋 执行步骤: {node_name}",
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 处理其他节点
|
||||
else:
|
||||
yield {
|
||||
"type": "step",
|
||||
"content": f"📋 执行步骤: {node_name}",
|
||||
"node_name": node_name,
|
||||
"raw_output": str(node_output)[:500] if node_output else "",
|
||||
"done": False
|
||||
}
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 最终完成事件
|
||||
yield {
|
||||
"type": "complete",
|
||||
"content": accumulated_response,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_stream: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"处理请求时出错: {str(e)}",
|
||||
"done": True
|
||||
}
|
||||
|
||||
# 确保最终响应包含完整内容
|
||||
final_content = accumulated_response
|
||||
if not final_content and final_ai_message and hasattr(final_ai_message, 'content'):
|
||||
final_content = final_ai_message.content or ""
|
||||
|
||||
# Final completion signal
|
||||
yield {
|
||||
"type": "response",
|
||||
"content": final_content,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
"done": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangGraph chat stream error: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
"content": f"抱歉,处理您的请求时出现错误: {str(e)}",
|
||||
"error": str(e),
|
||||
"done": True
|
||||
}
|
||||
|
||||
def get_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of available tools."""
|
||||
tools = []
|
||||
for tool in self.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": [],
|
||||
"enabled": True
|
||||
})
|
||||
return tools
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current agent configuration."""
|
||||
return self.config.dict()
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update agent configuration."""
|
||||
for key, value in config.items():
|
||||
if hasattr(self.config, key):
|
||||
setattr(self.config, key, value)
|
||||
|
||||
# Recreate agent with new config
|
||||
self._create_react_agent()
|
||||
logger.info("Agent configuration updated")
|
||||
|
||||
def _create_plan_execute_agent(self):
|
||||
"""Create a Plan-and-Execute agent using LangGraph low-level API.
|
||||
结构:START -> planner -> executor(loop) -> summarize -> END
|
||||
- planner:根据用户问题生成计划(JSON 数组)
|
||||
- executor:逐步执行计划(可调用工具),收集每步结果
|
||||
- summarize:综合计划与执行结果产出最终回答
|
||||
"""
|
||||
from typing import TypedDict, Annotated, List
|
||||
import json
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
|
||||
try:
|
||||
self.bound_model = self.model.bind_tools(self.tools)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to bind tools to model, tool calling may not work: {e}")
|
||||
self.bound_model = self.model
|
||||
class PlanState(TypedDict):
|
||||
messages: Annotated[List[BaseMessage], add_messages]
|
||||
plan_steps: List[str]
|
||||
current_step: int
|
||||
step_results: List[str]
|
||||
|
||||
def planner_node(state: PlanState) -> PlanState:
|
||||
messages = state.get("messages", [])
|
||||
plan_prompt = (
|
||||
"你是规划助手。基于对话内容生成可执行计划,"
|
||||
"用 JSON 数组返回,每个元素是一条明确且可操作的步骤。"
|
||||
"仅输出 JSON,不要额外解释。"
|
||||
)
|
||||
ai_plan = self.model.invoke(messages + [HumanMessage(content=plan_prompt)])
|
||||
steps: List[str] = []
|
||||
try:
|
||||
parsed = json.loads(ai_plan.content)
|
||||
if isinstance(parsed, list):
|
||||
steps = [str(s) for s in parsed]
|
||||
except Exception:
|
||||
# 回退:按行拆分
|
||||
steps = [s.strip() for s in ai_plan.content.split("\n") if s.strip()]
|
||||
return {
|
||||
"messages": [ai_plan],
|
||||
"plan_steps": steps,
|
||||
"current_step": 0,
|
||||
"step_results": []
|
||||
}
|
||||
|
||||
def executor_node(state: PlanState) -> PlanState:
|
||||
idx = state.get("current_step", 0)
|
||||
steps = state.get("plan_steps", [])
|
||||
msgs = state.get("messages", [])
|
||||
if idx >= len(steps):
|
||||
return {"messages": [], "current_step": idx, "step_results": state.get("step_results", [])}
|
||||
|
||||
step_text = steps[idx]
|
||||
exec_prompt = (
|
||||
f"请执行计划的第{idx+1}步:{step_text}。"
|
||||
"需要用工具时创建工具调用;完成后给出该步的结果。"
|
||||
)
|
||||
ai_exec = self.bound_model.invoke(msgs + [HumanMessage(content=exec_prompt)])
|
||||
|
||||
new_messages: List[BaseMessage] = [ai_exec]
|
||||
step_result_content = None
|
||||
|
||||
# 处理工具调用
|
||||
tool_map = {t.name: t for t in self.tools}
|
||||
tool_msgs: List[ToolMessage] = []
|
||||
tool_calls = getattr(ai_exec, "tool_calls", []) or (ai_exec.additional_kwargs.get("tool_calls") if hasattr(ai_exec, "additional_kwargs") else [])
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
name = call.get("name")
|
||||
args = call.get("args", {})
|
||||
tool_obj = tool_map.get(name)
|
||||
if tool_obj:
|
||||
try:
|
||||
result = tool_obj.invoke(args)
|
||||
except Exception as e:
|
||||
result = f"工具执行失败: {e}"
|
||||
else:
|
||||
result = f"未找到工具: {name}"
|
||||
tool_call_id = call.get("id") or call.get("tool_call_id") or call.get("call_id") or f"tool_{name}"
|
||||
tool_msgs.append(ToolMessage(content=str(result), tool_call_id=tool_call_id, name=name or "tool"))
|
||||
new_messages.extend(tool_msgs)
|
||||
# 基于工具输出总结该步结果
|
||||
summarize_step = "请基于上述工具输出,总结该步骤的结果,给出结构化要点与可读说明。"
|
||||
ai_step = self.bound_model.invoke(msgs + [ai_exec] + tool_msgs + [HumanMessage(content=summarize_step)])
|
||||
step_result_content = ai_step.content
|
||||
new_messages.append(ai_step)
|
||||
else:
|
||||
step_result_content = ai_exec.content
|
||||
|
||||
all_results = list(state.get("step_results", []))
|
||||
if step_result_content:
|
||||
all_results.append(step_result_content)
|
||||
|
||||
return {
|
||||
"messages": new_messages,
|
||||
"current_step": idx + 1,
|
||||
"step_results": all_results
|
||||
}
|
||||
|
||||
def route_after_planner(state: PlanState) -> str:
|
||||
return "executor" if state.get("plan_steps") else END
|
||||
|
||||
def route_after_executor(state: PlanState) -> str:
|
||||
cur = state.get("current_step", 0)
|
||||
total = len(state.get("plan_steps", []))
|
||||
return "executor" if cur < total else "summarize"
|
||||
|
||||
def summarize_node(state: PlanState) -> PlanState:
|
||||
import json as _json
|
||||
msgs = state.get("messages", [])
|
||||
steps = state.get("plan_steps", [])
|
||||
results = state.get("step_results", [])
|
||||
final_prompt = (
|
||||
"请综合以上计划与各步骤结果,生成最终回答。"
|
||||
"要求:逻辑清晰、结论明确、可读性强;如存在不确定性请注明。"
|
||||
)
|
||||
context_msg = HumanMessage(content=(
|
||||
f"计划: {_json.dumps(steps, ensure_ascii=False)}\n"
|
||||
f"步骤结果: {_json.dumps(results, ensure_ascii=False)}\n"
|
||||
f"{final_prompt}"
|
||||
))
|
||||
ai_final = self.model.invoke(msgs + [context_msg])
|
||||
return {"messages": [ai_final]}
|
||||
|
||||
graph = StateGraph(PlanState)
|
||||
graph.add_node("planner", planner_node)
|
||||
graph.add_node("executor", executor_node)
|
||||
graph.add_node("summarize", summarize_node)
|
||||
graph.add_edge(START, "planner")
|
||||
graph.add_conditional_edges("planner", route_after_planner, {"executor": "executor", END: END})
|
||||
graph.add_conditional_edges("executor", route_after_executor, {"executor": "executor", "summarize": "summarize"})
|
||||
graph.add_edge("summarize", END)
|
||||
|
||||
self.plan_execute_agent = graph.compile()
|
||||
|
||||
async def chat_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> Dict[str, Any]:
|
||||
"""Single-turn Plan-and-Execute chat."""
|
||||
# 确保 agent 已创建
|
||||
if not hasattr(self, "plan_execute_agent"):
|
||||
self._create_plan_execute_agent()
|
||||
|
||||
# 构建消息
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
try:
|
||||
result = await self.plan_execute_agent.ainvoke({"messages": messages}, config={"recursion_limit": self.config.max_iterations})
|
||||
final_msg = None
|
||||
if isinstance(result, dict) and "messages" in result:
|
||||
ms = result["messages"]
|
||||
if ms:
|
||||
final_msg = ms[-1]
|
||||
final_text = getattr(final_msg, "content", "") if final_msg else ""
|
||||
return {
|
||||
"status": "success",
|
||||
"response": final_text,
|
||||
"raw": str(result)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_plan_execute: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def chat_stream_plan_execute(self, message: str, chat_history: Optional[List[Dict[str, str]]] = None) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Streamed Plan-and-Execute chat."""
|
||||
import asyncio as _asyncio
|
||||
if not hasattr(self, "plan_execute_agent"):
|
||||
self._create_plan_execute_agent()
|
||||
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
try:
|
||||
accumulated = ""
|
||||
async for event in self.react_agent.astream({"messages": messages}, config={"recursion_limit": self.config.max_iterations}):
|
||||
for key, node_output in event.items():
|
||||
node_name = key[0] if isinstance(key, tuple) else key
|
||||
if node_name == "planner":
|
||||
# 规划阶段
|
||||
content = "生成计划中..."
|
||||
if node_output and isinstance(node_output, dict):
|
||||
m = node_output.get("messages", [])
|
||||
if m:
|
||||
last = m[-1]
|
||||
if hasattr(last, "content"):
|
||||
content = str(last.content)[:400]
|
||||
yield {"type": "planning", "content": content, "done": False}
|
||||
await _asyncio.sleep(0.05)
|
||||
elif node_name == "executor":
|
||||
# 执行阶段(可能包含工具)
|
||||
yield {"type": "step", "content": "执行计划步骤", "done": False}
|
||||
await _asyncio.sleep(0.05)
|
||||
if node_output and isinstance(node_output, dict):
|
||||
msgs = node_output.get("messages", [])
|
||||
# 输出工具结束标记
|
||||
tool_msgs = [m for m in msgs if hasattr(m, "__class__") and "Tool" in m.__class__.__name__]
|
||||
if tool_msgs:
|
||||
yield {"type": "tools_end", "content": f"完成 {len(tool_msgs)} 次工具执行", "done": False}
|
||||
await _asyncio.sleep(0.03)
|
||||
# 尝试输出该步总结
|
||||
ai_msgs = [m for m in msgs if hasattr(m, "__class__") and "AI" in m.__class__.__name__]
|
||||
if ai_msgs:
|
||||
text = ai_msgs[-1].content
|
||||
if text:
|
||||
accumulated = text
|
||||
yield {"type": "response", "content": accumulated, "done": False}
|
||||
await _asyncio.sleep(0.02)
|
||||
elif node_name == "summarize":
|
||||
# 最终总结
|
||||
if node_output and isinstance(node_output, dict):
|
||||
msgs = node_output.get("messages", [])
|
||||
if msgs:
|
||||
final = msgs[-1]
|
||||
content = getattr(final, "content", "")
|
||||
if content:
|
||||
yield {"type": "response_start", "content": "", "done": False}
|
||||
yield {"type": "response", "content": content, "done": False}
|
||||
accumulated = content
|
||||
await _asyncio.sleep(0.02)
|
||||
yield {"type": "complete", "content": accumulated, "done": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_stream_plan_execute: {e}", exc_info=True)
|
||||
yield {"type": "error", "content": str(e), "done": True}
|
||||
|
||||
|
||||
# Global instance
|
||||
_langgraph_agent_service: Optional[LangGraphAgentService] = None
|
||||
|
||||
|
||||
def get_langgraph_agent_service(db_session=None) -> LangGraphAgentService:
|
||||
"""Get or create LangGraph agent service instance."""
|
||||
global _langgraph_agent_service
|
||||
|
||||
if _langgraph_agent_service is None:
|
||||
_langgraph_agent_service = LangGraphAgentService(db_session)
|
||||
logger.info("LangGraph Agent service initialized")
|
||||
|
||||
return _langgraph_agent_service
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
"""Agent configuration service."""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
|
||||
from ..models.agent_config import AgentConfig
|
||||
from utils.util_exceptions import ValidationError, NotFoundError
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AgentConfigService:
|
||||
"""Service for managing agent configurations."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_config(self, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
"""Create a new agent configuration."""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not config_data.get("name"):
|
||||
raise ValidationError("Configuration name is required")
|
||||
|
||||
# Check if name already exists
|
||||
existing = self.db.query(AgentConfig).filter(
|
||||
AgentConfig.name == config_data["name"]
|
||||
).first()
|
||||
if existing:
|
||||
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
|
||||
|
||||
# Create new configuration
|
||||
config = AgentConfig(
|
||||
name=config_data["name"],
|
||||
description=config_data.get("description", ""),
|
||||
enabled_tools=config_data.get("enabled_tools", ["calculator", "weather", "search", "datetime", "file"]),
|
||||
max_iterations=config_data.get("max_iterations", 10),
|
||||
temperature=config_data.get("temperature", 0.1),
|
||||
system_message=config_data.get("system_message", "You are a helpful AI assistant with access to various tools. Use the available tools to help answer user questions accurately. Always explain your reasoning and the tools you're using."),
|
||||
verbose=config_data.get("verbose", True),
|
||||
is_active=config_data.get("is_active", True),
|
||||
is_default=config_data.get("is_default", False)
|
||||
)
|
||||
|
||||
# If this is set as default, unset other defaults
|
||||
if config.is_default:
|
||||
self.db.query(AgentConfig).filter(
|
||||
AgentConfig.is_default == True
|
||||
).update({"is_default": False})
|
||||
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(f"Created agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error creating agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_config(self, config_id: int) -> Optional[AgentConfig]:
|
||||
"""Get agent configuration by ID."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
AgentConfig.id == config_id
|
||||
).first()
|
||||
|
||||
def get_config_by_name(self, name: str) -> Optional[AgentConfig]:
|
||||
"""Get agent configuration by name."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
AgentConfig.name == name
|
||||
).first()
|
||||
|
||||
def get_default_config(self) -> Optional[AgentConfig]:
|
||||
"""Get default agent configuration."""
|
||||
return self.db.query(AgentConfig).filter(
|
||||
and_(AgentConfig.is_default == True, AgentConfig.is_active == True)
|
||||
).first()
|
||||
|
||||
def list_configs(self, active_only: bool = True) -> List[AgentConfig]:
|
||||
"""List all agent configurations."""
|
||||
query = self.db.query(AgentConfig)
|
||||
if active_only:
|
||||
query = query.filter(AgentConfig.is_active == True)
|
||||
return query.order_by(AgentConfig.created_at.desc()).all()
|
||||
|
||||
def update_config(self, config_id: int, config_data: Dict[str, Any]) -> AgentConfig:
|
||||
"""Update agent configuration."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
# Check if name change conflicts with existing
|
||||
if "name" in config_data and config_data["name"] != config.name:
|
||||
existing = self.db.query(AgentConfig).filter(
|
||||
and_(
|
||||
AgentConfig.name == config_data["name"],
|
||||
AgentConfig.id != config_id
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
raise ValidationError(f"Configuration with name '{config_data['name']}' already exists")
|
||||
|
||||
# Update fields
|
||||
for key, value in config_data.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# If this is set as default, unset other defaults
|
||||
if config_data.get("is_default", False):
|
||||
self.db.query(AgentConfig).filter(
|
||||
and_(
|
||||
AgentConfig.is_default == True,
|
||||
AgentConfig.id != config_id
|
||||
)
|
||||
).update({"is_default": False})
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(f"Updated agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_config(self, config_id: int) -> bool:
|
||||
"""Delete agent configuration (soft delete by setting is_active=False)."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
# Don't allow deleting the default configuration
|
||||
if config.is_default:
|
||||
raise ValidationError("Cannot delete the default configuration")
|
||||
|
||||
config.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Deleted agent configuration: {config.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error deleting agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def set_default_config(self, config_id: int) -> AgentConfig:
|
||||
"""Set a configuration as default."""
|
||||
try:
|
||||
config = self.get_config(config_id)
|
||||
if not config:
|
||||
raise NotFoundError(f"Agent configuration with ID {config_id} not found")
|
||||
|
||||
if not config.is_active:
|
||||
raise ValidationError("Cannot set inactive configuration as default")
|
||||
|
||||
# Unset other defaults
|
||||
self.db.query(AgentConfig).filter(
|
||||
AgentConfig.is_default == True
|
||||
).update({"is_default": False})
|
||||
|
||||
# Set this as default
|
||||
config.is_default = True
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(f"Set default agent configuration: {config.name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error setting default agent configuration: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_config_dict(self, config_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""Get configuration as dictionary. If no ID provided, returns default config."""
|
||||
if config_id:
|
||||
config = self.get_config(config_id)
|
||||
else:
|
||||
config = self.get_default_config()
|
||||
|
||||
if not config:
|
||||
# Return default values if no configuration found
|
||||
return {
|
||||
"enabled_tools": ["calculator", "weather", "search", "datetime", "file", "generate_image"],
|
||||
"max_iterations": 10,
|
||||
"temperature": 0.1,
|
||||
"system_message": "You are a helpful AI assistant with access to various tools. Use the available tools to help answer user questions accurately. Always explain your reasoning and the tools you're using.",
|
||||
"verbose": True
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled_tools": config.enabled_tools,
|
||||
"max_iterations": config.max_iterations,
|
||||
"temperature": config.temperature,
|
||||
"system_message": config.system_message,
|
||||
"verbose": config.verbose
|
||||
}
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
"""Authentication service."""
|
||||
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
import bcrypt
|
||||
import jwt
|
||||
|
||||
from ..core.config import settings
|
||||
from ..db.database import get_session
|
||||
from ..models.user import User
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
class AuthService:
|
||||
"""Authentication service."""
|
||||
@staticmethod
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
session: Session = Depends(get_session)
|
||||
) -> User:
|
||||
"""Get current authenticated user."""
|
||||
from ..core.context import UserContext
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
session.desc = f"[AuthService] 取得token: {token[:50]}..."
|
||||
payload = AuthService.verify_token(token)
|
||||
if payload is None:
|
||||
session.desc = "ERROR: 令牌验证失败"
|
||||
raise credentials_exception
|
||||
|
||||
session.desc = f"[AuthService] 令牌有效 - 解析得到有效载荷: {payload}"
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
session.desc = "ERROR: 令牌中没有用户名"
|
||||
raise credentials_exception
|
||||
|
||||
session.desc = f"[AuthService] 获取当前用户 - 查找名为 {username} 的用户"
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if user is None:
|
||||
session.desc = f"ERROR: 数据库中未找到用户 {username}"
|
||||
raise credentials_exception
|
||||
|
||||
# Set user in context for global access
|
||||
UserContext.set_current_user(user)
|
||||
session.desc = f"[AuthService] 用户 {user.username} (ID: {user.id}) 已设置为当前用户"
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Get current active user."""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user_by_email(session: Session, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with email and password."""
|
||||
session.desc = f"根据邮箱 {email} 验证用户密码"
|
||||
stmt = select(User).where(User.email == email)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
return None
|
||||
if not AuthService.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(session: Session, username: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with username and password."""
|
||||
session.desc = f"根据用户名 {username} 验证用户密码"
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
return None
|
||||
if not AuthService.verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def create_access_token(session: Session, data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""创建 JWT 访问 token"""
|
||||
session.desc = f"创建 JWT 访问 token - 数据: {data}"
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.security.secret_key,
|
||||
algorithm=settings.security.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
# 直接使用bcrypt库进行密码验证
|
||||
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
|
||||
|
||||
@staticmethod
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate password hash."""
|
||||
# 直接使用bcrypt库进行哈希
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_bytes = bcrypt.hashpw(password.encode('utf-8'), salt)
|
||||
hashed_password = hashed_bytes.decode('utf-8')
|
||||
return hashed_password
|
||||
|
||||
@staticmethod
|
||||
def verify_token(token: str) -> Optional[dict]:
|
||||
"""Verify JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.security.secret_key,
|
||||
algorithms=[settings.security.algorithm]
|
||||
)
|
||||
return payload
|
||||
except jwt.PyJWTError as e:
|
||||
logger.error(f"Token verification failed: {e}")
|
||||
logger.error(f"Token: {token[:50]}...")
|
||||
logger.error(f"Secret key: {settings.security.secret_key[:20]}...")
|
||||
logger.error(f"Algorithm: {settings.security.algorithm}")
|
||||
return None
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
"""Chat service for AI model integration using LangChain."""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional, List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from loguru import logger
|
||||
from ..core.config import settings
|
||||
from ..models.message import MessageRole
|
||||
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
||||
from utils.util_exceptions import ChatServiceError, OpenAIError
|
||||
from .conversation import ConversationService
|
||||
from .langchain_chat import LangChainChatService
|
||||
from .knowledge_chat import KnowledgeChatService
|
||||
from .agent.agent_service import get_agent_service
|
||||
from .agent.langgraph_agent_service import get_langgraph_agent_service
|
||||
|
||||
class ChatService:
|
||||
"""Service for handling AI chat functionality using LangChain."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
# Initialize LangChain chat service
|
||||
self.langchain_service = LangChainChatService(db)
|
||||
|
||||
# Initialize Knowledge chat service
|
||||
self.knowledge_service = KnowledgeChatService(db)
|
||||
|
||||
# Initialize Agent service with database session
|
||||
self.agent_service = get_agent_service(db)
|
||||
|
||||
# Initialize LangGraph Agent service with database session
|
||||
self.langgraph_service = get_langgraph_agent_service(db)
|
||||
|
||||
logger.info("ChatService initialized with LangChain backend and Agent support")
|
||||
|
||||
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_agent: bool = False,
|
||||
use_langgraph: bool = False,
|
||||
use_knowledge_base: bool = False,
|
||||
knowledge_base_id: Optional[int] = None
|
||||
) -> ChatResponse:
|
||||
"""Send a message and get AI response using LangChain, Agent, or Knowledge Base."""
|
||||
if use_knowledge_base and knowledge_base_id:
|
||||
logger.info(f"Processing chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
||||
|
||||
# Use knowledge base chat service
|
||||
return await self.knowledge_service.chat_with_knowledge_base(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
elif use_langgraph:
|
||||
logger.info(f"Processing chat request for conversation {conversation_id} via LangGraph Agent")
|
||||
|
||||
# Get conversation history for LangGraph agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Use LangGraph agent service
|
||||
agent_result = await self.langgraph_service.chat(message, chat_history)
|
||||
|
||||
if agent_result["success"]:
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Save assistant response
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"intermediate_steps": agent_result["intermediate_steps"]}
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
message=MessageResponse(
|
||||
id=assistant_message.id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
conversation_id=conversation_id,
|
||||
created_at=assistant_message.created_at,
|
||||
metadata=assistant_message.metadata
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ChatServiceError(f"LangGraph Agent error: {agent_result.get('error', 'Unknown error')}")
|
||||
elif use_agent:
|
||||
logger.info(f"Processing chat request for conversation {conversation_id} via Agent")
|
||||
|
||||
# Get conversation history for agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Use agent service
|
||||
agent_result = await self.agent_service.chat(message, chat_history)
|
||||
|
||||
if agent_result["success"]:
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Save assistant response
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"tool_calls": agent_result["tool_calls"]}
|
||||
)
|
||||
|
||||
return ChatResponse(
|
||||
message=MessageResponse(
|
||||
id=assistant_message.id,
|
||||
content=agent_result["response"],
|
||||
role=MessageRole.ASSISTANT,
|
||||
conversation_id=conversation_id,
|
||||
created_at=assistant_message.created_at,
|
||||
metadata=assistant_message.metadata
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ChatServiceError(f"Agent error: {agent_result.get('error', 'Unknown error')}")
|
||||
else:
|
||||
logger.info(f"Processing chat request for conversation {conversation_id} via LangChain")
|
||||
|
||||
# Delegate to LangChain service
|
||||
return await self.langchain_service.chat(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_agent: bool = False,
|
||||
use_langgraph: bool = False,
|
||||
use_knowledge_base: bool = False,
|
||||
knowledge_base_id: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Send a message and get streaming AI response using LangChain, Agent, or Knowledge Base."""
|
||||
if use_knowledge_base and knowledge_base_id:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Knowledge Base {knowledge_base_id}")
|
||||
|
||||
# Use knowledge base chat service streaming
|
||||
async for content in self.knowledge_service.chat_stream_with_knowledge_base(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
knowledge_base_id=knowledge_base_id,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
):
|
||||
# Create stream chunk for compatibility with existing API
|
||||
stream_chunk = StreamChunk(
|
||||
content=content,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
|
||||
elif use_langgraph:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangGraph Agent")
|
||||
|
||||
# Get conversation history for LangGraph agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Save user message first
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Use LangGraph agent service streaming
|
||||
full_response = ""
|
||||
intermediate_steps = []
|
||||
|
||||
async for chunk in self.langgraph_service.chat_stream(message, chat_history):
|
||||
if chunk["type"] == "response":
|
||||
full_response = chunk["content"]
|
||||
intermediate_steps = chunk.get("intermediate_steps", [])
|
||||
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
elif chunk["type"] == "error":
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
return
|
||||
else:
|
||||
# For other types (status, step, etc.), pass through
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"intermediate_steps": intermediate_steps}
|
||||
)
|
||||
elif use_agent:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via Agent")
|
||||
|
||||
# Get conversation history for agent
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError(f"Conversation {conversation_id} not found")
|
||||
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
chat_history = [{
|
||||
"role": "user" if msg.role == MessageRole.USER else "assistant",
|
||||
"content": msg.content
|
||||
} for msg in messages]
|
||||
|
||||
# Save user message first
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Use agent service streaming
|
||||
full_response = ""
|
||||
tool_calls = []
|
||||
|
||||
async for chunk in self.agent_service.chat_stream(message, chat_history):
|
||||
if chunk["type"] == "response":
|
||||
full_response = chunk["content"]
|
||||
tool_calls = chunk.get("tool_calls", [])
|
||||
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
elif chunk["type"] == "error":
|
||||
# Return the chunk as-is to maintain type information
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
return
|
||||
else:
|
||||
# For other types (status, tool_start, etc.), pass through
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={"tool_calls": tool_calls}
|
||||
)
|
||||
else:
|
||||
logger.info(f"Processing streaming chat request for conversation {conversation_id} via LangChain")
|
||||
|
||||
# Delegate to LangChain service and wrap response in JSON format
|
||||
async for content in self.langchain_service.chat_stream(
|
||||
conversation_id=conversation_id,
|
||||
message=message,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
):
|
||||
# Create stream chunk for compatibility with existing API
|
||||
stream_chunk = StreamChunk(
|
||||
content=content,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
yield json.dumps(stream_chunk.dict(), ensure_ascii=False)
|
||||
|
||||
async def get_available_models(self) -> List[str]:
|
||||
"""Get list of available models from LangChain."""
|
||||
logger.info("Getting available models via LangChain")
|
||||
|
||||
# Delegate to LangChain service
|
||||
return await self.langchain_service.get_available_models()
|
||||
|
||||
def update_model_config(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
):
|
||||
"""Update LLM configuration via LangChain."""
|
||||
logger.info(f"Updating model config via LangChain: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
||||
|
||||
# Delegate to LangChain service
|
||||
self.langchain_service.update_model_config(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
"""Conversation service."""
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, desc, func, or_
|
||||
|
||||
from ..models.conversation import Conversation
|
||||
from ..models.message import Message, MessageRole
|
||||
from utils.util_schemas import ConversationCreate, ConversationUpdate
|
||||
from utils.util_exceptions import ConversationNotFoundError, DatabaseError
|
||||
from ..core.context import UserContext
|
||||
from datetime import datetime, timezone
|
||||
from loguru import logger
|
||||
|
||||
class ConversationService:
|
||||
"""Service for managing conversations and messages."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id: int,
|
||||
conversation_data: ConversationCreate
|
||||
) -> Conversation:
|
||||
"""Create a new conversation."""
|
||||
logger.info(f"Creating new conversation for user {user_id}: {conversation_data}")
|
||||
|
||||
try:
|
||||
conversation = Conversation(
|
||||
**conversation_data.model_dump(),
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Set audit fields
|
||||
conversation.set_audit_fields(user_id=user_id, is_update=False)
|
||||
|
||||
self.session.add(conversation)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(conversation)
|
||||
|
||||
logger.info(f"Successfully created conversation {conversation.id} for user {user_id}")
|
||||
return conversation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create conversation: {str(e)}", exc_info=True)
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Failed to create conversation: {str(e)}")
|
||||
|
||||
async def get_conversation(self, conversation_id: int) -> Optional[Conversation]:
|
||||
"""Get a conversation by ID."""
|
||||
try:
|
||||
user_id = UserContext.get_current_user_id()
|
||||
conversation = await self.session.scalar(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(f"Conversation {conversation_id} not found")
|
||||
|
||||
return conversation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation {conversation_id}: {str(e)}", exc_info=True)
|
||||
raise DatabaseError(f"Failed to get conversation: {str(e)}")
|
||||
|
||||
async def get_user_conversations(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
search_query: Optional[str] = None,
|
||||
include_archived: bool = False,
|
||||
order_by: str = "updated_at",
|
||||
order_desc: bool = True
|
||||
) -> List[Conversation]:
|
||||
"""Get user's conversations with search and filtering."""
|
||||
user_id = UserContext.get_current_user_id()
|
||||
query = select(Conversation).where(
|
||||
Conversation.user_id == user_id
|
||||
)
|
||||
|
||||
# Filter archived conversations
|
||||
if not include_archived:
|
||||
query = query.where(Conversation.is_archived == False)
|
||||
|
||||
# Search functionality
|
||||
if search_query and search_query.strip():
|
||||
search_term = f"%{search_query.strip()}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Conversation.title.ilike(search_term),
|
||||
Conversation.system_prompt.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
# Ordering
|
||||
order_column = getattr(Conversation, order_by, Conversation.updated_at)
|
||||
if order_desc:
|
||||
query = query.order_by(desc(order_column))
|
||||
else:
|
||||
query = query.order_by(order_column)
|
||||
|
||||
return (await self.session.scalars(query.offset(skip).limit(limit))).all()
|
||||
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: int,
|
||||
conversation_update: ConversationUpdate
|
||||
) -> Optional[Conversation]:
|
||||
"""Update a conversation."""
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
update_data = conversation_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(conversation, field, value)
|
||||
|
||||
# Update audit fields
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
await self.session.refresh(conversation)
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update conversation {conversation_id}: {str(e)}", exc_info=True)
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Failed to update conversation: {str(e)}")
|
||||
|
||||
async def delete_conversation(self, conversation_id: int) -> bool:
|
||||
"""Delete a conversation."""
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
return False
|
||||
|
||||
self.session.delete(conversation)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
async def get_conversation_messages(
|
||||
self,
|
||||
conversation_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Message]:
|
||||
"""Get messages from a conversation."""
|
||||
return (await self.session.scalars(
|
||||
select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).offset(skip).limit(limit)
|
||||
)).all()
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation_id: int,
|
||||
content: str,
|
||||
role: MessageRole,
|
||||
message_metadata: Optional[dict] = None,
|
||||
context_documents: Optional[list] = None,
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
total_tokens: Optional[int] = None
|
||||
) -> Message:
|
||||
"""Add a message to a conversation."""
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
content=content,
|
||||
role=role,
|
||||
message_metadata=message_metadata,
|
||||
context_documents=context_documents,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
|
||||
# Set audit fields
|
||||
message.set_audit_fields()
|
||||
|
||||
self.session.add(message)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(message)
|
||||
|
||||
# Update conversation's updated_at timestamp
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.now(timezone.utc)
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await self.session.commit()
|
||||
|
||||
return message
|
||||
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: int,
|
||||
limit: int = 20
|
||||
) -> List[Message]:
|
||||
"""Get recent conversation history for context."""
|
||||
return (await self.session.scalars(
|
||||
select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(desc(Message.created_at)).limit(limit)
|
||||
)).all()[::-1] # Reverse to get chronological order
|
||||
|
||||
async def update_conversation_timestamp(self, conversation_id: int) -> None:
|
||||
"""Update conversation's updated_at timestamp."""
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.now(timezone.utc)
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await self.session.commit()
|
||||
|
||||
async def get_user_conversations_count(
|
||||
self,
|
||||
search_query: Optional[str] = None,
|
||||
include_archived: bool = False
|
||||
) -> int:
|
||||
"""Get total count of user's conversations."""
|
||||
user_id = UserContext.get_current_user_id()
|
||||
query = select(func.count(Conversation.id)).where(
|
||||
Conversation.user_id == user_id
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
query = query.where(Conversation.is_archived == False)
|
||||
|
||||
if search_query and search_query.strip():
|
||||
search_term = f"%{search_query.strip()}%"
|
||||
query = query.where(
|
||||
or_(
|
||||
Conversation.title.ilike(search_term),
|
||||
Conversation.system_prompt.ilike(search_term)
|
||||
)
|
||||
)
|
||||
|
||||
return (await self.session.scalar(query)) or 0
|
||||
|
||||
async def archive_conversation(self, conversation_id: int) -> bool:
|
||||
"""Archive a conversation."""
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
return False
|
||||
|
||||
conversation.is_archived = True
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
async def unarchive_conversation(self, conversation_id: int) -> bool:
|
||||
"""Unarchive a conversation."""
|
||||
conversation = await self.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
return False
|
||||
|
||||
conversation.is_archived = False
|
||||
conversation.set_audit_fields(user_id=conversation.user_id, is_update=True)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
|
@ -0,0 +1,310 @@
|
|||
from typing import Dict, Any, List, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from th_agenter.models.conversation import Conversation
|
||||
from th_agenter.models.message import Message
|
||||
from th_agenter.db.database import get_session
|
||||
|
||||
class ConversationContextService:
|
||||
"""
|
||||
对话上下文管理服务
|
||||
用于管理智能问数的对话历史和上下文信息
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.context_cache = {} # 内存缓存对话上下文
|
||||
|
||||
async def create_conversation(self, user_id: int, title: str = "智能问数对话") -> int:
|
||||
"""
|
||||
创建新的对话
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
title: 对话标题
|
||||
|
||||
Returns:
|
||||
新创建的对话ID
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
|
||||
conversation = Conversation(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
session.refresh(conversation)
|
||||
|
||||
# 初始化对话上下文
|
||||
self.context_cache[conversation.id] = {
|
||||
'conversation_id': conversation.id,
|
||||
'user_id': user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return conversation.id
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建对话失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def get_conversation_context(self, conversation_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取对话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
|
||||
Returns:
|
||||
对话上下文信息
|
||||
"""
|
||||
# 先从缓存中查找
|
||||
if conversation_id in self.context_cache:
|
||||
return self.context_cache[conversation_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
session = next(get_session())
|
||||
|
||||
conversation = session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# 加载消息历史
|
||||
messages = session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
# 重建上下文
|
||||
context = {
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': conversation.user_id,
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': [],
|
||||
'created_at': conversation.created_at.isoformat()
|
||||
}
|
||||
|
||||
# 从消息中提取查询历史
|
||||
for message in messages:
|
||||
if message.role == 'user':
|
||||
context['query_history'].append({
|
||||
'query': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
})
|
||||
elif message.role == 'assistant' and message.metadata:
|
||||
# 从助手消息的元数据中提取文件信息
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
if 'selected_files' in metadata:
|
||||
context['selected_files'] = metadata['selected_files']
|
||||
if 'file_list' in metadata:
|
||||
context['file_list'] = metadata['file_list']
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 缓存上下文
|
||||
self.context_cache[conversation_id] = context
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取对话上下文失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def update_conversation_context(
|
||||
self,
|
||||
conversation_id: int,
|
||||
file_list: List[Dict[str, Any]] = None,
|
||||
selected_files: List[Dict[str, Any]] = None,
|
||||
query: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
更新对话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
file_list: 文件列表
|
||||
selected_files: 选中的文件
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
# 获取或创建上下文
|
||||
context = await self.get_conversation_context(conversation_id)
|
||||
if not context:
|
||||
return False
|
||||
|
||||
# 更新上下文信息
|
||||
if file_list is not None:
|
||||
context['file_list'] = file_list
|
||||
|
||||
if selected_files is not None:
|
||||
context['selected_files'] = selected_files
|
||||
|
||||
if query is not None:
|
||||
context['query_history'].append({
|
||||
'query': query,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
# 更新缓存
|
||||
self.context_cache[conversation_id] = context
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新对话上下文失败: {e}")
|
||||
return False
|
||||
|
||||
async def save_message(
|
||||
self,
|
||||
conversation_id: int,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
保存消息到数据库
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
role: 消息角色 (user/assistant)
|
||||
content: 消息内容
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
session.add(message)
|
||||
session.commit()
|
||||
|
||||
# 更新对话的最后更新时间
|
||||
conversation = session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if conversation:
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存消息失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def reset_conversation_context(self, conversation_id: int) -> bool:
|
||||
"""
|
||||
重置对话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
|
||||
Returns:
|
||||
重置是否成功
|
||||
"""
|
||||
try:
|
||||
# 清除缓存
|
||||
if conversation_id in self.context_cache:
|
||||
context = self.context_cache[conversation_id]
|
||||
# 保留基本信息,清除文件和查询历史
|
||||
context.update({
|
||||
'file_list': [],
|
||||
'selected_files': [],
|
||||
'query_history': []
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"重置对话上下文失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_conversation_history(self, conversation_id: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取对话历史消息
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
|
||||
Returns:
|
||||
消息历史列表
|
||||
"""
|
||||
try:
|
||||
session = next(get_session())
|
||||
|
||||
messages = session.query(Message).filter(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
history = []
|
||||
for message in messages:
|
||||
msg_data = {
|
||||
'id': message.id,
|
||||
'role': message.role,
|
||||
'content': message.content,
|
||||
'timestamp': message.created_at.isoformat()
|
||||
}
|
||||
|
||||
if message.metadata:
|
||||
try:
|
||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||
msg_data['metadata'] = metadata
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
history.append(msg_data)
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取对话历史失败: {e}")
|
||||
return []
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def clear_cache(self, conversation_id: int = None):
|
||||
"""
|
||||
清除缓存
|
||||
|
||||
Args:
|
||||
conversation_id: 特定对话ID,如果为None则清除所有缓存
|
||||
"""
|
||||
if conversation_id:
|
||||
self.context_cache.pop(conversation_id, None)
|
||||
else:
|
||||
self.context_cache.clear()
|
||||
|
||||
# 全局实例
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
|
@ -0,0 +1,375 @@
|
|||
"""数据库配置服务"""
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from cryptography.fernet import Fernet
|
||||
import os
|
||||
|
||||
from ..models.database_config import DatabaseConfig
|
||||
from utils.util_exceptions import ValidationError, NotFoundError
|
||||
from .postgresql_tool_manager import get_postgresql_tool
|
||||
from .mysql_tool_manager import get_mysql_tool
|
||||
|
||||
class DatabaseConfigService:
|
||||
"""数据库配置管理服务"""
|
||||
|
||||
def __init__(self, db_session: Session):
|
||||
self.session = db_session
|
||||
self.postgresql_tool = get_postgresql_tool()
|
||||
self.mysql_tool = get_mysql_tool()
|
||||
# 初始化加密密钥
|
||||
self.encryption_key = self._get_or_create_encryption_key()
|
||||
self.cipher = Fernet(self.encryption_key)
|
||||
def _get_or_create_encryption_key(self) -> bytes:
|
||||
"""获取或创建加密密钥"""
|
||||
key_file = "db/db_config_key.key"
|
||||
if os.path.exists(key_file):
|
||||
print('find db_config_key')
|
||||
with open(key_file, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
else:
|
||||
print('not find db_config_key')
|
||||
key = Fernet.generate_key()
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
return key
|
||||
|
||||
def _encrypt_password(self, password: str) -> str:
|
||||
"""加密密码"""
|
||||
return self.cipher.encrypt(password.encode()).decode()
|
||||
|
||||
def _decrypt_password(self, encrypted_password: str) -> str:
|
||||
"""解密密码"""
|
||||
return self.cipher.decrypt(encrypted_password.encode()).decode()
|
||||
|
||||
async def create_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
||||
"""创建数据库配置"""
|
||||
try:
|
||||
# 验证配置
|
||||
required_fields = ['name', 'db_type', 'host', 'port', 'database', 'username', 'password']
|
||||
for field in required_fields:
|
||||
if field not in config_data:
|
||||
raise ValidationError(f"缺少必需字段: {field}")
|
||||
|
||||
|
||||
# 测试连接
|
||||
test_config = {
|
||||
'host': config_data['host'],
|
||||
'port': config_data['port'],
|
||||
'database': config_data['database'],
|
||||
'username': config_data['username'],
|
||||
'password': config_data['password']
|
||||
}
|
||||
if 'postgresql' == config_data['db_type']:
|
||||
test_result = await self.postgresql_tool.execute(
|
||||
operation="test_connection",
|
||||
connection_config=test_config
|
||||
)
|
||||
if not test_result.success:
|
||||
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
||||
elif 'mysql' == config_data['db_type']:
|
||||
test_result = await self.mysql_tool.execute(
|
||||
operation="test_connection",
|
||||
connection_config=test_config
|
||||
)
|
||||
if not test_result.success:
|
||||
raise ValidationError(f"数据库连接测试失败: {test_result.error}")
|
||||
# 如果设置为默认配置,先取消其他默认配置
|
||||
if config_data.get('is_default', False):
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.created_by == user_id,
|
||||
DatabaseConfig.is_default == True
|
||||
)
|
||||
result = self.session.execute(stmt)
|
||||
for config in result.scalars():
|
||||
config.is_default = False
|
||||
|
||||
# 创建配置
|
||||
db_config = DatabaseConfig(
|
||||
created_by=user_id,
|
||||
name=config_data['name'],
|
||||
db_type=config_data['db_type'],
|
||||
host=config_data['host'],
|
||||
port=config_data['port'],
|
||||
database=config_data['database'],
|
||||
username=config_data['username'],
|
||||
password=self._encrypt_password(config_data['password']),
|
||||
is_active=config_data.get('is_active', True),
|
||||
is_default=config_data.get('is_default', False),
|
||||
connection_params=config_data.get('connection_params')
|
||||
)
|
||||
|
||||
self.session.add(db_config)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_config)
|
||||
|
||||
logger.info(f"创建数据库配置成功: {db_config.name} (ID: {db_config.id})")
|
||||
return db_config
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建数据库配置失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_configs(self, user_id: int, active_only: bool = True) -> List[DatabaseConfig]:
|
||||
"""获取用户的数据库配置列表"""
|
||||
stmt = select(DatabaseConfig).where(DatabaseConfig.created_by == user_id)
|
||||
if active_only:
|
||||
stmt = stmt.where(DatabaseConfig.is_active == True)
|
||||
stmt = stmt.order_by(DatabaseConfig.created_at.desc())
|
||||
return self.session.scalars(stmt).all()
|
||||
|
||||
def get_config_by_id(self, config_id: int, user_id: int) -> Optional[DatabaseConfig]:
|
||||
"""根据ID获取配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.id == config_id,
|
||||
DatabaseConfig.created_by == user_id
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
def get_default_config(self, user_id: int) -> Optional[DatabaseConfig]:
|
||||
"""获取用户的默认配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.created_by == user_id,
|
||||
# DatabaseConfig.is_default == True,
|
||||
DatabaseConfig.is_active == True
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
async def test_connection(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
||||
"""测试数据库连接"""
|
||||
config = self.get_config_by_id(config_id, user_id)
|
||||
if not config:
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
|
||||
test_config = {
|
||||
'host': config.host,
|
||||
'port': config.port,
|
||||
'database': config.database,
|
||||
'username': config.username,
|
||||
'password': self._decrypt_password(config.password)
|
||||
}
|
||||
|
||||
result = await self.postgresql_tool.execute(
|
||||
operation="test_connection",
|
||||
connection_config=test_config
|
||||
)
|
||||
|
||||
return {
|
||||
'success': result.success,
|
||||
'message': result.result.get('message') if result.success else result.error,
|
||||
'details': result.result if result.success else None
|
||||
}
|
||||
|
||||
async def connect_and_get_tables(self, config_id: int, user_id: int) -> Dict[str, Any]:
|
||||
"""连接数据库并获取表列表"""
|
||||
config = self.get_config_by_id(config_id, user_id)
|
||||
if not config:
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
|
||||
connection_config = {
|
||||
'host': config.host,
|
||||
'port': config.port,
|
||||
'database': config.database,
|
||||
'username': config.username,
|
||||
'password': self._decrypt_password(config.password)
|
||||
}
|
||||
|
||||
if 'postgresql' == config.db_type:
|
||||
# 连接数据库
|
||||
connect_result = await self.postgresql_tool.execute(
|
||||
operation="connect",
|
||||
connection_config=connection_config,
|
||||
user_id=str(user_id)
|
||||
)
|
||||
elif 'mysql' == config.db_type:
|
||||
# 连接数据库
|
||||
connect_result = await self.mysql_tool.execute(
|
||||
operation="connect",
|
||||
connection_config=connection_config,
|
||||
user_id=str(user_id)
|
||||
)
|
||||
|
||||
if not connect_result.success:
|
||||
return {
|
||||
'success': False,
|
||||
'message': connect_result.error
|
||||
}
|
||||
# 连接信息已保存到PostgreSQLMCPTool的connections中
|
||||
return {
|
||||
'success': True,
|
||||
'data': connect_result.result,
|
||||
'config_name': config.name
|
||||
}
|
||||
|
||||
async def get_table_data(self, table_name: str, user_id: int, db_type: str, limit: int = 100) -> Dict[str, Any]:
|
||||
"""获取表数据预览(复用已建立的连接)"""
|
||||
try:
|
||||
user_id_str = str(user_id)
|
||||
|
||||
# 根据db_type选择相应的数据库工具
|
||||
if db_type.lower() == 'postgresql':
|
||||
db_tool = self.postgresql_tool
|
||||
elif db_type.lower() == 'mysql':
|
||||
db_tool = self.mysql_tool
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'不支持的数据库类型: {db_type}'
|
||||
}
|
||||
|
||||
# 检查是否已有连接
|
||||
if user_id_str not in db_tool.connections:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '数据库连接已断开,请重新连接数据库'
|
||||
}
|
||||
|
||||
# 直接使用已建立的连接执行查询
|
||||
sql_query = f"SELECT * FROM {table_name}"
|
||||
result = await db_tool.execute(
|
||||
operation="execute_query",
|
||||
user_id=user_id_str,
|
||||
sql_query=sql_query,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
return {
|
||||
'success': False,
|
||||
'message': result.error
|
||||
}
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': result.result,
|
||||
'db_type': db_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表数据失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'获取表数据失败: {str(e)}'
|
||||
}
|
||||
|
||||
def disconnect_database(self, user_id: int) -> Dict[str, Any]:
|
||||
"""断开数据库连接"""
|
||||
try:
|
||||
# 从PostgreSQLMCPTool断开连接
|
||||
self.postgresql_tool.execute(
|
||||
operation="disconnect",
|
||||
user_id=str(user_id)
|
||||
)
|
||||
|
||||
# 从本地连接管理中移除
|
||||
if user_id in self.user_connections:
|
||||
del self.user_connections[user_id]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '数据库连接已断开'
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'断开连接失败: {str(e)}'
|
||||
}
|
||||
|
||||
def get_config_by_type(self, user_id: int, db_type: str) -> Optional[DatabaseConfig]:
|
||||
"""根据数据库类型获取用户配置"""
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.created_by == user_id,
|
||||
DatabaseConfig.db_type == db_type,
|
||||
DatabaseConfig.is_active == True
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
async def create_or_update_config(self, user_id: int, config_data: Dict[str, Any]) -> DatabaseConfig:
|
||||
"""创建或更新数据库配置(保证db_type唯一性)"""
|
||||
try:
|
||||
# 检查是否已存在该类型的配置
|
||||
existing_config = self.get_config_by_type(user_id, config_data['db_type'])
|
||||
|
||||
if existing_config:
|
||||
# 更新现有配置
|
||||
for key, value in config_data.items():
|
||||
if key == 'password':
|
||||
setattr(existing_config, key, self._encrypt_password(value))
|
||||
elif hasattr(existing_config, key):
|
||||
setattr(existing_config, key, value)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(existing_config)
|
||||
logger.info(f"更新数据库配置成功: {existing_config.name} (ID: {existing_config.id})")
|
||||
return existing_config
|
||||
else:
|
||||
# 创建新配置
|
||||
return await self.create_config(user_id, config_data)
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def describe_table(self, table_name: str, user_id: int) -> Dict[str, Any]:
|
||||
"""获取表结构信息(复用已建立的连接)"""
|
||||
try:
|
||||
logger.error(f"未实现的逻辑,暂自编 - describe_table: {table_name}")
|
||||
user_id_str = str(user_id)
|
||||
|
||||
# 获取用户默认数据库配置
|
||||
default_config = self.get_default_config(user_id)
|
||||
if not default_config:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '未找到默认数据库配置'
|
||||
}
|
||||
|
||||
# 根据db_type选择相应的数据库工具
|
||||
if default_config.db_type.lower() == 'postgresql':
|
||||
db_tool = self.postgresql_tool
|
||||
elif default_config.db_type.lower() == 'mysql':
|
||||
db_tool = self.mysql_tool
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'不支持的数据库类型: {default_config.db_type}'
|
||||
}
|
||||
|
||||
# 检查是否已有连接
|
||||
if user_id_str not in db_tool.connections:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '数据库连接已断开,请重新连接数据库'
|
||||
}
|
||||
|
||||
# 使用已建立的连接执行describe_table操作
|
||||
result = await db_tool.execute(
|
||||
operation="describe_table",
|
||||
user_id=user_id_str,
|
||||
table_name=table_name
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
return {
|
||||
'success': False,
|
||||
'message': result.error
|
||||
}
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': result.result,
|
||||
'db_type': default_config.db_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表结构失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'获取表结构失败: {str(e)}'
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,302 @@
|
|||
"""Document service."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import UploadFile
|
||||
|
||||
from ..models.knowledge_base import Document, KnowledgeBase
|
||||
from ..core.config import get_settings
|
||||
from utils.util_file import FileUtils
|
||||
from .storage import storage_service
|
||||
from .document_processor import get_document_processor
|
||||
from utils.util_schemas import DocumentChunk
|
||||
from loguru import logger
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class DocumentService:
|
||||
"""Document service for managing documents in knowledge bases."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.file_utils = FileUtils()
|
||||
|
||||
async def upload_document(self, file: UploadFile, kb_id: int) -> Document:
|
||||
"""Upload a document to knowledge base."""
|
||||
self.session.desc = f"上传文档 {file.filename} 到知识库 {kb_id}"
|
||||
# Validate knowledge base exists
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
kb = self.session.scalar(stmt)
|
||||
if not kb:
|
||||
self.session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||
raise ValueError(f"知识库 {kb_id} 不存在")
|
||||
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
self.session.desc = f"ERROR: 上传文件时未提供文件名"
|
||||
raise ValueError("No filename provided")
|
||||
|
||||
# Validate file extension
|
||||
file_extension = Path(file.filename).suffix.lower()
|
||||
if file_extension not in settings.file.allowed_extensions:
|
||||
self.session.desc = f"ERROR: 非期望的文件类型 {file_extension}"
|
||||
raise ValueError(f"非期望的文件类型 {file_extension}")
|
||||
|
||||
# Upload file using storage service
|
||||
storage_info = await storage_service.upload_file(file, kb_id)
|
||||
|
||||
# Create document record
|
||||
document = Document(
|
||||
knowledge_base_id=kb_id,
|
||||
filename=os.path.basename(storage_info["file_path"]),
|
||||
original_filename=file.filename,
|
||||
file_path=storage_info.get("full_path", storage_info["file_path"]), # Use absolute path if available
|
||||
file_size=storage_info["size"],
|
||||
file_type=file_extension,
|
||||
mime_type=storage_info["mime_type"],
|
||||
is_processed=False
|
||||
)
|
||||
|
||||
# Set audit fields
|
||||
document.set_audit_fields()
|
||||
|
||||
self.session.add(document)
|
||||
self.session.commit()
|
||||
self.session.refresh(document)
|
||||
|
||||
self.session.desc = f"SUCCESS: 成功上传文档 {file.filename} 到知识库 {kb_id} (Doc ID: {document.id})"
|
||||
return document
|
||||
|
||||
def get_document(self, doc_id: int, kb_id: int = None) -> Optional[Document]:
|
||||
"""根据文档ID查询文档,可选地根据知识库ID过滤。"""
|
||||
self.session.desc = f"查询文档 {doc_id}"
|
||||
stmt = select(Document).where(Document.id == doc_id)
|
||||
if kb_id is not None:
|
||||
stmt = stmt.where(Document.knowledge_base_id == kb_id)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
def get_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> List[Document]:
|
||||
"""根据知识库ID查询文档,支持分页。"""
|
||||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||||
stmt = (
|
||||
select(Document)
|
||||
.where(Document.knowledge_base_id == kb_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return self.session.scalars(stmt).all()
|
||||
|
||||
def list_documents(self, kb_id: int, skip: int = 0, limit: int = 50) -> tuple[List[Document], int]:
|
||||
"""根据知识库ID查询文档,支持分页,并返回总文档数。"""
|
||||
self.session.desc = f"查询知识库 {kb_id} 中的文档 (跳过 {skip} 条,限制 {limit} 条)"
|
||||
# Get total count
|
||||
count_stmt = select(func.count(Document.id)).where(Document.knowledge_base_id == kb_id)
|
||||
total = self.session.scalar(count_stmt)
|
||||
|
||||
# Get documents with pagination
|
||||
documents_stmt = (
|
||||
select(Document)
|
||||
.where(Document.knowledge_base_id == kb_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
documents = self.session.scalars(documents_stmt).all()
|
||||
|
||||
return documents, total
|
||||
|
||||
def delete_document(self, doc_id: int, kb_id: int = None) -> bool:
|
||||
"""根据文档ID删除文档,可选地根据知识库ID过滤。"""
|
||||
self.session.desc = f"删除文档 {doc_id}"
|
||||
document = self.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return False
|
||||
|
||||
# Delete file from storage
|
||||
try:
|
||||
storage_service.delete_file(document.file_path)
|
||||
logger.info(f"Deleted file: {document.file_path}")
|
||||
except Exception as e:
|
||||
self.session.desc = f"EXCEPTION: 删除文档 {doc_id} 关联文件时失败: {e}"
|
||||
|
||||
# TODO: Remove from vector database
|
||||
# This should be implemented when vector database service is ready
|
||||
get_document_processor().delete_document_from_vector_store(kb_id,doc_id)
|
||||
# Delete database record
|
||||
self.session.delete(document)
|
||||
self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 成功删除文档 {doc_id}"
|
||||
return True
|
||||
|
||||
async def process_document(self, doc_id: int, kb_id: int = None) -> Dict[str, Any]:
|
||||
"""处理文档,提取文本并创建嵌入向量。"""
|
||||
try:
|
||||
self.session.desc = f"处理文档 {doc_id}"
|
||||
document = self.get_document(doc_id, kb_id)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
raise ValueError(f"Document {doc_id} not found")
|
||||
|
||||
if document.is_processed:
|
||||
self.session.desc = f"INFO: 文档 {doc_id} 已处理"
|
||||
return {
|
||||
"document_id": doc_id,
|
||||
"status": "already_processed",
|
||||
"message": "文档已处理"
|
||||
}
|
||||
|
||||
# 更新文档状态为处理中
|
||||
document.processing_error = None
|
||||
self.session.commit()
|
||||
|
||||
# 调用文档处理器进行处理
|
||||
result = get_document_processor().process_document(
|
||||
document_id=doc_id,
|
||||
file_path=document.file_path,
|
||||
knowledge_base_id=document.knowledge_base_id
|
||||
)
|
||||
self.session.desc = f"SUCCESS: 成功处理文档 {doc_id}"
|
||||
|
||||
# 如果处理成功,更新文档状态
|
||||
if result["status"] == "success":
|
||||
document.is_processed = True
|
||||
document.chunk_count = result.get("chunks_count", 0)
|
||||
self.session.commit()
|
||||
self.session.refresh(document)
|
||||
logger.info(f"Processed document: {document.filename} (ID: {doc_id})")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
self.session.desc = f"EXCEPTION: 处理文档 {doc_id} 时失败: {e}"
|
||||
|
||||
# Update document with error
|
||||
try:
|
||||
document = self.get_document(doc_id)
|
||||
if document:
|
||||
document.processing_error = str(e)
|
||||
self.session.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"Failed to update document error status: {db_error}")
|
||||
|
||||
return {
|
||||
"document_id": doc_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"message": "文档处理失败"
|
||||
}
|
||||
|
||||
async def _extract_text(self, document: Document) -> str:
|
||||
"""从文档文件中提取文本内容。"""
|
||||
try:
|
||||
if document.is_text_file:
|
||||
# Read text files directly
|
||||
with open(document.file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
elif document.is_pdf_file:
|
||||
# TODO: Implement PDF text extraction using PyPDF2 or similar
|
||||
# For now, return placeholder
|
||||
return f"PDF content from {document.original_filename}"
|
||||
|
||||
elif document.is_office_file:
|
||||
# TODO: Implement Office file text extraction using python-docx, openpyxl, etc.
|
||||
# For now, return placeholder
|
||||
return f"Office document content from {document.original_filename}"
|
||||
|
||||
else:
|
||||
self.session.desc = f"ERROR: 不支持的文件类型: {document.file_type}"
|
||||
raise ValueError(f"不支持的文件类型: {document.file_type}")
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"EXCEPTION: 从文档 {document.file_path} 提取文本时失败: {e}"
|
||||
raise
|
||||
|
||||
def update_document_status(self, doc_id: int, is_processed: bool, error: Optional[str] = None) -> bool:
|
||||
"""更新文档处理状态。"""
|
||||
self.session.desc = f"更新文档 {doc_id} 处理状态为 {is_processed}"
|
||||
document = self.get_document(doc_id)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return False
|
||||
|
||||
document.is_processed = is_processed
|
||||
document.processing_error = error
|
||||
|
||||
self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 更新文档 {doc_id} 处理状态为 {is_processed}"
|
||||
return True
|
||||
|
||||
def search_documents(self, kb_id: int, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""在知识库中搜索文档使用向量相似度。"""
|
||||
try:
|
||||
# 使用文档处理器进行相似性搜索
|
||||
self.session.desc = f"搜索知识库 {kb_id} 中的文档使用向量相似度: {query}"
|
||||
results = get_document_processor().search_similar_documents(kb_id, query, limit)
|
||||
self.session.desc = f"SUCCESS: 搜索知识库 {kb_id} 中的文档使用向量相似度: {query} >>> {len(results)} 条结果"
|
||||
return results
|
||||
except Exception as e:
|
||||
self.session.desc = f"EXCEPTION: 搜索知识库 {kb_id} 中的文档使用向量相似度时失败: {e}"
|
||||
logger.error(f"查找知识库 {kb_id} 中的文档使用向量相似度时失败: {e}")
|
||||
return []
|
||||
|
||||
def get_document_stats(self, kb_id: int) -> Dict[str, Any]:
|
||||
"""获取知识库中的文档统计信息。"""
|
||||
documents = self.get_documents(kb_id, limit=1000) # Get all documents
|
||||
|
||||
total_count = len(documents)
|
||||
processed_count = len([doc for doc in documents if doc.is_processed])
|
||||
total_size = sum(doc.file_size for doc in documents)
|
||||
|
||||
file_types = {}
|
||||
for doc in documents:
|
||||
file_type = doc.file_type
|
||||
file_types[file_type] = file_types.get(file_type, 0) + 1
|
||||
|
||||
return {
|
||||
"total_documents": total_count,
|
||||
"processed_documents": processed_count,
|
||||
"pending_documents": total_count - processed_count,
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"file_types": file_types
|
||||
}
|
||||
|
||||
def get_document_chunks(self, doc_id: int) -> List[DocumentChunk]:
|
||||
"""获取特定文档的文档块。"""
|
||||
try:
|
||||
self.session.desc = f"获取文档 {doc_id} 的文档块"
|
||||
stmt = select(Document).where(Document.id == doc_id)
|
||||
document = self.session.scalar(stmt)
|
||||
if not document:
|
||||
self.session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||
return []
|
||||
|
||||
# Get chunks from document processor
|
||||
chunks_data = get_document_processor().get_document_chunks(document.knowledge_base_id, doc_id)
|
||||
|
||||
# Convert to DocumentChunk objects
|
||||
chunks = []
|
||||
for chunk_data in chunks_data:
|
||||
chunk = DocumentChunk(
|
||||
id=chunk_data["id"],
|
||||
content=chunk_data["content"],
|
||||
metadata=chunk_data["metadata"],
|
||||
page_number=chunk_data.get("page_number"),
|
||||
chunk_index=chunk_data["chunk_index"],
|
||||
start_char=chunk_data.get("start_char"),
|
||||
end_char=chunk_data.get("end_char")
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
self.session.desc = f"SUCCESS: 获取文档 {doc_id} 的文档块: {len(chunks)} 个"
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"EXCEPTION: 获取文档 {doc_id} 的文档块时失败: {e}"
|
||||
return []
|
||||
|
|
@ -0,0 +1,973 @@
|
|||
"""文档处理服务,负责文档的分段、向量化和索引"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
PyPDFLoader,
|
||||
Docx2txtLoader,
|
||||
UnstructuredMarkdownLoader
|
||||
)
|
||||
import pdfplumber
|
||||
from langchain_core.documents import Document
|
||||
from langchain_postgres import PGVector
|
||||
from typing import List
|
||||
# 旧的ZhipuEmbeddings类已移除,现在统一使用EmbeddingFactory创建embedding实例
|
||||
|
||||
from ..core.config import BaseSettings, get_settings
|
||||
from ..models.knowledge_base import Document as DocumentModel
|
||||
from ..db.database import get_session
|
||||
from loguru import logger
|
||||
|
||||
settings = get_settings()
|
||||
class PGVectorConnectionPool:
|
||||
"""PGVector连接池管理器"""
|
||||
|
||||
def __init__(self):
|
||||
logger.error("PGVector连接池管理器 -==== 待异步方式实现")
|
||||
self.engine = None
|
||||
self.SessionLocal = None
|
||||
# self._init_connection_pool()
|
||||
|
||||
# def _init_connection_pool(self):
|
||||
# """初始化连接池"""
|
||||
# if settings.vector_db.type == "pgvector":
|
||||
# # 构建连接字符串,对密码进行URL编码以处理特殊字符(如@符号)
|
||||
# encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
||||
# connection_string = (
|
||||
# f"postgresql://{settings.vector_db.pgvector_user}:"
|
||||
# f"{encoded_password}@"
|
||||
# f"{settings.vector_db.pgvector_host}:"
|
||||
# f"{settings.vector_db.pgvector_port}/"
|
||||
# f"{settings.vector_db.pgvector_database}"
|
||||
# )
|
||||
|
||||
# # 创建SQLAlchemy引擎,配置连接池
|
||||
# self.engine = create_engine(
|
||||
# connection_string,
|
||||
# poolclass=QueuePool,
|
||||
# pool_size=5, # 连接池大小
|
||||
# max_overflow=10, # 最大溢出连接数
|
||||
# pool_pre_ping=True, # 连接前ping检查
|
||||
# pool_recycle=3600, # 连接回收时间(秒)
|
||||
# echo=False # 是否打印SQL语句
|
||||
# )
|
||||
|
||||
# # 创建会话工厂
|
||||
# self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||||
# logger.info(f"PGVector连接池已初始化: {settings.vector_db.pgvector_host}:{settings.vector_db.pgvector_port}")
|
||||
|
||||
# def get_session(self):
|
||||
# """获取数据库会话"""
|
||||
# if self.SessionLocal is None:
|
||||
# raise RuntimeError("连接池未初始化")
|
||||
# return self.SessionLocal()
|
||||
|
||||
# def execute_query(self, query: str, params: tuple = None):
|
||||
# """执行查询并返回结果"""
|
||||
# session = self.get_session()
|
||||
# try:
|
||||
# result = session.execute(text(query), params or {})
|
||||
# return result.fetchall()
|
||||
# finally:
|
||||
# session.close()
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器,负责文档的加载、分段和向量化"""
|
||||
|
||||
def __init__(self):
|
||||
# 初始化语义分割器配置
|
||||
self.semantic_splitter_enabled = settings.file.semantic_splitter_enabled
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=settings.file.chunk_size,
|
||||
chunk_overlap=settings.file.chunk_overlap,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", " ", ""]
|
||||
)
|
||||
|
||||
# 初始化嵌入模型 - 根据配置选择提供商
|
||||
self._init_embeddings()
|
||||
|
||||
# 初始化连接池(仅对PGVector)
|
||||
self.pgvector_pool = None
|
||||
|
||||
# PostgreSQL pgvector连接配置
|
||||
print('settings.vector_db.type=============', settings.vector_db.type)
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# 新版本PGVector使用psycopg3连接字符串
|
||||
# 对密码进行URL编码以处理特殊字符(如@符号)
|
||||
encoded_password = quote(settings.vector_db.pgvector_password, safe="")
|
||||
self.connection_string = (
|
||||
f"postgresql+psycopg://{settings.vector_db.pgvector_user}:"
|
||||
f"{encoded_password}@"
|
||||
f"{settings.vector_db.pgvector_host}:"
|
||||
f"{settings.vector_db.pgvector_port}/"
|
||||
f"{settings.vector_db.pgvector_database}"
|
||||
)
|
||||
# 初始化连接池
|
||||
self.pgvector_pool = PGVectorConnectionPool()
|
||||
else:
|
||||
# 向量数据库存储路径(Chroma兼容)
|
||||
vector_db_path = settings.vector_db.persist_directory
|
||||
if not os.path.isabs(vector_db_path):
|
||||
# 如果是相对路径,则基于项目根目录计算绝对路径
|
||||
# 项目根目录是backend的父目录
|
||||
backend_dir = Path(__file__).parent.parent.parent
|
||||
vector_db_path = str(backend_dir / vector_db_path)
|
||||
self.vector_db_path = vector_db_path
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""根据配置初始化embedding模型"""
|
||||
from .embedding_factory import EmbeddingFactory
|
||||
self.embeddings = EmbeddingFactory.create_embeddings()
|
||||
|
||||
def load_document(self, file_path: str) -> List[Document]:
|
||||
"""根据文件类型加载文档"""
|
||||
file_extension = Path(file_path).suffix.lower()
|
||||
|
||||
try:
|
||||
if file_extension == '.txt':
|
||||
loader = TextLoader(file_path, encoding='utf-8')
|
||||
documents = loader.load()
|
||||
elif file_extension == '.pdf':
|
||||
# 使用pdfplumber处理PDF文件,更稳定
|
||||
documents = self._load_pdf_with_pdfplumber(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
documents = loader.load()
|
||||
elif file_extension == '.md':
|
||||
loader = UnstructuredMarkdownLoader(file_path)
|
||||
documents = loader.load()
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {file_extension}")
|
||||
|
||||
logger.info(f"成功加载文档: {file_path}, 页数: {len(documents)}")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载文档失败 {file_path}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_pdf_with_pdfplumber(self, file_path: str) -> List[Document]:
|
||||
"""使用pdfplumber加载PDF文档"""
|
||||
documents = []
|
||||
try:
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages):
|
||||
text = page.extract_text()
|
||||
if text and text.strip(): # 只处理有文本内容的页面
|
||||
doc = Document(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"source": file_path,
|
||||
"page": page_num + 1
|
||||
}
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
except Exception as e:
|
||||
logger.error(f"使用pdfplumber加载PDF失败 {file_path}: {str(e)}")
|
||||
# 如果pdfplumber失败,回退到PyPDFLoader
|
||||
try:
|
||||
loader = PyPDFLoader(file_path)
|
||||
return loader.load()
|
||||
except Exception as fallback_e:
|
||||
logger.error(f"PyPDFLoader回退也失败 {file_path}: {str(fallback_e)}")
|
||||
raise fallback_e
|
||||
|
||||
def _merge_documents(self, documents: List[Document]) -> Document:
|
||||
"""将多个文档合并成一个文档"""
|
||||
merged_text = ""
|
||||
merged_metadata = {}
|
||||
|
||||
for doc in documents:
|
||||
if merged_text:
|
||||
merged_text += "\n\n"
|
||||
merged_text += doc.page_content
|
||||
# 合并元数据
|
||||
merged_metadata.update(doc.metadata)
|
||||
|
||||
return Document(page_content=merged_text, metadata=merged_metadata)
|
||||
|
||||
def _get_semantic_split_points(self, text: str) -> List[str]:
|
||||
"""使用大模型分析文档内容,返回合适的分割点列表"""
|
||||
try:
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from ..core.config import get_settings
|
||||
|
||||
|
||||
|
||||
prompt = f"""
|
||||
# 任务说明
|
||||
请分析文档内容,识别出适合作为分割点的关键位置。分割点应该是能够将文档划分为有意义段落的文本片段。
|
||||
|
||||
|
||||
# 分割规则
|
||||
请严格按照以下规则识别分割点:
|
||||
|
||||
## 基本要求
|
||||
1. 分割点必须是完整的句子开头或段落开头
|
||||
2. 每个分割后的部分应包含相对完整的语义内容
|
||||
3. 每个分割部分的理想长度控制在500字以内,严禁超过1000字。如果超过了1000字,要强制分段。
|
||||
|
||||
## 短段落处理
|
||||
4. 如果某部分长度可能小于50字,应将其与后续内容合并,避免产生过短片段
|
||||
|
||||
## 唯一性保证(重要)
|
||||
5. 确保每个分割点在文档中具有唯一性:
|
||||
- 检查文内是否存在相同的文本片段
|
||||
- 如果存在重复,需要扩展分割点字符串,直到获得唯一标识
|
||||
- 扩展方法:在当前分割点后追加几个字符,形成更长的唯一字符串
|
||||
|
||||
## 示例说明
|
||||
原始文档:
|
||||
"目录:
|
||||
第一章 标题一
|
||||
第二章 标题二
|
||||
正文
|
||||
第一章 标题一
|
||||
这是第一章的内容
|
||||
|
||||
第二章 标题二
|
||||
这是第二章的内容"
|
||||
|
||||
错误分割点:"第一章 标题一"(在目录和正文中重复出现)
|
||||
|
||||
正确分割点:"第一章 标题一\n这是第"(通过追加内容确保唯一性)
|
||||
|
||||
# 输出格式
|
||||
- 只返回分割点文本字符串
|
||||
- 每个分割点用~~分隔
|
||||
- 不要包含任何其他内容或解释
|
||||
|
||||
示例输出:分割点1~~分割点2~~分割点3
|
||||
|
||||
|
||||
文档内容:
|
||||
{text[:10000]} # 限制输入长度
|
||||
"""
|
||||
from ..core.llm import create_llm
|
||||
llm = create_llm(temperature=0.2)
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
# 解析响应获取分割点列表
|
||||
split_points = [point.strip() for point in response.content.split('~~') if point.strip()]
|
||||
logger.info(f"语义分析得到 {len(split_points)} 个分割点")
|
||||
return split_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取语义分割点失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _split_by_semantic_points(self, text: str, split_points: List[str]) -> List[str]:
|
||||
"""根据语义分割点切分文本"""
|
||||
chunks = []
|
||||
current_pos = 0
|
||||
|
||||
# 按顺序查找每个分割点并切分文本
|
||||
for point in split_points:
|
||||
pos = text.find(point, current_pos)
|
||||
if pos != -1:
|
||||
# 添加当前位置到分割点位置的文本块
|
||||
if pos > current_pos:
|
||||
chunk = text[current_pos:pos].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
current_pos = pos
|
||||
|
||||
# 添加最后一个文本块
|
||||
if current_pos < len(text):
|
||||
chunk = text[current_pos:].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> List[Document]:
|
||||
"""将文档分割成小块(含短段落合并和超长强制分割功能)"""
|
||||
try:
|
||||
if self.semantic_splitter_enabled and documents:
|
||||
# 1. 合并文档
|
||||
merged_doc = self._merge_documents(documents)
|
||||
|
||||
# 2. 获取语义分割点
|
||||
split_points = self._get_semantic_split_points(merged_doc.page_content)
|
||||
|
||||
if split_points:
|
||||
# 3. 根据语义分割点切分文本
|
||||
text_chunks = self._split_by_semantic_points(merged_doc.page_content, split_points)
|
||||
|
||||
# 4. 处理短段落合并和超长强制分割(新增逻辑)
|
||||
processed_chunks = []
|
||||
buffer = ""
|
||||
for chunk in text_chunks:
|
||||
# 先检查当前chunk是否超长(超过1000字符)
|
||||
if len(chunk) > 1000:
|
||||
# 如果有缓冲内容,先处理缓冲
|
||||
if buffer:
|
||||
processed_chunks.append(buffer)
|
||||
buffer = ""
|
||||
|
||||
# 对超长chunk进行强制分割
|
||||
forced_splits = self._force_split_long_chunk(chunk)
|
||||
processed_chunks.extend(forced_splits)
|
||||
else:
|
||||
# 正常处理短段落合并
|
||||
if not buffer:
|
||||
buffer = chunk
|
||||
else:
|
||||
if len(buffer) < 100:
|
||||
buffer = f"{buffer}\n{chunk}"
|
||||
else:
|
||||
processed_chunks.append(buffer)
|
||||
buffer = chunk
|
||||
|
||||
# 添加最后剩余的缓冲内容
|
||||
if buffer:
|
||||
processed_chunks.append(buffer)
|
||||
|
||||
# 5. 创建Document对象
|
||||
chunks = []
|
||||
for i, chunk in enumerate(processed_chunks):
|
||||
doc = Document(
|
||||
page_content=chunk,
|
||||
metadata={
|
||||
**merged_doc.metadata,
|
||||
'chunk_index': i,
|
||||
'merged': len(chunk) > 100, # 标记是否经过合并
|
||||
'forced_split': len(chunk) > 1000 # 标记是否经过强制分割
|
||||
}
|
||||
)
|
||||
chunks.append(doc)
|
||||
else:
|
||||
# 如果获取分割点失败,回退到默认分割器
|
||||
logger.warning("语义分割失败,使用默认分割器")
|
||||
chunks = self.text_splitter.split_documents(documents)
|
||||
else:
|
||||
# 使用默认分割器
|
||||
chunks = self.text_splitter.split_documents(documents)
|
||||
|
||||
logger.info(f"文档分割完成,共生成 {len(chunks)} 个文档块")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档分割失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def _force_split_long_chunk(self, chunk: str) -> List[str]:
|
||||
"""强制分割超长段落(超过1000字符)"""
|
||||
max_length = 1000
|
||||
chunks = []
|
||||
|
||||
# 先尝试按换行符分割
|
||||
if '\n' in chunk:
|
||||
lines = chunk.split('\n')
|
||||
current_chunk = ""
|
||||
for line in lines:
|
||||
if len(current_chunk) + len(line) + 1 > max_length:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = line
|
||||
else:
|
||||
chunks.append(line[:max_length])
|
||||
current_chunk = line[max_length:]
|
||||
else:
|
||||
if current_chunk:
|
||||
current_chunk += "\n" + line
|
||||
else:
|
||||
current_chunk = line
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
else:
|
||||
# 没有换行符则直接按长度分割
|
||||
chunks = [chunk[i:i + max_length] for i in range(0, len(chunk), max_length)]
|
||||
|
||||
return chunks
|
||||
|
||||
def create_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> str:
|
||||
"""为知识库创建向量存储"""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# 添加元数据
|
||||
for i, doc in enumerate(documents):
|
||||
doc.metadata.update({
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": str(document_id) if document_id else "unknown",
|
||||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||||
"chunk_index": i
|
||||
})
|
||||
|
||||
# 创建PostgreSQL pgvector存储
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
# 创建新版本PGVector实例
|
||||
vector_store = PGVector(
|
||||
connection=self.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True # 使用JSONB存储元数据
|
||||
)
|
||||
|
||||
# 手动添加文档
|
||||
vector_store.add_documents(documents)
|
||||
|
||||
logger.info(f"PostgreSQL pgvector存储创建成功: {collection_name}")
|
||||
return collection_name
|
||||
else:
|
||||
# Chroma兼容模式
|
||||
from langchain_community.vectorstores import Chroma
|
||||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
# 添加元数据
|
||||
for i, doc in enumerate(documents):
|
||||
doc.metadata.update({
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": str(document_id) if document_id else "unknown",
|
||||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||||
"chunk_index": i
|
||||
})
|
||||
|
||||
# 创建向量存储
|
||||
vector_store = Chroma.from_documents(
|
||||
documents=documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=kb_vector_path
|
||||
)
|
||||
|
||||
# 持久化向量存储
|
||||
vector_store.persist()
|
||||
|
||||
logger.info(f"向量存储创建成功: {kb_vector_path}")
|
||||
return kb_vector_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建向量存储失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def add_documents_to_vector_store(self, knowledge_base_id: int, documents: List[Document], document_id: int = None) -> None:
|
||||
"""向现有向量存储添加文档"""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# 添加元数据
|
||||
for i, doc in enumerate(documents):
|
||||
doc.metadata.update({
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": str(document_id) if document_id else "unknown",
|
||||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||||
"chunk_index": i
|
||||
})
|
||||
|
||||
# PostgreSQL pgvector存储
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
try:
|
||||
# 连接到现有集合
|
||||
vector_store = PGVector(
|
||||
connection=self.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
# 添加新文档
|
||||
vector_store.add_documents(documents)
|
||||
except Exception as e:
|
||||
# 如果集合不存在,创建新的向量存储
|
||||
logger.warning(f"连接现有向量存储失败,创建新的向量存储: {e}")
|
||||
self.create_vector_store(knowledge_base_id, documents, document_id)
|
||||
return
|
||||
|
||||
logger.info(f"文档已添加到PostgreSQL pgvector存储: {collection_name}")
|
||||
else:
|
||||
# Chroma兼容模式
|
||||
from langchain_community.vectorstores import Chroma
|
||||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
# 检查向量存储是否存在
|
||||
if not os.path.exists(kb_vector_path):
|
||||
# 如果不存在,创建新的向量存储
|
||||
self.create_vector_store(knowledge_base_id, documents, document_id)
|
||||
return
|
||||
|
||||
# 添加元数据
|
||||
for i, doc in enumerate(documents):
|
||||
doc.metadata.update({
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"document_id": str(document_id) if document_id else "unknown",
|
||||
"chunk_id": f"{knowledge_base_id}_{document_id}_{i}",
|
||||
"chunk_index": i
|
||||
})
|
||||
|
||||
# 加载现有向量存储
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
# 添加新文档
|
||||
vector_store.add_documents(documents)
|
||||
vector_store.persist()
|
||||
|
||||
logger.info(f"文档已添加到向量存储: {kb_vector_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加文档到向量存储失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def process_document(self, document_id: int, file_path: str, knowledge_base_id: int) -> Dict[str, Any]:
|
||||
"""处理单个文档:加载、分段、向量化"""
|
||||
try:
|
||||
logger.info(f"开始处理文档 ID: {document_id}, 路径: {file_path}")
|
||||
|
||||
# 1. 加载文档
|
||||
documents = self.load_document(file_path)
|
||||
|
||||
# 2. 分割文档
|
||||
chunks = self.split_documents(documents)
|
||||
|
||||
# 3. 添加到向量存储
|
||||
self.add_documents_to_vector_store(knowledge_base_id, chunks, document_id)
|
||||
|
||||
# 4. 更新文档状态
|
||||
with next(get_session()) as session:
|
||||
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
||||
if document:
|
||||
document.status = "processed"
|
||||
document.chunk_count = len(chunks)
|
||||
session.commit()
|
||||
|
||||
result = {
|
||||
"document_id": document_id,
|
||||
"status": "success",
|
||||
"chunks_count": len(chunks),
|
||||
"message": "文档处理完成"
|
||||
}
|
||||
|
||||
logger.info(f"文档处理完成: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败 ID: {document_id}: {str(e)}")
|
||||
|
||||
# 更新文档状态为失败
|
||||
try:
|
||||
with next(get_session()) as session:
|
||||
document = session.query(DocumentModel).filter(DocumentModel.id == document_id).first()
|
||||
if document:
|
||||
document.status = "failed"
|
||||
document.error_message = str(e)
|
||||
session.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"更新文档状态失败: {str(db_error)}")
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"message": "文档处理失败"
|
||||
}
|
||||
|
||||
def _get_document_ids_from_vector_store(self, knowledge_base_id: int, document_id: int) -> List[str]:
|
||||
"""查询指定document_id的所有向量记录的uuid"""
|
||||
try:
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
# 使用连接池执行查询
|
||||
if self.pgvector_pool:
|
||||
query = f"""
|
||||
SELECT uuid FROM langchain_pg_embedding
|
||||
WHERE collection_id = (
|
||||
SELECT uuid FROM langchain_pg_collection
|
||||
WHERE name = %s
|
||||
) AND cmetadata->>'document_id' = %s
|
||||
"""
|
||||
|
||||
result = self.pgvector_pool.execute_query(query, (collection_name, str(document_id)))
|
||||
return [row[0] for row in result] if result else []
|
||||
else:
|
||||
logger.warning("PGVector连接池未初始化")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询文档向量记录失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def delete_document_from_vector_store(self, knowledge_base_id: int, document_id: int) -> None:
|
||||
"""从向量存储中删除文档"""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# PostgreSQL pgvector存储
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
try:
|
||||
# 创建新版本PGVector实例
|
||||
vector_store = PGVector(
|
||||
connection=self.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
|
||||
# 直接从数据库查询要删除的文档UUID
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取数据库引擎
|
||||
engine = vector_store._engine
|
||||
|
||||
with Session(engine) as session:
|
||||
# 查询匹配document_id的所有记录的ID
|
||||
query_sql = text(
|
||||
f"SELECT id FROM langchain_pg_embedding "
|
||||
f"WHERE cmetadata->>'document_id' = :doc_id"
|
||||
)
|
||||
result = session.execute(query_sql, {"doc_id": str(document_id)})
|
||||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||||
|
||||
if ids_to_delete:
|
||||
# 使用ID删除文档
|
||||
vector_store.delete(ids=ids_to_delete)
|
||||
logger.info(f"成功删除 {len(ids_to_delete)} 个文档块: document_id={document_id}")
|
||||
else:
|
||||
logger.warning(f"未找到要删除的文档ID: document_id={document_id}")
|
||||
|
||||
except Exception as query_error:
|
||||
logger.error(f"查询要删除的文档时出错: {query_error}")
|
||||
# 如果查询失败,说明文档可能不存在
|
||||
logger.warning(f"无法查询到要删除的文档: document_id={document_id}")
|
||||
return
|
||||
|
||||
logger.info(f"文档已从PostgreSQL pgvector存储中删除: document_id={document_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"PostgreSQL pgvector存储不存在或删除失败: {collection_name}, {str(e)}")
|
||||
else:
|
||||
# Chroma兼容模式
|
||||
from langchain_community.vectorstores import Chroma
|
||||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(kb_vector_path):
|
||||
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
||||
return
|
||||
|
||||
# 加载向量存储
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
# 删除相关文档块(这里需要根据实际的Chroma API来实现)
|
||||
# 注意:Chroma的删除功能可能需要特定的实现方式
|
||||
logger.info(f"文档已从向量存储中删除: document_id={document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从向量存储删除文档失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_document_chunks(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||||
"""获取文档的所有分段内容
|
||||
|
||||
改进说明:
|
||||
- 避免使用空查询进行相似性搜索,防止触发不必要的embedding API调用
|
||||
- 优先使用直接SQL查询,提高性能
|
||||
- 确保结果按chunk_index排序
|
||||
"""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# PostgreSQL pgvector存储 - 使用直接SQL查询避免相似性搜索
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
try:
|
||||
# 尝试直接SQL查询(推荐方法)
|
||||
chunks = self._get_chunks_by_sql(knowledge_base_id, document_id)
|
||||
if chunks:
|
||||
return chunks
|
||||
|
||||
# 如果SQL查询失败,回退到改进的LangChain方法
|
||||
logger.info("SQL查询失败,使用LangChain回退方案")
|
||||
return self._get_chunks_by_langchain_improved(knowledge_base_id, document_id, collection_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"PostgreSQL pgvector存储访问失败: {collection_name}, {str(e)}")
|
||||
return []
|
||||
else:
|
||||
# Chroma兼容模式
|
||||
return self._get_chunks_chroma(knowledge_base_id, document_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文档分段失败 document_id: {document_id}, kb_id: {knowledge_base_id}: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_chunks_by_sql(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||||
"""使用SQLAlchemy连接池查询获取文档分段(推荐方法)"""
|
||||
try:
|
||||
if not self.pgvector_pool:
|
||||
logger.error("PGVector连接池未初始化")
|
||||
return []
|
||||
|
||||
# 直接SQL查询,避免相似性搜索和embedding计算
|
||||
query = """
|
||||
SELECT
|
||||
id,
|
||||
document,
|
||||
cmetadata
|
||||
FROM langchain_pg_embedding
|
||||
WHERE cmetadata->>'document_id' = :document_id
|
||||
AND cmetadata->>'knowledge_base_id' = :knowledge_base_id
|
||||
ORDER BY
|
||||
CAST(cmetadata->>'chunk_index' AS INTEGER) ASC;
|
||||
"""
|
||||
|
||||
# 使用连接池执行查询
|
||||
session = self.pgvector_pool.get_session()
|
||||
try:
|
||||
result = session.execute(
|
||||
text(query),
|
||||
{
|
||||
'document_id': str(document_id),
|
||||
'knowledge_base_id': str(knowledge_base_id)
|
||||
}
|
||||
)
|
||||
results = result.fetchall()
|
||||
|
||||
chunks = []
|
||||
for row in results:
|
||||
# SQLAlchemy结果行访问
|
||||
metadata = row.cmetadata
|
||||
chunk = {
|
||||
"id": f"chunk_{document_id}_{metadata.get('chunk_index', 0)}",
|
||||
"content": row.document,
|
||||
"metadata": metadata,
|
||||
"page_number": metadata.get("page"),
|
||||
"chunk_index": metadata.get("chunk_index", 0),
|
||||
"start_char": metadata.get("start_char"),
|
||||
"end_char": metadata.get("end_char")
|
||||
}
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"通过SQLAlchemy连接池查询获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||||
return chunks
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SQLAlchemy连接池查询失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_chunks_by_langchain_improved(self, knowledge_base_id: int, document_id: int, collection_name: str) -> List[Dict[str, Any]]:
|
||||
"""改进的LangChain查询方法(回退方案)"""
|
||||
try:
|
||||
vector_store = PGVector(
|
||||
connection=self.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
|
||||
# 使用有意义的查询而不是空查询,避免触发embedding API错误
|
||||
# 先尝试获取少量结果来构造查询
|
||||
try:
|
||||
sample_results = vector_store.similarity_search(
|
||||
query="文档内容", # 使用通用查询词而非空字符串
|
||||
k=5,
|
||||
filter={"document_id": {"$eq": str(document_id)}}
|
||||
)
|
||||
|
||||
if sample_results:
|
||||
# 使用第一个结果的内容片段作为查询
|
||||
first_content = sample_results[0].page_content[:50]
|
||||
results = vector_store.similarity_search(
|
||||
query=first_content,
|
||||
k=1000,
|
||||
filter={"document_id": {"$eq": str(document_id)}}
|
||||
)
|
||||
else:
|
||||
# 如果没有结果,尝试不使用filter的查询
|
||||
results = vector_store.similarity_search(
|
||||
query="文档",
|
||||
k=1000
|
||||
)
|
||||
# 手动过滤结果
|
||||
results = [doc for doc in results if doc.metadata.get("document_id") == str(document_id)]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"改进的相似性搜索失败: {e}")
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
for i, doc in enumerate(results):
|
||||
chunk = {
|
||||
"id": f"chunk_{document_id}_{i}",
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"page_number": doc.metadata.get("page"),
|
||||
"chunk_index": doc.metadata.get("chunk_index", i),
|
||||
"start_char": doc.metadata.get("start_char"),
|
||||
"end_char": doc.metadata.get("end_char")
|
||||
}
|
||||
chunks.append(chunk)
|
||||
|
||||
# 按chunk_index排序
|
||||
chunks.sort(key=lambda x: x.get("chunk_index", 0))
|
||||
|
||||
logger.info(f"通过改进的LangChain方法获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LangChain改进方法失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_chunks_chroma(self, knowledge_base_id: int, document_id: int) -> List[Dict[str, Any]]:
|
||||
"""Chroma存储的处理逻辑"""
|
||||
try:
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
# 构建向量数据库路径
|
||||
vector_db_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(vector_db_path):
|
||||
logger.warning(f"向量数据库不存在: {vector_db_path}")
|
||||
return []
|
||||
|
||||
# 加载向量数据库
|
||||
vectorstore = Chroma(
|
||||
persist_directory=vector_db_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
# 获取所有文档的元数据,筛选出指定文档的分段
|
||||
collection = vectorstore._collection
|
||||
all_docs = collection.get(include=["metadatas", "documents"])
|
||||
|
||||
chunks = []
|
||||
chunk_index = 0
|
||||
|
||||
for i, metadata in enumerate(all_docs["metadatas"]):
|
||||
if metadata.get("document_id") == str(document_id):
|
||||
chunk_content = all_docs["documents"][i]
|
||||
|
||||
chunk = {
|
||||
"id": f"chunk_{document_id}_{chunk_index}",
|
||||
"content": chunk_content,
|
||||
"metadata": metadata,
|
||||
"page_number": metadata.get("page"),
|
||||
"chunk_index": chunk_index,
|
||||
"start_char": metadata.get("start_char"),
|
||||
"end_char": metadata.get("end_char")
|
||||
}
|
||||
chunks.append(chunk)
|
||||
chunk_index += 1
|
||||
|
||||
logger.info(f"获取到文档 {document_id} 的 {len(chunks)} 个分段")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chroma存储处理失败: {e}")
|
||||
return []
|
||||
|
||||
def search_similar_documents(self, knowledge_base_id: int, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""在知识库中搜索相似文档"""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# PostgreSQL pgvector存储
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
try:
|
||||
vector_store = PGVector(
|
||||
connection=self.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
|
||||
# 执行相似性搜索
|
||||
results = vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for doc, distance_score in results:
|
||||
# pgvector使用余弦距离,距离越小相似度越高
|
||||
# 将距离转换为0-1之间的相似度分数
|
||||
similarity_score = 1.0 / (1.0 + distance_score)
|
||||
|
||||
formatted_results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"similarity_score": distance_score, # 保留原始距离分数
|
||||
"normalized_score": similarity_score, # 归一化相似度分数
|
||||
"source": doc.metadata.get('filename', 'unknown'),
|
||||
"document_id": doc.metadata.get('document_id', 'unknown'),
|
||||
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
||||
})
|
||||
|
||||
# 按相似度分数排序(距离越小越相似)
|
||||
formatted_results.sort(key=lambda x: x['similarity_score'])
|
||||
|
||||
logger.info(f"PostgreSQL pgvector搜索完成,找到 {len(formatted_results)} 个相关文档")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"PostgreSQL pgvector存储不存在: {collection_name}, {str(e)}")
|
||||
return []
|
||||
else:
|
||||
# Chroma兼容模式
|
||||
from langchain_community.vectorstores import Chroma
|
||||
kb_vector_path = os.path.join(self.vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(kb_vector_path):
|
||||
logger.warning(f"向量存储不存在: {kb_vector_path}")
|
||||
return []
|
||||
|
||||
# 加载向量存储
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
# 执行相似性搜索
|
||||
results = vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for doc, distance_score in results:
|
||||
# Chroma使用欧几里得距离,距离越小相似度越高
|
||||
# 将距离转换为0-1之间的相似度分数
|
||||
similarity_score = 1.0 / (1.0 + distance_score)
|
||||
|
||||
formatted_results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"similarity_score": distance_score, # 保留原始距离分数
|
||||
"normalized_score": similarity_score, # 归一化相似度分数
|
||||
"source": doc.metadata.get('filename', 'unknown'),
|
||||
"document_id": doc.metadata.get('document_id', 'unknown'),
|
||||
"chunk_id": doc.metadata.get('chunk_id', 'unknown')
|
||||
})
|
||||
|
||||
# 按相似度分数排序(距离越小越相似)
|
||||
formatted_results.sort(key=lambda x: x['similarity_score'])
|
||||
|
||||
logger.info(f"搜索完成,找到 {len(formatted_results)} 个相关文档")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索文档失败: {str(e)}")
|
||||
return [] # 返回空列表而不是抛出异常
|
||||
|
||||
|
||||
# 全局文档处理器实例(延迟初始化)
|
||||
document_processor = None
|
||||
|
||||
def get_document_processor():
|
||||
"""获取文档处理器实例(延迟初始化)"""
|
||||
global document_processor
|
||||
if document_processor is None:
|
||||
document_processor = DocumentProcessor()
|
||||
return document_processor
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
"""Embedding factory for different providers."""
|
||||
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
||||
from ..core.config import settings
|
||||
from loguru import logger
|
||||
|
||||
class EmbeddingFactory:
|
||||
"""Factory class for creating embedding instances based on provider."""
|
||||
|
||||
@staticmethod
|
||||
def create_embeddings(
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
dimensions: Optional[int] = None
|
||||
) -> Embeddings:
|
||||
"""Create embeddings instance based on provider.
|
||||
|
||||
Args:
|
||||
provider: Embedding provider (openai, zhipu, deepseek, doubao, moonshot, sentence-transformers)
|
||||
model: Model name
|
||||
dimensions: Embedding dimensions
|
||||
|
||||
Returns:
|
||||
Embeddings instance
|
||||
"""
|
||||
# 使用新的embedding配置
|
||||
embedding_config = settings.embedding.get_current_config()
|
||||
provider = provider or settings.embedding.provider
|
||||
model = model or embedding_config.get("model")
|
||||
dimensions = dimensions or settings.vector_db.embedding_dimension
|
||||
|
||||
logger.info(f"Creating embeddings with provider: {provider}, model: {model}")
|
||||
|
||||
if provider == "openai":
|
||||
return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions)
|
||||
elif provider in ["zhipu", "deepseek", "doubao", "moonshot"]:
|
||||
return EmbeddingFactory._create_openai_compatible_embeddings(embedding_config, model, dimensions, provider)
|
||||
elif provider == "sentence-transformers":
|
||||
return EmbeddingFactory._create_huggingface_embeddings(model)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}")
|
||||
|
||||
@staticmethod
|
||||
def _create_openai_embeddings(embedding_config: dict, model: str, dimensions: int) -> OpenAIEmbeddings:
|
||||
"""Create OpenAI embeddings."""
|
||||
return OpenAIEmbeddings(
|
||||
api_key=embedding_config["api_key"],
|
||||
base_url=embedding_config["base_url"],
|
||||
model=model if model.startswith("text-embedding") else "text-embedding-ada-002",
|
||||
dimensions=dimensions if model.startswith("text-embedding-3") else None
|
||||
)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _create_openai_compatible_embeddings(embedding_config: dict, model: str, dimensions: int, provider: str) -> Embeddings:
|
||||
"""Create OpenAI-compatible embeddings for ZhipuAI, DeepSeek, Doubao, Moonshot."""
|
||||
if provider == "zhipu":
|
||||
return ZhipuOpenAIEmbeddings(
|
||||
api_key=embedding_config["api_key"],
|
||||
base_url=embedding_config["base_url"],
|
||||
model=model if model.startswith("embedding") else "embedding-3",
|
||||
dimensions=dimensions
|
||||
)
|
||||
else:
|
||||
return OpenAIEmbeddings(
|
||||
api_key=embedding_config["api_key"],
|
||||
base_url=embedding_config["base_url"],
|
||||
model=model,
|
||||
dimensions=dimensions
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_huggingface_embeddings(model: str) -> HuggingFaceEmbeddings:
|
||||
"""Create HuggingFace embeddings."""
|
||||
return HuggingFaceEmbeddings(
|
||||
model_name=model,
|
||||
model_kwargs={'device': 'cpu'},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
"""Excel metadata extraction service."""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.excel_file import ExcelFile
|
||||
from ..db.database import get_session
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ExcelMetadataService:
|
||||
"""Service for extracting and managing Excel file metadata."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def extract_file_metadata(self, file_path: str, original_filename: str,
|
||||
user_id: int, file_size: int) -> Dict[str, Any]:
|
||||
"""Extract metadata from Excel file."""
|
||||
try:
|
||||
# Determine file type
|
||||
file_extension = os.path.splitext(original_filename)[1].lower()
|
||||
|
||||
# Read Excel file
|
||||
if file_extension == '.csv':
|
||||
# For CSV files, treat as single sheet
|
||||
df = pd.read_csv(file_path)
|
||||
sheets_data = {'Sheet1': df}
|
||||
else:
|
||||
# For Excel files, read all sheets
|
||||
sheets_data = pd.read_excel(file_path, sheet_name=None)
|
||||
|
||||
# Extract metadata for each sheet
|
||||
sheet_names = list(sheets_data.keys())
|
||||
columns_info = {}
|
||||
preview_data = {}
|
||||
data_types = {}
|
||||
total_rows = {}
|
||||
total_columns = {}
|
||||
|
||||
for sheet_name, df in sheets_data.items():
|
||||
# Clean column names (remove unnamed columns)
|
||||
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
||||
|
||||
# Get column information - ensure proper encoding
|
||||
columns_info[sheet_name] = [str(col) if not isinstance(col, str) else col for col in df.columns.tolist()]
|
||||
|
||||
# Get preview data (first 5 rows) and convert to JSON serializable format
|
||||
preview_df = df.head(5)
|
||||
# Convert all values to strings to ensure JSON serialization
|
||||
preview_values = []
|
||||
for row in preview_df.values:
|
||||
string_row = []
|
||||
for value in row:
|
||||
if pd.isna(value):
|
||||
string_row.append(None)
|
||||
elif hasattr(value, 'strftime'): # Handle datetime/timestamp objects
|
||||
string_row.append(value.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
else:
|
||||
# Preserve Chinese characters and other unicode content
|
||||
if isinstance(value, str):
|
||||
string_row.append(value)
|
||||
else:
|
||||
string_row.append(str(value))
|
||||
preview_values.append(string_row)
|
||||
preview_data[sheet_name] = preview_values
|
||||
|
||||
# Get data types
|
||||
data_types[sheet_name] = {col: str(dtype) for col, dtype in df.dtypes.items()}
|
||||
|
||||
# Get statistics
|
||||
total_rows[sheet_name] = len(df)
|
||||
total_columns[sheet_name] = len(df.columns)
|
||||
|
||||
# Determine default sheet
|
||||
default_sheet = sheet_names[0] if sheet_names else None
|
||||
|
||||
return {
|
||||
'sheet_names': sheet_names,
|
||||
'default_sheet': default_sheet,
|
||||
'columns_info': columns_info,
|
||||
'preview_data': preview_data,
|
||||
'data_types': data_types,
|
||||
'total_rows': total_rows,
|
||||
'total_columns': total_columns,
|
||||
'is_processed': True,
|
||||
'processing_error': None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting metadata from {file_path}: {str(e)}")
|
||||
return {
|
||||
'sheet_names': [],
|
||||
'default_sheet': None,
|
||||
'columns_info': {},
|
||||
'preview_data': {},
|
||||
'data_types': {},
|
||||
'total_rows': {},
|
||||
'total_columns': {},
|
||||
'is_processed': False,
|
||||
'processing_error': str(e)
|
||||
}
|
||||
|
||||
def save_file_metadata(self, file_path: str, original_filename: str,
|
||||
user_id: int, file_size: int) -> ExcelFile:
|
||||
"""Extract and save Excel file metadata to database."""
|
||||
try:
|
||||
# Extract metadata
|
||||
metadata = self.extract_file_metadata(file_path, original_filename, user_id, file_size)
|
||||
|
||||
# Determine file type
|
||||
file_extension = os.path.splitext(original_filename)[1].lower()
|
||||
|
||||
# Create ExcelFile record
|
||||
excel_file = ExcelFile(
|
||||
original_filename=original_filename,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
file_type=file_extension,
|
||||
sheet_names=metadata['sheet_names'],
|
||||
default_sheet=metadata['default_sheet'],
|
||||
columns_info=metadata['columns_info'],
|
||||
preview_data=metadata['preview_data'],
|
||||
data_types=metadata['data_types'],
|
||||
total_rows=metadata['total_rows'],
|
||||
total_columns=metadata['total_columns'],
|
||||
is_processed=metadata['is_processed'],
|
||||
processing_error=metadata['processing_error']
|
||||
)
|
||||
|
||||
|
||||
# Save to database
|
||||
self.db.add(excel_file)
|
||||
self.db.commit()
|
||||
self.db.refresh(excel_file)
|
||||
|
||||
logger.info(f"Saved metadata for file {original_filename} with ID {excel_file.id}")
|
||||
return excel_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata for {original_filename}: {str(e)}")
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
def get_user_files(self, user_id: int, skip: int = 0, limit: int = 50) -> Tuple[List[ExcelFile], int]:
|
||||
"""Get Excel files for a user with pagination."""
|
||||
try:
|
||||
# Get total count
|
||||
total = self.db.query(ExcelFile).filter(ExcelFile.created_by == user_id).count()
|
||||
|
||||
# Get files with pagination
|
||||
files = (self.db.query(ExcelFile)
|
||||
.filter(ExcelFile.created_by == user_id)
|
||||
.order_by(ExcelFile.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all())
|
||||
|
||||
return files, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user files for user {user_id}: {str(e)}")
|
||||
return [], 0
|
||||
|
||||
def get_file_by_id(self, file_id: int, user_id: int) -> Optional[ExcelFile]:
|
||||
"""Get Excel file by ID and user ID."""
|
||||
try:
|
||||
return (self.db.query(ExcelFile)
|
||||
.filter(ExcelFile.id == file_id, ExcelFile.created_by == user_id)
|
||||
.first())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file {file_id} for user {user_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: int, user_id: int) -> bool:
|
||||
"""Delete Excel file record and physical file."""
|
||||
try:
|
||||
# Get file record
|
||||
excel_file = self.get_file_by_id(file_id, user_id)
|
||||
if not excel_file:
|
||||
return False
|
||||
|
||||
# Delete physical file if exists
|
||||
if os.path.exists(excel_file.file_path):
|
||||
os.remove(excel_file.file_path)
|
||||
logger.info(f"Deleted physical file: {excel_file.file_path}")
|
||||
|
||||
# Delete database record
|
||||
self.db.delete(excel_file)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Deleted Excel file record with ID {file_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {file_id}: {str(e)}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def update_last_accessed(self, file_id: int, user_id: int) -> bool:
|
||||
"""Update last accessed time for a file."""
|
||||
try:
|
||||
excel_file = self.get_file_by_id(file_id, user_id)
|
||||
if not excel_file:
|
||||
return False
|
||||
|
||||
from sqlalchemy.sql import func
|
||||
excel_file.last_accessed = func.now()
|
||||
self.db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating last accessed for file {file_id}: {str(e)}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def get_file_summary_for_llm(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get file summary information for LLM context."""
|
||||
try:
|
||||
files = self.db.query(ExcelFile).filter(ExcelFile.user_id == user_id).all()
|
||||
|
||||
summary = []
|
||||
for file in files:
|
||||
file_info = {
|
||||
'file_id': file.id,
|
||||
'filename': file.original_filename,
|
||||
'file_type': file.file_type,
|
||||
'sheets': file.get_all_sheets_summary(),
|
||||
'upload_time': file.upload_time.isoformat() if file.upload_time else None
|
||||
}
|
||||
summary.append(file_info)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file summary for user {user_id}: {str(e)}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
"""Knowledge base service."""
|
||||
|
||||
# Standard library imports
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
# Third-party imports
|
||||
from loguru import logger
|
||||
from sqlalchemy import select, and_, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Local imports
|
||||
from ..core.config import get_settings
|
||||
from ..core.context import UserContext
|
||||
from ..models.knowledge_base import KnowledgeBase
|
||||
from .document_processor import get_document_processor
|
||||
from utils.util_schemas import KnowledgeBaseCreate, KnowledgeBaseUpdate
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
class KnowledgeBaseService:
|
||||
"""知识库基础服务类,用于管理知识基础。
|
||||
|
||||
该服务类提供了创建、获取、更新、删除和搜索知识库基础的功能。
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
"""初始化知识库基础服务类。
|
||||
|
||||
Args:
|
||||
session (Session): 数据库会话,用于执行ORM操作。
|
||||
"""
|
||||
self.session = session
|
||||
|
||||
def create_knowledge_base(self, kb_data: KnowledgeBaseCreate) -> KnowledgeBase:
|
||||
"""创建一个新的知识库实例。
|
||||
|
||||
Args:
|
||||
kb_data (KnowledgeBaseCreate): 用于创建知识库实例的数据。
|
||||
|
||||
Returns:
|
||||
KnowledgeBase: 创建的知识库实例。
|
||||
|
||||
Raises:
|
||||
Exception: 如果创建过程中发生错误。
|
||||
"""
|
||||
try:
|
||||
# Generate collection name for vector database
|
||||
collection_name = f"kb_{kb_data.name.lower().replace(' ', '_').replace('-', '_')}"
|
||||
|
||||
kb = KnowledgeBase(
|
||||
name=kb_data.name,
|
||||
description=kb_data.description,
|
||||
embedding_model=kb_data.embedding_model,
|
||||
chunk_size=kb_data.chunk_size,
|
||||
chunk_overlap=kb_data.chunk_overlap,
|
||||
vector_db_type=settings.vector_db.type,
|
||||
collection_name=collection_name
|
||||
)
|
||||
|
||||
# Set audit fields
|
||||
kb.set_audit_fields()
|
||||
|
||||
self.session.add(kb)
|
||||
self.session.commit()
|
||||
self.session.refresh(kb)
|
||||
|
||||
logger.info(f"Created knowledge base: {kb.name} (ID: {kb.id})")
|
||||
return kb
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"Failed to create knowledge base: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_knowledge_base(self, kb_id: int) -> Optional[KnowledgeBase]:
|
||||
"""根据ID获取知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 知识库实例的ID。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def get_knowledge_base_by_name(self, name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取当前用户的知识库实例。
|
||||
|
||||
Args:
|
||||
name (str): 知识库实例的名称。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回知识库实例,否则返回None。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.name == name,
|
||||
KnowledgeBase.created_by == UserContext.get_current_user().id
|
||||
)
|
||||
return self.session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
async def get_knowledge_bases(self, skip: int = 0, limit: int = 50, active_only: bool = True) -> List[KnowledgeBase]:
|
||||
"""获取当前用户的所有知识库的列表。
|
||||
|
||||
Args:
|
||||
skip (int, optional): 跳过的记录数。默认值为0。
|
||||
limit (int, optional): 返回的最大记录数。默认值为50。
|
||||
active_only (bool, optional): 是否仅返回活动的知识库。默认值为True。
|
||||
|
||||
Returns:
|
||||
List[KnowledgeBase]: 当前用户的知识库列表。
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.created_by == UserContext.get_current_user().id)
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(KnowledgeBase.is_active == True)
|
||||
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
return (await self.session.execute(stmt)).scalars().all()
|
||||
|
||||
|
||||
def update_knowledge_base(self, kb_id: int, kb_update: KnowledgeBaseUpdate) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待更新的知识库实例ID。
|
||||
kb_update (KnowledgeBaseUpdate): 用于更新知识库实例的数据。
|
||||
|
||||
Returns:
|
||||
Optional[KnowledgeBase]: 如果找到则返回更新后的知识库实例,否则返回None。
|
||||
|
||||
Raises:
|
||||
Exception: 如果更新过程中发生错误。
|
||||
"""
|
||||
try:
|
||||
kb = self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
# Update fields
|
||||
update_data = kb_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(kb, field, value)
|
||||
|
||||
# Set audit fields
|
||||
kb.set_audit_fields(is_update=True)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(kb)
|
||||
|
||||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb.name} (ID: {kb.id})"
|
||||
return kb
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
self.session.desc = f"[KNOWLEDGE_BASE] 更新知识库 {kb_id} 失败: {str(e)}"
|
||||
raise
|
||||
|
||||
def delete_knowledge_base(self, kb_id: int) -> bool:
|
||||
"""删除知识库实例。
|
||||
|
||||
Args:
|
||||
kb_id (int): 待删除的知识库实例ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果知识库实例被成功删除则返回True,否则返回False。
|
||||
|
||||
Raises:
|
||||
Exception: 如果删除过程中发生错误。
|
||||
"""
|
||||
kb = self.get_knowledge_base(kb_id)
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
# TODO: Clean up vector database collection
|
||||
# This should be implemented when vector database service is ready
|
||||
|
||||
self.session.delete(kb)
|
||||
self.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def search_knowledge_bases(self, query: str, skip: int = 0, limit: int = 50) -> List[KnowledgeBase]:
|
||||
"""Search knowledge bases by name or description for the current user.
|
||||
|
||||
Args:
|
||||
query (str): Search query.
|
||||
skip (int, optional): Number of records to skip. Defaults to 0.
|
||||
limit (int, optional): Maximum number of records to return. Defaults to 50.
|
||||
|
||||
Returns:
|
||||
List[KnowledgeBase]: List of matching knowledge bases.
|
||||
"""
|
||||
stmt = select(KnowledgeBase).where(
|
||||
KnowledgeBase.created_by == UserContext.get_current_user().id,
|
||||
KnowledgeBase.is_active == True,
|
||||
or_(
|
||||
KnowledgeBase.name.ilike(f"%{query}%"),
|
||||
KnowledgeBase.description.ilike(f"%{query}%")
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
async def search(self, kb_id: int, query: str, top_k: int = 5, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]:
|
||||
"""Search in knowledge base using vector similarity.
|
||||
|
||||
Args:
|
||||
kb_id (int): ID of the knowledge base to search in.
|
||||
query (str): Search query.
|
||||
top_k (int, optional): Maximum number of results to return. Defaults to 5.
|
||||
similarity_threshold (float, optional): Minimum similarity score for results. Defaults to 0.7.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of search results with content, source, score, and metadata.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Searching in knowledge base {kb_id} for: {query}")
|
||||
|
||||
# Use document processor for vector search
|
||||
search_results = get_document_processor().search_similar_documents(
|
||||
knowledge_base_id=kb_id,
|
||||
query=query,
|
||||
k=top_k
|
||||
)
|
||||
|
||||
# Filter by similarity threshold
|
||||
filtered_results = []
|
||||
for result in search_results:
|
||||
# Use already normalized similarity score
|
||||
normalized_score = result.get('normalized_score', 0)
|
||||
|
||||
if normalized_score >= similarity_threshold:
|
||||
filtered_results.append({
|
||||
"content": result.get('content', ''),
|
||||
"source": result.get('source', 'unknown'),
|
||||
"score": normalized_score,
|
||||
"metadata": result.get('metadata', {}),
|
||||
"document_id": result.get('document_id', 'unknown'),
|
||||
"chunk_id": result.get('chunk_id', 'unknown')
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(filtered_results)} relevant documents (threshold: {similarity_threshold})")
|
||||
return filtered_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed for knowledge base {kb_id}: {str(e)}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,369 @@
|
|||
"""Knowledge base chat service using LangChain RAG."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_postgres import PGVector
|
||||
from .embedding_factory import EmbeddingFactory
|
||||
|
||||
from ..core.config import settings
|
||||
from ..models.message import MessageRole
|
||||
from utils.util_schemas import ChatResponse, MessageResponse
|
||||
from utils.util_exceptions import ChatServiceError
|
||||
from .conversation import ConversationService
|
||||
from .document_processor import get_document_processor
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class KnowledgeChatService:
|
||||
"""Knowledge base chat service using LangChain RAG."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
# 获取当前LLM配置
|
||||
llm_config = settings.llm.get_current_config()
|
||||
|
||||
# Initialize LangChain ChatOpenAI
|
||||
self.llm = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
api_key=llm_config["api_key"],
|
||||
base_url=llm_config["base_url"],
|
||||
temperature=llm_config["temperature"],
|
||||
max_tokens=llm_config["max_tokens"],
|
||||
streaming=False
|
||||
)
|
||||
|
||||
# Streaming LLM for stream responses
|
||||
self.streaming_llm = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
api_key=llm_config["api_key"],
|
||||
base_url=llm_config["base_url"],
|
||||
temperature=llm_config["temperature"],
|
||||
max_tokens=llm_config["max_tokens"],
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Initialize embeddings based on provider
|
||||
self.embeddings = EmbeddingFactory.create_embeddings()
|
||||
|
||||
logger.info(f"Knowledge Chat Service initialized with model: {self.llm.model_name}")
|
||||
|
||||
def _get_vector_store(self, knowledge_base_id: int) -> Optional[PGVector]:
|
||||
"""Get vector store for knowledge base."""
|
||||
try:
|
||||
if settings.vector_db.type == "pgvector":
|
||||
# 使用PGVector
|
||||
doc_processor = get_document_processor()
|
||||
collection_name = f"{settings.vector_db.pgvector_table_name}_kb_{knowledge_base_id}"
|
||||
|
||||
vector_store = PGVector(
|
||||
connection=doc_processor.connection_string,
|
||||
embeddings=self.embeddings,
|
||||
collection_name=collection_name,
|
||||
use_jsonb=True
|
||||
)
|
||||
|
||||
return vector_store
|
||||
else:
|
||||
# 兼容Chroma模式
|
||||
import os
|
||||
kb_vector_path = os.path.join(get_document_processor().vector_db_path, f"kb_{knowledge_base_id}")
|
||||
|
||||
if not os.path.exists(kb_vector_path):
|
||||
logger.warning(f"Vector store not found for knowledge base {knowledge_base_id}")
|
||||
return None
|
||||
|
||||
vector_store = Chroma(
|
||||
persist_directory=kb_vector_path,
|
||||
embedding_function=self.embeddings
|
||||
)
|
||||
|
||||
return vector_store
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load vector store for KB {knowledge_base_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_rag_chain(self, vector_store, conversation_history: List[Dict[str, str]]):
|
||||
"""Create RAG chain with conversation history."""
|
||||
|
||||
# Create retriever
|
||||
retriever = vector_store.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": 5}
|
||||
)
|
||||
|
||||
# Create prompt template
|
||||
system_prompt = """你是一个智能助手,基于提供的上下文信息回答用户问题。
|
||||
|
||||
上下文信息:
|
||||
{context}
|
||||
|
||||
请根据上下文信息回答用户的问题。如果上下文信息不足以回答问题,请诚实地说明。
|
||||
保持回答准确、有用且简洁。"""
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", system_prompt),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}")
|
||||
])
|
||||
|
||||
# Create chain
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
rag_chain = (
|
||||
{
|
||||
"context": retriever | format_docs,
|
||||
"question": RunnablePassthrough(),
|
||||
"chat_history": lambda x: conversation_history
|
||||
}
|
||||
| prompt
|
||||
| self.llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
return rag_chain, retriever
|
||||
|
||||
def _prepare_conversation_history(self, messages: List) -> List[Dict[str, str]]:
|
||||
"""Prepare conversation history for RAG chain."""
|
||||
history = []
|
||||
|
||||
for msg in messages[:-1]: # Exclude the last message (current user message)
|
||||
if msg.role == MessageRole.USER:
|
||||
history.append({"role": "human", "content": msg.content})
|
||||
elif msg.role == MessageRole.ASSISTANT:
|
||||
history.append({"role": "assistant", "content": msg.content})
|
||||
|
||||
return history
|
||||
|
||||
async def chat_with_knowledge_base(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
knowledge_base_id: int,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> ChatResponse:
|
||||
"""Chat with knowledge base using RAG."""
|
||||
|
||||
try:
|
||||
# Get conversation and validate
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
# Get vector store
|
||||
vector_store = self._get_vector_store(knowledge_base_id)
|
||||
if not vector_store:
|
||||
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
|
||||
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
conversation_history = self._prepare_conversation_history(messages)
|
||||
|
||||
# Create RAG chain
|
||||
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
|
||||
|
||||
# Get relevant documents for context
|
||||
relevant_docs = retriever.get_relevant_documents(message)
|
||||
context_documents = []
|
||||
|
||||
for doc in relevant_docs:
|
||||
context_documents.append({
|
||||
"content": doc.page_content[:500], # Limit content length
|
||||
"metadata": doc.metadata,
|
||||
"source": doc.metadata.get("filename", "unknown")
|
||||
})
|
||||
|
||||
# Generate response
|
||||
if stream:
|
||||
# For streaming, we'll use a different approach
|
||||
response_content = await self._generate_streaming_response(
|
||||
rag_chain, message, conversation_id
|
||||
)
|
||||
else:
|
||||
response_content = await asyncio.to_thread(rag_chain.invoke, message)
|
||||
|
||||
# Save assistant message with context
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=response_content,
|
||||
role=MessageRole.ASSISTANT,
|
||||
context_documents=context_documents
|
||||
)
|
||||
|
||||
# Create response
|
||||
return ChatResponse(
|
||||
user_message=MessageResponse.from_orm(user_message),
|
||||
assistant_message=MessageResponse.from_orm(assistant_message),
|
||||
model_used=self.llm.model_name,
|
||||
total_tokens=None # TODO: Calculate tokens if needed
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge base chat failed: {str(e)}")
|
||||
raise ChatServiceError(f"Knowledge base chat failed: {str(e)}")
|
||||
|
||||
async def _generate_streaming_response(
|
||||
self,
|
||||
rag_chain,
|
||||
message: str,
|
||||
conversation_id: int
|
||||
) -> str:
|
||||
"""Generate streaming response (placeholder for now)."""
|
||||
# For now, use non-streaming approach
|
||||
# TODO: Implement proper streaming with RAG chain
|
||||
return await asyncio.to_thread(rag_chain.invoke, message)
|
||||
|
||||
async def chat_stream_with_knowledge_base(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
knowledge_base_id: int,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Chat with knowledge base using RAG with streaming response."""
|
||||
|
||||
try:
|
||||
|
||||
# Get vector store
|
||||
vector_store = self._get_vector_store(knowledge_base_id)
|
||||
if not vector_store:
|
||||
raise ChatServiceError(f"Knowledge base {knowledge_base_id} not found or not processed")
|
||||
|
||||
# Get conversation history
|
||||
messages = self.conversation_service.get_conversation_messages(conversation_id)
|
||||
conversation_history = self._prepare_conversation_history(messages)
|
||||
|
||||
# Create RAG chain
|
||||
rag_chain, retriever = self._create_rag_chain(vector_store, conversation_history)
|
||||
|
||||
# Save user message
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get relevant documents
|
||||
relevant_docs = retriever.get_relevant_documents(message)
|
||||
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
||||
|
||||
# Create streaming LLM
|
||||
llm_config = settings.llm.get_current_config()
|
||||
streaming_llm = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
temperature=temperature or llm_config["temperature"],
|
||||
max_tokens=max_tokens or llm_config["max_tokens"],
|
||||
streaming=True,
|
||||
api_key=llm_config["api_key"],
|
||||
base_url=llm_config["base_url"]
|
||||
)
|
||||
|
||||
# Create prompt for streaming
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "你是一个智能助手。请基于以下上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地说明。\n\n上下文信息:\n{context}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", "{question}")
|
||||
])
|
||||
|
||||
# Prepare chat history for prompt
|
||||
chat_history_messages = []
|
||||
for hist in conversation_history:
|
||||
if hist["role"] == "human":
|
||||
chat_history_messages.append(HumanMessage(content=hist["content"]))
|
||||
elif hist["role"] == "assistant":
|
||||
chat_history_messages.append(AIMessage(content=hist["content"]))
|
||||
|
||||
# Create streaming chain
|
||||
streaming_chain = (
|
||||
{
|
||||
"context": lambda x: context,
|
||||
"chat_history": lambda x: chat_history_messages,
|
||||
"question": lambda x: x["question"]
|
||||
}
|
||||
| prompt
|
||||
| streaming_llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
# Generate streaming response
|
||||
full_response = ""
|
||||
async for chunk in streaming_chain.astream({"question": message}):
|
||||
if chunk:
|
||||
full_response += chunk
|
||||
yield chunk
|
||||
|
||||
# Save assistant response
|
||||
if full_response:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
"relevant_docs_count": len(relevant_docs)
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in knowledge base streaming chat: {str(e)}")
|
||||
error_message = f"知识库对话出错: {str(e)}"
|
||||
yield error_message
|
||||
|
||||
# Save error message
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
|
||||
async def search_knowledge_base(
|
||||
self,
|
||||
knowledge_base_id: int,
|
||||
query: str,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search knowledge base for relevant documents."""
|
||||
|
||||
try:
|
||||
vector_store = self._get_vector_store(knowledge_base_id)
|
||||
if not vector_store:
|
||||
return []
|
||||
|
||||
# Perform similarity search
|
||||
results = vector_store.similarity_search_with_score(query, k=k)
|
||||
|
||||
formatted_results = []
|
||||
for doc, score in results:
|
||||
formatted_results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"similarity_score": float(score),
|
||||
"source": doc.metadata.get("filename", "unknown")
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge base search failed: {str(e)}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
"""LangChain-based chat service."""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional, List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from ..core.config import settings
|
||||
from ..models.message import MessageRole
|
||||
from utils.util_schemas import ChatResponse, StreamChunk, MessageResponse
|
||||
from utils.util_exceptions import ChatServiceError, OpenAIError, AuthenticationError, RateLimitError
|
||||
from loguru import logger
|
||||
from .conversation import ConversationService
|
||||
|
||||
|
||||
class StreamingCallbackHandler(BaseCallbackHandler):
|
||||
"""Custom callback handler for streaming responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.tokens = []
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
||||
"""Handle new token from LLM."""
|
||||
self.tokens.append(token)
|
||||
|
||||
def get_response(self) -> str:
|
||||
"""Get the complete response."""
|
||||
return "".join(self.tokens)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the tokens."""
|
||||
self.tokens = []
|
||||
|
||||
|
||||
class LangChainChatService:
|
||||
"""LangChain-based chat service for AI model integration."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
from ..core.llm import create_llm
|
||||
|
||||
# 添加调试日志
|
||||
logger.info(f"LLM Provider: {settings.llm.provider}")
|
||||
|
||||
# Initialize LangChain ChatOpenAI
|
||||
self.llm = create_llm(streaming=False)
|
||||
|
||||
# Streaming LLM for stream responses
|
||||
self.streaming_llm = create_llm(streaming=True)
|
||||
|
||||
self.streaming_handler = StreamingCallbackHandler()
|
||||
|
||||
logger.info(f"LangChain ChatService initialized with model: {self.llm.model_name}")
|
||||
|
||||
def _prepare_langchain_messages(self, conversation, history: List) -> List:
|
||||
"""Prepare messages for LangChain format."""
|
||||
messages = []
|
||||
|
||||
# Add system message if conversation has system prompt
|
||||
if hasattr(conversation, 'system_prompt') and conversation.system_prompt:
|
||||
messages.append(SystemMessage(content=conversation.system_prompt))
|
||||
else:
|
||||
# Default system message
|
||||
messages.append(SystemMessage(
|
||||
content="You are a helpful AI assistant. Please provide accurate and helpful responses."
|
||||
))
|
||||
|
||||
# Add conversation history
|
||||
for msg in history[:-1]: # Exclude the last message (current user message)
|
||||
if msg.role == MessageRole.USER:
|
||||
messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == MessageRole.ASSISTANT:
|
||||
messages.append(AIMessage(content=msg.content))
|
||||
|
||||
# Add current user message
|
||||
if history:
|
||||
last_msg = history[-1]
|
||||
if last_msg.role == MessageRole.USER:
|
||||
messages.append(HumanMessage(content=last_msg.content))
|
||||
|
||||
return messages
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> ChatResponse:
|
||||
"""Send a message and get AI response using LangChain."""
|
||||
logger.info(f"Processing LangChain chat request for conversation {conversation_id}")
|
||||
|
||||
try:
|
||||
# Get conversation details
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
# Add user message to database
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = self.conversation_service.get_conversation_history(
|
||||
conversation_id, limit=20
|
||||
)
|
||||
|
||||
# Prepare messages for LangChain
|
||||
langchain_messages = self._prepare_langchain_messages(conversation, history)
|
||||
|
||||
# Update LLM parameters if provided
|
||||
llm_to_use = self.llm
|
||||
if temperature is not None or max_tokens is not None:
|
||||
llm_config = settings.llm.get_current_config()
|
||||
llm_to_use = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
openai_api_key=llm_config["api_key"],
|
||||
openai_api_base=llm_config["base_url"],
|
||||
temperature=temperature if temperature is not None else float(conversation.temperature),
|
||||
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
# Call LangChain LLM
|
||||
response = await llm_to_use.ainvoke(langchain_messages)
|
||||
|
||||
# Extract response content
|
||||
assistant_content = response.content
|
||||
|
||||
# Add assistant message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=assistant_content,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={
|
||||
"model": llm_to_use.model_name,
|
||||
"langchain_version": "0.1.0",
|
||||
"provider": "langchain_openai"
|
||||
}
|
||||
)
|
||||
|
||||
# Update conversation timestamp
|
||||
self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
|
||||
logger.info(f"Successfully processed LangChain chat request for conversation {conversation_id}")
|
||||
|
||||
return ChatResponse(
|
||||
user_message=MessageResponse.from_orm(user_message),
|
||||
assistant_message=MessageResponse.from_orm(assistant_message),
|
||||
total_tokens=None, # LangChain doesn't provide token count by default
|
||||
model_used=llm_to_use.model_name
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process LangChain chat request for conversation {conversation_id}: {str(e)}", exc_info=True)
|
||||
|
||||
# Classify error types for better handling
|
||||
error_type = type(e).__name__
|
||||
error_message = self._format_error_message(e)
|
||||
|
||||
# Add error message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={
|
||||
"error": True,
|
||||
"error_type": error_type,
|
||||
"original_error": str(e),
|
||||
"langchain_error": True
|
||||
}
|
||||
)
|
||||
|
||||
# Re-raise specific exceptions for proper error handling
|
||||
if "rate limit" in str(e).lower():
|
||||
raise RateLimitError(str(e))
|
||||
elif "api key" in str(e).lower() or "authentication" in str(e).lower():
|
||||
raise AuthenticationError(str(e))
|
||||
elif "openai" in str(e).lower():
|
||||
raise OpenAIError(str(e))
|
||||
|
||||
return ChatResponse(
|
||||
user_message=MessageResponse.from_orm(user_message),
|
||||
assistant_message=MessageResponse.from_orm(assistant_message),
|
||||
total_tokens=0,
|
||||
model_used=self.llm.model_name
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
conversation_id: int,
|
||||
message: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Send a message and get streaming AI response using LangChain."""
|
||||
logger.info(f"Processing LangChain streaming chat request for conversation {conversation_id}")
|
||||
|
||||
try:
|
||||
# Get conversation details
|
||||
conversation = self.conversation_service.get_conversation(conversation_id)
|
||||
if not conversation:
|
||||
raise ChatServiceError("Conversation not found")
|
||||
|
||||
# Add user message to database
|
||||
user_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
|
||||
# Get conversation history for context
|
||||
history = self.conversation_service.get_conversation_history(
|
||||
conversation_id, limit=20
|
||||
)
|
||||
|
||||
# Prepare messages for LangChain
|
||||
langchain_messages = self._prepare_langchain_messages(conversation, history)
|
||||
|
||||
# Update streaming LLM parameters if provided
|
||||
streaming_llm_to_use = self.streaming_llm
|
||||
if temperature is not None or max_tokens is not None:
|
||||
llm_config = settings.llm.get_current_config()
|
||||
streaming_llm_to_use = ChatOpenAI(
|
||||
model=llm_config["model"],
|
||||
openai_api_key=llm_config["api_key"],
|
||||
openai_api_base=llm_config["base_url"],
|
||||
temperature=temperature if temperature is not None else float(conversation.temperature),
|
||||
max_tokens=max_tokens if max_tokens is not None else conversation.max_tokens,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Clear previous streaming handler state
|
||||
self.streaming_handler.clear()
|
||||
|
||||
# Stream response
|
||||
full_response = ""
|
||||
async for chunk in streaming_llm_to_use.astream(langchain_messages):
|
||||
# Handle different chunk types to avoid KeyError
|
||||
chunk_content = None
|
||||
if hasattr(chunk, 'content'):
|
||||
# For object-like chunks with content attribute
|
||||
chunk_content = chunk.content
|
||||
elif isinstance(chunk, dict) and 'content' in chunk:
|
||||
# For dict-like chunks with content key
|
||||
chunk_content = chunk['content']
|
||||
elif isinstance(chunk, dict) and 'error' in chunk:
|
||||
# Handle error chunks explicitly
|
||||
logger.error(f"Error in LLM response: {chunk['error']}")
|
||||
yield self._format_error_message(Exception(chunk['error']))
|
||||
continue
|
||||
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
yield chunk_content
|
||||
|
||||
# Add complete assistant message to database
|
||||
assistant_message = self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=full_response,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={
|
||||
"model": streaming_llm_to_use.model_name,
|
||||
"langchain_version": "0.1.0",
|
||||
"provider": "langchain_openai",
|
||||
"streaming": True
|
||||
}
|
||||
)
|
||||
|
||||
# Update conversation timestamp
|
||||
self.conversation_service.update_conversation_timestamp(conversation_id)
|
||||
|
||||
logger.info(f"Successfully processed LangChain streaming chat request for conversation {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
# 安全地格式化异常信息,避免再次引发KeyError
|
||||
error_info = f"Failed to process LangChain streaming chat request for conversation {conversation_id}"
|
||||
logger.error(error_info, exc_info=True)
|
||||
|
||||
# Format error message for user
|
||||
error_message = self._format_error_message(e)
|
||||
yield error_message
|
||||
|
||||
# Add error message to database
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
content=error_message,
|
||||
role=MessageRole.ASSISTANT,
|
||||
message_metadata={
|
||||
"error": True,
|
||||
"error_type": type(e).__name__,
|
||||
"original_error": str(e),
|
||||
"langchain_error": True,
|
||||
"streaming": True
|
||||
}
|
||||
)
|
||||
|
||||
async def get_available_models(self) -> List[str]:
|
||||
"""Get list of available models from LangChain."""
|
||||
try:
|
||||
# LangChain doesn't have a direct method to list models
|
||||
# Return commonly available OpenAI models
|
||||
return [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini"
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get available models: {str(e)}")
|
||||
return ["gpt-3.5-turbo"]
|
||||
|
||||
def update_model_config(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
):
|
||||
"""Update LLM configuration."""
|
||||
from ..core.llm import create_llm
|
||||
|
||||
# 重新创建LLM实例
|
||||
self.llm = create_llm(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
self.streaming_llm = create_llm(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
logger.info(f"Updated LLM configuration: model={model}, temperature={temperature}, max_tokens={max_tokens}")
|
||||
|
||||
def _format_error_message(self, error: Exception) -> str:
|
||||
"""Format error message for user display."""
|
||||
error_type = type(error).__name__
|
||||
error_str = str(error)
|
||||
|
||||
# Provide user-friendly error messages
|
||||
if "rate limit" in error_str.lower():
|
||||
return "服务器繁忙,请稍后再试。"
|
||||
elif "api key" in error_str.lower() or "authentication" in error_str.lower():
|
||||
return "API认证失败,请检查配置。"
|
||||
elif "timeout" in error_str.lower():
|
||||
return "请求超时,请重试。"
|
||||
elif "connection" in error_str.lower():
|
||||
return "网络连接错误,请检查网络连接。"
|
||||
elif "model" in error_str.lower() and "not found" in error_str.lower():
|
||||
return "指定的模型不可用,请选择其他模型。"
|
||||
else:
|
||||
return f"处理请求时发生错误:{error_str}"
|
||||
|
||||
async def _retry_with_backoff(self, func, max_retries: int = 3, base_delay: float = 1.0):
|
||||
"""Retry function with exponential backoff."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func()
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
# Check if error is retryable
|
||||
if not self._is_retryable_error(e):
|
||||
raise e
|
||||
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay}s: {str(e)}")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""Check if an error is retryable."""
|
||||
error_str = str(error).lower()
|
||||
retryable_errors = [
|
||||
"timeout",
|
||||
"connection",
|
||||
"server error",
|
||||
"internal error",
|
||||
"rate limit"
|
||||
]
|
||||
return any(err in error_str for err in retryable_errors)
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
"""LLM配置服务 - 从数据库读取默认配置"""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from ..models.llm_config import LLMConfig
|
||||
from ..db.database import get_session
|
||||
from loguru import logger
|
||||
|
||||
class LLMConfigService:
|
||||
"""LLM配置管理服务"""
|
||||
|
||||
def __init__(self, db_session: Optional[Session] = None):
|
||||
self.db = db_session or get_session() # TODO DrGraph:检查异步
|
||||
|
||||
def get_default_chat_config(self) -> Optional[LLMConfig]:
|
||||
"""获取默认对话模型配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(
|
||||
and_(
|
||||
LLMConfig.is_default == True,
|
||||
LLMConfig.is_embedding == False,
|
||||
LLMConfig.is_active == True
|
||||
)
|
||||
)
|
||||
config = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
logger.warning("未找到默认对话模型配置")
|
||||
return None
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取默认对话模型配置失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_default_embedding_config(self) -> Optional[LLMConfig]:
|
||||
"""获取默认嵌入模型配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(
|
||||
and_(
|
||||
LLMConfig.is_default == True,
|
||||
LLMConfig.is_embedding == True,
|
||||
LLMConfig.is_active == True
|
||||
)
|
||||
)
|
||||
config = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
logger.warning("未找到默认嵌入模型配置")
|
||||
return None
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取默认嵌入模型配置失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_config_by_id(self, config_id: int) -> Optional[LLMConfig]:
|
||||
"""根据ID获取配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_active_configs(self, is_embedding: Optional[bool] = None) -> List[LLMConfig]:
|
||||
"""获取所有激活的配置"""
|
||||
try:
|
||||
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
|
||||
|
||||
if is_embedding is not None:
|
||||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
stmt = stmt.order_by(LLMConfig.created_at)
|
||||
return self.db.execute(stmt).scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取激活配置失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def _get_fallback_chat_config(self) -> Dict[str, Any]:
|
||||
"""获取fallback对话模型配置(从环境变量)"""
|
||||
from ..core.config import get_settings
|
||||
settings = get_settings()
|
||||
return settings.llm.get_current_config()
|
||||
|
||||
def _get_fallback_embedding_config(self) -> Dict[str, Any]:
|
||||
"""获取fallback嵌入模型配置(从环境变量)"""
|
||||
from ..core.config import get_settings
|
||||
settings = get_settings()
|
||||
return settings.embedding.get_current_config()
|
||||
|
||||
def test_config(self, config_id: int, test_message: str = "Hello") -> Dict[str, Any]:
|
||||
"""测试配置连接"""
|
||||
try:
|
||||
config = self.get_config_by_id(config_id)
|
||||
if not config:
|
||||
return {"success": False, "error": "配置不存在"}
|
||||
|
||||
# 这里可以添加实际的连接测试逻辑
|
||||
# 例如发送一个简单的请求来验证配置是否有效
|
||||
|
||||
return {"success": True, "message": "配置测试成功"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试配置失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
# 全局实例
|
||||
_llm_config_service = None
|
||||
|
||||
def get_llm_config_service(db_session: Optional[Session] = None) -> LLMConfigService:
|
||||
"""获取LLM配置服务实例"""
|
||||
global _llm_config_service
|
||||
if _llm_config_service is None or db_session is not None:
|
||||
_llm_config_service = LLMConfigService(db_session)
|
||||
return _llm_config_service
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""LLM service for workflow execution."""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
from ..models.llm_config import LLMConfig
|
||||
from loguru import logger
|
||||
|
||||
class LLMService:
|
||||
"""LLM服务,用于工作流中的大模型调用"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_config: LLMConfig,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> str:
|
||||
"""调用大模型进行对话完成"""
|
||||
try:
|
||||
# 创建LangChain ChatOpenAI实例
|
||||
llm = ChatOpenAI(
|
||||
model=model_config.model_name,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.base_url,
|
||||
temperature=temperature or model_config.temperature,
|
||||
max_tokens=max_tokens or model_config.max_tokens,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
# 转换消息格式
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
langchain_messages.append(SystemMessage(content=content))
|
||||
elif role == "user":
|
||||
langchain_messages.append(HumanMessage(content=content))
|
||||
elif role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=content))
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(langchain_messages)
|
||||
|
||||
# 返回响应内容
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM调用失败: {str(e)}")
|
||||
raise Exception(f"LLM调用失败: {str(e)}")
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
model_config: LLMConfig,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""调用大模型进行流式对话完成"""
|
||||
try:
|
||||
# 创建LangChain ChatOpenAI实例(流式)
|
||||
llm = ChatOpenAI(
|
||||
model=model_config.model_name,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.base_url,
|
||||
temperature=temperature or model_config.temperature,
|
||||
max_tokens=max_tokens or model_config.max_tokens,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# 转换消息格式
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
langchain_messages.append(SystemMessage(content=content))
|
||||
elif role == "user":
|
||||
langchain_messages.append(HumanMessage(content=content))
|
||||
elif role == "assistant":
|
||||
langchain_messages.append(AIMessage(content=content))
|
||||
|
||||
# 流式调用LLM
|
||||
async for chunk in llm.astream(langchain_messages):
|
||||
if hasattr(chunk, 'content') and chunk.content:
|
||||
yield chunk.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM流式调用失败: {str(e)}")
|
||||
raise Exception(f"LLM流式调用失败: {str(e)}")
|
||||
|
||||
def get_model_info(self, model_config: LLMConfig) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
"id": model_config.id,
|
||||
"name": model_config.model_name,
|
||||
"provider": model_config.provider,
|
||||
"base_url": model_config.base_url,
|
||||
"temperature": model_config.temperature,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
"is_active": model_config.is_active
|
||||
}
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
"""Dynamic MCP tool wrapper for LangChain/LangGraph.
|
||||
|
||||
Fetches available MCP tools from the MCP server and exposes them as LangChain BaseTool
|
||||
instances that call the MCP `/execute` endpoint at runtime.
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
import json
|
||||
import requests
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from th_agenter.core.config import get_settings
|
||||
from loguru import logger
|
||||
import os
|
||||
|
||||
# Map MCP parameter types to Python type hints
|
||||
_TYPE_MAP: Dict[str, Any] = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": List[Any],
|
||||
"object": Dict[str, Any],
|
||||
}
|
||||
|
||||
|
||||
def _build_args_schema(params: List[Dict[str, Any]]) -> Type[BaseModel]:
|
||||
"""Build a Pydantic BaseModel class dynamically from MCP tool params."""
|
||||
annotations: Dict[str, Any] = {}
|
||||
fields: Dict[str, Any] = {}
|
||||
|
||||
for p in params:
|
||||
name = p.get("name")
|
||||
ptype = p.get("type", "string")
|
||||
required = p.get("required", True)
|
||||
default = p.get("default", None)
|
||||
description = p.get("description", "")
|
||||
enum = p.get("enum")
|
||||
|
||||
py_type = _TYPE_MAP.get(ptype, Any)
|
||||
annotations[name] = py_type
|
||||
|
||||
if enum is not None and default is None:
|
||||
# if enum present without default, keep required unless specified
|
||||
field_default = ... if required else None
|
||||
else:
|
||||
field_default = ... if required and default is None else default
|
||||
|
||||
fields[name] = Field(
|
||||
default=field_default,
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Create model class
|
||||
namespace = {"__annotations__": annotations}
|
||||
namespace.update(fields)
|
||||
return type("MCPToolArgs", (BaseModel,), namespace)
|
||||
|
||||
|
||||
class MCPDynamicTool(BaseTool):
|
||||
"""LangChain BaseTool wrapper that executes MCP tools via HTTP."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: Type[BaseModel]
|
||||
|
||||
_mcp_base_url: str = PrivateAttr()
|
||||
_tool_name: str = PrivateAttr()
|
||||
|
||||
def __init__(self, mcp_base_url: str, tool_info: Dict[str, Any]):
|
||||
# Initialize BaseTool with dynamic metadata
|
||||
super().__init__(
|
||||
name=tool_info.get("name", "tool"),
|
||||
description=tool_info.get("description", ""),
|
||||
args_schema=_build_args_schema(tool_info.get("parameters", [])),
|
||||
)
|
||||
# set private attrs after BaseTool init to avoid pydantic stripping
|
||||
self._mcp_base_url = mcp_base_url.rstrip("/")
|
||||
self._tool_name = tool_info["name"]
|
||||
|
||||
def _execute(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
url = f"{self._mcp_base_url}/execute"
|
||||
payload = {
|
||||
"tool_name": self._tool_name,
|
||||
"parameters": params,
|
||||
}
|
||||
logger.info(f"调用 MCP 工具: {self._tool_name} 参数: {params}")
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"MCP 工具调用失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"result": None,
|
||||
"tool_name": self._tool_name,
|
||||
}
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Synchronous execution for LangChain tools."""
|
||||
data = self._execute(kwargs)
|
||||
if not isinstance(data, dict):
|
||||
return json.dumps({"success": False, "error": "Invalid MCP response"}, ensure_ascii=False)
|
||||
# Return string content; LangChain expects textual content for ToolMessage
|
||||
if data.get("success"):
|
||||
return json.dumps(data.get("result", {}), ensure_ascii=False)
|
||||
return json.dumps({"error": data.get("error")}, ensure_ascii=False)
|
||||
|
||||
async def _arun(self, **kwargs: Any) -> str:
|
||||
# LangChain will call async version when available; we simply delegate to sync for now.
|
||||
return self._run(**kwargs)
|
||||
|
||||
|
||||
def load_mcp_tools(include: Optional[List[str]] = None) -> List[MCPDynamicTool]:
|
||||
"""Load MCP tools from the MCP server and construct dynamic tools.
|
||||
|
||||
include: optional list of tool names to include (e.g., ["weather", "search"]).
|
||||
"""
|
||||
settings = get_settings()
|
||||
# Try settings.tool.mcp_server_url, fallback to default
|
||||
mcp_base_url = getattr(settings.tool, "mcp_server_url", None) or os.getenv("MCP_SERVER_URL") or "http://127.0.0.1:8001"
|
||||
|
||||
url = f"{mcp_base_url.rstrip('/')}/tools"
|
||||
try:
|
||||
resp = requests.get(url, timeout=15)
|
||||
resp.raise_for_status()
|
||||
tools_info = resp.json()
|
||||
except Exception as e:
|
||||
logger.error(f"获取 MCP 工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
dynamic_tools: List[MCPDynamicTool] = []
|
||||
for tool in tools_info:
|
||||
name = tool.get("name")
|
||||
if include and name not in include:
|
||||
continue
|
||||
try:
|
||||
dynamic_tools.append(MCPDynamicTool(mcp_base_url=mcp_base_url, tool_info=tool))
|
||||
except Exception as e:
|
||||
logger.warning(f"构建 MCP 工具'{name}'失败: {e}")
|
||||
logger.info(f"已加载 MCP 工具: {[t.name for t in dynamic_tools]}")
|
||||
return dynamic_tools
|
||||
|
|
@ -0,0 +1,454 @@
|
|||
"""MySQL MCP (Model Context Protocol) tool for database operations."""
|
||||
|
||||
import json
|
||||
import pymysql
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult
|
||||
|
||||
class MySQLMCPTool(BaseTool):
|
||||
"""MySQL MCP tool for database operations and intelligent querying."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.connections = {} # 存储用户的数据库连接
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "mysql_mcp"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "MySQL MCP服务工具,提供数据库连接、表结构查询、SQL执行等功能,支持智能数据问答。"
|
||||
|
||||
def get_parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ToolParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="connection_config",
|
||||
type=ToolParameterType.OBJECT,
|
||||
description="数据库连接配置 {host, port, database, username, password}",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="user_id",
|
||||
type=ToolParameterType.STRING,
|
||||
description="用户ID,用于管理连接",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="table_name",
|
||||
type=ToolParameterType.STRING,
|
||||
description="表名(用于describe_table操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="sql_query",
|
||||
type=ToolParameterType.STRING,
|
||||
description="SQL查询语句(用于execute_query操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="limit",
|
||||
type=ToolParameterType.INTEGER,
|
||||
description="查询结果限制数量,默认100",
|
||||
required=False,
|
||||
default=100
|
||||
)
|
||||
]
|
||||
|
||||
def _get_tables(self, connection) -> List[Dict[str, Any]]:
|
||||
"""获取数据库表列表"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
ORDER BY table_name;
|
||||
""")
|
||||
|
||||
tables = []
|
||||
for row in cursor.fetchall():
|
||||
tables.append({
|
||||
"table_name": row[0],
|
||||
"table_type": row[1],
|
||||
"table_schema": row[2]
|
||||
})
|
||||
|
||||
return tables
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _describe_table(self, connection, table_name: str) -> Dict[str, Any]:
|
||||
"""获取表结构信息"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
# 获取列信息
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default,
|
||||
character_maximum_length,
|
||||
numeric_precision,
|
||||
numeric_scale,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE() AND table_name = %s
|
||||
ORDER BY ordinal_position;
|
||||
""", (table_name,))
|
||||
|
||||
columns = []
|
||||
for row in cursor.fetchall():
|
||||
column_info = {
|
||||
"column_name": row[0],
|
||||
"data_type": row[1],
|
||||
"is_nullable": row[2] == 'YES',
|
||||
"column_default": row[3],
|
||||
"character_maximum_length": row[4],
|
||||
"numeric_precision": row[5],
|
||||
"numeric_scale": row[6],
|
||||
"column_comment": row[7] or ""
|
||||
}
|
||||
columns.append(column_info)
|
||||
|
||||
# 获取主键信息
|
||||
cursor.execute("""
|
||||
SELECT column_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
AND constraint_name = 'PRIMARY'
|
||||
ORDER BY ordinal_position;
|
||||
""", (table_name,))
|
||||
|
||||
primary_keys = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# 获取外键信息
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
column_name,
|
||||
referenced_table_name,
|
||||
referenced_column_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
AND referenced_table_name IS NOT NULL;
|
||||
""", (table_name,))
|
||||
|
||||
foreign_keys = []
|
||||
for row in cursor.fetchall():
|
||||
foreign_keys.append({
|
||||
"column_name": row[0],
|
||||
"referenced_table": row[1],
|
||||
"referenced_column": row[2]
|
||||
})
|
||||
|
||||
# 获取索引信息
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
index_name,
|
||||
column_name,
|
||||
non_unique
|
||||
FROM information_schema.statistics
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
ORDER BY index_name, seq_in_index;
|
||||
""", (table_name,))
|
||||
|
||||
indexes = []
|
||||
for row in cursor.fetchall():
|
||||
indexes.append({
|
||||
"index_name": row[0],
|
||||
"column_name": row[1],
|
||||
"is_unique": row[2] == 0
|
||||
})
|
||||
|
||||
# 获取表注释
|
||||
cursor.execute("""
|
||||
SELECT table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_name = %s;
|
||||
""", (table_name,))
|
||||
|
||||
table_comment = ""
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
table_comment = result[0] or ""
|
||||
|
||||
return {
|
||||
"table_name": table_name,
|
||||
"columns": columns,
|
||||
"primary_keys": primary_keys,
|
||||
"foreign_keys": foreign_keys,
|
||||
"indexes": indexes,
|
||||
"table_comment": table_comment
|
||||
}
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]:
|
||||
"""执行SQL查询"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
# 添加LIMIT限制(如果查询中没有LIMIT)
|
||||
if limit and limit > 0 and "LIMIT" not in sql_query.upper():
|
||||
sql_query = f"{sql_query.rstrip(';')} LIMIT {limit}"
|
||||
|
||||
cursor.execute(sql_query)
|
||||
|
||||
# 获取列名
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# 获取数据
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# 转换为字典列表
|
||||
data = []
|
||||
for row in rows:
|
||||
row_dict = {}
|
||||
for i, value in enumerate(row):
|
||||
if i < len(columns):
|
||||
# 处理特殊数据类型
|
||||
if isinstance(value, datetime):
|
||||
row_dict[columns[i]] = value.isoformat()
|
||||
else:
|
||||
row_dict[columns[i]] = value
|
||||
data.append(row_dict)
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"columns": columns,
|
||||
"row_count": len(data),
|
||||
"query": sql_query
|
||||
}
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _create_connection(self, config: Dict[str, Any]) -> pymysql.Connection:
|
||||
"""创建MySQL数据库连接"""
|
||||
try:
|
||||
connection = pymysql.connect(
|
||||
host=config['host'],
|
||||
port=int(config.get('port', 3306)),
|
||||
user=config['username'],
|
||||
password=config['password'],
|
||||
database=config['database'],
|
||||
connect_timeout=10,
|
||||
charset='utf8mb4'
|
||||
)
|
||||
return connection
|
||||
except Exception as e:
|
||||
raise Exception(f"MySQL连接失败: {str(e)}")
|
||||
|
||||
def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
conn = self._create_connection(config)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取数据库版本信息
|
||||
cursor.execute("SELECT VERSION();")
|
||||
version = cursor.fetchone()[0]
|
||||
|
||||
# 获取数据库引擎信息
|
||||
cursor.execute("SHOW ENGINES;")
|
||||
engines = cursor.fetchall()
|
||||
has_innodb = any('InnoDB' in str(engine) for engine in engines)
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"version": version,
|
||||
"has_innodb": has_innodb,
|
||||
"message": "连接测试成功"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "连接测试失败"
|
||||
}
|
||||
|
||||
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""Execute the MySQL MCP tool operation."""
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
connection_config = kwargs.get("connection_config", {})
|
||||
user_id = kwargs.get("user_id")
|
||||
table_name = kwargs.get("table_name")
|
||||
sql_query = kwargs.get("sql_query")
|
||||
limit = kwargs.get("limit", 100)
|
||||
|
||||
logger.info(f"执行MySQL MCP操作: {operation}")
|
||||
if operation == "test_connection":
|
||||
if not connection_config:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少连接配置参数"
|
||||
)
|
||||
|
||||
result = self._test_connection(connection_config)
|
||||
return ToolResult(
|
||||
success=result["success"],
|
||||
result=result,
|
||||
error=result.get("error")
|
||||
)
|
||||
elif operation == "connect":
|
||||
if not connection_config:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少connection_config参数"
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少user_id参数"
|
||||
)
|
||||
|
||||
try:
|
||||
# 建立MySQL连接
|
||||
connection = pymysql.connect(
|
||||
host=connection_config["host"],
|
||||
port=int(connection_config["port"]),
|
||||
user=connection_config["username"],
|
||||
password=connection_config["password"],
|
||||
database=connection_config["database"],
|
||||
charset='utf8mb4',
|
||||
cursorclass=pymysql.cursors.Cursor
|
||||
)
|
||||
|
||||
# 存储连接
|
||||
self.connections[user_id] = {
|
||||
"connection": connection,
|
||||
"config": connection_config,
|
||||
"connected_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 获取表列表
|
||||
tables = self._get_tables(connection)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={
|
||||
"message": "数据库连接成功",
|
||||
"database": connection_config["database"],
|
||||
"tables": tables,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"连接失败: {str(e)}"
|
||||
)
|
||||
|
||||
elif operation == "list_tables":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
tables = self._get_tables(connection)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={
|
||||
"tables": tables,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
)
|
||||
|
||||
elif operation == "describe_table":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
if not table_name:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少table_name参数"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
table_info = self._describe_table(connection, table_name)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result=table_info
|
||||
)
|
||||
|
||||
elif operation == "execute_query":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
if not sql_query:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少sql_query参数"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
query_result = self._execute_query(connection, sql_query, limit)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result=query_result
|
||||
)
|
||||
|
||||
elif operation == "disconnect":
|
||||
if user_id and user_id in self.connections:
|
||||
try:
|
||||
self.connections[user_id]["connection"].close()
|
||||
del self.connections[user_id]
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={"message": "数据库连接已断开"}
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"断开连接失败: {str(e)}"
|
||||
)
|
||||
else:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={"message": "用户未连接数据库"}
|
||||
)
|
||||
|
||||
else:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
result=f"不支持的操作类型: {operation}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MySQL MCP工具执行失败: {str(e)}", exc_info=True)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"工具执行失败: {str(e)}"
|
||||
)
|
||||
|
|
@ -0,0 +1,385 @@
|
|||
"""PostgreSQL MCP (Model Context Protocol) tool for database operations."""
|
||||
|
||||
import json
|
||||
import psycopg2
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.services.agent.base import BaseTool, ToolParameter, ToolParameterType, ToolResult
|
||||
|
||||
class PostgreSQLMCPTool(BaseTool):
|
||||
"""PostgreSQL MCP tool for database operations and intelligent querying."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.connections = {} # 存储用户的数据库连接
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "postgresql_mcp"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "PostgreSQL MCP服务工具,提供数据库连接、表结构查询、SQL执行等功能,支持智能数据问答。"
|
||||
|
||||
def get_parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ToolParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["connect", "list_tables", "describe_table", "execute_query", "test_connection", "disconnect"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="connection_config",
|
||||
type=ToolParameterType.OBJECT,
|
||||
description="数据库连接配置 {host, port, database, username, password}",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="user_id",
|
||||
type=ToolParameterType.STRING,
|
||||
description="用户ID,用于管理连接",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="table_name",
|
||||
type=ToolParameterType.STRING,
|
||||
description="表名(用于describe_table操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="sql_query",
|
||||
type=ToolParameterType.STRING,
|
||||
description="SQL查询语句(用于execute_query操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="limit",
|
||||
type=ToolParameterType.INTEGER,
|
||||
description="查询结果限制数量,默认100",
|
||||
required=False,
|
||||
default=100
|
||||
)
|
||||
]
|
||||
|
||||
def _create_connection(self, config: Dict[str, Any]) -> psycopg2.extensions.connection:
|
||||
"""创建PostgreSQL数据库连接"""
|
||||
try:
|
||||
connection = psycopg2.connect(
|
||||
host=config['host'],
|
||||
port=int(config.get('port', 5432)),
|
||||
user=config['username'],
|
||||
password=config['password'],
|
||||
database=config['database'],
|
||||
connect_timeout=10
|
||||
)
|
||||
return connection
|
||||
except Exception as e:
|
||||
raise Exception(f"PostgreSQL连接失败: {str(e)}")
|
||||
|
||||
def _test_connection(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
conn = self._create_connection(config)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取数据库版本信息
|
||||
cursor.execute("SELECT version();")
|
||||
version = cursor.fetchone()[0]
|
||||
|
||||
# 检查pgvector扩展
|
||||
cursor.execute("SELECT * FROM pg_extension WHERE extname = 'vector';")
|
||||
has_vector = bool(cursor.fetchall())
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"version": version,
|
||||
"has_pgvector": has_vector,
|
||||
"message": "连接测试成功"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "连接测试失败"
|
||||
}
|
||||
|
||||
def _get_tables(self, connection) -> List[Dict[str, Any]]:
|
||||
"""获取数据库表列表"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
table_schema
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
ORDER BY table_name;
|
||||
""")
|
||||
|
||||
tables = []
|
||||
for row in cursor.fetchall():
|
||||
tables.append({
|
||||
"table_name": row[0],
|
||||
"table_type": row[1],
|
||||
"table_schema": row[2]
|
||||
})
|
||||
|
||||
return tables
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _describe_table(self, connection, table_name: str) -> Dict[str, Any]:
|
||||
"""获取表结构信息"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
# 获取列信息
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default,
|
||||
character_maximum_length,
|
||||
numeric_precision,
|
||||
numeric_scale
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = 'public'
|
||||
ORDER BY ordinal_position;
|
||||
""", (table_name,))
|
||||
|
||||
columns = []
|
||||
for row in cursor.fetchall():
|
||||
columns.append({
|
||||
"column_name": row[0],
|
||||
"data_type": row[1],
|
||||
"is_nullable": row[2],
|
||||
"column_default": row[3],
|
||||
"character_maximum_length": row[4],
|
||||
"numeric_precision": row[5],
|
||||
"numeric_scale": row[6]
|
||||
})
|
||||
|
||||
# 获取主键信息
|
||||
cursor.execute("""
|
||||
SELECT column_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_name = %s AND table_schema = 'public'
|
||||
AND constraint_name IN (
|
||||
SELECT constraint_name
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_name = %s AND constraint_type = 'PRIMARY KEY'
|
||||
);
|
||||
""", (table_name, table_name))
|
||||
|
||||
primary_keys = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# 获取表行数
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name};")
|
||||
row_count = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
"table_name": table_name,
|
||||
"columns": columns,
|
||||
"primary_keys": primary_keys,
|
||||
"row_count": row_count
|
||||
}
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _execute_query(self, connection, sql_query: str, limit: int = 100) -> Dict[str, Any]:
|
||||
"""执行SQL查询"""
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
# 添加LIMIT限制(如果查询中没有)
|
||||
if limit and "LIMIT" not in sql_query.upper():
|
||||
sql_query = f"{sql_query.rstrip(';')} LIMIT {limit};"
|
||||
|
||||
cursor.execute(sql_query)
|
||||
|
||||
# 获取列名
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# 获取结果
|
||||
if cursor.description: # SELECT查询
|
||||
rows = cursor.fetchall()
|
||||
data = []
|
||||
for row in rows:
|
||||
row_dict = {}
|
||||
for i, value in enumerate(row):
|
||||
if i < len(columns):
|
||||
# 处理特殊数据类型
|
||||
if isinstance(value, datetime):
|
||||
row_dict[columns[i]] = value.isoformat()
|
||||
else:
|
||||
row_dict[columns[i]] = value
|
||||
data.append(row_dict)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"columns": columns,
|
||||
"row_count": len(data),
|
||||
"query": sql_query
|
||||
}
|
||||
else: # INSERT/UPDATE/DELETE查询
|
||||
affected_rows = cursor.rowcount
|
||||
return {
|
||||
"success": True,
|
||||
"affected_rows": affected_rows,
|
||||
"query": sql_query,
|
||||
"message": f"查询执行成功,影响 {affected_rows} 行"
|
||||
}
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
async def execute(self, operation: str, connection_config: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None, table_name: Optional[str] = None,
|
||||
sql_query: Optional[str] = None, limit: int = 100) -> ToolResult:
|
||||
"""执行PostgreSQL MCP操作"""
|
||||
try:
|
||||
logger.info(f"执行PostgreSQL MCP操作: {operation}")
|
||||
|
||||
if operation == "test_connection":
|
||||
if not connection_config:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少连接配置参数"
|
||||
)
|
||||
|
||||
result = self._test_connection(connection_config)
|
||||
return ToolResult(
|
||||
success=result["success"],
|
||||
result=result,
|
||||
error=result.get("error")
|
||||
)
|
||||
|
||||
elif operation == "connect":
|
||||
if not connection_config or not user_id:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少连接配置或用户ID参数"
|
||||
)
|
||||
|
||||
try:
|
||||
connection = self._create_connection(connection_config)
|
||||
self.connections[user_id] = {
|
||||
"connection": connection,
|
||||
"config": connection_config,
|
||||
"connected_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 获取表列表
|
||||
tables = self._get_tables(connection)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={
|
||||
"message": "数据库连接成功",
|
||||
"database": connection_config["database"],
|
||||
"tables": tables,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"连接失败: {str(e)}"
|
||||
)
|
||||
|
||||
elif operation == "list_tables":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
tables = self._get_tables(connection)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={
|
||||
"tables": tables,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
)
|
||||
|
||||
elif operation == "describe_table":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
if not table_name:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少table_name参数"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
table_info = self._describe_table(connection, table_name)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result=table_info
|
||||
)
|
||||
|
||||
elif operation == "execute_query":
|
||||
if not user_id or user_id not in self.connections:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="用户未连接数据库,请先执行connect操作"
|
||||
)
|
||||
|
||||
if not sql_query:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="缺少sql_query参数"
|
||||
)
|
||||
|
||||
connection = self.connections[user_id]["connection"]
|
||||
query_result = self._execute_query(connection, sql_query, limit)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result=query_result
|
||||
)
|
||||
|
||||
elif operation == "disconnect":
|
||||
if user_id and user_id in self.connections:
|
||||
try:
|
||||
self.connections[user_id]["connection"].close()
|
||||
del self.connections[user_id]
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={"message": "数据库连接已断开"}
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"断开连接失败: {str(e)}"
|
||||
)
|
||||
else:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
result={"message": "用户未连接数据库"}
|
||||
)
|
||||
|
||||
else:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"不支持的操作类型: {operation}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL MCP工具执行失败: {str(e)}", exc_info=True)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"工具执行失败: {str(e)}"
|
||||
)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""MySQL MCP工具全局管理器"""
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from th_agenter.services.mcp.mysql_mcp import MySQLMCPTool
|
||||
|
||||
class MySQLToolManager:
|
||||
"""MySQL工具全局单例管理器"""
|
||||
|
||||
_instance: Optional['MySQLToolManager'] = None
|
||||
_mysql_tool: Optional[MySQLMCPTool] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@property
|
||||
def mysql_tool(self) -> MySQLMCPTool:
|
||||
"""获取MySQL工具实例"""
|
||||
if self._mysql_tool is None:
|
||||
self._mysql_tool = MySQLMCPTool()
|
||||
logger.info("创建全局MySQL工具实例")
|
||||
return self._mysql_tool
|
||||
|
||||
def get_tool(self) -> MySQLMCPTool:
|
||||
"""获取MySQL工具实例(别名方法)"""
|
||||
return self.mysql_tool
|
||||
|
||||
|
||||
# 全局实例
|
||||
mysql_tool_manager = MySQLToolManager()
|
||||
|
||||
|
||||
def get_mysql_tool() -> MySQLMCPTool:
|
||||
"""获取全局MySQL工具实例"""
|
||||
return mysql_tool_manager.get_tool()
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
"""PostgreSQL MCP工具全局管理器"""
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from th_agenter.services.mcp.postgresql_mcp import PostgreSQLMCPTool
|
||||
|
||||
class PostgreSQLToolManager:
|
||||
"""PostgreSQL工具全局单例管理器"""
|
||||
|
||||
_instance: Optional['PostgreSQLToolManager'] = None
|
||||
_postgresql_tool: Optional[PostgreSQLMCPTool] = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@property
|
||||
def postgresql_tool(self) -> PostgreSQLMCPTool:
|
||||
"""获取PostgreSQL工具实例"""
|
||||
if self._postgresql_tool is None:
|
||||
self._postgresql_tool = PostgreSQLMCPTool()
|
||||
logger.info("创建全局PostgreSQL工具实例")
|
||||
return self._postgresql_tool
|
||||
|
||||
def get_tool(self) -> PostgreSQLMCPTool:
|
||||
"""获取PostgreSQL工具实例(别名方法)"""
|
||||
return self.postgresql_tool
|
||||
|
||||
|
||||
# 全局实例
|
||||
postgresql_tool_manager = PostgreSQLToolManager()
|
||||
|
||||
|
||||
def get_postgresql_tool() -> PostgreSQLMCPTool:
|
||||
"""获取全局PostgreSQL工具实例"""
|
||||
return postgresql_tool_manager.get_tool()
|
||||
|
|
@ -0,0 +1,879 @@
|
|||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from langchain_openai import ChatOpenAI
|
||||
from th_agenter.core.context import UserContext
|
||||
from .smart_query import DatabaseQueryService
|
||||
from .postgresql_tool_manager import get_postgresql_tool
|
||||
from .mysql_tool_manager import get_mysql_tool
|
||||
from .table_metadata_service import TableMetadataService
|
||||
from ..core.config import get_settings
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SmartWorkflowError(Exception):
|
||||
"""智能工作流自定义异常"""
|
||||
pass
|
||||
|
||||
class DatabaseConnectionError(SmartWorkflowError):
|
||||
"""数据库连接异常"""
|
||||
pass
|
||||
|
||||
class TableSchemaError(SmartWorkflowError):
|
||||
"""表结构获取异常"""
|
||||
pass
|
||||
|
||||
class SQLGenerationError(SmartWorkflowError):
|
||||
"""SQL生成异常"""
|
||||
pass
|
||||
|
||||
class QueryExecutionError(SmartWorkflowError):
|
||||
"""查询执行异常"""
|
||||
pass
|
||||
|
||||
|
||||
class SmartDatabaseWorkflowManager:
|
||||
"""
|
||||
智能数据库工作流管理器
|
||||
负责协调数据库连接、表元数据获取、SQL生成、查询执行和AI总结的完整流程
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.database_service = DatabaseQueryService()
|
||||
self.postgresql_tool = get_postgresql_tool()
|
||||
self.mysql_tool = get_mysql_tool()
|
||||
self.db = db
|
||||
self.table_metadata_service = TableMetadataService(db) if db else None
|
||||
|
||||
from ..core.llm import create_llm
|
||||
self.llm = create_llm()
|
||||
|
||||
def _get_database_tool(self, db_type: str):
|
||||
"""根据数据库类型获取对应的数据库工具"""
|
||||
if db_type.lower() == 'postgresql':
|
||||
return self.postgresql_tool
|
||||
elif db_type.lower() == 'mysql':
|
||||
return self.mysql_tool
|
||||
else:
|
||||
raise ValueError(f"不支持的数据库类型: {db_type}")
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行阻塞函数"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, func, *args)
|
||||
|
||||
def _convert_query_result_to_table_data(self, query_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库查询结果转换为前端表格数据格式
|
||||
参考Excel处理方式,以表格形式返回结果
|
||||
"""
|
||||
try:
|
||||
data = query_result.get('data', [])
|
||||
columns = query_result.get('columns', [])
|
||||
row_count = query_result.get('row_count', 0)
|
||||
|
||||
if not data or not columns:
|
||||
return {
|
||||
'result_type': 'table',
|
||||
'columns': [],
|
||||
'data': [],
|
||||
'total': 0,
|
||||
'message': '查询未返回数据'
|
||||
}
|
||||
|
||||
# 构建列定义
|
||||
table_columns = []
|
||||
for i, col_name in enumerate(columns):
|
||||
table_columns.append({
|
||||
'prop': f'col_{i}',
|
||||
'label': str(col_name),
|
||||
'width': 'auto'
|
||||
})
|
||||
|
||||
# 转换数据行
|
||||
table_data = []
|
||||
for row_index, row in enumerate(data):
|
||||
row_data = {'_index': str(row_index)}
|
||||
# 处理字典格式的行数据
|
||||
if isinstance(row, dict):
|
||||
for i, col_name in enumerate(columns):
|
||||
col_prop = f'col_{i}'
|
||||
value = row.get(col_name)
|
||||
# 处理None值和特殊值
|
||||
if value is None:
|
||||
row_data[col_prop] = ''
|
||||
elif isinstance(value, (int, float, str, bool)):
|
||||
row_data[col_prop] = str(value)
|
||||
else:
|
||||
row_data[col_prop] = str(value)
|
||||
else:
|
||||
# 处理列表格式的行数据(兼容性处理)
|
||||
for i, value in enumerate(row):
|
||||
col_prop = f'col_{i}'
|
||||
# 处理None值和特殊值
|
||||
if value is None:
|
||||
row_data[col_prop] = ''
|
||||
elif isinstance(value, (int, float, str, bool)):
|
||||
row_data[col_prop] = str(value)
|
||||
else:
|
||||
row_data[col_prop] = str(value)
|
||||
|
||||
table_data.append(row_data)
|
||||
|
||||
return {
|
||||
'result_type': 'table_data',
|
||||
'columns': table_columns,
|
||||
'data': table_data,
|
||||
'total': row_count,
|
||||
'message': f'查询成功,共返回 {row_count} 条记录'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换查询结果异常: {str(e)}")
|
||||
return {
|
||||
'result_type': 'error',
|
||||
'columns': [],
|
||||
'data': [],
|
||||
'total': 0,
|
||||
'message': f'结果转换失败: {str(e)}'
|
||||
}
|
||||
|
||||
async def process_database_query_stream(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
database_config_id: int
|
||||
):
|
||||
"""
|
||||
流式处理数据库智能问数查询的主要工作流(基于保存的表元数据)
|
||||
实时推送每个工作流步骤
|
||||
|
||||
新流程:
|
||||
1. 根据database_config_id获取数据库配置并创建连接
|
||||
2. 从系统数据库读取表元数据(只包含启用问答的表)
|
||||
3. 根据表元数据生成SQL
|
||||
4. 执行SQL查询
|
||||
5. 查询数据后处理成表格形式
|
||||
6. 生成数据总结
|
||||
7. 返回结果
|
||||
|
||||
Args:
|
||||
user_query: 用户问题
|
||||
user_id: 用户ID
|
||||
database_config_id: 数据库配置ID
|
||||
|
||||
Yields:
|
||||
包含工作流步骤或最终结果的字典
|
||||
"""
|
||||
workflow_steps = []
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行流式数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...")
|
||||
|
||||
# 步骤1: 根据database_config_id获取数据库配置并创建连接
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'database_connection',
|
||||
'status': 'running',
|
||||
'message': '正在建立数据库连接...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
# 获取数据库配置并建立连接
|
||||
connection_result = await self._connect_database(user_id, database_config_id)
|
||||
if not connection_result['success']:
|
||||
raise DatabaseConnectionError(connection_result['message'])
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': '数据库连接成功',
|
||||
'details': {'database': connection_result.get('database_name', 'Unknown')}
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'database_connection',
|
||||
'status': 'completed',
|
||||
'message': '数据库连接成功'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'数据库连接失败: {str(e)}'
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'database_connection',
|
||||
'status': 'failed',
|
||||
'message': error_msg,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
return
|
||||
|
||||
# 步骤2: 从系统数据库读取表元数据(只包含启用问答的表)
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'table_metadata',
|
||||
'status': 'running',
|
||||
'message': '正在从系统数据库读取表元数据...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
# 从系统数据库读取已保存的表元数据(只包含启用问答的表)
|
||||
tables_info = await self._get_saved_tables_metadata(user_id, database_config_id)
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': f'成功读取 {len(tables_info)} 个启用问答的表元数据',
|
||||
'details': {'table_count': len(tables_info), 'tables': list(tables_info.keys())}
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'table_metadata',
|
||||
'status': 'completed',
|
||||
'message': f'成功读取表元数据'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'获取表元数据失败: {str(e)}'
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'table_metadata',
|
||||
'status': 'failed',
|
||||
'message': error_msg,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
return
|
||||
|
||||
# 步骤3: 根据表元数据生成SQL
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'sql_generation',
|
||||
'status': 'running',
|
||||
'message': '正在根据表元数据生成SQL查询...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
# 根据表元数据选择相关表并生成SQL
|
||||
target_tables, target_schemas = await self._select_target_table(user_query, tables_info)
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'table_selected',
|
||||
'status': 'completed',
|
||||
'message': f'已经智能选择了相关表: {", ".join(target_tables)}',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
yield step_data
|
||||
workflow_steps.append({
|
||||
'step': 'table_metadata',
|
||||
'status': 'completed',
|
||||
'message': f'已经智能选择了相关表: {", ".join(target_tables)}',
|
||||
})
|
||||
sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas)
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': 'SQL查询生成成功',
|
||||
'details': {
|
||||
'target_tables': target_tables,
|
||||
'generated_sql': sql_query[:100] + '...' if len(sql_query) > 100 else sql_query
|
||||
}
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'sql_generation',
|
||||
'status': 'completed',
|
||||
'message': 'SQL语句生成成功'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'SQL生成失败: {str(e)}'
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'sql_generation',
|
||||
'status': 'failed',
|
||||
'message': error_msg,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
return
|
||||
|
||||
# 步骤4: 执行SQL查询
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'query_execution',
|
||||
'status': 'running',
|
||||
'message': '正在执行SQL查询...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
query_result = await self._execute_database_query(user_id, sql_query, database_config_id)
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': f'查询执行成功,返回 {query_result.get("row_count", 0)} 条记录',
|
||||
'details': {'row_count': query_result.get('row_count', 0)}
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'query_execution',
|
||||
'status': 'completed',
|
||||
'message': '查询执行成功'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'查询执行失败: {str(e)}'
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'query_execution',
|
||||
'status': 'failed',
|
||||
'message': error_msg,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
return
|
||||
|
||||
# 步骤5: 查询数据后处理成表格形式(在步骤6中完成)
|
||||
# 步骤6: 生成数据总结
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'ai_summary',
|
||||
'status': 'running',
|
||||
'message': '正在生成查询结果总结...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables))
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': '总结生成完成',
|
||||
'details': {
|
||||
'tables_analyzed': target_tables,
|
||||
'summary_length': len(summary)
|
||||
}
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'ai_summary',
|
||||
'status': 'completed',
|
||||
'message': '总结生成完成'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'生成总结失败: {str(e)}')
|
||||
summary = '查询执行完成,但生成总结时出现问题。'
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'ai_summary',
|
||||
'status': 'warning',
|
||||
'message': '总结生成失败,但查询成功'
|
||||
})
|
||||
|
||||
# 步骤7: 返回最终结果,且结果参考excel的处理方式,尽量以表格形式返回
|
||||
try:
|
||||
step_data = {
|
||||
'type': 'workflow_step',
|
||||
'step': 'result_formatting',
|
||||
'status': 'running',
|
||||
'message': '正在格式化查询结果...',
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
yield step_data
|
||||
|
||||
# 转换为表格格式
|
||||
table_data = self._convert_query_result_to_table_data(query_result)
|
||||
|
||||
step_data.update({
|
||||
'status': 'completed',
|
||||
'message': '结果格式化完成'
|
||||
})
|
||||
yield step_data
|
||||
|
||||
workflow_steps.append({
|
||||
'step': 'result_formatting',
|
||||
'status': 'completed',
|
||||
'message': '结果格式化完成'
|
||||
})
|
||||
|
||||
# 返回最终结果
|
||||
final_result = {
|
||||
'type': 'final_result',
|
||||
'success': True,
|
||||
'data': {
|
||||
**table_data,
|
||||
'generated_sql': sql_query,
|
||||
'summary': summary,
|
||||
'table_name': target_tables,
|
||||
'query_result': query_result,
|
||||
'metadata_source': 'saved_database' # 标记元数据来源
|
||||
},
|
||||
'workflow_steps': workflow_steps,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
yield final_result
|
||||
logger.info(f"数据库查询工作流完成 - 用户ID: {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'结果格式化失败: {str(e)}'
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询工作流异常: {str(e)}", exc_info=True)
|
||||
yield {
|
||||
'type': 'error',
|
||||
'message': f'系统异常: {str(e)}',
|
||||
'workflow_steps': workflow_steps
|
||||
}
|
||||
|
||||
async def _connect_database(self, user_id: int, database_config_id: int) -> Dict[str, Any]:
|
||||
"""连接数据库(判断用户现有连接)"""
|
||||
try:
|
||||
# 获取数据库配置
|
||||
from ..services.database_config_service import DatabaseConfigService
|
||||
config_service = DatabaseConfigService(self.db)
|
||||
config = config_service.get_config_by_id(database_config_id, user_id)
|
||||
|
||||
if not config:
|
||||
return {'success': False, 'message': '数据库配置不存在'}
|
||||
|
||||
# 根据数据库类型选择对应的工具
|
||||
try:
|
||||
db_tool = self._get_database_tool(config.db_type)
|
||||
except ValueError as e:
|
||||
return {'success': False, 'message': str(e)}
|
||||
|
||||
# 测试连接(如果已经有连接则直接复用)
|
||||
connection_config = {
|
||||
'host': config.host,
|
||||
'port': config.port,
|
||||
'database': config.database,
|
||||
'username': config.username,
|
||||
'password': config_service._decrypt_password(config.password)
|
||||
}
|
||||
|
||||
try:
|
||||
connection = db_tool._test_connection(connection_config)
|
||||
if connection['success'] == True:
|
||||
return {
|
||||
'success': True,
|
||||
'database_name': config.database,
|
||||
'db_type': config.db_type,
|
||||
'message': '连接成功'
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'database_name': config.database,
|
||||
'db_type': config.db_type,
|
||||
'message': '连接失败'
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'连接失败: {str(e)}'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接异常: {str(e)}")
|
||||
return {'success': False, 'message': f'连接异常: {str(e)}'}
|
||||
|
||||
async def _get_saved_tables_metadata(self, user_id: int, database_config_id: int) -> Dict[str, Dict[str, Any]]:
|
||||
"""从系统数据库中读取已保存的表元数据"""
|
||||
try:
|
||||
if not self.table_metadata_service:
|
||||
raise TableSchemaError("表元数据服务未初始化")
|
||||
|
||||
# 从数据库中获取表元数据
|
||||
saved_metadata = self.table_metadata_service.get_user_table_metadata(
|
||||
user_id, database_config_id
|
||||
)
|
||||
|
||||
if not saved_metadata:
|
||||
raise TableSchemaError(f"未找到数据库配置ID {database_config_id} 的表元数据,请先在数据库管理页面收集表元数据")
|
||||
|
||||
# 转换为所需格式
|
||||
tables_metadata = {}
|
||||
for meta in saved_metadata:
|
||||
# 只处理启用问答的表
|
||||
if meta.is_enabled_for_qa:
|
||||
tables_metadata[meta.table_name] = {
|
||||
'table_name': meta.table_name,
|
||||
'columns': meta.columns_info or [],
|
||||
'primary_keys': meta.primary_keys or [],
|
||||
'row_count': meta.row_count or 0,
|
||||
'table_comment': meta.table_comment or '',
|
||||
'qa_description': meta.qa_description or '',
|
||||
'business_context': meta.business_context or '',
|
||||
'from_saved_metadata': True # 标记来源
|
||||
}
|
||||
|
||||
if not tables_metadata:
|
||||
raise TableSchemaError("没有启用问答的表,请在数据库管理页面启用相关表的问答功能")
|
||||
|
||||
logger.info(f"从系统数据库读取表元数据成功,共 {len(tables_metadata)} 个启用问答的表")
|
||||
return tables_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取保存的表元数据异常: {str(e)}")
|
||||
raise TableSchemaError(f'读取表元数据失败: {str(e)}')
|
||||
|
||||
async def _get_table_schema(self, user_id: int, table_name: str, database_config_id: int) -> Dict[str, Any]:
|
||||
"""获取指定表结构"""
|
||||
try:
|
||||
# 获取数据库配置
|
||||
from ..services.database_config_service import DatabaseConfigService
|
||||
config_service = DatabaseConfigService(self.db)
|
||||
config = config_service.get_config_by_id(database_config_id, user_id)
|
||||
|
||||
if not config:
|
||||
raise TableSchemaError('数据库配置不存在')
|
||||
|
||||
# 根据数据库类型选择对应的工具
|
||||
try:
|
||||
db_tool = self._get_database_tool(config.db_type)
|
||||
except ValueError as e:
|
||||
raise TableSchemaError(str(e))
|
||||
|
||||
# 使用对应的数据库工具获取表结构
|
||||
schema_result = await db_tool.describe_table(table_name)
|
||||
|
||||
if schema_result.get('success'):
|
||||
return schema_result.get('schema', {})
|
||||
else:
|
||||
raise TableSchemaError(schema_result.get('error', '获取表结构失败'))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表结构异常: {str(e)}")
|
||||
raise TableSchemaError(f'获取表结构失败: {str(e)}')
|
||||
|
||||
async def _select_target_table(self, user_query: str, tables_info: Dict[str, Dict]) -> tuple[List[str], List[Dict]]:
|
||||
"""根据用户查询选择相关的表,支持返回多个表"""
|
||||
try:
|
||||
if len(tables_info) == 1:
|
||||
# 只有一个表,直接返回
|
||||
table_name = list(tables_info.keys())[0]
|
||||
return [table_name], [tables_info[table_name]]
|
||||
|
||||
# 多个表时,使用LLM选择相关的表
|
||||
tables_summary = []
|
||||
for table_name, schema in tables_info.items():
|
||||
columns = schema.get('columns', [])
|
||||
column_names = [col.get('column_name', col.get('name', '')) for col in columns]
|
||||
qa_desc = schema.get('qa_description', '')
|
||||
business_ctx = schema.get('business_context', '')
|
||||
tables_summary.append(f"表名: {table_name}\n字段: {', '.join(column_names[:10])}\n表描述: {qa_desc}\n业务上下文: {business_ctx}")
|
||||
|
||||
prompt = f"""
|
||||
用户查询: {user_query}
|
||||
|
||||
可用的表:
|
||||
{chr(10).join(tables_summary)}
|
||||
|
||||
请根据用户查询选择相关的表,可以选择多个表。分析表之间可能的关联关系,返回所有相关的表名,用逗号分隔。
|
||||
可以通过qa_description(表描述),business_context(表的业务上下文),以及column_names几个字段判断要使用哪些表。
|
||||
注意:只返回表名列表,后面不要跟其他的内容。
|
||||
例如直接输出: table1,table2,table3
|
||||
"""
|
||||
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
selected_tables = [t.strip() for t in response.content.strip().split(',')]
|
||||
|
||||
# 验证选择的表是否存在
|
||||
valid_tables = []
|
||||
valid_schemas = []
|
||||
for table in selected_tables:
|
||||
if table in tables_info:
|
||||
valid_tables.append(table)
|
||||
valid_schemas.append(tables_info[table])
|
||||
else:
|
||||
logger.warning(f"LLM选择的表 {table} 不存在")
|
||||
|
||||
if valid_tables:
|
||||
return valid_tables, valid_schemas
|
||||
else:
|
||||
# 如果没有有效的表,选择第一个表
|
||||
table_name = list(tables_info.keys())[0]
|
||||
logger.warning(f"没有找到有效的表,使用默认表 {table_name}")
|
||||
return [table_name], [tables_info[table_name]]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择目标表异常: {str(e)}")
|
||||
# 出现异常时选择第一个表
|
||||
table_name = list(tables_info.keys())[0]
|
||||
return [table_name], [tables_info[table_name]]
|
||||
|
||||
async def _generate_sql_query(self, user_query: str, table_names: List[str], table_schemas: List[Dict]) -> str:
|
||||
"""生成SQL语句,支持多表关联查询"""
|
||||
try:
|
||||
# 构建所有表的结构信息
|
||||
tables_info = []
|
||||
for table_name, schema in zip(table_names, table_schemas):
|
||||
columns_info = []
|
||||
for col in schema.get('columns', []):
|
||||
col_info = f"{col['column_name']} ({col['data_type']})"
|
||||
columns_info.append(col_info)
|
||||
|
||||
table_info = f"表名: {table_name}\n"
|
||||
table_info += f"表描述: {schema.get('qa_description', '')}\n"
|
||||
table_info += f"业务上下文: {schema.get('business_context', '')}\n"
|
||||
table_info += "字段信息:\n" + "\n".join(columns_info)
|
||||
tables_info.append(table_info)
|
||||
|
||||
schema_text = "\n\n".join(tables_info)
|
||||
|
||||
prompt = f"""
|
||||
基于以下表结构,将自然语言查询转换为SQL语句。如果需要关联多个表,请分析表之间的关系,使用合适的JOIN语法:
|
||||
|
||||
{schema_text}
|
||||
|
||||
用户查询: {user_query}
|
||||
|
||||
请生成对应的SQL查询语句,要求:
|
||||
1. 只返回SQL语句,不要包含其他解释
|
||||
2. 如果查询涉及多个表,需要正确处理表之间的关联关系
|
||||
3. 使用合适的JOIN类型(INNER JOIN、LEFT JOIN等)
|
||||
4. 确保SELECT的字段来源明确,必要时使用表名前缀
|
||||
"""
|
||||
|
||||
# 使用LLM生成SQL
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
sql_query = response.content.strip()
|
||||
|
||||
# 清理SQL语句
|
||||
if sql_query.startswith('```sql'):
|
||||
sql_query = sql_query[6:]
|
||||
if sql_query.endswith('```'):
|
||||
sql_query = sql_query[:-3]
|
||||
|
||||
sql_query = sql_query.strip()
|
||||
|
||||
logger.info(f"生成的SQL查询: {sql_query}")
|
||||
return sql_query
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SQL生成异常: {str(e)}")
|
||||
raise SQLGenerationError(f'SQL生成失败: {str(e)}')
|
||||
|
||||
async def _execute_database_query(self, user_id: int, sql_query: str, database_config_id: int) -> Dict[str, Any]:
|
||||
"""执行SQL语句"""
|
||||
try:
|
||||
# 获取数据库配置
|
||||
from ..services.database_config_service import DatabaseConfigService
|
||||
config_service = DatabaseConfigService(self.db)
|
||||
config = config_service.get_config_by_id(database_config_id, user_id)
|
||||
|
||||
if not config:
|
||||
raise QueryExecutionError('数据库配置不存在')
|
||||
|
||||
# 根据数据库类型选择对应的工具
|
||||
try:
|
||||
db_tool = self._get_database_tool(config.db_type)
|
||||
except ValueError as e:
|
||||
raise QueryExecutionError(str(e))
|
||||
|
||||
# 使用对应的数据库工具执行查询
|
||||
if str(user_id) in db_tool.connections:
|
||||
query_result = db_tool._execute_query(db_tool.connections[str(user_id)]['connection'], sql_query)
|
||||
else:
|
||||
raise QueryExecutionError('请重新进行数据库连接')
|
||||
|
||||
if query_result.get('success'):
|
||||
data = query_result.get('data', [])
|
||||
return {
|
||||
'success': True,
|
||||
'data': data,
|
||||
'row_count': len(data),
|
||||
'columns': query_result.get('columns', []),
|
||||
'sql_query': sql_query
|
||||
}
|
||||
else:
|
||||
raise QueryExecutionError(query_result.get('error', '查询执行失败'))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询执行异常: {str(e)}")
|
||||
raise QueryExecutionError(f'查询执行失败: {str(e)}')
|
||||
|
||||
async def _generate_database_summary(self, user_query: str, query_result: Dict, tables_str: str) -> str:
|
||||
"""生成AI总结,支持多表查询结果"""
|
||||
try:
|
||||
data = query_result.get('data', [])
|
||||
row_count = query_result.get('row_count', 0)
|
||||
columns = query_result.get('columns', [])
|
||||
sql_query = query_result.get('sql_query', '')
|
||||
|
||||
# 构建总结提示词
|
||||
prompt = f"""
|
||||
用户查询: {user_query}
|
||||
涉及的表: {tables_str}
|
||||
查询结果: 共 {row_count} 条记录
|
||||
查询的字段: {', '.join(columns)}
|
||||
执行的SQL: {sql_query}
|
||||
|
||||
前几条数据示例:
|
||||
{str(data[:3]) if data else '无数据'}
|
||||
|
||||
请基于以上信息,用中文生成一个简洁的查询结果总结,包括:
|
||||
1. 查询涉及的表及其关系
|
||||
2. 查询的主要发现和数据特征
|
||||
3. 如果有关联查询,说明关联的结果特点
|
||||
4. 最后对用户的问题进行回答
|
||||
|
||||
总结要求:
|
||||
1. 语言简洁明了
|
||||
2. 重点突出查询结果
|
||||
3. 如果是多表查询,需要说明表之间的关系
|
||||
4. 总结不超过300字
|
||||
"""
|
||||
|
||||
# 使用LLM生成总结
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
summary = response.content.strip()
|
||||
|
||||
logger.info(f"生成的总结: {summary[:100]}...")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"总结生成异常: {str(e)}")
|
||||
return f"查询完成,共返回 {query_result.get('row_count', 0)} 条记录。涉及的表: {tables_str}"
|
||||
|
||||
async def process_database_query(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_name: Optional[str] = None,
|
||||
conversation_id: Optional[int] = None,
|
||||
is_new_conversation: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理数据库智能问数查询的主要工作流(基于保存的表元数据)
|
||||
|
||||
新流程:
|
||||
1. 根据database_config_id获取数据库配置
|
||||
2. 创建数据库连接
|
||||
3. 从系统数据库读取表元数据(只包含启用问答的表)
|
||||
4. 根据表元数据生成SQL
|
||||
5. 执行SQL查询
|
||||
6. 查询数据后处理成表格形式
|
||||
7. 生成数据总结
|
||||
8. 返回结果
|
||||
|
||||
Args:
|
||||
user_query: 用户问题
|
||||
user_id: 用户ID
|
||||
database_config_id: 数据库配置ID
|
||||
table_name: 表名(可选)
|
||||
conversation_id: 对话ID
|
||||
is_new_conversation: 是否为新对话
|
||||
|
||||
Returns:
|
||||
包含查询结果的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始执行数据库查询工作流 - 用户ID: {user_id}, 数据库配置ID: {database_config_id}, 查询: {user_query[:50]}...")
|
||||
|
||||
# 步骤1: 根据database_config_id获取数据库配置并创建连接
|
||||
connection_result = await self._connect_database(user_id, database_config_id)
|
||||
if not connection_result['success']:
|
||||
raise DatabaseConnectionError(connection_result['message'])
|
||||
|
||||
logger.info("数据库连接成功")
|
||||
|
||||
# 步骤2: 从系统数据库读取表元数据(只包含启用问答的表)
|
||||
tables_info = await self._get_saved_tables_metadata(user_id, database_config_id)
|
||||
|
||||
logger.info(f"表元数据读取完成 - 共{len(tables_info)}个启用问答的表")
|
||||
|
||||
# 步骤3: 根据表元数据选择相关表并生成SQL
|
||||
target_tables, target_schemas = await self._select_target_table(user_query, tables_info)
|
||||
sql_query = await self._generate_sql_query(user_query, target_tables, target_schemas)
|
||||
|
||||
logger.info(f"SQL生成完成 - 目标表: {', '.join(target_tables)}")
|
||||
|
||||
# 步骤4: 执行SQL查询
|
||||
query_result = await self._execute_database_query(user_id, sql_query, database_config_id)
|
||||
logger.info("查询执行完成")
|
||||
|
||||
# 步骤5: 查询数据后处理成表格形式
|
||||
table_data = self._convert_query_result_to_table_data(query_result)
|
||||
|
||||
# 步骤6: 生成数据总结
|
||||
summary = await self._generate_database_summary(user_query, query_result, ', '.join(target_tables))
|
||||
|
||||
# 步骤7: 返回结果
|
||||
return {
|
||||
'success': True,
|
||||
'data': {
|
||||
**table_data,
|
||||
'generated_sql': sql_query,
|
||||
'summary': summary,
|
||||
'table_names': target_tables,
|
||||
'query_result': query_result,
|
||||
'metadata_source': 'saved_database' # 标记元数据来源
|
||||
}
|
||||
}
|
||||
|
||||
except SmartWorkflowError as e:
|
||||
logger.error(f"数据库工作流异常: {str(e)}")
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"数据库工作流未知异常: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'系统异常: {str(e)}',
|
||||
'error_type': 'SystemError'
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,717 @@
|
|||
import pandas as pd
|
||||
import pymysql
|
||||
import psycopg2
|
||||
import tempfile
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from loguru import logger
|
||||
|
||||
# 在 SmartQueryService 类中添加方法
|
||||
|
||||
from .table_metadata_service import TableMetadataService
|
||||
|
||||
class SmartQueryService:
|
||||
"""
|
||||
智能问数服务基类
|
||||
"""
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.table_metadata_service = None
|
||||
|
||||
def set_db_session(self, db_session):
|
||||
"""设置数据库会话"""
|
||||
self.table_metadata_service = TableMetadataService(db_session)
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行阻塞函数"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, func, *args)
|
||||
|
||||
class ExcelAnalysisService(SmartQueryService):
|
||||
"""
|
||||
Excel数据分析服务
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_dataframes = {} # 存储用户的DataFrame
|
||||
|
||||
def analyze_dataframe(self, df: pd.DataFrame, filename: str) -> Dict[str, Any]:
|
||||
"""
|
||||
分析DataFrame并返回基本信息
|
||||
"""
|
||||
try:
|
||||
# 基本统计信息
|
||||
rows, columns = df.shape
|
||||
|
||||
# 列信息
|
||||
column_info = []
|
||||
for col in df.columns:
|
||||
col_info = {
|
||||
'name': col,
|
||||
'dtype': str(df[col].dtype),
|
||||
'null_count': int(df[col].isnull().sum()),
|
||||
'unique_count': int(df[col].nunique())
|
||||
}
|
||||
|
||||
# 如果是数值列,添加统计信息
|
||||
if pd.api.types.is_numeric_dtype(df[col]):
|
||||
df.fillna({col:0}) #数值列,将空值补0
|
||||
col_info.update({
|
||||
'mean': float(df[col].mean()) if not df[col].isnull().all() else None,
|
||||
'std': float(df[col].std()) if not df[col].isnull().all() else None,
|
||||
'min': float(df[col].min()) if not df[col].isnull().all() else None,
|
||||
'max': float(df[col].max()) if not df[col].isnull().all() else None
|
||||
})
|
||||
|
||||
column_info.append(col_info)
|
||||
|
||||
# 数据预览(前5行)
|
||||
preview_data = df.head().fillna('').to_dict('records')
|
||||
|
||||
# 数据质量检查
|
||||
quality_issues = []
|
||||
|
||||
# 检查缺失值
|
||||
missing_cols = df.columns[df.isnull().any()].tolist()
|
||||
if missing_cols:
|
||||
quality_issues.append({
|
||||
'type': 'missing_values',
|
||||
'description': f'以下列存在缺失值: {", ".join(map(str, missing_cols))}',
|
||||
'columns': missing_cols
|
||||
})
|
||||
|
||||
# 检查重复行
|
||||
duplicate_count = df.duplicated().sum()
|
||||
if duplicate_count > 0:
|
||||
quality_issues.append({
|
||||
'type': 'duplicate_rows',
|
||||
'description': f'发现 {duplicate_count} 行重复数据',
|
||||
'count': int(duplicate_count)
|
||||
})
|
||||
|
||||
return {
|
||||
'filename': filename,
|
||||
'rows': rows,
|
||||
'columns': columns,
|
||||
'column_names': [str(col) for col in df.columns.tolist()],
|
||||
'column_info': column_info,
|
||||
'preview': preview_data,
|
||||
'quality_issues': quality_issues,
|
||||
'memory_usage': f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise Exception(f"DataFrame分析失败: {str(e)}")
|
||||
|
||||
def _create_pandas_agent(self, df: pd.DataFrame):
|
||||
"""
|
||||
创建pandas代理
|
||||
"""
|
||||
try:
|
||||
# 使用智谱AI作为LLM
|
||||
llm = ChatZhipuAI(
|
||||
model="glm-4",
|
||||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||||
temperature=0.1
|
||||
)
|
||||
agent = None
|
||||
logger.error('创建pandas代理失败 - 暂屏蔽处理')
|
||||
|
||||
# # 创建pandas代理
|
||||
# agent = create_pandas_dataframe_agent(
|
||||
# llm=llm,
|
||||
# df=df,
|
||||
# verbose=True,
|
||||
# return_intermediate_steps=True,
|
||||
# handle_parsing_errors=True,
|
||||
# max_iterations=3,
|
||||
# early_stopping_method="force",
|
||||
# allow_dangerous_code=True # 允许执行代码以支持数据分析
|
||||
# )
|
||||
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"创建pandas代理失败: {str(e)}")
|
||||
|
||||
def _execute_pandas_query(self, agent, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行pandas查询
|
||||
"""
|
||||
try:
|
||||
# 执行查询
|
||||
# 使用invoke方法来处理有多个输出键的情况
|
||||
agent_result = agent.invoke({"input": query})
|
||||
# 提取主要结果
|
||||
result = agent_result.get('output', agent_result)
|
||||
|
||||
# 解析结果
|
||||
if isinstance(result, pd.DataFrame):
|
||||
# 如果结果是DataFrame
|
||||
data = result.fillna('').to_dict('records')
|
||||
columns = result.columns.tolist()
|
||||
total = len(result)
|
||||
|
||||
return {
|
||||
'data': data,
|
||||
'columns': columns,
|
||||
'total': total,
|
||||
'result_type': 'dataframe'
|
||||
}
|
||||
else:
|
||||
# 如果结果是其他类型(字符串、数字等)
|
||||
return {
|
||||
'data': [{'result': str(result)}],
|
||||
'columns': ['result'],
|
||||
'total': 1,
|
||||
'result_type': 'scalar'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"pandas查询执行失败: {str(e)}")
|
||||
|
||||
async def execute_natural_language_query(
|
||||
self,
|
||||
query: str,
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行自然语言查询
|
||||
"""
|
||||
try:
|
||||
# 查找用户的临时文件
|
||||
temp_dir = tempfile.gettempdir()
|
||||
user_files = [f for f in os.listdir(temp_dir)
|
||||
if f.startswith(f"excel_{user_id}_") and f.endswith('.pkl')]
|
||||
|
||||
if not user_files:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '未找到上传的Excel文件,请先上传文件'
|
||||
}
|
||||
|
||||
# 使用最新的文件
|
||||
latest_file = sorted(user_files)[-1]
|
||||
file_path = os.path.join(temp_dir, latest_file)
|
||||
|
||||
# 加载DataFrame
|
||||
df = pd.read_pickle(file_path)
|
||||
|
||||
# 创建pandas代理
|
||||
agent = self._create_pandas_agent(df)
|
||||
|
||||
# 执行查询
|
||||
query_result = await self._run_in_executor(
|
||||
self._execute_pandas_query, agent, query
|
||||
)
|
||||
|
||||
# 分页处理
|
||||
total = query_result['total']
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
|
||||
paginated_data = query_result['data'][start_idx:end_idx]
|
||||
|
||||
# 生成AI总结
|
||||
summary = await self._generate_summary(query, query_result, df)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': {
|
||||
'data': paginated_data,
|
||||
'columns': query_result['columns'],
|
||||
'total': total,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'generated_code': f"# 基于自然语言查询: {query}\n# 使用LangChain Pandas代理执行",
|
||||
'summary': summary,
|
||||
'result_type': query_result['result_type']
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"查询执行失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _generate_summary(self, query: str, result: Dict[str, Any], df: pd.DataFrame) -> str:
|
||||
"""
|
||||
生成AI总结
|
||||
"""
|
||||
try:
|
||||
llm = ChatZhipuAI(
|
||||
model="glm-4",
|
||||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
# 构建总结提示
|
||||
prompt = f"""
|
||||
用户查询: {query}
|
||||
|
||||
数据集信息:
|
||||
- 总行数: {len(df)}
|
||||
- 总列数: {len(df.columns)}
|
||||
- 列名: {', '.join(str(col) for col in df.columns.tolist())}
|
||||
|
||||
查询结果:
|
||||
- 结果类型: {result['result_type']}
|
||||
- 结果行数: {result['total']}
|
||||
- 结果列数: {len(result['columns'])}
|
||||
|
||||
请基于以上信息,用中文生成一个简洁的分析总结,包括:
|
||||
1. 查询的主要目的
|
||||
2. 关键发现
|
||||
3. 数据洞察
|
||||
4. 建议的后续分析方向
|
||||
|
||||
总结应该专业、准确、易懂,控制在200字以内。
|
||||
"""
|
||||
|
||||
response = await self._run_in_executor(
|
||||
lambda: llm.invoke([HumanMessage(content=prompt)])
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
return f"查询已完成,但生成总结时出现错误: {str(e)}"
|
||||
|
||||
class DatabaseQueryService(SmartQueryService):
|
||||
"""
|
||||
数据库查询服务
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_connections = {} # 存储用户的数据库连接信息
|
||||
|
||||
def _create_connection(self, config: Dict[str, str]):
|
||||
"""
|
||||
创建数据库连接
|
||||
"""
|
||||
db_type = config['type'].lower()
|
||||
|
||||
try:
|
||||
if db_type == 'mysql':
|
||||
connection = pymysql.connect(
|
||||
host=config['host'],
|
||||
port=int(config['port']),
|
||||
user=config['username'],
|
||||
password=config['password'],
|
||||
database=config['database'],
|
||||
charset='utf8mb4'
|
||||
)
|
||||
elif db_type == 'postgresql':
|
||||
connection = psycopg2.connect(
|
||||
host=config['host'],
|
||||
port=int(config['port']),
|
||||
user=config['username'],
|
||||
password=config['password'],
|
||||
database=config['database']
|
||||
)
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"数据库连接失败: {str(e)}")
|
||||
|
||||
async def test_connection(self, config: Dict[str, str]) -> bool:
|
||||
"""
|
||||
测试数据库连接
|
||||
"""
|
||||
try:
|
||||
connection = await self._run_in_executor(self._create_connection, config)
|
||||
connection.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def connect_database(self, config: Dict[str, str], user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
连接数据库并获取表列表
|
||||
"""
|
||||
try:
|
||||
connection = await self._run_in_executor(self._create_connection, config)
|
||||
|
||||
# 获取表列表
|
||||
tables = await self._run_in_executor(self._get_tables, connection, config['type'])
|
||||
|
||||
# 存储连接信息
|
||||
self.user_connections[user_id] = {
|
||||
'config': config,
|
||||
'connection': connection,
|
||||
'connected_at': datetime.now()
|
||||
}
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': {
|
||||
'tables': tables,
|
||||
'database_type': config['type'],
|
||||
'database_name': config['database']
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"数据库连接失败: {str(e)}"
|
||||
}
|
||||
|
||||
def _get_tables(self, connection, db_type: str) -> List[str]:
|
||||
"""
|
||||
获取数据库表列表
|
||||
"""
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
if db_type.lower() == 'mysql':
|
||||
cursor.execute("SHOW TABLES")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
elif db_type.lower() == 'postgresql':
|
||||
cursor.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
""")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
elif db_type.lower() == 'sqlserver':
|
||||
cursor.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_type = 'BASE TABLE'
|
||||
""")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
else:
|
||||
tables = []
|
||||
|
||||
return tables
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
async def get_table_schema(self, table_name: str, user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
获取表结构
|
||||
"""
|
||||
try:
|
||||
if user_id not in self.user_connections:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '数据库连接已断开,请重新连接'
|
||||
}
|
||||
|
||||
connection = self.user_connections[user_id]['connection']
|
||||
db_type = self.user_connections[user_id]['config']['type']
|
||||
|
||||
schema = await self._run_in_executor(
|
||||
self._get_table_schema, connection, table_name, db_type
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': {
|
||||
'schema': schema,
|
||||
'table_name': table_name
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"获取表结构失败: {str(e)}"
|
||||
}
|
||||
|
||||
def _get_table_schema(self, connection, table_name: str, db_type: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取表结构信息
|
||||
"""
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
if db_type.lower() == 'mysql':
|
||||
cursor.execute(f"DESCRIBE {table_name}")
|
||||
columns = cursor.fetchall()
|
||||
schema = [{
|
||||
'column_name': col[0],
|
||||
'data_type': col[1],
|
||||
'is_nullable': 'YES' if col[2] == 'YES' else 'NO',
|
||||
'column_key': col[3],
|
||||
'column_default': col[4]
|
||||
} for col in columns]
|
||||
elif db_type.lower() == 'postgresql':
|
||||
cursor.execute("""
|
||||
SELECT column_name, data_type, is_nullable, column_default
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""", (table_name,))
|
||||
columns = cursor.fetchall()
|
||||
schema = [{
|
||||
'column_name': col[0],
|
||||
'data_type': col[1],
|
||||
'is_nullable': col[2],
|
||||
'column_default': col[3]
|
||||
} for col in columns]
|
||||
|
||||
else:
|
||||
schema = []
|
||||
|
||||
return schema
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
async def execute_natural_language_query(
|
||||
self,
|
||||
query: str,
|
||||
table_name: str,
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行自然语言数据库查询
|
||||
"""
|
||||
try:
|
||||
if user_id not in self.user_connections:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '数据库连接已断开,请重新连接'
|
||||
}
|
||||
|
||||
connection = self.user_connections[user_id]['connection']
|
||||
|
||||
# 这里应该集成MCP服务来将自然语言转换为SQL
|
||||
# 目前先使用简单的实现
|
||||
sql_query = await self._convert_to_sql(query, table_name, connection)
|
||||
|
||||
# 执行SQL查询
|
||||
result = await self._run_in_executor(
|
||||
self._execute_sql_query, connection, sql_query, page, page_size
|
||||
)
|
||||
|
||||
# 生成AI总结
|
||||
summary = await self._generate_db_summary(query, result, table_name)
|
||||
|
||||
result['generated_code'] = sql_query
|
||||
result['summary'] = summary
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'data': result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"数据库查询执行失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _convert_to_sql(self, query: str, table_name: str, connection) -> str:
|
||||
"""
|
||||
将自然语言转换为SQL查询
|
||||
TODO: 集成MCP服务
|
||||
"""
|
||||
# 这是一个简化的实现,实际应该使用MCP服务
|
||||
# 根据常见的查询模式生成SQL
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
if '所有' in query or '全部' in query or 'all' in query_lower:
|
||||
return f"SELECT * FROM {table_name} LIMIT 100"
|
||||
elif '统计' in query or '总数' in query or 'count' in query_lower:
|
||||
return f"SELECT COUNT(*) as total_count FROM {table_name}"
|
||||
elif '最近' in query or 'recent' in query_lower:
|
||||
return f"SELECT * FROM {table_name} ORDER BY id DESC LIMIT 10"
|
||||
elif '分组' in query or 'group' in query_lower:
|
||||
# 简单的分组查询,需要根据实际表结构调整
|
||||
return f"SELECT COUNT(*) as count FROM {table_name} GROUP BY id LIMIT 10"
|
||||
else:
|
||||
# 默认查询
|
||||
return f"SELECT * FROM {table_name} LIMIT 20"
|
||||
|
||||
def _execute_sql_query(self, connection, sql_query: str, page: int, page_size: int) -> Dict[str, Any]:
|
||||
"""
|
||||
执行SQL查询
|
||||
"""
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
# 执行查询
|
||||
cursor.execute(sql_query)
|
||||
|
||||
# 获取列名
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# 获取所有结果
|
||||
all_results = cursor.fetchall()
|
||||
total = len(all_results)
|
||||
|
||||
# 分页
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
paginated_results = all_results[start_idx:end_idx]
|
||||
|
||||
# 转换为字典格式
|
||||
data = []
|
||||
for row in paginated_results:
|
||||
row_dict = {}
|
||||
for i, value in enumerate(row):
|
||||
if i < len(columns):
|
||||
row_dict[columns[i]] = value
|
||||
data.append(row_dict)
|
||||
|
||||
return {
|
||||
'data': data,
|
||||
'columns': columns,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'page_size': page_size
|
||||
}
|
||||
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
async def _generate_db_summary(self, query: str, result: Dict[str, Any], table_name: str) -> str:
|
||||
"""
|
||||
生成数据库查询总结
|
||||
"""
|
||||
try:
|
||||
llm = ChatZhipuAI(
|
||||
model="glm-4",
|
||||
api_key=os.getenv("ZHIPUAI_API_KEY"),
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
prompt = f"""
|
||||
用户查询: {query}
|
||||
目标表: {table_name}
|
||||
|
||||
查询结果:
|
||||
- 结果行数: {result['total']}
|
||||
- 结果列数: {len(result['columns'])}
|
||||
- 列名: {', '.join(result['columns'])}
|
||||
|
||||
请基于以上信息,用中文生成一个简洁的数据库查询分析总结,包括:
|
||||
1. 查询的主要目的
|
||||
2. 关键数据发现
|
||||
3. 数据特征分析
|
||||
4. 建议的后续查询方向
|
||||
|
||||
总结应该专业、准确、易懂,控制在200字以内。
|
||||
"""
|
||||
|
||||
response = await self._run_in_executor(
|
||||
lambda: llm.invoke([HumanMessage(content=prompt)])
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
return f"查询已完成,但生成总结时出现错误: {str(e)}"
|
||||
|
||||
# 在 SmartQueryService 类中添加方法
|
||||
|
||||
from .table_metadata_service import TableMetadataService
|
||||
|
||||
class SmartQueryService:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.table_metadata_service = None
|
||||
|
||||
def set_db_session(self, db_session):
|
||||
"""设置数据库会话"""
|
||||
self.table_metadata_service = TableMetadataService(db_session)
|
||||
|
||||
async def get_database_context(self, user_id: int, query: str) -> str:
|
||||
"""获取数据库上下文信息用于问答"""
|
||||
if not self.table_metadata_service:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# 获取用户的表元数据
|
||||
table_metadata_list = self.table_metadata_service.get_user_table_metadata(user_id)
|
||||
|
||||
if not table_metadata_list:
|
||||
return ""
|
||||
|
||||
# 构建数据库上下文
|
||||
context_parts = []
|
||||
context_parts.append("=== 数据库表信息 ===")
|
||||
|
||||
for metadata in table_metadata_list:
|
||||
table_info = []
|
||||
table_info.append(f"表名: {metadata.table_name}")
|
||||
|
||||
if metadata.table_comment:
|
||||
table_info.append(f"表描述: {metadata.table_comment}")
|
||||
|
||||
if metadata.qa_description:
|
||||
table_info.append(f"业务说明: {metadata.qa_description}")
|
||||
|
||||
# 添加列信息
|
||||
if metadata.columns_info:
|
||||
columns = []
|
||||
for col in metadata.columns_info:
|
||||
col_desc = f"{col['column_name']} ({col['data_type']})"
|
||||
if col.get('column_comment'):
|
||||
col_desc += f" - {col['column_comment']}"
|
||||
columns.append(col_desc)
|
||||
table_info.append(f"字段: {', '.join(columns)}")
|
||||
|
||||
# 添加示例数据
|
||||
if metadata.sample_data:
|
||||
table_info.append(f"示例数据: {metadata.sample_data[:2]}")
|
||||
|
||||
table_info.append(f"总行数: {metadata.row_count}")
|
||||
|
||||
context_parts.append("\n".join(table_info))
|
||||
context_parts.append("---")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库上下文失败: {str(e)}")
|
||||
return ""
|
||||
|
||||
async def execute_smart_query(self, query: str, user_id: int, **kwargs) -> Dict[str, Any]:
|
||||
"""执行智能查询(集成表元数据)"""
|
||||
try:
|
||||
# 获取数据库上下文
|
||||
db_context = await self.get_database_context(user_id, query)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = f"""
|
||||
{db_context}
|
||||
|
||||
用户问题: {query}
|
||||
|
||||
请基于上述数据库表信息,生成相应的SQL查询语句。
|
||||
注意:
|
||||
1. 使用准确的表名和字段名
|
||||
2. 考虑数据类型和约束
|
||||
3. 参考示例数据理解数据格式
|
||||
4. 生成高效的查询语句
|
||||
"""
|
||||
|
||||
# 调用原有的查询逻辑
|
||||
return await super().execute_smart_query(enhanced_prompt, user_id, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能查询失败: {str(e)}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"查询失败: {str(e)}"
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
from typing import Dict, Any, List, Optional, Union
|
||||
import logging
|
||||
from .smart_excel_workflow import SmartExcelWorkflowManager
|
||||
from .smart_db_workflow import SmartDatabaseWorkflowManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 异常类已迁移到各自的工作流文件中
|
||||
|
||||
class SmartWorkflowManager:
|
||||
"""
|
||||
智能工作流管理器
|
||||
统一入口,委托给具体的Excel或数据库工作流管理器
|
||||
"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
self.db = db
|
||||
self.excel_workflow = SmartExcelWorkflowManager(db)
|
||||
self.database_workflow = SmartDatabaseWorkflowManager(db)
|
||||
|
||||
async def process_excel_query_stream(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
conversation_id: Optional[int] = None,
|
||||
is_new_conversation: bool = False
|
||||
):
|
||||
"""
|
||||
流式处理Excel智能问数查询,委托给Excel工作流管理器
|
||||
"""
|
||||
async for result in self.excel_workflow.process_excel_query_stream(
|
||||
user_query, user_id, conversation_id, is_new_conversation
|
||||
):
|
||||
yield result
|
||||
|
||||
async def process_database_query_stream(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
conversation_id: Optional[int] = None,
|
||||
is_new_conversation: bool = False
|
||||
):
|
||||
"""
|
||||
流式处理数据库智能问数查询,委托给数据库工作流管理器
|
||||
"""
|
||||
async for result in self.database_workflow.process_database_query_stream(
|
||||
user_query, user_id, database_config_id
|
||||
):
|
||||
yield result
|
||||
|
||||
async def process_smart_query(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
conversation_id: Optional[int] = None,
|
||||
is_new_conversation: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理智能问数查询的主要工作流(非流式版本)
|
||||
委托给Excel工作流管理器
|
||||
"""
|
||||
return await self.excel_workflow.process_smart_query(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
is_new_conversation=is_new_conversation
|
||||
)
|
||||
|
||||
async def process_database_query(
|
||||
self,
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
conversation_id: Optional[int] = None,
|
||||
is_new_conversation: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理数据库智能问数查询,委托给数据库工作流管理器
|
||||
"""
|
||||
return await self.database_workflow.process_database_query(
|
||||
user_query, user_id, database_config_id, None, conversation_id, is_new_conversation
|
||||
)
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
"""File storage service supporting local and S3 storage."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, BinaryIO, Dict, Any
|
||||
from fastapi import UploadFile
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
|
||||
from ..core.config import settings
|
||||
from utils.util_file import FileUtils
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
"""Abstract storage backend interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
file_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload file and return storage info."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_file(self, file_path: str) -> bool:
|
||||
"""Delete file from storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_file_url(self, file_path: str) -> Optional[str]:
|
||||
"""Get file access URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def file_exists(self, file_path: str) -> bool:
|
||||
"""Check if file exists."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalStorageBackend(StorageBackend):
|
||||
"""Local file system storage backend."""
|
||||
|
||||
def __init__(self, base_path: str):
|
||||
self.base_path = Path(base_path)
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
file_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload file to local storage."""
|
||||
full_path = self.base_path / file_path
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(full_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
# Get file info
|
||||
file_info = FileUtils.get_file_info(str(full_path))
|
||||
|
||||
return {
|
||||
"file_path": file_path,
|
||||
"full_path": str(full_path),
|
||||
"size": file_info["size_bytes"],
|
||||
"mime_type": file_info["mime_type"],
|
||||
"storage_type": "local"
|
||||
}
|
||||
|
||||
async def delete_file(self, file_path: str) -> bool:
|
||||
"""Delete file from local storage."""
|
||||
full_path = self.base_path / file_path
|
||||
return FileUtils.delete_file(str(full_path))
|
||||
|
||||
async def get_file_url(self, file_path: str) -> Optional[str]:
|
||||
"""Get local file URL (for development)."""
|
||||
# In production, you might want to serve files through a web server
|
||||
full_path = self.base_path / file_path
|
||||
if full_path.exists():
|
||||
return f"/files/{file_path}"
|
||||
return None
|
||||
|
||||
async def file_exists(self, file_path: str) -> bool:
|
||||
"""Check if file exists in local storage."""
|
||||
full_path = self.base_path / file_path
|
||||
return full_path.exists()
|
||||
|
||||
|
||||
class S3StorageBackend(StorageBackend):
|
||||
"""Amazon S3 storage backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: str,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region: str = "us-east-1",
|
||||
endpoint_url: Optional[str] = None
|
||||
):
|
||||
self.bucket_name = bucket_name
|
||||
self.aws_region = aws_region
|
||||
|
||||
# Initialize S3 client
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region
|
||||
)
|
||||
|
||||
self.s3_client = session.client(
|
||||
's3',
|
||||
endpoint_url=endpoint_url # For S3-compatible services like MinIO
|
||||
)
|
||||
|
||||
# Verify bucket exists or create it
|
||||
self._ensure_bucket_exists()
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
"""Ensure S3 bucket exists."""
|
||||
try:
|
||||
self.s3_client.head_bucket(Bucket=self.bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = int(e.response['Error']['Code'])
|
||||
if error_code == 404:
|
||||
# Bucket doesn't exist, create it
|
||||
try:
|
||||
if self.aws_region == 'us-east-1':
|
||||
self.s3_client.create_bucket(Bucket=self.bucket_name)
|
||||
else:
|
||||
self.s3_client.create_bucket(
|
||||
Bucket=self.bucket_name,
|
||||
CreateBucketConfiguration={'LocationConstraint': self.aws_region}
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise Exception(f"Failed to create S3 bucket: {create_error}")
|
||||
else:
|
||||
raise Exception(f"Failed to access S3 bucket: {e}")
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
file_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload file to S3."""
|
||||
try:
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Determine content type
|
||||
content_type = FileUtils.get_mime_type(file.filename) or 'application/octet-stream'
|
||||
|
||||
# Upload to S3
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
Body=content,
|
||||
ContentType=content_type,
|
||||
Metadata={
|
||||
'original_filename': file.filename or 'unknown',
|
||||
'upload_timestamp': str(int(os.time.time()))
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"file_path": file_path,
|
||||
"bucket": self.bucket_name,
|
||||
"size": len(content),
|
||||
"mime_type": content_type,
|
||||
"storage_type": "s3"
|
||||
}
|
||||
except (ClientError, NoCredentialsError) as e:
|
||||
raise Exception(f"Failed to upload file to S3: {e}")
|
||||
|
||||
async def delete_file(self, file_path: str) -> bool:
|
||||
"""Delete file from S3."""
|
||||
try:
|
||||
self.s3_client.delete_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path
|
||||
)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
async def get_file_url(self, file_path: str) -> Optional[str]:
|
||||
"""Get presigned URL for S3 file."""
|
||||
try:
|
||||
url = self.s3_client.generate_presigned_url(
|
||||
'get_object',
|
||||
Params={'Bucket': self.bucket_name, 'Key': file_path},
|
||||
ExpiresIn=3600 # 1 hour
|
||||
)
|
||||
return url
|
||||
except ClientError:
|
||||
return None
|
||||
|
||||
async def file_exists(self, file_path: str) -> bool:
|
||||
"""Check if file exists in S3."""
|
||||
try:
|
||||
self.s3_client.head_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path
|
||||
)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
|
||||
class StorageService:
|
||||
"""统一的存储服务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage_type = settings.storage.storage_type
|
||||
|
||||
if self.storage_type == 's3':
|
||||
self.backend = S3StorageBackend(
|
||||
bucket_name=settings.storage.s3_bucket_name,
|
||||
aws_access_key_id=settings.storage.aws_access_key_id,
|
||||
aws_secret_access_key=settings.storage.aws_secret_access_key,
|
||||
aws_region=settings.storage.aws_region,
|
||||
endpoint_url=settings.storage.s3_endpoint_url
|
||||
)
|
||||
else:
|
||||
# 确保使用绝对路径,避免在不同目录运行时路径不一致
|
||||
upload_dir = settings.storage.upload_directory
|
||||
if not os.path.isabs(upload_dir):
|
||||
# 如果是相对路径,则基于项目根目录计算绝对路径
|
||||
# 项目根目录是backend的父目录
|
||||
backend_dir = Path(__file__).parent.parent.parent
|
||||
upload_dir = str(backend_dir / upload_dir)
|
||||
self.backend = LocalStorageBackend(upload_dir)
|
||||
|
||||
def generate_file_path(self, knowledge_base_id: int, filename: str) -> str:
|
||||
"""Generate unique file path for storage."""
|
||||
# Sanitize filename
|
||||
safe_filename = FileUtils.sanitize_filename(filename)
|
||||
|
||||
# Generate unique identifier
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Create path: kb_{id}/{file_id}_{filename}
|
||||
return f"kb_{knowledge_base_id}/{file_id}_{safe_filename}"
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
knowledge_base_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload file using configured storage backend."""
|
||||
file_path = self.generate_file_path(knowledge_base_id, file.filename)
|
||||
return await self.backend.upload_file(file, file_path)
|
||||
|
||||
async def delete_file(self, file_path: str) -> bool:
|
||||
"""Delete file using configured storage backend."""
|
||||
return await self.backend.delete_file(file_path)
|
||||
|
||||
async def get_file_url(self, file_path: str) -> Optional[str]:
|
||||
"""Get file access URL."""
|
||||
return await self.backend.get_file_url(file_path)
|
||||
|
||||
async def file_exists(self, file_path: str) -> bool:
|
||||
"""Check if file exists."""
|
||||
return await self.backend.file_exists(file_path)
|
||||
|
||||
|
||||
# Global storage service instance
|
||||
storage_service = StorageService()
|
||||
|
|
@ -0,0 +1,455 @@
|
|||
"""表元数据管理服务"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, func
|
||||
from datetime import datetime
|
||||
|
||||
from ..models.table_metadata import TableMetadata
|
||||
from ..models.database_config import DatabaseConfig
|
||||
from utils.util_exceptions import ValidationError, NotFoundError
|
||||
from .postgresql_tool_manager import get_postgresql_tool
|
||||
from .mysql_tool_manager import get_mysql_tool
|
||||
from loguru import logger
|
||||
|
||||
class TableMetadataService:
|
||||
"""表元数据管理服务"""
|
||||
|
||||
def __init__(self, db_session: Session):
|
||||
self.session = db_session
|
||||
self.postgresql_tool = get_postgresql_tool()
|
||||
self.mysql_tool = get_mysql_tool()
|
||||
|
||||
async def collect_and_save_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_names: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""收集并保存表元数据"""
|
||||
self.session.desc = f"为用户 {user_id} 收集数据库 {database_config_id} 的表元数据"
|
||||
try:
|
||||
# 获取数据库配置
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.id == database_config_id,
|
||||
DatabaseConfig.created_by == user_id
|
||||
)
|
||||
db_config = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if not db_config:
|
||||
self.session.desc = "ERROR: 数据库配置不存在"
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
|
||||
# 根据数据库类型选择相应的工具
|
||||
if db_config.db_type.lower() == 'postgresql':
|
||||
db_tool = self.postgresql_tool
|
||||
elif db_config.db_type.lower() == 'mysql':
|
||||
db_tool = self.mysql_tool
|
||||
else:
|
||||
self.session.desc = f"ERROR: 不支持的数据库类型: {db_config.db_type}, 期望为postgresql或mysql"
|
||||
raise Exception(f"不支持的数据库类型: {db_config.db_type}")
|
||||
|
||||
# 检查是否已有连接,如果没有则建立连接
|
||||
user_id_str = str(user_id)
|
||||
if user_id_str not in db_tool.connections:
|
||||
connection_config = {
|
||||
'host': db_config.host,
|
||||
'port': db_config.port,
|
||||
'database': db_config.database,
|
||||
'username': db_config.username,
|
||||
'password': self._decrypt_password(db_config.password)
|
||||
}
|
||||
|
||||
# 连接数据库
|
||||
connect_result = await db_tool.execute(
|
||||
operation="connect",
|
||||
connection_config=connection_config,
|
||||
user_id=user_id_str
|
||||
)
|
||||
|
||||
if not connect_result.success:
|
||||
self.session.desc = f"ERROR: 数据库连接失败: {connect_result.error}"
|
||||
raise Exception(f"数据库连接失败: {connect_result.error}")
|
||||
|
||||
self.session.desc = f"SUCCESS: 为用户 {user_id} 建立了新的{db_config.db_type}数据库连接"
|
||||
else:
|
||||
self.session.desc = f"SUCCESS: 复用用户 {user_id} 的现有{db_config.db_type}数据库连接"
|
||||
|
||||
collected_tables = []
|
||||
failed_tables = []
|
||||
|
||||
for table_name in table_names:
|
||||
try:
|
||||
# 收集表元数据
|
||||
metadata = await self._collect_single_table_metadata(
|
||||
user_id, table_name, db_config.db_type
|
||||
)
|
||||
|
||||
# 保存或更新元数据
|
||||
table_metadata = await self._save_table_metadata(
|
||||
user_id, database_config_id, table_name, metadata
|
||||
)
|
||||
|
||||
collected_tables.append({
|
||||
'table_name': table_name,
|
||||
'metadata_id': table_metadata.id,
|
||||
'columns_count': len(metadata['columns_info']),
|
||||
'sample_rows': len(metadata['sample_data'])
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ERROR: 收集表 {table_name} 元数据失败: {str(e)}"
|
||||
failed_tables.append({
|
||||
'table_name': table_name,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'collected_tables': collected_tables,
|
||||
'failed_tables': failed_tables,
|
||||
'total_collected': len(collected_tables),
|
||||
'total_failed': len(failed_tables)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ERROR: 收集表元数据失败: {str(e)}"
|
||||
return {
|
||||
'success': False,
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
async def _collect_single_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
table_name: str,
|
||||
db_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""收集单个表的元数据"""
|
||||
self.session.desc = f"为用户 {user_id} 收集表 {table_name} 的元数据"
|
||||
# 根据数据库类型选择相应的工具
|
||||
if db_type.lower() == 'postgresql':
|
||||
db_tool = self.postgresql_tool
|
||||
elif db_type.lower() == 'mysql':
|
||||
db_tool = self.mysql_tool
|
||||
else:
|
||||
self.session.desc = f"ERROR: 不支持的数据库类型: {db_type}, 期望为postgresql或mysql"
|
||||
raise Exception(f"不支持的数据库类型: {db_type}")
|
||||
|
||||
# 获取表结构
|
||||
schema_result = await db_tool.execute(
|
||||
operation="describe_table",
|
||||
user_id=str(user_id),
|
||||
table_name=table_name
|
||||
)
|
||||
|
||||
if not schema_result.success:
|
||||
self.session.desc = f"ERROR: 获取表 {table_name} 结构失败: {schema_result.error}"
|
||||
raise Exception(f"获取表结构失败: {schema_result.error}")
|
||||
|
||||
schema_data = schema_result.result
|
||||
|
||||
# 获取示例数据(前5条)
|
||||
sample_result = await db_tool.execute(
|
||||
operation="execute_query",
|
||||
user_id=str(user_id),
|
||||
sql_query=f"SELECT * FROM {table_name} LIMIT 5",
|
||||
limit=5
|
||||
)
|
||||
|
||||
sample_data = []
|
||||
if sample_result.success:
|
||||
sample_data = sample_result.result.get('data', [])
|
||||
|
||||
# 获取行数统计
|
||||
count_result = await db_tool.execute(
|
||||
operation="execute_query",
|
||||
user_id=str(user_id),
|
||||
sql_query=f"SELECT COUNT(*) as total_rows FROM {table_name}",
|
||||
limit=1
|
||||
)
|
||||
|
||||
row_count = 0
|
||||
if count_result.success and count_result.result.get('data'):
|
||||
row_count = count_result.result['data'][0].get('total_rows', 0)
|
||||
|
||||
self.session.desc = f"SUCCESS: 为用户 {user_id} 收集表 {table_name} 的元数据, 包含 {len(schema_data.get('columns', []))} 列, {row_count} 行数据"
|
||||
|
||||
return {
|
||||
'columns_info': schema_data.get('columns', []),
|
||||
'primary_keys': schema_data.get('primary_keys', []),
|
||||
'foreign_keys': schema_data.get('foreign_keys', []),
|
||||
'indexes': schema_data.get('indexes', []),
|
||||
'sample_data': sample_data,
|
||||
'row_count': row_count,
|
||||
'table_comment': schema_data.get('table_comment', '')
|
||||
}
|
||||
|
||||
async def _save_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_name: str,
|
||||
metadata: Dict[str, Any]
|
||||
) -> TableMetadata:
|
||||
"""保存表元数据"""
|
||||
self.session.desc = f"为用户 {user_id} 保存表 {table_name} 的元数据"
|
||||
|
||||
# 检查是否已存在
|
||||
stmt = select(TableMetadata).where(
|
||||
TableMetadata.created_by == user_id,
|
||||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
# 更新现有记录
|
||||
existing.columns_info = metadata['columns_info']
|
||||
existing.primary_keys = metadata['primary_keys']
|
||||
existing.foreign_keys = metadata['foreign_keys']
|
||||
existing.indexes = metadata['indexes']
|
||||
existing.sample_data = metadata['sample_data']
|
||||
existing.row_count = metadata['row_count']
|
||||
existing.table_comment = metadata['table_comment']
|
||||
existing.last_synced_at = datetime.utcnow()
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
# 创建新记录
|
||||
table_metadata = TableMetadata(
|
||||
created_by=user_id,
|
||||
database_config_id=database_config_id,
|
||||
table_name=table_name,
|
||||
table_schema='public',
|
||||
table_type='BASE TABLE',
|
||||
table_comment=metadata['table_comment'],
|
||||
columns_info=metadata['columns_info'],
|
||||
primary_keys=metadata['primary_keys'],
|
||||
foreign_keys=metadata['foreign_keys'],
|
||||
indexes=metadata['indexes'],
|
||||
sample_data=metadata['sample_data'],
|
||||
row_count=metadata['row_count'],
|
||||
is_enabled_for_qa=True,
|
||||
last_synced_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(table_metadata)
|
||||
self.session.commit()
|
||||
self.session.refresh(table_metadata)
|
||||
self.session.desc = f"SUCCESS: 创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
return table_metadata
|
||||
|
||||
async def save_table_metadata_config(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_names: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""保存表元数据配置(简化版,只保存基本信息)"""
|
||||
self.session.desc = f"为用户 {user_id} 保存数据库配置 {database_config_id} 表 {table_names} 的元数据配置"
|
||||
# 获取数据库配置
|
||||
stmt = select(DatabaseConfig).where(
|
||||
DatabaseConfig.id == database_config_id,
|
||||
DatabaseConfig.user_id == user_id
|
||||
)
|
||||
db_config = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if not db_config:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
|
||||
saved_tables = []
|
||||
failed_tables = []
|
||||
|
||||
for table_name in table_names:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
stmt = select(TableMetadata).where(
|
||||
TableMetadata.user_id == user_id,
|
||||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"更新用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置"
|
||||
# 更新现有记录
|
||||
existing.is_enabled_for_qa = True
|
||||
existing.last_synced_at = datetime.utcnow()
|
||||
saved_tables.append({
|
||||
'table_name': table_name,
|
||||
'action': 'updated'
|
||||
})
|
||||
else:
|
||||
self.session.desc = f"创建用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置"
|
||||
# 创建新记录
|
||||
metadata = TableMetadata(
|
||||
created_by=user_id,
|
||||
database_config_id=database_config_id,
|
||||
table_name=table_name,
|
||||
table_schema='public', # 默认值
|
||||
table_type='table', # 默认值
|
||||
table_comment='',
|
||||
columns_count=0, # 后续可通过collect接口更新
|
||||
row_count=0, # 后续可通过collect接口更新
|
||||
is_enabled_for_qa=True,
|
||||
qa_description='',
|
||||
business_context='',
|
||||
sample_data='{}',
|
||||
column_info='{}',
|
||||
last_synced_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(metadata)
|
||||
saved_tables.append({
|
||||
'table_name': table_name,
|
||||
'action': 'created'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ERROR: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据配置失败: {str(e)}"
|
||||
failed_tables.append({
|
||||
'table_name': table_name,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
# 提交事务
|
||||
self.session.commit()
|
||||
self.session.desc = f"SUCCESS: 保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_names} 的元数据配置"
|
||||
return {
|
||||
'saved_tables': saved_tables,
|
||||
'failed_tables': failed_tables,
|
||||
'total_saved': len(saved_tables),
|
||||
'total_failed': len(failed_tables)
|
||||
}
|
||||
|
||||
|
||||
def get_user_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: Optional[int] = None
|
||||
) -> List[TableMetadata]:
|
||||
"""获取用户的表元数据列表"""
|
||||
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表元数据列表"
|
||||
stmt = select(TableMetadata).where(TableMetadata.created_by == user_id)
|
||||
|
||||
if database_config_id:
|
||||
stmt = stmt.where(TableMetadata.database_config_id == database_config_id)
|
||||
else:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 不存在"
|
||||
raise NotFoundError("数据库配置不存在")
|
||||
stmt = stmt.where(TableMetadata.is_enabled_for_qa == True)
|
||||
return self.session.scalars(stmt).all()
|
||||
|
||||
def get_table_metadata_by_name(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_name: str
|
||||
) -> Optional[TableMetadata]:
|
||||
"""根据表名获取表元数据"""
|
||||
self.session.desc = f"获取用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
stmt = select(TableMetadata).where(
|
||||
TableMetadata.created_by == user_id,
|
||||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
return self.session.scalar_one_or_none(stmt)
|
||||
|
||||
def update_table_qa_settings(
|
||||
self,
|
||||
user_id: int,
|
||||
metadata_id: int,
|
||||
settings: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""更新表的问答设置"""
|
||||
self.session.desc = f"更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置"
|
||||
try:
|
||||
stmt = select(TableMetadata).where(
|
||||
TableMetadata.id == metadata_id,
|
||||
TableMetadata.created_by == user_id
|
||||
)
|
||||
metadata = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if not metadata:
|
||||
self.session.desc = f"用户 {user_id} 数据库库配置表 metadata_id={metadata_id} 不存在"
|
||||
return False
|
||||
|
||||
if 'is_enabled_for_qa' in settings:
|
||||
metadata.is_enabled_for_qa = settings['is_enabled_for_qa']
|
||||
if 'qa_description' in settings:
|
||||
metadata.qa_description = settings['qa_description']
|
||||
if 'business_context' in settings:
|
||||
metadata.business_context = settings['business_context']
|
||||
|
||||
self.session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.desc = f"ERROR: 更新用户 {user_id} 数据库配置表 metadata_id={metadata_id} 的问答设置失败: {str(e)}"
|
||||
self.session.rollback()
|
||||
return False
|
||||
|
||||
def save_table_metadata(
|
||||
self,
|
||||
user_id: int,
|
||||
database_config_id: int,
|
||||
table_name: str,
|
||||
columns_info: List[Dict[str, Any]],
|
||||
primary_keys: List[str],
|
||||
row_count: int,
|
||||
table_comment: str = ''
|
||||
) -> TableMetadata:
|
||||
"""保存单个表的元数据"""
|
||||
self.session.desc = f"保存用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 的元数据"
|
||||
# 检查是否已存在
|
||||
stmt = select(TableMetadata).where(
|
||||
TableMetadata.created_by == user_id,
|
||||
TableMetadata.database_config_id == database_config_id,
|
||||
TableMetadata.table_name == table_name
|
||||
)
|
||||
existing = self.session.scalar_one_or_none(stmt)
|
||||
|
||||
if existing:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 已存在,更新其元数据"
|
||||
# 更新现有记录
|
||||
existing.columns_info = columns_info
|
||||
existing.primary_keys = primary_keys
|
||||
existing.row_count = row_count
|
||||
existing.table_comment = table_comment
|
||||
existing.last_synced_at = datetime.utcnow()
|
||||
self.session.commit()
|
||||
return existing
|
||||
else:
|
||||
self.session.desc = f"用户 {user_id} 数据库配置 {database_config_id} 表 {table_name} 不存在,创建新记录"
|
||||
# 创建新记录
|
||||
metadata = TableMetadata(
|
||||
created_by=user_id,
|
||||
database_config_id=database_config_id,
|
||||
table_name=table_name,
|
||||
table_schema='public',
|
||||
table_type='BASE TABLE',
|
||||
table_comment=table_comment,
|
||||
columns_info=columns_info,
|
||||
primary_keys=primary_keys,
|
||||
row_count=row_count,
|
||||
is_enabled_for_qa=True,
|
||||
last_synced_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(metadata)
|
||||
self.session.commit()
|
||||
self.session.refresh(metadata)
|
||||
return metadata
|
||||
|
||||
def _decrypt_password(self, encrypted_password: str) -> str:
|
||||
"""解密密码(需要实现加密逻辑)"""
|
||||
# 这里需要实现与DatabaseConfigService相同的解密逻辑
|
||||
# 暂时返回原始密码
|
||||
return encrypted_password
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Agent tools package."""
|
||||
|
||||
from .weather import WeatherQueryTool
|
||||
from .search import TavilySearchTool
|
||||
from .datetime_tool import DateTimeTool
|
||||
from th_agenter.services.mcp.postgresql_mcp import PostgreSQLMCPTool
|
||||
from th_agenter.services.mcp.mysql_mcp import MySQLMCPTool
|
||||
|
||||
|
||||
# Try to import LangChain native tools if available
|
||||
# TODO: 暂屏蔽
|
||||
# try:
|
||||
# from .langchain_native_tools import LANGCHAIN_NATIVE_TOOLS
|
||||
# except ImportError:
|
||||
# LANGCHAIN_NATIVE_TOOLS = []
|
||||
|
||||
__all__ = [
|
||||
'WeatherQueryTool',
|
||||
'TavilySearchTool',
|
||||
'DateTimeTool',
|
||||
'PostgreSQLMCPTool',
|
||||
'MySQLMCPTool',
|
||||
'LANGCHAIN_NATIVE_TOOLS'
|
||||
]
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Type, Literal, ClassVar
|
||||
import datetime
|
||||
import pytz
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("datetime_tool")
|
||||
|
||||
# 定义输入参数模型(使用Pydantic替代原get_parameters())
|
||||
class DateTimeInput(BaseModel):
|
||||
operation: Literal["current_time", "timezone_convert", "date_diff", "add_time", "format_date"] = Field(
|
||||
description="操作类型: current_time(当前时间), timezone_convert(时区转换), "
|
||||
"date_diff(日期差), add_time(时间加减), format_date(格式化日期)"
|
||||
)
|
||||
timezone: Optional[str] = Field(
|
||||
default="UTC",
|
||||
description="时区名称 (e.g., 'UTC', 'Asia/Shanghai')"
|
||||
)
|
||||
date_string: Optional[str] = Field(
|
||||
description="日期字符串 (格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS)"
|
||||
)
|
||||
target_timezone: Optional[str] = Field(
|
||||
description="目标时区(用于时区转换)"
|
||||
)
|
||||
days: Optional[int] = Field(
|
||||
default=0,
|
||||
description="要加减的天数"
|
||||
)
|
||||
hours: Optional[int] = Field(
|
||||
default=0,
|
||||
description="要加减的小时数"
|
||||
)
|
||||
format: Optional[str] = Field(
|
||||
default="%Y-%m-%d %H:%M:%S",
|
||||
description="日期格式字符串 (e.g., '%Y-%m-%d %H:%M:%S')"
|
||||
)
|
||||
|
||||
class DateTimeTool(BaseTool):
|
||||
"""日期时间操作工具(支持时区转换、日期计算等)"""
|
||||
|
||||
name: ClassVar[str] = "datetime_tool"
|
||||
description: ClassVar[str] = """执行日期时间相关操作,包括:
|
||||
- 获取当前时间
|
||||
- 时区转换
|
||||
- 计算日期差
|
||||
- 日期时间加减
|
||||
- 格式化日期
|
||||
使用时必须指定operation参数确定操作类型。"""
|
||||
args_schema: Type[BaseModel] = DateTimeInput
|
||||
|
||||
def _parse_datetime(self, date_string: str, timezone_str: str = "UTC") -> datetime.datetime:
|
||||
"""解析日期字符串(私有方法)"""
|
||||
tz = pytz.timezone(timezone_str)
|
||||
formats = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d",
|
||||
"%d/%m/%Y %H:%M:%S",
|
||||
"%d/%m/%Y",
|
||||
"%m/%d/%Y %H:%M:%S",
|
||||
"%m/%d/%Y"
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
dt = datetime.datetime.strptime(date_string, fmt)
|
||||
return tz.localize(dt)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError(f"无法解析日期字符串: {date_string}")
|
||||
|
||||
def _run(self,
|
||||
operation: str,
|
||||
timezone: str = "UTC",
|
||||
date_string: Optional[str] = None,
|
||||
target_timezone: Optional[str] = None,
|
||||
days: int = 0,
|
||||
hours: int = 0,
|
||||
format: str = "%Y-%m-%d %H:%M:%S") -> dict:
|
||||
"""同步执行日期时间操作"""
|
||||
logger.info(f"执行日期时间操作: {operation}")
|
||||
|
||||
try:
|
||||
if operation == "current_time":
|
||||
tz = pytz.timezone(timezone)
|
||||
now = datetime.datetime.now(tz)
|
||||
return {
|
||||
"status": "success",
|
||||
"result": {
|
||||
"formatted": now.strftime(format),
|
||||
"iso": now.isoformat(),
|
||||
"timestamp": now.timestamp(),
|
||||
"timezone": timezone
|
||||
},
|
||||
"summary": f"当前时间 ({timezone}): {now.strftime(format)}"
|
||||
}
|
||||
|
||||
elif operation == "timezone_convert":
|
||||
if not date_string or not target_timezone:
|
||||
raise ValueError("必须提供date_string和target_timezone参数")
|
||||
|
||||
source_dt = self._parse_datetime(date_string, timezone)
|
||||
target_dt = source_dt.astimezone(pytz.timezone(target_timezone))
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"result": {
|
||||
"source": source_dt.strftime(format),
|
||||
"target": target_dt.strftime(format),
|
||||
"source_tz": timezone,
|
||||
"target_tz": target_timezone
|
||||
},
|
||||
"summary": f"时区转换: {source_dt.strftime(format)} → {target_dt.strftime(format)}"
|
||||
}
|
||||
|
||||
elif operation == "date_diff":
|
||||
if not date_string:
|
||||
raise ValueError("必须提供date_string参数")
|
||||
|
||||
target_dt = self._parse_datetime(date_string, timezone)
|
||||
current_dt = datetime.datetime.now(pytz.timezone(timezone))
|
||||
delta = target_dt - current_dt
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"result": {
|
||||
"days": delta.days,
|
||||
"hours": delta.seconds // 3600,
|
||||
"total_seconds": delta.total_seconds(),
|
||||
"is_future": delta.days > 0
|
||||
},
|
||||
"summary": f"日期差: {abs(delta.days)}天 {delta.seconds//3600}小时"
|
||||
}
|
||||
|
||||
elif operation == "add_time":
|
||||
base_dt = self._parse_datetime(date_string, timezone) if date_string \
|
||||
else datetime.datetime.now(pytz.timezone(timezone))
|
||||
new_dt = base_dt + datetime.timedelta(days=days, hours=hours)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"result": {
|
||||
"original": base_dt.strftime(format),
|
||||
"new": new_dt.strftime(format),
|
||||
"delta": f"{days}天 {hours}小时"
|
||||
},
|
||||
"summary": f"时间计算: {base_dt.strftime(format)} + {days}天 {hours}小时 = {new_dt.strftime(format)}"
|
||||
}
|
||||
|
||||
elif operation == "format_date":
|
||||
dt = self._parse_datetime(date_string, timezone) if date_string \
|
||||
else datetime.datetime.now(pytz.timezone(timezone))
|
||||
formatted = dt.strftime(format)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"result": {
|
||||
"original": dt.isoformat(),
|
||||
"formatted": formatted
|
||||
},
|
||||
"summary": f"格式化结果: {formatted}"
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知操作类型: {operation}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"操作失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"operation": operation
|
||||
}
|
||||
|
||||
async def _arun(self, **kwargs):
|
||||
"""异步执行"""
|
||||
return self._run(**kwargs)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
"""基于TavilySearch的搜索工具"""
|
||||
|
||||
from th_agenter.core.config import get_settings
|
||||
from loguru import logger
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from typing import Optional, Type, ClassVar
|
||||
|
||||
# 定义输入参数模型(替代原get_parameters())
|
||||
class SearchInput(BaseModel):
|
||||
query: str = Field(description="搜索查询内容")
|
||||
max_results: Optional[int] = Field(
|
||||
default=5,
|
||||
description="返回结果的最大数量(默认:5)"
|
||||
)
|
||||
topic: Optional[str] = Field(
|
||||
default="general",
|
||||
description="搜索主题,可选值:general, academic, news, places"
|
||||
)
|
||||
|
||||
|
||||
class TavilySearchTool(BaseTool):
|
||||
name:ClassVar[str] = "tavily_search_tool"
|
||||
description:ClassVar[str] = """使用Tavily搜索引擎进行网络搜索,可以获取最新信息。
|
||||
输入应该包含搜索查询(query),可选参数包括max_results和topic。""" # 替代get_description()
|
||||
args_schema: Type[BaseModel] = SearchInput # 用Pydantic模型定义参数
|
||||
_tavily_api_key: str = PrivateAttr()
|
||||
_search_client: TavilySearchResults = PrivateAttr()
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._tavily_api_key = get_settings().tool.tavily_api_key
|
||||
if not self._tavily_api_key:
|
||||
raise ValueError("Tavily API key not found in settings")
|
||||
|
||||
# 初始化Tavily客户端
|
||||
self._search_client = TavilySearchResults(
|
||||
tavily_api_key=self._tavily_api_key
|
||||
)
|
||||
|
||||
def _run(self, query: str, max_results: int = 5, topic: str = "general"):
|
||||
try:
|
||||
logger.info(f"执行搜索:{query}")
|
||||
# 调用Tavily(LangChain已内置Tavily工具,这里直接使用)
|
||||
results = self._search_client.run({
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"topic": topic
|
||||
})
|
||||
|
||||
# 格式化结果(根据Tavily的实际返回结构调整)
|
||||
if isinstance(results, list):
|
||||
return {
|
||||
"status": "success",
|
||||
"results": [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("url", ""),
|
||||
"content": r.get("content", "")[:200] + "..."
|
||||
} for r in results
|
||||
]
|
||||
}
|
||||
else:
|
||||
return {"status": "error", "message": "Unexpected result format"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索失败: {str(e)}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def _arun(self, **kwargs):
|
||||
"""异步版本"""
|
||||
"""直接调用同步版本"""
|
||||
return self._run(**kwargs) # 直接委托给同步方法
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue