chore: 提交所有代码
This commit is contained in:
parent
1451d5f616
commit
d96c1eb65f
|
|
@ -0,0 +1,179 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control. See https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
webIOs/output/logs/
|
||||||
|
|
||||||
|
# OS generated files
|
||||||
|
Thumbs.db
|
||||||
|
.DS_Store
|
||||||
|
.DS_Store?
|
||||||
|
._*
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
ehthumbs.db
|
||||||
|
|
||||||
|
# FastAPI specific
|
||||||
|
*.pyc
|
||||||
|
uvicorn*.log
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the tzdata library which can be installed by adding
|
||||||
|
# `alembic[tz]` to the pip requirements.
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = module
|
||||||
|
# ruff.module = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Generic single-database configuration with an async dbapi.
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import asyncio, os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
database_url = os.getenv("DATABASE_URL")
|
||||||
|
if not database_url:
|
||||||
|
raise ValueError("环境变量DATABASE_URL未设置")
|
||||||
|
|
||||||
|
config.set_main_option("sqlalchemy.url", database_url)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# from myapp import mymodel
|
||||||
|
# target_metadata = mymodel.Base.metadata
|
||||||
|
from th_agenter.db import Base
|
||||||
|
# from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode."""
|
||||||
|
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""Initial migration
|
||||||
|
|
||||||
|
Revision ID: 424646027786
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-12-16 09:56:45.172954
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '424646027786'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('agent_configs',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('enabled_tools', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('max_iterations', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||||
|
sa.Column('system_message', sa.Text(), nullable=True),
|
||||||
|
sa.Column('verbose', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False)
|
||||||
|
op.create_table('conversations',
|
||||||
|
sa.Column('title', sa.String(length=200), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('knowledge_base_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('system_prompt', sa.Text(), nullable=True),
|
||||||
|
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||||
|
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('is_archived', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
|
||||||
|
op.create_table('database_configs',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('db_type', sa.String(length=20), nullable=False),
|
||||||
|
sa.Column('host', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('port', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('database', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('username', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('password', sa.Text(), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('connection_params', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')),
|
||||||
|
sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False)
|
||||||
|
op.create_table('documents',
|
||||||
|
sa.Column('knowledge_base_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||||
|
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||||
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
|
sa.Column('doc_metadata', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('chunk_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('embedding_model', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('vector_ids', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_documents'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False)
|
||||||
|
op.create_table('excel_files',
|
||||||
|
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||||
|
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('sheet_names', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('default_sheet', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('preview_data', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('data_types', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('total_rows', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('total_columns', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||||
|
sa.Column('last_accessed', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False)
|
||||||
|
op.create_table('knowledge_bases',
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('embedding_model', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('chunk_size', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('chunk_overlap', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('vector_db_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('collection_name', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False)
|
||||||
|
op.create_table('llm_configs',
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('api_key', sa.String(length=500), nullable=False),
|
||||||
|
sa.Column('base_url', sa.String(length=200), nullable=True),
|
||||||
|
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('temperature', sa.Float(), nullable=False),
|
||||||
|
sa.Column('top_p', sa.Float(), nullable=False),
|
||||||
|
sa.Column('frequency_penalty', sa.Float(), nullable=False),
|
||||||
|
sa.Column('presence_penalty', sa.Float(), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_embedding', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('usage_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||||
|
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||||
|
op.create_table('messages',
|
||||||
|
sa.Column('conversation_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False),
|
||||||
|
sa.Column('content', sa.Text(), nullable=False),
|
||||||
|
sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False),
|
||||||
|
sa.Column('message_metadata', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('context_documents', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('completion_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
|
||||||
|
op.create_table('roles',
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('is_system', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_roles'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True)
|
||||||
|
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||||
|
op.create_table('table_metadata',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('table_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('table_schema', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('table_type', sa.String(length=20), nullable=False),
|
||||||
|
sa.Column('table_comment', sa.Text(), nullable=True),
|
||||||
|
sa.Column('database_config_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('primary_keys', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('foreign_keys', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('indexes', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('sample_data', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('qa_description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('business_context', sa.Text(), nullable=True),
|
||||||
|
sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False)
|
||||||
|
op.create_table('users',
|
||||||
|
sa.Column('username', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('email', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('full_name', sa.String(length=100), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('bio', sa.Text(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_users'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||||
|
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||||
|
op.create_table('user_roles',
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')),
|
||||||
|
sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles'))
|
||||||
|
)
|
||||||
|
op.create_table('workflows',
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'),
|
||||||
|
sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
|
||||||
|
sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'),
|
||||||
|
sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'),
|
||||||
|
sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False)
|
||||||
|
op.create_table('workflow_executions',
|
||||||
|
sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'),
|
||||||
|
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||||
|
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||||
|
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||||
|
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||||
|
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||||
|
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||||
|
sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')),
|
||||||
|
sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False)
|
||||||
|
op.create_table('node_executions',
|
||||||
|
sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'),
|
||||||
|
sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'),
|
||||||
|
sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'),
|
||||||
|
sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'),
|
||||||
|
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||||
|
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||||
|
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||||
|
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||||
|
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||||
|
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'),
|
||||||
|
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')),
|
||||||
|
sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions'))
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions')
|
||||||
|
op.drop_table('node_executions')
|
||||||
|
op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions')
|
||||||
|
op.drop_table('workflow_executions')
|
||||||
|
op.drop_index(op.f('ix_workflows_id'), table_name='workflows')
|
||||||
|
op.drop_table('workflows')
|
||||||
|
op.drop_table('user_roles')
|
||||||
|
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||||
|
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||||
|
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||||
|
op.drop_table('users')
|
||||||
|
op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata')
|
||||||
|
op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata')
|
||||||
|
op.drop_table('table_metadata')
|
||||||
|
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||||
|
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||||
|
op.drop_index(op.f('ix_roles_code'), table_name='roles')
|
||||||
|
op.drop_table('roles')
|
||||||
|
op.drop_index(op.f('ix_messages_id'), table_name='messages')
|
||||||
|
op.drop_table('messages')
|
||||||
|
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||||
|
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||||
|
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||||
|
op.drop_table('llm_configs')
|
||||||
|
op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases')
|
||||||
|
op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases')
|
||||||
|
op.drop_table('knowledge_bases')
|
||||||
|
op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files')
|
||||||
|
op.drop_table('excel_files')
|
||||||
|
op.drop_index(op.f('ix_documents_id'), table_name='documents')
|
||||||
|
op.drop_table('documents')
|
||||||
|
op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs')
|
||||||
|
op.drop_table('database_configs')
|
||||||
|
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
|
||||||
|
op.drop_table('conversations')
|
||||||
|
op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs')
|
||||||
|
op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs')
|
||||||
|
op.drop_table('agent_configs')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
"""Add message_count and last_message_at to conversations
|
||||||
|
|
||||||
|
Revision ID: 8da391c6e2b7
|
||||||
|
Revises: 424646027786
|
||||||
|
Create Date: 2025-12-19 16:16:29.943314
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '8da391c6e2b7'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = '424646027786'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('conversations', sa.Column('message_count', sa.Integer(), nullable=False))
|
||||||
|
op.add_column('conversations', sa.Column('last_message_at', sa.DateTime(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('conversations', 'last_message_at')
|
||||||
|
op.drop_column('conversations', 'message_count')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
class DrGraphSession:
|
||||||
|
def __init__(self, stepIndex: int, msg: str, session_id: str):
|
||||||
|
logger.info(f"DrGraphSession.__init__: stepIndex={stepIndex}, msg={msg}, session_id={session_id}")
|
||||||
|
self.stepIndex = stepIndex
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
match = re.search(r";(-\d+)", msg);
|
||||||
|
level = -3
|
||||||
|
if match:
|
||||||
|
level = int(match.group(1))
|
||||||
|
value = value.replace(f";{level}", "")
|
||||||
|
level = -3 + level
|
||||||
|
|
||||||
|
if "警告" in value or value.startswith("WARNING"):
|
||||||
|
self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
elif "异常" in value or value.startswith("EXCEPTION"):
|
||||||
|
self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
elif "成功" in value or value.startswith("SUCCESS"):
|
||||||
|
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
elif "开始" in value or value.startswith("START"):
|
||||||
|
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
elif "失败" in value or value.startswith("ERROR"):
|
||||||
|
self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
else:
|
||||||
|
self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||||
|
|
||||||
|
def log_prefix(self) -> str:
|
||||||
|
"""Get log prefix with session ID and desc."""
|
||||||
|
return f"〖Session{self.session_id}〗"
|
||||||
|
|
||||||
|
def parse_source_pos(self, level: int):
|
||||||
|
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
||||||
|
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||||
|
if match:
|
||||||
|
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||||
|
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def log_info(self, msg: str, level: int = -2):
|
||||||
|
"""Log info message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_success(self, msg: str, level: int = -2):
|
||||||
|
"""Log success message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_warning(self, msg: str, level: int = -2):
|
||||||
|
"""Log warning message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_error(self, msg: str, level: int = -2):
|
||||||
|
"""Log error message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_exception(self, msg: str, level: int = -2):
|
||||||
|
"""Log exception message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from sqlalchemy import create_engine, inspect
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
async def check_table_constraints():
|
||||||
|
try:
|
||||||
|
# 获取数据库连接字符串
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "mysql+asyncmy://root:123456@localhost:3306/th_agenter")
|
||||||
|
|
||||||
|
# 创建异步引擎
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=True)
|
||||||
|
|
||||||
|
# 创建会话
|
||||||
|
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
# 获取数据库连接
|
||||||
|
async with session.begin():
|
||||||
|
# 使用inspect查看表结构
|
||||||
|
inspector = inspect(engine)
|
||||||
|
|
||||||
|
# 获取messages表的所有约束
|
||||||
|
constraints = await engine.run_sync(inspector.get_table_constraints, 'messages')
|
||||||
|
print("Messages表的所有约束:")
|
||||||
|
for constraint in constraints:
|
||||||
|
print(f" 约束名称: {constraint['name']}, 类型: {constraint['type']}")
|
||||||
|
if constraint['type'] == 'PRIMARY KEY':
|
||||||
|
print(f" 主键约束列: {constraint['constrained_columns']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"检查约束时出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(check_table_constraints())
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
import jwt
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
print(f"jwt module path: {inspect.getfile(jwt)}")
|
||||||
|
print(f"jwt module attributes: {dir(jwt)}")
|
||||||
|
try:
|
||||||
|
print(f"jwt module __version__: {jwt.__version__}")
|
||||||
|
except AttributeError:
|
||||||
|
print("jwt module has no __version__ attribute")
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,17 @@
|
||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
container_name: pgvector-db
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: drgraph
|
||||||
|
POSTGRES_PASSWORD: yingping
|
||||||
|
POSTGRES_DB: th_agenter
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pgdata:
|
||||||
|
# docker exec -it pgvector-db psql -U drgraph -d th_agenter
|
||||||
Binary file not shown.
|
|
@ -0,0 +1,139 @@
|
||||||
|
# uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
||||||
|
# 1. pip install fastapi-cdn-host
|
||||||
|
# 2. import fastapi_cdn_host
|
||||||
|
# 3. fastapi_cdn_host.patch_docs(app)
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import fastapi_cdn_host
|
||||||
|
|
||||||
|
from os.path import dirname, realpath
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from utils.util_log import init_logger
|
||||||
|
from loguru import logger
|
||||||
|
base_dir: str = dirname(realpath(__file__))
|
||||||
|
init_logger(base_dir)
|
||||||
|
|
||||||
|
from th_agenter.api.routes import router
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager."""
|
||||||
|
logger.info("[生命周期] - Starting up TH Agenter application...")
|
||||||
|
yield
|
||||||
|
# Shutdown
|
||||||
|
logger.info("[生命周期] - Shutting down TH Agenter application...")
|
||||||
|
|
||||||
|
def setup_exception_handlers(app: FastAPI) -> None:
|
||||||
|
"""Setup global exception handlers."""
|
||||||
|
|
||||||
|
# Import custom exceptions and handlers
|
||||||
|
from utils.util_exceptions import ChatAgentException, chat_agent_exception_handler
|
||||||
|
|
||||||
|
@app.exception_handler(ChatAgentException)
|
||||||
|
async def custom_chat_agent_exception_handler(request, exc):
|
||||||
|
return await chat_agent_exception_handler(request, exc)
|
||||||
|
|
||||||
|
@app.exception_handler(StarletteHTTPException)
|
||||||
|
async def http_exception_handler(request, exc):
|
||||||
|
from utils.util_exceptions import HxfErrorResponse
|
||||||
|
logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}")
|
||||||
|
return HxfErrorResponse(exc)
|
||||||
|
|
||||||
|
def make_json_serializable(obj):
|
||||||
|
"""递归地将对象转换为JSON可序列化的格式"""
|
||||||
|
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
elif isinstance(obj, (ValueError, Exception)):
|
||||||
|
return str(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return [make_json_serializable(item) for item in obj]
|
||||||
|
else:
|
||||||
|
# For any other object, convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
# Convert any non-serializable objects to strings in error details
|
||||||
|
try:
|
||||||
|
errors = make_json_serializable(exc.errors())
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: if even our conversion fails, use a simple error message
|
||||||
|
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||||
|
logger.exception(f"Request Validation Error: {errors}")
|
||||||
|
|
||||||
|
logger.exception(f"validation_error: {errors}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "validation_error",
|
||||||
|
"message": "Request validation failed",
|
||||||
|
"details": errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc):
|
||||||
|
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "internal_error",
|
||||||
|
"message": "Internal server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create and configure FastAPI application."""
|
||||||
|
from th_agenter.core.config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.app_name,
|
||||||
|
version=settings.app_version,
|
||||||
|
description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由TH Agenter修改",
|
||||||
|
debug=settings.debug,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add middleware
|
||||||
|
from th_agenter.core.app import setup_middleware
|
||||||
|
setup_middleware(app, settings)
|
||||||
|
|
||||||
|
# # Add exception handlers
|
||||||
|
setup_exception_handlers(app)
|
||||||
|
add_router(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
def add_router(app: FastAPI) -> None:
|
||||||
|
"""Add default routers to the FastAPI application."""
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def read_root():
|
||||||
|
logger.info("Hello World")
|
||||||
|
return {"Hello": "World"}
|
||||||
|
|
||||||
|
# Include routers
|
||||||
|
app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
fastapi_cdn_host.patch_docs(app)
|
||||||
|
# from test.example import internet_search_tool
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.13.2
|
||||||
|
aiomysql==0.3.2
|
||||||
|
aiosignal==1.4.0
|
||||||
|
alembic==1.17.2
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.12.0
|
||||||
|
asyncpg==0.31.0
|
||||||
|
attrs==25.4.0
|
||||||
|
bcrypt==5.0.0
|
||||||
|
boto3==1.42.9
|
||||||
|
botocore==1.42.9
|
||||||
|
certifi==2025.11.12
|
||||||
|
cffi==2.0.0
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
click==8.3.1
|
||||||
|
colorama==0.4.6
|
||||||
|
cryptography==46.0.3
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
distro==1.9.0
|
||||||
|
dnspython==2.8.0
|
||||||
|
email-validator==2.3.0
|
||||||
|
fastapi==0.124.4
|
||||||
|
fastapi-cli==0.0.16
|
||||||
|
fastapi-cloud-cli==0.6.0
|
||||||
|
fastar==0.8.0
|
||||||
|
filelock==3.20.0
|
||||||
|
frozenlist==1.8.0
|
||||||
|
greenlet==3.3.0
|
||||||
|
h11==0.16.0
|
||||||
|
httpcore==1.0.9
|
||||||
|
httptools==0.7.1
|
||||||
|
httpx==0.28.1
|
||||||
|
httpx-sse==0.4.3
|
||||||
|
idna==3.11
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jiter==0.12.0
|
||||||
|
jmespath==1.0.1
|
||||||
|
jsonpatch==1.33
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
langchain==1.1.3
|
||||||
|
langchain-classic==1.0.0
|
||||||
|
langchain-chroma>=0.1.0
|
||||||
|
langchain-community==0.4.1
|
||||||
|
langchain-core==1.2.0
|
||||||
|
langchain-openai==1.1.3
|
||||||
|
langchain-postgres==0.0.16
|
||||||
|
langchain-text-splitters==1.0.0
|
||||||
|
langgraph==1.0.5
|
||||||
|
langgraph-checkpoint==3.0.1
|
||||||
|
langgraph-prebuilt==1.0.5
|
||||||
|
langgraph-sdk==0.3.0
|
||||||
|
langsmith==0.4.59
|
||||||
|
loguru==0.7.3
|
||||||
|
Mako==1.3.10
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
marshmallow==3.26.1
|
||||||
|
mdurl==0.1.2
|
||||||
|
modelscope==1.33.0
|
||||||
|
multidict==6.7.0
|
||||||
|
mypy_extensions==1.1.0
|
||||||
|
numpy==2.3.5
|
||||||
|
openai==2.11.0
|
||||||
|
orjson==3.11.5
|
||||||
|
ormsgpack==1.12.0
|
||||||
|
packaging==25.0
|
||||||
|
pandas==2.3.3
|
||||||
|
pdfminer.six==20251107
|
||||||
|
pdfplumber==0.11.8
|
||||||
|
pgvector==0.3.6
|
||||||
|
pillow==12.0.0
|
||||||
|
propcache==0.4.1
|
||||||
|
psycopg==3.3.2
|
||||||
|
psycopg-binary==3.3.2
|
||||||
|
psycopg-pool==3.3.0
|
||||||
|
psycopg2==2.9.11
|
||||||
|
pycparser==2.23
|
||||||
|
pydantic==2.12.5
|
||||||
|
pydantic-extra-types==2.10.6
|
||||||
|
pydantic-settings==2.12.0
|
||||||
|
pydantic_core==2.41.5
|
||||||
|
Pygments==2.19.2
|
||||||
|
PyJWT==2.10.1
|
||||||
|
PyMySQL==1.1.2
|
||||||
|
pypdfium2==5.2.0
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-dotenv==1.2.1
|
||||||
|
python-multipart==0.0.20
|
||||||
|
pytz==2025.2
|
||||||
|
PyYAML==6.0.3
|
||||||
|
regex==2025.11.3
|
||||||
|
requests==2.32.5
|
||||||
|
requests-toolbelt==1.0.0
|
||||||
|
rich==14.2.0
|
||||||
|
rich-toolkit==0.17.0
|
||||||
|
rignore==0.7.6
|
||||||
|
ruamel.yaml==0.18.16
|
||||||
|
ruamel.yaml.clib==0.2.15
|
||||||
|
s3transfer==0.16.0
|
||||||
|
sentry-sdk==2.47.0
|
||||||
|
setuptools==80.9.0
|
||||||
|
shellingham==1.5.4
|
||||||
|
six==1.17.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
SQLAlchemy==2.0.45
|
||||||
|
starlette==0.50.0
|
||||||
|
tenacity==9.1.2
|
||||||
|
tiktoken==0.12.0
|
||||||
|
tqdm==4.67.1
|
||||||
|
typer==0.20.0
|
||||||
|
typing-inspect==0.9.0
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2025.2
|
||||||
|
ujson==5.11.0
|
||||||
|
urllib3==2.6.2
|
||||||
|
uuid_utils==0.12.0
|
||||||
|
uvicorn==0.38.0
|
||||||
|
watchfiles==1.1.1
|
||||||
|
websockets==15.0.1
|
||||||
|
wheel==0.45.1
|
||||||
|
win32_setctime==1.2.0
|
||||||
|
xxhash==3.6.0
|
||||||
|
yarl==1.22.0
|
||||||
|
zstandard==0.25.0
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
# 前端代理到后端 (端口 5005 → 8000)
|
||||||
|
|
||||||
|
前端(如运行在 5005 的 Vite/Vben)访问 `/api/*` 时,需要转发到本后端 **http://localhost:8000**。
|
||||||
|
|
||||||
|
## Vite
|
||||||
|
|
||||||
|
在 `vite.config.ts` 中:
|
||||||
|
|
||||||
|
```ts
|
||||||
|
export default defineConfig({
|
||||||
|
server: {
|
||||||
|
port: 5005,
|
||||||
|
proxy: {
|
||||||
|
'/api': {
|
||||||
|
target: 'http://localhost:8000',
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
## Vben Admin
|
||||||
|
|
||||||
|
在 `.env.development` 或环境变量中设置:
|
||||||
|
|
||||||
|
```
|
||||||
|
VITE_GLOB_API_URL=/api
|
||||||
|
```
|
||||||
|
|
||||||
|
并确保 Vite 的 `server.proxy` 将 `/api` 指向 `http://localhost:8000`(同上)。
|
||||||
|
|
||||||
|
## 直接调后端(排错用)
|
||||||
|
|
||||||
|
不经过前端代理,直接请求后端:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://localhost:8000/api/auth/login' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"email":"admin@example.com","password":"admin123"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
默认管理员:`admin@example.com` / `admin123`(需先执行 `python3 scripts/seed_admin.py`)。
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# 通过 Homebrew 安装并启动 MySQL(macOS)
|
||||||
|
# 请在终端中执行:./scripts/install_mysql.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# 查找 brew
|
||||||
|
BREW=""
|
||||||
|
for p in /opt/homebrew/bin/brew /usr/local/bin/brew; do
|
||||||
|
[ -x "$p" ] && BREW="$p" && break
|
||||||
|
done
|
||||||
|
if [ -z "$BREW" ]; then
|
||||||
|
echo ">>> 未找到 Homebrew。请先安装:"
|
||||||
|
echo " /bin/bash -c \"\$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)\""
|
||||||
|
echo " 安装后按提示把 brew 加入 PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ">>> 安装 MySQL..."
|
||||||
|
$BREW install mysql
|
||||||
|
|
||||||
|
echo ">>> 启动 MySQL 服务..."
|
||||||
|
$BREW services start mysql
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo ">>> MySQL 安装并已启动。等待几秒后,创建库和用户请执行:"
|
||||||
|
echo " ./scripts/setup_mysql_local.sh"
|
||||||
|
echo ""
|
||||||
|
echo ">>> 然后在 .env 中设置:"
|
||||||
|
echo " DATABASE_URL=mysql+aiomysql://root:yingping@localhost:3306/allm?charset=utf8mb4"
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""创建默认管理员账号 admin@example.com / admin123,便于前端登录。"""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 在导入 th_agenter 前加载 .env
|
||||||
|
from pathlib import Path
|
||||||
|
_root = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(_root))
|
||||||
|
os.chdir(_root)
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 若未设置,可在此指定(或通过环境变量传入)
|
||||||
|
# os.environ.setdefault("DATABASE_URL", "mysql+aiomysql://root:xxx@localhost:3306/allm?charset=utf8mb4")
|
||||||
|
|
||||||
|
from th_agenter.db.database import AsyncSessionFactory
|
||||||
|
from th_agenter.services.user import UserService
|
||||||
|
from utils.util_schemas import UserCreate
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with AsyncSessionFactory() as session:
|
||||||
|
svc = UserService(session)
|
||||||
|
exists = await svc.get_user_by_email("admin@example.com")
|
||||||
|
if exists:
|
||||||
|
print("admin@example.com 已存在,跳过创建")
|
||||||
|
return
|
||||||
|
user = await svc.create_user(UserCreate(
|
||||||
|
username="admin",
|
||||||
|
email="admin@example.com",
|
||||||
|
password="admin123",
|
||||||
|
full_name="Admin",
|
||||||
|
))
|
||||||
|
print(f"已创建管理员: {user.username} / admin@example.com / admin123")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# 本地 MySQL 建库脚本:创建 allm,root 密码 yingping
|
||||||
|
# 连接串:mysql+aiomysql://root:yingping@localhost:3306/allm?charset=utf8mb4
|
||||||
|
#
|
||||||
|
# 使用前请确保 MySQL 已安装并启动:
|
||||||
|
# macOS: brew install mysql && brew services start mysql
|
||||||
|
# Ubuntu: sudo apt install mysql-server && sudo systemctl start mysql
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
DB_NAME="allm"
|
||||||
|
DB_USER="root"
|
||||||
|
DB_PASS="yingping"
|
||||||
|
HOST="127.0.0.1"
|
||||||
|
PORT="3306"
|
||||||
|
|
||||||
|
echo ">>> 检查 MySQL 是否可连接 (${HOST}:${PORT}) ..."
|
||||||
|
|
||||||
|
# 尝试无密码连接(首次安装)
|
||||||
|
if mysql -u "$DB_USER" -h "$HOST" -P "$PORT" -e "SELECT 1" 2>/dev/null; then
|
||||||
|
echo ">>> 使用 root 无密码连接成功,创建库并设置密码..."
|
||||||
|
mysql -u "$DB_USER" -h "$HOST" -P "$PORT" -e "
|
||||||
|
CREATE DATABASE IF NOT EXISTS \`${DB_NAME}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||||
|
ALTER USER 'root'@'localhost' IDENTIFIED BY '${DB_PASS}';
|
||||||
|
ALTER USER 'root'@'127.0.0.1' IDENTIFIED BY '${DB_PASS}';
|
||||||
|
FLUSH PRIVILEGES;
|
||||||
|
"
|
||||||
|
echo ">>> 数据库 ${DB_NAME} 已创建,root 密码已设为 ${DB_PASS}"
|
||||||
|
# 尝试使用 yingping 连接(可能已设置过)
|
||||||
|
elif mysql -u "$DB_USER" -p"${DB_PASS}" -h "$HOST" -P "$PORT" -e "SELECT 1" 2>/dev/null; then
|
||||||
|
echo ">>> 使用 root:yingping 连接成功,确保库存在..."
|
||||||
|
mysql -u "$DB_USER" -p"${DB_PASS}" -h "$HOST" -P "$PORT" -e "
|
||||||
|
CREATE DATABASE IF NOT EXISTS \`${DB_NAME}\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||||
|
"
|
||||||
|
echo ">>> 数据库 ${DB_NAME} 已就绪"
|
||||||
|
else
|
||||||
|
echo ">>> 无法连接 MySQL。请先安装并启动,且能以 root 登录(无密码或已知密码)。"
|
||||||
|
echo ">>> 手动执行:"
|
||||||
|
echo " mysql -u root -p -e \"CREATE DATABASE IF NOT EXISTS ${DB_NAME} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; ALTER USER 'root'@'localhost' IDENTIFIED BY '${DB_PASS}'; FLUSH PRIVILEGES;\""
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo ">>> 在 .env 中设置:"
|
||||||
|
echo "DATABASE_URL=mysql+aiomysql://${DB_USER}:${DB_PASS}@localhost:${PORT}/${DB_NAME}?charset=utf8mb4"
|
||||||
|
echo ""
|
||||||
|
echo ">>> 然后执行迁移:"
|
||||||
|
echo "DATABASE_URL=\"mysql+aiomysql://${DB_USER}:${DB_PASS}@localhost:${PORT}/${DB_NAME}?charset=utf8mb4\" python -m alembic upgrade head"
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# 本地启动(不使用 Docker,使用本地 PostgreSQL)
|
||||||
|
#
|
||||||
|
# 前置条件:
|
||||||
|
# 1. 本地 PostgreSQL 已安装并运行,且已安装 pgvector 扩展
|
||||||
|
# 2. 已创建数据库 th_agenter、用户 drgraph / 密码 yingping(与 docker-compose 一致)
|
||||||
|
# 创建示例:psql -U postgres -c "CREATE USER drgraph WITH PASSWORD 'yingping';"
|
||||||
|
# psql -U postgres -c "CREATE DATABASE th_agenter OWNER drgraph;"
|
||||||
|
# psql -U drgraph -d th_agenter -c "CREATE EXTENSION vector;"
|
||||||
|
# 3. 首次运行前执行迁移:DATABASE_URL="postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter" python3 -m alembic upgrade head
|
||||||
|
#
|
||||||
|
# 也可在 .env 中设置 DATABASE_URL=postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter
|
||||||
|
|
||||||
|
set -e
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
export DATABASE_URL="${DATABASE_URL:-postgresql+asyncpg://drgraph:yingping@localhost:5432/th_agenter}"
|
||||||
|
|
||||||
|
exec python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
# Test package for PostgreSQL agent functionality
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from deepagents import create_deep_agent
|
||||||
|
from openai import OpenAI
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具
|
||||||
|
from deepagents.backends import StoreBackend
|
||||||
|
from loguru import logger
|
||||||
|
def internet_search_tool(query: str):
|
||||||
|
"""Run a web search"""
|
||||||
|
logger.info(f"Running internet search for query: {query}")
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=os.getenv('DASHSCOPE_API_KEY'),
|
||||||
|
base_url=os.getenv('DASHSCOPE_BASE_URL'),
|
||||||
|
)
|
||||||
|
logger.info(f"create OpenAI")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="qwen-plus",
|
||||||
|
messages=[
|
||||||
|
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||||
|
{'role': 'user', 'content': query}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"enable_search": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"create completions")
|
||||||
|
logger.info(f"OpenAI response: {completion.choices[0].message.content}")
|
||||||
|
return completion.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# System prompt to steer the agent to be an expert researcher
|
||||||
|
today = datetime.now().strftime("%Y年%m月%d日")
|
||||||
|
research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。
|
||||||
|
|
||||||
|
你可以使用互联网搜索工具来获取信息。
|
||||||
|
## `internet_search`
|
||||||
|
使用此工具对给定查询进行互联网搜索。你可以指定返回结果的最大数量、主题以及是否包含原始内容。
|
||||||
|
|
||||||
|
今天的日期是:{today}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create the deep agent with memory
|
||||||
|
model = init_chat_model(
|
||||||
|
model="gpt-4.1-mini",
|
||||||
|
model_provider='openai',
|
||||||
|
api_key=os.getenv('OPENAI_API_KEY'),
|
||||||
|
base_url=os.getenv('OPENAI_BASE_URL'),
|
||||||
|
)
|
||||||
|
checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史
|
||||||
|
|
||||||
|
agent = create_deep_agent( # state:thread会话级的状态
|
||||||
|
tools=[internet_search_tool],
|
||||||
|
system_prompt=research_instructions,
|
||||||
|
model=model,
|
||||||
|
checkpointer=checkpointer, # 添加检查点,启用自动记忆
|
||||||
|
interrupt_on={'internet_search_tool':True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多轮对话循环(使用 Checkpointer 自动记忆)
|
||||||
|
printed_msg_ids = set() # 跟踪已打印的消息ID
|
||||||
|
thread_id = "user_session_001" # 会话 ID,区分不同用户/会话
|
||||||
|
config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话
|
||||||
|
|
||||||
|
print("开始对话(输入 'exit' 退出):")
|
||||||
|
while True:
|
||||||
|
user_input = input("\nHUMAN: ").strip()
|
||||||
|
if user_input.lower() == 'exit':
|
||||||
|
break
|
||||||
|
|
||||||
|
# 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印
|
||||||
|
pending_resume = None
|
||||||
|
while True:
|
||||||
|
if pending_resume is None:
|
||||||
|
request = {"messages": [{"role": "user", "content": user_input}]}
|
||||||
|
else:
|
||||||
|
from langgraph.types import Command as _Command
|
||||||
|
|
||||||
|
request = _Command(resume=pending_resume)
|
||||||
|
pending_resume = None
|
||||||
|
|
||||||
|
for item in agent.stream(
|
||||||
|
request,
|
||||||
|
config=config,
|
||||||
|
stream_mode="values",
|
||||||
|
):
|
||||||
|
state = item[0] if isinstance(item, tuple) and len(item) == 2 else item
|
||||||
|
|
||||||
|
# 先检查是否触发了 Human-In-The-Loop 中断
|
||||||
|
if isinstance(state, dict) and "__interrupt__" in state:
|
||||||
|
interrupts = state["__interrupt__"] or []
|
||||||
|
if interrupts:
|
||||||
|
hitl_payload = interrupts[0].value
|
||||||
|
action_requests = hitl_payload.get("action_requests", [])
|
||||||
|
|
||||||
|
print("\n=== 需要人工审批的工具调用 ===")
|
||||||
|
decisions: list[dict[str, str]] = []
|
||||||
|
for idx, ar in enumerate(action_requests):
|
||||||
|
name = ar.get("name")
|
||||||
|
args = ar.get("args")
|
||||||
|
print(f"[{idx}] 工具 {name} 参数: {args}")
|
||||||
|
while True:
|
||||||
|
choice = input(" 决策 (a=approve, r=reject): ").strip().lower()
|
||||||
|
if choice in ("a", "r"):
|
||||||
|
break
|
||||||
|
decisions.append({"type": "approve" if choice == "a" else "reject"})
|
||||||
|
|
||||||
|
# 下一轮调用改为 resume,同一轮用户回合继续往下跑
|
||||||
|
pending_resume = {"decisions": decisions}
|
||||||
|
break
|
||||||
|
|
||||||
|
# 兼容 dict state 和 AgentState dataclass
|
||||||
|
messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", [])
|
||||||
|
for msg in messages:
|
||||||
|
msg_id = getattr(msg, "id", None)
|
||||||
|
if msg_id is not None and msg_id in printed_msg_ids:
|
||||||
|
continue
|
||||||
|
if msg_id is not None:
|
||||||
|
printed_msg_ids.add(msg_id)
|
||||||
|
|
||||||
|
msg_type = getattr(msg, "type", None)
|
||||||
|
|
||||||
|
if msg_type == "human":
|
||||||
|
# 用户输入已经在命令行里,不再重复打印
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg_type == "ai":
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||||
|
if tool_calls:
|
||||||
|
# 这是发起工具调用的 AI 消息(TOOL CALL)
|
||||||
|
for tc in tool_calls:
|
||||||
|
tool_name = tc.get("name")
|
||||||
|
args = tc.get("args")
|
||||||
|
print(f"TOOL CALL [{tool_name}]: {args}")
|
||||||
|
# 如果 AI 同时带有自然语言内容,也一起打印
|
||||||
|
if getattr(msg, "content", None):
|
||||||
|
print(f"AI: {msg.content}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg_type == "tool":
|
||||||
|
# 工具执行结果(TOOL RESPONSE)
|
||||||
|
tool_name = getattr(msg, "name", None) or "tool"
|
||||||
|
print(f"TOOL RESPONSE [{tool_name}]: {msg.content}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 兜底:其它类型直接打印出来便于调试
|
||||||
|
print(f"[{msg_type}]: {getattr(msg, 'content', None)}")
|
||||||
|
|
||||||
|
# 如果没有新的中断需要 resume,则整轮结束,等待下一轮用户输入
|
||||||
|
if pending_resume is None:
|
||||||
|
break
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
# from vllm import LLM, SamplingParams
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
# 自动下载模型时,指定使用modelscope; 否则,会从HuggingFace下载
|
||||||
|
os.environ['VLLM_USE_MODELSCOPE']='True'
|
||||||
|
|
||||||
|
# def get_completion(prompts, model, tokenizer=None, max_tokens=512, temperature=0.8, top_p=0.95, max_model_len=2048):
|
||||||
|
# stop_token_ids = [151329, 151336, 151338]
|
||||||
|
# # 创建采样参数。temperature 控制生成文本的多样性,top_p 控制核心采样的概率
|
||||||
|
# sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids)
|
||||||
|
# # 初始化 vLLM 推理引擎
|
||||||
|
# llm = LLM(model=model, tokenizer=tokenizer, max_model_len=max_model_len,trust_remote_code=True)
|
||||||
|
# outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# return outputs
|
||||||
|
|
||||||
|
def vl_test():
|
||||||
|
logger.info("vl_test")
|
||||||
|
|
||||||
|
# 使用LangChain 1.x的ChatOllama类创建客户端
|
||||||
|
client = ChatOllama(
|
||||||
|
base_url="http://192.168.10.11:11434",
|
||||||
|
model="llava-phi3:latest", # "qwen3-vl:8b",
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试调用qwen3-vl:8b视觉大模型
|
||||||
|
try:
|
||||||
|
# 使用LangChain 1.x的方式构建消息
|
||||||
|
message = HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "请描述这张图片的内容"
|
||||||
|
},
|
||||||
|
# 如果需要添加图像,可以使用以下格式:
|
||||||
|
# {
|
||||||
|
# "type": "image_url",
|
||||||
|
# "image_url": {
|
||||||
|
# "url": "https://example.com/image.jpg" # 或者base64编码的图片数据
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用模型
|
||||||
|
response = client.invoke([message])
|
||||||
|
|
||||||
|
# 获取模型响应
|
||||||
|
result = response.content
|
||||||
|
logger.info(f"qwen3-vl:8b响应: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用qwen3-vl:8b失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
# 如果直接运行该文件,执行测试
|
||||||
|
if __name__ == "__main__":
|
||||||
|
vl_test()
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""API module for TH Agenter."""
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""API endpoints for TH Agenter."""
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
"""Authentication endpoints."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...core.config import get_settings
|
||||||
|
from ...db.database import DrSession, get_session
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.user import UserService
|
||||||
|
from ...schemas.user import UserResponse, UserCreate, LoginResponse
|
||||||
|
from utils.util_schemas import Token, LoginRequest
|
||||||
|
from loguru import logger
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, summary="注册新用户")
|
||||||
|
async def register(
|
||||||
|
request_user_data: UserCreate,
|
||||||
|
session: DrSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""注册新用户"""
|
||||||
|
user_service = UserService(session)
|
||||||
|
session.desc = f"START: 注册用户 {request_user_data.email}"
|
||||||
|
if await user_service.get_user_by_email(request_user_data.email):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"邮箱 {request_user_data.email} 已被注册,请使用其他邮箱注册!!!"
|
||||||
|
)
|
||||||
|
|
||||||
|
if await user_service.get_user_by_username(request_user_data.username):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"用户名 {request_user_data.username} 已被注册,请使用其他用户名注册!!!"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await user_service.create_user(request_user_data)
|
||||||
|
response = UserResponse.model_validate(user, from_attributes=True)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/login", response_model=LoginResponse, summary="邮箱与密码登录")
|
||||||
|
async def login(
|
||||||
|
login_data: LoginRequest,
|
||||||
|
session: DrSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""邮箱与密码登录"""
|
||||||
|
# Authenticate user by email
|
||||||
|
session.desc = f"START: 用户 {login_data.email} 尝试登录"
|
||||||
|
user = await AuthService.authenticate_user_by_email(session, login_data.email, login_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"邮箱 {login_data.email} 或密码错误,请检查后重试!!!",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = await AuthService.create_access_token(
|
||||||
|
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
session.desc = f"用户 {user.username} 登录成功"
|
||||||
|
|
||||||
|
response = LoginResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=settings.security.access_token_expire_minutes * 60,
|
||||||
|
user=UserResponse.model_validate(user, from_attributes=True)
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/login-oauth", response_model=Token, summary="用户通过用户名和密码登录 (OAuth2 兼容)")
|
||||||
|
async def login_oauth(
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
session: DrSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""用户通过用户名和密码登录 (OAuth2 兼容)"""
|
||||||
|
session.desc = f"START: 用户 {form_data.username} 尝试 OAuth2 登录"
|
||||||
|
user = await AuthService.authenticate_user(session, form_data.username, form_data.password)
|
||||||
|
if not user:
|
||||||
|
session.desc = f"用户 {form_data.username} 尝试 OAuth2 登录失败"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = await AuthService.create_access_token(
|
||||||
|
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
session.desc = f"用户 {user.username} OAuth2 登录成功"
|
||||||
|
|
||||||
|
return HxfResponse(
|
||||||
|
{
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token, summary="刷新访问token")
|
||||||
|
async def refresh_token(
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: DrSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""刷新访问 token"""
|
||||||
|
# Create new access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = await AuthService.create_access_token(
|
||||||
|
session, data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
response = Token(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=settings.security.access_token_expire_minutes * 60
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
|
||||||
|
async def get_current_user_info(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
response = UserResponse.model_validate(current_user, from_attributes=True)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
@ -0,0 +1,283 @@
|
||||||
|
"""Chat endpoints for TH Agenter."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...models.user import User
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.chat import ChatService
|
||||||
|
from ...services.conversation import ConversationService
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
from utils.util_schemas import (
|
||||||
|
ConversationCreate,
|
||||||
|
ConversationResponse,
|
||||||
|
ConversationUpdate,
|
||||||
|
MessageCreate,
|
||||||
|
MessageResponse,
|
||||||
|
ChatRequest,
|
||||||
|
ChatResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
|
||||||
|
async def update_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
conversation_update: ConversationUpdate,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""更新指定对话"""
|
||||||
|
session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
updated_conversation = await conversation_service.update_conversation(
|
||||||
|
conversation_id, conversation_update
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||||
|
response = ConversationResponse.model_validate(updated_conversation)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
|
||||||
|
async def delete_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""删除指定对话"""
|
||||||
|
session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
await conversation_service.delete_conversation(conversation_id)
|
||||||
|
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||||
|
response = {"message": "Conversation deleted successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
|
||||||
|
async def archive_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""归档指定对话."""
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
success = await conversation_service.archive_conversation(conversation_id)
|
||||||
|
if not success:
|
||||||
|
session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to archive conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||||
|
response = {"message": "Conversation archived successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
|
||||||
|
async def unarchive_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""取消归档指定对话."""
|
||||||
|
session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
success = await conversation_service.unarchive_conversation(conversation_id)
|
||||||
|
if not success:
|
||||||
|
session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to unarchive conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||||
|
response = {"message": "Conversation unarchived successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Message management
|
||||||
|
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息")
|
||||||
|
async def get_conversation_messages(
|
||||||
|
conversation_id: int,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""获取指定对话的消息"""
|
||||||
|
session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
messages = await conversation_service.get_conversation_messages(
|
||||||
|
conversation_id, skip=skip, limit=limit
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||||
|
response = [MessageResponse.model_validate(msg) for msg in messages]
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# Chat functionality
|
||||||
|
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
|
||||||
|
async def chat(
|
||||||
|
conversation_id: int,
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""发送消息并获取AI响应"""
|
||||||
|
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
|
||||||
|
chat_service = ChatService(session)
|
||||||
|
await chat_service.initialize(conversation_id)
|
||||||
|
|
||||||
|
# response = await chat_service.chat(
|
||||||
|
# conversation_id=conversation_id,
|
||||||
|
# message=chat_request.message,
|
||||||
|
# stream=False,
|
||||||
|
# temperature=chat_request.temperature,
|
||||||
|
# max_tokens=chat_request.max_tokens,
|
||||||
|
# use_agent=chat_request.use_agent, # 可以简化掉
|
||||||
|
# use_langgraph=chat_request.use_langgraph, # 可以简化掉
|
||||||
|
# use_knowledge_base=chat_request.use_knowledge_base, # 可以简化掉
|
||||||
|
# knowledge_base_id=chat_request.knowledge_base_id # 可以简化掉
|
||||||
|
# )
|
||||||
|
response = "oooooooooooooooooooK"
|
||||||
|
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
|
||||||
|
|
||||||
|
return HxfResponse(response)
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
|
||||||
|
async def chat_stream(
|
||||||
|
conversation_id: int,
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""发送消息并获取流式AI响应."""
|
||||||
|
session.title = f"对话{conversation_id} 发送消息并获取流式AI响应"
|
||||||
|
session.desc = f"START: 对话{conversation_id} 发送消息 [{chat_request.message}] 并获取流式AI响应 >>> "
|
||||||
|
chat_service = ChatService(session)
|
||||||
|
await chat_service.initialize(conversation_id, streaming=True)
|
||||||
|
|
||||||
|
async def generate_response(chat_service):
|
||||||
|
try:
|
||||||
|
async for chunk in chat_service.chat_stream(
|
||||||
|
message=chat_request.message
|
||||||
|
):
|
||||||
|
yield chunk + "\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{session.log_prefix()} - 流式响应生成异常: {str(e)}")
|
||||||
|
yield {'success': False, 'data': f"data: {json.dumps({'type': 'error', 'message': f'流式响应生成异常: {str(e)}'}, ensure_ascii=False)}"}
|
||||||
|
|
||||||
|
response = StreamingResponse(
|
||||||
|
generate_response(chat_service),
|
||||||
|
media_type="text/stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Conversation management
|
||||||
|
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
||||||
|
async def create_conversation(
|
||||||
|
conversation_data: ConversationCreate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""创建新对话"""
|
||||||
|
id = current_user.id
|
||||||
|
session.title = f"用户{current_user.username} - 创建新对话"
|
||||||
|
session.desc = "START: 创建新对话"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
conversation = await conversation_service.create_conversation(
|
||||||
|
user_id=id,
|
||||||
|
conversation_data=conversation_data
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {id}, conversation_id: {conversation.id}"
|
||||||
|
response = ConversationResponse.model_validate(conversation)
|
||||||
|
return HxfResponse(response)
|
||||||
|
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
||||||
|
async def list_conversations(
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
search: str = None,
|
||||||
|
include_archived: bool = False,
|
||||||
|
order_by: str = "updated_at",
|
||||||
|
order_desc: bool = True,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""获取用户对话列表"""
|
||||||
|
session.title = "获取用户对话列表"
|
||||||
|
session.desc = "START: 获取用户对话列表"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
conversations = await conversation_service.get_user_conversations(
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
search_query=search,
|
||||||
|
include_archived=include_archived,
|
||||||
|
order_by=order_by,
|
||||||
|
order_desc=order_desc
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话 ..."
|
||||||
|
response = [ConversationResponse.model_validate(conv) for conv in conversations]
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/conversations/count", summary="获取用户对话总数")
|
||||||
|
async def get_conversations_count(
|
||||||
|
search: str = None,
|
||||||
|
include_archived: bool = False,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""获取用户对话总数"""
|
||||||
|
from th_agenter.core.context import UserContext
|
||||||
|
user_id = UserContext.get_current_user_id()
|
||||||
|
session.title = f"获取用户对话总数[用户id = {user_id}]"
|
||||||
|
session.desc = "START: 获取用户对话总数"
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
count = await conversation_service.get_user_conversations_count(
|
||||||
|
search_query=search,
|
||||||
|
include_archived=include_archived
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
||||||
|
response = {"count": count}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
||||||
|
async def get_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""获取指定对话"""
|
||||||
|
session.title = f"获取指定对话[对话id = {conversation_id}]"
|
||||||
|
session.desc = f"START: 获取指定对话 >>> 对话id = {conversation_id}"
|
||||||
|
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
conversation = await conversation_service.get_conversation(
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Conversation not found"
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id} >>> {conversation}"
|
||||||
|
|
||||||
|
response = ConversationResponse.model_validate(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
# chat_service = ChatService(session)
|
||||||
|
# await chat_service.initialize(conversation_id, streaming=False)
|
||||||
|
# messages = await chat_service.get_conversation_history_messages(
|
||||||
|
# conversation_id
|
||||||
|
# )
|
||||||
|
# response.messages = messages
|
||||||
|
|
||||||
|
messages = await conversation_service.get_conversation_messages(
|
||||||
|
conversation_id, skip=0, limit=100
|
||||||
|
)
|
||||||
|
response.messages = [MessageResponse.model_validate(msg) for msg in messages]
|
||||||
|
|
||||||
|
response.message_count = len(response.messages)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""数据库配置管理API"""
|
||||||
|
from loguru import logger
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from th_agenter.models.user import User
|
||||||
|
from th_agenter.db.database import get_session
|
||||||
|
from th_agenter.services.database_config_service import DatabaseConfigService
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
# 在文件顶部添加
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/database-config", tags=["database-config"])
|
||||||
|
# 创建服务单例
|
||||||
|
@lru_cache()
|
||||||
|
def get_database_config_service() -> DatabaseConfigService:
|
||||||
|
"""获取DatabaseConfigService单例"""
|
||||||
|
# 注意:这里需要处理db session的问题
|
||||||
|
return DatabaseConfigService(None) # 临时方案
|
||||||
|
|
||||||
|
# 或者使用全局变量
|
||||||
|
_database_service_instance = None
|
||||||
|
|
||||||
|
def get_database_service(session: Session = Depends(get_session)) -> DatabaseConfigService:
|
||||||
|
"""获取DatabaseConfigService实例"""
|
||||||
|
global _database_service_instance
|
||||||
|
if _database_service_instance is None:
|
||||||
|
_database_service_instance = DatabaseConfigService(session)
|
||||||
|
else:
|
||||||
|
# 更新db session
|
||||||
|
_database_service_instance.db = session
|
||||||
|
return _database_service_instance
|
||||||
|
|
||||||
|
class DatabaseConfigCreate(BaseModel):
|
||||||
|
name: str = Field(..., description="配置名称")
|
||||||
|
db_type: str = Field(default="postgresql", description="数据库类型")
|
||||||
|
host: str = Field(..., description="主机地址")
|
||||||
|
port: int = Field(..., description="端口号")
|
||||||
|
database: str = Field(..., description="数据库名")
|
||||||
|
username: str = Field(..., description="用户名")
|
||||||
|
password: str = Field(..., description="密码")
|
||||||
|
is_default: bool = Field(default=False, description="是否为默认配置")
|
||||||
|
connection_params: Dict[str, Any] = Field(default=None, description="额外连接参数")
|
||||||
|
|
||||||
|
class DatabaseConfigResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
db_type: str
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
database: str
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
is_active: bool
|
||||||
|
is_default: bool
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=NormalResponse, summary="创建或更新数据库配置")
|
||||||
|
async def create_database_config(
|
||||||
|
config_data: DatabaseConfigCreate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""创建或更新数据库配置"""
|
||||||
|
config = await service.create_or_update_config(current_user.id, config_data.model_dump())
|
||||||
|
response = NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="保存数据库配置成功",
|
||||||
|
data=config
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表")
|
||||||
|
async def get_database_configs(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""获取用户的数据库配置列表"""
|
||||||
|
configs = service.get_user_configs(current_user.id)
|
||||||
|
|
||||||
|
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||||
|
return HxfResponse(config_list)
|
||||||
|
|
||||||
|
@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接")
|
||||||
|
async def test_database_connection(
|
||||||
|
config_id: int,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""测试数据库连接"""
|
||||||
|
result = await service.test_connection(config_id, current_user.id)
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表")
|
||||||
|
async def connect_database(
|
||||||
|
config_id: int,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""连接数据库并获取表列表"""
|
||||||
|
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tables/{table_name}/data", summary="获取表数据预览")
|
||||||
|
async def get_table_data(
|
||||||
|
table_name: str,
|
||||||
|
db_type: str,
|
||||||
|
limit: int = 100,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""获取表数据预览"""
|
||||||
|
try:
|
||||||
|
result = await service.get_table_data(table_name, current_user.id, db_type, limit)
|
||||||
|
return HxfResponse(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表数据失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/tables/{table_name}/schema", summary="获取表结构信息")
|
||||||
|
async def get_table_schema(
|
||||||
|
table_name: str,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""获取表结构信息"""
|
||||||
|
result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的?
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置")
|
||||||
|
async def get_config_by_type(
|
||||||
|
db_type: str,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""根据数据库类型获取配置"""
|
||||||
|
config = service.get_config_by_type(current_user.id, db_type)
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"未找到类型为 {db_type} 的配置"
|
||||||
|
)
|
||||||
|
# 返回包含解密密码的配置
|
||||||
|
return HxfResponse(config.to_dict(include_password=True, decrypt_service=service))
|
||||||
|
|
@ -0,0 +1,616 @@
|
||||||
|
"""Knowledge base API endpoints."""
|
||||||
|
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.knowledge_base import KnowledgeBase, Document
|
||||||
|
from ...services.knowledge_base import KnowledgeBaseService
|
||||||
|
from ...services.document import DocumentService
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from utils.util_schemas import (
|
||||||
|
KnowledgeBaseCreate,
|
||||||
|
KnowledgeBaseResponse,
|
||||||
|
DocumentResponse,
|
||||||
|
DocumentListResponse,
|
||||||
|
DocumentUpload,
|
||||||
|
DocumentProcessingStatus,
|
||||||
|
DocumentChunksResponse,
|
||||||
|
ErrorResponse
|
||||||
|
)
|
||||||
|
from utils.util_file import FileUtils
|
||||||
|
from ...core.config import settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["knowledge-bases"])
|
||||||
|
|
||||||
|
@router.post("/", response_model=KnowledgeBaseResponse, summary="创建新的知识库")
|
||||||
|
async def create_knowledge_base(
|
||||||
|
kb_data: KnowledgeBaseCreate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""创建新的知识库"""
|
||||||
|
# Check if knowledge base with same name already exists for this user
|
||||||
|
session.desc = f"START: 为用户 {current_user.username}[ID={current_user.id}] 创建新的知识库 {kb_data}"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
session.desc = f"检查用户 {current_user.username} 是否已存在知识库 {kb_data.name}"
|
||||||
|
existing_kb = await kb_service.get_knowledge_base_by_name(kb_data.name)
|
||||||
|
if existing_kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"知识库名称 {kb_data.name} 已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create knowledge base
|
||||||
|
session.desc = f"知识库 {kb_data.name}不存在,创建之"
|
||||||
|
kb = await kb_service.create_knowledge_base(kb_data)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 创建知识库 {kb.name} 成功"
|
||||||
|
response = KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=0,
|
||||||
|
active_document_count=0
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[KnowledgeBaseResponse], summary="获取当前用户的所有知识库")
|
||||||
|
async def list_knowledge_bases(
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取当前用户的所有知识库"""
|
||||||
|
session.title = f"获取用户 {current_user.username} 的所有知识库"
|
||||||
|
session.desc = f"START: 获取用户 {current_user.username} 的所有知识库 (skip={skip}, limit={limit})"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
knowledge_bases = await kb_service.get_knowledge_bases(skip=skip, limit=limit)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for kb in knowledge_bases:
|
||||||
|
# 本知识库的文档数量
|
||||||
|
total_docs = await session.scalar(
|
||||||
|
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||||
|
)
|
||||||
|
# 本知识库的已处理文档数量
|
||||||
|
active_docs = await session.scalar(
|
||||||
|
select(func.count()).where(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result.append(KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
))
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取用户 {current_user.username} 的所有 {len(result)} 知识库"
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse, summary="根据知识库ID获取知识库详情")
|
||||||
|
async def get_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""根据知识库ID获取知识库详情"""
|
||||||
|
session.desc = f"START: 获取知识库 {kb_id} 的详情"
|
||||||
|
service = KnowledgeBaseService(session)
|
||||||
|
session.desc = f"检查知识库 {kb_id} 是否存在"
|
||||||
|
kb = await service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count documents
|
||||||
|
total_docs = await session.scalar(
|
||||||
|
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||||
|
)
|
||||||
|
session.desc = f"获取知识库 {kb_id} 共 {total_docs} 个文档"
|
||||||
|
|
||||||
|
active_docs = await session.scalar(
|
||||||
|
select(func.count()).where(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取知识库 {kb_id} 的详情,共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||||
|
response = KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse, summary="更新知识库")
|
||||||
|
async def update_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
kb_data: KnowledgeBaseCreate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""更新知识库"""
|
||||||
|
session.desc = f"START: 更新知识库 {kb_id}"
|
||||||
|
service = KnowledgeBaseService(session)
|
||||||
|
kb = await service.update_knowledge_base(kb_id, kb_data)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count documents
|
||||||
|
total_docs = await session.scalar(
|
||||||
|
select(func.count()).where(Document.knowledge_base_id == kb.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
active_docs = await session.scalar(
|
||||||
|
select(func.count()).where(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 更新知识库 {kb_id},结果 - 共 {total_docs} 个文档,其中 {active_docs} 个已处理"
|
||||||
|
response = KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/{kb_id}", summary="删除知识库")
|
||||||
|
async def delete_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""删除知识库"""
|
||||||
|
session.desc = f"START: 删除知识库 {kb_id}"
|
||||||
|
service = KnowledgeBaseService(session)
|
||||||
|
success = await service.delete_knowledge_base(kb_id)
|
||||||
|
if not success:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 删除知识库 {kb_id}"
|
||||||
|
return HxfResponse({"message": "Knowledge base deleted successfully"})
|
||||||
|
|
||||||
|
# Document management endpoints
|
||||||
|
@router.post("/{kb_id}/documents", response_model=DocumentResponse, summary="上传文档到知识库")
|
||||||
|
async def upload_document(
|
||||||
|
kb_id: int,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
process_immediately: bool = Form(True),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""上传文档到知识库"""
|
||||||
|
session.desc = f"START: 上传文档 {file.filename} ({FileUtils.format_file_size(file.size)}) 到知识库 (ID={kb_id})"
|
||||||
|
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
session.desc = f"获取知识库 {kb_id} 详情完毕 - 名称: {kb.name}, 描述: {kb.description}, 模型: {kb.embedding_model}"
|
||||||
|
# Validate file
|
||||||
|
if not FileUtils.validate_file_extension(file.filename):
|
||||||
|
session.desc = f"ERROR: 文件 {file.filename} 类型不支持,仅支持 {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"文件类型 {file.filename.split('.')[-1]} 不支持。支持类型: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
# Check file size (50MB limit)
|
||||||
|
max_size = 50 * 1024 * 1024 # 50MB
|
||||||
|
if file.size and file.size > max_size:
|
||||||
|
session.desc = f"ERROR: 文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"文件 {file.filename} 大小超过 {FileUtils.format_file_size(max_size)} 限制"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"文件为期望类型,处理文件 {file.filename} - "
|
||||||
|
# Upload document
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
document = await doc_service.upload_document(
|
||||||
|
file, kb_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process document immediately if requested
|
||||||
|
if process_immediately:
|
||||||
|
try:
|
||||||
|
await doc_service.process_document(document.id, kb_id)
|
||||||
|
# Refresh document to get updated status
|
||||||
|
await session.refresh(document)
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 处理文档 {document.id} 时出错: {str(e)}"
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 上传文档 {document.id} 到知识库 {kb_id}"
|
||||||
|
response = DocumentResponse(
|
||||||
|
id=document.id,
|
||||||
|
created_at=document.created_at,
|
||||||
|
updated_at=document.updated_at,
|
||||||
|
knowledge_base_id=document.knowledge_base_id,
|
||||||
|
filename=document.filename,
|
||||||
|
original_filename=document.original_filename,
|
||||||
|
file_path=document.file_path,
|
||||||
|
file_type=document.file_type,
|
||||||
|
file_size=document.file_size,
|
||||||
|
mime_type=document.mime_type,
|
||||||
|
is_processed=document.is_processed,
|
||||||
|
processing_error=document.processing_error,
|
||||||
|
chunk_count=document.chunk_count or 0,
|
||||||
|
embedding_model=document.embedding_model,
|
||||||
|
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents", response_model=DocumentListResponse, summary="获取知识库中的文档列表")
|
||||||
|
async def list_documents(
|
||||||
|
kb_id: int,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取知识库中的文档列表。"""
|
||||||
|
session.desc = f"START: 获取知识库 {kb_id} 中的文档列表"
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
documents, total = await doc_service.list_documents(kb_id, skip, limit)
|
||||||
|
|
||||||
|
doc_responses = []
|
||||||
|
for doc in documents:
|
||||||
|
doc_responses.append(DocumentResponse(
|
||||||
|
id=doc.id,
|
||||||
|
created_at=doc.created_at,
|
||||||
|
updated_at=doc.updated_at,
|
||||||
|
knowledge_base_id=doc.knowledge_base_id,
|
||||||
|
filename=doc.filename,
|
||||||
|
original_filename=doc.original_filename,
|
||||||
|
file_path=doc.file_path,
|
||||||
|
file_type=doc.file_type,
|
||||||
|
file_size=doc.file_size,
|
||||||
|
mime_type=doc.mime_type,
|
||||||
|
is_processed=doc.is_processed,
|
||||||
|
processing_error=doc.processing_error,
|
||||||
|
chunk_count=doc.chunk_count or 0,
|
||||||
|
embedding_model=doc.embedding_model,
|
||||||
|
file_size_mb=round(doc.file_size / (1024 * 1024), 2)
|
||||||
|
))
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档列表,共 {total} 条"
|
||||||
|
response = DocumentListResponse(
|
||||||
|
documents=doc_responses,
|
||||||
|
total=total,
|
||||||
|
page=skip // limit + 1,
|
||||||
|
page_size=limit
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse, summary="获取知识库中的文档块(片段)")
|
||||||
|
async def get_document_chunks(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取知识库中特定文档的所有文档块(片段)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: 知识库ID
|
||||||
|
doc_id: 文档ID
|
||||||
|
session: 数据库会话
|
||||||
|
current_user: 当前认证用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocumentChunksResponse: 文档块(片段)响应模型
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 所有文档块(片段)"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
knowledge_base = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
|
||||||
|
if not knowledge_base:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="知识库不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify document exists in the knowledge base
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > DocumentService"
|
||||||
|
document = await doc_service.get_document(doc_id, kb_id)
|
||||||
|
session.desc = f"获取知识库 {kb_id} 中的文档 {doc_id} 的信息 > get_document"
|
||||||
|
if not document:
|
||||||
|
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文档不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get document chunks
|
||||||
|
chunks = await doc_service.get_document_chunks(doc_id)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取文档 {doc_id} 共 {len(chunks)} 个文档块(片段)"
|
||||||
|
response = DocumentChunksResponse(
|
||||||
|
document_id=doc_id,
|
||||||
|
document_name=document.filename,
|
||||||
|
total_chunks=len(chunks),
|
||||||
|
chunks=chunks
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse, summary="获取知识库中的文档详情")
|
||||||
|
async def get_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取知识库中的文档详情。"""
|
||||||
|
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
document = await doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取知识库 {kb_id} 中的文档 {doc_id} 详情"
|
||||||
|
response = DocumentResponse(
|
||||||
|
id=document.id,
|
||||||
|
created_at=document.created_at,
|
||||||
|
updated_at=document.updated_at,
|
||||||
|
knowledge_base_id=document.knowledge_base_id,
|
||||||
|
filename=document.filename,
|
||||||
|
original_filename=document.original_filename,
|
||||||
|
file_path=document.file_path,
|
||||||
|
file_type=document.file_type,
|
||||||
|
file_size=document.file_size,
|
||||||
|
mime_type=document.mime_type,
|
||||||
|
is_processed=document.is_processed,
|
||||||
|
processing_error=document.processing_error,
|
||||||
|
chunk_count=document.chunk_count or 0,
|
||||||
|
embedding_model=document.embedding_model,
|
||||||
|
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/{kb_id}/documents/{doc_id}", summary="删除知识库中的文档")
|
||||||
|
async def delete_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""删除知识库中的文档。"""
|
||||||
|
session.desc = f"START: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
success = await doc_service.delete_document(doc_id, kb_id)
|
||||||
|
if not success:
|
||||||
|
session.desc = f"ERROR: 删除文档 {doc_id} 失败 - 文档不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 删除知识库 {kb_id} 中的文档 {doc_id}"
|
||||||
|
response = {"message": "Document deleted successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus, summary="处理知识库中的文档")
|
||||||
|
async def process_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""处理知识库中的文档,用于向量搜索。"""
|
||||||
|
session.desc = f"START: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if document exists
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
document = await doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the document
|
||||||
|
result = await doc_service.process_document(doc_id, kb_id)
|
||||||
|
await session.refresh(document)
|
||||||
|
session.desc = f"SUCCESS: 处理知识库 {kb_id} 中的文档 {doc_id}"
|
||||||
|
response = DocumentProcessingStatus(
|
||||||
|
document_id=doc_id,
|
||||||
|
status=result["status"],
|
||||||
|
progress=result.get("progress", 0.0),
|
||||||
|
error_message=result.get("error_message"),
|
||||||
|
chunks_created=result.get("chunks_created", 0)
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus, summary="获取知识库中的文档处理状态")
|
||||||
|
async def get_document_processing_status(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取知识库中的文档处理状态。"""
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
session.desc = f"START: 获取知识库 {kb_id} 中的文档 {doc_id} 处理状态"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
document = await doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
session.desc = f"ERROR: 文档 {doc_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine status
|
||||||
|
if document.processing_error:
|
||||||
|
status_str = "failed"
|
||||||
|
progress = 0.0
|
||||||
|
session.desc = f"ERROR: 文档 {doc_id} 处理失败,错误信息:{document.processing_error}"
|
||||||
|
elif document.is_processed:
|
||||||
|
status_str = "completed"
|
||||||
|
progress = 100.0
|
||||||
|
session.desc = f"SUCCESS: 文档 {doc_id} 处理完成"
|
||||||
|
else:
|
||||||
|
status_str = "pending"
|
||||||
|
progress = 0.0
|
||||||
|
session.desc = f"文档 {doc_id} 处理pending中"
|
||||||
|
|
||||||
|
response = DocumentProcessingStatus(
|
||||||
|
document_id=document.id,
|
||||||
|
status=status_str,
|
||||||
|
progress=progress,
|
||||||
|
error_message=document.processing_error,
|
||||||
|
chunks_created=document.chunk_count or 0
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/search", summary="在知识库中搜索文档")
|
||||||
|
async def search_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""在知识库中搜索文档。"""
|
||||||
|
session.desc = f"START: 在知识库 {kb_id} 中搜索文档,查询:{query}"
|
||||||
|
kb_service = KnowledgeBaseService(session)
|
||||||
|
kb = await kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
session.desc = f"ERROR: 知识库 {kb_id} 不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
doc_service = DocumentService(session)
|
||||||
|
results = await doc_service.search_documents(kb_id, query, limit)
|
||||||
|
session.desc = f"SUCCESS: 在知识库 {kb_id} 中搜索文档,查询:{query},返回 {len(results)} 条结果"
|
||||||
|
response = {
|
||||||
|
"knowledge_base_id": kb_id,
|
||||||
|
"query": query,
|
||||||
|
"results": results,
|
||||||
|
"total_results": len(results)
|
||||||
|
}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
@ -0,0 +1,473 @@
|
||||||
|
"""LLM configuration management API endpoints."""
|
||||||
|
|
||||||
|
from turtle import textinput
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import or_, select, delete, update
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM
|
||||||
|
from th_agenter.llm.online.online_llm import OnlineLLM
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.llm_config import LLMConfig
|
||||||
|
from th_agenter.llm.base_llm import LLMConfig_DataClass
|
||||||
|
from ...core.simple_permissions import require_super_admin, require_authenticated_user
|
||||||
|
from ...schemas.llm_config import (
|
||||||
|
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
|
||||||
|
LLMConfigTest
|
||||||
|
)
|
||||||
|
from th_agenter.services.document_processor import get_document_processor
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表")
|
||||||
|
async def get_llm_configs(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
provider: Optional[str] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
is_embedding: Optional[bool] = Query(None),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取大模型配置列表."""
|
||||||
|
session.title = "获取大模型配置列表"
|
||||||
|
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
|
||||||
|
stmt = select(LLMConfig)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
if search:
|
||||||
|
stmt = stmt.where(
|
||||||
|
or_(
|
||||||
|
LLMConfig.name.ilike(f"%{search}%"),
|
||||||
|
LLMConfig.model_name.ilike(f"%{search}%"),
|
||||||
|
LLMConfig.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 服务商筛选
|
||||||
|
if provider:
|
||||||
|
stmt = stmt.where(LLMConfig.provider == provider)
|
||||||
|
|
||||||
|
# 状态筛选
|
||||||
|
if is_active is not None:
|
||||||
|
stmt = stmt.where(LLMConfig.is_active == is_active)
|
||||||
|
|
||||||
|
# 模型类型筛选
|
||||||
|
if is_embedding is not None:
|
||||||
|
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
stmt = stmt.order_by(LLMConfig.name)
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
stmt = stmt.offset(skip).limit(limit)
|
||||||
|
configs = (await session.execute(stmt)).scalars().all()
|
||||||
|
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..."
|
||||||
|
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers", summary="获取支持的大模型服务商列表")
|
||||||
|
async def get_llm_providers(
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取支持的大模型服务商列表."""
|
||||||
|
session.desc = "START: 获取支持的大模型服务商列表"
|
||||||
|
stmt = select(LLMConfig.provider).distinct()
|
||||||
|
providers = (await session.execute(stmt)).scalars().all()
|
||||||
|
session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商"
|
||||||
|
return HxfResponse([provider for provider in providers if provider])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置")
|
||||||
|
async def get_active_llm_configs(
|
||||||
|
is_embedding: Optional[bool] = Query(None),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取所有激活的大模型配置."""
|
||||||
|
session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}"
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
|
||||||
|
|
||||||
|
if is_embedding is not None:
|
||||||
|
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||||
|
|
||||||
|
stmt = stmt.order_by(LLMConfig.created_at)
|
||||||
|
configs = (await session.execute(stmt)).scalars().all()
|
||||||
|
session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置"
|
||||||
|
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||||
|
|
||||||
|
@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置")
|
||||||
|
async def get_default_llm_config(
|
||||||
|
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取默认大模型配置."""
|
||||||
|
session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||||
|
stmt = select(LLMConfig).where(
|
||||||
|
LLMConfig.is_default == True,
|
||||||
|
LLMConfig.is_embedding == is_embedding,
|
||||||
|
LLMConfig.is_active == True
|
||||||
|
)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
model_type = "嵌入模型" if is_embedding else "对话模型"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"未找到默认{model_type}配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||||
|
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||||
|
|
||||||
|
@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情")
|
||||||
|
async def get_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取大模型配置详情."""
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置")
|
||||||
|
async def create_llm_config(
|
||||||
|
config_data: LLMConfigCreate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""创建大模型配置."""
|
||||||
|
# 检查配置名称是否已存在
|
||||||
|
# 先保存当前用户名,避免在refresh后访问可能导致MissingGreenlet错误
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
||||||
|
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if existing_config:
|
||||||
|
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="配置名称已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建配置对象
|
||||||
|
config = LLMConfig_DataClass(
|
||||||
|
name=config_data.name,
|
||||||
|
provider=config_data.provider,
|
||||||
|
model_name=config_data.model_name,
|
||||||
|
api_key=config_data.api_key,
|
||||||
|
base_url=config_data.base_url,
|
||||||
|
max_tokens=config_data.max_tokens,
|
||||||
|
temperature=config_data.temperature,
|
||||||
|
top_p=config_data.top_p,
|
||||||
|
frequency_penalty=config_data.frequency_penalty,
|
||||||
|
presence_penalty=config_data.presence_penalty,
|
||||||
|
description=config_data.description,
|
||||||
|
is_active=config_data.is_active,
|
||||||
|
is_default=config_data.is_default,
|
||||||
|
is_embedding=config_data.is_embedding,
|
||||||
|
extra_config=config_data.extra_config or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证配置
|
||||||
|
validation_result = config.validate_config()
|
||||||
|
if not validation_result['valid']:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=validation_result['error']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果设为默认,取消同类型的其他默认配置
|
||||||
|
if config_data.is_default:
|
||||||
|
stmt = update(LLMConfig).where(
|
||||||
|
LLMConfig.is_embedding == config_data.is_embedding
|
||||||
|
).values({"is_default": False})
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
session.desc = f"验证大模型配置, config_data"
|
||||||
|
# 创建配置
|
||||||
|
config = LLMConfig_DataClass(
|
||||||
|
name=config_data.name,
|
||||||
|
provider=config_data.provider,
|
||||||
|
model_name=config_data.model_name,
|
||||||
|
api_key=config_data.api_key,
|
||||||
|
base_url=config_data.base_url,
|
||||||
|
max_tokens=config_data.max_tokens,
|
||||||
|
temperature=config_data.temperature,
|
||||||
|
top_p=config_data.top_p,
|
||||||
|
frequency_penalty=config_data.frequency_penalty,
|
||||||
|
presence_penalty=config_data.presence_penalty,
|
||||||
|
description=config_data.description,
|
||||||
|
is_active=config_data.is_active,
|
||||||
|
is_default=config_data.is_default,
|
||||||
|
is_embedding=config_data.is_embedding,
|
||||||
|
extra_config=config_data.extra_config or {}
|
||||||
|
)
|
||||||
|
# Audit fields are set automatically by SQLAlchemy event listener
|
||||||
|
|
||||||
|
session.add(config)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(config)
|
||||||
|
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}"
|
||||||
|
return HxfResponse(config.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置")
|
||||||
|
async def update_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
config_data: LLMConfigUpdate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""更新大模型配置."""
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"START: 更新大模型配置, id={config_id}"
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查配置名称是否已存在(排除自己)
|
||||||
|
if config_data.name and config_data.name != config.name:
|
||||||
|
stmt = select(LLMConfig).where(
|
||||||
|
LLMConfig.name == config_data.name,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
)
|
||||||
|
existing_config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if existing_config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="配置名称已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果设为默认,取消同类型的其他默认配置
|
||||||
|
if config_data.is_default is True:
|
||||||
|
# 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值
|
||||||
|
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
|
||||||
|
stmt = update(LLMConfig).where(
|
||||||
|
LLMConfig.is_embedding == is_embedding,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
).values({"is_default": False})
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = config_data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(config, field, value)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(config)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}"
|
||||||
|
return HxfResponse(config.to_dict())
|
||||||
|
|
||||||
|
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
|
||||||
|
async def delete_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""删除大模型配置."""
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"START: 删除大模型配置, id={config_id}"
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
session.desc = f"待删除大模型记录 {config.to_dict()}"
|
||||||
|
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||||
|
# 这里可以添加相关的检查逻辑
|
||||||
|
|
||||||
|
# 删除配置
|
||||||
|
await session.delete(config)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}"
|
||||||
|
return HxfResponse({"message": "LLM config deleted successfully"})
|
||||||
|
|
||||||
|
@router.post("/{config_id}/test", summary="测试连接大模型配置")
|
||||||
|
async def test_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
test_data: LLMConfigTest,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""测试连接大模型配置."""
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}"
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}")
|
||||||
|
config_name = config.name
|
||||||
|
# 验证配置
|
||||||
|
validation_result = config.validate_config()
|
||||||
|
logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}")
|
||||||
|
if not validation_result["valid"]:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"配置验证失败: {validation_result['error']}",
|
||||||
|
"details": validation_result
|
||||||
|
}
|
||||||
|
|
||||||
|
session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}"
|
||||||
|
# 尝试创建客户端并发送测试请求
|
||||||
|
try:
|
||||||
|
# # 这里应该根据不同的服务商创建相应的客户端
|
||||||
|
# # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||||
|
|
||||||
|
test_message = test_data.message or "Hello, this is a test message."
|
||||||
|
session.desc = f"准备测试LLM功能 > test_message = {test_message}"
|
||||||
|
|
||||||
|
if config.is_embedding:
|
||||||
|
config.provider = "ollama"
|
||||||
|
streaming_llm = BGEEmbedLLM(config)
|
||||||
|
else:
|
||||||
|
streaming_llm = OnlineLLM(config)
|
||||||
|
session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}"
|
||||||
|
streaming_llm.load_model() # 加载模型
|
||||||
|
session.desc = f"加载模型完毕,模型名称:{config.model_name},base_url: {config.base_url},准备测试对话..."
|
||||||
|
|
||||||
|
if config.is_embedding:
|
||||||
|
# 测试嵌入模型,使用嵌入API而非聊天API
|
||||||
|
test_text = test_message or "Hello, this is a test message for embedding"
|
||||||
|
response = streaming_llm.embed_query(test_text)
|
||||||
|
else:
|
||||||
|
# 测试聊天模型
|
||||||
|
from langchain.messages import SystemMessage, HumanMessage
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="你是一个简洁的助手,回答控制在50字以内"),
|
||||||
|
HumanMessage(content=test_message)
|
||||||
|
]
|
||||||
|
response = streaming_llm.model.invoke(messages)
|
||||||
|
session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}"
|
||||||
|
|
||||||
|
return HxfResponse({
|
||||||
|
"success": True,
|
||||||
|
"message": "LLM测试成功",
|
||||||
|
"request": test_message,
|
||||||
|
"response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典
|
||||||
|
"latency_ms": 150, # 模拟延迟
|
||||||
|
"config_info": config.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as test_error:
|
||||||
|
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
|
||||||
|
return HxfResponse({
|
||||||
|
"success": False,
|
||||||
|
"message": f"LLM测试失败: {str(test_error)}",
|
||||||
|
"test_message": test_message,
|
||||||
|
"config_info": config.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
|
||||||
|
async def toggle_llm_config_status(
|
||||||
|
config_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""切换大模型配置状态."""
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}"
|
||||||
|
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 切换状态
|
||||||
|
config.is_active = not config.is_active
|
||||||
|
# Audit fields are set automatically by SQLAlchemy event listener
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(config)
|
||||||
|
|
||||||
|
status_text = "激活" if config.is_active else "禁用"
|
||||||
|
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}"
|
||||||
|
|
||||||
|
return HxfResponse({
|
||||||
|
"message": f"配置已{status_text}",
|
||||||
|
"is_active": config.is_active
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/set-default", summary="设置默认大模型配置")
|
||||||
|
async def set_default_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""设置默认大模型配置."""
|
||||||
|
username = current_user.username
|
||||||
|
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}"
|
||||||
|
|
||||||
|
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||||
|
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查配置是否激活
|
||||||
|
if not config.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="只能将激活的配置设为默认"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取消同类型的其他默认配置
|
||||||
|
stmt = update(LLMConfig).where(
|
||||||
|
LLMConfig.is_embedding == config.is_embedding,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
).values({"is_default": False})
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
# 设置当前配置为默认
|
||||||
|
config.is_default = True
|
||||||
|
config.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(config)
|
||||||
|
|
||||||
|
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||||
|
# 更新文档处理器默认embedding
|
||||||
|
await get_document_processor(session)._init_embeddings()
|
||||||
|
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}"
|
||||||
|
return HxfResponse({
|
||||||
|
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||||
|
"is_default": config.is_default
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,281 @@
|
||||||
|
"""Role management API endpoints."""
|
||||||
|
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import select, and_, or_, delete
|
||||||
|
|
||||||
|
from ...core.simple_permissions import require_super_admin
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.permission import Role, UserRole
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...schemas.permission import (
|
||||||
|
RoleCreate, RoleUpdate, RoleResponse,
|
||||||
|
UserRoleAssign
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/roles", tags=["roles"])
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[RoleResponse], summary="获取角色列表")
|
||||||
|
async def get_roles(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user = Depends(require_super_admin),
|
||||||
|
):
|
||||||
|
"""获取角色列表."""
|
||||||
|
session.desc = f"START: 获取用户 {current_user.username} 角色列表"
|
||||||
|
stmt = select(Role)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
if search:
|
||||||
|
stmt = stmt.where(
|
||||||
|
or_(
|
||||||
|
Role.name.ilike(f"%{search}%"),
|
||||||
|
Role.code.ilike(f"%{search}%"),
|
||||||
|
Role.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 状态筛选
|
||||||
|
if is_active is not None:
|
||||||
|
stmt = stmt.where(Role.is_active == is_active)
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
stmt = stmt.offset(skip).limit(limit)
|
||||||
|
roles = (await session.execute(stmt)).scalars().all()
|
||||||
|
session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色"
|
||||||
|
response = [role.to_dict() for role in roles]
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情")
|
||||||
|
async def get_role(
|
||||||
|
role_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""获取角色详情."""
|
||||||
|
session.desc = f"START: 获取角色 {role_id} 详情"
|
||||||
|
stmt = select(Role).where(Role.id == role_id)
|
||||||
|
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = role.to_dict()
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色")
|
||||||
|
async def create_role(
|
||||||
|
role_data: RoleCreate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""创建角色."""
|
||||||
|
session.desc = f"START: 创建角色 {role_data.name}"
|
||||||
|
# 检查角色代码是否已存在
|
||||||
|
stmt = select(Role).where(Role.code == role_data.code)
|
||||||
|
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if existing_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="角色代码已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建角色
|
||||||
|
role = Role(
|
||||||
|
name=role_data.name,
|
||||||
|
code=role_data.code,
|
||||||
|
description=role_data.description,
|
||||||
|
is_active=role_data.is_active
|
||||||
|
)
|
||||||
|
role.set_audit_fields(current_user.id)
|
||||||
|
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(role)
|
||||||
|
|
||||||
|
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||||
|
response = role.to_dict()
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色")
|
||||||
|
async def update_role(
|
||||||
|
role_id: int,
|
||||||
|
role_data: RoleUpdate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""更新角色."""
|
||||||
|
session.desc = f"更新用户 {current_user.username} 角色 {role_id}"
|
||||||
|
stmt = select(Role).where(Role.id == role_id)
|
||||||
|
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 超级管理员角色不能被编辑
|
||||||
|
if role.code == "SUPER_ADMIN":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="超级管理员角色不能被编辑"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查角色编码是否已存在(排除当前角色)
|
||||||
|
if role_data.code and role_data.code != role.code:
|
||||||
|
stmt = select(Role).where(
|
||||||
|
and_(
|
||||||
|
Role.code == role_data.code,
|
||||||
|
Role.id != role_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if existing_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="角色代码已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = role_data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(role, field, value)
|
||||||
|
|
||||||
|
# Audit fields are set automatically by SQLAlchemy event listener
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(role)
|
||||||
|
|
||||||
|
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||||
|
response = role.to_dict()
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色")
|
||||||
|
async def delete_role(
|
||||||
|
role_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""删除角色."""
|
||||||
|
stmt = select(Role).where(Role.id == role_id)
|
||||||
|
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 超级管理员角色不能被删除
|
||||||
|
if role.code == "SUPER_ADMIN":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="超级管理员角色不能被删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否有用户使用该角色
|
||||||
|
stmt = select(UserRole).where(UserRole.role_id == role_id)
|
||||||
|
user_count = (await session.execute(stmt)).scalars().count()
|
||||||
|
if user_count > 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"无法删除角色,还有 {user_count} 个用户关联此角色"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除角色
|
||||||
|
await session.delete(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"角色删除成功: {role.name} by user {current_user.username}"
|
||||||
|
response = {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# 用户角色管理路由
|
||||||
|
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||||
|
|
||||||
|
@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED, summary="为用户分配角色")
|
||||||
|
async def assign_user_roles(
|
||||||
|
assignment_data: UserRoleAssign,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""为用户分配角色."""
|
||||||
|
# 验证用户是否存在
|
||||||
|
stmt = select(User).where(User.id == assignment_data.user_id)
|
||||||
|
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="用户不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证角色是否存在
|
||||||
|
stmt = select(Role).where(Role.id.in_(assignment_data.role_ids))
|
||||||
|
roles = (await session.execute(stmt)).scalars().all()
|
||||||
|
if len(roles) != len(assignment_data.role_ids):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="部分角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除现有角色关联
|
||||||
|
stmt = delete(UserRole).where(UserRole.user_id == assignment_data.user_id)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
# 添加新的角色关联
|
||||||
|
for role_id in assignment_data.role_ids:
|
||||||
|
user_role = UserRole(
|
||||||
|
user_id=assignment_data.user_id,
|
||||||
|
role_id=role_id
|
||||||
|
)
|
||||||
|
session.add(user_role)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}"
|
||||||
|
|
||||||
|
response = {"message": "角色分配成功"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表")
|
||||||
|
async def get_user_roles(
|
||||||
|
user_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_active_user)
|
||||||
|
):
|
||||||
|
"""获取用户角色列表."""
|
||||||
|
# 检查权限:用户只能查看自己的角色,或者是超级管理员
|
||||||
|
if current_user.id != user_id and not await current_user.is_superuser():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权限查看其他用户的角色"
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = select(User).where(User.id == user_id)
|
||||||
|
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="用户不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = select(Role).join(
|
||||||
|
UserRole, Role.id == UserRole.role_id
|
||||||
|
).where(
|
||||||
|
UserRole.user_id == user_id
|
||||||
|
)
|
||||||
|
roles = (await session.execute(stmt)).scalars().all()
|
||||||
|
|
||||||
|
response = [role.to_dict() for role in roles]
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# 将子路由添加到主路由
|
||||||
|
router.include_router(user_role_router)
|
||||||
|
|
@ -0,0 +1,338 @@
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from th_agenter.db.database import get_session
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||||
|
from th_agenter.services.conversation import ConversationService
|
||||||
|
from th_agenter.services.conversation_context import conversation_context_service
|
||||||
|
from utils.util_schemas import BaseResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from loguru import logger
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class SmartQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
class SmartQueryResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
workflow_steps: Optional[list] = None
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
|
||||||
|
class ConversationContextResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@router.post("/query", response_model=SmartQueryResponse, summary="智能问数查询")
|
||||||
|
async def smart_query(
|
||||||
|
request: SmartQueryRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
智能问数查询接口
|
||||||
|
支持新对话时自动加载文件列表,智能选择相关Excel文件,生成和执行pandas代码
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 用户 {current_user.username} 智能问数查询"
|
||||||
|
conversation_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
session.desc = "ERROR: 用户输入为空, 查询内容不能为空"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="查询内容不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
session.desc = "ERROR: 用户输入过长, 查询内容不能超过1000字符"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="查询内容过长,请控制在1000字符以内"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化工作流管理器
|
||||||
|
workflow_manager = SmartWorkflowManager(session)
|
||||||
|
await workflow_manager.initialize()
|
||||||
|
|
||||||
|
conversation_service = ConversationService(session)
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"智能问数: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
request.is_new_conversation = True
|
||||||
|
session.desc = f"创建新对话: {conversation_id}"
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"WARNING: 创建对话失败,使用临时会话: {e}"
|
||||||
|
conversation_id = None
|
||||||
|
else:
|
||||||
|
# 验证对话是否存在且属于当前用户
|
||||||
|
try:
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
if not context or context.get('user_id') != current_user.id:
|
||||||
|
session.desc = f"ERROR: 对话 {conversation_id} 不存在或无权访问"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话不存在或无权访问"
|
||||||
|
)
|
||||||
|
session.desc = f"使用现有对话: {conversation_id}"
|
||||||
|
except HTTPException:
|
||||||
|
session.desc = f"EXCEPTION: 对话 {conversation_id} 不存在或无权访问"
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 验证对话失败: {e}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="对话验证失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"WARNING: 保存用户消息失败: {e}"
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 执行智能查询工作流
|
||||||
|
try:
|
||||||
|
result = await workflow_manager.process_smart_query(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
is_new_conversation=request.is_new_conversation
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 智能查询执行失败: {e}"
|
||||||
|
# 返回结构化的错误响应
|
||||||
|
response = SmartQueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"查询执行失败: {str(e)}",
|
||||||
|
data={'error_type': 'query_execution_error'},
|
||||||
|
workflow_steps=[{
|
||||||
|
'step': 'error',
|
||||||
|
'status': 'failed',
|
||||||
|
'message': str(e)
|
||||||
|
}],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# 如果查询成功,保存助手回复和更新上下文
|
||||||
|
if result['success'] and conversation_id:
|
||||||
|
try:
|
||||||
|
# 保存助手回复
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result.get('data', {}).get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result.get('data'),
|
||||||
|
'workflow_steps': result.get('workflow_steps', []),
|
||||||
|
'selected_files': result.get('data', {}).get('used_files', [])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=result.get('data', {}).get('used_files', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"EXCEPTION: 保存消息到对话历史失败: {e}"
|
||||||
|
# 不影响返回结果,只记录警告
|
||||||
|
|
||||||
|
# 返回结果,包含对话ID
|
||||||
|
response_data = result.get('data', {})
|
||||||
|
if conversation_id:
|
||||||
|
response_data['conversation_id'] = conversation_id
|
||||||
|
session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}"
|
||||||
|
response = SmartQueryResponse(
|
||||||
|
success=result['success'],
|
||||||
|
message=result.get('message', '查询完成'),
|
||||||
|
data=response_data,
|
||||||
|
workflow_steps=result.get('workflow_steps', []),
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
session.desc = f"EXCEPTION: HTTP异常: {e}"
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 智能查询接口异常: {e}"
|
||||||
|
# 返回通用错误响应
|
||||||
|
response = SmartQueryResponse(
|
||||||
|
success=False,
|
||||||
|
message="服务器内部错误,请稍后重试",
|
||||||
|
data={'error_type': 'internal_server_error'},
|
||||||
|
workflow_steps=[{
|
||||||
|
'step': 'error',
|
||||||
|
'status': 'failed',
|
||||||
|
'message': '系统异常'
|
||||||
|
}],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
|
||||||
|
async def get_conversation_context(
|
||||||
|
conversation_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取对话上下文信息,包括已使用的文件和历史查询
|
||||||
|
"""
|
||||||
|
# 获取对话上下文
|
||||||
|
session.desc = f"START: 获取对话上下文,对话ID: {conversation_id}"
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话上下文不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
if context['user_id'] != current_user.id:
|
||||||
|
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权访问此对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取对话历史
|
||||||
|
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||||
|
context['message_history'] = history
|
||||||
|
session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}"
|
||||||
|
response = ConversationContextResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取对话上下文成功",
|
||||||
|
data=context
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
|
||||||
|
async def get_files_status(
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户当前的文件状态和统计信息
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}"
|
||||||
|
workflow_manager = SmartWorkflowManager()
|
||||||
|
await workflow_manager.initialize()
|
||||||
|
|
||||||
|
# 获取用户文件列表
|
||||||
|
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
total_files = len(file_list)
|
||||||
|
total_rows = sum(f.get('row_count', 0) for f in file_list)
|
||||||
|
total_columns = sum(f.get('column_count', 0) for f in file_list)
|
||||||
|
|
||||||
|
# 文件类型统计
|
||||||
|
file_types = {}
|
||||||
|
for file_info in file_list:
|
||||||
|
filename = file_info['filename']
|
||||||
|
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
|
||||||
|
file_types[ext] = file_types.get(ext, 0) + 1
|
||||||
|
|
||||||
|
status_data = {
|
||||||
|
'total_files': total_files,
|
||||||
|
'total_rows': total_rows,
|
||||||
|
'total_columns': total_columns,
|
||||||
|
'file_types': file_types,
|
||||||
|
'files': [{
|
||||||
|
'id': f['id'],
|
||||||
|
'filename': f['filename'],
|
||||||
|
'row_count': f.get('row_count', 0),
|
||||||
|
'column_count': f.get('column_count', 0),
|
||||||
|
'columns': f.get('columns', []),
|
||||||
|
'upload_time': f.get('upload_time')
|
||||||
|
} for f in file_list],
|
||||||
|
'ready_for_query': total_files > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}"
|
||||||
|
response = ConversationContextResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||||
|
data=status_data
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
|
||||||
|
async def reset_conversation_context(
|
||||||
|
conversation_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
重置对话上下文,清除历史查询记录但保留文件
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 重置对话上下文,对话ID: {conversation_id}"
|
||||||
|
# 验证对话存在和用户权限
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话上下文不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
if context['user_id'] != current_user.id:
|
||||||
|
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权访问此对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重置对话上下文
|
||||||
|
success = await conversation_context_service.reset_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}"
|
||||||
|
response = ConversationContextResponse(
|
||||||
|
success=True,
|
||||||
|
message="对话上下文已重置,可以开始新的数据分析会话"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
else:
|
||||||
|
session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="重置对话上下文失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,755 @@
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
import pandas as pd
|
||||||
|
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse, BaseResponse
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from th_agenter.services.smart_query import (
|
||||||
|
SmartQueryService,
|
||||||
|
ExcelAnalysisService,
|
||||||
|
DatabaseQueryService
|
||||||
|
)
|
||||||
|
from th_agenter.services.excel_metadata_service import ExcelMetadataService
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from utils.util_file import FileUtils
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, AsyncGenerator
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from th_agenter.db.database import get_session
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||||
|
from th_agenter.services.conversation_context import ConversationContextService
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from loguru import logger
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class DatabaseConfig(BaseModel):
|
||||||
|
type: str
|
||||||
|
host: str
|
||||||
|
port: str
|
||||||
|
database: str
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
page: int = 1
|
||||||
|
page_size: int = 20
|
||||||
|
table_name: Optional[str] = None
|
||||||
|
|
||||||
|
class TableSchemaRequest(BaseModel):
|
||||||
|
table_name: str
|
||||||
|
|
||||||
|
class ExcelUploadResponse(BaseModel):
|
||||||
|
file_id: int
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None # 添加data字段
|
||||||
|
|
||||||
|
class QueryResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@router.post("/upload-excel", response_model=ExcelUploadResponse, summary="上传Excel文件并进行预处理")
|
||||||
|
async def upload_excel(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传Excel文件并进行预处理
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 用户 {current_user.username} 上传 Excel 文件并进行预处理"
|
||||||
|
# 验证文件类型
|
||||||
|
allowed_extensions = ['.xlsx', '.xls', '.csv']
|
||||||
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||||
|
|
||||||
|
if file_extension not in allowed_extensions:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 上传了不支持的文件格式 {file_extension},请上传 .xlsx, .xls 或 .csv 文件"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证文件大小 (10MB)
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
if file_size > 10 * 1024 * 1024:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 大小为 {file_size / (1024 * 1024):.2f}MB,超过最大限制 10MB"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="文件大小不能超过 10MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建持久化目录结构
|
||||||
|
backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录
|
||||||
|
data_dir = backend_dir / "data/uploads"
|
||||||
|
excel_user_dir = data_dir / f"excel_{current_user.id}"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
excel_user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成文件名:{uuid}_{原始文件名称}
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
safe_filename = FileUtils.sanitize_filename(file.filename)
|
||||||
|
new_filename = f"{file_id}_{safe_filename}"
|
||||||
|
file_path = excel_user_dir / new_filename
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 使用Excel元信息服务提取并保存元信息
|
||||||
|
metadata_service = ExcelMetadataService(session)
|
||||||
|
excel_file = metadata_service.save_file_metadata(
|
||||||
|
file_path=str(file_path),
|
||||||
|
original_filename=file.filename,
|
||||||
|
user_id=current_user.id,
|
||||||
|
file_size=file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为了兼容现有前端,仍然创建pickle文件
|
||||||
|
try:
|
||||||
|
if file_extension == '.csv':
|
||||||
|
df = pd.read_csv(file_path, encoding='utf-8')
|
||||||
|
else:
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
if file_extension == '.csv':
|
||||||
|
df = pd.read_csv(file_path, encoding='gbk')
|
||||||
|
else:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 编码错误,请确保文件为UTF-8或GBK编码"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="文件编码错误,请确保文件为UTF-8或GBK编码"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 读取失败: {str(e)}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"文件读取失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存pickle文件到同一目录
|
||||||
|
pickle_filename = f"{file_id}_{safe_filename}.pkl"
|
||||||
|
pickle_path = excel_user_dir / pickle_filename
|
||||||
|
df.to_pickle(pickle_path)
|
||||||
|
|
||||||
|
# 数据预处理和分析(保持兼容性)
|
||||||
|
excel_service = ExcelAnalysisService()
|
||||||
|
analysis_result = excel_service.analyze_dataframe(df, file.filename)
|
||||||
|
|
||||||
|
# 添加数据库文件信息
|
||||||
|
analysis_result.update({
|
||||||
|
'file_id': str(excel_file.id),
|
||||||
|
'database_id': excel_file.id,
|
||||||
|
'temp_file_path': str(pickle_path), # 更新为新的pickle路径
|
||||||
|
'original_filename': file.filename,
|
||||||
|
'file_size_mb': excel_file.file_size_mb,
|
||||||
|
'sheet_names': excel_file.sheet_names,
|
||||||
|
})
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}"
|
||||||
|
response = ExcelUploadResponse(
|
||||||
|
file_id=excel_file.id,
|
||||||
|
success=True,
|
||||||
|
message="Excel文件上传成功",
|
||||||
|
data=analysis_result
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据")
|
||||||
|
async def preview_excel(
|
||||||
|
request: ExcelPreviewRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
预览Excel文件数据
|
||||||
|
"""
|
||||||
|
session.desc = f"START: 用户 {current_user.username} 预览文件 {request.file_id}"
|
||||||
|
|
||||||
|
# 验证file_id格式
|
||||||
|
try:
|
||||||
|
file_id = int(request.file_id)
|
||||||
|
except ValueError:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 提供了无效的文件ID格式: {request.file_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"无效的文件ID格式: {request.file_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从数据库获取文件信息
|
||||||
|
metadata_service = ExcelMetadataService(session)
|
||||||
|
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||||
|
|
||||||
|
if not excel_file:
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 不存在或已被删除"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件不存在或已被删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(excel_file.file_path):
|
||||||
|
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 已被移动或删除"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件已被移动或删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最后访问时间
|
||||||
|
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
if excel_file.file_type.lower() == 'csv':
|
||||||
|
df = pd.read_csv(excel_file.file_path, encoding='utf-8')
|
||||||
|
else:
|
||||||
|
# 对于Excel文件,使用默认sheet或第一个sheet
|
||||||
|
sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0
|
||||||
|
df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name)
|
||||||
|
|
||||||
|
# 计算分页
|
||||||
|
total_rows = len(df)
|
||||||
|
start_idx = (request.page - 1) * request.page_size
|
||||||
|
end_idx = start_idx + request.page_size
|
||||||
|
|
||||||
|
# 获取分页数据
|
||||||
|
paginated_df = df.iloc[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 转换为字典格式
|
||||||
|
data = paginated_df.fillna('').to_dict('records')
|
||||||
|
columns = df.columns.tolist()
|
||||||
|
session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据"
|
||||||
|
response = QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="Excel文件预览加载成功",
|
||||||
|
data={
|
||||||
|
'data': data,
|
||||||
|
'columns': columns,
|
||||||
|
'total_rows': total_rows,
|
||||||
|
'page': request.page,
|
||||||
|
'page_size': request.page_size,
|
||||||
|
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接")
|
||||||
|
async def test_database_connection(
|
||||||
|
config: DatabaseConfig,
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试数据库连接
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_service = DatabaseQueryService()
|
||||||
|
is_connected = await db_service.test_connection(config.model_dump())
|
||||||
|
|
||||||
|
if is_connected:
|
||||||
|
return NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="数据库连接测试成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message="数据库连接测试失败"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"连接测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除第285-314行的connect_database方法
|
||||||
|
# @router.post("/connect-database", response_model=QueryResponse)
|
||||||
|
# async def connect_database(
|
||||||
|
# config_id: int,
|
||||||
|
# current_user = Depends(AuthService.get_current_user),
|
||||||
|
# db: Session = Depends(get_session)
|
||||||
|
# ):
|
||||||
|
# """连接数据库并获取表列表"""
|
||||||
|
# ... (整个方法都删除)
|
||||||
|
|
||||||
|
@router.post("/table-schema", response_model=QueryResponse, summary="获取数据表结构")
|
||||||
|
async def get_table_schema(
|
||||||
|
request: TableSchemaRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取数据表结构
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_service = DatabaseQueryService()
|
||||||
|
schema_result = await db_service.get_table_schema(request.table_name, current_user.id)
|
||||||
|
|
||||||
|
if schema_result['success']:
|
||||||
|
response = QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取表结构成功",
|
||||||
|
data=schema_result['data']
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
else:
|
||||||
|
response = QueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=schema_result['message']
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response = QueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"获取表结构失败: {str(e)}"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
class StreamQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
class DatabaseStreamQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
database_config_id: int
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
@router.post("/execute-excel-query", summary="流式智能问答查询")
|
||||||
|
async def stream_smart_query(
|
||||||
|
request: StreamQueryRequest,
|
||||||
|
current_user=Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
流式智能问答查询接口
|
||||||
|
支持实时推送工作流步骤和最终结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
workflow_manager = SmartWorkflowManager(session)
|
||||||
|
await workflow_manager.initialize()
|
||||||
|
|
||||||
|
conversation_context_service = ConversationContextService()
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"智能问数: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建对话失败: {e}")
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存用户消息失败: {e}")
|
||||||
|
|
||||||
|
# 执行智能查询工作流(带流式推送)
|
||||||
|
async for step_data in workflow_manager.process_excel_query_stream(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
is_new_conversation=request.is_new_conversation
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 如果是最终结果,保存到对话历史
|
||||||
|
if step_data.get('type') == 'final_result' and conversation_id:
|
||||||
|
try:
|
||||||
|
result_data = step_data.get('data', {})
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result_data.get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result_data,
|
||||||
|
'workflow_steps': step_data.get('workflow_steps', []),
|
||||||
|
'selected_files': result_data.get('used_files', [])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=result_data.get('used_files', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"查询成功完成,对话ID: {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'complete', 'message': '查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式智能查询异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if workflow_manager:
|
||||||
|
try:
|
||||||
|
workflow_manager.excel_workflow.executor.shutdown(wait=False)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
response = StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/execute-db-query", summary="流式数据库查询")
|
||||||
|
async def execute_database_query(
|
||||||
|
request: DatabaseStreamQueryRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
流式数据库查询接口
|
||||||
|
支持实时推送工作流步骤和最终结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理数据库查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
workflow_manager = SmartWorkflowManager(session)
|
||||||
|
await workflow_manager.initialize()
|
||||||
|
conversation_context_service = ConversationContextService()
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"数据库查询: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建对话失败: {e}")
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存用户消息失败: {e}")
|
||||||
|
|
||||||
|
# 执行数据库查询工作流(带流式推送)
|
||||||
|
async for step_data in workflow_manager.process_database_query_stream(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
database_config_id=request.database_config_id
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 如果是最终结果,保存到对话历史
|
||||||
|
if step_data.get('type') == 'final_result' and conversation_id:
|
||||||
|
try:
|
||||||
|
result_data = step_data.get('data', {})
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result_data.get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result_data,
|
||||||
|
'workflow_steps': step_data.get('workflow_steps', []),
|
||||||
|
'generated_sql': result_data.get('generated_sql', '')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"数据库查询成功完成,对话ID: {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'complete', 'message': '数据库查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式数据库查询异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if workflow_manager:
|
||||||
|
try:
|
||||||
|
workflow_manager.database_workflow.executor.shutdown(wait=False)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
response = StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/cleanup-temp-files", summary="清理临时文件")
|
||||||
|
async def cleanup_temp_files(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
清理临时文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
user_prefix = f"excel_{current_user.id}_"
|
||||||
|
|
||||||
|
cleaned_count = 0
|
||||||
|
for filename in os.listdir(temp_dir):
|
||||||
|
if filename.startswith(user_prefix) and filename.endswith('.pkl'):
|
||||||
|
file_path = os.path.join(temp_dir, filename)
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
cleaned_count += 1
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
response = BaseResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"已清理 {cleaned_count} 个临时文件"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response = BaseResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"清理临时文件失败: {str(e)}"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表")
|
||||||
|
async def get_file_list(
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户上传的Excel文件列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session.desc = f"START: 获取用户 {current_user.id} 的文件列表"
|
||||||
|
metadata_service = ExcelMetadataService(session)
|
||||||
|
skip = (page - 1) * page_size
|
||||||
|
files, total = metadata_service.get_user_files(current_user.id, skip, page_size)
|
||||||
|
|
||||||
|
file_list = []
|
||||||
|
for file in files:
|
||||||
|
file_info = {
|
||||||
|
'id': file.id,
|
||||||
|
'filename': file.original_filename,
|
||||||
|
'file_size': file.file_size,
|
||||||
|
'file_size_mb': file.file_size_mb,
|
||||||
|
'file_type': file.file_type,
|
||||||
|
'sheet_names': file.sheet_names,
|
||||||
|
'sheet_count': file.sheet_count,
|
||||||
|
'last_accessed': file.last_accessed.isoformat() if file.last_accessed else None,
|
||||||
|
'is_processed': file.is_processed,
|
||||||
|
'processing_error': file.processing_error
|
||||||
|
}
|
||||||
|
file_list.append(file_info)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件"
|
||||||
|
response = FileListResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取文件列表成功",
|
||||||
|
data={
|
||||||
|
'files': file_list,
|
||||||
|
'total': total,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response = FileListResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"获取文件列表失败: {str(e)}"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件")
|
||||||
|
async def delete_file(
|
||||||
|
file_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除指定的Excel文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session.desc = f"START: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||||
|
metadata_service = ExcelMetadataService(session)
|
||||||
|
success = metadata_service.delete_file(file_id, current_user.id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||||
|
response = NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="文件删除成功"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
else:
|
||||||
|
session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败"
|
||||||
|
response = NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message="文件不存在或删除失败"
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response = NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message=str(e)
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息")
|
||||||
|
async def get_file_info(
|
||||||
|
file_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定文件的详细信息
|
||||||
|
"""
|
||||||
|
metadata_service = ExcelMetadataService(session)
|
||||||
|
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||||
|
|
||||||
|
if not excel_file:
|
||||||
|
session.desc = f"ERROR: 获取用户 {current_user.id} 的文件 {file_id} 信息,文件不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最后访问时间
|
||||||
|
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
'id': excel_file.id,
|
||||||
|
'filename': excel_file.original_filename,
|
||||||
|
'file_size': excel_file.file_size,
|
||||||
|
'file_size_mb': excel_file.file_size_mb,
|
||||||
|
'file_type': excel_file.file_type,
|
||||||
|
'sheet_names': excel_file.sheet_names,
|
||||||
|
'default_sheet': excel_file.default_sheet,
|
||||||
|
'columns_info': excel_file.columns_info,
|
||||||
|
'preview_data': excel_file.preview_data,
|
||||||
|
'data_types': excel_file.data_types,
|
||||||
|
'total_rows': excel_file.total_rows,
|
||||||
|
'total_columns': excel_file.total_columns,
|
||||||
|
'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None,
|
||||||
|
'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None,
|
||||||
|
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||||
|
}
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件 {file_id} 信息"
|
||||||
|
response = QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取文件信息成功",
|
||||||
|
data=file_info
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,235 @@
|
||||||
|
"""表元数据管理API"""
|
||||||
|
from loguru import logger
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from th_agenter.models.user import User
|
||||||
|
from th_agenter.db.database import get_session
|
||||||
|
from th_agenter.services.table_metadata_service import TableMetadataService
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
|
||||||
|
|
||||||
|
class TableSelectionRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_names: List[str] = Field(..., description="选中的表名列表")
|
||||||
|
|
||||||
|
class TableMetadataResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
table_name: str
|
||||||
|
table_schema: str
|
||||||
|
table_type: str
|
||||||
|
table_comment: str
|
||||||
|
columns_count: int
|
||||||
|
row_count: int
|
||||||
|
is_enabled_for_qa: bool
|
||||||
|
qa_description: str
|
||||||
|
business_context: str
|
||||||
|
last_synced_at: str
|
||||||
|
|
||||||
|
class QASettingsUpdate(BaseModel):
|
||||||
|
is_enabled_for_qa: bool = Field(default=True)
|
||||||
|
qa_description: str = Field(default="")
|
||||||
|
business_context: str = Field(default="")
|
||||||
|
|
||||||
|
class TableByNameRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_name: str = Field(..., description="表名")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/collect", summary="收集选中表的元数据")
|
||||||
|
async def collect_table_metadata(
|
||||||
|
request: TableSelectionRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""收集选中表的元数据"""
|
||||||
|
session.desc = f"START: 用户 {current_user.id} 收集表元数据"
|
||||||
|
service = TableMetadataService(session)
|
||||||
|
result = await service.collect_and_save_table_metadata(
|
||||||
|
current_user.id,
|
||||||
|
request.database_config_id,
|
||||||
|
request.table_names
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据"
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
@router.get("/", summary="获取用户表元数据列表")
|
||||||
|
async def get_table_metadata(
|
||||||
|
database_config_id: int = None,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""获取表元数据列表"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(session)
|
||||||
|
metadata_list = await service.get_user_table_metadata(
|
||||||
|
current_user.id,
|
||||||
|
database_config_id
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": meta.id,
|
||||||
|
"table_name": meta.table_name,
|
||||||
|
"table_schema": meta.table_schema,
|
||||||
|
"table_type": meta.table_type,
|
||||||
|
"table_comment": meta.table_comment or "",
|
||||||
|
"columns": meta.columns_info if meta.columns_info else [],
|
||||||
|
"column_count": len(meta.columns_info) if meta.columns_info else 0,
|
||||||
|
"row_count": meta.row_count,
|
||||||
|
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||||
|
"qa_description": meta.qa_description or "",
|
||||||
|
"business_context": meta.business_context or "",
|
||||||
|
"created_at": meta.created_at.isoformat() if meta.created_at else "",
|
||||||
|
"updated_at": meta.updated_at.isoformat() if meta.updated_at else "",
|
||||||
|
"last_synced_at": meta.last_synced_at.isoformat() if meta.last_synced_at else "",
|
||||||
|
"qa_settings": {
|
||||||
|
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||||
|
"qa_description": meta.qa_description or "",
|
||||||
|
"business_context": meta.business_context or ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for meta in metadata_list
|
||||||
|
]
|
||||||
|
|
||||||
|
return HxfResponse({
|
||||||
|
"success": True,
|
||||||
|
"data": data
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return HxfResponse({
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/by-table", summary="根据表名获取表元数据")
|
||||||
|
async def get_table_metadata_by_name(
|
||||||
|
request: TableByNameRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""根据表名获取表元数据"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(session)
|
||||||
|
metadata = await service.get_table_metadata_by_name(
|
||||||
|
current_user.id,
|
||||||
|
request.database_config_id,
|
||||||
|
request.table_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
data = {
|
||||||
|
"id": metadata.id,
|
||||||
|
"table_name": metadata.table_name,
|
||||||
|
"table_schema": metadata.table_schema,
|
||||||
|
"table_type": metadata.table_type,
|
||||||
|
"table_comment": metadata.table_comment or "",
|
||||||
|
"columns": metadata.columns_info if metadata.columns_info else [],
|
||||||
|
"column_count": len(metadata.columns_info) if metadata.columns_info else 0,
|
||||||
|
"row_count": metadata.row_count,
|
||||||
|
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||||
|
"qa_description": metadata.qa_description or "",
|
||||||
|
"business_context": metadata.business_context or "",
|
||||||
|
"created_at": metadata.created_at.isoformat() if metadata.created_at else "",
|
||||||
|
"updated_at": metadata.updated_at.isoformat() if metadata.updated_at else "",
|
||||||
|
"last_synced_at": metadata.last_synced_at.isoformat() if metadata.last_synced_at else "",
|
||||||
|
"qa_settings": {
|
||||||
|
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||||
|
"qa_description": metadata.qa_description or "",
|
||||||
|
"business_context": metadata.business_context or ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return HxfResponse({
|
||||||
|
"success": True,
|
||||||
|
"data": data
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return HxfResponse({
|
||||||
|
"success": False,
|
||||||
|
"data": None,
|
||||||
|
"message": "表元数据不存在"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return HxfResponse({
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return HxfResponse({
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置")
|
||||||
|
async def update_qa_settings(
|
||||||
|
metadata_id: int,
|
||||||
|
settings: QASettingsUpdate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""更新表的问答设置"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(session)
|
||||||
|
success = await service.update_table_qa_settings(
|
||||||
|
current_user.id,
|
||||||
|
metadata_id,
|
||||||
|
settings.model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return HxfResponse({
|
||||||
|
"success": True,
|
||||||
|
"message": "设置更新成功"
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="表元数据不存在"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新问答设置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TableSaveRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_names: List[str] = Field(..., description="要保存的表名列表")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/save")
|
||||||
|
async def save_table_metadata(
|
||||||
|
request: TableSaveRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""保存选中表的元数据配置"""
|
||||||
|
service = TableMetadataService(session)
|
||||||
|
result = await service.save_table_metadata_config(
|
||||||
|
user_id=current_user.id,
|
||||||
|
database_config_id=request.database_config_id,
|
||||||
|
table_names=request.table_names
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置"
|
||||||
|
|
||||||
|
return HxfResponse({
|
||||||
|
"success": True,
|
||||||
|
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||||
|
"saved_tables": result['saved_tables'],
|
||||||
|
"failed_tables": result.get('failed_tables', [])
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""User management endpoints."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...core.simple_permissions import require_super_admin
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.user import UserService
|
||||||
|
from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/profile", response_model=UserResponse, summary="获取当前用户的个人信息")
|
||||||
|
async def get_user_profile(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取当前用户的个人信息."""
|
||||||
|
response = UserResponse.model_validate(current_user)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息")
|
||||||
|
async def update_user_profile(
|
||||||
|
user_update: UserUpdate,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""更新当前用户的个人信息."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
# Check if email is being changed and is already taken
|
||||||
|
if user_update.email and user_update.email != current_user.email:
|
||||||
|
existing_user = await user_service.get_user_by_email(user_update.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
updated_user = await user_service.update_user(current_user.id, user_update)
|
||||||
|
response = UserResponse.model_validate(updated_user)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/profile", summary="删除当前用户的账户")
|
||||||
|
async def delete_user_account(
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""删除当前用户的账户."""
|
||||||
|
username = current_user.username
|
||||||
|
user_service = UserService(session)
|
||||||
|
await user_service.delete_user(current_user.id)
|
||||||
|
session.desc = f"删除用户 [{username}] 成功"
|
||||||
|
response = {"message": f"删除用户 {username} 成功"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# Admin endpoints
|
||||||
|
@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)")
|
||||||
|
async def create_user(
|
||||||
|
user_create: UserCreate,
|
||||||
|
current_user = Depends(require_super_admin),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""创建一个新用户 (需要有管理员权限)."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
# Check if username already exists
|
||||||
|
existing_user = await user_service.get_user_by_username(user_create.username)
|
||||||
|
if existing_user:
|
||||||
|
session.desc = f"创建用户 [{user_create.username}] 失败 - 用户名已存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Username already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if email already exists
|
||||||
|
existing_user = await user_service.get_user_by_email(user_create.email)
|
||||||
|
if existing_user:
|
||||||
|
session.desc = f"创建用户 [{user_create.username}] 失败 - 邮箱已存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
new_user = await user_service.create_user(user_create)
|
||||||
|
response = UserResponse.model_validate(new_user)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)")
|
||||||
|
async def list_users(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
size: int = Query(20, ge=1, le=100),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
role_id: Optional[int] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""列出所有用户,支持分页和筛选 (仅管理员权限)."""
|
||||||
|
session.desc = f"START: 列出所有用户,分页={page}, 每页大小={size}, 搜索={search}, 角色ID={role_id}, 激活状态={is_active}"
|
||||||
|
user_service = UserService(session)
|
||||||
|
skip = (page - 1) * size
|
||||||
|
users, total = await user_service.get_users_with_filters(
|
||||||
|
skip=skip,
|
||||||
|
limit=size,
|
||||||
|
search=search,
|
||||||
|
role_id=role_id,
|
||||||
|
is_active=is_active
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"users": [UserResponse.model_validate(user) for user in users],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": size
|
||||||
|
}
|
||||||
|
return HxfResponse(result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)")
|
||||||
|
async def get_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""通过ID获取用户信息 (仅管理员权限)."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
user = await user_service.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
response = UserResponse.model_validate(user)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.put("/change-password", summary="修改当前用户的密码")
|
||||||
|
async def change_password(
|
||||||
|
request: ChangePasswordRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""修改当前用户的密码."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await user_service.change_password(
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_password=request.current_password,
|
||||||
|
new_password=request.new_password
|
||||||
|
)
|
||||||
|
response = {"message": "Password changed successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
except Exception as e:
|
||||||
|
if "Current password is incorrect" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password is incorrect"
|
||||||
|
)
|
||||||
|
elif "must be at least 6 characters" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="New password must be at least 6 characters long"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to change password"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/{user_id}/reset-password", summary="重置用户密码 (仅管理员权限)")
|
||||||
|
async def reset_user_password(
|
||||||
|
user_id: int,
|
||||||
|
request: ResetPasswordRequest,
|
||||||
|
current_user = Depends(require_super_admin),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""重置用户密码 (仅管理员权限)."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await user_service.reset_password(
|
||||||
|
user_id=user_id,
|
||||||
|
new_password=request.new_password
|
||||||
|
)
|
||||||
|
response = {"message": "Password reset successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
except Exception as e:
|
||||||
|
if "User not found" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
elif "must be at least 6 characters" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="New password must be at least 6 characters long"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to reset password"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{user_id}", response_model=UserResponse, summary="更新用户信息 (仅管理员权限)")
|
||||||
|
async def update_user(
|
||||||
|
user_id: int,
|
||||||
|
user_update: UserUpdate,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""更新用户信息 (仅管理员权限)."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
user = await user_service.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = await user_service.update_user(user_id, user_update)
|
||||||
|
response = UserResponse.model_validate(updated_user)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)")
|
||||||
|
async def delete_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""删除用户 (仅管理员权限)."""
|
||||||
|
user_service = UserService(session)
|
||||||
|
|
||||||
|
user = await user_service.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
await user_service.delete_user(user_id)
|
||||||
|
response = {"message": "User deleted successfully"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
@ -0,0 +1,531 @@
|
||||||
|
"""工作流管理API"""
|
||||||
|
|
||||||
|
from typing import List, Optional, AsyncGenerator
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import select, and_, func
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ...db.database import get_session
|
||||||
|
from ...schemas.workflow import (
|
||||||
|
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
|
||||||
|
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
|
||||||
|
)
|
||||||
|
from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
|
||||||
|
from ...services.workflow_engine import get_workflow_engine
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...models.user import User
|
||||||
|
from loguru import logger
|
||||||
|
from utils.util_exceptions import HxfResponse
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
def convert_workflow_for_response(workflow_dict):
|
||||||
|
"""转换工作流数据以适配响应模型"""
|
||||||
|
if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'):
|
||||||
|
for conn in workflow_dict['definition']['connections']:
|
||||||
|
if 'from_node' in conn:
|
||||||
|
conn['from'] = conn.pop('from_node')
|
||||||
|
if 'to_node' in conn:
|
||||||
|
conn['to'] = conn.pop('to_node')
|
||||||
|
return workflow_dict
|
||||||
|
|
||||||
|
@router.get("/", response_model=WorkflowListResponse, summary="获取工作流列表")
|
||||||
|
async def list_workflows(
|
||||||
|
skip: Optional[int] = Query(None, ge=0),
|
||||||
|
limit: Optional[int] = Query(None, ge=1, le=100),
|
||||||
|
workflow_status: Optional[WorkflowStatus] = None,
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流列表"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"获取用户 {current_user.username} 的所有工作流"
|
||||||
|
session.desc = f"START: 获取用户 {current_user.username} 的所有工作流 (skip={skip}, limit={limit})"
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
|
||||||
|
|
||||||
|
if workflow_status:
|
||||||
|
stmt = stmt.where(Workflow.status == workflow_status)
|
||||||
|
|
||||||
|
# 添加搜索功能
|
||||||
|
if search:
|
||||||
|
stmt = stmt.where(Workflow.name.ilike(f"%{search}%"))
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id)
|
||||||
|
if workflow_status:
|
||||||
|
count_query = count_query.where(Workflow.status == workflow_status)
|
||||||
|
if search:
|
||||||
|
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
|
||||||
|
|
||||||
|
session.desc = f"查询条件: 状态={workflow_status}, 搜索={search}"
|
||||||
|
total = await session.scalar(count_query)
|
||||||
|
session.desc = f"查询结果: 共 {total} 条"
|
||||||
|
|
||||||
|
# 如果没有传分页参数,返回所有数据
|
||||||
|
if skip is None and limit is None:
|
||||||
|
workflows = (await session.scalars(stmt)).all()
|
||||||
|
session.desc = f"SUCCESS: 没有传分页参数,返回所有数据 - 共 {len(workflows)} 条"
|
||||||
|
response = WorkflowListResponse(
|
||||||
|
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||||
|
total=total,
|
||||||
|
page=1,
|
||||||
|
size=total
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# 使用默认分页参数
|
||||||
|
if skip is None:
|
||||||
|
skip = 0
|
||||||
|
if limit is None:
|
||||||
|
limit = 10
|
||||||
|
|
||||||
|
# 分页查询
|
||||||
|
workflows = (await session.scalars(stmt.offset(skip).limit(limit))).all()
|
||||||
|
session.desc = f"SUCCESS: 分页查询 - 共 {len(workflows)} 条"
|
||||||
|
|
||||||
|
response = WorkflowListResponse(
|
||||||
|
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||||
|
total=total,
|
||||||
|
page=skip // limit + 1, # 计算页码
|
||||||
|
size=limit
|
||||||
|
)
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{workflow_id}", response_model=WorkflowResponse, summary="获取工作流详情")
|
||||||
|
async def get_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流详情"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"获取工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 获取工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).where(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 获取工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取工作流数据 {workflow_id}"
|
||||||
|
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.put("/{workflow_id}", response_model=WorkflowResponse, summary="更新工作流")
|
||||||
|
async def update_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
workflow_data: WorkflowUpdate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""更新工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"更新工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 更新工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).where(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 更新工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_data.status = WorkflowStatus.PUBLISHED
|
||||||
|
# 更新字段
|
||||||
|
session.desc = f"UPDATE: 工作流 {workflow_id} 更新字段 {workflow_data.model_dump(exclude_unset=True)}"
|
||||||
|
update_data = workflow_data.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
if field == "definition" and value:
|
||||||
|
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||||
|
if hasattr(value, 'dict'):
|
||||||
|
setattr(workflow, field, value.dict())
|
||||||
|
else:
|
||||||
|
setattr(workflow, field, value)
|
||||||
|
else:
|
||||||
|
setattr(workflow, field, value)
|
||||||
|
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(workflow)
|
||||||
|
session.desc = f"SUCCESS: 更新工作流数据 commit & refresh {workflow_id}"
|
||||||
|
response = WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.delete("/{workflow_id}", summary="删除工作流")
|
||||||
|
async def delete_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""删除工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"删除工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 删除工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 删除工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
session.desc = f"删除工作流: {workflow.name}"
|
||||||
|
await session.delete(workflow)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 删除工作流数据 commit {workflow_id}"
|
||||||
|
response = {"message": "工作流删除成功"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/activate", summary="激活工作流")
|
||||||
|
async def activate_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""激活工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"激活工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 激活工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 激活工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 激活工作流数据 commit {workflow_id}"
|
||||||
|
response = {"message": "工作流激活成功"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/deactivate", summary="停用工作流")
|
||||||
|
async def deactivate_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""停用工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"停用工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 停用工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 停用工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 停用工作流数据 commit {workflow_id}"
|
||||||
|
response = {"message": "工作流停用成功"}
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse], summary="获取工作流执行历史")
|
||||||
|
async def list_workflow_executions(
|
||||||
|
workflow_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(10, ge=1, le=100),
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流执行历史"""
|
||||||
|
session.title = f"获取工作流执行历史 {workflow_id}"
|
||||||
|
session.desc = f"START: 获取工作流执行历史 {workflow_id}"
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow, WorkflowExecution
|
||||||
|
|
||||||
|
# 验证工作流所有权
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).where(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 获取工作流执行历史数据 - 工作流不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取执行历史
|
||||||
|
executions = (await session.scalars(
|
||||||
|
select(WorkflowExecution).where(
|
||||||
|
WorkflowExecution.workflow_id == workflow_id
|
||||||
|
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
|
||||||
|
)).all()
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 获取工作流执行历史数据 commit {workflow_id}"
|
||||||
|
response = [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 获取工作流执行历史数据 commit {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取执行历史失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse, summary="获取工作流执行详情")
|
||||||
|
async def get_workflow_execution(
|
||||||
|
execution_id: int,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流执行详情"""
|
||||||
|
session.title = f"获取工作流执行详情 {execution_id}"
|
||||||
|
session.desc = f"START: 获取工作流执行详情 {execution_id}"
|
||||||
|
try:
|
||||||
|
from ...models.workflow import WorkflowExecution, Workflow
|
||||||
|
|
||||||
|
execution = await session.scalar(
|
||||||
|
select(WorkflowExecution).join(
|
||||||
|
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||||
|
).where(
|
||||||
|
WorkflowExecution.id == execution_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
session.desc = f"ERROR: 获取工作流执行详情数据 - 执行记录不存在"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = WorkflowExecutionResponse.model_validate(execution)
|
||||||
|
session.desc = f"SUCCESS: 获取工作流执行详情数据 commit {execution_id}"
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
session.desc = f"ERROR: 获取工作流执行详情数据 commit {execution_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取执行详情失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/execute-stream", summary="流式执行工作流")
|
||||||
|
async def execute_workflow_stream(
|
||||||
|
workflow_id: int,
|
||||||
|
request: WorkflowExecuteRequest,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""流式执行工作流,实时推送节点执行状态"""
|
||||||
|
session.title = f"流式执行工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 流式执行工作流 {workflow_id}"
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_engine = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
# 验证工作流
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取工作流引擎
|
||||||
|
workflow_engine = await get_workflow_engine(session)
|
||||||
|
|
||||||
|
# 执行工作流(流式版本)
|
||||||
|
async for step_data in workflow_engine.execute_workflow_stream(
|
||||||
|
workflow=workflow,
|
||||||
|
input_data=request.input_data,
|
||||||
|
user_id=current_user.id,
|
||||||
|
session=session
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
response = StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
session.desc = f"SUCCESS: 流式执行工作流 {workflow_id} 完毕"
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/", response_model=WorkflowResponse, summary="创建工作流")
|
||||||
|
async def create_workflow(
|
||||||
|
workflow_data: WorkflowCreate,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""创建工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"创建工作流 {workflow_data.name}"
|
||||||
|
session.desc = f"START: 创建工作流 {workflow_data.name}"
|
||||||
|
# 创建工作流
|
||||||
|
workflow = Workflow(
|
||||||
|
name=workflow_data.name,
|
||||||
|
description=workflow_data.description,
|
||||||
|
definition=workflow_data.definition.model_dump(),
|
||||||
|
version="1.0.0",
|
||||||
|
status=ModelWorkflowStatus.PUBLISHED, # workflow_data.status,
|
||||||
|
owner_id=current_user.id
|
||||||
|
)
|
||||||
|
session.desc = f"创建工作流实例 - Workflow(), {workflow_data.name}"
|
||||||
|
workflow.set_audit_fields(current_user.id)
|
||||||
|
session.desc = f"保存工作流 - set_audit_fields {workflow_data.name}"
|
||||||
|
|
||||||
|
session.add(workflow)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(workflow)
|
||||||
|
session.desc = f"保存工作流 - commit & refresh {workflow_data.name}"
|
||||||
|
# 转换definition中的字段映射
|
||||||
|
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||||
|
session.desc = f"转换工作流数据 - convert_workflow_for_response {workflow_data.name}"
|
||||||
|
|
||||||
|
response = WorkflowResponse(**workflow_dict)
|
||||||
|
session.desc = f"SUCCESS: 返回工作流数据 - WorkflowResponse {workflow_data.name}"
|
||||||
|
return HxfResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse, summary="执行工作流")
|
||||||
|
async def execute_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
request: WorkflowExecuteRequest,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""执行工作流"""
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
session.title = f"执行工作流 {workflow_id}"
|
||||||
|
session.desc = f"START: 执行工作流 {workflow_id}"
|
||||||
|
|
||||||
|
workflow = await session.scalar(
|
||||||
|
select(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
session.desc = f"ERROR: 执行工作流数据 - 工作流不存在 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"获取工作流数据 - Workflow() {workflow_id}"
|
||||||
|
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||||
|
session.desc = f"ERROR: 执行工作流数据 - 工作流未激活 {workflow_id}"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="工作流未激活,无法执行"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工作流引擎并执行
|
||||||
|
engine = await get_workflow_engine(session)
|
||||||
|
session.desc = f"获取工作流引擎 - get_workflow_engine {workflow_id}"
|
||||||
|
execution_result = await engine.execute_workflow(
|
||||||
|
workflow=workflow,
|
||||||
|
input_data=request.input_data,
|
||||||
|
user_id=current_user.id,
|
||||||
|
session=session
|
||||||
|
)
|
||||||
|
|
||||||
|
session.desc = f"SUCCESS: 执行工作流数据 commit {workflow_id}"
|
||||||
|
return HxfResponse(execution_result)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""Main API router."""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .endpoints import chat
|
||||||
|
from .endpoints import auth
|
||||||
|
from .endpoints import knowledge_base
|
||||||
|
from .endpoints import smart_query
|
||||||
|
from .endpoints import smart_chat
|
||||||
|
from .endpoints import database_config
|
||||||
|
from .endpoints import table_metadata
|
||||||
|
|
||||||
|
# # System management endpoints
|
||||||
|
from .endpoints import roles
|
||||||
|
from .endpoints import llm_configs
|
||||||
|
from .endpoints import users
|
||||||
|
|
||||||
|
# # Workflow endpoints
|
||||||
|
from .endpoints import workflow
|
||||||
|
|
||||||
|
# Create main API router
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
auth.router,
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["身份验证"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
users.router,
|
||||||
|
prefix="/users",
|
||||||
|
tags=["users"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
roles.router,
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["admin-roles"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
llm_configs.router,
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["admin-llm-configs"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
knowledge_base.router,
|
||||||
|
prefix="/knowledge-bases",
|
||||||
|
tags=["knowledge-bases"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
database_config.router,
|
||||||
|
tags=["database-config"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
table_metadata.router,
|
||||||
|
tags=["table-metadata"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
smart_query.router,
|
||||||
|
tags=["smart-query"]
|
||||||
|
)
|
||||||
|
router.include_router(
|
||||||
|
chat.router,
|
||||||
|
prefix="/chat",
|
||||||
|
tags=["chat"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
smart_chat.router,
|
||||||
|
tags=["smart-chat"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
workflow.router,
|
||||||
|
prefix="/workflows",
|
||||||
|
tags=["workflows"]
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Core module for TH Agenter."""
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
"""FastAPI application factory."""
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
|
from .config import Settings
|
||||||
|
from .middleware import UserContextMiddleware
|
||||||
|
from ..api.routes import router
|
||||||
|
from ..api.endpoints import table_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager."""
|
||||||
|
logger.info("Starting up TH Agenter application...")
|
||||||
|
yield
|
||||||
|
logger.info("Shutting down TH Agenter application...")
|
||||||
|
|
||||||
|
|
||||||
|
# def create_app(settings: Settings = None) -> FastAPI:
|
||||||
|
# """Create and configure FastAPI application."""
|
||||||
|
# if settings is None:
|
||||||
|
# from .config import get_settings
|
||||||
|
# settings = get_settings()
|
||||||
|
|
||||||
|
# # Create FastAPI app
|
||||||
|
# app = FastAPI(
|
||||||
|
# title=settings.app_name,
|
||||||
|
# version=settings.app_version,
|
||||||
|
# description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由DrGraph修改",
|
||||||
|
# debug=settings.debug,
|
||||||
|
# lifespan=lifespan,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Add middleware
|
||||||
|
# setup_middleware(app, settings)
|
||||||
|
|
||||||
|
# # Add exception handlers
|
||||||
|
# setup_exception_handlers(app)
|
||||||
|
|
||||||
|
# # Include routers
|
||||||
|
# app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
|
# app.include_router(table_metadata.router)
|
||||||
|
# # 在现有导入中添加
|
||||||
|
# from ..api.endpoints import database_config
|
||||||
|
|
||||||
|
# # 在路由注册部分添加
|
||||||
|
# app.include_router(database_config.router)
|
||||||
|
# # Health check endpoint
|
||||||
|
# @app.get("/health")
|
||||||
|
# async def health_check():
|
||||||
|
# return {"status": "healthy", "version": settings.app_version}
|
||||||
|
|
||||||
|
# # Root endpoint
|
||||||
|
# @app.get("/")
|
||||||
|
# async def root():
|
||||||
|
# return {"message": "Chat Agent API is running"}
|
||||||
|
# return app
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middleware(app: FastAPI, settings: Settings) -> None:
|
||||||
|
"""Setup application middleware."""
|
||||||
|
|
||||||
|
# User context middleware (should be first to set context for all requests)
|
||||||
|
app.add_middleware(UserContextMiddleware)
|
||||||
|
|
||||||
|
# CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors.allowed_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=settings.cors.allowed_methods,
|
||||||
|
allow_headers=settings.cors.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trusted host middleware (for production)
|
||||||
|
if settings.environment == "production":
|
||||||
|
app.add_middleware(
|
||||||
|
TrustedHostMiddleware,
|
||||||
|
allowed_hosts=["*"] # Configure this properly in production
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_exception_handlers(app: FastAPI) -> None:
|
||||||
|
"""Setup global exception handlers."""
|
||||||
|
|
||||||
|
@app.exception_handler(StarletteHTTPException)
|
||||||
|
async def http_exception_handler(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "http_error",
|
||||||
|
"message": exc.detail,
|
||||||
|
"status_code": exc.status_code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_json_serializable(obj):
|
||||||
|
"""递归地将对象转换为JSON可序列化的格式"""
|
||||||
|
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
elif isinstance(obj, (ValueError, Exception)):
|
||||||
|
return str(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return [make_json_serializable(item) for item in obj]
|
||||||
|
else:
|
||||||
|
# For any other object, convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
# Convert any non-serializable objects to strings in error details
|
||||||
|
try:
|
||||||
|
errors = make_json_serializable(exc.errors())
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: if even our conversion fails, use a simple error message
|
||||||
|
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "validation_error",
|
||||||
|
"message": "Request validation failed",
|
||||||
|
"details": errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc):
|
||||||
|
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "internal_error",
|
||||||
|
"message": "Internal server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the app instance
|
||||||
|
# app = create_app()
|
||||||
|
|
@ -0,0 +1,468 @@
|
||||||
|
"""Configuration management for TH Agenter."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from requests import Session
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class DatabaseSettings(BaseSettings):
|
||||||
|
"""Database configuration."""
|
||||||
|
url: str = Field(..., alias="database_url") # Must be provided via environment variable
|
||||||
|
echo: bool = Field(default=False)
|
||||||
|
pool_size: int = Field(default=5)
|
||||||
|
max_overflow: int = Field(default=10)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class SecuritySettings(BaseSettings):
|
||||||
|
"""Security configuration."""
|
||||||
|
secret_key: str = Field(default="your-secret-key-here-change-in-production")
|
||||||
|
algorithm: str = Field(default="HS256")
|
||||||
|
access_token_expire_minutes: int = Field(default=300)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolSetings(BaseSettings):
|
||||||
|
# Tavily搜索配置
|
||||||
|
tavily_api_key: Optional[str] = Field(default=None)
|
||||||
|
weather_api_key: Optional[str] = Field(default=None)
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class LLMSettings(BaseSettings):
|
||||||
|
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
|
||||||
|
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
openai_api_key: Optional[str] = Field(default=None)
|
||||||
|
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||||
|
openai_model: str = Field(default="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_key: Optional[str] = Field(default=None)
|
||||||
|
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||||
|
deepseek_model: str = Field(default="deepseek-chat")
|
||||||
|
|
||||||
|
# 豆包配置
|
||||||
|
doubao_api_key: Optional[str] = Field(default=None)
|
||||||
|
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
doubao_model: str = Field(default="doubao-lite-4k")
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
zhipu_api_key: Optional[str] = Field(default=None)
|
||||||
|
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||||
|
zhipu_model: str = Field(default="glm-4")
|
||||||
|
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
moonshot_api_key: Optional[str] = Field(default=None)
|
||||||
|
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||||
|
moonshot_model: str = Field(default="moonshot-v1-8k")
|
||||||
|
|
||||||
|
# 通用配置
|
||||||
|
max_tokens: int = Field(default=2048)
|
||||||
|
temperature: float = Field(default=0.7)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_current_config(self, session: Session) -> dict:
|
||||||
|
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
|
||||||
|
try:
|
||||||
|
from th_agenter.services.llm_config_service import LLMConfigService
|
||||||
|
# 尝试从数据库读取默认聊天模型配置
|
||||||
|
llm_service = LLMConfigService()
|
||||||
|
db_config = None
|
||||||
|
if session:
|
||||||
|
db_config = await llm_service.get_default_chat_config(session)
|
||||||
|
|
||||||
|
if db_config:
|
||||||
|
# 如果数据库中有默认配置,使用数据库配置
|
||||||
|
config = {
|
||||||
|
"api_key": db_config.api_key,
|
||||||
|
"base_url": db_config.base_url,
|
||||||
|
"model": db_config.model_name,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
if session:
|
||||||
|
session.desc = f"使用LLM配置(get_default_chat_config)> {config}"
|
||||||
|
else:
|
||||||
|
logger.info(f"使用LLM配置(get_default_chat_config) > {config}")
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||||
|
if session:
|
||||||
|
session.desc = f"EXCEPTION: 获取默认对话模型配置失败: {str(e)}"
|
||||||
|
else:
|
||||||
|
logger.error(f"获取默认对话模型配置失败: {str(e)}")
|
||||||
|
|
||||||
|
# 回退到原有的环境变量配置
|
||||||
|
provider_configs = {
|
||||||
|
"openai": {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"base_url": self.openai_base_url,
|
||||||
|
"model": self.openai_model
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key": self.deepseek_api_key,
|
||||||
|
"base_url": self.deepseek_base_url,
|
||||||
|
"model": self.deepseek_model
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"api_key": self.doubao_api_key,
|
||||||
|
"base_url": self.doubao_base_url,
|
||||||
|
"model": self.doubao_model
|
||||||
|
},
|
||||||
|
"zhipu": {
|
||||||
|
"api_key": self.zhipu_api_key,
|
||||||
|
"base_url": self.zhipu_base_url,
|
||||||
|
"model": self.zhipu_model
|
||||||
|
},
|
||||||
|
"moonshot": {
|
||||||
|
"api_key": self.moonshot_api_key,
|
||||||
|
"base_url": self.moonshot_base_url,
|
||||||
|
"model": self.moonshot_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config = provider_configs.get(self.provider, provider_configs["openai"])
|
||||||
|
config.update({
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
class EmbeddingSettings(BaseSettings):
|
||||||
|
"""Embedding模型配置 - 支持多种提供商."""
|
||||||
|
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
openai_api_key: Optional[str] = Field(default=None)
|
||||||
|
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||||
|
openai_embedding_model: str = Field(default="text-embedding-ada-002")
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_key: Optional[str] = Field(default=None)
|
||||||
|
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||||
|
deepseek_embedding_model: str = Field(default="deepseek-embedding")
|
||||||
|
|
||||||
|
# 豆包配置
|
||||||
|
doubao_api_key: Optional[str] = Field(default=None)
|
||||||
|
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
doubao_embedding_model: str = Field(default="doubao-embedding")
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
zhipu_api_key: Optional[str] = Field(default=None)
|
||||||
|
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||||
|
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
moonshot_api_key: Optional[str] = Field(default=None)
|
||||||
|
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||||
|
moonshot_embedding_model: str = Field(default="moonshot-embedding")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_current_config(self, session: Session) -> dict:
|
||||||
|
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
|
||||||
|
try:
|
||||||
|
if session:
|
||||||
|
session.desc = "尝试从数据库读取默认嵌入模型配置 ... >>> get_current_config";
|
||||||
|
# 尝试从数据库读取默认嵌入模型配置
|
||||||
|
from th_agenter.services.llm_config_service import LLMConfigService
|
||||||
|
llm_service = LLMConfigService()
|
||||||
|
db_config = await llm_service.get_default_embedding_config(session)
|
||||||
|
|
||||||
|
if db_config:
|
||||||
|
# 如果数据库中有默认配置,使用数据库配置
|
||||||
|
config = {
|
||||||
|
"api_key": db_config.api_key,
|
||||||
|
"base_url": db_config.base_url,
|
||||||
|
"model": db_config.model_name
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||||
|
if session:
|
||||||
|
session.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||||
|
|
||||||
|
# 回退到原有的环境变量配置
|
||||||
|
provider_configs = {
|
||||||
|
"openai": {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"base_url": self.openai_base_url,
|
||||||
|
"model": self.openai_embedding_model
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key": self.deepseek_api_key,
|
||||||
|
"base_url": self.deepseek_base_url,
|
||||||
|
"model": self.deepseek_embedding_model
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"api_key": self.doubao_api_key,
|
||||||
|
"base_url": self.doubao_base_url,
|
||||||
|
"model": self.doubao_embedding_model
|
||||||
|
},
|
||||||
|
"zhipu": {
|
||||||
|
"api_key": self.zhipu_api_key,
|
||||||
|
"base_url": self.zhipu_base_url,
|
||||||
|
"model": self.zhipu_embedding_model
|
||||||
|
},
|
||||||
|
"moonshot": {
|
||||||
|
"api_key": self.moonshot_api_key,
|
||||||
|
"base_url": self.moonshot_base_url,
|
||||||
|
"model": self.moonshot_embedding_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider_configs.get(self.provider, provider_configs["zhipu"])
|
||||||
|
|
||||||
|
class VectorDBSettings(BaseSettings):
|
||||||
|
"""Vector database configuration."""
|
||||||
|
type: str = Field(default="pgvector", alias="vector_db_type")
|
||||||
|
persist_directory: str = Field(default="./data/chroma")
|
||||||
|
collection_name: str = Field(default="documents")
|
||||||
|
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
|
||||||
|
|
||||||
|
# PostgreSQL pgvector configuration
|
||||||
|
pgvector_host: str = Field(default="localhost")
|
||||||
|
pgvector_port: int = Field(default=5432)
|
||||||
|
pgvector_database: str = Field(default="vectordb")
|
||||||
|
pgvector_user: str = Field(default="postgres")
|
||||||
|
pgvector_password: str = Field(default="")
|
||||||
|
pgvector_table_name: str = Field(default="embeddings")
|
||||||
|
pgvector_vector_dimension: int = Field(default=1024)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class FileSettings(BaseSettings):
|
||||||
|
"""File processing configuration."""
|
||||||
|
upload_dir: str = Field(default="./data/uploads")
|
||||||
|
max_size: int = Field(default=10485760) # 10MB
|
||||||
|
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
|
||||||
|
chunk_size: int = Field(default=1000)
|
||||||
|
chunk_overlap: int = Field(default=200)
|
||||||
|
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
|
||||||
|
|
||||||
|
@field_validator('allowed_extensions', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def parse_allowed_extensions(cls, v):
|
||||||
|
"""Parse comma-separated string to list of extensions."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Split by comma and add dots if not present
|
||||||
|
extensions = [ext.strip() for ext in v.split(',')]
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
# Ensure all extensions start with dot
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
|
||||||
|
return v
|
||||||
|
|
||||||
|
def get_allowed_extensions_list(self) -> List[str]:
|
||||||
|
"""Get allowed extensions as a list."""
|
||||||
|
if isinstance(self.allowed_extensions, list):
|
||||||
|
return self.allowed_extensions
|
||||||
|
elif isinstance(self.allowed_extensions, str):
|
||||||
|
# Split by comma and add dots if not present
|
||||||
|
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||||
|
return []
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class StorageSettings(BaseSettings):
|
||||||
|
"""Storage configuration."""
|
||||||
|
storage_type: str = Field(default="local") # local or s3
|
||||||
|
upload_directory: str = Field(default="./data/uploads")
|
||||||
|
|
||||||
|
# S3 settings
|
||||||
|
s3_bucket_name: str = Field(default="chat-agent-files")
|
||||||
|
aws_access_key_id: Optional[str] = Field(default=None)
|
||||||
|
aws_secret_access_key: Optional[str] = Field(default=None)
|
||||||
|
aws_region: str = Field(default="us-east-1")
|
||||||
|
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class CORSSettings(BaseSettings):
|
||||||
|
"""CORS configuration."""
|
||||||
|
allowed_origins: List[str] = Field(default=["*"])
|
||||||
|
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||||
|
allowed_headers: List[str] = Field(default=["*"])
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class ChatSettings(BaseSettings):
|
||||||
|
"""Chat configuration."""
|
||||||
|
max_history_length: int = Field(default=10)
|
||||||
|
system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。")
|
||||||
|
max_response_tokens: int = Field(default=1000)
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Main application settings."""
|
||||||
|
|
||||||
|
# App info
|
||||||
|
app_name: str = Field(default="TH Agenter")
|
||||||
|
app_version: str = Field(default="0.2.0")
|
||||||
|
debug: bool = Field(default=True)
|
||||||
|
environment: str = Field(default="development")
|
||||||
|
|
||||||
|
# Server
|
||||||
|
host: str = Field(default="0.0.0.0")
|
||||||
|
port: int = Field(default=8000)
|
||||||
|
|
||||||
|
# Configuration sections
|
||||||
|
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||||||
|
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
||||||
|
llm: LLMSettings = Field(default_factory=LLMSettings)
|
||||||
|
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
|
||||||
|
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
|
||||||
|
file: FileSettings = Field(default_factory=FileSettings)
|
||||||
|
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||||||
|
cors: CORSSettings = Field(default_factory=CORSSettings)
|
||||||
|
chat: ChatSettings = Field(default_factory=ChatSettings)
|
||||||
|
tool: ToolSetings = Field(default_factory=ToolSetings)
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
|
||||||
|
"""Load settings from YAML file."""
|
||||||
|
config_file = Path(config_path)
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
# 获取当前文件所在目录(backend/open_agent/core)
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
# 向上两级到backend目录,然后找configs/settings.yaml
|
||||||
|
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
|
||||||
|
if backend_config_path.exists():
|
||||||
|
config_file = backend_config_path
|
||||||
|
else:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
config_data = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
# 处理环境变量替换
|
||||||
|
config_data = cls._resolve_env_vars_nested(config_data)
|
||||||
|
|
||||||
|
# 为每个子设置类创建实例,确保它们能正确加载环境变量
|
||||||
|
# 如果YAML中没有对应配置,则使用默认的BaseSettings加载(会自动读取.env文件)
|
||||||
|
settings_kwargs = {}
|
||||||
|
|
||||||
|
# 显式处理各个子设置,以解决debug等情况因为环境的变化没有自动加载.env配置的问题
|
||||||
|
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
|
||||||
|
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
|
||||||
|
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
|
||||||
|
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
|
||||||
|
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
|
||||||
|
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
|
||||||
|
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
|
||||||
|
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
|
||||||
|
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
|
||||||
|
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||||||
|
|
||||||
|
# 添加顶级配置
|
||||||
|
for key, value in config_data.items():
|
||||||
|
if key not in settings_kwargs:
|
||||||
|
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
|
||||||
|
settings_kwargs[key] = value
|
||||||
|
|
||||||
|
return cls(**settings_kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
||||||
|
"""Flatten nested configuration dictionary."""
|
||||||
|
flat = {}
|
||||||
|
for key, value in config.items():
|
||||||
|
new_key = f"{prefix}_{key}" if prefix else key
|
||||||
|
if isinstance(value, dict):
|
||||||
|
flat.update(Settings._flatten_config(value, new_key))
|
||||||
|
else:
|
||||||
|
flat[new_key] = value
|
||||||
|
return flat
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Resolve environment variables in nested configuration."""
|
||||||
|
if isinstance(config, dict):
|
||||||
|
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
|
||||||
|
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
|
||||||
|
env_var = config[2:-1]
|
||||||
|
return os.getenv(env_var, config)
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Resolve environment variables in configuration values."""
|
||||||
|
resolved = {}
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||||
|
env_var = value[2:-1]
|
||||||
|
resolved[key] = os.getenv(env_var, value)
|
||||||
|
else:
|
||||||
|
resolved[key] = value
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get cached settings instance."""
|
||||||
|
settings = Settings.load_from_yaml()
|
||||||
|
return settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
@ -0,0 +1,142 @@
|
||||||
|
"""
|
||||||
|
HTTP请求上下文管理,如:获取当前登录用户信息及Token信息
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional
|
||||||
|
import threading
|
||||||
|
from ..models.user import User
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Context variable to store current user
|
||||||
|
current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None)
|
||||||
|
|
||||||
|
# Thread-local storage as backup
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
class UserContext:
|
||||||
|
"""User context manager for accessing current user globally."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_current_user(user: User, canLog: bool = False) -> None:
|
||||||
|
"""Set current user in context."""
|
||||||
|
if canLog:
|
||||||
|
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||||||
|
|
||||||
|
# Store user information as a dictionary instead of the SQLAlchemy model
|
||||||
|
user_dict = {
|
||||||
|
'id': user.id,
|
||||||
|
'username': user.username,
|
||||||
|
'email': user.email,
|
||||||
|
'full_name': user.full_name,
|
||||||
|
'is_active': user.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set in ContextVar
|
||||||
|
current_user_context.set(user_dict)
|
||||||
|
|
||||||
|
# Also set in thread-local as backup
|
||||||
|
_thread_local.current_user = user_dict
|
||||||
|
|
||||||
|
# Verify it was set
|
||||||
|
verify_user = current_user_context.get()
|
||||||
|
if canLog:
|
||||||
|
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_current_user_with_token(user: User, canLog: bool = False):
|
||||||
|
"""Set current user in context and return token for cleanup."""
|
||||||
|
if canLog:
|
||||||
|
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
|
||||||
|
|
||||||
|
# Store user information as a dictionary instead of the SQLAlchemy model
|
||||||
|
user_dict = {
|
||||||
|
'id': user.id,
|
||||||
|
'username': user.username,
|
||||||
|
'email': user.email,
|
||||||
|
'full_name': user.full_name,
|
||||||
|
'is_active': user.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set in ContextVar and get token
|
||||||
|
token = current_user_context.set(user_dict)
|
||||||
|
|
||||||
|
# Also set in thread-local as backup
|
||||||
|
_thread_local.current_user = user_dict
|
||||||
|
|
||||||
|
# Verify it was set
|
||||||
|
verify_user = current_user_context.get()
|
||||||
|
if canLog:
|
||||||
|
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_current_user_token(token):
|
||||||
|
"""Reset current user context using token."""
|
||||||
|
logger.info("[UserContext] - Resetting user context using token")
|
||||||
|
|
||||||
|
# Reset ContextVar using token
|
||||||
|
current_user_context.reset(token)
|
||||||
|
|
||||||
|
# Clear thread-local as well
|
||||||
|
if hasattr(_thread_local, 'current_user'):
|
||||||
|
delattr(_thread_local, 'current_user')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user() -> Optional[dict]:
|
||||||
|
"""Get current user from context."""
|
||||||
|
# Try ContextVar first
|
||||||
|
user = current_user_context.get()
|
||||||
|
if user:
|
||||||
|
# logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}")
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Fallback to thread-local
|
||||||
|
user = getattr(_thread_local, 'current_user', None)
|
||||||
|
if user:
|
||||||
|
# logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}")
|
||||||
|
return user
|
||||||
|
|
||||||
|
logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user_id() -> Optional[int]:
|
||||||
|
"""Get current user ID from context."""
|
||||||
|
try:
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.get('id') if user else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[UserContext] - Error getting current user ID: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_current_user(canLog: bool = False) -> None:
|
||||||
|
"""Clear current user from context."""
|
||||||
|
if canLog:
|
||||||
|
logger.info("[UserContext] - 清除当前用户上下文")
|
||||||
|
|
||||||
|
current_user_context.set(None)
|
||||||
|
if hasattr(_thread_local, 'current_user'):
|
||||||
|
delattr(_thread_local, 'current_user')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def require_current_user() -> dict:
|
||||||
|
"""Get current user from context, raise exception if not found."""
|
||||||
|
# Use the same logic as get_current_user to check both ContextVar and thread-local
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
if user is None:
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="No authenticated user in context"
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def require_current_user_id() -> int:
|
||||||
|
"""Get current user ID from context, raise exception if not found."""
|
||||||
|
user = UserContext.require_current_user()
|
||||||
|
return user.get('id')
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Custom exceptions for the application."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCustomException(Exception):
|
||||||
|
"""Base custom exception class."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(BaseCustomException):
|
||||||
|
"""Exception raised when a resource is not found."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(BaseCustomException):
|
||||||
|
"""Exception raised when validation fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(BaseCustomException):
|
||||||
|
"""Exception raised when authentication fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(BaseCustomException):
|
||||||
|
"""Exception raised when authorization fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(BaseCustomException):
|
||||||
|
"""Exception raised when database operations fail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurationError(BaseCustomException):
|
||||||
|
"""Exception raised when configuration is invalid."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalServiceError(BaseCustomException):
|
||||||
|
"""Exception raised when external service calls fail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BusinessLogicError(BaseCustomException):
|
||||||
|
"""Exception raised when business logic validation fails."""
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
"""
|
||||||
|
中间件管理,如上下文中间件:校验Token等
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import Response
|
||||||
|
from typing import Callable
|
||||||
|
from loguru import logger
|
||||||
|
from fastapi import status
|
||||||
|
from utils.util_exceptions import HxfErrorResponse
|
||||||
|
|
||||||
|
from ..db.database import get_session, AsyncSessionFactory, engine_async
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from ..services.auth import AuthService
|
||||||
|
from .context import UserContext
|
||||||
|
|
||||||
|
class UserContextMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to set user context for authenticated requests."""
|
||||||
|
|
||||||
|
def __init__(self, app, exclude_paths: list = None):
|
||||||
|
super().__init__(app)
|
||||||
|
self.canLog = False
|
||||||
|
# Paths that don't require authentication
|
||||||
|
self.exclude_paths = exclude_paths or [
|
||||||
|
"/docs",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
"/api/auth/login",
|
||||||
|
"/api/auth/register",
|
||||||
|
"/api/auth/login-oauth",
|
||||||
|
"/auth/login",
|
||||||
|
"/auth/register",
|
||||||
|
"/auth/login-oauth",
|
||||||
|
"/health",
|
||||||
|
"/static/"
|
||||||
|
]
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""Process request and set user context if authenticated."""
|
||||||
|
if self.canLog:
|
||||||
|
logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}")
|
||||||
|
|
||||||
|
# Skip authentication for excluded paths
|
||||||
|
path = request.url.path
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}")
|
||||||
|
|
||||||
|
should_skip = False
|
||||||
|
for exclude_path in self.exclude_paths:
|
||||||
|
# Exact match
|
||||||
|
if path == exclude_path:
|
||||||
|
should_skip = True
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}")
|
||||||
|
break
|
||||||
|
# For paths ending with '/', check if request path starts with it
|
||||||
|
elif exclude_path.endswith('/') and path.startswith(exclude_path):
|
||||||
|
should_skip = True
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头")
|
||||||
|
break
|
||||||
|
# For paths not ending with '/', check if request path starts with it + '/'
|
||||||
|
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
|
||||||
|
should_skip = True
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头")
|
||||||
|
break
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
|
if self.canLog:
|
||||||
|
logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next")
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
|
||||||
|
|
||||||
|
# Always clear any existing user context to ensure fresh authentication
|
||||||
|
UserContext.clear_current_user(self.canLog)
|
||||||
|
|
||||||
|
# Initialize context token
|
||||||
|
user_token = None
|
||||||
|
|
||||||
|
# Try to extract and validate token
|
||||||
|
try:
|
||||||
|
# Get authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
# No token provided, return 401 error
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="缺少或无效的授权头",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract token
|
||||||
|
token = authorization.split(" ")[1]
|
||||||
|
|
||||||
|
|
||||||
|
# Verify token
|
||||||
|
payload = AuthService.verify_token(token)
|
||||||
|
if payload is None:
|
||||||
|
# Invalid token, return 401 error
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="无效或过期的令牌",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get username from token
|
||||||
|
username = payload.get("sub")
|
||||||
|
if not username:
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="令牌负载无效",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
from sqlalchemy import select
|
||||||
|
from ..models.user import User
|
||||||
|
|
||||||
|
# 创建一个临时的异步会话获取用户信息
|
||||||
|
session = AsyncSession(bind=engine_async)
|
||||||
|
try:
|
||||||
|
stmt = select(User).where(User.username == username)
|
||||||
|
user = await session.execute(stmt)
|
||||||
|
user = user.scalar_one_or_none()
|
||||||
|
if not user:
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="用户不存在",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="用户账户已停用",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set user in context using token mechanism
|
||||||
|
user_token = UserContext.set_current_user_with_token(user, self.canLog)
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
|
||||||
|
|
||||||
|
# Verify context is set correctly
|
||||||
|
current_user_id = UserContext.get_current_user_id()
|
||||||
|
if self.canLog:
|
||||||
|
logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文")
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the request
|
||||||
|
logger.error(f"[MIDDLEWARE] - 认证过程 [{request.method} {request.url.path}] 中设置用户上下文出错: {e}")
|
||||||
|
# Return 401 error
|
||||||
|
return HxfErrorResponse(
|
||||||
|
message="认证过程中出错",
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Continue with request
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the request
|
||||||
|
logger.error(f"[MIDDLEWARE] - 请求处理 [{request.method} {request.url.path}] 出错: {e}")
|
||||||
|
# Return 500 error
|
||||||
|
return HxfErrorResponse(e)
|
||||||
|
finally:
|
||||||
|
# Always clear user context after request processing
|
||||||
|
UserContext.clear_current_user(self.canLog)
|
||||||
|
if self.canLog:
|
||||||
|
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""LLM工厂类,用于创建和管理LLM实例"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from loguru import logger
|
||||||
|
from requests import Session
|
||||||
|
from .config import get_settings
|
||||||
|
|
||||||
|
async def new_llm(session: Session = None, model: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
streaming: bool = False) -> ChatOpenAI:
|
||||||
|
"""创建LLM实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||||
|
temperature: 可选,模型温度参数
|
||||||
|
streaming: 是否启用流式响应,默认False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatOpenAI实例
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
llm_config = await settings.llm.get_current_config(session)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
# 根据指定的模型获取对应配置
|
||||||
|
if model.startswith('deepseek'):
|
||||||
|
llm_config['model'] = settings.llm.deepseek_model
|
||||||
|
llm_config['api_key'] = settings.llm.deepseek_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.deepseek_base_url
|
||||||
|
elif model.startswith('doubao'):
|
||||||
|
llm_config['model'] = settings.llm.doubao_model
|
||||||
|
llm_config['api_key'] = settings.llm.doubao_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.doubao_base_url
|
||||||
|
elif model.startswith('glm'):
|
||||||
|
llm_config['model'] = settings.llm.zhipu_model
|
||||||
|
llm_config['api_key'] = settings.llm.zhipu_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.zhipu_base_url
|
||||||
|
elif model.startswith('moonshot'):
|
||||||
|
llm_config['model'] = settings.llm.moonshot_model
|
||||||
|
llm_config['api_key'] = settings.llm.moonshot_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.moonshot_base_url
|
||||||
|
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model=llm_config['model'],
|
||||||
|
api_key=llm_config['api_key'],
|
||||||
|
base_url=llm_config['base_url'],
|
||||||
|
temperature=temperature if temperature is not None else llm_config['temperature'],
|
||||||
|
max_tokens=llm_config['max_tokens'],
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
|
async def new_agent(session: Session = None, model: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
streaming: bool = False) -> ChatOpenAI:
|
||||||
|
"""创建LLM实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||||
|
temperature: 可选,模型温度参数
|
||||||
|
streaming: 是否启用流式响应,默认False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatOpenAI实例
|
||||||
|
"""
|
||||||
|
llm = await new_llm(session, model, temperature, streaming)
|
||||||
|
return create_react_agent(llm, [])
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""简化的权限检查系统."""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import HTTPException, Depends
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ..db.database import get_session
|
||||||
|
from ..models.user import User
|
||||||
|
from ..models.permission import Role
|
||||||
|
from ..services.auth import AuthService
|
||||||
|
|
||||||
|
|
||||||
|
async def is_super_admin(user: User, session: Session) -> bool:
|
||||||
|
"""检查用户是否为超级管理员."""
|
||||||
|
session.desc = f"检查用户 {user.id} 是否为超级管理员"
|
||||||
|
if not user or not user.is_active:
|
||||||
|
session.desc = f"用户 {user.id} 不是活跃状态"
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 直接使用提供的session查询,避免MissingGreenlet错误
|
||||||
|
from sqlalchemy import select
|
||||||
|
from ..models.permission import UserRole, Role
|
||||||
|
|
||||||
|
stmt = select(UserRole).join(Role).filter(
|
||||||
|
UserRole.user_id == user.id,
|
||||||
|
Role.code == 'SUPER_ADMIN',
|
||||||
|
Role.is_active == True
|
||||||
|
)
|
||||||
|
user_role = await session.execute(stmt)
|
||||||
|
result = user_role.scalar_one_or_none() is not None
|
||||||
|
session.desc = f"用户 {user.id} 超级管理员角色查询结果: {result}"
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
# 如果调用失败,记录错误并返回False
|
||||||
|
session.desc = f"EXCEPTION: 用户 {user.id} 超级管理员角色查询失败: {str(e)}"
|
||||||
|
logger.error(f"检查用户 {user.id} 超级管理员角色失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def require_super_admin(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
) -> User:
|
||||||
|
"""要求超级管理员权限的依赖项."""
|
||||||
|
if not await is_super_admin(current_user, session):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="需要超级管理员权限"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def require_authenticated_user(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
) -> User:
|
||||||
|
"""要求已认证用户的依赖项."""
|
||||||
|
if not current_user or not current_user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="需要登录"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePermissionChecker:
|
||||||
|
"""简化的权限检查器."""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def check_super_admin(self, user: User) -> bool:
|
||||||
|
"""检查是否为超级管理员."""
|
||||||
|
return await is_super_admin(user, self.db)
|
||||||
|
|
||||||
|
async def check_user_access(self, user: User, target_user_id: int) -> bool:
|
||||||
|
"""检查用户访问权限(自己或超级管理员)."""
|
||||||
|
if not user or not user.is_active:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 超级管理员可以访问所有用户
|
||||||
|
if await self.check_super_admin(user):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 用户只能访问自己的信息
|
||||||
|
return user.id == target_user_id
|
||||||
|
|
||||||
|
|
||||||
|
# 权限装饰器
|
||||||
|
def super_admin_required(func):
|
||||||
|
"""超级管理员权限装饰器."""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def authenticated_required(func):
|
||||||
|
"""认证用户权限装饰器."""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""User utility functions for easy access to current user context."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from ..models.user import User
|
||||||
|
from .context import UserContext
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> Optional[User]:
|
||||||
|
"""Get current authenticated user from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_id() -> Optional[int]:
|
||||||
|
"""Get current authenticated user ID from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user ID if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user_id()
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user() -> User:
|
||||||
|
"""Get current authenticated user from context, raise exception if not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no authenticated user in context
|
||||||
|
"""
|
||||||
|
return UserContext.require_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user_id() -> int:
|
||||||
|
"""Get current authenticated user ID from context, raise exception if not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no authenticated user in context
|
||||||
|
"""
|
||||||
|
return UserContext.require_current_user_id()
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_authenticated() -> bool:
|
||||||
|
"""Check if there is an authenticated user in the current context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user is authenticated, False otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_username() -> Optional[str]:
|
||||||
|
"""Get current authenticated user's username from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user's username if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.username if user else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_email() -> Optional[str]:
|
||||||
|
"""Get current authenticated user's email from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user's email if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.email if user else None
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
"""Database module for TH Agenter."""
|
||||||
|
|
||||||
|
from .database import get_session
|
||||||
|
from .base import Base
|
||||||
|
from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_session", "Base"]
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
"""Database base model."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Integer, DateTime, event
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import MetaData
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
metadata = MetaData(
|
||||||
|
naming_convention={
|
||||||
|
# ix: index, 索引
|
||||||
|
"ix": "ix_%(column_0_label)s",
|
||||||
|
# uq: unique, 唯一约束
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||||
|
# ck: check, 检查约束
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
# fk: foreign key, 外键约束
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
# pk: primary key, 主键约束
|
||||||
|
"pk": "pk_%(table_name)s"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Base):
|
||||||
|
"""Base model with common fields."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert model to dictionary."""
|
||||||
|
return {
|
||||||
|
column.name: getattr(self, column.name).isoformat() if hasattr(getattr(self, column.name), 'isoformat') else getattr(self, column.name)
|
||||||
|
for column in self.__table__.columns
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
"""Create model instance from dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing model field values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance created from the dictionary
|
||||||
|
"""
|
||||||
|
# Filter out fields that don't exist in the model
|
||||||
|
model_fields = {column.name for column in cls.__table__.columns}
|
||||||
|
filtered_data = {key: value for key, value in data.items() if key in model_fields}
|
||||||
|
|
||||||
|
# Create and return the instance
|
||||||
|
return cls(**filtered_data)
|
||||||
|
|
||||||
|
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||||
|
"""对创建/更新操作设置created_by/updated_by字段。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID,用于设置创建/更新操作的审计字段(可选,默认从上下文获取)
|
||||||
|
is_update: True 表示更新操作,False 表示创建操作
|
||||||
|
"""
|
||||||
|
# 如果未提供user_id,则从上下文获取
|
||||||
|
if user_id is None:
|
||||||
|
from ..core.context import UserContext
|
||||||
|
try:
|
||||||
|
user_id = UserContext.get_current_user_id()
|
||||||
|
except Exception:
|
||||||
|
# 如果上下文没有用户ID,则跳过设置审计字段
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果仍未提供user_id,则跳过设置审计字段
|
||||||
|
if user_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not is_update:
|
||||||
|
# 对于创建操作,同时设置created_by和updated_by
|
||||||
|
self.created_by = user_id
|
||||||
|
self.updated_by = user_id
|
||||||
|
else:
|
||||||
|
# 对于更新操作,仅设置updated_by
|
||||||
|
self.updated_by = user_id
|
||||||
|
|
||||||
|
# @event.listens_for(Session, 'before_flush')
|
||||||
|
# def set_audit_fields_before_flush(session, flush_context, instances):
|
||||||
|
# """Automatically set audit fields before flush."""
|
||||||
|
# try:
|
||||||
|
# from th_agenter.core.context import UserContext
|
||||||
|
# user_id = UserContext.get_current_user_id()
|
||||||
|
# except Exception:
|
||||||
|
# user_id = None
|
||||||
|
|
||||||
|
# # 处理新增对象
|
||||||
|
# for instance in session.new:
|
||||||
|
# if isinstance(instance, BaseModel) and user_id:
|
||||||
|
# instance.created_by = user_id
|
||||||
|
# instance.updated_by = user_id
|
||||||
|
|
||||||
|
# # 处理修改对象
|
||||||
|
# for instance in session.dirty:
|
||||||
|
# if isinstance(instance, BaseModel) and user_id:
|
||||||
|
# instance.updated_by = user_id
|
||||||
|
|
||||||
|
# # def __init__(self, **kwargs):
|
||||||
|
# # """Initialize model with automatic audit fields setting."""
|
||||||
|
# # super().__init__(**kwargs)
|
||||||
|
# # # Set audit fields for new instances
|
||||||
|
# # self.set_audit_fields()
|
||||||
|
|
||||||
|
# # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||||
|
# # """Set audit fields for create/update operations.
|
||||||
|
|
||||||
|
# # Args:
|
||||||
|
# # user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||||
|
# # is_update: True for update operations, False for create operations
|
||||||
|
# # """
|
||||||
|
# # # Get user_id from context if not provided
|
||||||
|
# # if user_id is None:
|
||||||
|
# # from ..core.context import UserContext
|
||||||
|
# # try:
|
||||||
|
# # user_id = UserContext.get_current_user_id()
|
||||||
|
# # except Exception:
|
||||||
|
# # # If no user in context, skip setting audit fields
|
||||||
|
# # return
|
||||||
|
|
||||||
|
# # # Skip if still no user_id
|
||||||
|
# # if user_id is None:
|
||||||
|
# # return
|
||||||
|
|
||||||
|
# # if not is_update:
|
||||||
|
# # # For create operations, set both create_by and update_by
|
||||||
|
# # self.created_by = user_id
|
||||||
|
# # self.updated_by = user_id
|
||||||
|
# # else:
|
||||||
|
# # # For update operations, only set update_by
|
||||||
|
# # self.updated_by = user_id
|
||||||
|
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Database connection and session management."""
|
||||||
|
|
||||||
|
import uuid, re
|
||||||
|
from loguru import logger
|
||||||
|
import traceback
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from utils.general import gradient_text
|
||||||
|
|
||||||
|
from ..core.config import get_settings
|
||||||
|
from .base import Base
|
||||||
|
from utils.util_exceptions import DatabaseError
|
||||||
|
|
||||||
|
# Custom Session class with desc property and unique ID
|
||||||
|
class DrSession(AsyncSession):
|
||||||
|
"""Custom Session class with desc property and unique ID."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize DrSession with unique ID."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.title = ""
|
||||||
|
self.descs = []
|
||||||
|
# 确保info属性存在
|
||||||
|
if not hasattr(self, 'info'):
|
||||||
|
self.info = {}
|
||||||
|
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
||||||
|
self.stepIndex = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def title(self) -> Optional[str]:
|
||||||
|
"""Get work brief from session info."""
|
||||||
|
return self.info.get('title')
|
||||||
|
|
||||||
|
@title.setter
|
||||||
|
def title(self, value: str) -> None:
|
||||||
|
"""Set work brief in session info."""
|
||||||
|
if('title' not in self.info or self.info['title'].strip() == ""):
|
||||||
|
self.info['title'] = value # 确保title属性存在
|
||||||
|
else:
|
||||||
|
self.info['title'] = value + " >>> " + self.info['title']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def desc(self) -> Optional[str]:
|
||||||
|
"""Get work brief from session info."""
|
||||||
|
return self.info.get('desc')
|
||||||
|
|
||||||
|
@desc.setter
|
||||||
|
def desc(self, value: str) -> None:
|
||||||
|
"""Set work brief in session info."""
|
||||||
|
self.stepIndex += 1
|
||||||
|
logger.info(value)
|
||||||
|
|
||||||
|
def log_prefix(self) -> str:
|
||||||
|
"""Get log prefix with session ID and desc."""
|
||||||
|
return f"〖Session{self.info['session_id']}〗"
|
||||||
|
|
||||||
|
def parse_source_pos(self, level: int):
|
||||||
|
pos = (traceback.format_stack())[level].strip().split('\n')[0]
|
||||||
|
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||||
|
if match:
|
||||||
|
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||||
|
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||||
|
return pos
|
||||||
|
|
||||||
|
def log_info(self, msg: str, level: int = -2):
|
||||||
|
"""Log info message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_success(self, msg: str, level: int = -2):
|
||||||
|
"""Log success message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_warning(self, msg: str, level: int = -2):
|
||||||
|
"""Log warning message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_error(self, msg: str, level: int = -2):
|
||||||
|
"""Log error message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
def log_exception(self, msg: str, level: int = -2):
|
||||||
|
"""Log exception message with session ID."""
|
||||||
|
pos = self.parse_source_pos(level)
|
||||||
|
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||||
|
|
||||||
|
engine_async = create_async_engine(
|
||||||
|
get_settings().database.url,
|
||||||
|
echo=False, # get_settings().database.echo,
|
||||||
|
future=True,
|
||||||
|
pool_size=get_settings().database.pool_size,
|
||||||
|
max_overflow=get_settings().database.max_overflow,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
AsyncSessionFactory = sessionmaker(
|
||||||
|
bind=engine_async,
|
||||||
|
class_=DrSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_session(request: Request = None):
|
||||||
|
url = "无request"
|
||||||
|
if request:
|
||||||
|
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
||||||
|
# session = AsyncSessionFactory()
|
||||||
|
print(url)
|
||||||
|
# 取得request的来源IP
|
||||||
|
if request:
|
||||||
|
client_host = request.client.host
|
||||||
|
else:
|
||||||
|
client_host = "无request"
|
||||||
|
session = DrSession(bind=engine_async)
|
||||||
|
|
||||||
|
session.title = f"{url} - {client_host}"
|
||||||
|
|
||||||
|
# 设置request属性
|
||||||
|
if request:
|
||||||
|
session.request = request
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errMsg = f"数据库 session 异常 >>> {e}"
|
||||||
|
session.desc = f"EXCEPTION: {errMsg}"
|
||||||
|
await session.rollback()
|
||||||
|
# 重新抛出原始异常,不转换为 HTTPException
|
||||||
|
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
|
||||||
|
finally:
|
||||||
|
# session.desc = f"数据库 session 关闭"
|
||||||
|
session.desc = ""
|
||||||
|
await session.close()
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Add system management tables.
|
||||||
|
|
||||||
|
Revision ID: add_system_management
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-01-01 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic_sync import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_system_management'
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Create system management tables."""
|
||||||
|
|
||||||
|
# Create departments table
|
||||||
|
op.create_table('departments',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('parent_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['parent_id'], ['departments.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
|
||||||
|
op.create_index(op.f('ix_departments_parent_id'), 'departments', ['parent_id'], unique=False)
|
||||||
|
|
||||||
|
# Create permissions table
|
||||||
|
op.create_table('permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('category', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_permissions_category'), 'permissions', ['category'], unique=False)
|
||||||
|
op.create_index(op.f('ix_permissions_name'), 'permissions', ['name'], unique=False)
|
||||||
|
|
||||||
|
# Create roles table
|
||||||
|
op.create_table('roles',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False)
|
||||||
|
|
||||||
|
# Create role_permissions table
|
||||||
|
op.create_table('role_permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('role_id', 'permission_id', name='uq_role_permission')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_role_permissions_permission_id'), 'role_permissions', ['permission_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_role_permissions_role_id'), 'role_permissions', ['role_id'], unique=False)
|
||||||
|
|
||||||
|
# Create user_roles table
|
||||||
|
op.create_table('user_roles',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('user_id', 'role_id', name='uq_user_role')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_roles_role_id'), 'user_roles', ['role_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_roles_user_id'), 'user_roles', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
# Create user_permissions table
|
||||||
|
op.create_table('user_permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('granted', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('user_id', 'permission_id', name='uq_user_permission')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_permissions_permission_id'), 'user_permissions', ['permission_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_permissions_user_id'), 'user_permissions', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
# Create llm_configs table
|
||||||
|
op.create_table('llm_configs',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('api_key', sa.Text(), nullable=True),
|
||||||
|
sa.Column('api_base', sa.String(length=500), nullable=True),
|
||||||
|
sa.Column('api_version', sa.String(length=20), nullable=True),
|
||||||
|
sa.Column('max_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('temperature', sa.Float(), nullable=True),
|
||||||
|
sa.Column('top_p', sa.Float(), nullable=True),
|
||||||
|
sa.Column('frequency_penalty', sa.Float(), nullable=True),
|
||||||
|
sa.Column('presence_penalty', sa.Float(), nullable=True),
|
||||||
|
sa.Column('timeout', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('is_default', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||||
|
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||||
|
|
||||||
|
# Add new columns to users table
|
||||||
|
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('is_superuser', sa.Boolean(), nullable=True, default=False))
|
||||||
|
op.add_column('users', sa.Column('is_admin', sa.Boolean(), nullable=True, default=False))
|
||||||
|
op.add_column('users', sa.Column('last_login_at', sa.DateTime(), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('login_count', sa.Integer(), nullable=True, default=0))
|
||||||
|
|
||||||
|
# Create foreign key constraint for department_id
|
||||||
|
op.create_foreign_key('fk_users_department_id', 'users', 'departments', ['department_id'], ['id'])
|
||||||
|
op.create_index(op.f('ix_users_department_id'), 'users', ['department_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Drop system management tables."""
|
||||||
|
|
||||||
|
# Drop foreign key and index for users.department_id
|
||||||
|
op.drop_index(op.f('ix_users_department_id'), table_name='users')
|
||||||
|
op.drop_constraint('fk_users_department_id', 'users', type_='foreignkey')
|
||||||
|
|
||||||
|
# Drop new columns from users table
|
||||||
|
op.drop_column('users', 'login_count')
|
||||||
|
op.drop_column('users', 'last_login_at')
|
||||||
|
op.drop_column('users', 'is_admin')
|
||||||
|
op.drop_column('users', 'is_superuser')
|
||||||
|
op.drop_column('users', 'department_id')
|
||||||
|
|
||||||
|
# Drop llm_configs table
|
||||||
|
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||||
|
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||||
|
op.drop_table('llm_configs')
|
||||||
|
|
||||||
|
# Drop user_permissions table
|
||||||
|
op.drop_index(op.f('ix_user_permissions_user_id'), table_name='user_permissions')
|
||||||
|
op.drop_index(op.f('ix_user_permissions_permission_id'), table_name='user_permissions')
|
||||||
|
op.drop_table('user_permissions')
|
||||||
|
|
||||||
|
# Drop user_roles table
|
||||||
|
op.drop_index(op.f('ix_user_roles_user_id'), table_name='user_roles')
|
||||||
|
op.drop_index(op.f('ix_user_roles_role_id'), table_name='user_roles')
|
||||||
|
op.drop_table('user_roles')
|
||||||
|
|
||||||
|
# Drop role_permissions table
|
||||||
|
op.drop_index(op.f('ix_role_permissions_role_id'), table_name='role_permissions')
|
||||||
|
op.drop_index(op.f('ix_role_permissions_permission_id'), table_name='role_permissions')
|
||||||
|
op.drop_table('role_permissions')
|
||||||
|
|
||||||
|
# Drop roles table
|
||||||
|
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||||
|
op.drop_table('roles')
|
||||||
|
|
||||||
|
# Drop permissions table
|
||||||
|
op.drop_index(op.f('ix_permissions_name'), table_name='permissions')
|
||||||
|
op.drop_index(op.f('ix_permissions_category'), table_name='permissions')
|
||||||
|
op.drop_table('permissions')
|
||||||
|
|
||||||
|
# Drop departments table
|
||||||
|
op.drop_index(op.f('ix_departments_parent_id'), table_name='departments')
|
||||||
|
op.drop_index(op.f('ix_departments_name'), table_name='departments')
|
||||||
|
op.drop_table('departments')
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Add user_department association table migration."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
from th_agenter.core.config import get_settings
|
||||||
|
|
||||||
|
async def create_user_department_table():
|
||||||
|
"""Create user_departments association table."""
|
||||||
|
settings = get_settings()
|
||||||
|
database_url = settings.database.url
|
||||||
|
|
||||||
|
print(f"Database URL: {database_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析PostgreSQL连接URL
|
||||||
|
# postgresql://user:password@host:port/database
|
||||||
|
url_parts = database_url.replace('postgresql://', '').split('/')
|
||||||
|
db_name = url_parts[1] if len(url_parts) > 1 else 'postgres'
|
||||||
|
user_host = url_parts[0].split('@')
|
||||||
|
user_pass = user_host[0].split(':')
|
||||||
|
host_port = user_host[1].split(':')
|
||||||
|
|
||||||
|
user = user_pass[0]
|
||||||
|
password = user_pass[1] if len(user_pass) > 1 else ''
|
||||||
|
host = host_port[0]
|
||||||
|
port = int(host_port[1]) if len(host_port) > 1 else 5432
|
||||||
|
|
||||||
|
# 连接PostgreSQL数据库
|
||||||
|
conn = await asyncpg.connect(
|
||||||
|
user=user,
|
||||||
|
password=password,
|
||||||
|
database=db_name,
|
||||||
|
host=host,
|
||||||
|
port=port
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建user_departments表
|
||||||
|
create_table_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS user_departments (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
department_id INTEGER NOT NULL,
|
||||||
|
is_primary BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (department_id) REFERENCES departments (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
await conn.execute(create_table_sql)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
create_indexes_sql = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_user_departments_user_id ON user_departments (user_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_user_departments_department_id ON user_departments (department_id);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_user_departments_unique ON user_departments (user_id, department_id);"
|
||||||
|
]
|
||||||
|
|
||||||
|
for index_sql in create_indexes_sql:
|
||||||
|
await conn.execute(index_sql)
|
||||||
|
|
||||||
|
print("User departments table created successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating user departments table: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if 'conn' in locals():
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(create_user_department_table())
|
||||||
|
|
@ -0,0 +1,457 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Migration script to move hardcoded resources to database."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the backend directory to Python path
|
||||||
|
backend_dir = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.insert(0, str(backend_dir))
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from th_agenter.core.config import get_settings
|
||||||
|
from th_agenter.db.database import Base
|
||||||
|
from th_agenter.models import * # Import all models to ensure they're registered
|
||||||
|
from th_agenter.utils.logger import get_logger
|
||||||
|
from th_agenter.models.resource import Resource
|
||||||
|
from th_agenter.models.permission import Role
|
||||||
|
from th_agenter.models.resource import RoleResource
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def migrate_hardcoded_resources():
|
||||||
|
"""Migrate hardcoded resources from init_resource_data.py to database."""
|
||||||
|
db = None
|
||||||
|
try:
|
||||||
|
# Get database settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Create synchronous engine (remove asyncpg from URL)
|
||||||
|
sync_db_url = settings.database.url.replace('postgresql+asyncpg://', 'postgresql://')
|
||||||
|
sync_engine = create_engine(
|
||||||
|
sync_db_url,
|
||||||
|
echo=False,
|
||||||
|
pool_size=settings.database.pool_size,
|
||||||
|
max_overflow=settings.database.max_overflow,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create synchronous session factory
|
||||||
|
SyncSessionFactory = sessionmaker(bind=sync_engine)
|
||||||
|
|
||||||
|
# Get database session
|
||||||
|
db = SyncSessionFactory()
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
logger.error("Failed to create database session")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create all tables if they don't exist
|
||||||
|
from th_agenter.db.database import engine as global_engine
|
||||||
|
if global_engine:
|
||||||
|
Base.metadata.create_all(bind=global_engine)
|
||||||
|
|
||||||
|
logger.info("Starting hardcoded resources migration...")
|
||||||
|
|
||||||
|
# Check if resources already exist
|
||||||
|
existing_count = db.query(Resource).count()
|
||||||
|
if existing_count > 0:
|
||||||
|
logger.info(f"Found {existing_count} existing resources. Checking role assignments.")
|
||||||
|
# 即使资源已存在,也要检查并分配角色资源关联
|
||||||
|
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||||
|
if admin_role:
|
||||||
|
# 获取所有资源
|
||||||
|
all_resources = db.query(Resource).all()
|
||||||
|
assigned_count = 0
|
||||||
|
|
||||||
|
for resource in all_resources:
|
||||||
|
# 检查关联是否已存在
|
||||||
|
existing = db.query(RoleResource).filter(
|
||||||
|
RoleResource.role_id == admin_role.id,
|
||||||
|
RoleResource.resource_id == resource.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
role_resource = RoleResource(
|
||||||
|
role_id=admin_role.id,
|
||||||
|
resource_id=resource.id
|
||||||
|
)
|
||||||
|
db.add(role_resource)
|
||||||
|
assigned_count += 1
|
||||||
|
|
||||||
|
if assigned_count > 0:
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"已为系统管理员角色分配 {assigned_count} 个新资源")
|
||||||
|
else:
|
||||||
|
logger.info("系统管理员角色已拥有所有资源")
|
||||||
|
else:
|
||||||
|
logger.warning("未找到系统管理员角色")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Define hardcoded resource data
|
||||||
|
main_menu_data = [
|
||||||
|
{
|
||||||
|
"name": "智能问答",
|
||||||
|
"code": "CHAT",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/chat",
|
||||||
|
"component": "views/Chat.vue",
|
||||||
|
"icon": "ChatDotRound",
|
||||||
|
"description": "智能问答功能",
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "智能问数",
|
||||||
|
"code": "SMART_QUERY",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/smart-query",
|
||||||
|
"component": "views/SmartQuery.vue",
|
||||||
|
"icon": "DataAnalysis",
|
||||||
|
"description": "智能问数功能",
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "知识库",
|
||||||
|
"code": "KNOWLEDGE",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/knowledge",
|
||||||
|
"component": "views/KnowledgeBase.vue",
|
||||||
|
"icon": "Collection",
|
||||||
|
"description": "知识库管理",
|
||||||
|
"sort_order": 3,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "工作流编排",
|
||||||
|
"code": "WORKFLOW",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/workflow",
|
||||||
|
"component": "views/Workflow.vue",
|
||||||
|
"icon": "Connection",
|
||||||
|
"description": "工作流编排功能",
|
||||||
|
"sort_order": 4,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "智能体管理",
|
||||||
|
"code": "AGENT",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/agent",
|
||||||
|
"component": "views/Agent.vue",
|
||||||
|
"icon": "User",
|
||||||
|
"description": "智能体管理功能",
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "系统管理",
|
||||||
|
"code": "SYSTEM",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system",
|
||||||
|
"component": "views/SystemManagement.vue",
|
||||||
|
"icon": "Setting",
|
||||||
|
"description": "系统管理功能",
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create main menu resources
|
||||||
|
created_resources = {}
|
||||||
|
for menu_data in main_menu_data:
|
||||||
|
resource = Resource(**menu_data)
|
||||||
|
db.add(resource)
|
||||||
|
db.flush()
|
||||||
|
created_resources[menu_data["code"]] = resource
|
||||||
|
logger.info(f"Created main menu resource: {menu_data['name']}")
|
||||||
|
|
||||||
|
# System management submenu data
|
||||||
|
system_submenu_data = [
|
||||||
|
{
|
||||||
|
"name": "用户管理",
|
||||||
|
"code": "SYSTEM_USERS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/users",
|
||||||
|
"component": "components/system/UserManagement.vue",
|
||||||
|
"icon": "User",
|
||||||
|
"description": "用户管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "部门管理",
|
||||||
|
"code": "SYSTEM_DEPARTMENTS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/departments",
|
||||||
|
"component": "components/system/DepartmentManagement.vue",
|
||||||
|
"icon": "OfficeBuilding",
|
||||||
|
"description": "部门管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "角色管理",
|
||||||
|
"code": "SYSTEM_ROLES",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/roles",
|
||||||
|
"component": "components/system/RoleManagement.vue",
|
||||||
|
"icon": "Avatar",
|
||||||
|
"description": "角色管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 3,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "权限管理",
|
||||||
|
"code": "SYSTEM_PERMISSIONS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/permissions",
|
||||||
|
"component": "components/system/PermissionManagement.vue",
|
||||||
|
"icon": "Lock",
|
||||||
|
"description": "权限管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 4,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "资源管理",
|
||||||
|
"code": "SYSTEM_RESOURCES",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/resources",
|
||||||
|
"component": "components/system/ResourceManagement.vue",
|
||||||
|
"icon": "Grid",
|
||||||
|
"description": "资源管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "大模型管理",
|
||||||
|
"code": "SYSTEM_LLM_CONFIGS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/llm-configs",
|
||||||
|
"component": "components/system/LLMConfigManagement.vue",
|
||||||
|
"icon": "Cpu",
|
||||||
|
"description": "大模型配置管理",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create system management submenu
|
||||||
|
for submenu_data in system_submenu_data:
|
||||||
|
submenu = Resource(**submenu_data)
|
||||||
|
db.add(submenu)
|
||||||
|
db.flush()
|
||||||
|
created_resources[submenu_data["code"]] = submenu
|
||||||
|
logger.info(f"Created system submenu resource: {submenu_data['name']}")
|
||||||
|
|
||||||
|
# Button resources data
|
||||||
|
button_resources_data = [
|
||||||
|
# User management buttons
|
||||||
|
{
|
||||||
|
"name": "新增用户",
|
||||||
|
"code": "USER_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增用户按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑用户",
|
||||||
|
"code": "USER_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑用户按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Role management buttons
|
||||||
|
{
|
||||||
|
"name": "新增角色",
|
||||||
|
"code": "ROLE_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增角色按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑角色",
|
||||||
|
"code": "ROLE_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑角色按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Permission management buttons
|
||||||
|
{
|
||||||
|
"name": "新增权限",
|
||||||
|
"code": "PERMISSION_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增权限按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑权限",
|
||||||
|
"code": "PERMISSION_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑权限按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create button resources
|
||||||
|
for button_data in button_resources_data:
|
||||||
|
button = Resource(**button_data)
|
||||||
|
db.add(button)
|
||||||
|
db.flush()
|
||||||
|
created_resources[button_data["code"]] = button
|
||||||
|
logger.info(f"Created button resource: {button_data['name']}")
|
||||||
|
|
||||||
|
# API resources data
|
||||||
|
api_resources_data = [
|
||||||
|
# User management APIs
|
||||||
|
{
|
||||||
|
"name": "用户列表API",
|
||||||
|
"code": "USER_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/users",
|
||||||
|
"description": "获取用户列表API",
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建用户API",
|
||||||
|
"code": "USER_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/users",
|
||||||
|
"description": "创建用户API",
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Role management APIs
|
||||||
|
{
|
||||||
|
"name": "角色列表API",
|
||||||
|
"code": "ROLE_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/roles",
|
||||||
|
"description": "获取角色列表API",
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建角色API",
|
||||||
|
"code": "ROLE_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/roles",
|
||||||
|
"description": "创建角色API",
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Resource management APIs
|
||||||
|
{
|
||||||
|
"name": "资源列表API",
|
||||||
|
"code": "RESOURCE_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/resources",
|
||||||
|
"description": "获取资源列表API",
|
||||||
|
"sort_order": 10,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建资源API",
|
||||||
|
"code": "RESOURCE_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/resources",
|
||||||
|
"description": "创建资源API",
|
||||||
|
"sort_order": 11,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create API resources
|
||||||
|
for api_data in api_resources_data:
|
||||||
|
api_resource = Resource(**api_data)
|
||||||
|
db.add(api_resource)
|
||||||
|
db.flush()
|
||||||
|
created_resources[api_data["code"]] = api_resource
|
||||||
|
logger.info(f"Created API resource: {api_data['name']}")
|
||||||
|
|
||||||
|
# 分配资源给系统管理员角色
|
||||||
|
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||||
|
if admin_role:
|
||||||
|
all_resources = list(created_resources.values())
|
||||||
|
for resource in all_resources:
|
||||||
|
# 检查关联是否已存在
|
||||||
|
existing = db.query(RoleResource).filter(
|
||||||
|
RoleResource.role_id == admin_role.id,
|
||||||
|
RoleResource.resource_id == resource.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
role_resource = RoleResource(
|
||||||
|
role_id=admin_role.id,
|
||||||
|
resource_id=resource.id
|
||||||
|
)
|
||||||
|
db.add(role_resource)
|
||||||
|
|
||||||
|
logger.info(f"已为系统管理员角色分配 {len(all_resources)} 个资源")
|
||||||
|
else:
|
||||||
|
logger.warning("未找到系统管理员角色")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
total_resources = db.query(Resource).count()
|
||||||
|
logger.info(f"Migration completed successfully. Total resources: {total_resources}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Migration failed: {str(e)}")
|
||||||
|
if db:
|
||||||
|
db.rollback()
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""删除权限相关表的迁移脚本
|
||||||
|
|
||||||
|
Revision ID: remove_permission_tables
|
||||||
|
Revises: add_system_management
|
||||||
|
Create Date: 2024-01-25 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic_sync import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'remove_permission_tables'
|
||||||
|
down_revision = 'add_system_management'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
async def upgrade():
|
||||||
|
"""删除权限相关表."""
|
||||||
|
|
||||||
|
# 获取数据库连接
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# 删除外键约束和表(按依赖关系顺序)
|
||||||
|
tables_to_drop = [
|
||||||
|
'user_permissions', # 用户权限关联表
|
||||||
|
'role_permissions', # 角色权限关联表
|
||||||
|
'permission_resources', # 权限资源关联表
|
||||||
|
'permissions', # 权限表
|
||||||
|
'role_resources', # 角色资源关联表
|
||||||
|
'resources', # 资源表
|
||||||
|
'user_departments', # 用户部门关联表
|
||||||
|
'departments' # 部门表
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_name in tables_to_drop:
|
||||||
|
try:
|
||||||
|
# 检查表是否存在
|
||||||
|
result = connection.execute(text(f"""
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = '{table_name}'
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
table_exists = await result.scalar()
|
||||||
|
|
||||||
|
if table_exists:
|
||||||
|
print(f"删除表: {table_name}")
|
||||||
|
op.drop_table(table_name)
|
||||||
|
else:
|
||||||
|
print(f"表 {table_name} 不存在,跳过")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"删除表 {table_name} 时出错: {e}")
|
||||||
|
# 继续删除其他表
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 删除用户表中的部门相关字段
|
||||||
|
try:
|
||||||
|
# 检查字段是否存在
|
||||||
|
result = connection.execute(text("""
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'users' AND column_name = 'department_id';
|
||||||
|
"""))
|
||||||
|
column_exists = result.fetchone()
|
||||||
|
|
||||||
|
if column_exists:
|
||||||
|
print("删除用户表中的 department_id 字段")
|
||||||
|
op.drop_column('users', 'department_id')
|
||||||
|
else:
|
||||||
|
print("用户表中的 department_id 字段不存在,跳过")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"删除 department_id 字段时出错: {e}")
|
||||||
|
|
||||||
|
# 简化 user_roles 表结构(如果需要的话)
|
||||||
|
try:
|
||||||
|
# 检查 user_roles 表是否有多余的字段
|
||||||
|
result = connection.execute(text("""
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'user_roles' AND column_name IN ('id', 'created_at', 'updated_at', 'created_by', 'updated_by');
|
||||||
|
"""))
|
||||||
|
extra_columns = [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
|
if extra_columns:
|
||||||
|
print("简化 user_roles 表结构")
|
||||||
|
# 创建新的简化表
|
||||||
|
op.execute(text("""
|
||||||
|
CREATE TABLE user_roles_new (
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
role_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, role_id),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (role_id) REFERENCES roles(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# 迁移数据
|
||||||
|
op.execute(text("""
|
||||||
|
INSERT INTO user_roles_new (user_id, role_id)
|
||||||
|
SELECT DISTINCT user_id, role_id FROM user_roles;
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# 删除旧表,重命名新表
|
||||||
|
op.drop_table('user_roles')
|
||||||
|
op.execute(text("ALTER TABLE user_roles_new RENAME TO user_roles;"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"简化 user_roles 表时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""回滚操作 - 重新创建权限相关表."""
|
||||||
|
|
||||||
|
# 注意:这是一个破坏性操作,回滚会丢失数据
|
||||||
|
# 在生产环境中应该谨慎使用
|
||||||
|
|
||||||
|
print("警告:回滚操作会重新创建权限相关表,但不会恢复数据")
|
||||||
|
|
||||||
|
# 重新创建基本的权限表结构(简化版)
|
||||||
|
op.create_table('permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(100), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, default=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table('role_permissions',
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('role_id', 'permission_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加用户表的 department_id 字段
|
||||||
|
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||||
|
|
@ -0,0 +1,198 @@
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
|
||||||
|
|
||||||
|
# 核心:导入 LangChain 的基础语言模型抽象类
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMConfig_DataClass:
|
||||||
|
"""
|
||||||
|
统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段
|
||||||
|
通过 provider + is_embedding 区分模型类型:
|
||||||
|
- 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
|
||||||
|
- 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False
|
||||||
|
- 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True
|
||||||
|
"""
|
||||||
|
# ====================== 数据库核心公共字段(必选/可选) ======================
|
||||||
|
# 基础标识字段
|
||||||
|
name: str # 模型自定义名称(如 "gpt-5")
|
||||||
|
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5")
|
||||||
|
provider: str # 提供商(openai/llama/bge/zhipu 等)
|
||||||
|
id: Optional[int] = None # 数据库主键ID
|
||||||
|
description: Optional[str] = None # 模型描述
|
||||||
|
is_active: bool = True # 是否启用
|
||||||
|
is_default: bool = False # 是否默认模型
|
||||||
|
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
|
||||||
|
|
||||||
|
# ====================== 通用生成参数(所有推理模型共用) ======================
|
||||||
|
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
|
||||||
|
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
|
||||||
|
top_p: float = 0.6 # 采样Top-P
|
||||||
|
frequency_penalty: float = 0.0 # 频率惩罚
|
||||||
|
presence_penalty: float = 0.0 # 存在惩罚
|
||||||
|
|
||||||
|
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
|
||||||
|
api_key: Optional[str] = None # API密钥(在线模型必填)
|
||||||
|
base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1)
|
||||||
|
# timeout: int = 30 # 请求超时时间(秒)
|
||||||
|
max_retries: int = 3 # 最大重试次数
|
||||||
|
api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview)
|
||||||
|
|
||||||
|
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
|
||||||
|
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
|
||||||
|
device: str = "cpu" # 运行设备(cpu/cuda/mps)
|
||||||
|
n_ctx: int = 2048 # 上下文窗口大小
|
||||||
|
n_threads: int = 8 # 推理线程数
|
||||||
|
quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16)
|
||||||
|
load_in_8bit: bool = False # 是否8bit加载
|
||||||
|
load_in_4bit: bool = False # 是否4bit加载
|
||||||
|
prompt_template: Optional[str] = None # 自定义Prompt模板
|
||||||
|
|
||||||
|
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
|
||||||
|
normalize_embeddings: bool = True # 是否归一化向量
|
||||||
|
batch_size: int = 32 # 批量编码大小
|
||||||
|
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
|
||||||
|
dimension: Optional[int] = None # 向量维度(如 768)
|
||||||
|
|
||||||
|
# ====================== 元数据字段(数据库自动维护) ======================
|
||||||
|
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
|
||||||
|
usage_count: int = 0 # 使用次数
|
||||||
|
last_used_at: Optional[datetime] = None # 最后使用时间
|
||||||
|
created_at: Optional[datetime] = None # 创建时间
|
||||||
|
updated_at: Optional[datetime] = None # 更新时间
|
||||||
|
created_by: Optional[int] = None # 创建人ID
|
||||||
|
updated_by: Optional[int] = None # 更新人ID
|
||||||
|
|
||||||
|
api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储)
|
||||||
|
|
||||||
|
# ====================== 核心工具方法 ======================
|
||||||
|
def __post_init__(self):
|
||||||
|
"""后置初始化:自动校验和修正配置"""
|
||||||
|
# 1. 嵌入式模型强制清空推理参数(避免误用)
|
||||||
|
if self.is_embedding:
|
||||||
|
self.max_tokens = 0
|
||||||
|
self.temperature = 0.0
|
||||||
|
self.top_p = 0.0
|
||||||
|
|
||||||
|
# 2. 校验必填参数(按模型类型)
|
||||||
|
self._validate_required_fields()
|
||||||
|
|
||||||
|
def _validate_required_fields(self):
|
||||||
|
"""按模型类型校验必填参数"""
|
||||||
|
# 在线模型校验
|
||||||
|
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
|
||||||
|
|
||||||
|
# 本地模型校验
|
||||||
|
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||||||
|
if not self.model_path:
|
||||||
|
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典(用于存入/更新数据库)"""
|
||||||
|
return {
|
||||||
|
key: value for key, value in self.__dict__.items()
|
||||||
|
if not key.startswith('_') # 排除私有属性
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
|
||||||
|
"""从数据库字典初始化配置(核心方法)"""
|
||||||
|
# 1. 时间字段转换:字符串 → datetime
|
||||||
|
time_fields = ['last_used_at', 'created_at', 'updated_at']
|
||||||
|
for field_name in time_fields:
|
||||||
|
val = db_dict.get(field_name)
|
||||||
|
if val and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
db_dict[field_name] = None
|
||||||
|
|
||||||
|
# 2. 过滤数据库中无关字段(如 api_key_masked)
|
||||||
|
valid_fields = cls.__dataclass_fields__.keys()
|
||||||
|
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
|
||||||
|
|
||||||
|
# 3. 初始化并返回配置实例
|
||||||
|
return cls(**filtered_dict)
|
||||||
|
|
||||||
|
def get_model_type(self) -> str:
|
||||||
|
"""快速判断模型类型(返回:online/local/embedding)"""
|
||||||
|
if self.is_embedding:
|
||||||
|
return "embedding"
|
||||||
|
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||||||
|
return "online"
|
||||||
|
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||||||
|
return "local"
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM(BaseChatModel):
|
||||||
|
"""
|
||||||
|
继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类)
|
||||||
|
使其能直接用于 create_agent
|
||||||
|
"""
|
||||||
|
# 配置参数(通过 __init__ 初始化)
|
||||||
|
config: Any = None
|
||||||
|
model: Any = None
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__() # 必须调用父类构造函数
|
||||||
|
self.config = config
|
||||||
|
self.model = None
|
||||||
|
self._validate_config()
|
||||||
|
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
|
||||||
|
|
||||||
|
# ---------------------- 必须实现的核心抽象方法(LangChain 协议) ----------------------
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
核心同步生成方法(LangChain 要求必须实现)
|
||||||
|
messages: 消息列表(如 [HumanMessage(content="你好")])
|
||||||
|
返回 ChatResult 类型(LangChain 标准输出)
|
||||||
|
"""
|
||||||
|
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
** kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""异步生成方法(LangChain 异步协议)"""
|
||||||
|
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""返回模型类型标识(如 "openai"、"llama"、"bge")"""
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
"""加载模型(自定义逻辑)"""
|
||||||
|
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""释放资源(自定义逻辑)"""
|
||||||
|
if self.model:
|
||||||
|
logger.info(f"释放 {self.__class__.__name__} 模型资源")
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.load_model()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from loguru import logger
|
||||||
|
from th_agenter.llm.base_llm import BaseLLM
|
||||||
|
|
||||||
|
class EmbedLLM(BaseLLM, Embeddings):
|
||||||
|
"""嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel"""
|
||||||
|
def __init__(self, config):
|
||||||
|
logger.info(f"初始化 EmbedLLM 模型: {config.model_name}")
|
||||||
|
super().__init__(config)
|
||||||
|
logger.info(f"已加载 EmbedLLM 模型: {config.model_name}")
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""LangChain 要求的核心方法:批量文档向量化"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""异步批量向量化"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""单查询文本向量化"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""异步单查询向量化"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 具体实现 BGE 嵌入式模型
|
||||||
|
class BGEEmbedLLM(EmbedLLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def _validate_config(self):
|
||||||
|
if not self.config.model_name:
|
||||||
|
raise ValueError("必须配置 model_name")
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
logger.info(f"正在加载 嵌入 模型: {self.config.model_name}")
|
||||||
|
if hasattr(self.config, 'provider') and self.config.provider == 'ollama':
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
self.model = OllamaEmbeddings(
|
||||||
|
model=self.config.model_name,
|
||||||
|
base_url=self.config.base_url if hasattr(self.config, 'base_url') else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
self.model = HuggingFaceEmbeddings(
|
||||||
|
model_name=self.config.model_name,
|
||||||
|
model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"},
|
||||||
|
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True}
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to load HuggingFaceEmbeddings: {e}")
|
||||||
|
logger.error("Please install sentence-transformers: pip install sentence-transformers")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
return self.model.embed_documents(texts)
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
return await self.model.aembed_documents(texts)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
return self.model.embed_query(text)
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
return await self.model.aembed_query(text)
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
import os, dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from utils.Constant import Constant
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
# 加载环境变量
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
||||||
|
os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL")
|
||||||
|
|
||||||
|
class LLM_Model_Base(object):
|
||||||
|
'''
|
||||||
|
语言模型基类
|
||||||
|
所有语言模型类的基类,定义了语言模型的基本属性和方法。
|
||||||
|
- 语言模型名称, 缺省为"gpt-4o-mini"
|
||||||
|
- 温度,缺省为0.7
|
||||||
|
- 语言模型实例, 由子类实现
|
||||||
|
- 语言模型模式, 由子类实现
|
||||||
|
- 语言模型名称, 用于描述语言模型, 在人机界面中显示
|
||||||
|
|
||||||
|
author: DrGraph
|
||||||
|
date: 2025-11-20
|
||||||
|
'''
|
||||||
|
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||||
|
self.model_name = model_name # 0.15 0.6
|
||||||
|
self.temperature = temperature
|
||||||
|
self.llmModel = None
|
||||||
|
self.mode = Constant.LLM_MODE_NONE
|
||||||
|
self.name = '未知模型'
|
||||||
|
|
||||||
|
def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str):
|
||||||
|
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
|
||||||
|
prompt_template = PromptTemplate.from_template(
|
||||||
|
template="请回答以下问题: {question}",
|
||||||
|
)
|
||||||
|
prompt_template_value = None
|
||||||
|
if methodType == "format":
|
||||||
|
# 方式1 - 使用format方法,取得字符串
|
||||||
|
prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串
|
||||||
|
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法,取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}")
|
||||||
|
|
||||||
|
if valueType == "str":
|
||||||
|
# 1.1 直接用字符串进行调用LLM的invoke
|
||||||
|
prompt_template_value = prompt_str
|
||||||
|
logger.info(f"{self.name} >>> 1.2.1 直接使用字符串")
|
||||||
|
|
||||||
|
elif valueType == "messages":
|
||||||
|
# 1.2 由字符串,创建HumanMessage对象列表
|
||||||
|
prompt_template_value = [HumanMessage(content=prompt)]
|
||||||
|
logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表")
|
||||||
|
|
||||||
|
elif methodType == "invoke":
|
||||||
|
# 方式2 - 使用invoke方法,取得PromptValue
|
||||||
|
prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue
|
||||||
|
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法,取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}")
|
||||||
|
if valueType == "str":
|
||||||
|
# 2.1 再倒回字符串方式
|
||||||
|
prompt_template_value = prompt_value.to_string()
|
||||||
|
logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串")
|
||||||
|
elif valueType == "promptValue":
|
||||||
|
# 2.2 直接使用 prompt_value 作为 prompt_template_value
|
||||||
|
prompt_template_value = prompt_value
|
||||||
|
logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value")
|
||||||
|
elif valueType == "messages":
|
||||||
|
# 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表
|
||||||
|
prompt_template_value = prompt_value.to_messages()
|
||||||
|
logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表")
|
||||||
|
|
||||||
|
logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}")
|
||||||
|
return prompt_template_value
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from DrGraph.utils.Constant import Constant
|
||||||
|
from LLM.llm_model_base import LLM_Model_Base
|
||||||
|
|
||||||
|
class Chat_LLM(LLM_Model_Base):
|
||||||
|
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||||
|
super().__init__(model_name, temperature)
|
||||||
|
self.name = '聊天模型'
|
||||||
|
self.mode = Constant.LLM_MODE_CHAT
|
||||||
|
self.llmModel = ChatOpenAI(
|
||||||
|
model_name=self.model_name,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回消息格式,以便在chatbot中显示
|
||||||
|
def invoke(self, prompt: str):
|
||||||
|
prompt_template_value = self.buildPromptTemplateValue(
|
||||||
|
prompt=prompt,
|
||||||
|
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
|
||||||
|
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
|
||||||
|
try:
|
||||||
|
response = self.llmModel.invoke(prompt_template_value)
|
||||||
|
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
|
||||||
|
# response = {"role": "assistant", "content": response.content}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return response
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
'''
|
||||||
|
非聊天模型类,继承自 LLM_Model_Base
|
||||||
|
|
||||||
|
author: DrGraph
|
||||||
|
date: 2025-11-20
|
||||||
|
'''
|
||||||
|
from loguru import logger
|
||||||
|
from langchain_openai import OpenAI
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from DrGraph.utils.Constant import Constant
|
||||||
|
from LLM.llm_model_base import LLM_Model_Base
|
||||||
|
|
||||||
|
|
||||||
|
class NonChat_LLM(LLM_Model_Base):
|
||||||
|
'''
|
||||||
|
非聊天模型类,继承自 LLM_Model_Base,调用这个非聊天模型OpenAI
|
||||||
|
- 语言模型名称, 缺省为"gpt-4o-mini"
|
||||||
|
- 温度,缺省为0.7
|
||||||
|
- 语言模型名称 = "非聊天模型", 在人机界面中显示
|
||||||
|
'''
|
||||||
|
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
|
||||||
|
super().__init__(model_name, temperature)
|
||||||
|
self.name = '非聊天模型'
|
||||||
|
self.mode = Constant.LLM_MODE_NONCHAT
|
||||||
|
self.llmModel = OpenAI(
|
||||||
|
model_name=self.model_name,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
# 返回消息格式,以便在chatbot中显示
|
||||||
|
def invoke(self, prompt: str):
|
||||||
|
'''
|
||||||
|
调用非聊天模型,返回消息格式,以便在chatbot中显示
|
||||||
|
prompt: 用户输入,为字符串类型
|
||||||
|
return: 助手回复,为字符串类型
|
||||||
|
'''
|
||||||
|
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
|
||||||
|
try:
|
||||||
|
response = self.llmModel.invoke(prompt)
|
||||||
|
logger.info(f"{self.name} >>> 1.2 助手回复: {type(response)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from utils.Constant import Constant
|
||||||
|
from th_agenter.llm.llm_model_base import LLM_Model_Base
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
class Chat_Ollama(LLM_Model_Base):
|
||||||
|
def __init__(self, base_url="http://127.0.0.1:11434", model_name: str = "OxW/Qwen3-0.6B-GGUF:latest", temperature: float = 0.7):
|
||||||
|
super().__init__(model_name, temperature)
|
||||||
|
self.name = '私有化Ollama模型'
|
||||||
|
self.base_url = base_url
|
||||||
|
self.llmModel = ChatOllama(
|
||||||
|
base_url = self.base_url,
|
||||||
|
model=model_name,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
self.mode = Constant.LLM_MODE_LOCAL_OLLAMA
|
||||||
|
|
||||||
|
def invoke(self, prompt: str):
|
||||||
|
prompt_template_value = self.buildPromptTemplateValue(
|
||||||
|
prompt=prompt,
|
||||||
|
methodType=Constant.LLM_PROMPT_TEMPLATE_METHOD_INVOKE,
|
||||||
|
valueType=Constant.LLM_PROMPT_VALUE_MESSAGES)
|
||||||
|
try:
|
||||||
|
response = self.llmModel.invoke(prompt_template_value)
|
||||||
|
logger.info(f"{self.name} >>> 2. 助手回复: {type(response)}\n{response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return response
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
from th_agenter.llm.base_llm import BaseLLM
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||||
|
|
||||||
|
|
||||||
|
class LocalLLM(BaseLLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.local_config = config
|
||||||
|
|
||||||
|
def _validate_config(self):
|
||||||
|
if not self.local_config.model_path:
|
||||||
|
raise ValueError("LocalLLM 必须配置 model_path")
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
from langchain_community.llms import LlamaCpp
|
||||||
|
self.model = LlamaCpp(
|
||||||
|
model_path=self.local_config.model_path,
|
||||||
|
temperature=self.local_config.temperature,
|
||||||
|
max_tokens=self.local_config.max_tokens,
|
||||||
|
n_ctx=self.local_config.n_ctx,
|
||||||
|
n_threads=self.local_config.n_threads,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "llama"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
# 适配 LlamaCpp(非 Chat 模型)的调用方式
|
||||||
|
prompt = self._format_messages(messages)
|
||||||
|
text = self.model.invoke(prompt, stop=stop, **kwargs)
|
||||||
|
# 构造 ChatResult(LangChain 标准格式)
|
||||||
|
generation = ChatGeneration(message=AIMessage(content=text))
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
prompt = self._format_messages(messages)
|
||||||
|
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
|
||||||
|
generation = ChatGeneration(message=AIMessage(content=text))
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
def _format_messages(self, messages: List[BaseMessage]) -> str:
|
||||||
|
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
|
||||||
|
prompt_parts = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
prompt_parts.append(msg.content)
|
||||||
|
return "".join(prompt_parts)
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_core.messages import HumanMessage, BaseMessage
|
||||||
|
from typing import List, Optional, Any, Union
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
|
from th_agenter.llm.base_llm import BaseLLM
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
|
||||||
|
class OnlineLLM(BaseLLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def _validate_config(self):
|
||||||
|
if not self.config.api_key:
|
||||||
|
raise ValueError("OnlineLLM 必须配置 api_key")
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
# from langchain.chat_models import init_chat_model
|
||||||
|
# self.model = init_chat_model(
|
||||||
|
# self.config.model_name,
|
||||||
|
# self.config.api_key)
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
self.model = ChatOpenAI(
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
model_name=self.config.model_name,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
)
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "openai" # 标识模型类型
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""委托给底层 LangChain 模型的 _generate 方法"""
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
# 复用底层模型的实现
|
||||||
|
return self.model._generate(
|
||||||
|
messages=messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,** kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if not self.model:
|
||||||
|
self.load_model()
|
||||||
|
return await self.model._agenerate(
|
||||||
|
messages=messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,** kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------- 保留自定义的便捷方法 ----------------------
|
||||||
|
def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
||||||
|
"""自定义便捷方法:直接传入字符串 prompt 或消息列表"""
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
messages = [HumanMessage(content=prompt)]
|
||||||
|
else:
|
||||||
|
messages = prompt
|
||||||
|
result = self._generate(messages, **kwargs)
|
||||||
|
return result.generations[0].text
|
||||||
|
|
||||||
|
async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
||||||
|
"""自定义便捷异步方法:直接传入字符串 prompt 或消息列表"""
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
messages = [HumanMessage(content=prompt)]
|
||||||
|
else:
|
||||||
|
messages = prompt
|
||||||
|
result = await self._agenerate(messages, **kwargs)
|
||||||
|
return result.generations[0].text
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Database models for TH Agenter."""
|
||||||
|
|
||||||
|
from .user import User
|
||||||
|
from .conversation import Conversation
|
||||||
|
from .message import Message
|
||||||
|
from .knowledge_base import KnowledgeBase, Document
|
||||||
|
from .agent_config import AgentConfig
|
||||||
|
from .excel_file import ExcelFile
|
||||||
|
from .permission import Role, UserRole
|
||||||
|
from .llm_config import LLMConfig
|
||||||
|
from .workflow import Workflow, WorkflowExecution, NodeExecution
|
||||||
|
from .database_config import DatabaseConfig
|
||||||
|
from .table_metadata import TableMetadata
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"User",
|
||||||
|
"Conversation",
|
||||||
|
"Message",
|
||||||
|
"KnowledgeBase",
|
||||||
|
"Document",
|
||||||
|
"AgentConfig",
|
||||||
|
"ExcelFile",
|
||||||
|
"Role",
|
||||||
|
"UserRole",
|
||||||
|
"LLMConfig",
|
||||||
|
"Workflow",
|
||||||
|
"WorkflowExecution",
|
||||||
|
"NodeExecution",
|
||||||
|
"DatabaseConfig",
|
||||||
|
"TableMetadata"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Agent configuration model."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import String, Text, Boolean, JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Agent configuration model."""
|
||||||
|
|
||||||
|
__tablename__ = "agent_configs"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Agent configuration
|
||||||
|
enabled_tools: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
max_iterations: Mapped[int] = mapped_column(default=10)
|
||||||
|
temperature: Mapped[str] = mapped_column(String(10), default="0.1")
|
||||||
|
system_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
verbose: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model_name: Mapped[str] = mapped_column(String(100), default="gpt-3.5-turbo")
|
||||||
|
max_tokens: Mapped[int] = mapped_column(default=2048)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
|
is_default: Mapped[bool] = mapped_column(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<AgentConfig(id={self.id}, name='{self.name}', is_active={self.is_active})>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name}[{self.id}] Active: {self.is_active}"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data['enabled_tools'] = self.enabled_tools or []
|
||||||
|
return data
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Conversation model."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy import String, Integer, Text, Boolean, DateTime
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
"""Conversation model."""
|
||||||
|
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("users.id")
|
||||||
|
knowledge_base_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id")
|
||||||
|
system_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
model_name: Mapped[str] = mapped_column(String(100), nullable=False, default="gpt-3.5-turbo")
|
||||||
|
temperature: Mapped[str] = mapped_column(String(10), nullable=False, default="0.7")
|
||||||
|
max_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=2048)
|
||||||
|
is_archived: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
message_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||||
|
last_message_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert conversation to a dictionary."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"title": self.title,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"knowledge_base_id": self.knowledge_base_id,
|
||||||
|
"system_prompt": self.system_prompt,
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"is_archived": self.is_archived,
|
||||||
|
"message_count": self.message_count,
|
||||||
|
"last_message_at": self.last_message_at,
|
||||||
|
}
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id}, system_prompt={self.system_prompt}, model_name='{self.model_name}', temperature='{self.temperature}', message_count={self.message_count})>"
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""数据库配置模型"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from loguru import logger
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy import Integer, String, Text, Boolean, JSON
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# 在现有的DatabaseConfig类中添加关系
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
class DatabaseConfig(BaseModel):
|
||||||
|
"""数据库配置表"""
|
||||||
|
__tablename__ = "database_configs"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False) # 配置名称
|
||||||
|
db_type: Mapped[str] = mapped_column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等
|
||||||
|
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
port: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
database: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
username: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
password: Mapped[str] = mapped_column(Text, nullable=False) # 加密存储
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
connection_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 额外连接参数
|
||||||
|
|
||||||
|
def to_dict(self, include_password=False, decrypt_service=None):
|
||||||
|
result = {
|
||||||
|
"id": self.id,
|
||||||
|
"created_by": self.created_by,
|
||||||
|
"name": self.name,
|
||||||
|
"db_type": self.db_type,
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"database": self.database,
|
||||||
|
"username": self.username,
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"is_default": self.is_default,
|
||||||
|
"connection_params": self.connection_params,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果需要包含密码且提供了解密服务
|
||||||
|
if include_password and decrypt_service:
|
||||||
|
logger.info(f"begin decrypt password for db config {self.id}")
|
||||||
|
result["password"] = decrypt_service._decrypt_password(self.password)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 添加关系
|
||||||
|
# table_metadata = relationship("TableMetadata", back_populates="database_config")
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Excel file models for smart query."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy import String, Integer, Text, Boolean, JSON, DateTime
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class ExcelFile(BaseModel):
|
||||||
|
"""Excel file model for storing file metadata."""
|
||||||
|
__tablename__ = "excel_files"
|
||||||
|
# Basic file information
|
||||||
|
# user_id: Mapped[int] = mapped_column(Integer, nullable=False) # 用户ID
|
||||||
|
original_filename: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名
|
||||||
|
file_path: Mapped[str] = mapped_column(String(500), nullable=False) # 文件存储路径
|
||||||
|
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # 文件大小(字节)
|
||||||
|
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv)
|
||||||
|
|
||||||
|
# Excel specific information
|
||||||
|
sheet_names: Mapped[list] = mapped_column(JSON, nullable=False) # 所有sheet名称列表
|
||||||
|
default_sheet: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # 默认sheet名称
|
||||||
|
|
||||||
|
# Data preview information
|
||||||
|
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]}
|
||||||
|
preview_data: Mapped[dict] = mapped_column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]}
|
||||||
|
data_types: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}}
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
total_rows: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count}
|
||||||
|
total_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count}
|
||||||
|
|
||||||
|
# Processing status
|
||||||
|
is_processed: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否已处理
|
||||||
|
processing_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 处理错误信息
|
||||||
|
|
||||||
|
# Upload information
|
||||||
|
# upload_time: Mapped[DateTime] = mapped_column(DateTime, default=func.now(), nullable=False) # 上传时间
|
||||||
|
last_accessed: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后访问时间
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ExcelFile(id={self.id}, filename='{self.original_filename}')>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_size_mb(self):
|
||||||
|
"""Get file size in MB."""
|
||||||
|
return round(self.file_size / (1024 * 1024), 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sheet_count(self):
|
||||||
|
"""Get number of sheets."""
|
||||||
|
return len(self.sheet_names) if self.sheet_names else 0
|
||||||
|
|
||||||
|
def get_sheet_info(self, sheet_name: str = None):
|
||||||
|
"""Get information for a specific sheet or default sheet."""
|
||||||
|
if not sheet_name:
|
||||||
|
sheet_name = self.default_sheet or (self.sheet_names[0] if self.sheet_names else None)
|
||||||
|
|
||||||
|
if not sheet_name or sheet_name not in self.sheet_names:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sheet_name': sheet_name,
|
||||||
|
'columns': self.columns_info.get(sheet_name, []) if self.columns_info else [],
|
||||||
|
'preview_data': self.preview_data.get(sheet_name, []) if self.preview_data else [],
|
||||||
|
'data_types': self.data_types.get(sheet_name, {}) if self.data_types else {},
|
||||||
|
'total_rows': self.total_rows.get(sheet_name, 0) if self.total_rows else 0,
|
||||||
|
'total_columns': self.total_columns.get(sheet_name, 0) if self.total_columns else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_sheets_summary(self):
|
||||||
|
"""Get summary information for all sheets."""
|
||||||
|
if not self.sheet_names:
|
||||||
|
return []
|
||||||
|
|
||||||
|
summary = []
|
||||||
|
for sheet_name in self.sheet_names:
|
||||||
|
sheet_info = self.get_sheet_info(sheet_name)
|
||||||
|
if sheet_info:
|
||||||
|
summary.append({
|
||||||
|
'sheet_name': sheet_name,
|
||||||
|
'columns_count': len(sheet_info['columns']),
|
||||||
|
'rows_count': sheet_info['total_rows'],
|
||||||
|
'columns': sheet_info['columns'][:10] # 只显示前10列
|
||||||
|
})
|
||||||
|
return summary
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""Knowledge base models."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy import String, Integer, Text, Boolean, JSON
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class KnowledgeBase(BaseModel):
|
||||||
|
"""Knowledge base model."""
|
||||||
|
|
||||||
|
__tablename__ = "knowledge_bases"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(100), unique=False, index=True, nullable=False)
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
embedding_model: Mapped[str] = mapped_column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
chunk_size: Mapped[int] = mapped_column(Integer, nullable=False, default=1000)
|
||||||
|
chunk_overlap: Mapped[int] = mapped_column(Integer, nullable=False, default=200)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Vector database settings
|
||||||
|
vector_db_type: Mapped[str] = mapped_column(String(50), nullable=False, default="chroma")
|
||||||
|
collection_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # For vector DB collection
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<KnowledgeBase(id={self.id}, name='{self.name}')>"
|
||||||
|
|
||||||
|
# Relationships are commented out to remove foreign key constraints, so these properties should be updated
|
||||||
|
# @property
|
||||||
|
# def document_count(self):
|
||||||
|
# """Get the number of documents in this knowledge base."""
|
||||||
|
# return len(self.documents)
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def active_document_count(self):
|
||||||
|
# """Get the number of active documents in this knowledge base."""
|
||||||
|
# return len([doc for doc in self.documents if doc.is_processed])
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
"""Document model."""
|
||||||
|
|
||||||
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
knowledge_base_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id")
|
||||||
|
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
original_filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # in bytes
|
||||||
|
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # .pdf, .txt, .docx, etc.
|
||||||
|
mime_type: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||||
|
|
||||||
|
# Processing status
|
||||||
|
is_processed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
processing_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Content and metadata
|
||||||
|
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # Extracted text content
|
||||||
|
doc_metadata: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Additional metadata
|
||||||
|
|
||||||
|
# Chunking information
|
||||||
|
chunk_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||||
|
|
||||||
|
# Embedding information
|
||||||
|
embedding_model: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||||
|
vector_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # Store vector database IDs for chunks
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Document(id={self.id}, filename='{self.filename}', kb_id={self.knowledge_base_id})>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_size_mb(self):
|
||||||
|
"""Get file size in MB."""
|
||||||
|
return round(self.file_size / (1024 * 1024), 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_text_file(self):
|
||||||
|
"""Check if document is a text file."""
|
||||||
|
return self.file_type.lower() in ['.txt', '.md', '.csv']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pdf_file(self):
|
||||||
|
"""Check if document is a PDF file."""
|
||||||
|
return self.file_type.lower() == '.pdf'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_office_file(self):
|
||||||
|
"""Check if document is an Office file."""
|
||||||
|
return self.file_type.lower() in ['.docx', '.xlsx', '.pptx']
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
"""LLM Configuration model for managing multiple AI models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class LLMConfig(BaseModel):
|
||||||
|
"""LLM Configuration model for managing AI model settings."""
|
||||||
|
__tablename__ = "llm_configs"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
|
||||||
|
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||||||
|
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
|
||||||
|
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥(加密存储)
|
||||||
|
base_url: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) # API基础URL
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
|
||||||
|
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||||
|
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
|
||||||
|
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||||
|
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||||
|
|
||||||
|
# 配置信息
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 配置描述
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
|
||||||
|
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||||||
|
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||||||
|
|
||||||
|
# 扩展配置(JSON格式)
|
||||||
|
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
|
||||||
|
|
||||||
|
# 使用统计
|
||||||
|
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
|
||||||
|
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model_name='{self.model_name}', base_url='{self.base_url}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_sensitive=False):
|
||||||
|
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'provider': self.provider,
|
||||||
|
'model_name': self.model_name,
|
||||||
|
'base_url': self.base_url,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'temperature': self.temperature,
|
||||||
|
'top_p': self.top_p,
|
||||||
|
'frequency_penalty': self.frequency_penalty,
|
||||||
|
'presence_penalty': self.presence_penalty,
|
||||||
|
'description': self.description,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'is_default': self.is_default,
|
||||||
|
'is_embedding': self.is_embedding,
|
||||||
|
'extra_config': self.extra_config,
|
||||||
|
'usage_count': self.usage_count,
|
||||||
|
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_sensitive:
|
||||||
|
data['api_key'] = self.api_key
|
||||||
|
else:
|
||||||
|
# 只显示API密钥的前几位和后几位
|
||||||
|
if self.api_key:
|
||||||
|
key_len = len(self.api_key)
|
||||||
|
if key_len > 8:
|
||||||
|
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
|
||||||
|
else:
|
||||||
|
data['api_key_masked'] = "***"
|
||||||
|
else:
|
||||||
|
data['api_key_masked'] = None
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_client_config(self) -> Dict[str, Any]:
|
||||||
|
"""获取用于创建客户端的配置."""
|
||||||
|
config = {
|
||||||
|
'api_key': self.api_key,
|
||||||
|
'base_url': self.base_url,
|
||||||
|
'model': self.model_name,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'temperature': self.temperature,
|
||||||
|
'top_p': self.top_p,
|
||||||
|
'frequency_penalty': self.frequency_penalty,
|
||||||
|
'presence_penalty': self.presence_penalty
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加额外配置
|
||||||
|
if self.extra_config:
|
||||||
|
config.update(self.extra_config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def validate_config(self) -> Dict[str, Any]:
|
||||||
|
"""验证配置是否有效."""
|
||||||
|
if not self.name or not self.name.strip():
|
||||||
|
return {"valid": False, "error": "配置名称不能为空"}
|
||||||
|
|
||||||
|
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu', 'ollama']:
|
||||||
|
return {"valid": False, "error": f"不支持的服务商 {self.provider}"}
|
||||||
|
|
||||||
|
if not self.model_name or not self.model_name.strip():
|
||||||
|
return {"valid": False, "error": "模型名称不能为空"}
|
||||||
|
|
||||||
|
if not self.api_key or not self.api_key.strip():
|
||||||
|
return {"valid": False, "error": "API密钥不能为空"}
|
||||||
|
|
||||||
|
if self.max_tokens <= 0 or self.max_tokens > 32000:
|
||||||
|
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
|
||||||
|
|
||||||
|
if self.temperature < 0 or self.temperature > 2:
|
||||||
|
return {"valid": False, "error": "温度参数必须在0-2之间"}
|
||||||
|
|
||||||
|
return {"valid": True, "error": None}
|
||||||
|
|
||||||
|
def increment_usage(self):
|
||||||
|
"""增加使用次数."""
|
||||||
|
self.usage_count += 1
|
||||||
|
self.last_used_at = datetime.now()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_config(cls, provider: str, is_embedding: bool = False):
|
||||||
|
"""获取服务商的默认配置模板."""
|
||||||
|
templates = {
|
||||||
|
'openai': {
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'deepseek': {
|
||||||
|
'base_url': 'https://api.deepseek.com/v1',
|
||||||
|
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'doubao': {
|
||||||
|
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||||
|
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'zhipu': {
|
||||||
|
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||||
|
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'moonshot': {
|
||||||
|
'base_url': 'https://api.moonshot.cn/v1',
|
||||||
|
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates.get(provider, {})
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""Message model."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import String, Integer, Text, Enum, JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, enum.Enum):
|
||||||
|
"""Message role enumeration."""
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, enum.Enum):
|
||||||
|
"""Message type enumeration."""
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
FILE = "file"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""Message model."""
|
||||||
|
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
conversation_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("conversations.id")
|
||||||
|
role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False)
|
||||||
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||||
|
message_metadata: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Store additional data like file info, tokens used, etc.
|
||||||
|
|
||||||
|
# For knowledge base context
|
||||||
|
context_documents: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # Store retrieved document references
|
||||||
|
|
||||||
|
# Token usage tracking
|
||||||
|
prompt_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
completion_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
total_tokens: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content
|
||||||
|
return f"<Message(id={self.id}, role='{self.role}', content='{content_preview}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_metadata=True):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
data = super().to_dict()
|
||||||
|
if not include_metadata:
|
||||||
|
data.pop('message_metadata', None)
|
||||||
|
data.pop('context_documents', None)
|
||||||
|
data.pop('prompt_tokens', None)
|
||||||
|
data.pop('completion_tokens', None)
|
||||||
|
data.pop('total_tokens', None)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_from_user(self):
|
||||||
|
"""Check if message is from user."""
|
||||||
|
return self.role == MessageRole.USER
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_from_assistant(self):
|
||||||
|
"""Check if message is from assistant."""
|
||||||
|
return self.role == MessageRole.ASSISTANT
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Role models for simplified RBAC system."""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from sqlalchemy import String, Text, Boolean, ForeignKey, Integer
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from ..db.base import BaseModel, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Role(BaseModel):
|
||||||
|
"""Role model for simplified RBAC system."""
|
||||||
|
|
||||||
|
__tablename__ = "roles"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色名称
|
||||||
|
code: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色编码
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 角色描述
|
||||||
|
is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否系统角色
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# 关系 - 只保留用户关系
|
||||||
|
users = relationship("User", secondary="user_roles", back_populates="roles")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Role(id={self.id}, code='{self.code}', name='{self.name}')>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'code': self.code,
|
||||||
|
'description': self.description,
|
||||||
|
'is_system': self.is_system,
|
||||||
|
'is_active': self.is_active
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(Base):
|
||||||
|
"""User role association model."""
|
||||||
|
|
||||||
|
__tablename__ = "user_roles"
|
||||||
|
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'), primary_key=True)
|
||||||
|
role_id: Mapped[int] = mapped_column(Integer, ForeignKey('roles.id'), primary_key=True)
|
||||||
|
|
||||||
|
# 关系 - 用于直接操作关联表的场景
|
||||||
|
user = relationship("User", viewonly=True)
|
||||||
|
role = relationship("Role", viewonly=True)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<UserRole(user_id={self.user_id}, role_id={self.role_id})>"
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""表元数据模型"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import Integer, String, Text, DateTime, Boolean, JSON
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class TableMetadata(BaseModel):
|
||||||
|
"""表元数据表"""
|
||||||
|
__tablename__ = "table_metadata"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||||
|
# database_config_id = Column(Integer, ForeignKey('database_configs.id'), nullable=False)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||||
|
table_schema: Mapped[str] = mapped_column(String(50), default='public')
|
||||||
|
table_type: Mapped[str] = mapped_column(String(20), default='BASE TABLE')
|
||||||
|
table_comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 表描述
|
||||||
|
database_config_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) #数据库配置ID
|
||||||
|
# 表结构信息
|
||||||
|
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:名称、类型、注释等
|
||||||
|
primary_keys: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 主键列表
|
||||||
|
foreign_keys: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 外键信息
|
||||||
|
indexes: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 索引信息
|
||||||
|
|
||||||
|
# 示例数据
|
||||||
|
sample_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # 前5条示例数据
|
||||||
|
row_count: Mapped[int] = mapped_column(Integer, default=0) # 总行数
|
||||||
|
|
||||||
|
# 问答相关
|
||||||
|
is_enabled_for_qa: Mapped[bool] = mapped_column(Boolean, default=True) # 是否启用问答
|
||||||
|
qa_description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 问答描述
|
||||||
|
business_context: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 业务上下文
|
||||||
|
|
||||||
|
last_synced_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) # 最后同步时间
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
# database_config = relationship("DatabaseConfig", back_populates="table_metadata")
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"created_by": self.created_by, # 改为created_by
|
||||||
|
"database_config_id": self.database_config_id,
|
||||||
|
"table_name": self.table_name,
|
||||||
|
"table_schema": self.table_schema,
|
||||||
|
"table_type": self.table_type,
|
||||||
|
"table_comment": self.table_comment,
|
||||||
|
"columns_info": self.columns_info,
|
||||||
|
"primary_keys": self.primary_keys,
|
||||||
|
# "foreign_keys": self.foreign_keys,
|
||||||
|
"indexes": self.indexes,
|
||||||
|
"sample_data": self.sample_data,
|
||||||
|
"row_count": self.row_count,
|
||||||
|
"is_enabled_for_qa": self.is_enabled_for_qa,
|
||||||
|
"qa_description": self.qa_description,
|
||||||
|
"business_context": self.business_context,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"last_synced_at": self.last_synced_at.isoformat() if self.last_synced_at else None
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""User model."""
|
||||||
|
|
||||||
|
from sqlalchemy import String, Boolean, Text
|
||||||
|
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||||
|
from typing import List, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
"""User model."""
|
||||||
|
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False)
|
||||||
|
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
|
||||||
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
full_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
avatar_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 关系 - 只保留角色关系
|
||||||
|
roles = relationship("Role", secondary="user_roles", back_populates="users")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<User(id={self.id}, username='{self.username}', email='{self.email}', full_name='{self.full_name}', is_active={self.is_active}, avatar_url='{self.avatar_url}', bio='{self.bio}', is_superuser={self.is_admin})>"
|
||||||
|
|
||||||
|
def to_dict(self, include_sensitive=False, include_roles=False):
|
||||||
|
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'username': self.username,
|
||||||
|
'email': self.email,
|
||||||
|
'full_name': self.full_name,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'avatar_url': self.avatar_url,
|
||||||
|
'bio': self.bio,
|
||||||
|
'is_superuser': self.is_admin # 使用同步的 is_admin 属性代替异步的 is_superuser 方法
|
||||||
|
})
|
||||||
|
|
||||||
|
if not include_sensitive:
|
||||||
|
data.pop('hashed_password', None)
|
||||||
|
|
||||||
|
if include_roles:
|
||||||
|
try:
|
||||||
|
# 安全访问roles关系属性
|
||||||
|
data['roles'] = [role.to_dict() for role in self.roles if role.is_active]
|
||||||
|
except Exception:
|
||||||
|
# 如果角色关系未加载或访问出错,返回空列表
|
||||||
|
data['roles'] = []
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def has_role(self, role_code: str) -> bool:
|
||||||
|
"""检查用户是否拥有指定角色."""
|
||||||
|
try:
|
||||||
|
# 在异步环境中,需要先加载关系属性
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import object_session
|
||||||
|
from sqlalchemy import select
|
||||||
|
from .permission import Role, UserRole
|
||||||
|
|
||||||
|
session = object_session(self)
|
||||||
|
if isinstance(session, AsyncSession):
|
||||||
|
# 如果是异步会话,使用await加载关系
|
||||||
|
await session.refresh(self, ['roles'])
|
||||||
|
return any(role.code == role_code and role.is_active for role in self.roles)
|
||||||
|
except Exception:
|
||||||
|
# 如果对象已分离或加载关系失败,使用数据库查询
|
||||||
|
from sqlalchemy.orm import object_session
|
||||||
|
from sqlalchemy import select
|
||||||
|
from .permission import Role, UserRole
|
||||||
|
|
||||||
|
session = object_session(self)
|
||||||
|
if session is None:
|
||||||
|
# 如果没有会话,返回False
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
if isinstance(session, AsyncSession):
|
||||||
|
# 如果是异步会话,使用异步查询
|
||||||
|
user_role = await session.execute(
|
||||||
|
select(UserRole).join(Role).filter(
|
||||||
|
UserRole.user_id == self.id,
|
||||||
|
Role.code == role_code,
|
||||||
|
Role.is_active == True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return user_role.scalar_one_or_none() is not None
|
||||||
|
else:
|
||||||
|
# 如果是同步会话,使用同步查询
|
||||||
|
user_role = session.query(UserRole).join(Role).filter(
|
||||||
|
UserRole.user_id == self.id,
|
||||||
|
Role.code == role_code,
|
||||||
|
Role.is_active == True
|
||||||
|
).first()
|
||||||
|
return user_role is not None
|
||||||
|
|
||||||
|
async def is_superuser(self) -> bool:
|
||||||
|
"""检查用户是否为超级管理员."""
|
||||||
|
return await self.has_role('SUPER_ADMIN')
|
||||||
|
|
||||||
|
async def is_admin_user(self) -> bool:
|
||||||
|
"""检查用户是否为管理员(兼容性方法)."""
|
||||||
|
return await self.is_superuser()
|
||||||
|
|
||||||
|
# 注意:属性方式的 is_admin 无法是异步的,所以我们改为同步方法并简化实现
|
||||||
|
@property
|
||||||
|
def is_admin(self) -> bool:
|
||||||
|
"""检查用户是否为管理员(属性方式)."""
|
||||||
|
# 同步属性无法使用 await,所以我们只能检查已加载的角色
|
||||||
|
# 使用try-except捕获可能的MissingGreenlet错误
|
||||||
|
try:
|
||||||
|
# 检查角色关系是否已经加载
|
||||||
|
# 如果roles属性是一个InstrumentedList且已经加载,那么它应该有__iter__方法
|
||||||
|
return any(role.code == 'SUPER_ADMIN' and role.is_active for role in self.roles)
|
||||||
|
except Exception:
|
||||||
|
# 如果角色关系未加载或访问出错,返回False
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""Workflow models."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import String, Text, Boolean, Integer, JSON, ForeignKey, Enum
|
||||||
|
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
class WorkflowStatus(enum.Enum):
|
||||||
|
"""工作流状态枚举"""
|
||||||
|
DRAFT = "DRAFT" # 草稿
|
||||||
|
PUBLISHED = "PUBLISHED" # 已发布
|
||||||
|
ARCHIVED = "ARCHIVED" # 已归档
|
||||||
|
|
||||||
|
class NodeType(enum.Enum):
|
||||||
|
"""节点类型枚举"""
|
||||||
|
START = "start" # 开始节点
|
||||||
|
END = "end" # 结束节点
|
||||||
|
LLM = "llm" # 大模型节点
|
||||||
|
CONDITION = "condition" # 条件分支节点
|
||||||
|
LOOP = "loop" # 循环节点
|
||||||
|
CODE = "code" # 代码执行节点
|
||||||
|
HTTP = "http" # HTTP请求节点
|
||||||
|
TOOL = "tool" # 工具节点
|
||||||
|
|
||||||
|
class ExecutionStatus(enum.Enum):
|
||||||
|
"""执行状态枚举"""
|
||||||
|
PENDING = "pending" # 等待执行
|
||||||
|
RUNNING = "running" # 执行中
|
||||||
|
COMPLETED = "completed" # 执行完成
|
||||||
|
FAILED = "failed" # 执行失败
|
||||||
|
CANCELLED = "cancelled" # 已取消
|
||||||
|
|
||||||
|
class Workflow(BaseModel):
|
||||||
|
"""工作流模型"""
|
||||||
|
__tablename__ = "workflows"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False, comment="工作流名称")
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="工作流描述")
|
||||||
|
status: Mapped[WorkflowStatus] = mapped_column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||||
|
|
||||||
|
# 工作流定义(JSON格式存储节点和连接信息)
|
||||||
|
definition: Mapped[dict] = mapped_column(JSON, nullable=False, comment="工作流定义")
|
||||||
|
|
||||||
|
# 版本信息
|
||||||
|
version: Mapped[str] = mapped_column(String(20), default="1.0.0", nullable=False, comment="版本号")
|
||||||
|
|
||||||
|
# 关联用户
|
||||||
|
owner_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Workflow(id={self.id}, name='{self.name}', status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_definition=True):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'description': self.description,
|
||||||
|
'status': self.status.value,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'version': self.version,
|
||||||
|
'owner_id': self.owner_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_definition:
|
||||||
|
data['definition'] = self.definition
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
class WorkflowExecution(BaseModel):
|
||||||
|
"""工作流执行记录"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_executions"
|
||||||
|
|
||||||
|
workflow_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID")
|
||||||
|
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||||
|
|
||||||
|
# 执行输入和输出
|
||||||
|
input_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||||
|
output_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||||
|
|
||||||
|
# 执行信息
|
||||||
|
started_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||||
|
completed_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||||
|
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||||
|
|
||||||
|
# 执行者
|
||||||
|
executor_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
workflow = relationship("Workflow", back_populates="executions")
|
||||||
|
node_executions = relationship("NodeExecution", back_populates="workflow_execution", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<WorkflowExecution(id={self.id}, workflow_id={self.workflow_id}, status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_nodes=False):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'workflow_id': self.workflow_id,
|
||||||
|
'status': self.status.value,
|
||||||
|
'input_data': self.input_data,
|
||||||
|
'output_data': self.output_data,
|
||||||
|
'started_at': self.started_at,
|
||||||
|
'completed_at': self.completed_at,
|
||||||
|
'error_message': self.error_message,
|
||||||
|
'executor_id': self.executor_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_nodes:
|
||||||
|
data['node_executions'] = [node.to_dict() for node in self.node_executions]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
class NodeExecution(BaseModel):
|
||||||
|
"""节点执行记录"""
|
||||||
|
__tablename__ = "node_executions"
|
||||||
|
|
||||||
|
workflow_execution_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID")
|
||||||
|
node_id: Mapped[str] = mapped_column(String(50), nullable=False, comment="节点ID")
|
||||||
|
node_type: Mapped[NodeType] = mapped_column(Enum(NodeType), nullable=False, comment="节点类型")
|
||||||
|
node_name: Mapped[str] = mapped_column(String(100), nullable=False, comment="节点名称")
|
||||||
|
|
||||||
|
# 执行状态和结果
|
||||||
|
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||||
|
input_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||||
|
output_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||||
|
|
||||||
|
# 执行时间
|
||||||
|
started_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||||
|
completed_at: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||||
|
duration_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, comment="执行时长(毫秒)")
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
workflow_execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<NodeExecution(id={self.id}, node_id='{self.node_id}', status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'workflow_execution_id': self.workflow_execution_id,
|
||||||
|
'node_id': self.node_id,
|
||||||
|
'node_type': self.node_type.value,
|
||||||
|
'node_name': self.node_name,
|
||||||
|
'status': self.status.value,
|
||||||
|
'input_data': self.input_data,
|
||||||
|
'output_data': self.output_data,
|
||||||
|
'started_at': self.started_at,
|
||||||
|
'completed_at': self.completed_at,
|
||||||
|
'duration_ms': self.duration_ms,
|
||||||
|
'error_message': self.error_message
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Schemas package initialization."""
|
||||||
|
|
||||||
|
from .user import UserCreate, UserUpdate, UserResponse
|
||||||
|
from .permission import (
|
||||||
|
RoleCreate, RoleUpdate, RoleResponse,
|
||||||
|
UserRoleAssign
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# User schemas
|
||||||
|
"UserCreate", "UserUpdate", "UserResponse",
|
||||||
|
|
||||||
|
# Permission schemas
|
||||||
|
"RoleCreate", "RoleUpdate", "RoleResponse",
|
||||||
|
"UserRoleAssign",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
"""LLM Configuration Pydantic schemas."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigBase(BaseModel):
|
||||||
|
"""大模型配置基础模式."""
|
||||||
|
name: str = Field(..., min_length=1, max_length=100, description="配置名称")
|
||||||
|
provider: str = Field(..., min_length=1, max_length=50, description="服务商")
|
||||||
|
model_name: str = Field(..., min_length=1, max_length=100, description="模型名称")
|
||||||
|
api_key: str = Field(..., min_length=1, description="API密钥")
|
||||||
|
base_url: Optional[str] = Field(None, description="API基础URL")
|
||||||
|
max_tokens: Optional[int] = Field(4096, ge=1, le=32000, description="最大令牌数")
|
||||||
|
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数")
|
||||||
|
top_p: Optional[float] = Field(1.0, ge=0.0, le=1.0, description="Top-p参数")
|
||||||
|
frequency_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="频率惩罚")
|
||||||
|
presence_penalty: Optional[float] = Field(0.0, ge=-2.0, le=2.0, description="存在惩罚")
|
||||||
|
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
||||||
|
|
||||||
|
is_active: bool = Field(True, description="是否激活")
|
||||||
|
is_default: bool = Field(False, description="是否为默认配置")
|
||||||
|
is_embedding: bool = Field(False, description="是否为嵌入模型")
|
||||||
|
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigCreate(LLMConfigBase):
|
||||||
|
"""创建大模型配置模式."""
|
||||||
|
|
||||||
|
@field_validator('provider')
|
||||||
|
@classmethod
|
||||||
|
def validate_provider(cls, v: str) -> str:
|
||||||
|
allowed_providers = [
|
||||||
|
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||||
|
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||||
|
'ollama', 'custom', "doubao", "ollama"
|
||||||
|
]
|
||||||
|
if v.lower() not in allowed_providers:
|
||||||
|
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
@field_validator('api_key')
|
||||||
|
@classmethod
|
||||||
|
def validate_api_key(cls, v: str) -> str:
|
||||||
|
if len(v.strip()) < 10:
|
||||||
|
raise ValueError('API密钥长度不能少于10个字符')
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigUpdate(BaseModel):
|
||||||
|
"""更新大模型配置模式."""
|
||||||
|
name: Optional[str] = Field(None, min_length=1, max_length=100, description="配置名称")
|
||||||
|
provider: Optional[str] = Field(None, min_length=1, max_length=50, description="服务商")
|
||||||
|
model_name: Optional[str] = Field(None, min_length=1, max_length=100, description="模型名称")
|
||||||
|
api_key: Optional[str] = Field(None, min_length=1, description="API密钥")
|
||||||
|
base_url: Optional[str] = Field(None, description="API基础URL")
|
||||||
|
max_tokens: Optional[int] = Field(None, ge=1, le=32000, description="最大令牌数")
|
||||||
|
temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="温度参数")
|
||||||
|
top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Top-p参数")
|
||||||
|
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="频率惩罚")
|
||||||
|
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="存在惩罚")
|
||||||
|
description: Optional[str] = Field(None, max_length=500, description="配置描述")
|
||||||
|
|
||||||
|
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||||
|
is_default: Optional[bool] = Field(None, description="是否为默认配置")
|
||||||
|
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
|
||||||
|
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
||||||
|
|
||||||
|
@field_validator('provider')
|
||||||
|
@classmethod
|
||||||
|
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is not None:
|
||||||
|
allowed_providers = [
|
||||||
|
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||||
|
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||||
|
'ollama', 'custom',"doubao", "ollama"
|
||||||
|
]
|
||||||
|
if v.lower() not in allowed_providers:
|
||||||
|
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||||
|
return v.lower()
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('api_key')
|
||||||
|
@classmethod
|
||||||
|
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is not None and len(v.strip()) < 10:
|
||||||
|
raise ValueError('API密钥长度不能少于10个字符')
|
||||||
|
return v.strip() if v else v
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigResponse(BaseModel):
|
||||||
|
"""大模型配置响应模式."""
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
provider: str
|
||||||
|
model_name: str
|
||||||
|
api_key: Optional[str] = None # 完整的API密钥(仅在include_sensitive=True时返回)
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
is_active: bool
|
||||||
|
is_default: bool
|
||||||
|
is_embedding: bool
|
||||||
|
extra_config: Optional[Dict[str, Any]] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
created_by: Optional[int] = None
|
||||||
|
updated_by: Optional[int] = None
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
'from_attributes': True
|
||||||
|
}
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def api_key_masked(self) -> Optional[str]:
|
||||||
|
# 在响应中隐藏API密钥,只显示前4位和后4位
|
||||||
|
if self.api_key:
|
||||||
|
key = self.api_key
|
||||||
|
if len(key) > 8:
|
||||||
|
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
|
||||||
|
else:
|
||||||
|
return '*' * len(key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigTest(BaseModel):
|
||||||
|
"""大模型配置测试模式."""
|
||||||
|
message: Optional[str] = Field(
|
||||||
|
"Hello, this is a test message.",
|
||||||
|
max_length=1000,
|
||||||
|
description="测试消息"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigClientResponse(BaseModel):
|
||||||
|
"""大模型配置客户端响应模式(用于前端)."""
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
provider: str
|
||||||
|
model_name: str
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
is_active: bool
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
'from_attributes': True
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Role Pydantic schemas."""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class RoleBase(BaseModel):
|
||||||
|
"""角色基础模式."""
|
||||||
|
name: str = Field(..., min_length=1, max_length=100, description="角色名称")
|
||||||
|
code: str = Field(..., min_length=1, max_length=50, description="角色代码")
|
||||||
|
description: Optional[str] = Field(None, max_length=500, description="角色描述")
|
||||||
|
sort_order: Optional[int] = Field(0, ge=0, description="排序")
|
||||||
|
is_active: bool = Field(True, description="是否激活")
|
||||||
|
|
||||||
|
class RoleCreate(RoleBase):
|
||||||
|
"""创建角色模式."""
|
||||||
|
|
||||||
|
@field_validator('code')
|
||||||
|
@classmethod
|
||||||
|
def validate_code(cls, v: str) -> str:
|
||||||
|
if not v.replace('_', '').replace('-', '').isalnum():
|
||||||
|
raise ValueError('角色代码只能包含字母、数字、下划线和连字符')
|
||||||
|
return v.upper()
|
||||||
|
|
||||||
|
class RoleUpdate(BaseModel):
|
||||||
|
"""更新角色模式."""
|
||||||
|
name: Optional[str] = Field(None, min_length=1, max_length=100, description="角色名称")
|
||||||
|
code: Optional[str] = Field(None, min_length=1, max_length=50, description="角色代码")
|
||||||
|
description: Optional[str] = Field(None, max_length=500, description="角色描述")
|
||||||
|
sort_order: Optional[int] = Field(None, ge=0, description="排序")
|
||||||
|
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||||
|
|
||||||
|
@field_validator('code')
|
||||||
|
@classmethod
|
||||||
|
def validate_code(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is not None and not v.replace('_', '').replace('-', '').isalnum():
|
||||||
|
raise ValueError('角色代码只能包含字母、数字、下划线和连字符')
|
||||||
|
return v.upper() if v else v
|
||||||
|
|
||||||
|
|
||||||
|
class RoleResponse(RoleBase):
|
||||||
|
"""角色响应模式."""
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
created_by: Optional[int] = None
|
||||||
|
updated_by: Optional[int] = None
|
||||||
|
|
||||||
|
# 关联信息
|
||||||
|
user_count: Optional[int] = 0
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"from_attributes": True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UserRoleAssign(BaseModel):
|
||||||
|
"""用户角色分配模式."""
|
||||||
|
user_id: int = Field(..., description="用户ID")
|
||||||
|
role_ids: List[int] = Field(..., description="角色ID列表")
|
||||||
|
|
||||||
|
@field_validator('role_ids')
|
||||||
|
@classmethod
|
||||||
|
def validate_role_ids(cls, v: List[int]) -> List[int]:
|
||||||
|
if not v:
|
||||||
|
raise ValueError('角色ID列表不能为空')
|
||||||
|
if len(v) != len(set(v)):
|
||||||
|
raise ValueError('角色ID列表不能包含重复项')
|
||||||
|
return v
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""User schemas."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from utils.util_schemas import BaseResponse
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
"""User base schema."""
|
||||||
|
username: str = Field(..., min_length=3, max_length=50)
|
||||||
|
email: str = Field(..., max_length=100)
|
||||||
|
full_name: Optional[str] = Field(None, max_length=100)
|
||||||
|
bio: Optional[str] = None
|
||||||
|
avatar_url: Optional[str] = None
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
"""User creation schema."""
|
||||||
|
password: str = Field(..., min_length=6)
|
||||||
|
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
"""User update schema."""
|
||||||
|
username: Optional[str] = Field(None, min_length=3, max_length=50)
|
||||||
|
email: Optional[str] = Field(None, max_length=100)
|
||||||
|
full_name: Optional[str] = Field(None, max_length=100)
|
||||||
|
bio: Optional[str] = None
|
||||||
|
avatar_url: Optional[str] = None
|
||||||
|
password: Optional[str] = Field(None, min_length=6)
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
"""Change password request schema."""
|
||||||
|
current_password: str = Field(..., description="Current password")
|
||||||
|
new_password: str = Field(..., min_length=6, description="New password")
|
||||||
|
|
||||||
|
class ResetPasswordRequest(BaseModel):
|
||||||
|
"""Admin reset password request schema."""
|
||||||
|
new_password: str = Field(..., min_length=6, description="New password")
|
||||||
|
|
||||||
|
class UserResponse(BaseResponse, UserBase):
|
||||||
|
"""User response schema."""
|
||||||
|
is_active: bool
|
||||||
|
is_superuser: Optional[bool] = Field(default=False, description="是否为超级管理员")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
'from_attributes': True
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_validate(cls, obj, *, from_attributes=False):
|
||||||
|
"""从对象创建响应模型,正确处理is_superuser方法"""
|
||||||
|
if hasattr(obj, '__dict__'):
|
||||||
|
data = obj.__dict__.copy()
|
||||||
|
# 调用is_superuser方法获取布尔值
|
||||||
|
if hasattr(obj, 'is_admin'):
|
||||||
|
# 使用同步的 is_admin 属性代替异步的 is_superuser 方法
|
||||||
|
data['is_superuser'] = obj.is_admin
|
||||||
|
elif hasattr(obj, 'is_superuser') and not callable(obj.is_superuser):
|
||||||
|
# 如果is_superuser是属性而不是方法
|
||||||
|
data['is_superuser'] = obj.is_superuser
|
||||||
|
return super().model_validate(data)
|
||||||
|
return super().model_validate(obj, from_attributes=from_attributes)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
"""登录响应模型,包含令牌和用户信息"""
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
expires_in: int
|
||||||
|
user: UserResponse
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue