jcq #1
|
|
@ -1,6 +1,7 @@
|
|||
# th - 企业级智能体应用平台
|
||||
|
||||
|
||||
|
||||
🚀 **完全开源的大模型应用平台**
|
||||
- 集成智能问答、智能问数、知识库、工作流和智能体编排的大模型解决方案。
|
||||
- 采用Vue.js + FastAPI + PostgreSQL+Langchain/LangGraph架构。
|
||||
|
|
@ -461,4 +462,4 @@ http://113.240.110.92:81/
|
|||
|
||||
|
||||
|
||||
**如果这个项目对你有帮助,请给它一个 ⭐️!**
|
||||
**如果这个项目对你有帮助,请给它一个 ⭐️!**
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control. See https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
webIOs/output/logs/
|
||||
|
||||
# OS generated files
|
||||
Thumbs.db
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
|
||||
# FastAPI specific
|
||||
*.pyc
|
||||
uvicorn*.log
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
Generic single-database configuration with an async dbapi.
|
||||
|
||||
alembic revision --autogenerate -m "init"
|
||||
alembic upgrade head
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import asyncio, os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if not database_url:
|
||||
raise ValueError("环境变量DATABASE_URL未设置")
|
||||
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
from th_agenter.db import Base
|
||||
# from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
"""init
|
||||
|
||||
Revision ID: 1ea5548d641d
|
||||
Revises:
|
||||
Create Date: 2025-12-13 13:47:07.838600
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1ea5548d641d'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('agent_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('enabled_tools', sa.JSON(), nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('system_message', sa.Text(), nullable=True),
|
||||
sa.Column('verbose', sa.Boolean(), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_agent_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_agent_configs_id'), 'agent_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_agent_configs_name'), 'agent_configs', ['name'], unique=False)
|
||||
op.create_table('conversations',
|
||||
sa.Column('title', sa.String(length=200), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=True),
|
||||
sa.Column('system_prompt', sa.Text(), nullable=True),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('temperature', sa.String(length=10), nullable=False),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('is_archived', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_conversations'))
|
||||
)
|
||||
op.create_index(op.f('ix_conversations_id'), 'conversations', ['id'], unique=False)
|
||||
op.create_table('database_configs',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('db_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('host', sa.String(length=255), nullable=False),
|
||||
sa.Column('port', sa.Integer(), nullable=False),
|
||||
sa.Column('database', sa.String(length=100), nullable=False),
|
||||
sa.Column('username', sa.String(length=100), nullable=False),
|
||||
sa.Column('password', sa.Text(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('connection_params', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_database_configs')),
|
||||
sa.UniqueConstraint('db_type', name=op.f('uq_database_configs_db_type'))
|
||||
)
|
||||
op.create_index(op.f('ix_database_configs_id'), 'database_configs', ['id'], unique=False)
|
||||
op.create_table('documents',
|
||||
sa.Column('knowledge_base_id', sa.Integer(), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('mime_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('doc_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('chunk_count', sa.Integer(), nullable=False),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=True),
|
||||
sa.Column('vector_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_documents'))
|
||||
)
|
||||
op.create_index(op.f('ix_documents_id'), 'documents', ['id'], unique=False)
|
||||
op.create_table('excel_files',
|
||||
sa.Column('original_filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('file_size', sa.Integer(), nullable=False),
|
||||
sa.Column('file_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('sheet_names', sa.JSON(), nullable=False),
|
||||
sa.Column('default_sheet', sa.String(length=100), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('preview_data', sa.JSON(), nullable=False),
|
||||
sa.Column('data_types', sa.JSON(), nullable=True),
|
||||
sa.Column('total_rows', sa.JSON(), nullable=True),
|
||||
sa.Column('total_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('is_processed', sa.Boolean(), nullable=False),
|
||||
sa.Column('processing_error', sa.Text(), nullable=True),
|
||||
sa.Column('last_accessed', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_excel_files'))
|
||||
)
|
||||
op.create_index(op.f('ix_excel_files_id'), 'excel_files', ['id'], unique=False)
|
||||
op.create_table('knowledge_bases',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('embedding_model', sa.String(length=100), nullable=False),
|
||||
sa.Column('chunk_size', sa.Integer(), nullable=False),
|
||||
sa.Column('chunk_overlap', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('vector_db_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('collection_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_knowledge_bases'))
|
||||
)
|
||||
op.create_index(op.f('ix_knowledge_bases_id'), 'knowledge_bases', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_knowledge_bases_name'), 'knowledge_bases', ['name'], unique=False)
|
||||
op.create_table('llm_configs',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('api_key', sa.String(length=500), nullable=False),
|
||||
sa.Column('base_url', sa.String(length=200), nullable=True),
|
||||
sa.Column('max_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('temperature', sa.Float(), nullable=False),
|
||||
sa.Column('top_p', sa.Float(), nullable=False),
|
||||
sa.Column('frequency_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('presence_penalty', sa.Float(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_default', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_embedding', sa.Boolean(), nullable=False),
|
||||
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_llm_configs'))
|
||||
)
|
||||
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||
op.create_table('messages',
|
||||
sa.Column('conversation_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sa.Enum('USER', 'ASSISTANT', 'SYSTEM', name='messagerole'), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('message_type', sa.Enum('TEXT', 'IMAGE', 'FILE', 'AUDIO', name='messagetype'), nullable=False),
|
||||
sa.Column('message_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('context_documents', sa.JSON(), nullable=True),
|
||||
sa.Column('prompt_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('completion_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_messages'))
|
||||
)
|
||||
op.create_index(op.f('ix_messages_id'), 'messages', ['id'], unique=False)
|
||||
op.create_table('roles',
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('code', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_system', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_roles'))
|
||||
)
|
||||
op.create_index(op.f('ix_roles_code'), 'roles', ['code'], unique=True)
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=True)
|
||||
op.create_table('table_metadata',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('table_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('table_schema', sa.String(length=50), nullable=False),
|
||||
sa.Column('table_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('table_comment', sa.Text(), nullable=True),
|
||||
sa.Column('database_config_id', sa.Integer(), nullable=True),
|
||||
sa.Column('columns_info', sa.JSON(), nullable=False),
|
||||
sa.Column('primary_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('foreign_keys', sa.JSON(), nullable=True),
|
||||
sa.Column('indexes', sa.JSON(), nullable=True),
|
||||
sa.Column('sample_data', sa.JSON(), nullable=True),
|
||||
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||
sa.Column('is_enabled_for_qa', sa.Boolean(), nullable=False),
|
||||
sa.Column('qa_description', sa.Text(), nullable=True),
|
||||
sa.Column('business_context', sa.Text(), nullable=True),
|
||||
sa.Column('last_synced_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_table_metadata'))
|
||||
)
|
||||
op.create_index(op.f('ix_table_metadata_id'), 'table_metadata', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_table_metadata_table_name'), 'table_metadata', ['table_name'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('username', sa.String(length=50), nullable=False),
|
||||
sa.Column('email', sa.String(length=100), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('avatar_url', sa.String(length=255), nullable=True),
|
||||
sa.Column('bio', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_users'))
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('user_roles',
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], name=op.f('fk_user_roles_role_id_roles')),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_user_roles_user_id_users')),
|
||||
sa.PrimaryKeyConstraint('user_id', 'role_id', name=op.f('pk_user_roles'))
|
||||
)
|
||||
op.create_table('workflows',
|
||||
sa.Column('name', sa.String(length=100), nullable=False, comment='工作流名称'),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='工作流描述'),
|
||||
sa.Column('status', sa.Enum('DRAFT', 'PUBLISHED', 'ARCHIVED', name='workflowstatus'), nullable=False, comment='工作流状态'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'),
|
||||
sa.Column('definition', sa.JSON(), nullable=False, comment='工作流定义'),
|
||||
sa.Column('version', sa.String(length=20), nullable=False, comment='版本号'),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=False, comment='所有者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name=op.f('fk_workflows_owner_id_users')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflows'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflows_id'), 'workflows', ['id'], unique=False)
|
||||
op.create_table('workflow_executions',
|
||||
sa.Column('workflow_id', sa.Integer(), nullable=False, comment='工作流ID'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('executor_id', sa.Integer(), nullable=False, comment='执行者ID'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['executor_id'], ['users.id'], name=op.f('fk_workflow_executions_executor_id_users')),
|
||||
sa.ForeignKeyConstraint(['workflow_id'], ['workflows.id'], name=op.f('fk_workflow_executions_workflow_id_workflows')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_workflow_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_workflow_executions_id'), 'workflow_executions', ['id'], unique=False)
|
||||
op.create_table('node_executions',
|
||||
sa.Column('workflow_execution_id', sa.Integer(), nullable=False, comment='工作流执行ID'),
|
||||
sa.Column('node_id', sa.String(length=50), nullable=False, comment='节点ID'),
|
||||
sa.Column('node_type', sa.Enum('START', 'END', 'LLM', 'CONDITION', 'LOOP', 'CODE', 'HTTP', 'TOOL', name='nodetype'), nullable=False, comment='节点类型'),
|
||||
sa.Column('node_name', sa.String(length=100), nullable=False, comment='节点名称'),
|
||||
sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', 'CANCELLED', name='executionstatus'), nullable=False, comment='执行状态'),
|
||||
sa.Column('input_data', sa.JSON(), nullable=True, comment='输入数据'),
|
||||
sa.Column('output_data', sa.JSON(), nullable=True, comment='输出数据'),
|
||||
sa.Column('started_at', sa.String(length=50), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.String(length=50), nullable=True, comment='完成时间'),
|
||||
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='执行时长(毫秒)'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['workflow_execution_id'], ['workflow_executions.id'], name=op.f('fk_node_executions_workflow_execution_id_workflow_executions')),
|
||||
sa.PrimaryKeyConstraint('id', name=op.f('pk_node_executions'))
|
||||
)
|
||||
op.create_index(op.f('ix_node_executions_id'), 'node_executions', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_node_executions_id'), table_name='node_executions')
|
||||
op.drop_table('node_executions')
|
||||
op.drop_index(op.f('ix_workflow_executions_id'), table_name='workflow_executions')
|
||||
op.drop_table('workflow_executions')
|
||||
op.drop_index(op.f('ix_workflows_id'), table_name='workflows')
|
||||
op.drop_table('workflows')
|
||||
op.drop_table('user_roles')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_table_metadata_table_name'), table_name='table_metadata')
|
||||
op.drop_index(op.f('ix_table_metadata_id'), table_name='table_metadata')
|
||||
op.drop_table('table_metadata')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_code'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
op.drop_index(op.f('ix_messages_id'), table_name='messages')
|
||||
op.drop_table('messages')
|
||||
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
op.drop_index(op.f('ix_knowledge_bases_name'), table_name='knowledge_bases')
|
||||
op.drop_index(op.f('ix_knowledge_bases_id'), table_name='knowledge_bases')
|
||||
op.drop_table('knowledge_bases')
|
||||
op.drop_index(op.f('ix_excel_files_id'), table_name='excel_files')
|
||||
op.drop_table('excel_files')
|
||||
op.drop_index(op.f('ix_documents_id'), table_name='documents')
|
||||
op.drop_table('documents')
|
||||
op.drop_index(op.f('ix_database_configs_id'), table_name='database_configs')
|
||||
op.drop_table('database_configs')
|
||||
op.drop_index(op.f('ix_conversations_id'), table_name='conversations')
|
||||
op.drop_table('conversations')
|
||||
op.drop_index(op.f('ix_agent_configs_name'), table_name='agent_configs')
|
||||
op.drop_index(op.f('ix_agent_configs_id'), table_name='agent_configs')
|
||||
op.drop_table('agent_configs')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import re
|
||||
import traceback
|
||||
from loguru import logger
|
||||
|
||||
class DrGraphSession:
|
||||
def __init__(self, stepIndex: int, msg: str, session_id: str):
|
||||
logger.info(f"DrGraphSession.__init__: stepIndex={stepIndex}, msg={msg}, session_id={session_id}")
|
||||
self.stepIndex = stepIndex
|
||||
self.session_id = session_id
|
||||
|
||||
match = re.search(r";(-\d+)", msg);
|
||||
level = -3
|
||||
if match:
|
||||
level = int(match.group(1))
|
||||
value = value.replace(f";{level}", "")
|
||||
level = -3 + level
|
||||
|
||||
if "警告" in value or value.startswith("WARNING"):
|
||||
self.log_warning(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "异常" in value or value.startswith("EXCEPTION"):
|
||||
self.log_exception(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "成功" in value or value.startswith("SUCCESS"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "开始" in value or value.startswith("START"):
|
||||
self.log_success(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
elif "失败" in value or value.startswith("ERROR"):
|
||||
self.log_error(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
else:
|
||||
self.log_info(f"第 {self.stepIndex} 步 - {value}", level = level)
|
||||
|
||||
def log_prefix(self) -> str:
|
||||
"""Get log prefix with session ID and desc."""
|
||||
return f"〖Session{self.session_id}〗"
|
||||
|
||||
def parse_source_pos(self, level: int):
|
||||
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
||||
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||
if match:
|
||||
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||
return pos
|
||||
|
||||
def log_info(self, msg: str, level: int = -2):
|
||||
"""Log info message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_success(self, msg: str, level: int = -2):
|
||||
"""Log success message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_warning(self, msg: str, level: int = -2):
|
||||
"""Log warning message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_error(self, msg: str, level: int = -2):
|
||||
"""Log error message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_exception(self, msg: str, level: int = -2):
|
||||
"""Log exception message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import jwt
|
||||
import inspect
|
||||
|
||||
print(f"jwt module path: {inspect.getfile(jwt)}")
|
||||
print(f"jwt module attributes: {dir(jwt)}")
|
||||
try:
|
||||
print(f"jwt module __version__: {jwt.__version__}")
|
||||
except AttributeError:
|
||||
print("jwt module has no __version__ attribute")
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
services:
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
container_name: pgvector-db
|
||||
environment:
|
||||
POSTGRES_USER: drgraph
|
||||
POSTGRES_PASSWORD: yingping
|
||||
POSTGRES_DB: th_agenter
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
# docker exec -it pgvector-db psql -U drgraph -d th_agenter
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
# uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
from fastapi import FastAPI
|
||||
from os.path import dirname, realpath
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from utils.util_log import init_logger
|
||||
from loguru import logger
|
||||
base_dir: str = dirname(realpath(__file__))
|
||||
init_logger(base_dir)
|
||||
|
||||
from th_agenter.api.routes import router
|
||||
from contextlib import asynccontextmanager
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("[生命周期] - Starting up TH Agenter application...")
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("[生命周期] - Shutting down TH Agenter application...")
|
||||
|
||||
def setup_exception_handlers(app: FastAPI) -> None:
|
||||
"""Setup global exception handlers."""
|
||||
|
||||
# Import custom exceptions and handlers
|
||||
from utils.util_exceptions import ChatAgentException, chat_agent_exception_handler
|
||||
|
||||
@app.exception_handler(ChatAgentException)
|
||||
async def custom_chat_agent_exception_handler(request, exc):
|
||||
return await chat_agent_exception_handler(request, exc)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
from utils.util_exceptions import HxfErrorResponse
|
||||
logger.exception(f"HTTP Exception: {exc.status_code} - {exc.detail} - {request.method} {request.url}")
|
||||
return HxfErrorResponse(exc.status_code, exc.detail)
|
||||
|
||||
def make_json_serializable(obj):
|
||||
"""递归地将对象转换为JSON可序列化的格式"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
elif isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
elif isinstance(obj, (ValueError, Exception)):
|
||||
return str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
# For any other object, convert to string
|
||||
return str(obj)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
# Convert any non-serializable objects to strings in error details
|
||||
try:
|
||||
errors = make_json_serializable(exc.errors())
|
||||
except Exception as e:
|
||||
# Fallback: if even our conversion fails, use a simple error message
|
||||
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||
logger.exception(f"Request Validation Error: {errors}")
|
||||
|
||||
logger.exception(f"validation_error: {errors}")
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"type": "validation_error",
|
||||
"message": "Request validation failed",
|
||||
"details": errors
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": "Internal server error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
from th_agenter.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由TH Agenter修改",
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.mount("/static", StaticFiles(directory="static"), name="th_agenter_static")
|
||||
|
||||
# Add middleware
|
||||
from th_agenter.core.app import setup_middleware
|
||||
setup_middleware(app, settings)
|
||||
|
||||
# # Add exception handlers
|
||||
setup_exception_handlers(app)
|
||||
add_router(app)
|
||||
|
||||
return app
|
||||
|
||||
def add_router(app: FastAPI) -> None:
|
||||
"""Add default routers to the FastAPI application."""
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
logger.info("Hello World")
|
||||
return {"Hello": "World"}
|
||||
|
||||
# Include routers
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
|
||||
|
||||
# app.include_router(table_metadata.router)
|
||||
# # 在现有导入中添加
|
||||
# from ..api.endpoints import database_config
|
||||
|
||||
# # 在路由注册部分添加
|
||||
# app.include_router(database_config.router)
|
||||
# # Health check endpoint
|
||||
# @app.get("/health")
|
||||
# async def health_check():
|
||||
# return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# # Root endpoint
|
||||
# @app.get("/")
|
||||
# async def root():
|
||||
# return {"message": "Chat Agent API is running"}
|
||||
|
||||
# # Test endpoint
|
||||
# @app.get("/test")
|
||||
# async def test_endpoint():
|
||||
# return {"message": "API is working"}
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
# from utils.util_test import test_db
|
||||
# test_db()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.9 KiB |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
|
@ -1 +1 @@
|
|||
"""API module for TH-Agenter."""
|
||||
"""API module for TH Agenter."""
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -1 +1 @@
|
|||
"""API endpoints for TH-Agenter."""
|
||||
"""API endpoints for TH Agenter."""
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,84 +1,87 @@
|
|||
"""Authentication endpoints."""
|
||||
|
||||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...core.config import get_settings
|
||||
from ...db.database import get_db
|
||||
from ...db.database import DrSession, get_session
|
||||
from ...services.auth import AuthService
|
||||
from ...services.user import UserService
|
||||
from ...schemas.user import UserResponse, UserCreate
|
||||
from ...utils.schemas import Token, LoginRequest
|
||||
from ...schemas.user import UserResponse, UserCreate, LoginResponse
|
||||
from utils.util_schemas import Token, LoginRequest
|
||||
from loguru import logger
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse)
|
||||
@router.post("/register", response_model=UserResponse, summary="注册新用户")
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
request_user_data: UserCreate,
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""Register a new user."""
|
||||
user_service = UserService(db)
|
||||
|
||||
# Check if user already exists
|
||||
if user_service.get_user_by_email(user_data.email):
|
||||
"""注册新用户"""
|
||||
user_service = UserService(session)
|
||||
session.desc = f"START: 注册用户 {request_user_data.email}"
|
||||
if await user_service.get_user_by_email(request_user_data.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
detail=f"邮箱 {request_user_data.email} 已被注册,请使用其他邮箱注册!!!"
|
||||
)
|
||||
|
||||
if user_service.get_user_by_username(user_data.username):
|
||||
if await user_service.get_user_by_username(request_user_data.username):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken"
|
||||
detail=f"用户名 {request_user_data.username} 已被注册,请使用其他用户名注册!!!"
|
||||
)
|
||||
|
||||
# Create user
|
||||
user = user_service.create_user(user_data)
|
||||
return UserResponse.from_orm(user)
|
||||
user = await user_service.create_user(request_user_data)
|
||||
response = UserResponse.model_validate(user, from_attributes=True)
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
@router.post("/login", response_model=LoginResponse, summary="邮箱与密码登录")
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""Login with email and password."""
|
||||
"""邮箱与密码登录"""
|
||||
# Authenticate user by email
|
||||
user = AuthService.authenticate_user_by_email(db, login_data.email, login_data.password)
|
||||
session.desc = f"START: 用户 {login_data.email} 尝试登录"
|
||||
user = await AuthService.authenticate_user_by_email(session, login_data.email, login_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
detail=f"邮箱 {login_data.email} 或密码错误,请检查后重试!!!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = AuthService.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
session.desc = f"用户 {user.username} 登录成功"
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
response = LoginResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.security.access_token_expire_minutes * 60,
|
||||
user=UserResponse.model_validate(user, from_attributes=True)
|
||||
)
|
||||
return HxfResponse(response)
|
||||
|
||||
|
||||
@router.post("/login-oauth", response_model=Token)
|
||||
@router.post("/login-oauth", response_model=Token, summary="用户通过用户名和密码登录 (OAuth2 兼容)")
|
||||
async def login_oauth(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""Login with username and password (OAuth2 compatible)."""
|
||||
# Authenticate user
|
||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
"""用户通过用户名和密码登录 (OAuth2 兼容)"""
|
||||
session.desc = f"START: 用户 {form_data.username} 尝试 OAuth2 登录"
|
||||
user = await AuthService.authenticate_user(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
session.desc = f"用户 {form_data.username} 尝试 OAuth2 登录失败"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
|
|
@ -87,9 +90,10 @@ async def login_oauth(
|
|||
|
||||
# Create access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = AuthService.create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
session.desc = f"用户 {user.username} OAuth2 登录成功"
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
|
|
@ -97,29 +101,27 @@ async def login_oauth(
|
|||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
@router.post("/refresh", response_model=Token, summary="刷新访问token")
|
||||
async def refresh_token(
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: DrSession = Depends(get_session)
|
||||
):
|
||||
"""Refresh access token."""
|
||||
"""刷新访问 token"""
|
||||
# Create new access token
|
||||
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||
access_token = AuthService.create_access_token(
|
||||
data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||
access_token = await AuthService.create_access_token(
|
||||
session, data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||
}
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.security.access_token_expire_minutes * 60
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
|
||||
async def get_current_user_info(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""Get current user information."""
|
||||
return UserResponse.from_orm(current_user)
|
||||
"""获取当前用户信息"""
|
||||
return UserResponse.model_validate(current_user, from_attributes=True)
|
||||
|
|
@ -1,16 +1,17 @@
|
|||
"""Chat endpoints."""
|
||||
"""Chat endpoints for TH Agenter."""
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from ...db.database import get_db
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...services.auth import AuthService
|
||||
from ...services.chat import ChatService
|
||||
from ...services.conversation import ConversationService
|
||||
from ...utils.schemas import (
|
||||
from utils.util_schemas import (
|
||||
ConversationCreate,
|
||||
ConversationResponse,
|
||||
ConversationUpdate,
|
||||
|
|
@ -22,24 +23,24 @@ from ...utils.schemas import (
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Conversation management
|
||||
@router.post("/conversations", response_model=ConversationResponse)
|
||||
@router.post("/conversations", response_model=ConversationResponse, summary="创建新对话")
|
||||
async def create_conversation(
|
||||
conversation_data: ConversationCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = conversation_service.create_conversation(
|
||||
"""创建新对话"""
|
||||
session.desc = "START: 创建新对话"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.create_conversation(
|
||||
user_id=current_user.id,
|
||||
conversation_data=conversation_data
|
||||
)
|
||||
return ConversationResponse.from_orm(conversation)
|
||||
session.desc = f"SUCCESS: 创建新对话完毕 >>> 当前用户ID: {current_user.id}, conversation: {conversation}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
|
||||
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||
@router.get("/conversations", response_model=List[ConversationResponse], summary="获取用户对话列表")
|
||||
async def list_conversations(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
|
|
@ -47,11 +48,12 @@ async def list_conversations(
|
|||
include_archived: bool = False,
|
||||
order_by: str = "updated_at",
|
||||
order_desc: bool = True,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""List user's conversations with search and filtering."""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
"""获取用户对话列表"""
|
||||
session.desc = "START: 获取用户对话列表"
|
||||
conversation_service = ConversationService(session)
|
||||
conversations = await conversation_service.get_user_conversations(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
search_query=search,
|
||||
|
|
@ -59,137 +61,140 @@ async def list_conversations(
|
|||
order_by=order_by,
|
||||
order_desc=order_desc
|
||||
)
|
||||
return [ConversationResponse.from_orm(conv) for conv in conversations]
|
||||
session.desc = f"SUCCESS: 获取用户对话列表完毕 >>> {len(conversations)} 个对话"
|
||||
return [ConversationResponse.model_validate(conv) for conv in conversations]
|
||||
|
||||
|
||||
@router.get("/conversations/count")
|
||||
@router.get("/conversations/count", summary="获取用户对话总数")
|
||||
async def get_conversations_count(
|
||||
search: str = None,
|
||||
include_archived: bool = False,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Get total count of conversations."""
|
||||
conversation_service = ConversationService(db)
|
||||
count = conversation_service.get_user_conversations_count(
|
||||
"""获取用户对话总数"""
|
||||
session.desc = "START: 获取用户对话总数"
|
||||
conversation_service = ConversationService(session)
|
||||
count = await conversation_service.get_user_conversations_count(
|
||||
search_query=search,
|
||||
include_archived=include_archived
|
||||
)
|
||||
session.desc = f"SUCCESS: 获取用户对话总数完毕 >>> {count} 个对话"
|
||||
return {"count": count}
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse)
|
||||
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse, summary="获取指定对话")
|
||||
async def get_conversation(
|
||||
conversation_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Get a specific conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
conversation = conversation_service.get_conversation(
|
||||
"""获取指定对话"""
|
||||
session.desc = f"START: 获取指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
conversation = await conversation_service.get_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
session.desc = f"ERROR: 获取指定对话失败 >>> conversation_id: {conversation_id}, 未找到该对话"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Conversation not found"
|
||||
)
|
||||
return ConversationResponse.from_orm(conversation)
|
||||
session.desc = f"SUCCESS: 获取指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(conversation)
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse)
|
||||
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse, summary="更新指定对话")
|
||||
async def update_conversation(
|
||||
conversation_id: int,
|
||||
conversation_update: ConversationUpdate,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Update a conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
updated_conversation = conversation_service.update_conversation(
|
||||
"""更新指定对话"""
|
||||
session.desc = f"START: 更新指定对话 >>> conversation_id: {conversation_id}, conversation_update: {conversation_update}"
|
||||
conversation_service = ConversationService(session)
|
||||
updated_conversation = await conversation_service.update_conversation(
|
||||
conversation_id, conversation_update
|
||||
)
|
||||
return ConversationResponse.from_orm(updated_conversation)
|
||||
session.desc = f"SUCCESS: 更新指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return ConversationResponse.model_validate(updated_conversation)
|
||||
|
||||
|
||||
@router.delete("/conversations/{conversation_id}")
|
||||
@router.delete("/conversations/{conversation_id}", summary="删除指定对话")
|
||||
async def delete_conversation(
|
||||
conversation_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Delete a conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
conversation_service.delete_conversation(conversation_id)
|
||||
"""删除指定对话"""
|
||||
session.desc = f"删除指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
await conversation_service.delete_conversation(conversation_id)
|
||||
session.desc = f"SUCCESS: 删除指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation deleted successfully"}
|
||||
|
||||
|
||||
@router.delete("/conversations")
|
||||
async def delete_all_conversations(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete all conversations."""
|
||||
conversation_service = ConversationService(db)
|
||||
conversation_service.delete_all_conversations()
|
||||
return {"message": "All conversations deleted successfully"}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/archive")
|
||||
@router.put("/conversations/{conversation_id}/archive", summary="归档指定对话")
|
||||
async def archive_conversation(
|
||||
conversation_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Archive a conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
success = conversation_service.archive_conversation(conversation_id)
|
||||
"""归档指定对话."""
|
||||
conversation_service = ConversationService(session)
|
||||
success = await conversation_service.archive_conversation(conversation_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to archive conversation"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation archived successfully"}
|
||||
|
||||
|
||||
@router.put("/conversations/{conversation_id}/unarchive")
|
||||
@router.put("/conversations/{conversation_id}/unarchive", summary="取消归档指定对话")
|
||||
async def unarchive_conversation(
|
||||
conversation_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Unarchive a conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
success = conversation_service.unarchive_conversation(conversation_id)
|
||||
"""取消归档指定对话."""
|
||||
session.desc = f"START: 取消归档指定对话 >>> conversation_id: {conversation_id}"
|
||||
conversation_service = ConversationService(session)
|
||||
success = await conversation_service.unarchive_conversation(conversation_id)
|
||||
if not success:
|
||||
session.desc = f"ERROR: 取消归档指定对话失败 >>> conversation_id: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to unarchive conversation"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 取消归档指定对话完毕 >>> conversation_id: {conversation_id}"
|
||||
return {"message": "Conversation unarchived successfully"}
|
||||
|
||||
|
||||
# Message management
|
||||
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse])
|
||||
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse], summary="获取指定对话的消息")
|
||||
async def get_conversation_messages(
|
||||
conversation_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Get messages from a conversation."""
|
||||
conversation_service = ConversationService(db)
|
||||
messages = conversation_service.get_conversation_messages(
|
||||
"""获取指定对话的消息"""
|
||||
session.desc = f"START: 获取指定对话的消息 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||
conversation_service = ConversationService(session)
|
||||
messages = await conversation_service.get_conversation_messages(
|
||||
conversation_id, skip=skip, limit=limit
|
||||
)
|
||||
return [MessageResponse.from_orm(msg) for msg in messages]
|
||||
|
||||
session.desc = f"SUCCESS: 获取指定对话的消息完毕 >>> conversation_id: {conversation_id}, skip: {skip}, limit: {limit}"
|
||||
return [MessageResponse.model_validate(msg) for msg in messages]
|
||||
|
||||
# Chat functionality
|
||||
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse)
|
||||
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse, summary="发送消息并获取AI响应")
|
||||
async def chat(
|
||||
conversation_id: int,
|
||||
chat_request: ChatRequest,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Send a message and get AI response."""
|
||||
chat_service = ChatService(db)
|
||||
"""发送消息并获取AI响应"""
|
||||
session.desc = f"START: 发送消息并获取AI响应 >>> conversation_id: {conversation_id}"
|
||||
chat_service = ChatService(session)
|
||||
response = await chat_service.chat(
|
||||
conversation_id=conversation_id,
|
||||
message=chat_request.message,
|
||||
|
|
@ -201,18 +206,18 @@ async def chat(
|
|||
use_knowledge_base=chat_request.use_knowledge_base,
|
||||
knowledge_base_id=chat_request.knowledge_base_id
|
||||
)
|
||||
session.desc = f"SUCCESS: 发送消息并获取AI响应完毕 >>> conversation_id: {conversation_id}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/chat/stream")
|
||||
@router.post("/conversations/{conversation_id}/chat/stream", summary="发送消息并获取流式AI响应")
|
||||
async def chat_stream(
|
||||
conversation_id: int,
|
||||
chat_request: ChatRequest,
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Send a message and get streaming AI response."""
|
||||
chat_service = ChatService(db)
|
||||
"""发送消息并获取流式AI响应."""
|
||||
chat_service = ChatService(session)
|
||||
|
||||
async def generate_response():
|
||||
async for chunk in chat_service.chat_stream(
|
||||
|
|
|
|||
|
|
@ -1,21 +1,19 @@
|
|||
"""数据库配置管理API"""
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from th_agenter.models.user import User
|
||||
from th_agenter.db.database import get_db
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.database_config_service import DatabaseConfigService
|
||||
from th_agenter.utils.logger import get_logger
|
||||
from th_agenter.services.auth import AuthService
|
||||
logger = get_logger("database_config_api")
|
||||
router = APIRouter(prefix="/api/database-config", tags=["database-config"])
|
||||
from th_agenter.utils.schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||
|
||||
# 在文件顶部添加
|
||||
from functools import lru_cache
|
||||
|
||||
router = APIRouter(prefix="/api/database-config", tags=["database-config"])
|
||||
# 创建服务单例
|
||||
@lru_cache()
|
||||
def get_database_config_service() -> DatabaseConfigService:
|
||||
|
|
@ -26,15 +24,16 @@ def get_database_config_service() -> DatabaseConfigService:
|
|||
# 或者使用全局变量
|
||||
_database_service_instance = None
|
||||
|
||||
def get_database_service(db: Session = Depends(get_db)) -> DatabaseConfigService:
|
||||
def get_database_service(session: Session = Depends(get_session)) -> DatabaseConfigService:
|
||||
"""获取DatabaseConfigService实例"""
|
||||
global _database_service_instance
|
||||
if _database_service_instance is None:
|
||||
_database_service_instance = DatabaseConfigService(db)
|
||||
_database_service_instance = DatabaseConfigService(session)
|
||||
else:
|
||||
# 更新db session
|
||||
_database_service_instance.db = db
|
||||
_database_service_instance.db = session
|
||||
return _database_service_instance
|
||||
|
||||
class DatabaseConfigCreate(BaseModel):
|
||||
name: str = Field(..., description="配置名称")
|
||||
db_type: str = Field(default="postgresql", description="数据库类型")
|
||||
|
|
@ -46,103 +45,68 @@ class DatabaseConfigCreate(BaseModel):
|
|||
is_default: bool = Field(default=False, description="是否为默认配置")
|
||||
connection_params: Dict[str, Any] = Field(default=None, description="额外连接参数")
|
||||
|
||||
|
||||
class DatabaseConfigResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
db_type: str
|
||||
host: str
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
username: str
|
||||
password: str # 添加密码字段
|
||||
password: str
|
||||
is_active: bool
|
||||
is_default: bool
|
||||
created_at: str
|
||||
updated_at: str = None
|
||||
updated_at: str
|
||||
|
||||
|
||||
@router.post("/", response_model=NormalResponse)
|
||||
@router.post("/", response_model=NormalResponse, summary="创建或更新数据库配置")
|
||||
async def create_database_config(
|
||||
config_data: DatabaseConfigCreate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""创建或更新数据库配置"""
|
||||
try:
|
||||
service = DatabaseConfigService(db)
|
||||
config = await service.create_or_update_config(current_user.id, config_data.dict())
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="保存数据库配置成功",
|
||||
data=config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
config = await service.create_or_update_config(current_user.id, config_data.model_dump())
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="保存数据库配置成功",
|
||||
data=config
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[DatabaseConfigResponse])
|
||||
@router.get("/", response_model=List[DatabaseConfigResponse], summary="获取用户的数据库配置列表")
|
||||
async def get_database_configs(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""获取用户的数据库配置列表"""
|
||||
try:
|
||||
service = DatabaseConfigService(db)
|
||||
configs = service.get_user_configs(current_user.id)
|
||||
configs = service.get_user_configs(current_user.id)
|
||||
|
||||
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||
return config_list
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||
return config_list
|
||||
|
||||
|
||||
@router.post("/{config_id}/test")
|
||||
@router.post("/{config_id}/test", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""测试数据库连接"""
|
||||
try:
|
||||
service = DatabaseConfigService(db)
|
||||
result = await service.test_connection(config_id, current_user.id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"测试数据库连接失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
result = await service.test_connection(config_id, current_user.id)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{config_id}/connect")
|
||||
@router.post("/{config_id}/connect", response_model=NormalResponse, summary="连接数据库并获取表列表")
|
||||
async def connect_database(
|
||||
config_id: int,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""连接数据库并获取表列表"""
|
||||
try:
|
||||
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"连接数据库失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/tables/{table_name}/data")
|
||||
@router.get("/tables/{table_name}/data", summary="获取表数据预览")
|
||||
async def get_table_data(
|
||||
table_name: str,
|
||||
db_type: str,
|
||||
|
|
@ -161,47 +125,28 @@ async def get_table_data(
|
|||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tables/{table_name}/schema")
|
||||
@router.get("/tables/{table_name}/schema", summary="获取表结构信息")
|
||||
async def get_table_schema(
|
||||
table_name: str,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""获取表结构信息"""
|
||||
try:
|
||||
service = DatabaseConfigService(db)
|
||||
result = await service.describe_table(table_name, current_user.id)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取表结构失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
result = await service.describe_table(table_name, current_user.id) # 这在哪里实现的?
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse)
|
||||
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse, summary="根据数据库类型获取配置")
|
||||
async def get_config_by_type(
|
||||
db_type: str,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
service: DatabaseConfigService = Depends(get_database_service)
|
||||
):
|
||||
"""根据数据库类型获取配置"""
|
||||
try:
|
||||
service = DatabaseConfigService(db)
|
||||
config = service.get_config_by_type(current_user.id, db_type)
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到类型为 {db_type} 的配置"
|
||||
)
|
||||
# 返回包含解密密码的配置
|
||||
return config.to_dict(include_password=True, decrypt_service=service)
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库配置失败: {str(e)}")
|
||||
config = service.get_config_by_type(current_user.id, db_type)
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到类型为 {db_type} 的配置"
|
||||
)
|
||||
|
||||
# 返回包含解密密码的配置
|
||||
return config.to_dict(include_password=True, decrypt_service=service)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -3,24 +3,23 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select, delete, update
|
||||
|
||||
from ...db.database import get_db
|
||||
from loguru import logger
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.llm_config import LLMConfig
|
||||
from ...core.simple_permissions import require_super_admin, require_authenticated_user
|
||||
from ...services.auth import AuthService
|
||||
from ...utils.logger import get_logger
|
||||
from ...schemas.llm_config import (
|
||||
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
|
||||
LLMConfigTest
|
||||
)
|
||||
from th_agenter.services.document_processor import get_document_processor
|
||||
logger = get_logger(__name__)
|
||||
from utils.util_exceptions import HxfResponse
|
||||
|
||||
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[LLMConfigResponse])
|
||||
@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表")
|
||||
async def get_llm_configs(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
|
|
@ -28,501 +27,414 @@ async def get_llm_configs(
|
|||
provider: Optional[str] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
is_embedding: Optional[bool] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取大模型配置列表."""
|
||||
try:
|
||||
query = db.query(LLMConfig)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
LLMConfig.name.ilike(f"%{search}%"),
|
||||
LLMConfig.model_name.ilike(f"%{search}%"),
|
||||
LLMConfig.description.ilike(f"%{search}%")
|
||||
)
|
||||
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
LLMConfig.name.ilike(f"%{search}%"),
|
||||
LLMConfig.model_name.ilike(f"%{search}%"),
|
||||
LLMConfig.description.ilike(f"%{search}%")
|
||||
)
|
||||
|
||||
# 服务商筛选
|
||||
if provider:
|
||||
query = query.filter(LLMConfig.provider == provider)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
query = query.filter(LLMConfig.is_active == is_active)
|
||||
|
||||
# 模型类型筛选
|
||||
if is_embedding is not None:
|
||||
query = query.filter(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
# 排序
|
||||
query = query.order_by(LLMConfig.name)
|
||||
|
||||
# 分页
|
||||
configs = query.offset(skip).limit(limit).all()
|
||||
|
||||
return [config.to_dict(include_sensitive=True) for config in configs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM configs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取大模型配置列表失败"
|
||||
)
|
||||
|
||||
# 服务商筛选
|
||||
if provider:
|
||||
stmt = stmt.where(LLMConfig.provider == provider)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(LLMConfig.is_active == is_active)
|
||||
|
||||
# 模型类型筛选
|
||||
if is_embedding is not None:
|
||||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
# 排序
|
||||
stmt = stmt.order_by(LLMConfig.name)
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
configs = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置"
|
||||
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
@router.get("/providers", summary="获取支持的大模型服务商列表")
|
||||
async def get_llm_providers(
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取支持的大模型服务商列表."""
|
||||
try:
|
||||
providers = db.query(LLMConfig.provider).distinct().all()
|
||||
return [provider[0] for provider in providers if provider[0]]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM providers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取服务商列表失败"
|
||||
)
|
||||
session.desc = "START: 获取支持的大模型服务商列表"
|
||||
stmt = select(LLMConfig.provider).distinct()
|
||||
providers = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商"
|
||||
return HxfResponse([provider for provider in providers if provider])
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[LLMConfigResponse])
|
||||
|
||||
@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置")
|
||||
async def get_active_llm_configs(
|
||||
is_embedding: Optional[bool] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取所有激活的大模型配置."""
|
||||
try:
|
||||
query = db.query(LLMConfig).filter(LLMConfig.is_active == True)
|
||||
|
||||
if is_embedding is not None:
|
||||
query = query.filter(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
configs = query.order_by(LLMConfig.created_at).all()
|
||||
|
||||
return [config.to_dict(include_sensitive=True) for config in configs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active LLM configs: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取激活配置列表失败"
|
||||
)
|
||||
session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
|
||||
|
||||
if is_embedding is not None:
|
||||
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
|
||||
|
||||
stmt = stmt.order_by(LLMConfig.created_at)
|
||||
configs = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置"
|
||||
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
|
||||
|
||||
|
||||
@router.get("/default", response_model=LLMConfigResponse)
|
||||
@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置")
|
||||
async def get_default_llm_config(
|
||||
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取默认大模型配置."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(
|
||||
LLMConfig.is_default == True,
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.is_active == True
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
model_type = "嵌入模型" if is_embedding else "对话模型"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到默认{model_type}配置"
|
||||
)
|
||||
|
||||
return config.to_dict(include_sensitive=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting default LLM config: {str(e)}")
|
||||
session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||
stmt = select(LLMConfig).where(
|
||||
LLMConfig.is_default == True,
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.is_active == True
|
||||
)
|
||||
config = (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
model_type = "嵌入模型" if is_embedding else "对话模型"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取默认配置失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到默认{model_type}配置"
|
||||
)
|
||||
|
||||
session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}"
|
||||
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||
|
||||
|
||||
@router.get("/{config_id}", response_model=LLMConfigResponse)
|
||||
@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情")
|
||||
async def get_llm_config(
|
||||
config_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_authenticated_user)
|
||||
):
|
||||
"""获取大模型配置详情."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
return config.to_dict(include_sensitive=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM config {config_id}: {str(e)}")
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取大模型配置详情失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
return HxfResponse(config.to_dict(include_sensitive=True))
|
||||
|
||||
|
||||
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置")
|
||||
async def create_llm_config(
|
||||
config_data: LLMConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""创建大模型配置."""
|
||||
try:
|
||||
# 检查配置名称是否已存在
|
||||
existing_config = db.query(LLMConfig).filter(
|
||||
LLMConfig.name == config_data.name
|
||||
).first()
|
||||
# 检查配置名称是否已存在
|
||||
session.desc = f"START: 创建大模型配置, name={config_data.name}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 创建临时配置对象进行验证
|
||||
temp_config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = temp_config.validate_config()
|
||||
if not validation_result['valid']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=validation_result['error']
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default:
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == config_data.is_embedding
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 创建配置
|
||||
config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.add(config)
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置")
|
||||
async def update_llm_config(
|
||||
config_id: int,
|
||||
config_data: LLMConfigUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新大模型配置."""
|
||||
session.desc = f"START: 更新大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置名称是否已存在(排除自己)
|
||||
if config_data.name and config_data.name != config.name:
|
||||
stmt = select(LLMConfig).where(
|
||||
LLMConfig.name == config_data.name,
|
||||
LLMConfig.id != config_id
|
||||
)
|
||||
existing_config = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 创建临时配置对象进行验证
|
||||
temp_config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = temp_config.validate_config()
|
||||
if not validation_result['valid']:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=validation_result['error']
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default:
|
||||
db.query(LLMConfig).filter(
|
||||
LLMConfig.is_embedding == config_data.is_embedding
|
||||
).update({"is_default": False})
|
||||
|
||||
# 创建配置
|
||||
config = LLMConfig(
|
||||
name=config_data.name,
|
||||
provider=config_data.provider,
|
||||
model_name=config_data.model_name,
|
||||
api_key=config_data.api_key,
|
||||
base_url=config_data.base_url,
|
||||
max_tokens=config_data.max_tokens,
|
||||
temperature=config_data.temperature,
|
||||
top_p=config_data.top_p,
|
||||
frequency_penalty=config_data.frequency_penalty,
|
||||
presence_penalty=config_data.presence_penalty,
|
||||
description=config_data.description,
|
||||
is_active=config_data.is_active,
|
||||
is_default=config_data.is_default,
|
||||
is_embedding=config_data.is_embedding,
|
||||
extra_config=config_data.extra_config or {}
|
||||
)
|
||||
config.set_audit_fields(current_user.id)
|
||||
|
||||
db.add(config)
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
logger.info(f"LLM config created: {config.name} by user {current_user.username}")
|
||||
return config.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating LLM config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="创建大模型配置失败"
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default is True:
|
||||
# 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值
|
||||
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 更新字段
|
||||
update_data = config_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}"
|
||||
return HxfResponse(config.to_dict())
|
||||
|
||||
|
||||
@router.put("/{config_id}", response_model=LLMConfigResponse)
|
||||
async def update_llm_config(
|
||||
config_id: int,
|
||||
config_data: LLMConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新大模型配置."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置名称是否已存在(排除自己)
|
||||
if config_data.name and config_data.name != config.name:
|
||||
existing_config = db.query(LLMConfig).filter(
|
||||
LLMConfig.name == config_data.name,
|
||||
LLMConfig.id != config_id
|
||||
).first()
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="配置名称已存在"
|
||||
)
|
||||
|
||||
# 如果设为默认,取消同类型的其他默认配置
|
||||
if config_data.is_default is True:
|
||||
# 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值
|
||||
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
|
||||
db.query(LLMConfig).filter(
|
||||
LLMConfig.is_embedding == is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# 更新字段
|
||||
update_data = config_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
logger.info(f"LLM config updated: {config.name} by user {current_user.username}")
|
||||
return config.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating LLM config {config_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="更新大模型配置失败"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
|
||||
async def delete_llm_config(
|
||||
config_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""删除大模型配置."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||
# 这里可以添加相关的检查逻辑
|
||||
|
||||
db.delete(config)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"LLM config deleted: {config.name} by user {current_user.username}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting LLM config {config_id}: {str(e)}")
|
||||
session.desc = f"START: 删除大模型配置, id={config_id}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="删除大模型配置失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||
# 这里可以添加相关的检查逻辑
|
||||
|
||||
session.delete(config)
|
||||
session.commit()
|
||||
|
||||
session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}"
|
||||
return HxfResponse({"message": "LLM config deleted successfully"})
|
||||
|
||||
|
||||
@router.post("/{config_id}/test")
|
||||
@router.post("/{config_id}/test", summary="测试连接大模型配置")
|
||||
async def test_llm_config(
|
||||
config_id: int,
|
||||
test_data: LLMConfigTest,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""测试大模型配置."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = config.validate_config()
|
||||
if not validation_result["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"配置验证失败: {validation_result['error']}",
|
||||
"details": validation_result
|
||||
}
|
||||
|
||||
# 尝试创建客户端并发送测试请求
|
||||
try:
|
||||
# 这里应该根据不同的服务商创建相应的客户端
|
||||
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||
|
||||
test_message = test_data.message or "Hello, this is a test message."
|
||||
|
||||
# TODO: 实现具体的测试逻辑
|
||||
# 例如:
|
||||
# client = config.get_client()
|
||||
# response = client.chat.completions.create(
|
||||
# model=config.model_name,
|
||||
# messages=[{"role": "user", "content": test_message}],
|
||||
# max_tokens=100
|
||||
# )
|
||||
|
||||
# 模拟测试成功
|
||||
logger.info(f"LLM config test: {config.name} by user {current_user.username}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置测试成功",
|
||||
"test_message": test_message,
|
||||
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
|
||||
"latency_ms": 150, # 模拟延迟
|
||||
"config_info": config.get_client_config()
|
||||
}
|
||||
|
||||
except Exception as test_error:
|
||||
logger.error(f"LLM config test failed: {config.name}, error: {str(test_error)}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"配置测试失败: {str(test_error)}",
|
||||
"test_message": test_message,
|
||||
"config_info": config.get_client_config()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing LLM config {config_id}: {str(e)}")
|
||||
"""测试连接大模型配置."""
|
||||
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}"
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="测试大模型配置失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
validation_result = config.validate_config()
|
||||
if not validation_result["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"配置验证失败: {validation_result['error']}",
|
||||
"details": validation_result
|
||||
}
|
||||
|
||||
# 尝试创建客户端并发送测试请求
|
||||
try:
|
||||
# 这里应该根据不同的服务商创建相应的客户端
|
||||
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||
|
||||
test_message = test_data.message or "Hello, this is a test message."
|
||||
|
||||
# TODO: 实现具体的测试逻辑
|
||||
# 例如:
|
||||
# client = config.get_client()
|
||||
# response = client.chat.completions.create(
|
||||
# model=config.model_name,
|
||||
# messages=[{"role": "user", "content": test_message}],
|
||||
# max_tokens=100
|
||||
# )
|
||||
|
||||
# 模拟测试成功
|
||||
session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}"
|
||||
|
||||
return HxfResponse({
|
||||
"success": True,
|
||||
"message": "配置测试成功",
|
||||
"test_message": test_message,
|
||||
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
|
||||
"latency_ms": 150, # 模拟延迟
|
||||
"config_info": config.get_client_config()
|
||||
})
|
||||
|
||||
except Exception as test_error:
|
||||
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
|
||||
return HxfResponse({
|
||||
"success": False,
|
||||
"message": f"配置测试失败: {str(test_error)}",
|
||||
"test_message": test_message,
|
||||
"config_info": config.get_client_config()
|
||||
})
|
||||
|
||||
@router.post("/{config_id}/toggle-status")
|
||||
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
|
||||
async def toggle_llm_config_status(
|
||||
config_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""切换大模型配置状态."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 切换状态
|
||||
config.is_active = not config.is_active
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
status_text = "激活" if config.is_active else "禁用"
|
||||
logger.info(f"LLM config status toggled: {config.name} {status_text} by user {current_user.username}")
|
||||
|
||||
return {
|
||||
"message": f"配置已{status_text}",
|
||||
"is_active": config.is_active
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error toggling LLM config status {config_id}: {str(e)}")
|
||||
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="切换配置状态失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 切换状态
|
||||
config.is_active = not config.is_active
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
status_text = "激活" if config.is_active else "禁用"
|
||||
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}"
|
||||
|
||||
return HxfResponse({
|
||||
"message": f"配置已{status_text}",
|
||||
"is_active": config.is_active
|
||||
})
|
||||
|
||||
|
||||
@router.post("/{config_id}/set-default")
|
||||
@router.post("/{config_id}/set-default", summary="设置默认大模型配置")
|
||||
async def set_default_llm_config(
|
||||
config_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""设置默认大模型配置."""
|
||||
try:
|
||||
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置是否激活
|
||||
if not config.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="只能将激活的配置设为默认"
|
||||
)
|
||||
|
||||
# 取消同类型的其他默认配置
|
||||
db.query(LLMConfig).filter(
|
||||
LLMConfig.is_embedding == config.is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# 设置当前配置为默认
|
||||
config.is_default = True
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||
logger.info(f"Default LLM config set: {config.name} ({model_type}) by user {current_user.username}")
|
||||
# 更新文档处理器默认embedding
|
||||
get_document_processor()._init_embeddings()
|
||||
return {
|
||||
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||
"is_default": config.is_default
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error setting default LLM config {config_id}: {str(e)}")
|
||||
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}"
|
||||
|
||||
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
|
||||
config = session.execute(stmt).scalar_one_or_none()
|
||||
if not config:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="设置默认配置失败"
|
||||
)
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="大模型配置不存在"
|
||||
)
|
||||
|
||||
# 检查配置是否激活
|
||||
if not config.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="只能将激活的配置设为默认"
|
||||
)
|
||||
|
||||
# 取消同类型的其他默认配置
|
||||
stmt = update(LLMConfig).where(
|
||||
LLMConfig.is_embedding == config.is_embedding,
|
||||
LLMConfig.id != config_id
|
||||
).values({"is_default": False})
|
||||
session.execute(stmt)
|
||||
|
||||
# 设置当前配置为默认
|
||||
config.is_default = True
|
||||
config.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(config)
|
||||
|
||||
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||
# 更新文档处理器默认embedding
|
||||
get_document_processor()._init_embeddings()
|
||||
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}"
|
||||
return HxfResponse({
|
||||
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||
"is_default": config.is_default
|
||||
})
|
||||
|
|
@ -1,346 +1,273 @@
|
|||
"""Role management API endpoints."""
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy import select, and_, or_, delete
|
||||
|
||||
from ...core.simple_permissions import require_super_admin
|
||||
from ...db.database import get_db
|
||||
from ...db.database import get_session
|
||||
from ...models.user import User
|
||||
from ...models.permission import Role, UserRole
|
||||
from ...services.auth import AuthService
|
||||
from ...utils.logger import get_logger
|
||||
from ...schemas.permission import (
|
||||
RoleCreate, RoleUpdate, RoleResponse,
|
||||
UserRoleAssign
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/roles", tags=["roles"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[RoleResponse])
|
||||
@router.get("/", response_model=List[RoleResponse], summary="获取角色列表")
|
||||
async def get_roles(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
search: Optional[str] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user = Depends(require_super_admin),
|
||||
):
|
||||
"""获取角色列表."""
|
||||
try:
|
||||
query = db.query(Role)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Role.name.ilike(f"%{search}%"),
|
||||
Role.code.ilike(f"%{search}%"),
|
||||
Role.description.ilike(f"%{search}%")
|
||||
)
|
||||
session.desc = f"START: 获取用户 {current_user.username} 角色列表"
|
||||
stmt = select(Role)
|
||||
|
||||
# 搜索
|
||||
if search:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Role.name.ilike(f"%{search}%"),
|
||||
Role.code.ilike(f"%{search}%"),
|
||||
Role.description.ilike(f"%{search}%")
|
||||
)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
query = query.filter(Role.is_active == is_active)
|
||||
|
||||
# 分页
|
||||
roles = query.offset(skip).limit(limit).all()
|
||||
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting roles: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取角色列表失败"
|
||||
)
|
||||
|
||||
# 状态筛选
|
||||
if is_active is not None:
|
||||
stmt = stmt.where(Role.is_active == is_active)
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 有 {len(roles)} 个角色"
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
|
||||
@router.get("/{role_id}", response_model=RoleResponse)
|
||||
@router.get("/{role_id}", response_model=RoleResponse, summary="获取角色详情")
|
||||
async def get_role(
|
||||
role_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""获取角色详情."""
|
||||
try:
|
||||
role = db.query(Role).filter(Role.id == role_id).first()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
return role.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting role {role_id}: {str(e)}")
|
||||
session.desc = f"START: 获取角色 {role_id} 详情"
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取角色详情失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
return role.to_dict()
|
||||
|
||||
|
||||
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
|
||||
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED, summary="创建角色")
|
||||
async def create_role(
|
||||
role_data: RoleCreate,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""创建角色."""
|
||||
try:
|
||||
# 检查角色代码是否已存在
|
||||
existing_role = db.query(Role).filter(
|
||||
Role.code == role_data.code
|
||||
).first()
|
||||
session.desc = f"START: 创建角色 {role_data.name}"
|
||||
# 检查角色代码是否已存在
|
||||
stmt = select(Role).where(Role.code == role_data.code)
|
||||
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="角色代码已存在"
|
||||
)
|
||||
|
||||
# 创建角色
|
||||
role = Role(
|
||||
name=role_data.name,
|
||||
code=role_data.code,
|
||||
description=role_data.description,
|
||||
is_active=role_data.is_active
|
||||
)
|
||||
role.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(role)
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
@router.put("/{role_id}", response_model=RoleResponse, summary="更新角色")
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
role_data: RoleUpdate,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新角色."""
|
||||
session.desc = f"更新用户 {current_user.username} 角色 {role_id}"
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
# 超级管理员角色不能被编辑
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被编辑"
|
||||
)
|
||||
|
||||
# 检查角色编码是否已存在(排除当前角色)
|
||||
if role_data.code and role_data.code != role.code:
|
||||
stmt = select(Role).where(
|
||||
and_(
|
||||
Role.code == role_data.code,
|
||||
Role.id != role_id
|
||||
)
|
||||
)
|
||||
existing_role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if existing_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="角色代码已存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = role_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(role, field, value)
|
||||
|
||||
# Audit fields are set automatically by SQLAlchemy event listener
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
|
||||
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
# 创建角色
|
||||
role = Role(
|
||||
name=role_data.name,
|
||||
code=role_data.code,
|
||||
description=role_data.description,
|
||||
is_active=role_data.is_active
|
||||
)
|
||||
role.set_audit_fields(current_user.id)
|
||||
|
||||
db.add(role)
|
||||
db.commit()
|
||||
db.refresh(role)
|
||||
|
||||
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating role: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="创建角色失败"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{role_id}", response_model=RoleResponse)
|
||||
async def update_role(
|
||||
role_id: int,
|
||||
role_data: RoleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""更新角色."""
|
||||
try:
|
||||
role = db.query(Role).filter(Role.id == role_id).first()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
# 超级管理员角色不能被编辑
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被编辑"
|
||||
)
|
||||
|
||||
# 检查角色编码是否已存在(排除当前角色)
|
||||
if role_data.code and role_data.code != role.code:
|
||||
existing_role = db.query(Role).filter(
|
||||
and_(
|
||||
Role.code == role_data.code,
|
||||
Role.id != role_id
|
||||
)
|
||||
).first()
|
||||
if existing_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="角色代码已存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = role_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(role, field, value)
|
||||
|
||||
role.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
db.refresh(role)
|
||||
|
||||
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||
return role.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating role {role_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="更新角色失败"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除角色")
|
||||
async def delete_role(
|
||||
role_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""删除角色."""
|
||||
try:
|
||||
role = db.query(Role).filter(Role.id == role_id).first()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
# 超级管理员角色不能被删除
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被删除"
|
||||
)
|
||||
|
||||
# 检查是否有用户使用该角色
|
||||
user_count = db.query(UserRole).filter(
|
||||
UserRole.role_id == role_id
|
||||
).count()
|
||||
if user_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无法删除角色,还有 {user_count} 个用户关联此角色"
|
||||
)
|
||||
|
||||
# 删除角色
|
||||
db.delete(role)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Role deleted: {role.name} by user {current_user.username}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting role {role_id}: {str(e)}")
|
||||
stmt = select(Role).where(Role.id == role_id)
|
||||
role = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="删除角色失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="角色不存在"
|
||||
)
|
||||
|
||||
|
||||
# 超级管理员角色不能被删除
|
||||
if role.code == "SUPER_ADMIN":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="超级管理员角色不能被删除"
|
||||
)
|
||||
|
||||
# 检查是否有用户使用该角色
|
||||
stmt = select(UserRole).where(UserRole.role_id == role_id)
|
||||
user_count = (await session.execute(stmt)).scalars().count()
|
||||
if user_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无法删除角色,还有 {user_count} 个用户关联此角色"
|
||||
)
|
||||
|
||||
# 删除角色
|
||||
await session.delete(role)
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"角色删除成功: {role.name} by user {current_user.username}"
|
||||
return {"message": f"Role deleted successfully: {role.name} by user {current_user.username}"}
|
||||
|
||||
# 用户角色管理路由
|
||||
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
||||
|
||||
@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED)
|
||||
@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED, summary="为用户分配角色")
|
||||
async def assign_user_roles(
|
||||
assignment_data: UserRoleAssign,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(require_super_admin)
|
||||
):
|
||||
"""为用户分配角色."""
|
||||
try:
|
||||
# 验证用户是否存在
|
||||
user = db.query(User).filter(User.id == assignment_data.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
# 验证角色是否存在
|
||||
roles = db.query(Role).filter(
|
||||
Role.id.in_(assignment_data.role_ids)
|
||||
).all()
|
||||
if len(roles) != len(assignment_data.role_ids):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="部分角色不存在"
|
||||
)
|
||||
|
||||
# 删除现有角色关联
|
||||
db.query(UserRole).filter(
|
||||
UserRole.user_id == assignment_data.user_id
|
||||
).delete()
|
||||
|
||||
# 添加新的角色关联
|
||||
for role_id in assignment_data.role_ids:
|
||||
user_role = UserRole(
|
||||
user_id=assignment_data.user_id,
|
||||
role_id=role_id
|
||||
)
|
||||
db.add(user_role)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}")
|
||||
|
||||
return {"message": "角色分配成功"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error assigning roles to user {assignment_data.user_id}: {str(e)}")
|
||||
# 验证用户是否存在
|
||||
stmt = select(User).where(User.id == assignment_data.user_id)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="角色分配失败"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
# 验证角色是否存在
|
||||
stmt = select(Role).where(Role.id.in_(assignment_data.role_ids))
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
if len(roles) != len(assignment_data.role_ids):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="部分角色不存在"
|
||||
)
|
||||
|
||||
# 删除现有角色关联
|
||||
stmt = delete(UserRole).where(UserRole.user_id == assignment_data.user_id)
|
||||
await session.execute(stmt)
|
||||
|
||||
# 添加新的角色关联
|
||||
for role_id in assignment_data.role_ids:
|
||||
user_role = UserRole(
|
||||
user_id=assignment_data.user_id,
|
||||
role_id=role_id
|
||||
)
|
||||
await session.add(user_role)
|
||||
|
||||
await session.commit()
|
||||
|
||||
session.desc = f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}"
|
||||
|
||||
return {"message": "角色分配成功"}
|
||||
|
||||
|
||||
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse])
|
||||
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse], summary="获取用户角色列表")
|
||||
async def get_user_roles(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_active_user)
|
||||
):
|
||||
"""获取用户角色列表."""
|
||||
try:
|
||||
# 检查权限:用户只能查看自己的角色,或者是超级管理员
|
||||
if current_user.id != user_id and not current_user.is_superuser():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权限查看其他用户的角色"
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
roles = db.query(Role).join(
|
||||
UserRole, Role.id == UserRole.role_id
|
||||
).filter(
|
||||
UserRole.user_id == user_id
|
||||
).all()
|
||||
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user roles {user_id}: {str(e)}")
|
||||
# 检查权限:用户只能查看自己的角色,或者是超级管理员
|
||||
if current_user.id != user_id and not await current_user.is_superuser():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取用户角色失败"
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权限查看其他用户的角色"
|
||||
)
|
||||
|
||||
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
user = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
stmt = select(Role).join(
|
||||
UserRole, Role.id == UserRole.role_id
|
||||
).where(
|
||||
UserRole.user_id == user_id
|
||||
)
|
||||
roles = (await session.execute(stmt)).scalars().all()
|
||||
|
||||
return [role.to_dict() for role in roles]
|
||||
|
||||
# 将子路由添加到主路由
|
||||
router.include_router(user_role_router)
|
||||
|
|
@ -2,18 +2,16 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||
from fastapi.security import HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.db.database import get_db
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.auth import AuthService
|
||||
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||
from th_agenter.services.conversation import ConversationService
|
||||
from th_agenter.services.conversation_context import conversation_context_service
|
||||
from th_agenter.utils.schemas import BaseResponse
|
||||
from utils.util_schemas import BaseResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
|
||||
security = HTTPBearer()
|
||||
|
|
@ -36,35 +34,38 @@ class ConversationContextResponse(BaseModel):
|
|||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
@router.post("/query", response_model=SmartQueryResponse)
|
||||
@router.post("/query", response_model=SmartQueryResponse, summary="智能问数查询")
|
||||
async def smart_query(
|
||||
request: SmartQueryRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
智能问数查询接口
|
||||
支持新对话时自动加载文件列表,智能选择相关Excel文件,生成和执行pandas代码
|
||||
"""
|
||||
session.desc = f"START: 用户 {current_user.username} 智能问数查询"
|
||||
conversation_id = None
|
||||
|
||||
try:
|
||||
# 验证请求参数
|
||||
if not request.query or not request.query.strip():
|
||||
session.desc = "ERROR: 用户输入为空, 查询内容不能为空"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="查询内容不能为空"
|
||||
)
|
||||
|
||||
if len(request.query) > 1000:
|
||||
session.desc = "ERROR: 用户输入过长, 查询内容不能超过1000字符"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="查询内容过长,请控制在1000字符以内"
|
||||
)
|
||||
|
||||
# 初始化工作流管理器
|
||||
workflow_manager = SmartWorkflowManager(db)
|
||||
conversation_service = ConversationService(db)
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_service = ConversationService(session)
|
||||
|
||||
# 处理对话上下文
|
||||
conversation_id = request.conversation_id
|
||||
|
|
@ -77,24 +78,26 @@ async def smart_query(
|
|||
title=f"智能问数: {request.query[:20]}..."
|
||||
)
|
||||
request.is_new_conversation = True
|
||||
logger.info(f"创建新对话: {conversation_id}")
|
||||
session.desc = f"创建新对话: {conversation_id}"
|
||||
except Exception as e:
|
||||
logger.warning(f"创建对话失败,使用临时会话: {e}")
|
||||
session.desc = f"WARNING: 创建对话失败,使用临时会话: {e}"
|
||||
conversation_id = None
|
||||
else:
|
||||
# 验证对话是否存在且属于当前用户
|
||||
try:
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
if not context or context.get('user_id') != current_user.id:
|
||||
session.desc = f"ERROR: 对话 {conversation_id} 不存在或无权访问"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话不存在或无权访问"
|
||||
)
|
||||
logger.info(f"使用现有对话: {conversation_id}")
|
||||
session.desc = f"使用现有对话: {conversation_id}"
|
||||
except HTTPException:
|
||||
session.desc = f"EXCEPTION: 对话 {conversation_id} 不存在或无权访问"
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"验证对话失败: {e}")
|
||||
session.desc = f"ERROR: 验证对话失败: {e}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="对话验证失败"
|
||||
|
|
@ -109,7 +112,7 @@ async def smart_query(
|
|||
content=request.query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存用户消息失败: {e}")
|
||||
session.desc = f"WARNING: 保存用户消息失败: {e}"
|
||||
# 不阻断流程,继续执行查询
|
||||
|
||||
# 执行智能查询工作流
|
||||
|
|
@ -121,7 +124,7 @@ async def smart_query(
|
|||
is_new_conversation=request.is_new_conversation
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"智能查询执行失败: {e}")
|
||||
session.desc = f"ERROR: 智能查询执行失败: {e}"
|
||||
# 返回结构化的错误响应
|
||||
return SmartQueryResponse(
|
||||
success=False,
|
||||
|
|
@ -157,17 +160,15 @@ async def smart_query(
|
|||
selected_files=result.get('data', {}).get('used_files', [])
|
||||
)
|
||||
|
||||
logger.info(f"查询成功完成,对话ID: {conversation_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||
session.desc = f"EXCEPTION: 保存消息到对话历史失败: {e}"
|
||||
# 不影响返回结果,只记录警告
|
||||
|
||||
# 返回结果,包含对话ID
|
||||
response_data = result.get('data', {})
|
||||
if conversation_id:
|
||||
response_data['conversation_id'] = conversation_id
|
||||
|
||||
session.desc = f"SUCCESS: 保存助手回复和更新上下文,对话ID: {conversation_id}"
|
||||
return SmartQueryResponse(
|
||||
success=result['success'],
|
||||
message=result.get('message', '查询完成'),
|
||||
|
|
@ -177,10 +178,10 @@ async def smart_query(
|
|||
)
|
||||
|
||||
except HTTPException:
|
||||
print(e)
|
||||
session.desc = f"EXCEPTION: HTTP异常: {e}"
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"智能查询接口异常: {e}", exc_info=True)
|
||||
session.desc = f"ERROR: 智能查询接口异常: {e}"
|
||||
# 返回通用错误响应
|
||||
return SmartQueryResponse(
|
||||
success=False,
|
||||
|
|
@ -194,149 +195,134 @@ async def smart_query(
|
|||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse)
|
||||
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse, summary="获取对话上下文")
|
||||
async def get_conversation_context(
|
||||
conversation_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取对话上下文信息,包括已使用的文件和历史查询
|
||||
"""
|
||||
try:
|
||||
# 获取对话上下文
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
# 验证用户权限
|
||||
if context['user_id'] != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 获取对话历史
|
||||
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||
context['message_history'] = history
|
||||
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message="获取对话上下文成功",
|
||||
data=context
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取对话上下文失败: {e}")
|
||||
# 获取对话上下文
|
||||
session.desc = f"START: 获取对话上下文,对话ID: {conversation_id}"
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取对话上下文失败: {str(e)}"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
# 验证用户权限
|
||||
if context['user_id'] != current_user.id:
|
||||
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 获取对话历史
|
||||
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||
context['message_history'] = history
|
||||
session.desc = f"SUCCESS: 获取对话上下文成功,对话ID: {conversation_id}"
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message="获取对话上下文成功",
|
||||
data=context
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/status", response_model=ConversationContextResponse)
|
||||
@router.get("/files/status", response_model=ConversationContextResponse, summary="获取用户当前的文件状态和统计信息")
|
||||
async def get_files_status(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取用户当前的文件状态和统计信息
|
||||
"""
|
||||
try:
|
||||
workflow_manager = SmartWorkflowManager()
|
||||
|
||||
# 获取用户文件列表
|
||||
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||
|
||||
# 统计信息
|
||||
total_files = len(file_list)
|
||||
total_rows = sum(f.get('row_count', 0) for f in file_list)
|
||||
total_columns = sum(f.get('column_count', 0) for f in file_list)
|
||||
|
||||
# 文件类型统计
|
||||
file_types = {}
|
||||
for file_info in file_list:
|
||||
filename = file_info['filename']
|
||||
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
|
||||
file_types[ext] = file_types.get(ext, 0) + 1
|
||||
|
||||
status_data = {
|
||||
'total_files': total_files,
|
||||
'total_rows': total_rows,
|
||||
'total_columns': total_columns,
|
||||
'file_types': file_types,
|
||||
'files': [{
|
||||
'id': f['id'],
|
||||
'filename': f['filename'],
|
||||
'row_count': f.get('row_count', 0),
|
||||
'column_count': f.get('column_count', 0),
|
||||
'columns': f.get('columns', []),
|
||||
'upload_time': f.get('upload_time')
|
||||
} for f in file_list],
|
||||
'ready_for_query': total_files > 0
|
||||
}
|
||||
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||
data=status_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件状态失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取文件状态失败: {str(e)}"
|
||||
)
|
||||
session.desc = f"START: 获取用户文件状态和统计信息,用户ID: {current_user.id}"
|
||||
workflow_manager = SmartWorkflowManager()
|
||||
|
||||
# 获取用户文件列表
|
||||
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||
|
||||
# 统计信息
|
||||
total_files = len(file_list)
|
||||
total_rows = sum(f.get('row_count', 0) for f in file_list)
|
||||
total_columns = sum(f.get('column_count', 0) for f in file_list)
|
||||
|
||||
# 文件类型统计
|
||||
file_types = {}
|
||||
for file_info in file_list:
|
||||
filename = file_info['filename']
|
||||
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
|
||||
file_types[ext] = file_types.get(ext, 0) + 1
|
||||
|
||||
status_data = {
|
||||
'total_files': total_files,
|
||||
'total_rows': total_rows,
|
||||
'total_columns': total_columns,
|
||||
'file_types': file_types,
|
||||
'files': [{
|
||||
'id': f['id'],
|
||||
'filename': f['filename'],
|
||||
'row_count': f.get('row_count', 0),
|
||||
'column_count': f.get('column_count', 0),
|
||||
'columns': f.get('columns', []),
|
||||
'upload_time': f.get('upload_time')
|
||||
} for f in file_list],
|
||||
'ready_for_query': total_files > 0
|
||||
}
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户文件状态和统计信息成功,用户ID: {current_user.id}"
|
||||
return ConversationContextResponse(
|
||||
success=True,
|
||||
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||
data=status_data
|
||||
)
|
||||
|
||||
@router.post("/conversation/{conversation_id}/reset")
|
||||
@router.post("/conversation/{conversation_id}/reset", summary="重置对话上下文")
|
||||
async def reset_conversation_context(
|
||||
conversation_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
重置对话上下文,清除历史查询记录但保留文件
|
||||
"""
|
||||
try:
|
||||
# 验证对话存在和用户权限
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
if context['user_id'] != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 重置对话上下文
|
||||
success = await conversation_context_service.reset_conversation_context(conversation_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "对话上下文已重置,可以开始新的数据分析会话"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="重置对话上下文失败"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置对话上下文失败: {e}")
|
||||
session.desc = f"START: 重置对话上下文,对话ID: {conversation_id}"
|
||||
# 验证对话存在和用户权限
|
||||
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||
|
||||
if not context:
|
||||
session.desc = f"ERROR: 对话上下文不存在,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话上下文不存在"
|
||||
)
|
||||
|
||||
if context['user_id'] != current_user.id:
|
||||
session.desc = f"ERROR: 无权访问对话上下文,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权访问此对话"
|
||||
)
|
||||
|
||||
# 重置对话上下文
|
||||
success = await conversation_context_service.reset_conversation_context(conversation_id)
|
||||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 重置对话上下文成功,对话ID: {conversation_id}"
|
||||
return {
|
||||
"success": True,
|
||||
"message": "对话上下文已重置,可以开始新的数据分析会话"
|
||||
}
|
||||
else:
|
||||
session.desc = f"EXCEPTION: 重置对话上下文失败,对话ID: {conversation_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"重置对话上下文失败: {str(e)}"
|
||||
)
|
||||
detail="重置对话上下文失败"
|
||||
)
|
||||
|
||||
|
|
@ -3,10 +3,9 @@ from fastapi.security import HTTPBearer
|
|||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any, List
|
||||
import pandas as pd
|
||||
from th_agenter.utils.schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||
from utils.util_schemas import FileListResponse,ExcelPreviewRequest,NormalResponse, BaseResponse
|
||||
import os
|
||||
import tempfile
|
||||
from th_agenter.utils.schemas import BaseResponse
|
||||
from th_agenter.services.smart_query import (
|
||||
SmartQueryService,
|
||||
ExcelAnalysisService,
|
||||
|
|
@ -16,7 +15,7 @@ from th_agenter.services.excel_metadata_service import ExcelMetadataService
|
|||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from th_agenter.utils.file_utils import FileUtils
|
||||
from utils.util_file import FileUtils
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -25,16 +24,12 @@ from typing import Optional, AsyncGenerator
|
|||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from th_agenter.db.database import get_db
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.auth import AuthService
|
||||
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||
from th_agenter.services.conversation_context import ConversationContextService
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
|
||||
security = HTTPBearer()
|
||||
|
|
@ -63,206 +58,195 @@ class ExcelUploadResponse(BaseModel):
|
|||
message: str
|
||||
data: Optional[Dict[str, Any]] = None # 添加data字段
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.post("/upload-excel", response_model=ExcelUploadResponse)
|
||||
@router.post("/upload-excel", response_model=ExcelUploadResponse, summary="上传Excel文件并进行预处理")
|
||||
async def upload_excel(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
上传Excel文件并进行预处理
|
||||
"""
|
||||
try:
|
||||
# 验证文件类型
|
||||
allowed_extensions = ['.xlsx', '.xls', '.csv']
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
|
||||
if file_extension not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件"
|
||||
)
|
||||
|
||||
# 验证文件大小 (10MB)
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
if file_size > 10 * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件大小不能超过 10MB"
|
||||
)
|
||||
|
||||
# 创建持久化目录结构
|
||||
backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录
|
||||
data_dir = backend_dir / "data/uploads"
|
||||
excel_user_dir = data_dir / f"excel_{current_user.id}"
|
||||
|
||||
# 确保目录存在
|
||||
excel_user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名:{uuid}_{原始文件名称}
|
||||
file_id = str(uuid.uuid4())
|
||||
safe_filename = FileUtils.sanitize_filename(file.filename)
|
||||
new_filename = f"{file_id}_{safe_filename}"
|
||||
file_path = excel_user_dir / new_filename
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# 使用Excel元信息服务提取并保存元信息
|
||||
metadata_service = ExcelMetadataService(db)
|
||||
excel_file = metadata_service.save_file_metadata(
|
||||
file_path=str(file_path),
|
||||
original_filename=file.filename,
|
||||
user_id=current_user.id,
|
||||
file_size=file_size
|
||||
)
|
||||
|
||||
# 为了兼容现有前端,仍然创建pickle文件
|
||||
try:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='utf-8')
|
||||
else:
|
||||
df = pd.read_excel(file_path)
|
||||
except UnicodeDecodeError:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='gbk')
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件编码错误,请确保文件为UTF-8或GBK编码"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件读取失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 保存pickle文件到同一目录
|
||||
pickle_filename = f"{file_id}_{safe_filename}.pkl"
|
||||
pickle_path = excel_user_dir / pickle_filename
|
||||
df.to_pickle(pickle_path)
|
||||
|
||||
# 数据预处理和分析(保持兼容性)
|
||||
excel_service = ExcelAnalysisService()
|
||||
analysis_result = excel_service.analyze_dataframe(df, file.filename)
|
||||
|
||||
# 添加数据库文件信息
|
||||
analysis_result.update({
|
||||
'file_id': str(excel_file.id),
|
||||
'database_id': excel_file.id,
|
||||
'temp_file_path': str(pickle_path), # 更新为新的pickle路径
|
||||
'original_filename': file.filename,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
})
|
||||
|
||||
return ExcelUploadResponse(
|
||||
file_id=excel_file.id,
|
||||
success=True,
|
||||
message="Excel文件上传成功",
|
||||
data=analysis_result
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
session.desc = f"START: 用户 {current_user.username} 上传 Excel 文件并进行预处理"
|
||||
# 验证文件类型
|
||||
allowed_extensions = ['.xlsx', '.xls', '.csv']
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
|
||||
if file_extension not in allowed_extensions:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传了不支持的文件格式 {file_extension},请上传 .xlsx, .xls 或 .csv 文件"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"文件处理失败: {str(e)}"
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件"
|
||||
)
|
||||
|
||||
# 验证文件大小 (10MB)
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
if file_size > 10 * 1024 * 1024:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 大小为 {file_size / (1024 * 1024):.2f}MB,超过最大限制 10MB"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件大小不能超过 10MB"
|
||||
)
|
||||
|
||||
# 创建持久化目录结构
|
||||
backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录
|
||||
data_dir = backend_dir / "data/uploads"
|
||||
excel_user_dir = data_dir / f"excel_{current_user.id}"
|
||||
|
||||
# 确保目录存在
|
||||
excel_user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名:{uuid}_{原始文件名称}
|
||||
file_id = str(uuid.uuid4())
|
||||
safe_filename = FileUtils.sanitize_filename(file.filename)
|
||||
new_filename = f"{file_id}_{safe_filename}"
|
||||
file_path = excel_user_dir / new_filename
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# 使用Excel元信息服务提取并保存元信息
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.save_file_metadata(
|
||||
file_path=str(file_path),
|
||||
original_filename=file.filename,
|
||||
user_id=current_user.id,
|
||||
file_size=file_size
|
||||
)
|
||||
|
||||
# 为了兼容现有前端,仍然创建pickle文件
|
||||
try:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='utf-8')
|
||||
else:
|
||||
df = pd.read_excel(file_path)
|
||||
except UnicodeDecodeError:
|
||||
if file_extension == '.csv':
|
||||
df = pd.read_csv(file_path, encoding='gbk')
|
||||
else:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 编码错误,请确保文件为UTF-8或GBK编码"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="文件编码错误,请确保文件为UTF-8或GBK编码"
|
||||
)
|
||||
except Exception as e:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 上传的文件 {file.filename} 读取失败: {str(e)}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"文件读取失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 保存pickle文件到同一目录
|
||||
pickle_filename = f"{file_id}_{safe_filename}.pkl"
|
||||
pickle_path = excel_user_dir / pickle_filename
|
||||
df.to_pickle(pickle_path)
|
||||
|
||||
# 数据预处理和分析(保持兼容性)
|
||||
excel_service = ExcelAnalysisService()
|
||||
analysis_result = excel_service.analyze_dataframe(df, file.filename)
|
||||
|
||||
# 添加数据库文件信息
|
||||
analysis_result.update({
|
||||
'file_id': str(excel_file.id),
|
||||
'database_id': excel_file.id,
|
||||
'temp_file_path': str(pickle_path), # 更新为新的pickle路径
|
||||
'original_filename': file.filename,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
})
|
||||
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 上传的文件 {file.filename} 预处理成功,文件ID: {excel_file.id}"
|
||||
return ExcelUploadResponse(
|
||||
file_id=excel_file.id,
|
||||
success=True,
|
||||
message="Excel文件上传成功",
|
||||
data=analysis_result
|
||||
)
|
||||
|
||||
@router.post("/preview-excel", response_model=QueryResponse)
|
||||
@router.post("/preview-excel", response_model=QueryResponse, summary="预览Excel文件数据")
|
||||
async def preview_excel(
|
||||
request: ExcelPreviewRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
预览Excel文件数据
|
||||
"""
|
||||
session.desc = f"START: 用户 {current_user.username} 预览文件 {request.file_id}"
|
||||
|
||||
# 验证file_id格式
|
||||
try:
|
||||
logger.info(f"Preview request for file_id: {request.file_id}, user: {current_user.id}")
|
||||
|
||||
# 验证file_id格式
|
||||
try:
|
||||
file_id = int(request.file_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"无效的文件ID格式: {request.file_id}"
|
||||
)
|
||||
|
||||
# 从数据库获取文件信息
|
||||
metadata_service = ExcelMetadataService(db)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在或已被删除"
|
||||
)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(excel_file.file_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件已被移动或删除"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
# 读取Excel文件
|
||||
if excel_file.file_type.lower() == 'csv':
|
||||
df = pd.read_csv(excel_file.file_path, encoding='utf-8')
|
||||
else:
|
||||
# 对于Excel文件,使用默认sheet或第一个sheet
|
||||
sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0
|
||||
df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name)
|
||||
|
||||
# 计算分页
|
||||
total_rows = len(df)
|
||||
start_idx = (request.page - 1) * request.page_size
|
||||
end_idx = start_idx + request.page_size
|
||||
|
||||
# 获取分页数据
|
||||
paginated_df = df.iloc[start_idx:end_idx]
|
||||
|
||||
# 转换为字典格式
|
||||
data = paginated_df.fillna('').to_dict('records')
|
||||
columns = df.columns.tolist()
|
||||
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="Excel文件预览加载成功",
|
||||
data={
|
||||
'data': data,
|
||||
'columns': columns,
|
||||
'total_rows': total_rows,
|
||||
'page': request.page,
|
||||
'page_size': request.page_size,
|
||||
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
file_id = int(request.file_id)
|
||||
except ValueError:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 提供了无效的文件ID格式: {request.file_id}"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"预览文件失败: {str(e)}"
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"无效的文件ID格式: {request.file_id}"
|
||||
)
|
||||
|
||||
# 从数据库获取文件信息
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 不存在或已被删除"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在或已被删除"
|
||||
)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(excel_file.file_path):
|
||||
session.desc = f"ERROR: 用户 {current_user.username} 预览文件 {request.file_id} 已被移动或删除"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件已被移动或删除"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
# 读取Excel文件
|
||||
if excel_file.file_type.lower() == 'csv':
|
||||
df = pd.read_csv(excel_file.file_path, encoding='utf-8')
|
||||
else:
|
||||
# 对于Excel文件,使用默认sheet或第一个sheet
|
||||
sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0
|
||||
df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name)
|
||||
|
||||
# 计算分页
|
||||
total_rows = len(df)
|
||||
start_idx = (request.page - 1) * request.page_size
|
||||
end_idx = start_idx + request.page_size
|
||||
|
||||
# 获取分页数据
|
||||
paginated_df = df.iloc[start_idx:end_idx]
|
||||
|
||||
# 转换为字典格式
|
||||
data = paginated_df.fillna('').to_dict('records')
|
||||
columns = df.columns.tolist()
|
||||
session.desc = f"SUCCESS: 用户 {current_user.username} 预览文件 {request.file_id} 加载成功,共 {total_rows} 行数据"
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="Excel文件预览加载成功",
|
||||
data={
|
||||
'data': data,
|
||||
'columns': columns,
|
||||
'total_rows': total_rows,
|
||||
'page': request.page,
|
||||
'page_size': request.page_size,
|
||||
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/test-db-connection", response_model=NormalResponse)
|
||||
@router.post("/test-db-connection", response_model=NormalResponse, summary="测试数据库连接")
|
||||
async def test_database_connection(
|
||||
config: DatabaseConfig,
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
|
|
@ -272,7 +256,7 @@ async def test_database_connection(
|
|||
"""
|
||||
try:
|
||||
db_service = DatabaseQueryService()
|
||||
is_connected = await db_service.test_connection(config.dict())
|
||||
is_connected = await db_service.test_connection(config.model_dump())
|
||||
|
||||
if is_connected:
|
||||
return NormalResponse(
|
||||
|
|
@ -296,12 +280,12 @@ async def test_database_connection(
|
|||
# async def connect_database(
|
||||
# config_id: int,
|
||||
# current_user = Depends(AuthService.get_current_user),
|
||||
# db: Session = Depends(get_db)
|
||||
# db: Session = Depends(get_session)
|
||||
# ):
|
||||
# """连接数据库并获取表列表"""
|
||||
# ... (整个方法都删除)
|
||||
|
||||
@router.post("/table-schema", response_model=QueryResponse)
|
||||
@router.post("/table-schema", response_model=QueryResponse, summary="获取数据表结构")
|
||||
async def get_table_schema(
|
||||
request: TableSchemaRequest,
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
|
|
@ -330,7 +314,6 @@ async def get_table_schema(
|
|||
success=False,
|
||||
message=f"获取表结构失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class StreamQueryRequest(BaseModel):
|
||||
query: str
|
||||
|
|
@ -343,12 +326,11 @@ class DatabaseStreamQueryRequest(BaseModel):
|
|||
conversation_id: Optional[int] = None
|
||||
is_new_conversation: bool = False
|
||||
|
||||
|
||||
@router.post("/execute-excel-query")
|
||||
@router.post("/execute-excel-query", summary="流式智能问答查询")
|
||||
async def stream_smart_query(
|
||||
request: StreamQueryRequest,
|
||||
current_user=Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
流式智能问答查询接口
|
||||
|
|
@ -372,7 +354,7 @@ async def stream_smart_query(
|
|||
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(db)
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
|
|
@ -466,12 +448,11 @@ async def stream_smart_query(
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/execute-db-query")
|
||||
@router.post("/execute-db-query", summary="流式数据库查询")
|
||||
async def execute_database_query(
|
||||
request: DatabaseStreamQueryRequest,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
流式数据库查询接口
|
||||
|
|
@ -495,7 +476,7 @@ async def execute_database_query(
|
|||
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理数据库查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化服务
|
||||
workflow_manager = SmartWorkflowManager(db)
|
||||
workflow_manager = SmartWorkflowManager(session)
|
||||
conversation_context_service = ConversationContextService()
|
||||
|
||||
# 处理对话上下文
|
||||
|
|
@ -588,7 +569,7 @@ async def execute_database_query(
|
|||
}
|
||||
)
|
||||
|
||||
@router.delete("/cleanup-temp-files")
|
||||
@router.delete("/cleanup-temp-files", summary="清理临时文件")
|
||||
async def cleanup_temp_files(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
|
|
@ -620,18 +601,19 @@ async def cleanup_temp_files(
|
|||
message=f"清理临时文件失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/files", response_model=FileListResponse)
|
||||
@router.get("/files", response_model=FileListResponse, summary="获取用户上传的Excel文件列表")
|
||||
async def get_file_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取用户上传的Excel文件列表
|
||||
"""
|
||||
try:
|
||||
metadata_service = ExcelMetadataService(db)
|
||||
session.desc = f"START: 获取用户 {current_user.id} 的文件列表"
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
skip = (page - 1) * page_size
|
||||
files, total = metadata_service.get_user_files(current_user.id, skip, page_size)
|
||||
|
||||
|
|
@ -651,6 +633,7 @@ async def get_file_list(
|
|||
}
|
||||
file_list.append(file_info)
|
||||
|
||||
session.desc = f"SUCCESS: 获取用户 {current_user.id} 的文件列表,共 {total} 个文件"
|
||||
return FileListResponse(
|
||||
success=True,
|
||||
message="获取文件列表成功",
|
||||
|
|
@ -669,25 +652,28 @@ async def get_file_list(
|
|||
message=f"获取文件列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=NormalResponse)
|
||||
@router.delete("/files/{file_id}", response_model=NormalResponse, summary="删除指定的Excel文件")
|
||||
async def delete_file(
|
||||
file_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
删除指定的Excel文件
|
||||
"""
|
||||
try:
|
||||
metadata_service = ExcelMetadataService(db)
|
||||
session.desc = f"START: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
success = metadata_service.delete_file(file_id, current_user.id)
|
||||
|
||||
if success:
|
||||
session.desc = f"SUCCESS: 删除用户 {current_user.id} 的文件 {file_id}"
|
||||
return NormalResponse(
|
||||
success=True,
|
||||
message="文件删除成功"
|
||||
)
|
||||
else:
|
||||
session.desc = f"ERROR: 删除用户 {current_user.id} 的文件 {file_id},文件不存在或删除失败"
|
||||
return NormalResponse(
|
||||
success=False,
|
||||
message="文件不存在或删除失败"
|
||||
|
|
@ -699,56 +685,49 @@ async def delete_file(
|
|||
message=str(e)
|
||||
)
|
||||
|
||||
@router.get("/files/{file_id}/info", response_model=QueryResponse)
|
||||
@router.get("/files/{file_id}/info", response_model=QueryResponse, summary="获取指定文件的详细信息")
|
||||
async def get_file_info(
|
||||
file_id: int,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""
|
||||
获取指定文件的详细信息
|
||||
"""
|
||||
try:
|
||||
metadata_service = ExcelMetadataService(db)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
file_info = {
|
||||
'id': excel_file.id,
|
||||
'filename': excel_file.original_filename,
|
||||
'file_size': excel_file.file_size,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'file_type': excel_file.file_type,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
'default_sheet': excel_file.default_sheet,
|
||||
'columns_info': excel_file.columns_info,
|
||||
'preview_data': excel_file.preview_data,
|
||||
'data_types': excel_file.data_types,
|
||||
'total_rows': excel_file.total_rows,
|
||||
'total_columns': excel_file.total_columns,
|
||||
'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None,
|
||||
'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None,
|
||||
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||
}
|
||||
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="获取文件信息成功",
|
||||
data=file_info
|
||||
metadata_service = ExcelMetadataService(session)
|
||||
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||
|
||||
if not excel_file:
|
||||
session.desc = f"ERROR: 获取用户 {current_user.id} 的文件 {file_id} 信息,文件不存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return QueryResponse(
|
||||
success=False,
|
||||
message=f"获取文件信息失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 更新最后访问时间
|
||||
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||
|
||||
file_info = {
|
||||
'id': excel_file.id,
|
||||
'filename': excel_file.original_filename,
|
||||
'file_size': excel_file.file_size,
|
||||
'file_size_mb': excel_file.file_size_mb,
|
||||
'file_type': excel_file.file_type,
|
||||
'sheet_names': excel_file.sheet_names,
|
||||
'default_sheet': excel_file.default_sheet,
|
||||
'columns_info': excel_file.columns_info,
|
||||
'preview_data': excel_file.preview_data,
|
||||
'data_types': excel_file.data_types,
|
||||
'total_rows': excel_file.total_rows,
|
||||
'total_columns': excel_file.total_columns,
|
||||
'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None,
|
||||
'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None,
|
||||
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||
}
|
||||
|
||||
return QueryResponse(
|
||||
success=True,
|
||||
message="获取文件信息成功",
|
||||
data=file_info
|
||||
)
|
||||
|
||||
|
|
@ -1,25 +1,21 @@
|
|||
"""表元数据管理API"""
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from th_agenter.models.user import User
|
||||
from th_agenter.db.database import get_db
|
||||
from th_agenter.db.database import get_session
|
||||
from th_agenter.services.table_metadata_service import TableMetadataService
|
||||
from th_agenter.utils.logger import get_logger
|
||||
from th_agenter.services.auth import AuthService
|
||||
|
||||
logger = get_logger("table_metadata_api")
|
||||
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
|
||||
|
||||
|
||||
class TableSelectionRequest(BaseModel):
|
||||
database_config_id: int = Field(..., description="数据库配置ID")
|
||||
table_names: List[str] = Field(..., description="选中的表名列表")
|
||||
|
||||
|
||||
class TableMetadataResponse(BaseModel):
|
||||
id: int
|
||||
table_name: str
|
||||
|
|
@ -33,50 +29,43 @@ class TableMetadataResponse(BaseModel):
|
|||
business_context: str
|
||||
last_synced_at: str
|
||||
|
||||
|
||||
class QASettingsUpdate(BaseModel):
|
||||
is_enabled_for_qa: bool = Field(default=True)
|
||||
qa_description: str = Field(default="")
|
||||
business_context: str = Field(default="")
|
||||
|
||||
|
||||
class TableByNameRequest(BaseModel):
|
||||
database_config_id: int = Field(..., description="数据库配置ID")
|
||||
table_name: str = Field(..., description="表名")
|
||||
|
||||
|
||||
@router.post("/collect")
|
||||
@router.post("/collect", summary="收集选中表的元数据")
|
||||
async def collect_table_metadata(
|
||||
request: TableSelectionRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""收集选中表的元数据"""
|
||||
try:
|
||||
service = TableMetadataService(db)
|
||||
result = await service.collect_and_save_table_metadata(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
request.table_names
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"收集表元数据失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
session.desc = f"START: 用户 {current_user.id} 收集表元数据"
|
||||
service = TableMetadataService(session)
|
||||
result = await service.collect_and_save_table_metadata(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
request.table_names
|
||||
)
|
||||
session.desc = f"SUCCESS: 用户 {current_user.id} 收集表元数据"
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", summary="获取用户表元数据列表")
|
||||
async def get_table_metadata(
|
||||
database_config_id: int = None,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""获取表元数据列表"""
|
||||
try:
|
||||
service = TableMetadataService(db)
|
||||
service = TableMetadataService(session)
|
||||
metadata_list = service.get_user_table_metadata(
|
||||
current_user.id,
|
||||
database_config_id
|
||||
|
|
@ -120,15 +109,15 @@ async def get_table_metadata(
|
|||
}
|
||||
|
||||
|
||||
@router.post("/by-table")
|
||||
@router.post("/by-table", summary="根据表名获取表元数据")
|
||||
async def get_table_metadata_by_name(
|
||||
request: TableByNameRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""根据表名获取表元数据"""
|
||||
try:
|
||||
service = TableMetadataService(db)
|
||||
service = TableMetadataService(session)
|
||||
metadata = service.get_table_metadata_by_name(
|
||||
current_user.id,
|
||||
request.database_config_id,
|
||||
|
|
@ -180,16 +169,16 @@ async def get_table_metadata_by_name(
|
|||
}
|
||||
|
||||
|
||||
@router.put("/{metadata_id}/qa-settings")
|
||||
@router.put("/{metadata_id}/qa-settings", summary="更新表的问答设置")
|
||||
async def update_qa_settings(
|
||||
metadata_id: int,
|
||||
settings: QASettingsUpdate,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""更新表的问答设置"""
|
||||
try:
|
||||
service = TableMetadataService(db)
|
||||
service = TableMetadataService(session)
|
||||
success = service.update_table_qa_settings(
|
||||
current_user.id,
|
||||
metadata_id,
|
||||
|
|
@ -220,29 +209,21 @@ class TableSaveRequest(BaseModel):
|
|||
async def save_table_metadata(
|
||||
request: TableSaveRequest,
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""保存选中表的元数据配置"""
|
||||
try:
|
||||
service = TableMetadataService(db)
|
||||
result = await service.save_table_metadata_config(
|
||||
user_id=current_user.id,
|
||||
database_config_id=request.database_config_id,
|
||||
table_names=request.table_names
|
||||
)
|
||||
|
||||
logger.info(f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||
"saved_tables": result['saved_tables'],
|
||||
"failed_tables": result.get('failed_tables', [])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表元数据配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"保存配置失败: {str(e)}"
|
||||
)
|
||||
service = TableMetadataService(session)
|
||||
result = await service.save_table_metadata_config(
|
||||
user_id=current_user.id,
|
||||
database_config_id=request.database_config_id,
|
||||
table_names=request.table_names
|
||||
)
|
||||
|
||||
session.desc = f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||
"saved_tables": result['saved_tables'],
|
||||
"failed_tables": result.get('failed_tables', [])
|
||||
}
|
||||
|
|
@ -4,7 +4,7 @@ from typing import List, Optional
|
|||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ...db.database import get_db
|
||||
from ...db.database import get_session
|
||||
from ...core.simple_permissions import require_super_admin
|
||||
from ...services.auth import AuthService
|
||||
from ...services.user import UserService
|
||||
|
|
@ -12,27 +12,25 @@ from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePassword
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/profile", response_model=UserResponse)
|
||||
@router.get("/profile", response_model=UserResponse, summary="获取当前用户的个人信息")
|
||||
async def get_user_profile(
|
||||
current_user = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""Get current user profile."""
|
||||
return UserResponse.from_orm(current_user)
|
||||
"""获取当前用户的个人信息."""
|
||||
return UserResponse.model_validate(current_user)
|
||||
|
||||
|
||||
@router.put("/profile", response_model=UserResponse)
|
||||
@router.put("/profile", response_model=UserResponse, summary="更新当前用户的个人信息")
|
||||
async def update_user_profile(
|
||||
user_update: UserUpdate,
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Update current user profile."""
|
||||
user_service = UserService(db)
|
||||
"""更新当前用户的个人信息."""
|
||||
user_service = UserService(session)
|
||||
|
||||
# Check if email is being changed and is already taken
|
||||
if user_update.email and user_update.email != current_user.email:
|
||||
existing_user = user_service.get_user_by_email(user_update.email)
|
||||
existing_user = await user_service.get_user_by_email(user_update.email)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -40,66 +38,67 @@ async def update_user_profile(
|
|||
)
|
||||
|
||||
# Update user
|
||||
updated_user = user_service.update_user(current_user.id, user_update)
|
||||
return UserResponse.from_orm(updated_user)
|
||||
updated_user = await user_service.update_user(current_user.id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
|
||||
@router.delete("/profile")
|
||||
@router.delete("/profile", summary="删除当前用户的账户")
|
||||
async def delete_user_account(
|
||||
current_user = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Delete current user account."""
|
||||
user_service = UserService(db)
|
||||
user_service.delete_user(current_user.id)
|
||||
return {"message": "Account deleted successfully"}
|
||||
|
||||
"""删除当前用户的账户."""
|
||||
username = current_user.username
|
||||
user_service = UserService(session)
|
||||
await user_service.delete_user(current_user.id)
|
||||
session.desc = f"删除用户 [{username}] 成功"
|
||||
return {"message": f"删除用户 {username} 成功"}
|
||||
|
||||
# Admin endpoints
|
||||
@router.post("/", response_model=UserResponse)
|
||||
async def create_user(
|
||||
@router.post("/", response_model=UserResponse, summary="创建新用户 (需要有管理员权限)")
|
||||
async def create_user(
|
||||
user_create: UserCreate,
|
||||
# current_user = Depends(require_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
current_user = Depends(require_super_admin),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Create a new user (admin only)."""
|
||||
user_service = UserService(db)
|
||||
"""创建一个新用户 (需要有管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
# Check if username already exists
|
||||
existing_user = user_service.get_user_by_username(user_create.username)
|
||||
existing_user = await user_service.get_user_by_username(user_create.username)
|
||||
if existing_user:
|
||||
session.desc = f"创建用户 [{user_create.username}] 失败 - 用户名已存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
# Check if email already exists
|
||||
existing_user = user_service.get_user_by_email(user_create.email)
|
||||
existing_user = await user_service.get_user_by_email(user_create.email)
|
||||
if existing_user:
|
||||
session.desc = f"创建用户 [{user_create.username}] 失败 - 邮箱已存在"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# Create user
|
||||
new_user = user_service.create_user(user_create)
|
||||
return UserResponse.from_orm(new_user)
|
||||
new_user = await user_service.create_user(user_create)
|
||||
return UserResponse.model_validate(new_user)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", summary="列出所有用户,支持分页和筛选 (仅管理员权限)")
|
||||
async def list_users(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None),
|
||||
role_id: Optional[int] = Query(None),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
# current_user = Depends(require_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""List all users with pagination and filters (admin only)."""
|
||||
user_service = UserService(db)
|
||||
"""列出所有用户,支持分页和筛选 (仅管理员权限)."""
|
||||
session.desc = f"START: 列出所有用户,分页={page}, 每页大小={size}, 搜索={search}, 角色ID={role_id}, 激活状态={is_active}"
|
||||
user_service = UserService(session)
|
||||
skip = (page - 1) * size
|
||||
users, total = user_service.get_users_with_filters(
|
||||
users, total = await user_service.get_users_with_filters(
|
||||
skip=skip,
|
||||
limit=size,
|
||||
search=search,
|
||||
|
|
@ -107,7 +106,7 @@ async def list_users(
|
|||
is_active=is_active
|
||||
)
|
||||
result = {
|
||||
"users": [UserResponse.from_orm(user) for user in users],
|
||||
"users": [UserResponse.model_validate(user) for user in users],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": size
|
||||
|
|
@ -115,34 +114,33 @@ async def list_users(
|
|||
return result
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
@router.get("/{user_id}", response_model=UserResponse, summary="通过ID获取用户信息 (仅管理员权限)")
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Get user by ID (admin only)."""
|
||||
user_service = UserService(db)
|
||||
user = user_service.get_user(user_id)
|
||||
"""通过ID获取用户信息 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return UserResponse.from_orm(user)
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.put("/change-password")
|
||||
@router.put("/change-password", summary="修改当前用户的密码")
|
||||
async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Change current user's password."""
|
||||
user_service = UserService(db)
|
||||
"""修改当前用户的密码."""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
user_service.change_password(
|
||||
await user_service.change_password(
|
||||
user_id=current_user.id,
|
||||
current_password=request.current_password,
|
||||
new_password=request.new_password
|
||||
|
|
@ -165,19 +163,18 @@ async def change_password(
|
|||
detail="Failed to change password"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}/reset-password")
|
||||
@router.put("/{user_id}/reset-password", summary="重置用户密码 (仅管理员权限)")
|
||||
async def reset_user_password(
|
||||
user_id: int,
|
||||
request: ResetPasswordRequest,
|
||||
current_user = Depends(require_super_admin),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Reset user password (admin only)."""
|
||||
user_service = UserService(db)
|
||||
"""重置用户密码 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
user_service.reset_password(
|
||||
await user_service.reset_password(
|
||||
user_id=user_id,
|
||||
new_password=request.new_password
|
||||
)
|
||||
|
|
@ -200,42 +197,41 @@ async def reset_user_password(
|
|||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
@router.put("/{user_id}", response_model=UserResponse, summary="更新用户信息 (仅管理员权限)")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_update: UserUpdate,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Update user by ID (admin only)."""
|
||||
user_service = UserService(db)
|
||||
"""更新用户信息 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
user = user_service.get_user(user_id)
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
updated_user = user_service.update_user(user_id, user_update)
|
||||
return UserResponse.from_orm(updated_user)
|
||||
updated_user = await user_service.update_user(user_id, user_update)
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
@router.delete("/{user_id}", summary="删除用户 (仅管理员权限)")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user = Depends(AuthService.get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
"""Delete user by ID (admin only)."""
|
||||
user_service = UserService(db)
|
||||
"""删除用户 (仅管理员权限)."""
|
||||
user_service = UserService(session)
|
||||
|
||||
user = user_service.get_user(user_id)
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
user_service.delete_user(user_id)
|
||||
await user_service.delete_user(user_id)
|
||||
return {"message": "User deleted successfully"}
|
||||
|
|
@ -4,11 +4,11 @@ from typing import List, Optional, AsyncGenerator
|
|||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import select, and_, func
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ...db.database import get_db
|
||||
from ...db.database import get_session
|
||||
from ...schemas.workflow import (
|
||||
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
|
||||
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
|
||||
|
|
@ -17,9 +17,7 @@ from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
|
|||
from ...services.workflow_engine import get_workflow_engine
|
||||
from ...services.auth import AuthService
|
||||
from ...models.user import User
|
||||
from ...utils.logger import get_logger
|
||||
|
||||
logger = get_logger("workflow_api")
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -36,41 +34,33 @@ def convert_workflow_for_response(workflow_dict):
|
|||
@router.post("/", response_model=WorkflowResponse)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""创建工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 创建工作流
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
definition=workflow_data.definition.dict(),
|
||||
version="1.0.0",
|
||||
status=workflow_data.status,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
workflow.set_audit_fields(current_user.id)
|
||||
|
||||
db.add(workflow)
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
# 转换definition中的字段映射
|
||||
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||
|
||||
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="创建工作流失败"
|
||||
)
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 创建工作流
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
definition=workflow_data.definition.model_dump(),
|
||||
version="1.0.0",
|
||||
status=workflow_data.status,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
workflow.set_audit_fields(current_user.id)
|
||||
|
||||
await session.add(workflow)
|
||||
await session.commit()
|
||||
await session.refresh(workflow)
|
||||
|
||||
# 转换definition中的字段映射
|
||||
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||
|
||||
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**workflow_dict)
|
||||
|
||||
|
||||
@router.get("/", response_model=WorkflowListResponse)
|
||||
async def list_workflows(
|
||||
|
|
@ -78,325 +68,271 @@ async def list_workflows(
|
|||
limit: Optional[int] = Query(None, ge=1, le=100),
|
||||
workflow_status: Optional[WorkflowStatus] = None,
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流列表"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 构建查询
|
||||
query = db.query(Workflow).filter(Workflow.owner_id == current_user.id)
|
||||
|
||||
if workflow_status:
|
||||
query = query.filter(Workflow.status == workflow_status)
|
||||
|
||||
# 添加搜索功能
|
||||
if search:
|
||||
query = query.filter(Workflow.name.ilike(f"%{search}%"))
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 如果没有传分页参数,返回所有数据
|
||||
if skip is None and limit is None:
|
||||
workflows = query.all()
|
||||
return WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=1,
|
||||
size=total
|
||||
)
|
||||
|
||||
# 使用默认分页参数
|
||||
if skip is None:
|
||||
skip = 0
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
# 分页查询
|
||||
workflows = query.offset(skip).limit(limit).all()
|
||||
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
# 构建查询
|
||||
stmt = select(Workflow).where(Workflow.owner_id == current_user.id)
|
||||
|
||||
if workflow_status:
|
||||
stmt = stmt.where(Workflow.status == workflow_status)
|
||||
|
||||
# 添加搜索功能
|
||||
if search:
|
||||
stmt = stmt.where(Workflow.name.ilike(f"%{search}%"))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count(Workflow.id)).where(Workflow.owner_id == current_user.id)
|
||||
if workflow_status:
|
||||
count_query = count_query.where(Workflow.status == workflow_status)
|
||||
if search:
|
||||
count_query = count_query.where(Workflow.name.ilike(f"%{search}%"))
|
||||
total = session.scalar(count_query)
|
||||
|
||||
# 如果没有传分页参数,返回所有数据
|
||||
if skip is None and limit is None:
|
||||
workflows = session.scalars(stmt).all()
|
||||
return WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=skip // limit + 1, # 计算页码
|
||||
size=limit
|
||||
page=1,
|
||||
size=total
|
||||
)
|
||||
|
||||
# 使用默认分页参数
|
||||
if skip is None:
|
||||
skip = 0
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workflows: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取工作流列表失败"
|
||||
)
|
||||
# 分页查询
|
||||
workflows = session.scalars(stmt.offset(skip).limit(limit)).all()
|
||||
|
||||
return WorkflowListResponse(
|
||||
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||
total=total,
|
||||
page=skip // limit + 1, # 计算页码
|
||||
size=limit
|
||||
)
|
||||
|
||||
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def get_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流详情"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取工作流失败"
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||
async def update_workflow(
|
||||
workflow_id: int,
|
||||
workflow_data: WorkflowUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""更新工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = workflow_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "definition" and value:
|
||||
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||
if hasattr(value, 'dict'):
|
||||
setattr(workflow, field, value.dict())
|
||||
else:
|
||||
setattr(workflow, field, value)
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = workflow_data.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "definition" and value:
|
||||
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||
if hasattr(value, 'dict'):
|
||||
setattr(workflow, field, value.dict())
|
||||
else:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
db.refresh(workflow)
|
||||
|
||||
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating workflow {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="更新工作流失败"
|
||||
)
|
||||
else:
|
||||
setattr(workflow, field, value)
|
||||
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
session.refresh(workflow)
|
||||
|
||||
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
|
||||
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||
|
||||
|
||||
@router.delete("/{workflow_id}")
|
||||
async def delete_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""删除工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
db.delete(workflow)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流删除成功"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting workflow {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="删除工作流失败"
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
session.delete(workflow)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流删除成功"}
|
||||
|
||||
|
||||
@router.post("/{workflow_id}/activate")
|
||||
async def activate_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""激活工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流激活成功"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error activating workflow {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="激活工作流失败"
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流激活成功"}
|
||||
|
||||
@router.post("/{workflow_id}/deactivate")
|
||||
async def deactivate_workflow(
|
||||
workflow_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""停用工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流停用成功"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating workflow {workflow_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="停用工作流失败"
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
|
||||
return {"message": "工作流停用成功"}
|
||||
|
||||
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||
async def execute_workflow(
|
||||
workflow_id: int,
|
||||
request: WorkflowExecuteRequest,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""执行工作流"""
|
||||
try:
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = db.query(Workflow).filter(
|
||||
from ...models.workflow import Workflow
|
||||
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="工作流未激活,无法执行"
|
||||
)
|
||||
|
||||
# 获取工作流引擎并执行
|
||||
engine = get_workflow_engine()
|
||||
execution_result = await engine.execute_workflow(
|
||||
workflow=workflow,
|
||||
input_data=request.input_data,
|
||||
user_id=current_user.id,
|
||||
db=db
|
||||
)
|
||||
|
||||
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
|
||||
return execution_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing workflow {workflow_id}: {str(e)}")
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"执行工作流失败: {str(e)}"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作流不存在"
|
||||
)
|
||||
|
||||
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="工作流未激活,无法执行"
|
||||
)
|
||||
|
||||
# 获取工作流引擎并执行
|
||||
engine = get_workflow_engine()
|
||||
execution_result = await engine.execute_workflow(
|
||||
workflow=workflow,
|
||||
input_data=request.input_data,
|
||||
user_id=current_user.id,
|
||||
session=session
|
||||
)
|
||||
|
||||
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
|
||||
return execution_result
|
||||
|
||||
|
||||
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
|
||||
async def list_workflow_executions(
|
||||
workflow_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行历史"""
|
||||
|
|
@ -404,12 +340,14 @@ async def list_workflow_executions(
|
|||
from ...models.workflow import Workflow, WorkflowExecution
|
||||
|
||||
# 验证工作流所有权
|
||||
workflow = db.query(Workflow).filter(
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
)
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
|
|
@ -418,11 +356,13 @@ async def list_workflow_executions(
|
|||
)
|
||||
|
||||
# 获取执行历史
|
||||
executions = db.query(WorkflowExecution).filter(
|
||||
executions = session.scalars(
|
||||
select(WorkflowExecution).where(
|
||||
WorkflowExecution.workflow_id == workflow_id
|
||||
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit).all()
|
||||
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit)
|
||||
).all()
|
||||
|
||||
return [WorkflowExecutionResponse.from_orm(execution) for execution in executions]
|
||||
return [WorkflowExecutionResponse.model_validate(execution) for execution in executions]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -436,21 +376,21 @@ async def list_workflow_executions(
|
|||
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||
async def get_workflow_execution(
|
||||
execution_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""获取工作流执行详情"""
|
||||
try:
|
||||
from ...models.workflow import WorkflowExecution, Workflow
|
||||
|
||||
execution = db.query(WorkflowExecution).join(
|
||||
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||
).filter(
|
||||
and_(
|
||||
execution = session.scalar(
|
||||
select(WorkflowExecution).join(
|
||||
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||
).where(
|
||||
WorkflowExecution.id == execution_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
).first()
|
||||
)
|
||||
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
|
|
@ -458,7 +398,7 @@ async def get_workflow_execution(
|
|||
detail="执行记录不存在"
|
||||
)
|
||||
|
||||
return WorkflowExecutionResponse.from_orm(execution)
|
||||
return WorkflowExecutionResponse.model_validate(execution)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -474,7 +414,7 @@ async def get_workflow_execution(
|
|||
async def execute_workflow_stream(
|
||||
workflow_id: int,
|
||||
request: WorkflowExecuteRequest,
|
||||
db: Session = Depends(get_db),
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(AuthService.get_current_user)
|
||||
):
|
||||
"""流式执行工作流,实时推送节点执行状态"""
|
||||
|
|
@ -486,12 +426,14 @@ async def execute_workflow_stream(
|
|||
from ...models.workflow import Workflow
|
||||
|
||||
# 验证工作流
|
||||
workflow = db.query(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
workflow = session.scalar(
|
||||
select(Workflow).filter(
|
||||
and_(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.owner_id == current_user.id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
|
||||
|
|
@ -512,7 +454,7 @@ async def execute_workflow_stream(
|
|||
workflow=workflow,
|
||||
input_data=request.input_data,
|
||||
user_id=current_user.id,
|
||||
db=db
|
||||
session=session
|
||||
):
|
||||
# 推送工作流步骤
|
||||
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||
|
|
|
|||
|
|
@ -1,55 +1,68 @@
|
|||
"""Main API router."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .endpoints import chat
|
||||
|
||||
|
||||
# TODO: Add other routers when implemented
|
||||
from .endpoints import auth
|
||||
from .endpoints import knowledge_base
|
||||
from .endpoints import smart_query
|
||||
from .endpoints import smart_chat
|
||||
|
||||
from .endpoints import database_config
|
||||
from .endpoints import table_metadata
|
||||
|
||||
# System management endpoints
|
||||
# # System management endpoints
|
||||
from .endpoints import roles
|
||||
from .endpoints import llm_configs
|
||||
from .endpoints import users
|
||||
|
||||
# Workflow endpoints
|
||||
# # Workflow endpoints
|
||||
from .endpoints import workflow
|
||||
|
||||
|
||||
# Create main API router
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(
|
||||
auth.router,
|
||||
prefix="/auth",
|
||||
tags=["authentication"]
|
||||
tags=["身份验证"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
router.include_router(
|
||||
chat.router,
|
||||
prefix="/chat",
|
||||
tags=["chat"]
|
||||
users.router,
|
||||
prefix="/users",
|
||||
tags=["users"]
|
||||
)
|
||||
router.include_router(
|
||||
roles.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-roles"]
|
||||
)
|
||||
router.include_router(
|
||||
llm_configs.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-llm-configs"]
|
||||
)
|
||||
|
||||
|
||||
|
||||
router.include_router(
|
||||
knowledge_base.router,
|
||||
prefix="/knowledge-bases",
|
||||
tags=["knowledge-bases"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
database_config.router,
|
||||
tags=["database-config"]
|
||||
)
|
||||
router.include_router(
|
||||
table_metadata.router,
|
||||
tags=["table-metadata"]
|
||||
)
|
||||
router.include_router(
|
||||
smart_query.router,
|
||||
tags=["smart-query"]
|
||||
)
|
||||
router.include_router(
|
||||
chat.router,
|
||||
prefix="/chat",
|
||||
tags=["chat"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
smart_chat.router,
|
||||
|
|
@ -60,42 +73,11 @@ router.include_router(
|
|||
|
||||
|
||||
|
||||
router.include_router(
|
||||
database_config.router,
|
||||
tags=["database-config"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
table_metadata.router,
|
||||
tags=["table-metadata"]
|
||||
)
|
||||
|
||||
# System management routers
|
||||
router.include_router(
|
||||
roles.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-roles"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
llm_configs.router,
|
||||
prefix="/admin",
|
||||
tags=["admin-llm-configs"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
users.router,
|
||||
prefix="/users",
|
||||
tags=["users"]
|
||||
)
|
||||
|
||||
router.include_router(
|
||||
workflow.router,
|
||||
prefix="/workflows",
|
||||
tags=["workflows"]
|
||||
)
|
||||
|
||||
# Test endpoint
|
||||
@router.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "API test is working"}
|
||||
)
|
||||
|
|
@ -1 +1 @@
|
|||
"""Core module for TH-Agenter."""
|
||||
"""Core module for TH Agenter."""
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,6 +1,6 @@
|
|||
"""FastAPI application factory."""
|
||||
|
||||
import logging
|
||||
from loguru import logger
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -10,78 +10,59 @@ from fastapi.exceptions import RequestValidationError
|
|||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from .config import Settings
|
||||
from .logging import setup_logging
|
||||
from .middleware import UserContextMiddleware
|
||||
from ..api.routes import router
|
||||
from ..db.database import init_db
|
||||
from ..api.endpoints import table_metadata
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
# Startup
|
||||
logging.info("Starting up TH-Agenter application...")
|
||||
await init_db()
|
||||
logging.info("Database initialized")
|
||||
|
||||
logger.info("Starting up TH Agenter application...")
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logging.info("Shutting down TH-Agenter application...")
|
||||
logger.info("Shutting down TH Agenter application...")
|
||||
|
||||
|
||||
def create_app(settings: Settings = None) -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
if settings is None:
|
||||
from .config import get_settings
|
||||
settings = get_settings()
|
||||
# def create_app(settings: Settings = None) -> FastAPI:
|
||||
# """Create and configure FastAPI application."""
|
||||
# if settings is None:
|
||||
# from .config import get_settings
|
||||
# settings = get_settings()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(settings.logging)
|
||||
# # Create FastAPI app
|
||||
# app = FastAPI(
|
||||
# title=settings.app_name,
|
||||
# version=settings.app_version,
|
||||
# description="基于Vue的第一个聊天智能体应用,使用FastAPI后端,由DrGraph修改",
|
||||
# debug=settings.debug,
|
||||
# lifespan=lifespan,
|
||||
# )
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="A modern chat agent application with Vue frontend and FastAPI backend",
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
# # Add middleware
|
||||
# setup_middleware(app, settings)
|
||||
|
||||
# Add middleware
|
||||
setup_middleware(app, settings)
|
||||
# # Add exception handlers
|
||||
# setup_exception_handlers(app)
|
||||
|
||||
# Add exception handlers
|
||||
setup_exception_handlers(app)
|
||||
|
||||
# Include routers
|
||||
app.include_router(router, prefix="/api")
|
||||
# # Include routers
|
||||
# app.include_router(router, prefix="/api")
|
||||
|
||||
# app.include_router(table_metadata.router)
|
||||
# # 在现有导入中添加
|
||||
# from ..api.endpoints import database_config
|
||||
|
||||
|
||||
app.include_router(table_metadata.router)
|
||||
# 在现有导入中添加
|
||||
from ..api.endpoints import database_config
|
||||
|
||||
# 在路由注册部分添加
|
||||
app.include_router(database_config.router)
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "version": settings.app_version}
|
||||
# # 在路由注册部分添加
|
||||
# app.include_router(database_config.router)
|
||||
# # Health check endpoint
|
||||
# @app.get("/health")
|
||||
# async def health_check():
|
||||
# return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Chat Agent API is running"}
|
||||
|
||||
# Test endpoint
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "API is working"}
|
||||
|
||||
return app
|
||||
# # Root endpoint
|
||||
# @app.get("/")
|
||||
# async def root():
|
||||
# return {"message": "Chat Agent API is running"}
|
||||
# return app
|
||||
|
||||
|
||||
def setup_middleware(app: FastAPI, settings: Settings) -> None:
|
||||
|
|
@ -161,7 +142,7 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
|||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
logging.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
|
|
@ -174,4 +155,4 @@ def setup_exception_handlers(app: FastAPI) -> None:
|
|||
|
||||
|
||||
# Create the app instance
|
||||
app = create_app()
|
||||
# app = create_app()
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
"""Configuration management for TH-Agenter."""
|
||||
"""Configuration management for TH Agenter."""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
"""Database configuration."""
|
||||
url: str = Field(..., alias="database_url") # Must be provided via environment variable
|
||||
|
|
@ -23,7 +23,6 @@ class DatabaseSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class SecuritySettings(BaseSettings):
|
||||
"""Security configuration."""
|
||||
secret_key: str = Field(default="your-secret-key-here-change-in-production")
|
||||
|
|
@ -47,6 +46,7 @@ class ToolSetings(BaseSettings):
|
|||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
class LLMSettings(BaseSettings):
|
||||
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
|
||||
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||
|
|
@ -108,8 +108,7 @@ class LLMSettings(BaseSettings):
|
|||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
import logging
|
||||
logging.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||||
logger.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
|
|
@ -147,7 +146,6 @@ class LLMSettings(BaseSettings):
|
|||
})
|
||||
return config
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseSettings):
|
||||
"""Embedding模型配置 - 支持多种提供商."""
|
||||
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||
|
|
@ -202,8 +200,7 @@ class EmbeddingSettings(BaseSettings):
|
|||
return config
|
||||
except Exception as e:
|
||||
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||
import logging
|
||||
logging.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
logger.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||
|
||||
# 回退到原有的环境变量配置
|
||||
provider_configs = {
|
||||
|
|
@ -236,7 +233,6 @@ class EmbeddingSettings(BaseSettings):
|
|||
|
||||
return provider_configs.get(self.provider, provider_configs["zhipu"])
|
||||
|
||||
|
||||
class VectorDBSettings(BaseSettings):
|
||||
"""Vector database configuration."""
|
||||
type: str = Field(default="pgvector", alias="vector_db_type")
|
||||
|
|
@ -260,7 +256,6 @@ class VectorDBSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class FileSettings(BaseSettings):
|
||||
"""File processing configuration."""
|
||||
upload_dir: str = Field(default="./data/uploads")
|
||||
|
|
@ -300,7 +295,6 @@ class FileSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class StorageSettings(BaseSettings):
|
||||
"""Storage configuration."""
|
||||
storage_type: str = Field(default="local") # local or s3
|
||||
|
|
@ -320,23 +314,6 @@ class StorageSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class LoggingSettings(BaseSettings):
|
||||
"""Logging configuration."""
|
||||
level: str = Field(default="INFO")
|
||||
file: str = Field(default="./data/logs/app.log")
|
||||
format: str = Field(default="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
max_bytes: int = Field(default=10485760) # 10MB
|
||||
backup_count: int = Field(default=5)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class CORSSettings(BaseSettings):
|
||||
"""CORS configuration."""
|
||||
allowed_origins: List[str] = Field(default=["*"])
|
||||
|
|
@ -350,20 +327,18 @@ class CORSSettings(BaseSettings):
|
|||
"extra": "ignore"
|
||||
}
|
||||
|
||||
|
||||
class ChatSettings(BaseSettings):
|
||||
"""Chat configuration."""
|
||||
max_history_length: int = Field(default=10)
|
||||
system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。")
|
||||
max_response_tokens: int = Field(default=1000)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Main application settings."""
|
||||
|
||||
# App info
|
||||
app_name: str = Field(default="TH-Agenter")
|
||||
app_version: str = Field(default="0.1.0")
|
||||
app_name: str = Field(default="TH Agenter")
|
||||
app_version: str = Field(default="0.2.0")
|
||||
debug: bool = Field(default=True)
|
||||
environment: str = Field(default="development")
|
||||
|
||||
|
|
@ -379,7 +354,6 @@ class Settings(BaseSettings):
|
|||
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
|
||||
file: FileSettings = Field(default_factory=FileSettings)
|
||||
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||||
logging: LoggingSettings = Field(default_factory=LoggingSettings)
|
||||
cors: CORSSettings = Field(default_factory=CORSSettings)
|
||||
chat: ChatSettings = Field(default_factory=ChatSettings)
|
||||
tool: ToolSetings = Field(default_factory=ToolSetings)
|
||||
|
|
@ -391,13 +365,12 @@ class Settings(BaseSettings):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def load_from_yaml(cls, config_path: str = "../configs/settings.yaml") -> "Settings":
|
||||
def load_from_yaml(cls, config_path: str = "webIOs/configs/settings.yaml") -> "Settings":
|
||||
"""Load settings from YAML file."""
|
||||
config_file = Path(config_path)
|
||||
|
||||
# 如果配置文件不存在,尝试从backend目录查找
|
||||
if not config_file.exists():
|
||||
# 获取当前文件所在目录(backend/th_agenter/core)
|
||||
# 获取当前文件所在目录(backend/open_agent/core)
|
||||
current_dir = Path(__file__).parent
|
||||
# 向上两级到backend目录,然后找configs/settings.yaml
|
||||
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
|
||||
|
|
@ -424,14 +397,14 @@ class Settings(BaseSettings):
|
|||
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
|
||||
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
|
||||
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
|
||||
settings_kwargs['logging'] = LoggingSettings(**(config_data.get('logging', {})))
|
||||
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
|
||||
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
|
||||
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||||
|
||||
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||||
|
||||
# 添加顶级配置
|
||||
for key, value in config_data.items():
|
||||
if key not in settings_kwargs:
|
||||
# logger.error(f"顶级配置项 {key} 未在子设置类中找到,直接添加到 settings_kwargs")
|
||||
settings_kwargs[key] = value
|
||||
|
||||
return cls(**settings_kwargs)
|
||||
|
|
@ -471,12 +444,10 @@ class Settings(BaseSettings):
|
|||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings.load_from_yaml()
|
||||
settings = Settings.load_from_yaml()
|
||||
return settings
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = get_settings()
|
||||
|
|
@ -6,6 +6,7 @@ from contextvars import ContextVar
|
|||
from typing import Optional
|
||||
import threading
|
||||
from ..models.user import User
|
||||
from loguru import logger
|
||||
|
||||
# Context variable to store current user
|
||||
current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None)
|
||||
|
|
@ -20,8 +21,7 @@ class UserContext:
|
|||
@staticmethod
|
||||
def set_current_user(user: User) -> None:
|
||||
"""Set current user in context."""
|
||||
import logging
|
||||
logging.info(f"Setting user in context: {user.username} (ID: {user.id})")
|
||||
logger.info(f"[UserContext] - Setting user in context: {user.username} (ID: {user.id})")
|
||||
|
||||
# Set in ContextVar
|
||||
current_user_context.set(user)
|
||||
|
|
@ -31,13 +31,12 @@ class UserContext:
|
|||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logging.info(f"Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
|
||||
@staticmethod
|
||||
def set_current_user_with_token(user: User):
|
||||
"""Set current user in context and return token for cleanup."""
|
||||
import logging
|
||||
logging.info(f"Setting user in context with token: {user.username} (ID: {user.id})")
|
||||
logger.info(f"[UserContext] - Setting user in context with token: {user.username} (ID: {user.id})")
|
||||
|
||||
# Set in ContextVar and get token
|
||||
token = current_user_context.set(user)
|
||||
|
|
@ -47,15 +46,14 @@ class UserContext:
|
|||
|
||||
# Verify it was set
|
||||
verify_user = current_user_context.get()
|
||||
logging.info(f"Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
logger.info(f"[UserContext] - Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def reset_current_user_token(token):
|
||||
"""Reset current user context using token."""
|
||||
import logging
|
||||
logging.info("Resetting user context using token")
|
||||
logger.info("[UserContext] - Resetting user context using token")
|
||||
|
||||
# Reset ContextVar using token
|
||||
current_user_context.reset(token)
|
||||
|
|
@ -67,21 +65,21 @@ class UserContext:
|
|||
@staticmethod
|
||||
def get_current_user() -> Optional[User]:
|
||||
"""Get current user from context."""
|
||||
import logging
|
||||
logger.debug("[UserContext] - Attempting to get user from context")
|
||||
|
||||
# Try ContextVar first
|
||||
user = current_user_context.get()
|
||||
if user:
|
||||
logging.debug(f"Got user from ContextVar: {user.username} (ID: {user.id})")
|
||||
logger.debug(f"[UserContext] - Got user from ContextVar: {user.username} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# Fallback to thread-local
|
||||
user = getattr(_thread_local, 'current_user', None)
|
||||
if user:
|
||||
logging.debug(f"Got user from thread-local: {user.username} (ID: {user.id})")
|
||||
logger.debug(f"[UserContext] - Got user from thread-local: {user.username} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
logging.debug("No user found in context (neither ContextVar nor thread-local)")
|
||||
logger.debug("[UserContext] - No user found in context (neither ContextVar nor thread-local)")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -93,8 +91,7 @@ class UserContext:
|
|||
@staticmethod
|
||||
def clear_current_user() -> None:
|
||||
"""Clear current user from context."""
|
||||
import logging
|
||||
logging.info("Clearing user context")
|
||||
logger.info("[UserContext] - 清除当前用户上下文")
|
||||
|
||||
current_user_context.set(None)
|
||||
if hasattr(_thread_local, 'current_user'):
|
||||
|
|
|
|||
|
|
@ -2,23 +2,25 @@
|
|||
中间件管理,如上下文中间件:校验Token等
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Callable
|
||||
from loguru import logger
|
||||
from fastapi import status
|
||||
from utils.util_exceptions import HxfErrorResponse
|
||||
|
||||
from ..db.database import get_db_session
|
||||
from ..db.database import get_session, AsyncSessionFactory, engine_async
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from ..services.auth import AuthService
|
||||
from .context import UserContext
|
||||
|
||||
|
||||
|
||||
class UserContextMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to set user context for authenticated requests."""
|
||||
|
||||
def __init__(self, app, exclude_paths: list = None):
|
||||
super().__init__(app)
|
||||
self.canLog = True
|
||||
# Paths that don't require authentication
|
||||
self.exclude_paths = exclude_paths or [
|
||||
"/docs",
|
||||
|
|
@ -31,42 +33,48 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
"/auth/register",
|
||||
"/auth/login-oauth",
|
||||
"/health",
|
||||
"/test"
|
||||
"/static/"
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request and set user context if authenticated."""
|
||||
import logging
|
||||
logging.info(f"[MIDDLEWARE] Processing request: {request.method} {request.url.path}")
|
||||
if self.canLog:
|
||||
logger.warning(f"[MIDDLEWARE] - 接收到请求信息: {request.method} {request.url.path}")
|
||||
|
||||
# Skip authentication for excluded paths
|
||||
path = request.url.path
|
||||
logging.info(f"[MIDDLEWARE] Checking path: {path} against exclude_paths: {self.exclude_paths}")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 检查路由 [{path}] 是否需要跳过认证: against exclude_paths: {self.exclude_paths}")
|
||||
|
||||
should_skip = False
|
||||
for exclude_path in self.exclude_paths:
|
||||
# Exact match
|
||||
if path == exclude_path:
|
||||
should_skip = True
|
||||
logging.info(f"[MIDDLEWARE] Path {path} exactly matches exclude_path {exclude_path}")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 完全匹配排除路径 {exclude_path}")
|
||||
break
|
||||
# For paths ending with '/', check if request path starts with it
|
||||
elif exclude_path.endswith('/') and path.startswith(exclude_path):
|
||||
should_skip = True
|
||||
logging.info(f"[MIDDLEWARE] Path {path} starts with exclude_path {exclude_path}")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path} 开头")
|
||||
break
|
||||
# For paths not ending with '/', check if request path starts with it + '/'
|
||||
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
|
||||
should_skip = True
|
||||
logging.info(f"[MIDDLEWARE] Path {path} starts with exclude_path {exclude_path}/")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 以排除路径 {exclude_path}/ 开头")
|
||||
break
|
||||
|
||||
if should_skip:
|
||||
logging.info(f"[MIDDLEWARE] Skipping authentication for excluded path: {path}")
|
||||
if self.canLog:
|
||||
logger.warning(f"[MIDDLEWARE] - 路由 {path} 匹配排除路径,跳过认证 >>> await call_next")
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
logging.info(f"[MIDDLEWARE] Processing authenticated request: {path}")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 路由 {path} 需要认证,开始处理")
|
||||
|
||||
# Always clear any existing user context to ensure fresh authentication
|
||||
UserContext.clear_current_user()
|
||||
|
|
@ -80,12 +88,9 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
authorization = request.headers.get("Authorization")
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# No token provided, return 401 error
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Missing or invalid authorization header"},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
return HxfErrorResponse(
|
||||
message="缺少或无效的授权头",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Extract token
|
||||
|
|
@ -95,63 +100,61 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
payload = AuthService.verify_token(token)
|
||||
if payload is None:
|
||||
# Invalid token, return 401 error
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Invalid or expired token"},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
return HxfErrorResponse(
|
||||
message="无效或过期的令牌",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Get username from token
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Invalid token payload"},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
return HxfErrorResponse(
|
||||
message="令牌负载无效",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
db = get_db_session()
|
||||
from sqlalchemy import select
|
||||
from ..models.user import User
|
||||
|
||||
# 创建一个临时的异步会话获取用户信息
|
||||
session = AsyncSession(bind=engine_async)
|
||||
try:
|
||||
from ..models.user import User
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
stmt = select(User).where(User.username == username)
|
||||
user = await session.execute(stmt)
|
||||
user = user.scalar_one_or_none()
|
||||
if not user:
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "User not found"},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
return HxfErrorResponse(
|
||||
message="用户不存在",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "User account is inactive"},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
return HxfErrorResponse(
|
||||
message="用户账户已停用",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Set user in context using token mechanism
|
||||
user_token = UserContext.set_current_user_with_token(user)
|
||||
import logging
|
||||
logging.info(f"User {user.username} (ID: {user.id}) authenticated and set in context")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 用户 {user.username} (ID: {user.id}) 已通过认证并设置到上下文")
|
||||
|
||||
# Verify context is set correctly
|
||||
current_user_id = UserContext.get_current_user_id()
|
||||
logging.info(f"Verified current user ID in context: {current_user_id}")
|
||||
if self.canLog:
|
||||
logger.info(f"[MIDDLEWARE] - 已验证当前用户 ID: {current_user_id} 上下文")
|
||||
finally:
|
||||
db.close()
|
||||
await session.close()
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't fail the request
|
||||
import logging
|
||||
logging.warning(f"Error setting user context: {e}")
|
||||
logger.error(f"[MIDDLEWARE] - 认证过程中设置用户上下文出错: {e}")
|
||||
# Return 401 error
|
||||
return HxfErrorResponse(
|
||||
message="认证过程中出错",
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
try:
|
||||
|
|
@ -160,4 +163,5 @@ class UserContextMiddleware(BaseHTTPMiddleware):
|
|||
finally:
|
||||
# Always clear user context after request processing
|
||||
UserContext.clear_current_user()
|
||||
logging.debug(f"[MIDDLEWARE] Cleared user context after processing request: {path}")
|
||||
if self.canLog:
|
||||
logger.debug(f"[MIDDLEWARE] - 已清除请求处理后的用户上下文: {path}")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
from fastapi import HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..db.database import get_db
|
||||
from ..db.database import get_session
|
||||
from ..models.user import User
|
||||
from ..models.permission import Role
|
||||
from ..services.auth import AuthService
|
||||
|
|
@ -17,18 +17,39 @@ def is_super_admin(user: User, db: Session) -> bool:
|
|||
return False
|
||||
|
||||
# 检查用户是否有超级管理员角色
|
||||
for role in user.roles:
|
||||
if role.code == "SUPER_ADMIN":
|
||||
return True
|
||||
return False
|
||||
try:
|
||||
# 尝试访问已加载的角色
|
||||
for role in user.roles:
|
||||
if role.code == "SUPER_ADMIN":
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
# 如果角色未加载或访问失败,直接从数据库查询
|
||||
from sqlalchemy import select, and_
|
||||
from ..models.permission import Role, UserRole
|
||||
|
||||
try:
|
||||
# 直接查询用户角色
|
||||
stmt = select(Role).join(UserRole).filter(
|
||||
and_(
|
||||
UserRole.user_id == user.id,
|
||||
Role.code == "SUPER_ADMIN",
|
||||
Role.is_active == True
|
||||
)
|
||||
)
|
||||
super_admin_role = db.execute(stmt).scalar_one_or_none()
|
||||
return super_admin_role is not None
|
||||
except Exception:
|
||||
# 如果查询失败,返回False
|
||||
return False
|
||||
|
||||
|
||||
def require_super_admin(
|
||||
current_user: User = Depends(AuthService.get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session: Session = Depends(get_session)
|
||||
) -> User:
|
||||
"""要求超级管理员权限的依赖项."""
|
||||
if not is_super_admin(current_user, db):
|
||||
if not is_super_admin(current_user, session):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="需要超级管理员权限"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Database module for TH-Agenter."""
|
||||
"""Database module for TH Agenter."""
|
||||
|
||||
from .database import get_db, init_db
|
||||
from .database import get_session
|
||||
from .base import Base
|
||||
from th_agenter.models import User, Conversation, Message, KnowledgeBase, Document, AgentConfig, ExcelFile, Role, UserRole, LLMConfig, Workflow, WorkflowExecution, NodeExecution, DatabaseConfig, TableMetadata
|
||||
|
||||
__all__ = ["get_db", "init_db", "Base"]
|
||||
|
||||
__all__ = ["get_session", "Base"]
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,13 +1,27 @@
|
|||
"""Database base model."""
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, DateTime, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import Integer, DateTime, event
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
|
||||
from sqlalchemy.sql import func
|
||||
from typing import Optional
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(
|
||||
naming_convention={
|
||||
# ix: index, 索引
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
# uq: unique, 唯一约束
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
# ck: check, 检查约束
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
# fk: foreign key, 外键约束
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
# pk: primary key, 主键约束
|
||||
"pk": "pk_%(table_name)s"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
|
|
@ -15,17 +29,36 @@ class BaseModel(Base):
|
|||
|
||||
__abstract__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
created_at = Column(DateTime, default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||||
created_by = Column(Integer, nullable=True)
|
||||
updated_by = Column(Integer, nullable=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||||
created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert model to dictionary."""
|
||||
return {
|
||||
column.name: getattr(self, column.name)
|
||||
for column in self.__table__.columns
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
"""Create model instance from dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing model field values
|
||||
|
||||
Returns:
|
||||
Model instance created from the dictionary
|
||||
"""
|
||||
# Filter out fields that don't exist in the model
|
||||
model_fields = {column.name for column in cls.__table__.columns}
|
||||
filtered_data = {key: value for key, value in data.items() if key in model_fields}
|
||||
|
||||
# Create and return the instance
|
||||
return cls(**filtered_data)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize model with automatic audit fields setting."""
|
||||
super().__init__(**kwargs)
|
||||
# Set audit fields for new instances
|
||||
self.set_audit_fields()
|
||||
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||
"""Set audit fields for create/update operations.
|
||||
|
||||
|
|
@ -54,9 +87,57 @@ class BaseModel(Base):
|
|||
# For update operations, only set update_by
|
||||
self.updated_by = user_id
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert model to dictionary."""
|
||||
return {
|
||||
column.name: getattr(self, column.name)
|
||||
for column in self.__table__.columns
|
||||
}
|
||||
# @event.listens_for(Session, 'before_flush')
|
||||
# def set_audit_fields_before_flush(session, flush_context, instances):
|
||||
# """Automatically set audit fields before flush."""
|
||||
# try:
|
||||
# from th_agenter.core.context import UserContext
|
||||
# user_id = UserContext.get_current_user_id()
|
||||
# except Exception:
|
||||
# user_id = None
|
||||
|
||||
# # 处理新增对象
|
||||
# for instance in session.new:
|
||||
# if isinstance(instance, BaseModel) and user_id:
|
||||
# instance.created_by = user_id
|
||||
# instance.updated_by = user_id
|
||||
|
||||
# # 处理修改对象
|
||||
# for instance in session.dirty:
|
||||
# if isinstance(instance, BaseModel) and user_id:
|
||||
# instance.updated_by = user_id
|
||||
|
||||
# # def __init__(self, **kwargs):
|
||||
# # """Initialize model with automatic audit fields setting."""
|
||||
# # super().__init__(**kwargs)
|
||||
# # # Set audit fields for new instances
|
||||
# # self.set_audit_fields()
|
||||
|
||||
# # def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||
# # """Set audit fields for create/update operations.
|
||||
|
||||
# # Args:
|
||||
# # user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||
# # is_update: True for update operations, False for create operations
|
||||
# # """
|
||||
# # # Get user_id from context if not provided
|
||||
# # if user_id is None:
|
||||
# # from ..core.context import UserContext
|
||||
# # try:
|
||||
# # user_id = UserContext.get_current_user_id()
|
||||
# # except Exception:
|
||||
# # # If no user in context, skip setting audit fields
|
||||
# # return
|
||||
|
||||
# # # Skip if still no user_id
|
||||
# # if user_id is None:
|
||||
# # return
|
||||
|
||||
# # if not is_update:
|
||||
# # # For create operations, set both create_by and update_by
|
||||
# # self.created_by = user_id
|
||||
# # self.updated_by = user_id
|
||||
# # else:
|
||||
# # # For update operations, only set update_by
|
||||
# # self.updated_by = user_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,89 +1,118 @@
|
|||
"""Database connection and session management."""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from typing import Generator
|
||||
import uuid, re
|
||||
from loguru import logger
|
||||
import traceback
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import Optional
|
||||
|
||||
from ..core.config import get_settings
|
||||
from .base import Base
|
||||
from utils.util_exceptions import DatabaseError
|
||||
|
||||
# Custom Session class with desc property and unique ID
|
||||
class DrSession(AsyncSession):
|
||||
"""Custom Session class with desc property and unique ID."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize DrSession with unique ID."""
|
||||
super().__init__(**kwargs)
|
||||
# 确保info属性存在
|
||||
if not hasattr(self, 'info'):
|
||||
self.info = {}
|
||||
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
||||
self.stepIndex = 0
|
||||
|
||||
# Global variables
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Get work brief from session info."""
|
||||
return self.info.get('desc')
|
||||
|
||||
@desc.setter
|
||||
def desc(self, value: str) -> None:
|
||||
"""Set work brief in session info."""
|
||||
self.stepIndex += 1
|
||||
|
||||
|
||||
def log_prefix(self) -> str:
|
||||
"""Get log prefix with session ID and desc."""
|
||||
return f"〖Session{self.info['session_id']}〗"
|
||||
|
||||
def parse_source_pos(self, level: int):
|
||||
pos = (traceback.format_stack())[level - 1].strip().split('\n')[0]
|
||||
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||||
if match:
|
||||
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||||
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||||
return pos
|
||||
|
||||
def log_info(self, msg: str, level: int = -2):
|
||||
"""Log info message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_success(self, msg: str, level: int = -2):
|
||||
"""Log success message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_warning(self, msg: str, level: int = -2):
|
||||
"""Log warning message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_error(self, msg: str, level: int = -2):
|
||||
"""Log error message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
def log_exception(self, msg: str, level: int = -2):
|
||||
"""Log exception message with session ID."""
|
||||
pos = self.parse_source_pos(level)
|
||||
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||||
|
||||
engine_async = create_async_engine(
|
||||
get_settings().database.url,
|
||||
echo=True, # get_settings().database.echo,
|
||||
future=True,
|
||||
pool_size=get_settings().database.pool_size,
|
||||
max_overflow=get_settings().database.max_overflow,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
from fastapi import Request
|
||||
|
||||
def create_database_engine():
|
||||
"""Create database engine."""
|
||||
global engine, SessionLocal
|
||||
|
||||
settings = get_settings()
|
||||
database_url = settings.database.url
|
||||
|
||||
# Determine database type and configure engine
|
||||
engine_kwargs = {
|
||||
"echo": settings.database.echo,
|
||||
}
|
||||
|
||||
if database_url.startswith("sqlite"):
|
||||
# SQLite configuration
|
||||
engine = create_engine(database_url, **engine_kwargs)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
logging.info(f"SQLite database engine created: {database_url}")
|
||||
elif database_url.startswith("postgresql"):
|
||||
# PostgreSQL configuration
|
||||
engine_kwargs.update({
|
||||
"pool_size": settings.database.pool_size,
|
||||
"max_overflow": settings.database.max_overflow,
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600,
|
||||
})
|
||||
engine = create_engine(database_url, **engine_kwargs)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
logging.info(f"PostgreSQL database engine created: {database_url}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type. Please use PostgreSQL or SQLite. URL: {database_url}")
|
||||
AsyncSessionFactory = sessionmaker(
|
||||
bind=engine_async,
|
||||
class_=DrSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=True
|
||||
)
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database."""
|
||||
global engine
|
||||
async def get_session(request: Request = None):
|
||||
url = "无request"
|
||||
if request:
|
||||
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
||||
# session = AsyncSessionFactory()
|
||||
|
||||
if engine is None:
|
||||
create_database_engine()
|
||||
session = DrSession(bind=engine_async)
|
||||
|
||||
# Import all models to ensure they are registered
|
||||
from ..models import user, conversation, message, knowledge_base, permission, workflow
|
||||
session.desc = f"SUCCESS: 创建数据库 session >>> {url}"
|
||||
|
||||
# Create all tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logging.info("Database tables created")
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""Get database session."""
|
||||
global SessionLocal
|
||||
# 设置request属性
|
||||
if request:
|
||||
session.request = request
|
||||
|
||||
if SessionLocal is None:
|
||||
create_database_engine()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
yield session
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logging.error(f"Database session error: {e}")
|
||||
raise
|
||||
errMsg = f"数据库 session 异常 >>> {e}"
|
||||
session.desc = f"EXCEPTION: {errMsg}"
|
||||
await session.rollback()
|
||||
raise e
|
||||
# DatabaseError(e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_db_session() -> Session:
|
||||
"""Get database session (synchronous)."""
|
||||
global SessionLocal
|
||||
|
||||
if SessionLocal is None:
|
||||
create_database_engine()
|
||||
|
||||
return SessionLocal()
|
||||
session.desc = f"数据库 session 关闭"
|
||||
await session.close()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Revises:
|
|||
Create Date: 2024-01-01 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from alembic_sync import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ sys.path.insert(0, str(backend_dir))
|
|||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from th_agenter.core.config import settings
|
||||
from th_agenter.db.database import Base, get_db_session
|
||||
from th_agenter.db.database import Base, get_session
|
||||
from th_agenter.models import * # Import all models to ensure they're registered
|
||||
from th_agenter.utils.logger import get_logger
|
||||
from th_agenter.models.resource import Resource
|
||||
|
|
@ -26,7 +26,7 @@ def migrate_hardcoded_resources():
|
|||
db = None
|
||||
try:
|
||||
# Get database session
|
||||
db = get_db_session()
|
||||
db = get_session() # xxxx
|
||||
|
||||
if db is None:
|
||||
logger.error("Failed to create database session")
|
||||
|
|
@ -437,15 +437,4 @@ def migrate_hardcoded_resources():
|
|||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
def main():
|
||||
"""Main function to run the migration."""
|
||||
print("=== 硬编码资源数据迁移 ===")
|
||||
success = migrate_hardcoded_resources()
|
||||
if success:
|
||||
print("\n🎉 资源数据迁移完成!")
|
||||
else:
|
||||
print("\n❌ 资源数据迁移失败!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,7 +5,7 @@ Revises: add_system_management
|
|||
Create Date: 2024-01-25 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from alembic_sync import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Database models for TH-Agenter."""
|
||||
"""Database models for TH Agenter."""
|
||||
|
||||
from .user import User
|
||||
from .conversation import Conversation
|
||||
|
|
@ -9,6 +9,8 @@ from .excel_file import ExcelFile
|
|||
from .permission import Role, UserRole
|
||||
from .llm_config import LLMConfig
|
||||
from .workflow import Workflow, WorkflowExecution, NodeExecution
|
||||
from .database_config import DatabaseConfig
|
||||
from .table_metadata import TableMetadata
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
|
|
@ -23,5 +25,7 @@ __all__ = [
|
|||
"LLMConfig",
|
||||
"Workflow",
|
||||
"WorkflowExecution",
|
||||
"NodeExecution"
|
||||
"NodeExecution",
|
||||
"DatabaseConfig",
|
||||
"TableMetadata"
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
|||
"""Agent configuration model."""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import String, Text, Boolean, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
|
|
@ -10,44 +10,34 @@ class AgentConfig(BaseModel):
|
|||
|
||||
__tablename__ = "agent_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Agent configuration
|
||||
enabled_tools = Column(JSON, nullable=False, default=list)
|
||||
max_iterations = Column(Integer, default=10)
|
||||
temperature = Column(String(10), default="0.1")
|
||||
system_message = Column(Text, nullable=True)
|
||||
verbose = Column(Boolean, default=True)
|
||||
enabled_tools: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
max_iterations: Mapped[int] = mapped_column(default=10)
|
||||
temperature: Mapped[str] = mapped_column(String(10), default="0.1")
|
||||
system_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
verbose: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Model configuration
|
||||
model_name = Column(String(100), default="gpt-3.5-turbo")
|
||||
max_tokens = Column(Integer, default=2048)
|
||||
model_name: Mapped[str] = mapped_column(String(100), default="gpt-3.5-turbo")
|
||||
max_tokens: Mapped[int] = mapped_column(default=2048)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
is_default: Mapped[bool] = mapped_column(default=False)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentConfig(id={self.id}, name='{self.name}', is_active={self.is_active})>"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}[{self.id}] Active: {self.is_active}"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"enabled_tools": self.enabled_tools or [],
|
||||
"max_iterations": self.max_iterations,
|
||||
"temperature": self.temperature,
|
||||
"system_message": self.system_message,
|
||||
"verbose": self.verbose,
|
||||
"model_name": self.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"is_active": self.is_active,
|
||||
"is_default": self.is_default,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at
|
||||
}
|
||||
data = super().to_dict()
|
||||
data['enabled_tools'] = self.enabled_tools or []
|
||||
return data
|
||||
|
|
@ -1,30 +1,32 @@
|
|||
"""Conversation model."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, DateTime
|
||||
from datetime import datetime
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Conversation model."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
|
||||
title = Column(String(200), nullable=False)
|
||||
user_id = Column(Integer, nullable=False) # Removed ForeignKey("users.id")
|
||||
knowledge_base_id = Column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id")
|
||||
system_prompt = Column(Text, nullable=True)
|
||||
model_name = Column(String(100), nullable=False, default="gpt-3.5-turbo")
|
||||
temperature = Column(String(10), nullable=False, default="0.7")
|
||||
max_tokens = Column(Integer, nullable=False, default=2048)
|
||||
is_archived = Column(Boolean, default=False, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("users.id")
|
||||
knowledge_base_id: Mapped[int | None] = mapped_column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id")
|
||||
system_prompt: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False, default="gpt-3.5-turbo")
|
||||
temperature: Mapped[str] = mapped_column(String(10), nullable=False, default="0.7")
|
||||
max_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=2048)
|
||||
is_archived: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
last_message_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
|
||||
|
||||
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
"""Get the number of messages in this conversation."""
|
||||
|
|
@ -33,6 +35,4 @@ class Conversation(BaseModel):
|
|||
@property
|
||||
def last_message_at(self):
|
||||
"""Get the timestamp of the last message."""
|
||||
if self.messages:
|
||||
return self.messages[-1].created_at
|
||||
return self.created_at
|
||||
return self.messages[-1].created_at or self.created_at
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""数据库配置模型"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import Integer, String, Text, Boolean, JSON
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
|
|
@ -12,17 +13,17 @@ class DatabaseConfig(BaseModel):
|
|||
"""数据库配置表"""
|
||||
__tablename__ = "database_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(100), nullable=False) # 配置名称
|
||||
db_type = Column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等
|
||||
host = Column(String(255), nullable=False)
|
||||
port = Column(Integer, nullable=False)
|
||||
database = Column(String(100), nullable=False)
|
||||
username = Column(String(100), nullable=False)
|
||||
password = Column(Text, nullable=False) # 加密存储
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
connection_params = Column(JSON, nullable=True) # 额外连接参数
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False) # 配置名称
|
||||
db_type: Mapped[str] = mapped_column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等
|
||||
host: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
port: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
database: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
username: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
password: Mapped[str] = mapped_column(Text, nullable=False) # 加密存储
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
connection_params: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 额外连接参数
|
||||
|
||||
def to_dict(self, include_password=False, decrypt_service=None):
|
||||
result = {
|
||||
|
|
@ -43,7 +44,7 @@ class DatabaseConfig(BaseModel):
|
|||
|
||||
# 如果需要包含密码且提供了解密服务
|
||||
if include_password and decrypt_service:
|
||||
print('begin decrypt password')
|
||||
logger.info(f"begin decrypt password for db config {self.id}")
|
||||
result["password"] = decrypt_service._decrypt_password(self.password)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,48 +1,44 @@
|
|||
"""Excel file models for smart query."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Text, Boolean, JSON, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, JSON, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class ExcelFile(BaseModel):
|
||||
"""Excel file model for storing file metadata."""
|
||||
|
||||
__tablename__ = "excel_files"
|
||||
|
||||
|
||||
|
||||
# Basic file information
|
||||
# user_id = Column(Integer, nullable=False) # 用户ID
|
||||
original_filename = Column(String(255), nullable=False) # 原始文件名
|
||||
file_path = Column(String(500), nullable=False) # 文件存储路径
|
||||
file_size = Column(Integer, nullable=False) # 文件大小(字节)
|
||||
file_type = Column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv)
|
||||
# user_id: Mapped[int] = mapped_column(Integer, nullable=False) # 用户ID
|
||||
original_filename: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False) # 文件存储路径
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # 文件大小(字节)
|
||||
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv)
|
||||
|
||||
# Excel specific information
|
||||
sheet_names = Column(JSON, nullable=False) # 所有sheet名称列表
|
||||
default_sheet = Column(String(100), nullable=True) # 默认sheet名称
|
||||
sheet_names: Mapped[list] = mapped_column(JSON, nullable=False) # 所有sheet名称列表
|
||||
default_sheet: Mapped[str | None] = mapped_column(String(100), nullable=True) # 默认sheet名称
|
||||
|
||||
# Data preview information
|
||||
columns_info = Column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]}
|
||||
preview_data = Column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]}
|
||||
data_types = Column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}}
|
||||
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]}
|
||||
preview_data: Mapped[dict] = mapped_column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]}
|
||||
data_types: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}}
|
||||
|
||||
# Statistics
|
||||
total_rows = Column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count}
|
||||
total_columns = Column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count}
|
||||
total_rows: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count}
|
||||
total_columns: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count}
|
||||
|
||||
# Processing status
|
||||
is_processed = Column(Boolean, default=True, nullable=False) # 是否已处理
|
||||
processing_error = Column(Text, nullable=True) # 处理错误信息
|
||||
is_processed: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否已处理
|
||||
processing_error: Mapped[str | None] = mapped_column(Text, nullable=True) # 处理错误信息
|
||||
|
||||
# Upload information
|
||||
# upload_time = Column(DateTime, default=func.now(), nullable=False) # 上传时间
|
||||
last_accessed = Column(DateTime, nullable=True) # 最后访问时间
|
||||
# upload_time: Mapped[DateTime] = mapped_column(DateTime, default=func.now(), nullable=False) # 上传时间
|
||||
last_accessed: Mapped[DateTime | None] = mapped_column(DateTime, nullable=True) # 最后访问时间
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ExcelFile(id={self.id}, filename='{self.original_filename}', user_id={self.user_id})>"
|
||||
return f"<ExcelFile(id={self.id}, filename='{self.original_filename}')>"
|
||||
|
||||
@property
|
||||
def file_size_mb(self):
|
||||
|
|
|
|||
|
|
@ -1,41 +1,40 @@
|
|||
"""Knowledge base models."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Boolean, JSON, Float
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import String, Integer, Text, Boolean, JSON
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class KnowledgeBase(BaseModel):
|
||||
"""Knowledge base model."""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
|
||||
name = Column(String(100), unique=False, index=True, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
embedding_model = Column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
chunk_size = Column(Integer, nullable=False, default=1000)
|
||||
chunk_overlap = Column(Integer, nullable=False, default=200)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(100), unique=False, index=True, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
embedding_model: Mapped[str] = mapped_column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
chunk_size: Mapped[int] = mapped_column(Integer, nullable=False, default=1000)
|
||||
chunk_overlap: Mapped[int] = mapped_column(Integer, nullable=False, default=200)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Vector database settings
|
||||
vector_db_type = Column(String(50), nullable=False, default="chroma")
|
||||
collection_name = Column(String(100), nullable=True) # For vector DB collection
|
||||
vector_db_type: Mapped[str] = mapped_column(String(50), nullable=False, default="chroma")
|
||||
collection_name: Mapped[str | None] = mapped_column(String(100), nullable=True) # For vector DB collection
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KnowledgeBase(id={self.id}, name='{self.name}')>"
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
"""Get the number of documents in this knowledge base."""
|
||||
return len(self.documents)
|
||||
# Relationships are commented out to remove foreign key constraints, so these properties should be updated
|
||||
# @property
|
||||
# def document_count(self):
|
||||
# """Get the number of documents in this knowledge base."""
|
||||
# return len(self.documents)
|
||||
|
||||
@property
|
||||
def active_document_count(self):
|
||||
"""Get the number of active documents in this knowledge base."""
|
||||
return len([doc for doc in self.documents if doc.is_processed])
|
||||
# @property
|
||||
# def active_document_count(self):
|
||||
# """Get the number of active documents in this knowledge base."""
|
||||
# return len([doc for doc in self.documents if doc.is_processed])
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
|
|
@ -43,28 +42,28 @@ class Document(BaseModel):
|
|||
|
||||
__tablename__ = "documents"
|
||||
|
||||
knowledge_base_id = Column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id")
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_path = Column(String(500), nullable=False)
|
||||
file_size = Column(Integer, nullable=False) # in bytes
|
||||
file_type = Column(String(50), nullable=False) # .pdf, .txt, .docx, etc.
|
||||
mime_type = Column(String(100), nullable=True)
|
||||
knowledge_base_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id")
|
||||
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
original_filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
file_size: Mapped[int] = mapped_column(Integer, nullable=False) # in bytes
|
||||
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # .pdf, .txt, .docx, etc.
|
||||
mime_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
|
||||
# Processing status
|
||||
is_processed = Column(Boolean, default=False, nullable=False)
|
||||
processing_error = Column(Text, nullable=True)
|
||||
is_processed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
processing_error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Content and metadata
|
||||
content = Column(Text, nullable=True) # Extracted text content
|
||||
doc_metadata = Column(JSON, nullable=True) # Additional metadata
|
||||
content: Mapped[str | None] = mapped_column(Text, nullable=True) # Extracted text content
|
||||
doc_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Additional metadata
|
||||
|
||||
# Chunking information
|
||||
chunk_count = Column(Integer, default=0, nullable=False)
|
||||
chunk_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
|
||||
# Embedding information
|
||||
embedding_model = Column(String(100), nullable=True)
|
||||
vector_ids = Column(JSON, nullable=True) # Store vector database IDs for chunks
|
||||
embedding_model: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
vector_ids: Mapped[list | None] = mapped_column(JSON, nullable=True) # Store vector database IDs for chunks
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +1,42 @@
|
|||
"""LLM Configuration model for managing multiple AI models."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, Integer, Float, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from sqlalchemy import String, Text, Boolean, Integer, Float, JSON, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM Configuration model for managing AI model settings."""
|
||||
|
||||
__tablename__ = "llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True) # 配置名称
|
||||
provider = Column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||||
model_name = Column(String(100), nullable=False) # 模型名称
|
||||
api_key = Column(String(500), nullable=False) # API密钥(加密存储)
|
||||
base_url = Column(String(200), nullable=True) # API基础URL
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True) # 配置名称
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||||
model_name: Mapped[str] = mapped_column(String(100), nullable=False) # 模型名称
|
||||
api_key: Mapped[str] = mapped_column(String(500), nullable=False) # API密钥(加密存储)
|
||||
base_url: Mapped[str | None] = mapped_column(String(200), nullable=True) # API基础URL
|
||||
|
||||
# 模型参数
|
||||
max_tokens = Column(Integer, default=2048, nullable=False)
|
||||
temperature = Column(Float, default=0.7, nullable=False)
|
||||
top_p = Column(Float, default=1.0, nullable=False)
|
||||
frequency_penalty = Column(Float, default=0.0, nullable=False)
|
||||
presence_penalty = Column(Float, default=0.0, nullable=False)
|
||||
max_tokens: Mapped[int] = mapped_column(Integer, default=2048, nullable=False)
|
||||
temperature: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||
top_p: Mapped[float] = mapped_column(Float, default=1.0, nullable=False)
|
||||
frequency_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
presence_penalty: Mapped[float] = mapped_column(Float, default=0.0, nullable=False)
|
||||
|
||||
# 配置信息
|
||||
description = Column(Text, nullable=True) # 配置描述
|
||||
is_active = Column(Boolean, default=True, nullable=False) # 是否启用
|
||||
is_default = Column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||||
is_embedding = Column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 配置描述
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # 是否启用
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||||
is_embedding: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||||
|
||||
# 扩展配置(JSON格式)
|
||||
extra_config = Column(JSON, nullable=True) # 额外配置参数
|
||||
extra_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) # 额外配置参数
|
||||
|
||||
# 使用统计
|
||||
usage_count = Column(Integer, default=0, nullable=False) # 使用次数
|
||||
last_used_at = Column(String(50), nullable=True) # 最后使用时间
|
||||
usage_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # 使用次数
|
||||
last_used_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 最后使用时间
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
|
||||
|
|
@ -122,9 +121,8 @@ class LLMConfig(BaseModel):
|
|||
|
||||
def increment_usage(self):
|
||||
"""增加使用次数."""
|
||||
from datetime import datetime
|
||||
self.usage_count += 1
|
||||
self.last_used_at = datetime.now().isoformat()
|
||||
self.last_used_at = datetime.now()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, provider: str, is_embedding: bool = False):
|
||||
|
|
@ -132,7 +130,7 @@ class LLMConfig(BaseModel):
|
|||
templates = {
|
||||
'openai': {
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'model_name': 'gpt-3.5-turbo' if not is_embedding else 'text-embedding-ada-002',
|
||||
'model_name': 'gpt-4.0-mini' if not is_embedding else 'text-embedding-ada-002',
|
||||
'max_tokens': 2048,
|
||||
'temperature': 0.7
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Message model."""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Enum, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import String, Integer, Text, Enum, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
import enum
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
|
@ -27,19 +27,19 @@ class Message(BaseModel):
|
|||
|
||||
__tablename__ = "messages"
|
||||
|
||||
conversation_id = Column(Integer, nullable=False) # Removed ForeignKey("conversations.id")
|
||||
role = Column(Enum(MessageRole), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
message_type = Column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||
message_metadata = Column(JSON, nullable=True) # Store additional data like file info, tokens used, etc.
|
||||
conversation_id: Mapped[int] = mapped_column(Integer, nullable=False) # Removed ForeignKey("conversations.id")
|
||||
role: Mapped[MessageRole] = mapped_column(Enum(MessageRole), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||
message_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Store additional data like file info, tokens used, etc.
|
||||
|
||||
# For knowledge base context
|
||||
context_documents = Column(JSON, nullable=True) # Store retrieved document references
|
||||
context_documents: Mapped[dict | None] = mapped_column(JSON, nullable=True) # Store retrieved document references
|
||||
|
||||
# Token usage tracking
|
||||
prompt_tokens = Column(Integer, nullable=True)
|
||||
completion_tokens = Column(Integer, nullable=True)
|
||||
total_tokens = Column(Integer, nullable=True)
|
||||
prompt_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationships removed to eliminate foreign key constraints
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Role models for simplified RBAC system."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import String, Text, Boolean, ForeignKey, Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..db.base import BaseModel, Base
|
||||
|
|
@ -12,11 +12,11 @@ class Role(BaseModel):
|
|||
|
||||
__tablename__ = "roles"
|
||||
|
||||
name = Column(String(100), nullable=False, unique=True, index=True) # 角色名称
|
||||
code = Column(String(100), nullable=False, unique=True, index=True) # 角色编码
|
||||
description = Column(Text, nullable=True) # 角色描述
|
||||
is_system = Column(Boolean, default=False, nullable=False) # 是否系统角色
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色名称
|
||||
code: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True) # 角色编码
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True) # 角色描述
|
||||
is_system: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # 是否系统角色
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 关系 - 只保留用户关系
|
||||
users = relationship("User", secondary="user_roles", back_populates="roles")
|
||||
|
|
@ -42,8 +42,8 @@ class UserRole(Base):
|
|||
|
||||
__tablename__ = "user_roles"
|
||||
|
||||
user_id = Column(Integer, ForeignKey('users.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('roles.id'), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey('users.id'), primary_key=True)
|
||||
role_id: Mapped[int] = mapped_column(Integer, ForeignKey('roles.id'), primary_key=True)
|
||||
|
||||
# 关系 - 用于直接操作关联表的场景
|
||||
user = relationship("User", viewonly=True)
|
||||
|
|
|
|||
|
|
@ -1,38 +1,36 @@
|
|||
"""表元数据模型"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Integer, String, Text, DateTime, Boolean, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class TableMetadata(BaseModel):
|
||||
"""表元数据表"""
|
||||
__tablename__ = "table_metadata"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
|
||||
# database_config_id = Column(Integer, ForeignKey('database_configs.id'), nullable=False)
|
||||
table_name = Column(String(100), nullable=False, index=True)
|
||||
table_schema = Column(String(50), default='public')
|
||||
table_type = Column(String(20), default='BASE TABLE')
|
||||
table_comment = Column(Text, nullable=True) # 表描述
|
||||
database_config_id = Column(Integer, nullable=True) #数据库配置ID
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
table_schema: Mapped[str] = mapped_column(String(50), default='public')
|
||||
table_type: Mapped[str] = mapped_column(String(20), default='BASE TABLE')
|
||||
table_comment: Mapped[str | None] = mapped_column(Text, nullable=True) # 表描述
|
||||
database_config_id: Mapped[int | None] = mapped_column(Integer, nullable=True) #数据库配置ID
|
||||
# 表结构信息
|
||||
columns_info = Column(JSON, nullable=False) # 列信息:名称、类型、注释等
|
||||
primary_keys = Column(JSON, nullable=True) # 主键列表
|
||||
foreign_keys = Column(JSON, nullable=True) # 外键信息
|
||||
indexes = Column(JSON, nullable=True) # 索引信息
|
||||
columns_info: Mapped[dict] = mapped_column(JSON, nullable=False) # 列信息:名称、类型、注释等
|
||||
primary_keys: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 主键列表
|
||||
foreign_keys: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 外键信息
|
||||
indexes: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 索引信息
|
||||
|
||||
# 示例数据
|
||||
sample_data = Column(JSON, nullable=True) # 前5条示例数据
|
||||
row_count = Column(Integer, default=0) # 总行数
|
||||
sample_data: Mapped[dict | None] = mapped_column(JSON, nullable=True) # 前5条示例数据
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0) # 总行数
|
||||
|
||||
# 问答相关
|
||||
is_enabled_for_qa = Column(Boolean, default=True) # 是否启用问答
|
||||
qa_description = Column(Text, nullable=True) # 问答描述
|
||||
business_context = Column(Text, nullable=True) # 业务上下文
|
||||
is_enabled_for_qa: Mapped[bool] = mapped_column(Boolean, default=True) # 是否启用问答
|
||||
qa_description: Mapped[str | None] = mapped_column(Text, nullable=True) # 问答描述
|
||||
business_context: Mapped[str | None] = mapped_column(Text, nullable=True) # 业务上下文
|
||||
|
||||
last_synced_at = Column(DateTime(timezone=True), nullable=True) # 最后同步时间
|
||||
last_synced_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True) # 最后同步时间
|
||||
|
||||
# 关系
|
||||
# database_config = relationship("DatabaseConfig", back_populates="table_metadata")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""User model."""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import String, Boolean, Text
|
||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
|
@ -12,19 +13,19 @@ class User(BaseModel):
|
|||
|
||||
__tablename__ = "users"
|
||||
|
||||
username = Column(String(50), unique=True, index=True, nullable=False)
|
||||
email = Column(String(100), unique=True, index=True, nullable=False)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(100), nullable=True)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
avatar_url = Column(String(255), nullable=True)
|
||||
bio = Column(Text, nullable=True)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True, index=True, nullable=False)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
bio: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# 关系 - 只保留角色关系
|
||||
roles = relationship("Role", secondary="user_roles", back_populates="users")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
|
||||
return f"<User(id={self.id}, username='{self.username}', email='{self.email}', full_name='{self.full_name}', is_active={self.is_active}, avatar_url='{self.avatar_url}', bio='{self.bio}', is_superuser={self.is_admin})>"
|
||||
|
||||
def to_dict(self, include_sensitive=False, include_roles=False):
|
||||
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||
|
|
@ -36,57 +37,85 @@ class User(BaseModel):
|
|||
'is_active': self.is_active,
|
||||
'avatar_url': self.avatar_url,
|
||||
'bio': self.bio,
|
||||
'is_superuser': self.is_superuser()
|
||||
'is_superuser': self.is_admin # 使用同步的 is_admin 属性代替异步的 is_superuser 方法
|
||||
})
|
||||
|
||||
if not include_sensitive:
|
||||
data.pop('hashed_password', None)
|
||||
|
||||
if include_roles:
|
||||
data['roles'] = [role.to_dict() for role in self.roles if role.is_active]
|
||||
try:
|
||||
# 安全访问roles关系属性
|
||||
data['roles'] = [role.to_dict() for role in self.roles if role.is_active]
|
||||
except Exception:
|
||||
# 如果角色关系未加载或访问出错,返回空列表
|
||||
data['roles'] = []
|
||||
|
||||
return data
|
||||
|
||||
def has_role(self, role_code: str) -> bool:
|
||||
async def has_role(self, role_code: str) -> bool:
|
||||
"""检查用户是否拥有指定角色."""
|
||||
try:
|
||||
# 在异步环境中,需要先加载关系属性
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy import select
|
||||
from .permission import Role, UserRole
|
||||
|
||||
session = object_session(self)
|
||||
if isinstance(session, AsyncSession):
|
||||
# 如果是异步会话,使用await加载关系
|
||||
await session.refresh(self, ['roles'])
|
||||
return any(role.code == role_code and role.is_active for role in self.roles)
|
||||
except Exception:
|
||||
# 如果对象已分离,使用数据库查询
|
||||
# 如果对象已分离或加载关系失败,使用数据库查询
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy import select
|
||||
from .permission import Role, UserRole
|
||||
|
||||
session = object_session(self)
|
||||
if session is None:
|
||||
# 如果没有会话,创建新的会话
|
||||
from ..db.database import SessionLocal
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 如果没有会话,返回False
|
||||
return False
|
||||
else:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
if isinstance(session, AsyncSession):
|
||||
# 如果是异步会话,使用异步查询
|
||||
user_role = await session.execute(
|
||||
select(UserRole).join(Role).filter(
|
||||
UserRole.user_id == self.id,
|
||||
Role.code == role_code,
|
||||
Role.is_active == True
|
||||
)
|
||||
)
|
||||
return user_role.scalar_one_or_none() is not None
|
||||
else:
|
||||
# 如果是同步会话,使用同步查询
|
||||
user_role = session.query(UserRole).join(Role).filter(
|
||||
UserRole.user_id == self.id,
|
||||
Role.code == role_code,
|
||||
Role.is_active == True
|
||||
).first()
|
||||
return user_role is not None
|
||||
finally:
|
||||
session.close()
|
||||
else:
|
||||
user_role = session.query(UserRole).join(Role).filter(
|
||||
UserRole.user_id == self.id,
|
||||
Role.code == role_code,
|
||||
Role.is_active == True
|
||||
).first()
|
||||
return user_role is not None
|
||||
|
||||
def is_superuser(self) -> bool:
|
||||
async def is_superuser(self) -> bool:
|
||||
"""检查用户是否为超级管理员."""
|
||||
return self.has_role('SUPER_ADMIN')
|
||||
return await self.has_role('SUPER_ADMIN')
|
||||
|
||||
def is_admin_user(self) -> bool:
|
||||
async def is_admin_user(self) -> bool:
|
||||
"""检查用户是否为管理员(兼容性方法)."""
|
||||
return self.is_superuser()
|
||||
return await self.is_superuser()
|
||||
|
||||
# 注意:属性方式的 is_admin 无法是异步的,所以我们改为同步方法并简化实现
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
"""检查用户是否为管理员(属性方式)."""
|
||||
return self.is_superuser()
|
||||
# 同步属性无法使用 await,所以我们只能检查已加载的角色
|
||||
# 使用try-except捕获可能的MissingGreenlet错误
|
||||
try:
|
||||
# 检查角色关系是否已经加载
|
||||
# 如果roles属性是一个InstrumentedList且已经加载,那么它应该有__iter__方法
|
||||
return any(role.code == 'SUPER_ADMIN' and role.is_active for role in self.roles)
|
||||
except Exception:
|
||||
# 如果角色关系未加载或访问出错,返回False
|
||||
return False
|
||||
|
|
@ -1,20 +1,17 @@
|
|||
"""Workflow models."""
|
||||
|
||||
from sqlalchemy import Column, String, Text, Boolean, Integer, JSON, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import String, Text, Boolean, Integer, JSON, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||
import enum
|
||||
|
||||
from ..db.base import BaseModel
|
||||
|
||||
|
||||
class WorkflowStatus(enum.Enum):
|
||||
"""工作流状态枚举"""
|
||||
DRAFT = "DRAFT" # 草稿
|
||||
PUBLISHED = "PUBLISHED" # 已发布
|
||||
ARCHIVED = "ARCHIVED" # 已归档
|
||||
|
||||
|
||||
class NodeType(enum.Enum):
|
||||
"""节点类型枚举"""
|
||||
START = "start" # 开始节点
|
||||
|
|
@ -26,7 +23,6 @@ class NodeType(enum.Enum):
|
|||
HTTP = "http" # HTTP请求节点
|
||||
TOOL = "tool" # 工具节点
|
||||
|
||||
|
||||
class ExecutionStatus(enum.Enum):
|
||||
"""执行状态枚举"""
|
||||
PENDING = "pending" # 等待执行
|
||||
|
|
@ -35,25 +31,23 @@ class ExecutionStatus(enum.Enum):
|
|||
FAILED = "failed" # 执行失败
|
||||
CANCELLED = "cancelled" # 已取消
|
||||
|
||||
|
||||
class Workflow(BaseModel):
|
||||
"""工作流模型"""
|
||||
|
||||
__tablename__ = "workflows"
|
||||
|
||||
name = Column(String(100), nullable=False, comment="工作流名称")
|
||||
description = Column(Text, nullable=True, comment="工作流描述")
|
||||
status = Column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态")
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False, comment="工作流名称")
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True, comment="工作流描述")
|
||||
status: Mapped[WorkflowStatus] = mapped_column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
|
||||
# 工作流定义(JSON格式存储节点和连接信息)
|
||||
definition = Column(JSON, nullable=False, comment="工作流定义")
|
||||
definition: Mapped[dict] = mapped_column(JSON, nullable=False, comment="工作流定义")
|
||||
|
||||
# 版本信息
|
||||
version = Column(String(20), default="1.0.0", nullable=False, comment="版本号")
|
||||
version: Mapped[str] = mapped_column(String(20), default="1.0.0", nullable=False, comment="版本号")
|
||||
|
||||
# 关联用户
|
||||
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID")
|
||||
owner_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID")
|
||||
|
||||
# 关系
|
||||
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
|
||||
|
|
@ -78,26 +72,25 @@ class Workflow(BaseModel):
|
|||
|
||||
return data
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""工作流执行记录"""
|
||||
|
||||
__tablename__ = "workflow_executions"
|
||||
|
||||
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID")
|
||||
status = Column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
workflow_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID")
|
||||
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
|
||||
# 执行输入和输出
|
||||
input_data = Column(JSON, nullable=True, comment="输入数据")
|
||||
output_data = Column(JSON, nullable=True, comment="输出数据")
|
||||
input_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||
output_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||
|
||||
# 执行信息
|
||||
started_at = Column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at = Column(String(50), nullable=True, comment="完成时间")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
started_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||
|
||||
# 执行者
|
||||
executor_id = Column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID")
|
||||
executor_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID")
|
||||
|
||||
# 关系
|
||||
workflow = relationship("Workflow", back_populates="executions")
|
||||
|
|
@ -125,29 +118,27 @@ class WorkflowExecution(BaseModel):
|
|||
|
||||
return data
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
"""节点执行记录"""
|
||||
|
||||
__tablename__ = "node_executions"
|
||||
|
||||
workflow_execution_id = Column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID")
|
||||
node_id = Column(String(50), nullable=False, comment="节点ID")
|
||||
node_type = Column(Enum(NodeType), nullable=False, comment="节点类型")
|
||||
node_name = Column(String(100), nullable=False, comment="节点名称")
|
||||
workflow_execution_id: Mapped[int] = mapped_column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID")
|
||||
node_id: Mapped[str] = mapped_column(String(50), nullable=False, comment="节点ID")
|
||||
node_type: Mapped[NodeType] = mapped_column(Enum(NodeType), nullable=False, comment="节点类型")
|
||||
node_name: Mapped[str] = mapped_column(String(100), nullable=False, comment="节点名称")
|
||||
|
||||
# 执行状态和结果
|
||||
status = Column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
input_data = Column(JSON, nullable=True, comment="输入数据")
|
||||
output_data = Column(JSON, nullable=True, comment="输出数据")
|
||||
status: Mapped[ExecutionStatus] = mapped_column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||
input_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输入数据")
|
||||
output_data: Mapped[dict | None] = mapped_column(JSON, nullable=True, comment="输出数据")
|
||||
|
||||
# 执行时间
|
||||
started_at = Column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at = Column(String(50), nullable=True, comment="完成时间")
|
||||
duration_ms = Column(Integer, nullable=True, comment="执行时长(毫秒)")
|
||||
started_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="开始时间")
|
||||
completed_at: Mapped[str | None] = mapped_column(String(50), nullable=True, comment="完成时间")
|
||||
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True, comment="执行时长(毫秒)")
|
||||
|
||||
# 错误信息
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True, comment="错误信息")
|
||||
|
||||
# 关系
|
||||
workflow_execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,7 +1,7 @@
|
|||
"""LLM Configuration Pydantic schemas."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
|
@ -28,8 +28,9 @@ class LLMConfigBase(BaseModel):
|
|||
class LLMConfigCreate(LLMConfigBase):
|
||||
"""创建大模型配置模式."""
|
||||
|
||||
@validator('provider')
|
||||
def validate_provider(cls, v):
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: str) -> str:
|
||||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
'alibaba', 'tencent', 'zhipu', 'moonshot', 'deepseek',
|
||||
|
|
@ -39,8 +40,9 @@ class LLMConfigCreate(LLMConfigBase):
|
|||
raise ValueError(f'不支持的服务商: {v},支持的服务商: {", ".join(allowed_providers)}')
|
||||
return v.lower()
|
||||
|
||||
@validator('api_key')
|
||||
def validate_api_key(cls, v):
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
if len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip()
|
||||
|
|
@ -65,8 +67,9 @@ class LLMConfigUpdate(BaseModel):
|
|||
is_embedding: Optional[bool] = Field(None, description="是否为嵌入模型")
|
||||
extra_config: Optional[Dict[str, Any]] = Field(None, description="额外配置")
|
||||
|
||||
@validator('provider')
|
||||
def validate_provider(cls, v):
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
allowed_providers = [
|
||||
'openai', 'azure', 'anthropic', 'google', 'baidu',
|
||||
|
|
@ -78,8 +81,9 @@ class LLMConfigUpdate(BaseModel):
|
|||
return v.lower()
|
||||
return v
|
||||
|
||||
@validator('api_key')
|
||||
def validate_api_key(cls, v):
|
||||
@field_validator('api_key')
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and len(v.strip()) < 10:
|
||||
raise ValueError('API密钥长度不能少于10个字符')
|
||||
return v.strip() if v else v
|
||||
|
|
@ -109,17 +113,16 @@ class LLMConfigResponse(BaseModel):
|
|||
created_by: Optional[int] = None
|
||||
updated_by: Optional[int] = None
|
||||
|
||||
# 敏感信息处理
|
||||
api_key_masked: Optional[str] = None
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@validator('api_key_masked', pre=True, always=True)
|
||||
def mask_api_key(cls, v, values):
|
||||
@computed_field
|
||||
@property
|
||||
def api_key_masked(self) -> Optional[str]:
|
||||
# 在响应中隐藏API密钥,只显示前4位和后4位
|
||||
if 'api_key' in values and values['api_key']:
|
||||
key = values['api_key']
|
||||
if self.api_key:
|
||||
key = self.api_key
|
||||
if len(key) > 8:
|
||||
return f"{key[:4]}{'*' * (len(key) - 8)}{key[-4:]}"
|
||||
else:
|
||||
|
|
@ -148,5 +151,6 @@ class LLMConfigClientResponse(BaseModel):
|
|||
is_active: bool
|
||||
description: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = {
|
||||
'from_attributes': True
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue