first commit
This commit is contained in:
commit
6e733683f8
|
|
@ -0,0 +1,464 @@
|
||||||
|
# th - 企业级智能体应用平台
|
||||||
|
|
||||||
|
|
||||||
|
🚀 **完全开源的大模型应用平台**
|
||||||
|
- 集成智能问答、智能问数、知识库、工作流和智能体编排的大模型解决方案。
|
||||||
|
- 采用Vue.js + FastAPI + PostgreSQL+Langchain/LangGraph架构。
|
||||||
|
- 专为企业级应用设计,代码完全开源,支持私有化部署,可灵活扩展及二次开发。
|
||||||
|
- 用户级数据隔离:每个用户的数据仅对其可见,确保数据安全。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🏗️ 技术架构
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 后端技术栈
|
||||||
|
- **Web框架**: FastAPI + SQLAlchemy + Alembic
|
||||||
|
- **数据库**: PostgreSQL 16+ (开源关系型数据库)
|
||||||
|
- **向量数据库**: PostgreSQL + pgvector 扩展 (开源向量数据库)
|
||||||
|
- **智能体编排**: LangGraph 状态图 + 条件路由
|
||||||
|
- **工具调用**: Function Calling
|
||||||
|
- **模型连接协议**: MCP (Model Context Protocol)
|
||||||
|
- **RAG检索**: LangChain Vector Store
|
||||||
|
- **对话记忆**: ConversationBufferMemory
|
||||||
|
- **文档处理**: PyPDF2 + python-docx + markdown
|
||||||
|
- **数据分析**: Pandas + NumPy
|
||||||
|
|
||||||
|
### 前端技术栈
|
||||||
|
- **框架**: Vue 3 + TypeScript + Vite
|
||||||
|
- **UI组件**: Element Plus (开源UI库)
|
||||||
|
- **HTTP客户端**: Axios
|
||||||
|
- **工作流编辑器**: 自研可视化编辑器
|
||||||
|
- **工作流引擎**: 基于DAG的流程执行引擎
|
||||||
|
- **图形渲染**: Canvas API + SVG
|
||||||
|
- **拖拽交互**: Vue Draggable
|
||||||
|
- **节点连接**: 自定义连线算法
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 本地部署指南
|
||||||
|
|
||||||
|
|
||||||
|
### 环境要求
|
||||||
|
- Python 3.10+
|
||||||
|
- Node.js 18+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
|
||||||
|
### 1. 安装数据库:PostgreSQL及pgvector插件(向量搜索)
|
||||||
|
|
||||||
|
#### 方式一:Docker安装(推荐)
|
||||||
|
使用 Docker + Docker Compose 部署 PostgreSQL 16 + pgvector 插件。
|
||||||
|
|
||||||
|
**1. 创建docker-compose.yml文件**
|
||||||
|
|
||||||
|
内容如下:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
container_name: pgvector-db
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: myuser
|
||||||
|
POSTGRES_PASSWORD: your_password
|
||||||
|
POSTGRES_DB: mydb
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- pgdata:/var/lib/postgresql/data
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
pgdata:
|
||||||
|
```
|
||||||
|
|
||||||
|
**说明:**
|
||||||
|
- 使用 `pgvector/pgvector:pg16` 镜像,内置 PostgreSQL 16 + pgvector 插件
|
||||||
|
- 数据保存在 Docker 卷 `pgdata` 中,重启不会丢失
|
||||||
|
- 监听宿主机端口 5432,可用本地工具如 pgAdmin, DBeaver, psql 连接
|
||||||
|
- 默认数据库名称:mydb
|
||||||
|
- 默认用户名:myuser
|
||||||
|
- 默认密码:your_password
|
||||||
|
|
||||||
|
**2. 启动服务**
|
||||||
|
|
||||||
|
在 `docker-compose.yml` 所在目录下运行:
|
||||||
|
```bash
|
||||||
|
docker-compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
查看容器状态:
|
||||||
|
```bash
|
||||||
|
docker ps
|
||||||
|
```
|
||||||
|
|
||||||
|
输出应包含一个名为 `pgvector-db` 的容器,状态为 Up。
|
||||||
|
|
||||||
|
**3. 验证 pgvector 安装成功**
|
||||||
|
|
||||||
|
进入 PostgreSQL 容器:
|
||||||
|
```bash
|
||||||
|
docker exec -it pgvector-db psql -U myuser -d mydb
|
||||||
|
```
|
||||||
|
|
||||||
|
启用 pgvector 插件:
|
||||||
|
```sql
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
```
|
||||||
|
|
||||||
|
插入并查询向量数据(示例,可以在客户端,如dbeaver等)**
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- 创建表,包含一个向量字段(维度为3)
|
||||||
|
CREATE TABLE items (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
embedding vector(3)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 插入向量数据
|
||||||
|
INSERT INTO items (embedding) VALUES
|
||||||
|
('[1,1,1]'),
|
||||||
|
('[2,2,2]'),
|
||||||
|
('[1,0,0]');
|
||||||
|
|
||||||
|
-- 查询与 [1,1,1] 最接近的向量(基于欧几里得距离)
|
||||||
|
SELECT id, embedding
|
||||||
|
FROM items
|
||||||
|
ORDER BY embedding <-> '[1,1,1]'
|
||||||
|
LIMIT 3;
|
||||||
|
```
|
||||||
|
-- 上述没报错且有结果返回,即安装成功
|
||||||
|
|
||||||
|
### 2. 后端部署
|
||||||
|
```bash
|
||||||
|
# 克隆项目
|
||||||
|
git clone https://github.com/lkpAgent/chat-agent.git
|
||||||
|
cd chat-agent/backend
|
||||||
|
|
||||||
|
#创建python虚拟环境,推荐使用conda创建虚拟环境
|
||||||
|
conda create -n chat-agent python=3.10
|
||||||
|
conda activate chat-agent
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# 配置环境变量,windows下直接复制.env.example文件为.env
|
||||||
|
cp .env.example .env
|
||||||
|
|
||||||
|
# 编辑.env文件,配置数据库连接和AI API密钥。相关配置信息见后面的配置说明
|
||||||
|
|
||||||
|
# 配置完数据库信息后,初始化数据库表及创建登录账号(用户名: test@example.com, 密码: 123456)
|
||||||
|
cd backend/tests
|
||||||
|
python init_db.py
|
||||||
|
|
||||||
|
# 启动后端服务,默认8000端口
|
||||||
|
python -m uvicorn th_agenter.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
# 或者直接运行main.py
|
||||||
|
# cd backend/th_agenter
|
||||||
|
# python main.py
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. 前端部署
|
||||||
|
```bash
|
||||||
|
# 进入前端目录
|
||||||
|
cd ../frontend
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
npm install
|
||||||
|
|
||||||
|
# 配置环境变量
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑.env文件,配置后端API地址
|
||||||
|
VITE_API_BASE_URL = http://localhost:8000
|
||||||
|
|
||||||
|
# 开发环境,启动前端服务,默认端口3000
|
||||||
|
npm run dev
|
||||||
|
# 发布到生产环境,比如部署在{nginx_home}/html/yourdomain,则指定base路径编译
|
||||||
|
# npm run build -- --base=yourdomain
|
||||||
|
```
|
||||||
|
启动成功后,访问http://localhost:3000,会进入到登录页面,默认账号密码为test@example.com/123456
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 4. 访问应用
|
||||||
|
- 前端地址: http://localhost:3000
|
||||||
|
- 后端API: http://localhost:8000
|
||||||
|
- API文档: http://localhost:8000/docs
|
||||||
|
|
||||||
|
### 5. 后端配置说明
|
||||||
|
|
||||||
|
#### 后端环境变量配置 (backend/.env)
|
||||||
|
几个核心配置:系统数据库地址DATABASE_URL,向量数据库配置,CHAT大模型提供商:LLM_PROVIDER及相关配置,向量大模型提供商:EMBEDDING_PROVIDER
|
||||||
|
几个工具API_KEY:tavilySearch,心知天气API
|
||||||
|
|
||||||
|
```env
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
# ========================================
|
||||||
|
DATABASE_URL=postgresql://your_username:your_password@your_host:your_port/your_db
|
||||||
|
# 示例:
|
||||||
|
# DATABASE_URL=postgresql://myuser:mypassword@127.0.0.1:5432/mydb
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 向量数据库配置
|
||||||
|
# ========================================
|
||||||
|
VECTOR_DB_TYPE=pgvector
|
||||||
|
PGVECTOR_HOST=your_host
|
||||||
|
PGVECTOR_PORT=your_port
|
||||||
|
PGVECTOR_DATABASE=mydb
|
||||||
|
PGVECTOR_USER=myuser
|
||||||
|
PGVECTOR_PASSWORD=your_password
|
||||||
|
|
||||||
|
# 大模型配置 (支持OpenAI协议的第三方服务) 只需要配置一种chat大模型以及embedding大模型
|
||||||
|
# ========================================
|
||||||
|
# chat大模型配置
|
||||||
|
# ========================================
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
LLM_PROVIDER=doubao
|
||||||
|
|
||||||
|
# Embedding模型配置
|
||||||
|
# ========================================
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
EMBEDDING_PROVIDER=zhipu
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
OPENAI_MODEL=gpt-4
|
||||||
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
OPENAI_EMBEDDING_MODEL=text-embedding-ada-002
|
||||||
|
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
ZHIPU_API_KEY=your-zhipu-api-key
|
||||||
|
ZHIPU_MODEL=glm-4
|
||||||
|
ZHIPU_EMBEDDING_MODEL=embedding-3
|
||||||
|
ZHIPU_BASE_URL=https://open.bigmodel.cn/api/paas/v4
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||||
|
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
DEEPSEEK_EMBEDDING_MODEL=deepseek-embedding
|
||||||
|
|
||||||
|
# 豆包配置
|
||||||
|
DOUBAO_API_KEY=your-doubao-api-key
|
||||||
|
DOUBAO_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
DOUBAO_MODEL=doubao-1-5-pro-32k-250115
|
||||||
|
DOUBAO_EMBEDDING_MODEL=doubao-embedding
|
||||||
|
|
||||||
|
# Moonshot配置
|
||||||
|
MOONSHOT_API_KEY=your-moonshot-api-key
|
||||||
|
MOONSHOT_BASE_URL=https://api.moonshot.cn/v1
|
||||||
|
MOONSHOT_MODEL=moonshot-v1-8k
|
||||||
|
MOONSHOT_EMBEDDING_MODEL=moonshot-embedding
|
||||||
|
|
||||||
|
# 工具API配置
|
||||||
|
## tavilySearch api
|
||||||
|
TAVILY_API_KEY=your-tavily-api-key
|
||||||
|
## 心知天气api
|
||||||
|
WEATHER_API_KEY=your_xinzhi_api_key
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 📖 API文档
|
||||||
|
|
||||||
|
### 主要API端点
|
||||||
|
|
||||||
|
#### 认证相关
|
||||||
|
- `POST /auth/login` - 用户登录
|
||||||
|
- `POST /auth/register` - 用户注册
|
||||||
|
- `POST /auth/refresh` - 刷新Token
|
||||||
|
|
||||||
|
#### 对话管理
|
||||||
|
- `GET /chat/conversations` - 获取对话列表
|
||||||
|
- `POST /chat/conversations` - 创建新对话
|
||||||
|
- `POST /chat/conversations/{id}/chat` - 发送消息
|
||||||
|
|
||||||
|
#### 知识库管理
|
||||||
|
- `POST /knowledge/upload` - 上传文档
|
||||||
|
- `GET /knowledge/documents` - 获取文档列表
|
||||||
|
- `DELETE /knowledge/documents/{id}` - 删除文档
|
||||||
|
|
||||||
|
#### 智能查询
|
||||||
|
- `POST /smart-query/query` - 智能数据查询
|
||||||
|
- `POST /smart-query/upload` - 上传Excel文件
|
||||||
|
- `GET /smart-query/files` - 获取文件列表
|
||||||
|
|
||||||
|
### 完整API文档
|
||||||
|
启动后端服务后访问: http://localhost:8000/docs
|
||||||
|
|
||||||
|
## 🔧 开发指南
|
||||||
|
|
||||||
|
### 项目结构
|
||||||
|
```
|
||||||
|
open-agent/
|
||||||
|
├── backend/ # 后端代码
|
||||||
|
│ ├── th_agenter/ # 主应用包
|
||||||
|
│ │ ├── api/ # API路由
|
||||||
|
│ │ ├── core/ # 核心配置
|
||||||
|
│ │ ├── db/ # 数据库相关
|
||||||
|
│ │ ├── models/ # 数据库模型
|
||||||
|
│ │ ├── services/ # 业务逻辑
|
||||||
|
│ │ ├── utils/ # 工具函数
|
||||||
|
│ │ └── main.py # 应用入口
|
||||||
|
│ ├── tests/ # 测试文件
|
||||||
|
│ └── requirements.txt # Python依赖
|
||||||
|
├── frontend/ # 前端代码
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── components/ # Vue组件
|
||||||
|
│ │ ├── views/ # 页面组件
|
||||||
|
│ │ │ ├── chat/ # 对话页面
|
||||||
|
│ │ │ ├── knowledge/ # 知识库页面
|
||||||
|
│ │ │ ├── workflow/ # 工作流页面
|
||||||
|
│ │ │ └── agent/ # 智能体页面
|
||||||
|
│ │ ├── stores/ # Pinia状态管理
|
||||||
|
│ │ ├── api/ # API调用
|
||||||
|
│ │ ├── types/ # TypeScript类型
|
||||||
|
│ │ └── router/ # 路由配置
|
||||||
|
│ └── package.json # Node.js依赖
|
||||||
|
├── data/ # 数据目录
|
||||||
|
│ ├── uploads/ # 上传文件
|
||||||
|
│ └── logs/ # 日志文件
|
||||||
|
└── docs/ # 文档目录
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## ✨ 核心能力
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 🤖 智能问答
|
||||||
|
|
||||||
|
|
||||||
|
- **多模型支持**:集成DeepSeek、智谱AI、豆包等国内主流AI服务商
|
||||||
|
- **三种对话模式**:
|
||||||
|
- 自由对话:直接与AI模型交互
|
||||||
|
- RAG对话:基于知识库的检索增强生成
|
||||||
|
- 智能体对话:多智能体协作处理复杂任务
|
||||||
|
- **多轮对话**:支持连续对话,上下文理解和记忆
|
||||||
|
- **对话历史**:完整的会话记录和管理
|
||||||
|
|
||||||
|
## 🌟 技术特色
|
||||||
|
|
||||||
|
### 基于LangGraph的智能体对话系统
|
||||||
|
- **自主规划能力**:智能体能够根据任务需求自主调用工具并规划执行流程
|
||||||
|
- **动态工具调用**:根据上下文自动选择最合适的工具并执行
|
||||||
|
- **多步任务分解**:复杂任务自动拆解为多个子任务并顺序执行
|
||||||
|
|
||||||
|
**示例场景**:
|
||||||
|
当用户询问"推荐长沙和北京哪个适宜旅游"时:
|
||||||
|
1. 智能体首先调用搜索工具查找相关信息
|
||||||
|
2. 未找到合适结果时,自动规划调用天气查询工具
|
||||||
|
3. 智能拆分为两次执行:先查询长沙天气,再查询北京天气
|
||||||
|
4. 根据气温数据判断北京更适宜旅游
|
||||||
|
5. 自动调用搜索工具查找北京景点信息
|
||||||
|
6. 最终整合所有信息生成总结推荐
|
||||||
|
|
||||||
|
第一步:调用搜索引擎搜索哪个城市更适宜旅游
|
||||||
|

|
||||||
|
|
||||||
|
第二步:搜索的内容没有找到合适的答案,意识到错了后,改变策划,重新调用天气工具,从天气的角度判断哪个城市更适合当下旅游 。
|
||||||
|
并且自动进行任务拆解,对北京、长沙分别调用一次天气工具,获取到两个城市的天气情况。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
第三步:根据天气判断北京更适合旅游,再调用搜索引擎工具,搜索北京的特色景点。最后将工具调用结果与问题进行总结,完成本次对话过程。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 📊 智能问数
|
||||||
|
|
||||||
|
|
||||||
|
- **Excel分析**:上传Excel文件进行智能数据分析
|
||||||
|
- **自然语言查询**:用自然语言提问,自动生成Python代码
|
||||||
|
- **数据库查询**:连接PostgreSQL等数据库进行智能问答
|
||||||
|
- **多表关联**:支持复杂的多表/多文件联合查询
|
||||||
|
- **可视化思维链**:大模型思考过程可视化呈现
|
||||||
|
|
||||||
|
## 🌟 技术特色
|
||||||
|
|
||||||
|
### 双引擎智能问数系统
|
||||||
|
|
||||||
|
**基于Excel的智能问数**
|
||||||
|
- 使用LangChain代码解释器插件,将Excel数据读取到Pandas DataFrame
|
||||||
|
- 大模型将自然语言问题转换为Pandas语法并执行
|
||||||
|
- 支持多表格文件联合查询和复杂数据分析
|
||||||
|
|
||||||
|
**基于数据库的智能问数**
|
||||||
|
- 实现PostgreSQL MCP(Model Context Protocol)接口
|
||||||
|
- 大模型先提取表元数据,了解表结构和关系
|
||||||
|
- 根据用户问题自动生成优化SQL查询语句
|
||||||
|
- 支持多表关联查询和复杂数据检索
|
||||||
|
|
||||||
|
基于Excel报表的智能问数
|
||||||
|

|
||||||
|
基于数据库的智能问数
|
||||||
|

|
||||||
|
### 📚 知识库管理
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- **文档处理**:支持PDF、Word、Markdown、TXT等格式
|
||||||
|
- **向量存储**:基于PostgreSQL + pgvector的向量数据库
|
||||||
|
- **智能检索**:向量相似度搜索和BM25算法关键词检索
|
||||||
|
- **文档管理**:上传、删除、分类和标签管理
|
||||||
|
- **RAG集成**:与对话系统无缝集成
|
||||||
|
|
||||||
|
## 🌟 技术特色
|
||||||
|
### 高级语义分割知识库处理
|
||||||
|
- **智能段落分割**:基于大模型的语义理解分割技术,而非传统的文本相似度判断
|
||||||
|
- **精准切分识别**:大模型直接识别最适合的切分位置并输出分割标记字符串
|
||||||
|
- **高效处理流程**:仅输出分割位置字符串,再由代码执行实际分割操作
|
||||||
|
- **性能优化**:避免了传统方法中大量的向量计算和相似度比较,提升处理速度
|
||||||
|
- **质量保证**:大模型的深层语义理解确保分割边界的准确性和合理性
|
||||||
|
### 双重召回检索机制
|
||||||
|
- **多模态检索**:结合向量相似度匹配(语义搜索)与BM25关键词检索(字面匹配)
|
||||||
|
- **混合排序策略**:采用加权融合算法,综合语义相关性和关键词匹配度进行结果排序
|
||||||
|
- **召回增强**:双重召回机制有效解决了单纯向量检索的"词汇不匹配"问题
|
||||||
|
- **精准度提升**:相比单一检索方式,显著提高相关文档的召回率和准确率
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
### 🔧 工作流编排
|
||||||
|
|
||||||
|
- **可视化设计**:拖拽式工作流设计器
|
||||||
|
- **节点类型**:支持AI对话、数据处理、条件判断等节点
|
||||||
|
- **流程控制**:条件分支、循环、并行执行
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 🤖 智能体编排
|
||||||
|
- **多智能体协作**:不同专业领域的AI智能体协同工作
|
||||||
|
- **角色定义**:自定义智能体的专业能力和知识领域
|
||||||
|
- **任务分配**:智能分解复杂任务到合适的智能体
|
||||||
|
- **结果整合**:汇总多个智能体的输出生成最终答案
|
||||||
|
### 在线体验地址,可自己注册账号使用
|
||||||
|
http://113.240.110.92:81/
|
||||||
|
|
||||||
|
#### 💼 商业使用
|
||||||
|
- ✅ 可用于商业项目
|
||||||
|
- ✅ 可修改源码
|
||||||
|
- ✅ 可私有化部署
|
||||||
|
- ✅ 可集成到现有系统
|
||||||
|
- ✅ 无需支付许可费用
|
||||||
|
|
||||||
|
## 📄 许可证
|
||||||
|
|
||||||
|
本项目采用 [MIT License](LICENSE) 许可证,这意味着:
|
||||||
|
- 可以自由使用、修改、分发
|
||||||
|
- 可以用于商业目的
|
||||||
|
- 只需保留原始许可证声明
|
||||||
|
- 作者不承担任何责任
|
||||||
|
|
||||||
|
## 🙏 致谢
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**如果这个项目对你有帮助,请给它一个 ⭐️!**
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
# ========================================
|
||||||
|
# 大模型配置 (支持OpenAI协议的第三方服务)
|
||||||
|
# ========================================
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
LLM_PROVIDER=doubao
|
||||||
|
|
||||||
|
# Embedding模型配置
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
EMBEDDING_PROVIDER=zhipu
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
OPENAI_MODEL=gpt-3.5-turbo
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||||
|
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
|
||||||
|
# 豆包(字节跳动)配置
|
||||||
|
DOUBAO_API_KEY=your-doubao-api-key
|
||||||
|
DOUBAO_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
DOUBAO_MODEL=doubao-1-5-pro-32k-250115
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
ZHIPU_API_KEY=your-zhipu-api-key
|
||||||
|
ZHIPU_BASE_URL=https://open.bigmodel.cn/api/paas/v4
|
||||||
|
ZHIPU_MODEL=glm-4
|
||||||
|
ZHIPU_EMBEDDING_MODEL=embedding-3
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
MOONSHOT_API_KEY=your-moonshot-api-key
|
||||||
|
MOONSHOT_BASE_URL=https://api.moonshot.cn/v1
|
||||||
|
MOONSHOT_MODEL=moonshot-v1-8k
|
||||||
|
MOONSHOT_EMBEDDING_MODEL=moonshot-embedding
|
||||||
|
|
||||||
|
# Embedding模型配置
|
||||||
|
OPENAI_EMBEDDING_MODEL=text-embedding-ada-002
|
||||||
|
DEEPSEEK_EMBEDDING_MODEL=deepseek-embedding
|
||||||
|
DOUBAO_EMBEDDING_MODEL=doubao-embedding
|
||||||
|
|
||||||
|
# 工具API配置
|
||||||
|
## tavilySearch api
|
||||||
|
TAVILY_API_KEY=your-tavily-api-key
|
||||||
|
## 心知天气api
|
||||||
|
WEATHER_API_KEY=your_xinzhi_api_key
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 应用配置
|
||||||
|
# ========================================
|
||||||
|
# 后端应用配置
|
||||||
|
APP_NAME=TH-Agenter
|
||||||
|
APP_VERSION=0.1.0
|
||||||
|
DEBUG=true
|
||||||
|
ENVIRONMENT=development
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# 前端应用配置
|
||||||
|
VITE_API_BASE_URL=http://localhost:8000/api
|
||||||
|
VITE_APP_TITLE=TH-Agenter
|
||||||
|
VITE_APP_VERSION=1.0.0
|
||||||
|
VITE_ENABLE_MOCK=false
|
||||||
|
VITE_UPLOAD_MAX_SIZE=10485760
|
||||||
|
VITE_SUPPORTED_FILE_TYPES=pdf,txt,md,doc,docx,ppt,pptx,xls,xlsx
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 安全配置
|
||||||
|
# ========================================
|
||||||
|
SECRET_KEY=your-secret-key-here-change-in-production
|
||||||
|
ALGORITHM=HS256
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=300
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 数据库配置
|
||||||
|
# ========================================
|
||||||
|
# 数据库URL配置
|
||||||
|
DATABASE_URL=sqlite:///./TH-Agenter.db
|
||||||
|
# DATABASE_URL=postgresql://iagent:iagent@118.196.30.45:5432/iagent
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 向量数据库配置
|
||||||
|
# ========================================
|
||||||
|
VECTOR_DB_TYPE=pgvector
|
||||||
|
PGVECTOR_HOST=118.196.30.45
|
||||||
|
PGVECTOR_PORT=5432
|
||||||
|
PGVECTOR_DATABASE=iagent
|
||||||
|
PGVECTOR_USER=iagent
|
||||||
|
PGVECTOR_PASSWORD=iagent
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
# ========================================
|
||||||
|
# 大模型配置 (支持OpenAI协议的第三方服务)
|
||||||
|
# ========================================
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
LLM_PROVIDER=doubao
|
||||||
|
|
||||||
|
# Embedding模型配置
|
||||||
|
# 可选择的提供商: openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
EMBEDDING_PROVIDER=zhipu
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||||
|
OPENAI_MODEL=gpt-3.5-turbo
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||||
|
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
DEEPSEEK_MODEL=deepseek-chat
|
||||||
|
|
||||||
|
# 豆包(字节跳动)配置
|
||||||
|
DOUBAO_API_KEY=your-doubao-api-key
|
||||||
|
DOUBAO_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
DOUBAO_MODEL=doubao-1-5-pro-32k-250115
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
ZHIPU_API_KEY=your-zhipu-api-key
|
||||||
|
ZHIPU_BASE_URL=https://open.bigmodel.cn/api/paas/v4
|
||||||
|
ZHIPU_MODEL=glm-4
|
||||||
|
ZHIPU_EMBEDDING_MODEL=embedding-3
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
MOONSHOT_API_KEY=your-moonshot-api-key
|
||||||
|
MOONSHOT_BASE_URL=https://api.moonshot.cn/v1
|
||||||
|
MOONSHOT_MODEL=moonshot-v1-8k
|
||||||
|
MOONSHOT_EMBEDDING_MODEL=moonshot-embedding
|
||||||
|
|
||||||
|
# Embedding模型配置
|
||||||
|
OPENAI_EMBEDDING_MODEL=text-embedding-ada-002
|
||||||
|
DEEPSEEK_EMBEDDING_MODEL=deepseek-embedding
|
||||||
|
DOUBAO_EMBEDDING_MODEL=doubao-embedding
|
||||||
|
|
||||||
|
# 工具API配置
|
||||||
|
## tavilySearch api
|
||||||
|
TAVILY_API_KEY=your-tavily-api-key
|
||||||
|
## 心知天气api
|
||||||
|
WEATHER_API_KEY=your_xinzhi_api_key
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 应用配置
|
||||||
|
# ========================================
|
||||||
|
# 后端应用配置
|
||||||
|
APP_NAME=TH-Agenter
|
||||||
|
APP_VERSION=0.1.0
|
||||||
|
DEBUG=true
|
||||||
|
ENVIRONMENT=development
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# 前端应用配置
|
||||||
|
VITE_API_BASE_URL=http://localhost:8000/api
|
||||||
|
VITE_APP_TITLE=TH-Agenter
|
||||||
|
VITE_APP_VERSION=1.0.0
|
||||||
|
VITE_ENABLE_MOCK=false
|
||||||
|
VITE_UPLOAD_MAX_SIZE=10485760
|
||||||
|
VITE_SUPPORTED_FILE_TYPES=pdf,txt,md,doc,docx,ppt,pptx,xls,xlsx
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 安全配置
|
||||||
|
# ========================================
|
||||||
|
SECRET_KEY=your-secret-key-here-change-in-production
|
||||||
|
ALGORITHM=HS256
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=300
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 数据库配置
|
||||||
|
# ========================================
|
||||||
|
|
||||||
|
DATABASE_URL=postgresql://iagent:iagent@118.196.30.45:5432/iagent
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# 向量数据库配置
|
||||||
|
# ========================================
|
||||||
|
VECTOR_DB_TYPE=pgvector
|
||||||
|
PGVECTOR_HOST=localhost
|
||||||
|
PGVECTOR_PORT=5432
|
||||||
|
PGVECTOR_DATABASE=mydb
|
||||||
|
PGVECTOR_USER=myuser
|
||||||
|
PGVECTOR_PASSWORD=mypassword
|
||||||
Binary file not shown.
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Chat Agent Configuration
|
||||||
|
app:
|
||||||
|
name: "TH-Agenter"
|
||||||
|
version: "0.1.0"
|
||||||
|
debug: true
|
||||||
|
environment: "development"
|
||||||
|
host: "0.0.0.0"
|
||||||
|
port: 8000
|
||||||
|
|
||||||
|
# File Configuration
|
||||||
|
file:
|
||||||
|
upload_dir: "./data/uploads"
|
||||||
|
max_size: 10485760 # 10MB
|
||||||
|
allowed_extensions: [".txt", ".pdf", ".docx", ".md"]
|
||||||
|
chunk_size: 1000
|
||||||
|
chunk_overlap: 200
|
||||||
|
semantic_splitter_enabled: true # 启用语义分割器
|
||||||
|
|
||||||
|
# Storage Configuration
|
||||||
|
storage:
|
||||||
|
storage_type: "local" # local or s3
|
||||||
|
upload_directory: "./data/uploads"
|
||||||
|
|
||||||
|
# S3 Configuration
|
||||||
|
s3_bucket_name: "chat-agent-files"
|
||||||
|
aws_access_key_id: null
|
||||||
|
aws_secret_access_key: null
|
||||||
|
aws_region: "us-east-1"
|
||||||
|
s3_endpoint_url: null
|
||||||
|
|
||||||
|
# Logging Configuration
|
||||||
|
logging:
|
||||||
|
level: "INFO"
|
||||||
|
file: "./data/logs/app.log"
|
||||||
|
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
max_bytes: 10485760 # 10MB
|
||||||
|
backup_count: 5
|
||||||
|
|
||||||
|
# CORS Configuration
|
||||||
|
cors:
|
||||||
|
allowed_origins: ["*"]
|
||||||
|
allowed_methods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||||
|
allowed_headers: ["*"]
|
||||||
|
|
||||||
|
# Chat Configuration
|
||||||
|
chat:
|
||||||
|
max_history_length: 10
|
||||||
|
system_prompt: "你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。"
|
||||||
|
max_response_tokens: 1000
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Web框架和核心依赖
|
||||||
|
fastapi>=0.104.1
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
pydantic>=2.5.0
|
||||||
|
sqlalchemy>=2.0.23
|
||||||
|
alembic>=1.13.1
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
passlib[bcrypt]>=1.7.4
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
|
# 数据库和向量数据库
|
||||||
|
psycopg2-binary>=2.9.7 # PostgreSQL
|
||||||
|
pgvector>=0.2.4 # PostgreSQL pgvector extension
|
||||||
|
pymysql>=1.1.2 #mysql
|
||||||
|
|
||||||
|
# Excel和数据分析(智能问数功能)
|
||||||
|
pandas>=2.1.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
openpyxl>=3.1.0 # Excel文件读写
|
||||||
|
xlrd>=2.0.1 # 旧版Excel文件支持
|
||||||
|
|
||||||
|
# LangChain AI框架
|
||||||
|
langchain>=0.1.0
|
||||||
|
langchain-community>=0.0.10
|
||||||
|
langchain-experimental>=0.0.50 # pandas代理
|
||||||
|
langchain-postgres>=0.0.6 # PGVector支持
|
||||||
|
langchain-openai>=0.0.5 # OpenAI集成
|
||||||
|
langgraph>=0.0.40 # LangGraph工作流编排
|
||||||
|
|
||||||
|
# AI模型服务商
|
||||||
|
zhipuai>=2.0.0 # 智谱AI
|
||||||
|
openai>=1.0.0 # OpenAI
|
||||||
|
|
||||||
|
# 文档处理(知识库功能)
|
||||||
|
pypdf2>=3.0.0 # PDF文件处理
|
||||||
|
python-docx>=0.8.11 # Word文档处理
|
||||||
|
markdown>=3.5.0 # Markdown文件处理
|
||||||
|
chardet>=5.2.0 # 文件编码检测
|
||||||
|
pdfplumber>=0.11.7 #pdf内容提取
|
||||||
|
|
||||||
|
# 工作流编排和智能体
|
||||||
|
celery>=5.3.0 # 异步任务队列
|
||||||
|
redis>=5.0.0 # Redis缓存和消息队列
|
||||||
|
apscheduler>=3.10.0 # 定时任务调度
|
||||||
|
|
||||||
|
# 文件和网络处理
|
||||||
|
aiofiles>=23.2.0 # 异步文件操作
|
||||||
|
requests>=2.31.0
|
||||||
|
httpx>=0.25.0
|
||||||
|
pyyaml>=6.0 # YAML配置文件解析
|
||||||
|
boto3>=1.40.30 #云对象存储
|
||||||
|
|
||||||
|
# 开发和测试工具
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
pytest-cov>=4.1.0
|
||||||
|
black>=23.0.0
|
||||||
|
isort>=5.12.0
|
||||||
|
flake8>=6.0.0
|
||||||
|
mypy>=1.5.0
|
||||||
|
pre-commit>=3.3.0
|
||||||
|
|
||||||
|
# 数据库迁移
|
||||||
|
alembic>=1.12.0
|
||||||
|
|
||||||
|
# 监控和日志
|
||||||
|
prometheus-client>=0.17.0
|
||||||
|
structlog>=23.1.0
|
||||||
|
|
||||||
|
# 安全
|
||||||
|
cryptography>=41.0.0
|
||||||
|
passlib[bcrypt]>=1.7.4
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
|
||||||
|
# 性能优化
|
||||||
|
orjson>=3.9.0
|
||||||
|
|
@ -0,0 +1,153 @@
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from fastapi import FastAPI, Request, Depends, HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
# 创建上下文变量存储当前用户和请求ID
|
||||||
|
current_user_ctx: ContextVar[dict] = ContextVar("current_user", default=None)
|
||||||
|
request_id_ctx: ContextVar[str] = ContextVar("request_id", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
# 用户模型
|
||||||
|
class User(BaseModel):
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# 模拟用户服务
|
||||||
|
class UserService:
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user_id() -> int:
|
||||||
|
"""在service中直接获取当前用户ID"""
|
||||||
|
user = current_user_ctx.get()
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("No current user available")
|
||||||
|
return user["id"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user() -> dict:
|
||||||
|
"""获取完整的当前用户信息"""
|
||||||
|
user = current_user_ctx.get()
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("No current user available")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# 业务服务示例
|
||||||
|
class TaskService:
|
||||||
|
def create_task(self, task_data: dict):
|
||||||
|
"""创建任务时自动添加当前用户ID"""
|
||||||
|
current_user_id = UserService.get_current_user_id()
|
||||||
|
|
||||||
|
# 这里模拟数据库操作
|
||||||
|
task = {
|
||||||
|
**task_data,
|
||||||
|
"created_by": current_user_id,
|
||||||
|
"created_at": "2023-10-01 12:00:00"
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Task created by user {current_user_id}: {task}")
|
||||||
|
return task
|
||||||
|
|
||||||
|
def get_user_tasks(self):
|
||||||
|
"""获取当前用户的任务"""
|
||||||
|
user = current_user_ctx.get()
|
||||||
|
current_user_id = UserService.get_current_user_id()
|
||||||
|
|
||||||
|
# 模拟根据用户ID查询任务
|
||||||
|
return [{"id": 1, "title": "Sample task", "user_id": current_user_id}]
|
||||||
|
|
||||||
|
|
||||||
|
# 中间件:设置上下文
|
||||||
|
@app.middleware("http")
|
||||||
|
async def set_context_vars(request: Request, call_next):
|
||||||
|
# 为每个请求生成唯一ID
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
request_id_token = request_id_ctx.set(request_id)
|
||||||
|
|
||||||
|
# 尝试提取用户信息
|
||||||
|
user_token = None
|
||||||
|
try:
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header and auth_header.startswith("Bearer "):
|
||||||
|
token = auth_header.replace("Bearer ", "")
|
||||||
|
user = await decode_token_and_get_user(token) # 您的认证逻辑
|
||||||
|
user_token = current_user_ctx.set(user)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
# 清理上下文
|
||||||
|
request_id_ctx.reset(request_id_token)
|
||||||
|
if user_token:
|
||||||
|
current_user_ctx.reset(user_token)
|
||||||
|
|
||||||
|
|
||||||
|
# 模拟认证函数
|
||||||
|
async def decode_token_and_get_user(token: str) -> dict:
|
||||||
|
# 这里应该是您的实际认证逻辑,例如JWT解码或数据库查询
|
||||||
|
# 简单模拟:根据token返回用户信息
|
||||||
|
if token == "valid_token_123":
|
||||||
|
return {"id": 123, "username": "john_doe", "email": "john@example.com"}
|
||||||
|
elif token == "valid_token_456":
|
||||||
|
return {"id": 456, "username": "jane_doe", "email": "jane@example.com"}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 依赖项:用于路由层认证
|
||||||
|
async def get_current_user_route(token: str = Depends(oauth2_scheme)) -> dict:
|
||||||
|
"""路由层的用户认证"""
|
||||||
|
user = await decode_token_and_get_user(token)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# 路由处理函数
|
||||||
|
@app.post("/tasks")
|
||||||
|
async def create_task(
|
||||||
|
task_data: dict,
|
||||||
|
current_user: dict = Depends(get_current_user_route)
|
||||||
|
):
|
||||||
|
"""创建任务"""
|
||||||
|
# 不需要显式传递user_id到service!
|
||||||
|
task_service = TaskService()
|
||||||
|
task = task_service.create_task(task_data)
|
||||||
|
return {"task": task, "message": "Task created successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tasks")
|
||||||
|
async def get_tasks(current_user: dict = Depends(get_current_user_route)):
|
||||||
|
"""获取当前用户的任务"""
|
||||||
|
task_service = TaskService()
|
||||||
|
tasks = task_service.get_user_tasks()
|
||||||
|
return {"tasks": tasks}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/users/me")
|
||||||
|
async def read_users_me(current_user: dict = Depends(get_current_user_route)):
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
# 测试端点 - 直接在路由中获取上下文用户
|
||||||
|
@app.get("/test-context")
|
||||||
|
async def test_context():
|
||||||
|
"""测试直接通过上下文获取用户(不通过依赖注入)"""
|
||||||
|
try:
|
||||||
|
user = UserService.get_current_user()
|
||||||
|
return {"message": "Successfully got user from context", "user": user}
|
||||||
|
except RuntimeError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from main import app, current_user_ctx, UserService
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_users_me_with_valid_token():
|
||||||
|
"""测试有效令牌获取用户信息"""
|
||||||
|
response = client.get(
|
||||||
|
"/users/me",
|
||||||
|
headers={"Authorization": "Bearer valid_token_123"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == 123
|
||||||
|
assert response.json()["username"] == "john_doe"
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_users_me_with_invalid_token():
|
||||||
|
"""测试无效令牌"""
|
||||||
|
response = client.get(
|
||||||
|
"/users/me",
|
||||||
|
headers={"Authorization": "Bearer invalid_token"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "Invalid credentials"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_task_with_user_context():
|
||||||
|
"""测试创建任务时用户上下文是否正确"""
|
||||||
|
response = client.post(
|
||||||
|
"/tasks",
|
||||||
|
json={"title": "Test task", "description": "Test description"},
|
||||||
|
headers={"Authorization": "Bearer valid_token_123"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
# 检查响应中是否包含正确的用户ID
|
||||||
|
assert response.json()["task"]["created_by"] == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tasks_with_different_users():
|
||||||
|
"""测试不同用户获取任务"""
|
||||||
|
# 用户1
|
||||||
|
response1 = client.get(
|
||||||
|
"/tasks",
|
||||||
|
headers={"Authorization": "Bearer valid_token_123"}
|
||||||
|
)
|
||||||
|
assert response1.status_code == 200
|
||||||
|
# 这里应该只返回用户1的任务
|
||||||
|
|
||||||
|
# 用户2
|
||||||
|
response2 = client.get(
|
||||||
|
"/tasks",
|
||||||
|
headers={"Authorization": "Bearer valid_token_456"}
|
||||||
|
)
|
||||||
|
assert response2.status_code == 200
|
||||||
|
# 这里应该只返回用户2的任务
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_outside_request():
|
||||||
|
"""测试在请求上下文外获取用户(应该失败)"""
|
||||||
|
try:
|
||||||
|
UserService.get_current_user()
|
||||||
|
assert False, "Should have raised an exception"
|
||||||
|
except RuntimeError as e:
|
||||||
|
assert "No current user available" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
# 手动设置上下文进行测试
|
||||||
|
def test_user_service_with_manual_context():
|
||||||
|
"""测试手动设置上下文后获取用户"""
|
||||||
|
test_user = {"id": 999, "username": "test_user"}
|
||||||
|
token = current_user_ctx.set(test_user)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = UserService.get_current_user_id()
|
||||||
|
assert user_id == 999
|
||||||
|
|
||||||
|
user = UserService.get_current_user()
|
||||||
|
assert user["username"] == "test_user"
|
||||||
|
finally:
|
||||||
|
current_user_ctx.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# pytest.main([__file__, "-v"])
|
||||||
|
test_read_users_me_with_valid_token()
|
||||||
|
test_read_users_me_with_invalid_token()
|
||||||
|
test_create_task_with_user_context()
|
||||||
|
test_get_tasks_with_different_users()
|
||||||
|
test_context_outside_request()
|
||||||
|
test_user_service_with_manual_context()
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""init db"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def find_project_root():
|
||||||
|
"""智能查找项目根目录"""
|
||||||
|
current_dir = os.path.abspath(os.getcwd())
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# 可能的项目根目录位置
|
||||||
|
possible_roots = [
|
||||||
|
current_dir, # 当前工作目录
|
||||||
|
script_dir, # 脚本所在目录
|
||||||
|
os.path.dirname(script_dir), # 脚本父目录
|
||||||
|
os.path.dirname(os.path.dirname(script_dir)) # 脚本祖父目录
|
||||||
|
]
|
||||||
|
|
||||||
|
for root in possible_roots:
|
||||||
|
backend_dir = os.path.join(root, 'backend')
|
||||||
|
if os.path.exists(backend_dir) and os.path.exists(os.path.join(backend_dir, 'th_agenter')):
|
||||||
|
return root, backend_dir
|
||||||
|
|
||||||
|
raise FileNotFoundError("无法找到项目根目录和backend目录")
|
||||||
|
|
||||||
|
|
||||||
|
# 查找项目根目录和backend目录
|
||||||
|
project_root, backend_dir = find_project_root()
|
||||||
|
|
||||||
|
# 添加backend目录到Python路径
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
# 保存原始工作目录
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
|
||||||
|
# 设置工作目录为backend,以便找到.env文件
|
||||||
|
os.chdir(backend_dir)
|
||||||
|
|
||||||
|
from th_agenter.db.database import get_db, init_db
|
||||||
|
from th_agenter.services.user import UserService
|
||||||
|
from th_agenter.utils.schemas import UserCreate
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def create_database_tables():
|
||||||
|
"""Create all database tables using SQLAlchemy models."""
|
||||||
|
try:
|
||||||
|
await init_db()
|
||||||
|
print('Database tables created successfully using SQLAlchemy models')
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error creating database tables: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_user():
|
||||||
|
"""Create a test user."""
|
||||||
|
# First, create all database tables using SQLAlchemy models
|
||||||
|
if not await create_database_tables():
|
||||||
|
print('Failed to create database tables')
|
||||||
|
return None
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
# Create test user
|
||||||
|
user_data = UserCreate(
|
||||||
|
username='test',
|
||||||
|
email='test@example.com',
|
||||||
|
password='123456',
|
||||||
|
full_name='Test User 1'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user already exists
|
||||||
|
existing_user = user_service.get_user_by_email(user_data.email)
|
||||||
|
if existing_user:
|
||||||
|
print(f'User already exists: {existing_user.username} ({existing_user.email})')
|
||||||
|
return existing_user
|
||||||
|
|
||||||
|
# Create new user
|
||||||
|
user = user_service.create_user(user_data)
|
||||||
|
print(f'Created user: {user.username} ({user.email})')
|
||||||
|
return user
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error creating user: {e}')
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(create_test_user())
|
||||||
|
finally:
|
||||||
|
# 恢复原始工作目录
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
@ -0,0 +1,125 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import pandas as pd
|
||||||
|
import tempfile
|
||||||
|
import pickle
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
sys.path.insert(0,os.path.join(os.path.dirname(__file__),'..','','backend'))
|
||||||
|
|
||||||
|
|
||||||
|
def execute(df_1,df_2):
|
||||||
|
# 假设合同日期列是字符串类型,将其转换为日期类型
|
||||||
|
if '合同日期' in df_1.columns:
|
||||||
|
df_1['合同日期'] = pd.to_datetime(df_1['合同日期'])
|
||||||
|
if '合同日期' in df_2.columns:
|
||||||
|
df_2['合同日期'] = pd.to_datetime(df_2['合同日期'])
|
||||||
|
|
||||||
|
# 筛选出2024年和2025年的数据
|
||||||
|
filtered_df_1 = df_1[
|
||||||
|
(df_1['合同日期'].dt.year == 2024) | (df_1['合同日期'].dt.year == 2025)]
|
||||||
|
filtered_df_2 = df_2[
|
||||||
|
(df_2['合同日期'].dt.year == 2024) | (df_2['合同日期'].dt.year == 2025)]
|
||||||
|
# 合并两个数据框
|
||||||
|
combined_df = pd.concat([filtered_df_1[:5], filtered_df_2[:7]], ignore_index=True)
|
||||||
|
# 在去重前清理空值
|
||||||
|
# combined_df_clean = combined_df.dropna(subset=['项目号']) # 确保主键不为空
|
||||||
|
|
||||||
|
# 填充数值列的空值
|
||||||
|
combined_df_filled = combined_df.fillna({
|
||||||
|
'总合同额': 0,
|
||||||
|
'已确认比例': 0,
|
||||||
|
'分包合同额': 0
|
||||||
|
})
|
||||||
|
# 找出不同的项目
|
||||||
|
unique_projects = combined_df.drop_duplicates(subset=['项目号'])
|
||||||
|
return unique_projects
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_selected_dataframes():
|
||||||
|
|
||||||
|
try:
|
||||||
|
file1_path = '2025年在手合同数据.xlsx.pkl'
|
||||||
|
file2_path = '2024年在手合同数据.xlsx.pkl'
|
||||||
|
target_filenames = [file1_path,file2_path]
|
||||||
|
dataframes = {}
|
||||||
|
base_dir = os.path.join("D:\workspace-py\chat-agent\\backend","data","uploads","excel_6")
|
||||||
|
|
||||||
|
all_files = os.listdir(base_dir)
|
||||||
|
for filename in target_filenames:
|
||||||
|
matching_files = []
|
||||||
|
for file in all_files:
|
||||||
|
if file.endswith(f"_{filename}") or file.endswith(f"_{filename}.pkl"):
|
||||||
|
matching_files.append(file)
|
||||||
|
if not matching_files:
|
||||||
|
print(f"未找到匹配的文件: {filename}")
|
||||||
|
|
||||||
|
# 如果有多个匹配文件,选择最新的
|
||||||
|
if len(matching_files) > 1:
|
||||||
|
matching_files.sort(key=lambda x: os.path.getmtime(os.path.join(base_dir, x)), reverse=True)
|
||||||
|
print(f"找到多个匹配文件,选择最新的: {matching_files[0]}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected_file = matching_files[0]
|
||||||
|
file_path = os.path.join(base_dir, selected_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 优先加载pickle文件
|
||||||
|
if selected_file.endswith('.pkl'):
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
df = pickle.load(f)
|
||||||
|
print(f"成功从pickle加载文件: {selected_file}")
|
||||||
|
else:
|
||||||
|
# 如果没有pickle文件,尝试加载原始文件
|
||||||
|
if selected_file.endswith(('.xlsx', '.xls')):
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
elif selected_file.endswith('.csv'):
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
else:
|
||||||
|
print(f"不支持的文件格式: {selected_file}")
|
||||||
|
continue
|
||||||
|
print(f"成功从原始文件加载: {selected_file}")
|
||||||
|
|
||||||
|
# 使用原始文件名作为key
|
||||||
|
dataframes[filename] = df
|
||||||
|
print(f"成功加载DataFrame: {filename}, 形状: {df.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载文件失败 {selected_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return dataframes
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataframes = test_load_selected_dataframes()
|
||||||
|
df_names = list(dataframes.keys())
|
||||||
|
if len(df_names) >= 2:
|
||||||
|
df_1 = dataframes[df_names[0]]
|
||||||
|
df_2 = dataframes[df_names[1]]
|
||||||
|
|
||||||
|
print(f"DataFrame 1 ({df_names[0]}) 形状: {df_1.shape}")
|
||||||
|
print(f"DataFrame 1 列名: {list(df_1.columns)}")
|
||||||
|
print(f"DataFrame 1 前几行:")
|
||||||
|
print(df_1.head())
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"DataFrame 2 ({df_names[1]}) 形状: {df_2.shape}")
|
||||||
|
print(f"DataFrame 2 列名: {list(df_2.columns)}")
|
||||||
|
print(f"DataFrame 2 前几行:")
|
||||||
|
print(df_2.head())
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 执行用户提供的数据处理逻辑
|
||||||
|
print("执行数据处理逻辑...")
|
||||||
|
result = execute(df_1, df_2)
|
||||||
|
|
||||||
|
print("处理结果:")
|
||||||
|
print(f"结果形状: {result.shape}")
|
||||||
|
print(f"结果列名: {list(result.columns)}")
|
||||||
|
print("结果数据:")
|
||||||
|
print(result)
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""th - A modern chat agent application."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
__author__ = "Your Name"
|
||||||
|
__email__ = "your.email@example.com"
|
||||||
|
__description__ = "A modern chat agent application with Vue frontend and FastAPI backend"
|
||||||
|
|
||||||
|
# 导出主要组件
|
||||||
|
from .core.config import settings
|
||||||
|
from .core.app import create_app
|
||||||
|
|
||||||
|
__all__ = ["settings", "create_app", "__version__"]
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1 @@
|
||||||
|
"""API module for TH-Agenter."""
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1 @@
|
||||||
|
"""API endpoints for TH-Agenter."""
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Authentication endpoints."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...core.config import get_settings
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.user import UserService
|
||||||
|
from ...schemas.user import UserResponse, UserCreate
|
||||||
|
from ...utils.schemas import Token, LoginRequest
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse)
|
||||||
|
async def register(
|
||||||
|
user_data: UserCreate,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Register a new user."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
# Check if user already exists
|
||||||
|
if user_service.get_user_by_email(user_data.email):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_service.get_user_by_username(user_data.username):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Username already taken"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
user = user_service.create_user(user_data)
|
||||||
|
return UserResponse.from_orm(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=Token)
|
||||||
|
async def login(
|
||||||
|
login_data: LoginRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Login with email and password."""
|
||||||
|
# Authenticate user by email
|
||||||
|
user = AuthService.authenticate_user_by_email(db, login_data.email, login_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect email or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = AuthService.create_access_token(
|
||||||
|
data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login-oauth", response_model=Token)
|
||||||
|
async def login_oauth(
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Login with username and password (OAuth2 compatible)."""
|
||||||
|
# Authenticate user
|
||||||
|
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = AuthService.create_access_token(
|
||||||
|
data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
async def refresh_token(
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Refresh access token."""
|
||||||
|
# Create new access token
|
||||||
|
access_token_expires = timedelta(minutes=settings.security.access_token_expire_minutes)
|
||||||
|
access_token = AuthService.create_access_token(
|
||||||
|
data={"sub": current_user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": settings.security.access_token_expire_minutes * 60
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_current_user_info(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Get current user information."""
|
||||||
|
return UserResponse.from_orm(current_user)
|
||||||
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""Chat endpoints."""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...models.user import User
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.chat import ChatService
|
||||||
|
from ...services.conversation import ConversationService
|
||||||
|
from ...utils.schemas import (
|
||||||
|
ConversationCreate,
|
||||||
|
ConversationResponse,
|
||||||
|
ConversationUpdate,
|
||||||
|
MessageCreate,
|
||||||
|
MessageResponse,
|
||||||
|
ChatRequest,
|
||||||
|
ChatResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# Conversation management
|
||||||
|
@router.post("/conversations", response_model=ConversationResponse)
|
||||||
|
async def create_conversation(
|
||||||
|
conversation_data: ConversationCreate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Create a new conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
conversation = conversation_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
conversation_data=conversation_data
|
||||||
|
)
|
||||||
|
return ConversationResponse.from_orm(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||||
|
async def list_conversations(
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
search: str = None,
|
||||||
|
include_archived: bool = False,
|
||||||
|
order_by: str = "updated_at",
|
||||||
|
order_desc: bool = True,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""List user's conversations with search and filtering."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
conversations = conversation_service.get_user_conversations(
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
search_query=search,
|
||||||
|
include_archived=include_archived,
|
||||||
|
order_by=order_by,
|
||||||
|
order_desc=order_desc
|
||||||
|
)
|
||||||
|
return [ConversationResponse.from_orm(conv) for conv in conversations]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/conversations/count")
|
||||||
|
async def get_conversations_count(
|
||||||
|
search: str = None,
|
||||||
|
include_archived: bool = False,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get total count of conversations."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
count = conversation_service.get_user_conversations_count(
|
||||||
|
search_query=search,
|
||||||
|
include_archived=include_archived
|
||||||
|
)
|
||||||
|
return {"count": count}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/conversations/{conversation_id}", response_model=ConversationResponse)
|
||||||
|
async def get_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get a specific conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
conversation = conversation_service.get_conversation(
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Conversation not found"
|
||||||
|
)
|
||||||
|
return ConversationResponse.from_orm(conversation)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}", response_model=ConversationResponse)
|
||||||
|
async def update_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
conversation_update: ConversationUpdate,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Update a conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
updated_conversation = conversation_service.update_conversation(
|
||||||
|
conversation_id, conversation_update
|
||||||
|
)
|
||||||
|
return ConversationResponse.from_orm(updated_conversation)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/conversations/{conversation_id}")
|
||||||
|
async def delete_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Delete a conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
conversation_service.delete_conversation(conversation_id)
|
||||||
|
return {"message": "Conversation deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/conversations")
|
||||||
|
async def delete_all_conversations(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Delete all conversations."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
conversation_service.delete_all_conversations()
|
||||||
|
return {"message": "All conversations deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}/archive")
|
||||||
|
async def archive_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Archive a conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
success = conversation_service.archive_conversation(conversation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to archive conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Conversation archived successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/conversations/{conversation_id}/unarchive")
|
||||||
|
async def unarchive_conversation(
|
||||||
|
conversation_id: int,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Unarchive a conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
success = conversation_service.unarchive_conversation(conversation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to unarchive conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Conversation unarchived successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
# Message management
|
||||||
|
@router.get("/conversations/{conversation_id}/messages", response_model=List[MessageResponse])
|
||||||
|
async def get_conversation_messages(
|
||||||
|
conversation_id: int,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get messages from a conversation."""
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
messages = conversation_service.get_conversation_messages(
|
||||||
|
conversation_id, skip=skip, limit=limit
|
||||||
|
)
|
||||||
|
return [MessageResponse.from_orm(msg) for msg in messages]
|
||||||
|
|
||||||
|
|
||||||
|
# Chat functionality
|
||||||
|
@router.post("/conversations/{conversation_id}/chat", response_model=ChatResponse)
|
||||||
|
async def chat(
|
||||||
|
conversation_id: int,
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Send a message and get AI response."""
|
||||||
|
chat_service = ChatService(db)
|
||||||
|
response = await chat_service.chat(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
message=chat_request.message,
|
||||||
|
stream=False,
|
||||||
|
temperature=chat_request.temperature,
|
||||||
|
max_tokens=chat_request.max_tokens,
|
||||||
|
use_agent=chat_request.use_agent,
|
||||||
|
use_langgraph=chat_request.use_langgraph,
|
||||||
|
use_knowledge_base=chat_request.use_knowledge_base,
|
||||||
|
knowledge_base_id=chat_request.knowledge_base_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/conversations/{conversation_id}/chat/stream")
|
||||||
|
async def chat_stream(
|
||||||
|
conversation_id: int,
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Send a message and get streaming AI response."""
|
||||||
|
chat_service = ChatService(db)
|
||||||
|
|
||||||
|
async def generate_response():
|
||||||
|
async for chunk in chat_service.chat_stream(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
message=chat_request.message,
|
||||||
|
temperature=chat_request.temperature,
|
||||||
|
max_tokens=chat_request.max_tokens,
|
||||||
|
use_agent=chat_request.use_agent,
|
||||||
|
use_langgraph=chat_request.use_langgraph,
|
||||||
|
use_knowledge_base=chat_request.use_knowledge_base,
|
||||||
|
knowledge_base_id=chat_request.knowledge_base_id
|
||||||
|
):
|
||||||
|
yield f"data: {chunk}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_response(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
"""数据库配置管理API"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from th_agenter.models.user import User
|
||||||
|
from th_agenter.db.database import get_db
|
||||||
|
from th_agenter.services.database_config_service import DatabaseConfigService
|
||||||
|
from th_agenter.utils.logger import get_logger
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
logger = get_logger("database_config_api")
|
||||||
|
router = APIRouter(prefix="/api/database-config", tags=["database-config"])
|
||||||
|
from th_agenter.utils.schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||||
|
|
||||||
|
# 在文件顶部添加
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
# 创建服务单例
|
||||||
|
@lru_cache()
|
||||||
|
def get_database_config_service() -> DatabaseConfigService:
|
||||||
|
"""获取DatabaseConfigService单例"""
|
||||||
|
# 注意:这里需要处理db session的问题
|
||||||
|
return DatabaseConfigService(None) # 临时方案
|
||||||
|
|
||||||
|
# 或者使用全局变量
|
||||||
|
_database_service_instance = None
|
||||||
|
|
||||||
|
def get_database_service(db: Session = Depends(get_db)) -> DatabaseConfigService:
|
||||||
|
"""获取DatabaseConfigService实例"""
|
||||||
|
global _database_service_instance
|
||||||
|
if _database_service_instance is None:
|
||||||
|
_database_service_instance = DatabaseConfigService(db)
|
||||||
|
else:
|
||||||
|
# 更新db session
|
||||||
|
_database_service_instance.db = db
|
||||||
|
return _database_service_instance
|
||||||
|
class DatabaseConfigCreate(BaseModel):
|
||||||
|
name: str = Field(..., description="配置名称")
|
||||||
|
db_type: str = Field(default="postgresql", description="数据库类型")
|
||||||
|
host: str = Field(..., description="主机地址")
|
||||||
|
port: int = Field(..., description="端口号")
|
||||||
|
database: str = Field(..., description="数据库名")
|
||||||
|
username: str = Field(..., description="用户名")
|
||||||
|
password: str = Field(..., description="密码")
|
||||||
|
is_default: bool = Field(default=False, description="是否为默认配置")
|
||||||
|
connection_params: Dict[str, Any] = Field(default=None, description="额外连接参数")
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConfigResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
db_type: str
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
database: str
|
||||||
|
username: str
|
||||||
|
password: str # 添加密码字段
|
||||||
|
is_active: bool
|
||||||
|
is_default: bool
|
||||||
|
created_at: str
|
||||||
|
updated_at: str = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=NormalResponse)
|
||||||
|
async def create_database_config(
|
||||||
|
config_data: DatabaseConfigCreate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""创建或更新数据库配置"""
|
||||||
|
try:
|
||||||
|
service = DatabaseConfigService(db)
|
||||||
|
config = await service.create_or_update_config(current_user.id, config_data.dict())
|
||||||
|
return NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="保存数据库配置成功",
|
||||||
|
data=config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建或更新数据库配置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[DatabaseConfigResponse])
|
||||||
|
async def get_database_configs(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取用户的数据库配置列表"""
|
||||||
|
try:
|
||||||
|
service = DatabaseConfigService(db)
|
||||||
|
configs = service.get_user_configs(current_user.id)
|
||||||
|
|
||||||
|
config_list = [config.to_dict(include_password=True, decrypt_service=service) for config in configs]
|
||||||
|
return config_list
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取数据库配置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/test")
|
||||||
|
async def test_database_connection(
|
||||||
|
config_id: int,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""测试数据库连接"""
|
||||||
|
try:
|
||||||
|
service = DatabaseConfigService(db)
|
||||||
|
result = await service.test_connection(config_id, current_user.id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试数据库连接失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/connect")
|
||||||
|
async def connect_database(
|
||||||
|
config_id: int,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""连接数据库并获取表列表"""
|
||||||
|
try:
|
||||||
|
result = await service.connect_and_get_tables(config_id, current_user.id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"连接数据库失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tables/{table_name}/data")
|
||||||
|
async def get_table_data(
|
||||||
|
table_name: str,
|
||||||
|
db_type: str,
|
||||||
|
limit: int = 100,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
service: DatabaseConfigService = Depends(get_database_service)
|
||||||
|
):
|
||||||
|
"""获取表数据预览"""
|
||||||
|
try:
|
||||||
|
result = await service.get_table_data(table_name, current_user.id, db_type, limit)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表数据失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tables/{table_name}/schema")
|
||||||
|
async def get_table_schema(
|
||||||
|
table_name: str,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取表结构信息"""
|
||||||
|
try:
|
||||||
|
service = DatabaseConfigService(db)
|
||||||
|
result = await service.describe_table(table_name, current_user.id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表结构失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/by-type/{db_type}", response_model=DatabaseConfigResponse)
|
||||||
|
async def get_config_by_type(
|
||||||
|
db_type: str,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""根据数据库类型获取配置"""
|
||||||
|
try:
|
||||||
|
service = DatabaseConfigService(db)
|
||||||
|
config = service.get_config_by_type(current_user.id, db_type)
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"未找到类型为 {db_type} 的配置"
|
||||||
|
)
|
||||||
|
# 返回包含解密密码的配置
|
||||||
|
return config.to_dict(include_password=True, decrypt_service=service)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取数据库配置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,666 @@
|
||||||
|
"""Knowledge base API endpoints."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.knowledge_base import KnowledgeBase, Document
|
||||||
|
from ...services.knowledge_base import KnowledgeBaseService
|
||||||
|
from ...services.document import DocumentService
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...utils.schemas import (
|
||||||
|
KnowledgeBaseCreate,
|
||||||
|
KnowledgeBaseResponse,
|
||||||
|
DocumentResponse,
|
||||||
|
DocumentListResponse,
|
||||||
|
DocumentUpload,
|
||||||
|
DocumentProcessingStatus,
|
||||||
|
DocumentChunksResponse,
|
||||||
|
ErrorResponse
|
||||||
|
)
|
||||||
|
from ...utils.file_utils import FileUtils
|
||||||
|
from ...core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["knowledge-bases"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=KnowledgeBaseResponse)
|
||||||
|
async def create_knowledge_base(
|
||||||
|
kb_data: KnowledgeBaseCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Create a new knowledge base."""
|
||||||
|
try:
|
||||||
|
# Check if knowledge base with same name already exists for this user
|
||||||
|
service = KnowledgeBaseService(db)
|
||||||
|
existing_kb = service.get_knowledge_base_by_name(kb_data.name)
|
||||||
|
if existing_kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Knowledge base with this name already exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create knowledge base
|
||||||
|
kb = service.create_knowledge_base(kb_data)
|
||||||
|
|
||||||
|
return KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=0,
|
||||||
|
active_document_count=0
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to create knowledge base: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[KnowledgeBaseResponse])
|
||||||
|
async def list_knowledge_bases(
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""List knowledge bases for current user."""
|
||||||
|
try:
|
||||||
|
service = KnowledgeBaseService(db)
|
||||||
|
knowledge_bases = service.get_knowledge_bases(skip=skip, limit=limit)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for kb in knowledge_bases:
|
||||||
|
# Count documents
|
||||||
|
total_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
active_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
).count()
|
||||||
|
|
||||||
|
result.append(KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to list knowledge bases: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||||||
|
async def get_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Get knowledge base by ID."""
|
||||||
|
try:
|
||||||
|
service = KnowledgeBaseService(db)
|
||||||
|
kb = service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count documents
|
||||||
|
total_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
active_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
).count()
|
||||||
|
|
||||||
|
return KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get knowledge base: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{kb_id}", response_model=KnowledgeBaseResponse)
|
||||||
|
async def update_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
kb_data: KnowledgeBaseCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Update knowledge base."""
|
||||||
|
try:
|
||||||
|
service = KnowledgeBaseService(db)
|
||||||
|
kb = service.update_knowledge_base(kb_id, kb_data)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count documents
|
||||||
|
total_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
active_docs = db.query(Document).filter(
|
||||||
|
Document.knowledge_base_id == kb.id,
|
||||||
|
Document.is_processed == True
|
||||||
|
).count()
|
||||||
|
|
||||||
|
return KnowledgeBaseResponse(
|
||||||
|
id=kb.id,
|
||||||
|
created_at=kb.created_at,
|
||||||
|
updated_at=kb.updated_at,
|
||||||
|
name=kb.name,
|
||||||
|
description=kb.description,
|
||||||
|
embedding_model=kb.embedding_model,
|
||||||
|
chunk_size=kb.chunk_size,
|
||||||
|
chunk_overlap=kb.chunk_overlap,
|
||||||
|
is_active=kb.is_active,
|
||||||
|
vector_db_type=kb.vector_db_type,
|
||||||
|
collection_name=kb.collection_name,
|
||||||
|
document_count=total_docs,
|
||||||
|
active_document_count=active_docs
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to update knowledge base: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{kb_id}")
|
||||||
|
async def delete_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Delete knowledge base."""
|
||||||
|
try:
|
||||||
|
service = KnowledgeBaseService(db)
|
||||||
|
success = service.delete_knowledge_base(kb_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Knowledge base deleted successfully"}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete knowledge base: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Document management endpoints
|
||||||
|
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
|
||||||
|
async def upload_document(
|
||||||
|
kb_id: int,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
process_immediately: bool = Form(True),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Upload document to knowledge base."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate file
|
||||||
|
if not FileUtils.validate_file_extension(file.filename):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File type not supported. Allowed types: {', '.join(FileUtils.ALLOWED_EXTENSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check file size (50MB limit)
|
||||||
|
max_size = 50 * 1024 * 1024 # 50MB
|
||||||
|
if file.size and file.size > max_size:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File too large. Maximum size: {FileUtils.format_file_size(max_size)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload document
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
document = await doc_service.upload_document(
|
||||||
|
file, kb_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process document immediately if requested
|
||||||
|
if process_immediately:
|
||||||
|
try:
|
||||||
|
await doc_service.process_document(document.id, kb_id)
|
||||||
|
# Refresh document to get updated status
|
||||||
|
db.refresh(document)
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the upload
|
||||||
|
logger.error(f"Failed to process document immediately: {e}")
|
||||||
|
|
||||||
|
return DocumentResponse(
|
||||||
|
id=document.id,
|
||||||
|
created_at=document.created_at,
|
||||||
|
updated_at=document.updated_at,
|
||||||
|
knowledge_base_id=document.knowledge_base_id,
|
||||||
|
filename=document.filename,
|
||||||
|
original_filename=document.original_filename,
|
||||||
|
file_path=document.file_path,
|
||||||
|
file_type=document.file_type,
|
||||||
|
file_size=document.file_size,
|
||||||
|
mime_type=document.mime_type,
|
||||||
|
is_processed=document.is_processed,
|
||||||
|
processing_error=document.processing_error,
|
||||||
|
chunk_count=document.chunk_count or 0,
|
||||||
|
embedding_model=document.embedding_model,
|
||||||
|
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to upload document: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents", response_model=DocumentListResponse)
|
||||||
|
async def list_documents(
|
||||||
|
kb_id: int,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""List documents in knowledge base."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
documents, total = doc_service.list_documents(kb_id, skip, limit)
|
||||||
|
|
||||||
|
doc_responses = []
|
||||||
|
for doc in documents:
|
||||||
|
doc_responses.append(DocumentResponse(
|
||||||
|
id=doc.id,
|
||||||
|
created_at=doc.created_at,
|
||||||
|
updated_at=doc.updated_at,
|
||||||
|
knowledge_base_id=doc.knowledge_base_id,
|
||||||
|
filename=doc.filename,
|
||||||
|
original_filename=doc.original_filename,
|
||||||
|
file_path=doc.file_path,
|
||||||
|
file_type=doc.file_type,
|
||||||
|
file_size=doc.file_size,
|
||||||
|
mime_type=doc.mime_type,
|
||||||
|
is_processed=doc.is_processed,
|
||||||
|
processing_error=doc.processing_error,
|
||||||
|
chunk_count=doc.chunk_count or 0,
|
||||||
|
embedding_model=doc.embedding_model,
|
||||||
|
file_size_mb=round(doc.file_size / (1024 * 1024), 2)
|
||||||
|
))
|
||||||
|
|
||||||
|
return DocumentListResponse(
|
||||||
|
documents=doc_responses,
|
||||||
|
total=total,
|
||||||
|
page=skip // limit + 1,
|
||||||
|
page_size=limit
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to list documents: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}", response_model=DocumentResponse)
|
||||||
|
async def get_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Get document by ID."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
document = doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return DocumentResponse(
|
||||||
|
id=document.id,
|
||||||
|
created_at=document.created_at,
|
||||||
|
updated_at=document.updated_at,
|
||||||
|
knowledge_base_id=document.knowledge_base_id,
|
||||||
|
filename=document.filename,
|
||||||
|
original_filename=document.original_filename,
|
||||||
|
file_path=document.file_path,
|
||||||
|
file_type=document.file_type,
|
||||||
|
file_size=document.file_size,
|
||||||
|
mime_type=document.mime_type,
|
||||||
|
is_processed=document.is_processed,
|
||||||
|
processing_error=document.processing_error,
|
||||||
|
chunk_count=document.chunk_count or 0,
|
||||||
|
embedding_model=document.embedding_model,
|
||||||
|
file_size_mb=round(document.file_size / (1024 * 1024), 2)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get document: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{kb_id}/documents/{doc_id}")
|
||||||
|
async def delete_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Delete document from knowledge base."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
success = doc_service.delete_document(doc_id, kb_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Document deleted successfully"}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete document: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/documents/{doc_id}/process", response_model=DocumentProcessingStatus)
|
||||||
|
async def process_document(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Process document for vector search."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if document exists
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
document = doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the document
|
||||||
|
result = await doc_service.process_document(doc_id, kb_id)
|
||||||
|
|
||||||
|
return DocumentProcessingStatus(
|
||||||
|
document_id=doc_id,
|
||||||
|
status=result["status"],
|
||||||
|
progress=result.get("progress", 0.0),
|
||||||
|
error_message=result.get("error_message"),
|
||||||
|
chunks_created=result.get("chunks_created", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to process document: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}/status", response_model=DocumentProcessingStatus)
|
||||||
|
async def get_document_processing_status(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Get document processing status."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
document = doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine status
|
||||||
|
if document.processing_error:
|
||||||
|
status_str = "failed"
|
||||||
|
progress = 0.0
|
||||||
|
elif document.is_processed:
|
||||||
|
status_str = "completed"
|
||||||
|
progress = 100.0
|
||||||
|
else:
|
||||||
|
status_str = "pending"
|
||||||
|
progress = 0.0
|
||||||
|
|
||||||
|
return DocumentProcessingStatus(
|
||||||
|
document_id=document.id,
|
||||||
|
status=status_str,
|
||||||
|
progress=progress,
|
||||||
|
error_message=document.processing_error,
|
||||||
|
chunks_created=document.chunk_count or 0
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get document status: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/search")
|
||||||
|
async def search_knowledge_base(
|
||||||
|
kb_id: int,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Search documents in a knowledge base."""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
kb = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
results = doc_service.search_documents(kb_id, query, limit)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"knowledge_base_id": kb_id,
|
||||||
|
"query": query,
|
||||||
|
"results": results,
|
||||||
|
"total_results": len(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to search knowledge base: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/documents/{doc_id}/chunks", response_model=DocumentChunksResponse)
|
||||||
|
async def get_document_chunks(
|
||||||
|
kb_id: int,
|
||||||
|
doc_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get document chunks (segments) for a specific document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: Knowledge base ID
|
||||||
|
doc_id: Document ID
|
||||||
|
db: Database session
|
||||||
|
current_user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DocumentChunksResponse: Document chunks with metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify knowledge base exists and user has access
|
||||||
|
kb_service = KnowledgeBaseService(db)
|
||||||
|
knowledge_base = kb_service.get_knowledge_base(kb_id)
|
||||||
|
if not knowledge_base:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Knowledge base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify document exists in the knowledge base
|
||||||
|
doc_service = DocumentService(db)
|
||||||
|
document = doc_service.get_document(doc_id, kb_id)
|
||||||
|
if not document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Document not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get document chunks
|
||||||
|
chunks = doc_service.get_document_chunks(doc_id)
|
||||||
|
|
||||||
|
return DocumentChunksResponse(
|
||||||
|
document_id=doc_id,
|
||||||
|
document_name=document.filename,
|
||||||
|
total_chunks=len(chunks),
|
||||||
|
chunks=chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get document chunks: {str(e)}"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,528 @@
|
||||||
|
"""LLM configuration management API endpoints."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.llm_config import LLMConfig
|
||||||
|
from ...core.simple_permissions import require_super_admin, require_authenticated_user
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...utils.logger import get_logger
|
||||||
|
from ...schemas.llm_config import (
|
||||||
|
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
|
||||||
|
LLMConfigTest
|
||||||
|
)
|
||||||
|
from th_agenter.services.document_processor import get_document_processor
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[LLMConfigResponse])
|
||||||
|
async def get_llm_configs(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
provider: Optional[str] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
is_embedding: Optional[bool] = Query(None),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取大模型配置列表."""
|
||||||
|
try:
|
||||||
|
query = db.query(LLMConfig)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
if search:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
LLMConfig.name.ilike(f"%{search}%"),
|
||||||
|
LLMConfig.model_name.ilike(f"%{search}%"),
|
||||||
|
LLMConfig.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 服务商筛选
|
||||||
|
if provider:
|
||||||
|
query = query.filter(LLMConfig.provider == provider)
|
||||||
|
|
||||||
|
# 状态筛选
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.filter(LLMConfig.is_active == is_active)
|
||||||
|
|
||||||
|
# 模型类型筛选
|
||||||
|
if is_embedding is not None:
|
||||||
|
query = query.filter(LLMConfig.is_embedding == is_embedding)
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
query = query.order_by(LLMConfig.name)
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
configs = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return [config.to_dict(include_sensitive=True) for config in configs]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting LLM configs: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取大模型配置列表失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers")
|
||||||
|
async def get_llm_providers(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取支持的大模型服务商列表."""
|
||||||
|
try:
|
||||||
|
providers = db.query(LLMConfig.provider).distinct().all()
|
||||||
|
return [provider[0] for provider in providers if provider[0]]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting LLM providers: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取服务商列表失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/active", response_model=List[LLMConfigResponse])
|
||||||
|
async def get_active_llm_configs(
|
||||||
|
is_embedding: Optional[bool] = Query(None),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取所有激活的大模型配置."""
|
||||||
|
try:
|
||||||
|
query = db.query(LLMConfig).filter(LLMConfig.is_active == True)
|
||||||
|
|
||||||
|
if is_embedding is not None:
|
||||||
|
query = query.filter(LLMConfig.is_embedding == is_embedding)
|
||||||
|
|
||||||
|
configs = query.order_by(LLMConfig.created_at).all()
|
||||||
|
|
||||||
|
return [config.to_dict(include_sensitive=True) for config in configs]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting active LLM configs: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取激活配置列表失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/default", response_model=LLMConfigResponse)
|
||||||
|
async def get_default_llm_config(
|
||||||
|
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取默认大模型配置."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.is_default == True,
|
||||||
|
LLMConfig.is_embedding == is_embedding,
|
||||||
|
LLMConfig.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
model_type = "嵌入模型" if is_embedding else "对话模型"
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"未找到默认{model_type}配置"
|
||||||
|
)
|
||||||
|
|
||||||
|
return config.to_dict(include_sensitive=True)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting default LLM config: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取默认配置失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{config_id}", response_model=LLMConfigResponse)
|
||||||
|
async def get_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_authenticated_user)
|
||||||
|
):
|
||||||
|
"""获取大模型配置详情."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return config.to_dict(include_sensitive=True)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting LLM config {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取大模型配置详情失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_llm_config(
|
||||||
|
config_data: LLMConfigCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""创建大模型配置."""
|
||||||
|
try:
|
||||||
|
# 检查配置名称是否已存在
|
||||||
|
existing_config = db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.name == config_data.name
|
||||||
|
).first()
|
||||||
|
if existing_config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="配置名称已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建临时配置对象进行验证
|
||||||
|
temp_config = LLMConfig(
|
||||||
|
name=config_data.name,
|
||||||
|
provider=config_data.provider,
|
||||||
|
model_name=config_data.model_name,
|
||||||
|
api_key=config_data.api_key,
|
||||||
|
base_url=config_data.base_url,
|
||||||
|
max_tokens=config_data.max_tokens,
|
||||||
|
temperature=config_data.temperature,
|
||||||
|
top_p=config_data.top_p,
|
||||||
|
frequency_penalty=config_data.frequency_penalty,
|
||||||
|
presence_penalty=config_data.presence_penalty,
|
||||||
|
description=config_data.description,
|
||||||
|
is_active=config_data.is_active,
|
||||||
|
is_default=config_data.is_default,
|
||||||
|
is_embedding=config_data.is_embedding,
|
||||||
|
extra_config=config_data.extra_config or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证配置
|
||||||
|
validation_result = temp_config.validate_config()
|
||||||
|
if not validation_result['valid']:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=validation_result['error']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果设为默认,取消同类型的其他默认配置
|
||||||
|
if config_data.is_default:
|
||||||
|
db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.is_embedding == config_data.is_embedding
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = LLMConfig(
|
||||||
|
name=config_data.name,
|
||||||
|
provider=config_data.provider,
|
||||||
|
model_name=config_data.model_name,
|
||||||
|
api_key=config_data.api_key,
|
||||||
|
base_url=config_data.base_url,
|
||||||
|
max_tokens=config_data.max_tokens,
|
||||||
|
temperature=config_data.temperature,
|
||||||
|
top_p=config_data.top_p,
|
||||||
|
frequency_penalty=config_data.frequency_penalty,
|
||||||
|
presence_penalty=config_data.presence_penalty,
|
||||||
|
description=config_data.description,
|
||||||
|
is_active=config_data.is_active,
|
||||||
|
is_default=config_data.is_default,
|
||||||
|
is_embedding=config_data.is_embedding,
|
||||||
|
extra_config=config_data.extra_config or {}
|
||||||
|
)
|
||||||
|
config.set_audit_fields(current_user.id)
|
||||||
|
|
||||||
|
db.add(config)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(config)
|
||||||
|
|
||||||
|
logger.info(f"LLM config created: {config.name} by user {current_user.username}")
|
||||||
|
return config.to_dict()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error creating LLM config: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="创建大模型配置失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{config_id}", response_model=LLMConfigResponse)
|
||||||
|
async def update_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
config_data: LLMConfigUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""更新大模型配置."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查配置名称是否已存在(排除自己)
|
||||||
|
if config_data.name and config_data.name != config.name:
|
||||||
|
existing_config = db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.name == config_data.name,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
).first()
|
||||||
|
if existing_config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="配置名称已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果设为默认,取消同类型的其他默认配置
|
||||||
|
if config_data.is_default is True:
|
||||||
|
# 获取当前配置的embedding类型,如果更新中包含is_embedding则使用新值
|
||||||
|
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
|
||||||
|
db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.is_embedding == is_embedding,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = config_data.dict(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(config, field, value)
|
||||||
|
|
||||||
|
config.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(config)
|
||||||
|
|
||||||
|
logger.info(f"LLM config updated: {config.name} by user {current_user.username}")
|
||||||
|
return config.to_dict()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error updating LLM config {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="更新大模型配置失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""删除大模型配置."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: 检查是否有对话或其他功能正在使用该配置
|
||||||
|
# 这里可以添加相关的检查逻辑
|
||||||
|
|
||||||
|
db.delete(config)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"LLM config deleted: {config.name} by user {current_user.username}")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error deleting LLM config {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="删除大模型配置失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/test")
|
||||||
|
async def test_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
test_data: LLMConfigTest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""测试大模型配置."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证配置
|
||||||
|
validation_result = config.validate_config()
|
||||||
|
if not validation_result["valid"]:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"配置验证失败: {validation_result['error']}",
|
||||||
|
"details": validation_result
|
||||||
|
}
|
||||||
|
|
||||||
|
# 尝试创建客户端并发送测试请求
|
||||||
|
try:
|
||||||
|
# 这里应该根据不同的服务商创建相应的客户端
|
||||||
|
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
|
||||||
|
|
||||||
|
test_message = test_data.message or "Hello, this is a test message."
|
||||||
|
|
||||||
|
# TODO: 实现具体的测试逻辑
|
||||||
|
# 例如:
|
||||||
|
# client = config.get_client()
|
||||||
|
# response = client.chat.completions.create(
|
||||||
|
# model=config.model_name,
|
||||||
|
# messages=[{"role": "user", "content": test_message}],
|
||||||
|
# max_tokens=100
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 模拟测试成功
|
||||||
|
logger.info(f"LLM config test: {config.name} by user {current_user.username}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "配置测试成功",
|
||||||
|
"test_message": test_message,
|
||||||
|
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
|
||||||
|
"latency_ms": 150, # 模拟延迟
|
||||||
|
"config_info": config.get_client_config()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as test_error:
|
||||||
|
logger.error(f"LLM config test failed: {config.name}, error: {str(test_error)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"配置测试失败: {str(test_error)}",
|
||||||
|
"test_message": test_message,
|
||||||
|
"config_info": config.get_client_config()
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing LLM config {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="测试大模型配置失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/toggle-status")
|
||||||
|
async def toggle_llm_config_status(
|
||||||
|
config_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""切换大模型配置状态."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 切换状态
|
||||||
|
config.is_active = not config.is_active
|
||||||
|
config.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(config)
|
||||||
|
|
||||||
|
status_text = "激活" if config.is_active else "禁用"
|
||||||
|
logger.info(f"LLM config status toggled: {config.name} {status_text} by user {current_user.username}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"配置已{status_text}",
|
||||||
|
"is_active": config.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error toggling LLM config status {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="切换配置状态失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{config_id}/set-default")
|
||||||
|
async def set_default_llm_config(
|
||||||
|
config_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""设置默认大模型配置."""
|
||||||
|
try:
|
||||||
|
config = db.query(LLMConfig).filter(LLMConfig.id == config_id).first()
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="大模型配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查配置是否激活
|
||||||
|
if not config.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="只能将激活的配置设为默认"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取消同类型的其他默认配置
|
||||||
|
db.query(LLMConfig).filter(
|
||||||
|
LLMConfig.is_embedding == config.is_embedding,
|
||||||
|
LLMConfig.id != config_id
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
|
# 设置当前配置为默认
|
||||||
|
config.is_default = True
|
||||||
|
config.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(config)
|
||||||
|
|
||||||
|
model_type = "嵌入模型" if config.is_embedding else "对话模型"
|
||||||
|
logger.info(f"Default LLM config set: {config.name} ({model_type}) by user {current_user.username}")
|
||||||
|
# 更新文档处理器默认embedding
|
||||||
|
get_document_processor()._init_embeddings()
|
||||||
|
return {
|
||||||
|
"message": f"已将 {config.name} 设为默认{model_type}配置",
|
||||||
|
"is_default": config.is_default
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error setting default LLM config {config_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="设置默认配置失败"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,346 @@
|
||||||
|
"""Role management API endpoints."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import and_, or_
|
||||||
|
|
||||||
|
from ...core.simple_permissions import require_super_admin
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...models.user import User
|
||||||
|
from ...models.permission import Role, UserRole
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...utils.logger import get_logger
|
||||||
|
from ...schemas.permission import (
|
||||||
|
RoleCreate, RoleUpdate, RoleResponse,
|
||||||
|
UserRoleAssign
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
router = APIRouter(prefix="/roles", tags=["roles"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[RoleResponse])
|
||||||
|
async def get_roles(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user = Depends(require_super_admin),
|
||||||
|
):
|
||||||
|
"""获取角色列表."""
|
||||||
|
try:
|
||||||
|
query = db.query(Role)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
if search:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
Role.name.ilike(f"%{search}%"),
|
||||||
|
Role.code.ilike(f"%{search}%"),
|
||||||
|
Role.description.ilike(f"%{search}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 状态筛选
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.filter(Role.is_active == is_active)
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
roles = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return [role.to_dict() for role in roles]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting roles: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取角色列表失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{role_id}", response_model=RoleResponse)
|
||||||
|
async def get_role(
|
||||||
|
role_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""获取角色详情."""
|
||||||
|
try:
|
||||||
|
role = db.query(Role).filter(Role.id == role_id).first()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return role.to_dict()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting role {role_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取角色详情失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=RoleResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_role(
|
||||||
|
role_data: RoleCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""创建角色."""
|
||||||
|
try:
|
||||||
|
# 检查角色代码是否已存在
|
||||||
|
existing_role = db.query(Role).filter(
|
||||||
|
Role.code == role_data.code
|
||||||
|
).first()
|
||||||
|
if existing_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="角色代码已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建角色
|
||||||
|
role = Role(
|
||||||
|
name=role_data.name,
|
||||||
|
code=role_data.code,
|
||||||
|
description=role_data.description,
|
||||||
|
is_active=role_data.is_active
|
||||||
|
)
|
||||||
|
role.set_audit_fields(current_user.id)
|
||||||
|
|
||||||
|
db.add(role)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(role)
|
||||||
|
|
||||||
|
logger.info(f"Role created: {role.name} by user {current_user.username}")
|
||||||
|
return role.to_dict()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error creating role: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="创建角色失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{role_id}", response_model=RoleResponse)
|
||||||
|
async def update_role(
|
||||||
|
role_id: int,
|
||||||
|
role_data: RoleUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""更新角色."""
|
||||||
|
try:
|
||||||
|
role = db.query(Role).filter(Role.id == role_id).first()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 超级管理员角色不能被编辑
|
||||||
|
if role.code == "SUPER_ADMIN":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="超级管理员角色不能被编辑"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查角色编码是否已存在(排除当前角色)
|
||||||
|
if role_data.code and role_data.code != role.code:
|
||||||
|
existing_role = db.query(Role).filter(
|
||||||
|
and_(
|
||||||
|
Role.code == role_data.code,
|
||||||
|
Role.id != role_id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if existing_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="角色代码已存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = role_data.dict(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(role, field, value)
|
||||||
|
|
||||||
|
role.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(role)
|
||||||
|
|
||||||
|
logger.info(f"Role updated: {role.name} by user {current_user.username}")
|
||||||
|
return role.to_dict()
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error updating role {role_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="更新角色失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{role_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_role(
|
||||||
|
role_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""删除角色."""
|
||||||
|
try:
|
||||||
|
role = db.query(Role).filter(Role.id == role_id).first()
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 超级管理员角色不能被删除
|
||||||
|
if role.code == "SUPER_ADMIN":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="超级管理员角色不能被删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否有用户使用该角色
|
||||||
|
user_count = db.query(UserRole).filter(
|
||||||
|
UserRole.role_id == role_id
|
||||||
|
).count()
|
||||||
|
if user_count > 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"无法删除角色,还有 {user_count} 个用户关联此角色"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除角色
|
||||||
|
db.delete(role)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Role deleted: {role.name} by user {current_user.username}")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error deleting role {role_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="删除角色失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 用户角色管理路由
|
||||||
|
user_role_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||||
|
|
||||||
|
|
||||||
|
@user_role_router.post("/assign", status_code=status.HTTP_201_CREATED)
|
||||||
|
async def assign_user_roles(
|
||||||
|
assignment_data: UserRoleAssign,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_super_admin)
|
||||||
|
):
|
||||||
|
"""为用户分配角色."""
|
||||||
|
try:
|
||||||
|
# 验证用户是否存在
|
||||||
|
user = db.query(User).filter(User.id == assignment_data.user_id).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="用户不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证角色是否存在
|
||||||
|
roles = db.query(Role).filter(
|
||||||
|
Role.id.in_(assignment_data.role_ids)
|
||||||
|
).all()
|
||||||
|
if len(roles) != len(assignment_data.role_ids):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="部分角色不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除现有角色关联
|
||||||
|
db.query(UserRole).filter(
|
||||||
|
UserRole.user_id == assignment_data.user_id
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
# 添加新的角色关联
|
||||||
|
for role_id in assignment_data.role_ids:
|
||||||
|
user_role = UserRole(
|
||||||
|
user_id=assignment_data.user_id,
|
||||||
|
role_id=role_id
|
||||||
|
)
|
||||||
|
db.add(user_role)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"User roles assigned: user {user.username}, roles {assignment_data.role_ids} by user {current_user.username}")
|
||||||
|
|
||||||
|
return {"message": "角色分配成功"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error assigning roles to user {assignment_data.user_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="角色分配失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@user_role_router.get("/user/{user_id}", response_model=List[RoleResponse])
|
||||||
|
async def get_user_roles(
|
||||||
|
user_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_active_user)
|
||||||
|
):
|
||||||
|
"""获取用户角色列表."""
|
||||||
|
try:
|
||||||
|
# 检查权限:用户只能查看自己的角色,或者是超级管理员
|
||||||
|
if current_user.id != user_id and not current_user.is_superuser():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权限查看其他用户的角色"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="用户不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
roles = db.query(Role).join(
|
||||||
|
UserRole, Role.id == UserRole.role_id
|
||||||
|
).filter(
|
||||||
|
UserRole.user_id == user_id
|
||||||
|
).all()
|
||||||
|
|
||||||
|
return [role.to_dict() for role in roles]
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user roles {user_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取用户角色失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 将子路由添加到主路由
|
||||||
|
router.include_router(user_role_router)
|
||||||
|
|
@ -0,0 +1,342 @@
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from th_agenter.db.database import get_db
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||||
|
from th_agenter.services.conversation import ConversationService
|
||||||
|
from th_agenter.services.conversation_context import conversation_context_service
|
||||||
|
from th_agenter.utils.schemas import BaseResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/smart-chat", tags=["smart-chat"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class SmartQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
class SmartQueryResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
workflow_steps: Optional[list] = None
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
|
||||||
|
class ConversationContextResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@router.post("/query", response_model=SmartQueryResponse)
|
||||||
|
async def smart_query(
|
||||||
|
request: SmartQueryRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
智能问数查询接口
|
||||||
|
支持新对话时自动加载文件列表,智能选择相关Excel文件,生成和执行pandas代码
|
||||||
|
"""
|
||||||
|
conversation_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="查询内容不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="查询内容过长,请控制在1000字符以内"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化工作流管理器
|
||||||
|
workflow_manager = SmartWorkflowManager(db)
|
||||||
|
conversation_service = ConversationService(db)
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"智能问数: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
request.is_new_conversation = True
|
||||||
|
logger.info(f"创建新对话: {conversation_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建对话失败,使用临时会话: {e}")
|
||||||
|
conversation_id = None
|
||||||
|
else:
|
||||||
|
# 验证对话是否存在且属于当前用户
|
||||||
|
try:
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
if not context or context.get('user_id') != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话不存在或无权访问"
|
||||||
|
)
|
||||||
|
logger.info(f"使用现有对话: {conversation_id}")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"验证对话失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="对话验证失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存用户消息失败: {e}")
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 执行智能查询工作流
|
||||||
|
try:
|
||||||
|
result = await workflow_manager.process_smart_query(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
is_new_conversation=request.is_new_conversation
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"智能查询执行失败: {e}")
|
||||||
|
# 返回结构化的错误响应
|
||||||
|
return SmartQueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"查询执行失败: {str(e)}",
|
||||||
|
data={'error_type': 'query_execution_error'},
|
||||||
|
workflow_steps=[{
|
||||||
|
'step': 'error',
|
||||||
|
'status': 'failed',
|
||||||
|
'message': str(e)
|
||||||
|
}],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果查询成功,保存助手回复和更新上下文
|
||||||
|
if result['success'] and conversation_id:
|
||||||
|
try:
|
||||||
|
# 保存助手回复
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result.get('data', {}).get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result.get('data'),
|
||||||
|
'workflow_steps': result.get('workflow_steps', []),
|
||||||
|
'selected_files': result.get('data', {}).get('used_files', [])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=result.get('data', {}).get('used_files', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"查询成功完成,对话ID: {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||||
|
# 不影响返回结果,只记录警告
|
||||||
|
|
||||||
|
# 返回结果,包含对话ID
|
||||||
|
response_data = result.get('data', {})
|
||||||
|
if conversation_id:
|
||||||
|
response_data['conversation_id'] = conversation_id
|
||||||
|
|
||||||
|
return SmartQueryResponse(
|
||||||
|
success=result['success'],
|
||||||
|
message=result.get('message', '查询完成'),
|
||||||
|
data=response_data,
|
||||||
|
workflow_steps=result.get('workflow_steps', []),
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
print(e)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"智能查询接口异常: {e}", exc_info=True)
|
||||||
|
# 返回通用错误响应
|
||||||
|
return SmartQueryResponse(
|
||||||
|
success=False,
|
||||||
|
message="服务器内部错误,请稍后重试",
|
||||||
|
data={'error_type': 'internal_server_error'},
|
||||||
|
workflow_steps=[{
|
||||||
|
'step': 'error',
|
||||||
|
'status': 'failed',
|
||||||
|
'message': '系统异常'
|
||||||
|
}],
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/conversation/{conversation_id}/context", response_model=ConversationContextResponse)
|
||||||
|
async def get_conversation_context(
|
||||||
|
conversation_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取对话上下文信息,包括已使用的文件和历史查询
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取对话上下文
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话上下文不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证用户权限
|
||||||
|
if context['user_id'] != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权访问此对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取对话历史
|
||||||
|
history = await conversation_context_service.get_conversation_history(conversation_id)
|
||||||
|
context['message_history'] = history
|
||||||
|
|
||||||
|
return ConversationContextResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取对话上下文成功",
|
||||||
|
data=context
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取对话上下文失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"获取对话上下文失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/files/status", response_model=ConversationContextResponse)
|
||||||
|
async def get_files_status(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户当前的文件状态和统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
workflow_manager = SmartWorkflowManager()
|
||||||
|
|
||||||
|
# 获取用户文件列表
|
||||||
|
file_list = await workflow_manager.excel_workflow._load_user_file_list(current_user.id)
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
total_files = len(file_list)
|
||||||
|
total_rows = sum(f.get('row_count', 0) for f in file_list)
|
||||||
|
total_columns = sum(f.get('column_count', 0) for f in file_list)
|
||||||
|
|
||||||
|
# 文件类型统计
|
||||||
|
file_types = {}
|
||||||
|
for file_info in file_list:
|
||||||
|
filename = file_info['filename']
|
||||||
|
ext = filename.split('.')[-1].lower() if '.' in filename else 'unknown'
|
||||||
|
file_types[ext] = file_types.get(ext, 0) + 1
|
||||||
|
|
||||||
|
status_data = {
|
||||||
|
'total_files': total_files,
|
||||||
|
'total_rows': total_rows,
|
||||||
|
'total_columns': total_columns,
|
||||||
|
'file_types': file_types,
|
||||||
|
'files': [{
|
||||||
|
'id': f['id'],
|
||||||
|
'filename': f['filename'],
|
||||||
|
'row_count': f.get('row_count', 0),
|
||||||
|
'column_count': f.get('column_count', 0),
|
||||||
|
'columns': f.get('columns', []),
|
||||||
|
'upload_time': f.get('upload_time')
|
||||||
|
} for f in file_list],
|
||||||
|
'ready_for_query': total_files > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return ConversationContextResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"当前有{total_files}个可用文件" if total_files > 0 else "暂无可用文件,请先上传Excel文件",
|
||||||
|
data=status_data
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文件状态失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"获取文件状态失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/conversation/{conversation_id}/reset")
|
||||||
|
async def reset_conversation_context(
|
||||||
|
conversation_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
重置对话上下文,清除历史查询记录但保留文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证对话存在和用户权限
|
||||||
|
context = await conversation_context_service.get_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="对话上下文不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
if context['user_id'] != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="无权访问此对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重置对话上下文
|
||||||
|
success = await conversation_context_service.reset_conversation_context(conversation_id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "对话上下文已重置,可以开始新的数据分析会话"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="重置对话上下文失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重置对话上下文失败: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"重置对话上下文失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,754 @@
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
|
||||||
|
from fastapi.security import HTTPBearer
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
import pandas as pd
|
||||||
|
from th_agenter.utils.schemas import FileListResponse,ExcelPreviewRequest,NormalResponse
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from th_agenter.utils.schemas import BaseResponse
|
||||||
|
from th_agenter.services.smart_query import (
|
||||||
|
SmartQueryService,
|
||||||
|
ExcelAnalysisService,
|
||||||
|
DatabaseQueryService
|
||||||
|
)
|
||||||
|
from th_agenter.services.excel_metadata_service import ExcelMetadataService
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from th_agenter.utils.file_utils import FileUtils
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, AsyncGenerator
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from th_agenter.db.database import get_db
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
from th_agenter.services.smart_workflow import SmartWorkflowManager
|
||||||
|
from th_agenter.services.conversation_context import ConversationContextService
|
||||||
|
import logging
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/smart-query", tags=["smart-query"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class DatabaseConfig(BaseModel):
|
||||||
|
type: str
|
||||||
|
host: str
|
||||||
|
port: str
|
||||||
|
database: str
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
page: int = 1
|
||||||
|
page_size: int = 20
|
||||||
|
table_name: Optional[str] = None
|
||||||
|
|
||||||
|
class TableSchemaRequest(BaseModel):
|
||||||
|
table_name: str
|
||||||
|
|
||||||
|
class ExcelUploadResponse(BaseModel):
|
||||||
|
file_id: int
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None # 添加data字段
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResponse(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload-excel", response_model=ExcelUploadResponse)
|
||||||
|
async def upload_excel(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传Excel文件并进行预处理
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证文件类型
|
||||||
|
allowed_extensions = ['.xlsx', '.xls', '.csv']
|
||||||
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||||
|
|
||||||
|
if file_extension not in allowed_extensions:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="不支持的文件格式,请上传 .xlsx, .xls 或 .csv 文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证文件大小 (10MB)
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
if file_size > 10 * 1024 * 1024:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="文件大小不能超过 10MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建持久化目录结构
|
||||||
|
backend_dir = Path(__file__).parent.parent.parent.parent # 获取backend目录
|
||||||
|
data_dir = backend_dir / "data/uploads"
|
||||||
|
excel_user_dir = data_dir / f"excel_{current_user.id}"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
excel_user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成文件名:{uuid}_{原始文件名称}
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
safe_filename = FileUtils.sanitize_filename(file.filename)
|
||||||
|
new_filename = f"{file_id}_{safe_filename}"
|
||||||
|
file_path = excel_user_dir / new_filename
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 使用Excel元信息服务提取并保存元信息
|
||||||
|
metadata_service = ExcelMetadataService(db)
|
||||||
|
excel_file = metadata_service.save_file_metadata(
|
||||||
|
file_path=str(file_path),
|
||||||
|
original_filename=file.filename,
|
||||||
|
user_id=current_user.id,
|
||||||
|
file_size=file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为了兼容现有前端,仍然创建pickle文件
|
||||||
|
try:
|
||||||
|
if file_extension == '.csv':
|
||||||
|
df = pd.read_csv(file_path, encoding='utf-8')
|
||||||
|
else:
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
if file_extension == '.csv':
|
||||||
|
df = pd.read_csv(file_path, encoding='gbk')
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="文件编码错误,请确保文件为UTF-8或GBK编码"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"文件读取失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存pickle文件到同一目录
|
||||||
|
pickle_filename = f"{file_id}_{safe_filename}.pkl"
|
||||||
|
pickle_path = excel_user_dir / pickle_filename
|
||||||
|
df.to_pickle(pickle_path)
|
||||||
|
|
||||||
|
# 数据预处理和分析(保持兼容性)
|
||||||
|
excel_service = ExcelAnalysisService()
|
||||||
|
analysis_result = excel_service.analyze_dataframe(df, file.filename)
|
||||||
|
|
||||||
|
# 添加数据库文件信息
|
||||||
|
analysis_result.update({
|
||||||
|
'file_id': str(excel_file.id),
|
||||||
|
'database_id': excel_file.id,
|
||||||
|
'temp_file_path': str(pickle_path), # 更新为新的pickle路径
|
||||||
|
'original_filename': file.filename,
|
||||||
|
'file_size_mb': excel_file.file_size_mb,
|
||||||
|
'sheet_names': excel_file.sheet_names,
|
||||||
|
})
|
||||||
|
|
||||||
|
return ExcelUploadResponse(
|
||||||
|
file_id=excel_file.id,
|
||||||
|
success=True,
|
||||||
|
message="Excel文件上传成功",
|
||||||
|
data=analysis_result
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"文件处理失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/preview-excel", response_model=QueryResponse)
|
||||||
|
async def preview_excel(
|
||||||
|
request: ExcelPreviewRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
预览Excel文件数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Preview request for file_id: {request.file_id}, user: {current_user.id}")
|
||||||
|
|
||||||
|
# 验证file_id格式
|
||||||
|
try:
|
||||||
|
file_id = int(request.file_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail=f"无效的文件ID格式: {request.file_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从数据库获取文件信息
|
||||||
|
metadata_service = ExcelMetadataService(db)
|
||||||
|
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||||
|
|
||||||
|
if not excel_file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件不存在或已被删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(excel_file.file_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件已被移动或删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最后访问时间
|
||||||
|
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||||
|
|
||||||
|
# 读取Excel文件
|
||||||
|
if excel_file.file_type.lower() == 'csv':
|
||||||
|
df = pd.read_csv(excel_file.file_path, encoding='utf-8')
|
||||||
|
else:
|
||||||
|
# 对于Excel文件,使用默认sheet或第一个sheet
|
||||||
|
sheet_name = excel_file.default_sheet if excel_file.default_sheet else 0
|
||||||
|
df = pd.read_excel(excel_file.file_path, sheet_name=sheet_name)
|
||||||
|
|
||||||
|
# 计算分页
|
||||||
|
total_rows = len(df)
|
||||||
|
start_idx = (request.page - 1) * request.page_size
|
||||||
|
end_idx = start_idx + request.page_size
|
||||||
|
|
||||||
|
# 获取分页数据
|
||||||
|
paginated_df = df.iloc[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 转换为字典格式
|
||||||
|
data = paginated_df.fillna('').to_dict('records')
|
||||||
|
columns = df.columns.tolist()
|
||||||
|
|
||||||
|
return QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="Excel文件预览加载成功",
|
||||||
|
data={
|
||||||
|
'data': data,
|
||||||
|
'columns': columns,
|
||||||
|
'total_rows': total_rows,
|
||||||
|
'page': request.page,
|
||||||
|
'page_size': request.page_size,
|
||||||
|
'total_pages': (total_rows + request.page_size - 1) // request.page_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"预览文件失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/test-db-connection", response_model=NormalResponse)
|
||||||
|
async def test_database_connection(
|
||||||
|
config: DatabaseConfig,
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试数据库连接
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_service = DatabaseQueryService()
|
||||||
|
is_connected = await db_service.test_connection(config.dict())
|
||||||
|
|
||||||
|
if is_connected:
|
||||||
|
return NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="数据库连接测试成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message="数据库连接测试失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"连接测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除第285-314行的connect_database方法
|
||||||
|
# @router.post("/connect-database", response_model=QueryResponse)
|
||||||
|
# async def connect_database(
|
||||||
|
# config_id: int,
|
||||||
|
# current_user = Depends(AuthService.get_current_user),
|
||||||
|
# db: Session = Depends(get_db)
|
||||||
|
# ):
|
||||||
|
# """连接数据库并获取表列表"""
|
||||||
|
# ... (整个方法都删除)
|
||||||
|
|
||||||
|
@router.post("/table-schema", response_model=QueryResponse)
|
||||||
|
async def get_table_schema(
|
||||||
|
request: TableSchemaRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取数据表结构
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_service = DatabaseQueryService()
|
||||||
|
schema_result = await db_service.get_table_schema(request.table_name, current_user.id)
|
||||||
|
|
||||||
|
if schema_result['success']:
|
||||||
|
return QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取表结构成功",
|
||||||
|
data=schema_result['data']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return QueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=schema_result['message']
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return QueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"获取表结构失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
class DatabaseStreamQueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
database_config_id: int
|
||||||
|
conversation_id: Optional[int] = None
|
||||||
|
is_new_conversation: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/execute-excel-query")
|
||||||
|
async def stream_smart_query(
|
||||||
|
request: StreamQueryRequest,
|
||||||
|
current_user=Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
流式智能问答查询接口
|
||||||
|
支持实时推送工作流步骤和最终结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
workflow_manager = SmartWorkflowManager(db)
|
||||||
|
conversation_context_service = ConversationContextService()
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"智能问数: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建对话失败: {e}")
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存用户消息失败: {e}")
|
||||||
|
|
||||||
|
# 执行智能查询工作流(带流式推送)
|
||||||
|
async for step_data in workflow_manager.process_excel_query_stream(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
is_new_conversation=request.is_new_conversation
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 如果是最终结果,保存到对话历史
|
||||||
|
if step_data.get('type') == 'final_result' and conversation_id:
|
||||||
|
try:
|
||||||
|
result_data = step_data.get('data', {})
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result_data.get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result_data,
|
||||||
|
'workflow_steps': step_data.get('workflow_steps', []),
|
||||||
|
'selected_files': result_data.get('used_files', [])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=result_data.get('used_files', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"查询成功完成,对话ID: {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'complete', 'message': '查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式智能查询异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if workflow_manager:
|
||||||
|
try:
|
||||||
|
workflow_manager.excel_workflow.executor.shutdown(wait=False)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/execute-db-query")
|
||||||
|
async def execute_database_query(
|
||||||
|
request: DatabaseStreamQueryRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
流式数据库查询接口
|
||||||
|
支持实时推送工作流步骤和最终结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证请求参数
|
||||||
|
if not request.query or not request.query.strip():
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容不能为空'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(request.query) > 1000:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '查询内容过长,请控制在1000字符以内'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'start', 'message': '开始处理数据库查询', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 初始化服务
|
||||||
|
workflow_manager = SmartWorkflowManager(db)
|
||||||
|
conversation_context_service = ConversationContextService()
|
||||||
|
|
||||||
|
# 处理对话上下文
|
||||||
|
conversation_id = request.conversation_id
|
||||||
|
|
||||||
|
# 如果是新对话或没有指定对话ID,创建新对话
|
||||||
|
if request.is_new_conversation or not conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id = await conversation_context_service.create_conversation(
|
||||||
|
user_id=current_user.id,
|
||||||
|
title=f"数据库查询: {request.query[:20]}..."
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps({'type': 'conversation_created', 'conversation_id': conversation_id}, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"创建对话失败: {e}")
|
||||||
|
# 不阻断流程,继续执行查询
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=request.query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存用户消息失败: {e}")
|
||||||
|
|
||||||
|
# 执行数据库查询工作流(带流式推送)
|
||||||
|
async for step_data in workflow_manager.process_database_query_stream(
|
||||||
|
user_query=request.query,
|
||||||
|
user_id=current_user.id,
|
||||||
|
database_config_id=request.database_config_id
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 如果是最终结果,保存到对话历史
|
||||||
|
if step_data.get('type') == 'final_result' and conversation_id:
|
||||||
|
try:
|
||||||
|
result_data = step_data.get('data', {})
|
||||||
|
await conversation_context_service.save_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result_data.get('summary', '查询完成'),
|
||||||
|
metadata={
|
||||||
|
'query_result': result_data,
|
||||||
|
'workflow_steps': step_data.get('workflow_steps', []),
|
||||||
|
'generated_sql': result_data.get('generated_sql', '')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新对话上下文
|
||||||
|
await conversation_context_service.update_conversation_context(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
query=request.query,
|
||||||
|
selected_files=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"数据库查询成功完成,对话ID: {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存消息到对话历史失败: {e}")
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'complete', 'message': '数据库查询处理完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式数据库查询异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'查询执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if workflow_manager:
|
||||||
|
try:
|
||||||
|
workflow_manager.database_workflow.executor.shutdown(wait=False)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/cleanup-temp-files")
|
||||||
|
async def cleanup_temp_files(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
清理临时文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
user_prefix = f"excel_{current_user.id}_"
|
||||||
|
|
||||||
|
cleaned_count = 0
|
||||||
|
for filename in os.listdir(temp_dir):
|
||||||
|
if filename.startswith(user_prefix) and filename.endswith('.pkl'):
|
||||||
|
file_path = os.path.join(temp_dir, filename)
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
cleaned_count += 1
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return BaseResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"已清理 {cleaned_count} 个临时文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return BaseResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"清理临时文件失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/files", response_model=FileListResponse)
|
||||||
|
async def get_file_list(
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户上传的Excel文件列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata_service = ExcelMetadataService(db)
|
||||||
|
skip = (page - 1) * page_size
|
||||||
|
files, total = metadata_service.get_user_files(current_user.id, skip, page_size)
|
||||||
|
|
||||||
|
file_list = []
|
||||||
|
for file in files:
|
||||||
|
file_info = {
|
||||||
|
'id': file.id,
|
||||||
|
'filename': file.original_filename,
|
||||||
|
'file_size': file.file_size,
|
||||||
|
'file_size_mb': file.file_size_mb,
|
||||||
|
'file_type': file.file_type,
|
||||||
|
'sheet_names': file.sheet_names,
|
||||||
|
'sheet_count': file.sheet_count,
|
||||||
|
'last_accessed': file.last_accessed.isoformat() if file.last_accessed else None,
|
||||||
|
'is_processed': file.is_processed,
|
||||||
|
'processing_error': file.processing_error
|
||||||
|
}
|
||||||
|
file_list.append(file_info)
|
||||||
|
|
||||||
|
return FileListResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取文件列表成功",
|
||||||
|
data={
|
||||||
|
'files': file_list,
|
||||||
|
'total': total,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return FileListResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"获取文件列表失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/files/{file_id}", response_model=NormalResponse)
|
||||||
|
async def delete_file(
|
||||||
|
file_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除指定的Excel文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata_service = ExcelMetadataService(db)
|
||||||
|
success = metadata_service.delete_file(file_id, current_user.id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message="文件删除成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return NormalResponse(
|
||||||
|
success=False,
|
||||||
|
message="文件不存在或删除失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return NormalResponse(
|
||||||
|
success=True,
|
||||||
|
message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/info", response_model=QueryResponse)
|
||||||
|
async def get_file_info(
|
||||||
|
file_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定文件的详细信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata_service = ExcelMetadataService(db)
|
||||||
|
excel_file = metadata_service.get_file_by_id(file_id, current_user.id)
|
||||||
|
|
||||||
|
if not excel_file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="文件不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最后访问时间
|
||||||
|
metadata_service.update_last_accessed(file_id, current_user.id)
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
'id': excel_file.id,
|
||||||
|
'filename': excel_file.original_filename,
|
||||||
|
'file_size': excel_file.file_size,
|
||||||
|
'file_size_mb': excel_file.file_size_mb,
|
||||||
|
'file_type': excel_file.file_type,
|
||||||
|
'sheet_names': excel_file.sheet_names,
|
||||||
|
'default_sheet': excel_file.default_sheet,
|
||||||
|
'columns_info': excel_file.columns_info,
|
||||||
|
'preview_data': excel_file.preview_data,
|
||||||
|
'data_types': excel_file.data_types,
|
||||||
|
'total_rows': excel_file.total_rows,
|
||||||
|
'total_columns': excel_file.total_columns,
|
||||||
|
'upload_time': excel_file.upload_time.isoformat() if excel_file.upload_time else None,
|
||||||
|
'last_accessed': excel_file.last_accessed.isoformat() if excel_file.last_accessed else None,
|
||||||
|
'sheets_summary': excel_file.get_all_sheets_summary()
|
||||||
|
}
|
||||||
|
|
||||||
|
return QueryResponse(
|
||||||
|
success=True,
|
||||||
|
message="获取文件信息成功",
|
||||||
|
data=file_info
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return QueryResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"获取文件信息失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
"""表元数据管理API"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from th_agenter.models.user import User
|
||||||
|
from th_agenter.db.database import get_db
|
||||||
|
from th_agenter.services.table_metadata_service import TableMetadataService
|
||||||
|
from th_agenter.utils.logger import get_logger
|
||||||
|
from th_agenter.services.auth import AuthService
|
||||||
|
|
||||||
|
logger = get_logger("table_metadata_api")
|
||||||
|
router = APIRouter(prefix="/api/table-metadata", tags=["table-metadata"])
|
||||||
|
|
||||||
|
|
||||||
|
class TableSelectionRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_names: List[str] = Field(..., description="选中的表名列表")
|
||||||
|
|
||||||
|
|
||||||
|
class TableMetadataResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
table_name: str
|
||||||
|
table_schema: str
|
||||||
|
table_type: str
|
||||||
|
table_comment: str
|
||||||
|
columns_count: int
|
||||||
|
row_count: int
|
||||||
|
is_enabled_for_qa: bool
|
||||||
|
qa_description: str
|
||||||
|
business_context: str
|
||||||
|
last_synced_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class QASettingsUpdate(BaseModel):
|
||||||
|
is_enabled_for_qa: bool = Field(default=True)
|
||||||
|
qa_description: str = Field(default="")
|
||||||
|
business_context: str = Field(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class TableByNameRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_name: str = Field(..., description="表名")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/collect")
|
||||||
|
async def collect_table_metadata(
|
||||||
|
request: TableSelectionRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""收集选中表的元数据"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(db)
|
||||||
|
result = await service.collect_and_save_table_metadata(
|
||||||
|
current_user.id,
|
||||||
|
request.database_config_id,
|
||||||
|
request.table_names
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"收集表元数据失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def get_table_metadata(
|
||||||
|
database_config_id: int = None,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取表元数据列表"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(db)
|
||||||
|
metadata_list = service.get_user_table_metadata(
|
||||||
|
current_user.id,
|
||||||
|
database_config_id
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": meta.id,
|
||||||
|
"table_name": meta.table_name,
|
||||||
|
"table_schema": meta.table_schema,
|
||||||
|
"table_type": meta.table_type,
|
||||||
|
"table_comment": meta.table_comment or "",
|
||||||
|
"columns": meta.columns_info if meta.columns_info else [],
|
||||||
|
"column_count": len(meta.columns_info) if meta.columns_info else 0,
|
||||||
|
"row_count": meta.row_count,
|
||||||
|
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||||
|
"qa_description": meta.qa_description or "",
|
||||||
|
"business_context": meta.business_context or "",
|
||||||
|
"created_at": meta.created_at.isoformat() if meta.created_at else "",
|
||||||
|
"updated_at": meta.updated_at.isoformat() if meta.updated_at else "",
|
||||||
|
"last_synced_at": meta.last_synced_at.isoformat() if meta.last_synced_at else "",
|
||||||
|
"qa_settings": {
|
||||||
|
"is_enabled_for_qa": meta.is_enabled_for_qa,
|
||||||
|
"qa_description": meta.qa_description or "",
|
||||||
|
"business_context": meta.business_context or ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for meta in metadata_list
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/by-table")
|
||||||
|
async def get_table_metadata_by_name(
|
||||||
|
request: TableByNameRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""根据表名获取表元数据"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(db)
|
||||||
|
metadata = service.get_table_metadata_by_name(
|
||||||
|
current_user.id,
|
||||||
|
request.database_config_id,
|
||||||
|
request.table_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
data = {
|
||||||
|
"id": metadata.id,
|
||||||
|
"table_name": metadata.table_name,
|
||||||
|
"table_schema": metadata.table_schema,
|
||||||
|
"table_type": metadata.table_type,
|
||||||
|
"table_comment": metadata.table_comment or "",
|
||||||
|
"columns": metadata.columns_info if metadata.columns_info else [],
|
||||||
|
"column_count": len(metadata.columns_info) if metadata.columns_info else 0,
|
||||||
|
"row_count": metadata.row_count,
|
||||||
|
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||||
|
"qa_description": metadata.qa_description or "",
|
||||||
|
"business_context": metadata.business_context or "",
|
||||||
|
"created_at": metadata.created_at.isoformat() if metadata.created_at else "",
|
||||||
|
"updated_at": metadata.updated_at.isoformat() if metadata.updated_at else "",
|
||||||
|
"last_synced_at": metadata.last_synced_at.isoformat() if metadata.last_synced_at else "",
|
||||||
|
"qa_settings": {
|
||||||
|
"is_enabled_for_qa": metadata.is_enabled_for_qa,
|
||||||
|
"qa_description": metadata.qa_description or "",
|
||||||
|
"business_context": metadata.business_context or ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {"success": True, "data": data}
|
||||||
|
else:
|
||||||
|
return {"success": False, "data": None, "message": "表元数据不存在"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表元数据失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{metadata_id}/qa-settings")
|
||||||
|
async def update_qa_settings(
|
||||||
|
metadata_id: int,
|
||||||
|
settings: QASettingsUpdate,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""更新表的问答设置"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(db)
|
||||||
|
success = service.update_table_qa_settings(
|
||||||
|
current_user.id,
|
||||||
|
metadata_id,
|
||||||
|
settings.dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return {"success": True, "message": "设置更新成功"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="表元数据不存在"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新问答设置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TableSaveRequest(BaseModel):
|
||||||
|
database_config_id: int = Field(..., description="数据库配置ID")
|
||||||
|
table_names: List[str] = Field(..., description="要保存的表名列表")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/save")
|
||||||
|
async def save_table_metadata(
|
||||||
|
request: TableSaveRequest,
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""保存选中表的元数据配置"""
|
||||||
|
try:
|
||||||
|
service = TableMetadataService(db)
|
||||||
|
result = await service.save_table_metadata_config(
|
||||||
|
user_id=current_user.id,
|
||||||
|
database_config_id=request.database_config_id,
|
||||||
|
table_names=request.table_names
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"用户 {current_user.id} 保存了 {len(request.table_names)} 个表的配置")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"成功保存 {len(result['saved_tables'])} 个表的配置",
|
||||||
|
"saved_tables": result['saved_tables'],
|
||||||
|
"failed_tables": result.get('failed_tables', [])
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存表元数据配置失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"保存配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
"""User management endpoints."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...core.simple_permissions import require_super_admin
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...services.user import UserService
|
||||||
|
from ...schemas.user import UserResponse, UserUpdate, UserCreate, ChangePasswordRequest, ResetPasswordRequest
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/profile", response_model=UserResponse)
|
||||||
|
async def get_user_profile(
|
||||||
|
current_user = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""Get current user profile."""
|
||||||
|
return UserResponse.from_orm(current_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/profile", response_model=UserResponse)
|
||||||
|
async def update_user_profile(
|
||||||
|
user_update: UserUpdate,
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Update current user profile."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
# Check if email is being changed and is already taken
|
||||||
|
if user_update.email and user_update.email != current_user.email:
|
||||||
|
existing_user = user_service.get_user_by_email(user_update.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
updated_user = user_service.update_user(current_user.id, user_update)
|
||||||
|
return UserResponse.from_orm(updated_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/profile")
|
||||||
|
async def delete_user_account(
|
||||||
|
current_user = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Delete current user account."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
user_service.delete_user(current_user.id)
|
||||||
|
return {"message": "Account deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
# Admin endpoints
|
||||||
|
@router.post("/", response_model=UserResponse)
|
||||||
|
async def create_user(
|
||||||
|
user_create: UserCreate,
|
||||||
|
# current_user = Depends(require_superuser),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Create a new user (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
# Check if username already exists
|
||||||
|
existing_user = user_service.get_user_by_username(user_create.username)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Username already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if email already exists
|
||||||
|
existing_user = user_service.get_user_by_email(user_create.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
new_user = user_service.create_user(user_create)
|
||||||
|
return UserResponse.from_orm(new_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_users(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
size: int = Query(20, ge=1, le=100),
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
role_id: Optional[int] = Query(None),
|
||||||
|
is_active: Optional[bool] = Query(None),
|
||||||
|
# current_user = Depends(require_superuser),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""List all users with pagination and filters (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
skip = (page - 1) * size
|
||||||
|
users, total = user_service.get_users_with_filters(
|
||||||
|
skip=skip,
|
||||||
|
limit=size,
|
||||||
|
search=search,
|
||||||
|
role_id=role_id,
|
||||||
|
is_active=is_active
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"users": [UserResponse.from_orm(user) for user in users],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": size
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
|
async def get_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get user by ID (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
user = user_service.get_user(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
return UserResponse.from_orm(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/change-password")
|
||||||
|
async def change_password(
|
||||||
|
request: ChangePasswordRequest,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Change current user's password."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_service.change_password(
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_password=request.current_password,
|
||||||
|
new_password=request.new_password
|
||||||
|
)
|
||||||
|
return {"message": "Password changed successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
if "Current password is incorrect" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password is incorrect"
|
||||||
|
)
|
||||||
|
elif "must be at least 6 characters" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="New password must be at least 6 characters long"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to change password"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{user_id}/reset-password")
|
||||||
|
async def reset_user_password(
|
||||||
|
user_id: int,
|
||||||
|
request: ResetPasswordRequest,
|
||||||
|
current_user = Depends(require_super_admin),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Reset user password (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_service.reset_password(
|
||||||
|
user_id=user_id,
|
||||||
|
new_password=request.new_password
|
||||||
|
)
|
||||||
|
return {"message": "Password reset successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
if "User not found" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
elif "must be at least 6 characters" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="New password must be at least 6 characters long"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to reset password"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
|
async def update_user(
|
||||||
|
user_id: int,
|
||||||
|
user_update: UserUpdate,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Update user by ID (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
user = user_service.get_user(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = user_service.update_user(user_id, user_update)
|
||||||
|
return UserResponse.from_orm(updated_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{user_id}")
|
||||||
|
async def delete_user(
|
||||||
|
user_id: int,
|
||||||
|
current_user = Depends(AuthService.get_current_active_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Delete user by ID (admin only)."""
|
||||||
|
user_service = UserService(db)
|
||||||
|
|
||||||
|
user = user_service.get_user(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
user_service.delete_user(user_id)
|
||||||
|
return {"message": "User deleted successfully"}
|
||||||
|
|
@ -0,0 +1,538 @@
|
||||||
|
"""工作流管理API"""
|
||||||
|
|
||||||
|
from typing import List, Optional, AsyncGenerator
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import and_
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ...db.database import get_db
|
||||||
|
from ...schemas.workflow import (
|
||||||
|
WorkflowCreate, WorkflowUpdate, WorkflowResponse, WorkflowListResponse,
|
||||||
|
WorkflowExecuteRequest, WorkflowExecutionResponse, NodeExecutionResponse, WorkflowStatus
|
||||||
|
)
|
||||||
|
from ...models.workflow import WorkflowStatus as ModelWorkflowStatus
|
||||||
|
from ...services.workflow_engine import get_workflow_engine
|
||||||
|
from ...services.auth import AuthService
|
||||||
|
from ...models.user import User
|
||||||
|
from ...utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("workflow_api")
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
def convert_workflow_for_response(workflow_dict):
|
||||||
|
"""转换工作流数据以适配响应模型"""
|
||||||
|
if workflow_dict.get('definition') and workflow_dict['definition'].get('connections'):
|
||||||
|
for conn in workflow_dict['definition']['connections']:
|
||||||
|
if 'from_node' in conn:
|
||||||
|
conn['from'] = conn.pop('from_node')
|
||||||
|
if 'to_node' in conn:
|
||||||
|
conn['to'] = conn.pop('to_node')
|
||||||
|
return workflow_dict
|
||||||
|
|
||||||
|
@router.post("/", response_model=WorkflowResponse)
|
||||||
|
async def create_workflow(
|
||||||
|
workflow_data: WorkflowCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""创建工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
# 创建工作流
|
||||||
|
workflow = Workflow(
|
||||||
|
name=workflow_data.name,
|
||||||
|
description=workflow_data.description,
|
||||||
|
definition=workflow_data.definition.dict(),
|
||||||
|
version="1.0.0",
|
||||||
|
status=workflow_data.status,
|
||||||
|
owner_id=current_user.id
|
||||||
|
)
|
||||||
|
workflow.set_audit_fields(current_user.id)
|
||||||
|
|
||||||
|
db.add(workflow)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(workflow)
|
||||||
|
|
||||||
|
# 转换definition中的字段映射
|
||||||
|
workflow_dict = convert_workflow_for_response(workflow.to_dict())
|
||||||
|
|
||||||
|
logger.info(f"Created workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return WorkflowResponse(**workflow_dict)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error creating workflow: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="创建工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/", response_model=WorkflowListResponse)
|
||||||
|
async def list_workflows(
|
||||||
|
skip: Optional[int] = Query(None, ge=0),
|
||||||
|
limit: Optional[int] = Query(None, ge=1, le=100),
|
||||||
|
workflow_status: Optional[WorkflowStatus] = None,
|
||||||
|
search: Optional[str] = Query(None),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流列表"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
query = db.query(Workflow).filter(Workflow.owner_id == current_user.id)
|
||||||
|
|
||||||
|
if workflow_status:
|
||||||
|
query = query.filter(Workflow.status == workflow_status)
|
||||||
|
|
||||||
|
# 添加搜索功能
|
||||||
|
if search:
|
||||||
|
query = query.filter(Workflow.name.ilike(f"%{search}%"))
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# 如果没有传分页参数,返回所有数据
|
||||||
|
if skip is None and limit is None:
|
||||||
|
workflows = query.all()
|
||||||
|
return WorkflowListResponse(
|
||||||
|
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||||
|
total=total,
|
||||||
|
page=1,
|
||||||
|
size=total
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用默认分页参数
|
||||||
|
if skip is None:
|
||||||
|
skip = 0
|
||||||
|
if limit is None:
|
||||||
|
limit = 10
|
||||||
|
|
||||||
|
# 分页查询
|
||||||
|
workflows = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return WorkflowListResponse(
|
||||||
|
workflows=[WorkflowResponse(**convert_workflow_for_response(w.to_dict())) for w in workflows],
|
||||||
|
total=total,
|
||||||
|
page=skip // limit + 1, # 计算页码
|
||||||
|
size=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workflows: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取工作流列表失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/{workflow_id}", response_model=WorkflowResponse)
|
||||||
|
async def get_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流详情"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/{workflow_id}", response_model=WorkflowResponse)
|
||||||
|
async def update_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
workflow_data: WorkflowUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""更新工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = workflow_data.dict(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
if field == "definition" and value:
|
||||||
|
# 如果value是Pydantic模型,转换为字典;如果已经是字典,直接使用
|
||||||
|
if hasattr(value, 'dict'):
|
||||||
|
setattr(workflow, field, value.dict())
|
||||||
|
else:
|
||||||
|
setattr(workflow, field, value)
|
||||||
|
else:
|
||||||
|
setattr(workflow, field, value)
|
||||||
|
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(workflow)
|
||||||
|
|
||||||
|
logger.info(f"Updated workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return WorkflowResponse(**convert_workflow_for_response(workflow.to_dict()))
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error updating workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="更新工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/{workflow_id}")
|
||||||
|
async def delete_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""删除工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.delete(workflow)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Deleted workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return {"message": "工作流删除成功"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error deleting workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="删除工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/activate")
|
||||||
|
async def activate_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""激活工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.status = ModelWorkflowStatus.PUBLISHED
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Activated workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return {"message": "工作流激活成功"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error activating workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="激活工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/deactivate")
|
||||||
|
async def deactivate_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""停用工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.status = ModelWorkflowStatus.ARCHIVED
|
||||||
|
workflow.set_audit_fields(current_user.id, is_update=True)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Deactivated workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return {"message": "工作流停用成功"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"Error deactivating workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="停用工作流失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/execute", response_model=WorkflowExecutionResponse)
|
||||||
|
async def execute_workflow(
|
||||||
|
workflow_id: int,
|
||||||
|
request: WorkflowExecuteRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""执行工作流"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="工作流未激活,无法执行"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取工作流引擎并执行
|
||||||
|
engine = get_workflow_engine()
|
||||||
|
execution_result = await engine.execute_workflow(
|
||||||
|
workflow=workflow,
|
||||||
|
input_data=request.input_data,
|
||||||
|
user_id=current_user.id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Executed workflow: {workflow.name} by user {current_user.username}")
|
||||||
|
return execution_result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing workflow {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"执行工作流失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/{workflow_id}/executions", response_model=List[WorkflowExecutionResponse])
|
||||||
|
async def list_workflow_executions(
|
||||||
|
workflow_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(10, ge=1, le=100),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流执行历史"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow, WorkflowExecution
|
||||||
|
|
||||||
|
# 验证工作流所有权
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作流不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取执行历史
|
||||||
|
executions = db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.workflow_id == workflow_id
|
||||||
|
).order_by(WorkflowExecution.created_at.desc()).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return [WorkflowExecutionResponse.from_orm(execution) for execution in executions]
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing workflow executions {workflow_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取执行历史失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/executions/{execution_id}", response_model=WorkflowExecutionResponse)
|
||||||
|
async def get_workflow_execution(
|
||||||
|
execution_id: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工作流执行详情"""
|
||||||
|
try:
|
||||||
|
from ...models.workflow import WorkflowExecution, Workflow
|
||||||
|
|
||||||
|
execution = db.query(WorkflowExecution).join(
|
||||||
|
Workflow, WorkflowExecution.workflow_id == Workflow.id
|
||||||
|
).filter(
|
||||||
|
and_(
|
||||||
|
WorkflowExecution.id == execution_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowExecutionResponse.from_orm(execution)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting workflow execution {execution_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="获取执行详情失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{workflow_id}/execute-stream")
|
||||||
|
async def execute_workflow_stream(
|
||||||
|
workflow_id: int,
|
||||||
|
request: WorkflowExecuteRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
):
|
||||||
|
"""流式执行工作流,实时推送节点执行状态"""
|
||||||
|
|
||||||
|
async def generate_stream() -> AsyncGenerator[str, None]:
|
||||||
|
workflow_engine = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ...models.workflow import Workflow
|
||||||
|
|
||||||
|
# 验证工作流
|
||||||
|
workflow = db.query(Workflow).filter(
|
||||||
|
and_(
|
||||||
|
Workflow.id == workflow_id,
|
||||||
|
Workflow.owner_id == current_user.id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '工作流不存在'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
if workflow.status != ModelWorkflowStatus.PUBLISHED:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': '工作流未激活,无法执行'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# 发送开始信号
|
||||||
|
yield f"data: {json.dumps({'type': 'workflow_start', 'workflow_id': workflow_id, 'workflow_name': workflow.name, 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 获取工作流引擎
|
||||||
|
workflow_engine = get_workflow_engine()
|
||||||
|
|
||||||
|
# 执行工作流(流式版本)
|
||||||
|
async for step_data in workflow_engine.execute_workflow_stream(
|
||||||
|
workflow=workflow,
|
||||||
|
input_data=request.input_data,
|
||||||
|
user_id=current_user.id,
|
||||||
|
db=db
|
||||||
|
):
|
||||||
|
# 推送工作流步骤
|
||||||
|
yield f"data: {json.dumps(step_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 发送完成信号
|
||||||
|
yield f"data: {json.dumps({'type': 'workflow_complete', 'message': '工作流执行完成', 'timestamp': datetime.now().isoformat()}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式工作流执行异常: {e}", exc_info=True)
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'工作流执行失败: {str(e)}'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_stream(),
|
||||||
|
media_type="text/plain",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Headers": "*",
|
||||||
|
"Access-Control-Allow-Methods": "*"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
"""Main API router."""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from .endpoints import chat
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add other routers when implemented
|
||||||
|
from .endpoints import auth
|
||||||
|
from .endpoints import knowledge_base
|
||||||
|
from .endpoints import smart_query
|
||||||
|
from .endpoints import smart_chat
|
||||||
|
|
||||||
|
from .endpoints import database_config
|
||||||
|
from .endpoints import table_metadata
|
||||||
|
|
||||||
|
# System management endpoints
|
||||||
|
from .endpoints import roles
|
||||||
|
from .endpoints import llm_configs
|
||||||
|
from .endpoints import users
|
||||||
|
|
||||||
|
# Workflow endpoints
|
||||||
|
from .endpoints import workflow
|
||||||
|
|
||||||
|
|
||||||
|
# Create main API router
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
auth.router,
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["authentication"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include sub-routers
|
||||||
|
router.include_router(
|
||||||
|
chat.router,
|
||||||
|
prefix="/chat",
|
||||||
|
tags=["chat"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
knowledge_base.router,
|
||||||
|
prefix="/knowledge-bases",
|
||||||
|
tags=["knowledge-bases"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
smart_query.router,
|
||||||
|
tags=["smart-query"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
smart_chat.router,
|
||||||
|
tags=["smart-chat"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
database_config.router,
|
||||||
|
tags=["database-config"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
table_metadata.router,
|
||||||
|
tags=["table-metadata"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# System management routers
|
||||||
|
router.include_router(
|
||||||
|
roles.router,
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["admin-roles"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
llm_configs.router,
|
||||||
|
prefix="/admin",
|
||||||
|
tags=["admin-llm-configs"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
users.router,
|
||||||
|
prefix="/users",
|
||||||
|
tags=["users"]
|
||||||
|
)
|
||||||
|
|
||||||
|
router.include_router(
|
||||||
|
workflow.router,
|
||||||
|
prefix="/workflows",
|
||||||
|
tags=["workflows"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test endpoint
|
||||||
|
@router.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"message": "API test is working"}
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Core module for TH-Agenter."""
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,177 @@
|
||||||
|
"""FastAPI application factory."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
|
from .config import Settings
|
||||||
|
from .logging import setup_logging
|
||||||
|
from .middleware import UserContextMiddleware
|
||||||
|
from ..api.routes import router
|
||||||
|
from ..db.database import init_db
|
||||||
|
from ..api.endpoints import table_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan manager."""
|
||||||
|
# Startup
|
||||||
|
logging.info("Starting up TH-Agenter application...")
|
||||||
|
await init_db()
|
||||||
|
logging.info("Database initialized")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logging.info("Shutting down TH-Agenter application...")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(settings: Settings = None) -> FastAPI:
|
||||||
|
"""Create and configure FastAPI application."""
|
||||||
|
if settings is None:
|
||||||
|
from .config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging(settings.logging)
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.app_name,
|
||||||
|
version=settings.app_version,
|
||||||
|
description="A modern chat agent application with Vue frontend and FastAPI backend",
|
||||||
|
debug=settings.debug,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add middleware
|
||||||
|
setup_middleware(app, settings)
|
||||||
|
|
||||||
|
# Add exception handlers
|
||||||
|
setup_exception_handlers(app)
|
||||||
|
|
||||||
|
# Include routers
|
||||||
|
app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(table_metadata.router)
|
||||||
|
# 在现有导入中添加
|
||||||
|
from ..api.endpoints import database_config
|
||||||
|
|
||||||
|
# 在路由注册部分添加
|
||||||
|
app.include_router(database_config.router)
|
||||||
|
# Health check endpoint
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy", "version": settings.app_version}
|
||||||
|
|
||||||
|
# Root endpoint
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Chat Agent API is running"}
|
||||||
|
|
||||||
|
# Test endpoint
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"message": "API is working"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middleware(app: FastAPI, settings: Settings) -> None:
|
||||||
|
"""Setup application middleware."""
|
||||||
|
|
||||||
|
# User context middleware (should be first to set context for all requests)
|
||||||
|
app.add_middleware(UserContextMiddleware)
|
||||||
|
|
||||||
|
# CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors.allowed_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=settings.cors.allowed_methods,
|
||||||
|
allow_headers=settings.cors.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trusted host middleware (for production)
|
||||||
|
if settings.environment == "production":
|
||||||
|
app.add_middleware(
|
||||||
|
TrustedHostMiddleware,
|
||||||
|
allowed_hosts=["*"] # Configure this properly in production
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_exception_handlers(app: FastAPI) -> None:
|
||||||
|
"""Setup global exception handlers."""
|
||||||
|
|
||||||
|
@app.exception_handler(StarletteHTTPException)
|
||||||
|
async def http_exception_handler(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "http_error",
|
||||||
|
"message": exc.detail,
|
||||||
|
"status_code": exc.status_code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_json_serializable(obj):
|
||||||
|
"""递归地将对象转换为JSON可序列化的格式"""
|
||||||
|
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
elif isinstance(obj, (ValueError, Exception)):
|
||||||
|
return str(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: make_json_serializable(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return [make_json_serializable(item) for item in obj]
|
||||||
|
else:
|
||||||
|
# For any other object, convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
# Convert any non-serializable objects to strings in error details
|
||||||
|
try:
|
||||||
|
errors = make_json_serializable(exc.errors())
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: if even our conversion fails, use a simple error message
|
||||||
|
errors = [{"type": "serialization_error", "msg": f"Error processing validation details: {str(e)}"}]
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "validation_error",
|
||||||
|
"message": "Request validation failed",
|
||||||
|
"details": errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request, exc):
|
||||||
|
logging.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"type": "internal_error",
|
||||||
|
"message": "Internal server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Create the app instance
|
||||||
|
app = create_app()
|
||||||
|
|
@ -0,0 +1,482 @@
|
||||||
|
"""Configuration management for TH-Agenter."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSettings(BaseSettings):
|
||||||
|
"""Database configuration."""
|
||||||
|
url: str = Field(..., alias="database_url") # Must be provided via environment variable
|
||||||
|
echo: bool = Field(default=False)
|
||||||
|
pool_size: int = Field(default=5)
|
||||||
|
max_overflow: int = Field(default=10)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SecuritySettings(BaseSettings):
|
||||||
|
"""Security configuration."""
|
||||||
|
secret_key: str = Field(default="your-secret-key-here-change-in-production")
|
||||||
|
algorithm: str = Field(default="HS256")
|
||||||
|
access_token_expire_minutes: int = Field(default=300)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolSetings(BaseSettings):
|
||||||
|
# Tavily搜索配置
|
||||||
|
tavily_api_key: Optional[str] = Field(default=None)
|
||||||
|
weather_api_key: Optional[str] = Field(default=None)
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
class LLMSettings(BaseSettings):
|
||||||
|
"""大模型配置 - 支持多种OpenAI协议兼容的服务商."""
|
||||||
|
provider: str = Field(default="openai", alias="llm_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
openai_api_key: Optional[str] = Field(default=None)
|
||||||
|
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||||
|
openai_model: str = Field(default="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_key: Optional[str] = Field(default=None)
|
||||||
|
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||||
|
deepseek_model: str = Field(default="deepseek-chat")
|
||||||
|
|
||||||
|
# 豆包配置
|
||||||
|
doubao_api_key: Optional[str] = Field(default=None)
|
||||||
|
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
doubao_model: str = Field(default="doubao-lite-4k")
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
zhipu_api_key: Optional[str] = Field(default=None)
|
||||||
|
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||||
|
zhipu_model: str = Field(default="glm-4")
|
||||||
|
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
moonshot_api_key: Optional[str] = Field(default=None)
|
||||||
|
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||||
|
moonshot_model: str = Field(default="moonshot-v1-8k")
|
||||||
|
|
||||||
|
# 通用配置
|
||||||
|
max_tokens: int = Field(default=2048)
|
||||||
|
temperature: float = Field(default=0.7)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_current_config(self) -> dict:
|
||||||
|
"""获取当前选择的提供商配置 - 优先从数据库读取默认配置."""
|
||||||
|
try:
|
||||||
|
# 尝试从数据库读取默认聊天模型配置
|
||||||
|
from th_agenter.services.llm_config_service import LLMConfigService
|
||||||
|
llm_service = LLMConfigService()
|
||||||
|
db_config = llm_service.get_default_chat_config()
|
||||||
|
|
||||||
|
if db_config:
|
||||||
|
# 如果数据库中有默认配置,使用数据库配置
|
||||||
|
config = {
|
||||||
|
"api_key": db_config.api_key,
|
||||||
|
"base_url": db_config.base_url,
|
||||||
|
"model": db_config.model_name,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||||
|
import logging
|
||||||
|
logging.warning(f"Failed to read LLM config from database, falling back to env vars: {e}")
|
||||||
|
|
||||||
|
# 回退到原有的环境变量配置
|
||||||
|
provider_configs = {
|
||||||
|
"openai": {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"base_url": self.openai_base_url,
|
||||||
|
"model": self.openai_model
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key": self.deepseek_api_key,
|
||||||
|
"base_url": self.deepseek_base_url,
|
||||||
|
"model": self.deepseek_model
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"api_key": self.doubao_api_key,
|
||||||
|
"base_url": self.doubao_base_url,
|
||||||
|
"model": self.doubao_model
|
||||||
|
},
|
||||||
|
"zhipu": {
|
||||||
|
"api_key": self.zhipu_api_key,
|
||||||
|
"base_url": self.zhipu_base_url,
|
||||||
|
"model": self.zhipu_model
|
||||||
|
},
|
||||||
|
"moonshot": {
|
||||||
|
"api_key": self.moonshot_api_key,
|
||||||
|
"base_url": self.moonshot_base_url,
|
||||||
|
"model": self.moonshot_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config = provider_configs.get(self.provider, provider_configs["openai"])
|
||||||
|
config.update({
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSettings(BaseSettings):
|
||||||
|
"""Embedding模型配置 - 支持多种提供商."""
|
||||||
|
provider: str = Field(default="zhipu", alias="embedding_provider") # openai, deepseek, doubao, zhipu, moonshot
|
||||||
|
|
||||||
|
# OpenAI配置
|
||||||
|
openai_api_key: Optional[str] = Field(default=None)
|
||||||
|
openai_base_url: str = Field(default="https://api.openai.com/v1")
|
||||||
|
openai_embedding_model: str = Field(default="text-embedding-ada-002")
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_key: Optional[str] = Field(default=None)
|
||||||
|
deepseek_base_url: str = Field(default="https://api.deepseek.com/v1")
|
||||||
|
deepseek_embedding_model: str = Field(default="deepseek-embedding")
|
||||||
|
|
||||||
|
# 豆包配置
|
||||||
|
doubao_api_key: Optional[str] = Field(default=None)
|
||||||
|
doubao_base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
doubao_embedding_model: str = Field(default="doubao-embedding")
|
||||||
|
|
||||||
|
# 智谱AI配置
|
||||||
|
zhipu_api_key: Optional[str] = Field(default=None)
|
||||||
|
zhipu_base_url: str = Field(default="https://open.bigmodel.cn/api/paas/v4")
|
||||||
|
zhipu_embedding_model: str = Field(default="embedding-3")
|
||||||
|
|
||||||
|
# 月之暗面配置
|
||||||
|
moonshot_api_key: Optional[str] = Field(default=None)
|
||||||
|
moonshot_base_url: str = Field(default="https://api.moonshot.cn/v1")
|
||||||
|
moonshot_embedding_model: str = Field(default="moonshot-embedding")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_current_config(self) -> dict:
|
||||||
|
"""获取当前选择的embedding提供商配置 - 优先从数据库读取默认配置."""
|
||||||
|
try:
|
||||||
|
# 尝试从数据库读取默认嵌入模型配置
|
||||||
|
from th_agenter.services.llm_config_service import LLMConfigService
|
||||||
|
llm_service = LLMConfigService()
|
||||||
|
db_config = llm_service.get_default_embedding_config()
|
||||||
|
|
||||||
|
if db_config:
|
||||||
|
# 如果数据库中有默认配置,使用数据库配置
|
||||||
|
config = {
|
||||||
|
"api_key": db_config.api_key,
|
||||||
|
"base_url": db_config.base_url,
|
||||||
|
"model": db_config.model_name
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
# 如果数据库读取失败,记录错误并回退到环境变量
|
||||||
|
import logging
|
||||||
|
logging.warning(f"Failed to read embedding config from database, falling back to env vars: {e}")
|
||||||
|
|
||||||
|
# 回退到原有的环境变量配置
|
||||||
|
provider_configs = {
|
||||||
|
"openai": {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"base_url": self.openai_base_url,
|
||||||
|
"model": self.openai_embedding_model
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"api_key": self.deepseek_api_key,
|
||||||
|
"base_url": self.deepseek_base_url,
|
||||||
|
"model": self.deepseek_embedding_model
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"api_key": self.doubao_api_key,
|
||||||
|
"base_url": self.doubao_base_url,
|
||||||
|
"model": self.doubao_embedding_model
|
||||||
|
},
|
||||||
|
"zhipu": {
|
||||||
|
"api_key": self.zhipu_api_key,
|
||||||
|
"base_url": self.zhipu_base_url,
|
||||||
|
"model": self.zhipu_embedding_model
|
||||||
|
},
|
||||||
|
"moonshot": {
|
||||||
|
"api_key": self.moonshot_api_key,
|
||||||
|
"base_url": self.moonshot_base_url,
|
||||||
|
"model": self.moonshot_embedding_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider_configs.get(self.provider, provider_configs["zhipu"])
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBSettings(BaseSettings):
|
||||||
|
"""Vector database configuration."""
|
||||||
|
type: str = Field(default="pgvector", alias="vector_db_type")
|
||||||
|
persist_directory: str = Field(default="./data/chroma")
|
||||||
|
collection_name: str = Field(default="documents")
|
||||||
|
embedding_dimension: int = Field(default=2048) # 智谱AI embedding-3模型的维度
|
||||||
|
|
||||||
|
# PostgreSQL pgvector configuration
|
||||||
|
pgvector_host: str = Field(default="localhost")
|
||||||
|
pgvector_port: int = Field(default=5432)
|
||||||
|
pgvector_database: str = Field(default="vectordb")
|
||||||
|
pgvector_user: str = Field(default="postgres")
|
||||||
|
pgvector_password: str = Field(default="")
|
||||||
|
pgvector_table_name: str = Field(default="embeddings")
|
||||||
|
pgvector_vector_dimension: int = Field(default=1024)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FileSettings(BaseSettings):
|
||||||
|
"""File processing configuration."""
|
||||||
|
upload_dir: str = Field(default="./data/uploads")
|
||||||
|
max_size: int = Field(default=10485760) # 10MB
|
||||||
|
allowed_extensions: Union[str, List[str]] = Field(default=[".txt", ".pdf", ".docx", ".md"])
|
||||||
|
chunk_size: int = Field(default=1000)
|
||||||
|
chunk_overlap: int = Field(default=200)
|
||||||
|
semantic_splitter_enabled: bool = Field(default=False) # 是否启用语义分割器
|
||||||
|
|
||||||
|
@field_validator('allowed_extensions', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def parse_allowed_extensions(cls, v):
|
||||||
|
"""Parse comma-separated string to list of extensions."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Split by comma and add dots if not present
|
||||||
|
extensions = [ext.strip() for ext in v.split(',')]
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
# Ensure all extensions start with dot
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in v]
|
||||||
|
return v
|
||||||
|
|
||||||
|
def get_allowed_extensions_list(self) -> List[str]:
|
||||||
|
"""Get allowed extensions as a list."""
|
||||||
|
if isinstance(self.allowed_extensions, list):
|
||||||
|
return self.allowed_extensions
|
||||||
|
elif isinstance(self.allowed_extensions, str):
|
||||||
|
# Split by comma and add dots if not present
|
||||||
|
extensions = [ext.strip() for ext in self.allowed_extensions.split(',')]
|
||||||
|
return [ext if ext.startswith('.') else f'.{ext}' for ext in extensions]
|
||||||
|
return []
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StorageSettings(BaseSettings):
|
||||||
|
"""Storage configuration."""
|
||||||
|
storage_type: str = Field(default="local") # local or s3
|
||||||
|
upload_directory: str = Field(default="./data/uploads")
|
||||||
|
|
||||||
|
# S3 settings
|
||||||
|
s3_bucket_name: str = Field(default="chat-agent-files")
|
||||||
|
aws_access_key_id: Optional[str] = Field(default=None)
|
||||||
|
aws_secret_access_key: Optional[str] = Field(default=None)
|
||||||
|
aws_region: str = Field(default="us-east-1")
|
||||||
|
s3_endpoint_url: Optional[str] = Field(default=None) # For S3-compatible services
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingSettings(BaseSettings):
|
||||||
|
"""Logging configuration."""
|
||||||
|
level: str = Field(default="INFO")
|
||||||
|
file: str = Field(default="./data/logs/app.log")
|
||||||
|
format: str = Field(default="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
max_bytes: int = Field(default=10485760) # 10MB
|
||||||
|
backup_count: int = Field(default=5)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CORSSettings(BaseSettings):
|
||||||
|
"""CORS configuration."""
|
||||||
|
allowed_origins: List[str] = Field(default=["*"])
|
||||||
|
allowed_methods: List[str] = Field(default=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||||
|
allowed_headers: List[str] = Field(default=["*"])
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSettings(BaseSettings):
|
||||||
|
"""Chat configuration."""
|
||||||
|
max_history_length: int = Field(default=10)
|
||||||
|
system_prompt: str = Field(default="你是一个有用的AI助手,请根据提供的上下文信息回答用户的问题。")
|
||||||
|
max_response_tokens: int = Field(default=1000)
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Main application settings."""
|
||||||
|
|
||||||
|
# App info
|
||||||
|
app_name: str = Field(default="TH-Agenter")
|
||||||
|
app_version: str = Field(default="0.1.0")
|
||||||
|
debug: bool = Field(default=True)
|
||||||
|
environment: str = Field(default="development")
|
||||||
|
|
||||||
|
# Server
|
||||||
|
host: str = Field(default="0.0.0.0")
|
||||||
|
port: int = Field(default=8000)
|
||||||
|
|
||||||
|
# Configuration sections
|
||||||
|
database: DatabaseSettings = Field(default_factory=DatabaseSettings)
|
||||||
|
security: SecuritySettings = Field(default_factory=SecuritySettings)
|
||||||
|
llm: LLMSettings = Field(default_factory=LLMSettings)
|
||||||
|
embedding: EmbeddingSettings = Field(default_factory=EmbeddingSettings)
|
||||||
|
vector_db: VectorDBSettings = Field(default_factory=VectorDBSettings)
|
||||||
|
file: FileSettings = Field(default_factory=FileSettings)
|
||||||
|
storage: StorageSettings = Field(default_factory=StorageSettings)
|
||||||
|
logging: LoggingSettings = Field(default_factory=LoggingSettings)
|
||||||
|
cors: CORSSettings = Field(default_factory=CORSSettings)
|
||||||
|
chat: ChatSettings = Field(default_factory=ChatSettings)
|
||||||
|
tool: ToolSetings = Field(default_factory=ToolSetings)
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore"
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_yaml(cls, config_path: str = "../configs/settings.yaml") -> "Settings":
|
||||||
|
"""Load settings from YAML file."""
|
||||||
|
config_file = Path(config_path)
|
||||||
|
|
||||||
|
# 如果配置文件不存在,尝试从backend目录查找
|
||||||
|
if not config_file.exists():
|
||||||
|
# 获取当前文件所在目录(backend/th_agenter/core)
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
# 向上两级到backend目录,然后找configs/settings.yaml
|
||||||
|
backend_config_path = current_dir.parent.parent / "configs" / "settings.yaml"
|
||||||
|
if backend_config_path.exists():
|
||||||
|
config_file = backend_config_path
|
||||||
|
else:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
config_data = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
# 处理环境变量替换
|
||||||
|
config_data = cls._resolve_env_vars_nested(config_data)
|
||||||
|
|
||||||
|
# 为每个子设置类创建实例,确保它们能正确加载环境变量
|
||||||
|
# 如果YAML中没有对应配置,则使用默认的BaseSettings加载(会自动读取.env文件)
|
||||||
|
settings_kwargs = {}
|
||||||
|
|
||||||
|
# 显式处理各个子设置,以解决debug等情况因为环境的变化没有自动加载.env配置的问题
|
||||||
|
settings_kwargs['database'] = DatabaseSettings(**(config_data.get('database', {})))
|
||||||
|
settings_kwargs['security'] = SecuritySettings(**(config_data.get('security', {})))
|
||||||
|
settings_kwargs['llm'] = LLMSettings(**(config_data.get('llm', {})))
|
||||||
|
settings_kwargs['embedding'] = EmbeddingSettings(**(config_data.get('embedding', {})))
|
||||||
|
settings_kwargs['vector_db'] = VectorDBSettings(**(config_data.get('vector_db', {})))
|
||||||
|
settings_kwargs['file'] = FileSettings(**(config_data.get('file', {})))
|
||||||
|
settings_kwargs['storage'] = StorageSettings(**(config_data.get('storage', {})))
|
||||||
|
settings_kwargs['logging'] = LoggingSettings(**(config_data.get('logging', {})))
|
||||||
|
settings_kwargs['cors'] = CORSSettings(**(config_data.get('cors', {})))
|
||||||
|
settings_kwargs['chat'] = ChatSettings(**(config_data.get('chat', {})))
|
||||||
|
settings_kwargs['tool'] = ToolSetings(**(config_data.get('tool', {})))
|
||||||
|
|
||||||
|
# 添加顶级配置
|
||||||
|
for key, value in config_data.items():
|
||||||
|
if key not in settings_kwargs:
|
||||||
|
settings_kwargs[key] = value
|
||||||
|
|
||||||
|
return cls(**settings_kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _flatten_config(config: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
|
||||||
|
"""Flatten nested configuration dictionary."""
|
||||||
|
flat = {}
|
||||||
|
for key, value in config.items():
|
||||||
|
new_key = f"{prefix}_{key}" if prefix else key
|
||||||
|
if isinstance(value, dict):
|
||||||
|
flat.update(Settings._flatten_config(value, new_key))
|
||||||
|
else:
|
||||||
|
flat[new_key] = value
|
||||||
|
return flat
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_env_vars_nested(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Resolve environment variables in nested configuration."""
|
||||||
|
if isinstance(config, dict):
|
||||||
|
return {key: Settings._resolve_env_vars_nested(value) for key, value in config.items()}
|
||||||
|
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
|
||||||
|
env_var = config[2:-1]
|
||||||
|
return os.getenv(env_var, config)
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Resolve environment variables in configuration values."""
|
||||||
|
resolved = {}
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||||
|
env_var = value[2:-1]
|
||||||
|
resolved[key] = os.getenv(env_var, value)
|
||||||
|
else:
|
||||||
|
resolved[key] = value
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get cached settings instance."""
|
||||||
|
return Settings.load_from_yaml()
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
settings = get_settings()
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
"""
|
||||||
|
HTTP请求上下文管理,如:获取当前登录用户信息及Token信息
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional
|
||||||
|
import threading
|
||||||
|
from ..models.user import User
|
||||||
|
|
||||||
|
# Context variable to store current user
|
||||||
|
current_user_context: ContextVar[Optional[User]] = ContextVar('current_user', default=None)
|
||||||
|
|
||||||
|
# Thread-local storage as backup
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
class UserContext:
|
||||||
|
"""User context manager for accessing current user globally."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_current_user(user: User) -> None:
|
||||||
|
"""Set current user in context."""
|
||||||
|
import logging
|
||||||
|
logging.info(f"Setting user in context: {user.username} (ID: {user.id})")
|
||||||
|
|
||||||
|
# Set in ContextVar
|
||||||
|
current_user_context.set(user)
|
||||||
|
|
||||||
|
# Also set in thread-local as backup
|
||||||
|
_thread_local.current_user = user
|
||||||
|
|
||||||
|
# Verify it was set
|
||||||
|
verify_user = current_user_context.get()
|
||||||
|
logging.info(f"Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_current_user_with_token(user: User):
|
||||||
|
"""Set current user in context and return token for cleanup."""
|
||||||
|
import logging
|
||||||
|
logging.info(f"Setting user in context with token: {user.username} (ID: {user.id})")
|
||||||
|
|
||||||
|
# Set in ContextVar and get token
|
||||||
|
token = current_user_context.set(user)
|
||||||
|
|
||||||
|
# Also set in thread-local as backup
|
||||||
|
_thread_local.current_user = user
|
||||||
|
|
||||||
|
# Verify it was set
|
||||||
|
verify_user = current_user_context.get()
|
||||||
|
logging.info(f"Verification - ContextVar user: {verify_user.username if verify_user else None}")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_current_user_token(token):
|
||||||
|
"""Reset current user context using token."""
|
||||||
|
import logging
|
||||||
|
logging.info("Resetting user context using token")
|
||||||
|
|
||||||
|
# Reset ContextVar using token
|
||||||
|
current_user_context.reset(token)
|
||||||
|
|
||||||
|
# Clear thread-local as well
|
||||||
|
if hasattr(_thread_local, 'current_user'):
|
||||||
|
delattr(_thread_local, 'current_user')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user() -> Optional[User]:
|
||||||
|
"""Get current user from context."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Try ContextVar first
|
||||||
|
user = current_user_context.get()
|
||||||
|
if user:
|
||||||
|
logging.debug(f"Got user from ContextVar: {user.username} (ID: {user.id})")
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Fallback to thread-local
|
||||||
|
user = getattr(_thread_local, 'current_user', None)
|
||||||
|
if user:
|
||||||
|
logging.debug(f"Got user from thread-local: {user.username} (ID: {user.id})")
|
||||||
|
return user
|
||||||
|
|
||||||
|
logging.debug("No user found in context (neither ContextVar nor thread-local)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_current_user_id() -> Optional[int]:
|
||||||
|
"""Get current user ID from context."""
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.id if user else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_current_user() -> None:
|
||||||
|
"""Clear current user from context."""
|
||||||
|
import logging
|
||||||
|
logging.info("Clearing user context")
|
||||||
|
|
||||||
|
current_user_context.set(None)
|
||||||
|
if hasattr(_thread_local, 'current_user'):
|
||||||
|
delattr(_thread_local, 'current_user')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def require_current_user() -> User:
|
||||||
|
"""Get current user from context, raise exception if not found."""
|
||||||
|
# Use the same logic as get_current_user to check both ContextVar and thread-local
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
if user is None:
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="No authenticated user in context"
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def require_current_user_id() -> int:
|
||||||
|
"""Get current user ID from context, raise exception if not found."""
|
||||||
|
user = UserContext.require_current_user()
|
||||||
|
return user.id
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Custom exceptions for the application."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCustomException(Exception):
|
||||||
|
"""Base custom exception class."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(BaseCustomException):
|
||||||
|
"""Exception raised when a resource is not found."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(BaseCustomException):
|
||||||
|
"""Exception raised when validation fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(BaseCustomException):
|
||||||
|
"""Exception raised when authentication fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationError(BaseCustomException):
|
||||||
|
"""Exception raised when authorization fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(BaseCustomException):
|
||||||
|
"""Exception raised when database operations fail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigurationError(BaseCustomException):
|
||||||
|
"""Exception raised when configuration is invalid."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalServiceError(BaseCustomException):
|
||||||
|
"""Exception raised when external service calls fail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BusinessLogicError(BaseCustomException):
|
||||||
|
"""Exception raised when business logic validation fails."""
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""LLM工厂类,用于创建和管理LLM实例"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from .config import get_settings
|
||||||
|
|
||||||
|
def create_llm(model: Optional[str] = None, temperature: Optional[float] = None, streaming: bool = False) -> ChatOpenAI:
|
||||||
|
"""创建LLM实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 可选,指定使用的模型名称。如果不指定,将使用配置文件中的默认模型
|
||||||
|
temperature: 可选,模型温度参数
|
||||||
|
streaming: 是否启用流式响应,默认False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatOpenAI实例
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
llm_config = settings.llm.get_current_config()
|
||||||
|
|
||||||
|
if model:
|
||||||
|
# 根据指定的模型获取对应配置
|
||||||
|
if model.startswith('deepseek'):
|
||||||
|
llm_config['model'] = settings.llm.deepseek_model
|
||||||
|
llm_config['api_key'] = settings.llm.deepseek_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.deepseek_base_url
|
||||||
|
elif model.startswith('doubao'):
|
||||||
|
llm_config['model'] = settings.llm.doubao_model
|
||||||
|
llm_config['api_key'] = settings.llm.doubao_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.doubao_base_url
|
||||||
|
elif model.startswith('glm'):
|
||||||
|
llm_config['model'] = settings.llm.zhipu_model
|
||||||
|
llm_config['api_key'] = settings.llm.zhipu_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.zhipu_base_url
|
||||||
|
elif model.startswith('moonshot'):
|
||||||
|
llm_config['model'] = settings.llm.moonshot_model
|
||||||
|
llm_config['api_key'] = settings.llm.moonshot_api_key
|
||||||
|
llm_config['base_url'] = settings.llm.moonshot_base_url
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=llm_config['model'],
|
||||||
|
api_key=llm_config['api_key'],
|
||||||
|
base_url=llm_config['base_url'],
|
||||||
|
temperature=temperature if temperature is not None else llm_config['temperature'],
|
||||||
|
max_tokens=llm_config['max_tokens'],
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
"""Logging configuration for TH-Agenter."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .config import LoggingSettings
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(logging_config: LoggingSettings) -> None:
|
||||||
|
"""Setup application logging."""
|
||||||
|
|
||||||
|
# 确保使用绝对路径,避免在不同目录运行时路径不一致
|
||||||
|
log_file_path = logging_config.file
|
||||||
|
if not Path(log_file_path).is_absolute():
|
||||||
|
# 如果是相对路径,则基于项目根目录计算绝对路径
|
||||||
|
# 项目根目录是backend的父目录
|
||||||
|
backend_dir = Path(__file__).parent.parent.parent
|
||||||
|
log_file_path = str(backend_dir / log_file_path)
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
log_file = Path(log_file_path)
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(getattr(logging, logging_config.level.upper()))
|
||||||
|
|
||||||
|
# Clear existing handlers
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(logging_config.format)
|
||||||
|
|
||||||
|
# Console handler
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# File handler with rotation
|
||||||
|
file_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
filename=log_file_path,
|
||||||
|
maxBytes=logging_config.max_bytes,
|
||||||
|
backupCount=logging_config.backup_count,
|
||||||
|
encoding="utf-8"
|
||||||
|
)
|
||||||
|
file_handler.setLevel(getattr(logging, logging_config.level.upper()))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Set specific logger levels
|
||||||
|
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
||||||
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("fastapi").setLevel(logging.INFO)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
logging.info("Logging configured successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
"""Get a logger instance."""
|
||||||
|
return logging.getLogger(name or __name__)
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""
|
||||||
|
中间件管理,如上下文中间件:校验Token等
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import Request, HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import Response
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from ..db.database import get_db_session
|
||||||
|
from ..services.auth import AuthService
|
||||||
|
from .context import UserContext
|
||||||
|
|
||||||
|
|
||||||
|
class UserContextMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to set user context for authenticated requests."""
|
||||||
|
|
||||||
|
def __init__(self, app, exclude_paths: list = None):
|
||||||
|
super().__init__(app)
|
||||||
|
# Paths that don't require authentication
|
||||||
|
self.exclude_paths = exclude_paths or [
|
||||||
|
"/docs",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
"/api/auth/login",
|
||||||
|
"/api/auth/register",
|
||||||
|
"/api/auth/login-oauth",
|
||||||
|
"/auth/login",
|
||||||
|
"/auth/register",
|
||||||
|
"/auth/login-oauth",
|
||||||
|
"/health",
|
||||||
|
"/test"
|
||||||
|
]
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""Process request and set user context if authenticated."""
|
||||||
|
import logging
|
||||||
|
logging.info(f"[MIDDLEWARE] Processing request: {request.method} {request.url.path}")
|
||||||
|
|
||||||
|
# Skip authentication for excluded paths
|
||||||
|
path = request.url.path
|
||||||
|
logging.info(f"[MIDDLEWARE] Checking path: {path} against exclude_paths: {self.exclude_paths}")
|
||||||
|
|
||||||
|
should_skip = False
|
||||||
|
for exclude_path in self.exclude_paths:
|
||||||
|
# Exact match
|
||||||
|
if path == exclude_path:
|
||||||
|
should_skip = True
|
||||||
|
logging.info(f"[MIDDLEWARE] Path {path} exactly matches exclude_path {exclude_path}")
|
||||||
|
break
|
||||||
|
# For paths ending with '/', check if request path starts with it
|
||||||
|
elif exclude_path.endswith('/') and path.startswith(exclude_path):
|
||||||
|
should_skip = True
|
||||||
|
logging.info(f"[MIDDLEWARE] Path {path} starts with exclude_path {exclude_path}")
|
||||||
|
break
|
||||||
|
# For paths not ending with '/', check if request path starts with it + '/'
|
||||||
|
elif not exclude_path.endswith('/') and exclude_path != '/' and path.startswith(exclude_path + '/'):
|
||||||
|
should_skip = True
|
||||||
|
logging.info(f"[MIDDLEWARE] Path {path} starts with exclude_path {exclude_path}/")
|
||||||
|
break
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
|
logging.info(f"[MIDDLEWARE] Skipping authentication for excluded path: {path}")
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
logging.info(f"[MIDDLEWARE] Processing authenticated request: {path}")
|
||||||
|
|
||||||
|
# Always clear any existing user context to ensure fresh authentication
|
||||||
|
UserContext.clear_current_user()
|
||||||
|
|
||||||
|
# Initialize context token
|
||||||
|
user_token = None
|
||||||
|
|
||||||
|
# Try to extract and validate token
|
||||||
|
try:
|
||||||
|
# Get authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
# No token provided, return 401 error
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "Missing or invalid authorization header"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract token
|
||||||
|
token = authorization.split(" ")[1]
|
||||||
|
|
||||||
|
# Verify token
|
||||||
|
payload = AuthService.verify_token(token)
|
||||||
|
if payload is None:
|
||||||
|
# Invalid token, return 401 error
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "Invalid or expired token"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get username from token
|
||||||
|
username = payload.get("sub")
|
||||||
|
if not username:
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "Invalid token payload"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
db = get_db_session()
|
||||||
|
try:
|
||||||
|
from ..models.user import User
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
if not user:
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "User not found"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={"detail": "User account is inactive"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set user in context using token mechanism
|
||||||
|
user_token = UserContext.set_current_user_with_token(user)
|
||||||
|
import logging
|
||||||
|
logging.info(f"User {user.username} (ID: {user.id}) authenticated and set in context")
|
||||||
|
|
||||||
|
# Verify context is set correctly
|
||||||
|
current_user_id = UserContext.get_current_user_id()
|
||||||
|
logging.info(f"Verified current user ID in context: {current_user_id}")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the request
|
||||||
|
import logging
|
||||||
|
logging.warning(f"Error setting user context: {e}")
|
||||||
|
|
||||||
|
# Continue with request
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
# Always clear user context after request processing
|
||||||
|
UserContext.clear_current_user()
|
||||||
|
logging.debug(f"[MIDDLEWARE] Cleared user context after processing request: {path}")
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
"""简化的权限检查系统."""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import HTTPException, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ..db.database import get_db
|
||||||
|
from ..models.user import User
|
||||||
|
from ..models.permission import Role
|
||||||
|
from ..services.auth import AuthService
|
||||||
|
|
||||||
|
|
||||||
|
def is_super_admin(user: User, db: Session) -> bool:
|
||||||
|
"""检查用户是否为超级管理员."""
|
||||||
|
if not user or not user.is_active:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查用户是否有超级管理员角色
|
||||||
|
for role in user.roles:
|
||||||
|
if role.code == "SUPER_ADMIN":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def require_super_admin(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> User:
|
||||||
|
"""要求超级管理员权限的依赖项."""
|
||||||
|
if not is_super_admin(current_user, db):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="需要超级管理员权限"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def require_authenticated_user(
|
||||||
|
current_user: User = Depends(AuthService.get_current_user)
|
||||||
|
) -> User:
|
||||||
|
"""要求已认证用户的依赖项."""
|
||||||
|
if not current_user or not current_user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="需要登录"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePermissionChecker:
|
||||||
|
"""简化的权限检查器."""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def check_super_admin(self, user: User) -> bool:
|
||||||
|
"""检查是否为超级管理员."""
|
||||||
|
return is_super_admin(user, self.db)
|
||||||
|
|
||||||
|
def check_user_access(self, user: User, target_user_id: int) -> bool:
|
||||||
|
"""检查用户访问权限(自己或超级管理员)."""
|
||||||
|
if not user or not user.is_active:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 超级管理员可以访问所有用户
|
||||||
|
if self.check_super_admin(user):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 用户只能访问自己的信息
|
||||||
|
return user.id == target_user_id
|
||||||
|
|
||||||
|
|
||||||
|
# 权限装饰器
|
||||||
|
def super_admin_required(func):
|
||||||
|
"""超级管理员权限装饰器."""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def authenticated_required(func):
|
||||||
|
"""认证用户权限装饰器."""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# 这个装饰器主要用于服务层,实际的FastAPI依赖项检查在路由层
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""User utility functions for easy access to current user context."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from ..models.user import User
|
||||||
|
from .context import UserContext
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> Optional[User]:
|
||||||
|
"""Get current authenticated user from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_id() -> Optional[int]:
|
||||||
|
"""Get current authenticated user ID from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user ID if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user_id()
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user() -> User:
|
||||||
|
"""Get current authenticated user from context, raise exception if not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no authenticated user in context
|
||||||
|
"""
|
||||||
|
return UserContext.require_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user_id() -> int:
|
||||||
|
"""Get current authenticated user ID from context, raise exception if not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no authenticated user in context
|
||||||
|
"""
|
||||||
|
return UserContext.require_current_user_id()
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_authenticated() -> bool:
|
||||||
|
"""Check if there is an authenticated user in the current context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user is authenticated, False otherwise
|
||||||
|
"""
|
||||||
|
return UserContext.get_current_user() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_username() -> Optional[str]:
|
||||||
|
"""Get current authenticated user's username from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user's username if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.username if user else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_email() -> Optional[str]:
|
||||||
|
"""Get current authenticated user's email from context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user's email if authenticated, None otherwise
|
||||||
|
"""
|
||||||
|
user = UserContext.get_current_user()
|
||||||
|
return user.email if user else None
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""Database module for TH-Agenter."""
|
||||||
|
|
||||||
|
from .database import get_db, init_db
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
__all__ = ["get_db", "init_db", "Base"]
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Database base model."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, Integer, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Base):
|
||||||
|
"""Base model with common fields."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
created_at = Column(DateTime, default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
created_by = Column(Integer, nullable=True)
|
||||||
|
updated_by = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize model with automatic audit fields setting."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Set audit fields for new instances
|
||||||
|
self.set_audit_fields()
|
||||||
|
def set_audit_fields(self, user_id: Optional[int] = None, is_update: bool = False):
|
||||||
|
"""Set audit fields for create/update operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: ID of the user performing the operation (optional, will use context if not provided)
|
||||||
|
is_update: True for update operations, False for create operations
|
||||||
|
"""
|
||||||
|
# Get user_id from context if not provided
|
||||||
|
if user_id is None:
|
||||||
|
from ..core.context import UserContext
|
||||||
|
try:
|
||||||
|
user_id = UserContext.get_current_user_id()
|
||||||
|
except Exception:
|
||||||
|
# If no user in context, skip setting audit fields
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip if still no user_id
|
||||||
|
if user_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not is_update:
|
||||||
|
# For create operations, set both create_by and update_by
|
||||||
|
self.created_by = user_id
|
||||||
|
self.updated_by = user_id
|
||||||
|
else:
|
||||||
|
# For update operations, only set update_by
|
||||||
|
self.updated_by = user_id
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert model to dictionary."""
|
||||||
|
return {
|
||||||
|
column.name: getattr(self, column.name)
|
||||||
|
for column in self.__table__.columns
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Database connection and session management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from ..core.config import get_settings
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
|
# Global variables
|
||||||
|
engine = None
|
||||||
|
SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_database_engine():
|
||||||
|
"""Create database engine."""
|
||||||
|
global engine, SessionLocal
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
database_url = settings.database.url
|
||||||
|
|
||||||
|
# Determine database type and configure engine
|
||||||
|
engine_kwargs = {
|
||||||
|
"echo": settings.database.echo,
|
||||||
|
}
|
||||||
|
|
||||||
|
if database_url.startswith("sqlite"):
|
||||||
|
# SQLite configuration
|
||||||
|
engine = create_engine(database_url, **engine_kwargs)
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
logging.info(f"SQLite database engine created: {database_url}")
|
||||||
|
elif database_url.startswith("postgresql"):
|
||||||
|
# PostgreSQL configuration
|
||||||
|
engine_kwargs.update({
|
||||||
|
"pool_size": settings.database.pool_size,
|
||||||
|
"max_overflow": settings.database.max_overflow,
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
})
|
||||||
|
engine = create_engine(database_url, **engine_kwargs)
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
logging.info(f"PostgreSQL database engine created: {database_url}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type. Please use PostgreSQL or SQLite. URL: {database_url}")
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
"""Initialize database."""
|
||||||
|
global engine
|
||||||
|
|
||||||
|
if engine is None:
|
||||||
|
create_database_engine()
|
||||||
|
|
||||||
|
# Import all models to ensure they are registered
|
||||||
|
from ..models import user, conversation, message, knowledge_base, permission, workflow
|
||||||
|
|
||||||
|
# Create all tables
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
logging.info("Database tables created")
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""Get database session."""
|
||||||
|
global SessionLocal
|
||||||
|
|
||||||
|
if SessionLocal is None:
|
||||||
|
create_database_engine()
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logging.error(f"Database session error: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_session() -> Session:
|
||||||
|
"""Get database session (synchronous)."""
|
||||||
|
global SessionLocal
|
||||||
|
|
||||||
|
if SessionLocal is None:
|
||||||
|
create_database_engine()
|
||||||
|
|
||||||
|
return SessionLocal()
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
nRWvsuXfWm1IXThkPQ7lA4HlTiNP4CkCYxqczEfrRR4=
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""Initialize system management data."""
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from ..models.permission import Role
|
||||||
|
from ..models.user import User
|
||||||
|
from ..services.auth import AuthService
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init_roles(db: Session) -> None:
|
||||||
|
"""初始化系统角色."""
|
||||||
|
roles_data = [
|
||||||
|
{
|
||||||
|
"name": "超级管理员",
|
||||||
|
"code": "SUPER_ADMIN",
|
||||||
|
"description": "系统超级管理员,拥有所有权限"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "普通用户",
|
||||||
|
"code": "USER",
|
||||||
|
"description": "普通用户,基础功能权限"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for role_data in roles_data:
|
||||||
|
# 检查角色是否已存在
|
||||||
|
existing_role = db.query(Role).filter(
|
||||||
|
Role.code == role_data["code"]
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing_role:
|
||||||
|
# 创建角色
|
||||||
|
role = Role(
|
||||||
|
name=role_data["name"],
|
||||||
|
code=role_data["code"],
|
||||||
|
description=role_data["description"]
|
||||||
|
)
|
||||||
|
role.set_audit_fields(1) # 系统用户ID为1
|
||||||
|
db.add(role)
|
||||||
|
logger.info(f"Created role: {role_data['name']} ({role_data['code']})")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info("Roles initialization completed")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init_admin_user(db: Session) -> None:
|
||||||
|
"""初始化默认管理员用户."""
|
||||||
|
logger.info("Starting admin user initialization...")
|
||||||
|
|
||||||
|
# 检查是否已存在管理员用户
|
||||||
|
existing_admin = db.query(User).filter(
|
||||||
|
User.username == "admin"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing_admin:
|
||||||
|
logger.info("Admin user already exists")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建管理员用户
|
||||||
|
hashed_password = AuthService.get_password_hash("admin123")
|
||||||
|
|
||||||
|
admin_user = User(
|
||||||
|
username="admin",
|
||||||
|
email="admin@example.com",
|
||||||
|
hashed_password=hashed_password,
|
||||||
|
full_name="系统管理员",
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_user.set_audit_fields(1)
|
||||||
|
db.add(admin_user)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
# 分配超级管理员角色
|
||||||
|
super_admin_role = db.query(Role).filter(
|
||||||
|
Role.code == "SUPER_ADMIN"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if super_admin_role:
|
||||||
|
admin_user.roles.append(super_admin_role)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info("Admin user created: admin / admin123")
|
||||||
|
|
||||||
|
|
||||||
|
def init_system_data(db: Session) -> None:
|
||||||
|
"""初始化所有系统数据."""
|
||||||
|
logger.info("Starting system data initialization...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化角色
|
||||||
|
init_roles(db)
|
||||||
|
|
||||||
|
# 初始化管理员用户
|
||||||
|
init_admin_user(db)
|
||||||
|
|
||||||
|
logger.info("System data initialization completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during system data initialization: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 可以单独运行此脚本来初始化数据
|
||||||
|
from ..db.database import SessionLocal
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
init_system_data(db)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Add system management tables.
|
||||||
|
|
||||||
|
Revision ID: add_system_management
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-01-01 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_system_management'
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Create system management tables."""
|
||||||
|
|
||||||
|
# Create departments table
|
||||||
|
op.create_table('departments',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('parent_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['parent_id'], ['departments.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
|
||||||
|
op.create_index(op.f('ix_departments_parent_id'), 'departments', ['parent_id'], unique=False)
|
||||||
|
|
||||||
|
# Create permissions table
|
||||||
|
op.create_table('permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('category', sa.String(length=50), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_permissions_category'), 'permissions', ['category'], unique=False)
|
||||||
|
op.create_index(op.f('ix_permissions_name'), 'permissions', ['name'], unique=False)
|
||||||
|
|
||||||
|
# Create roles table
|
||||||
|
op.create_table('roles',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False)
|
||||||
|
|
||||||
|
# Create role_permissions table
|
||||||
|
op.create_table('role_permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('role_id', 'permission_id', name='uq_role_permission')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_role_permissions_permission_id'), 'role_permissions', ['permission_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_role_permissions_role_id'), 'role_permissions', ['role_id'], unique=False)
|
||||||
|
|
||||||
|
# Create user_roles table
|
||||||
|
op.create_table('user_roles',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('user_id', 'role_id', name='uq_user_role')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_roles_role_id'), 'user_roles', ['role_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_roles_user_id'), 'user_roles', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
# Create user_permissions table
|
||||||
|
op.create_table('user_permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('granted', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('user_id', 'permission_id', name='uq_user_permission')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_permissions_permission_id'), 'user_permissions', ['permission_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_permissions_user_id'), 'user_permissions', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
# Create llm_configs table
|
||||||
|
op.create_table('llm_configs',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('model_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('api_key', sa.Text(), nullable=True),
|
||||||
|
sa.Column('api_base', sa.String(length=500), nullable=True),
|
||||||
|
sa.Column('api_version', sa.String(length=20), nullable=True),
|
||||||
|
sa.Column('max_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('temperature', sa.Float(), nullable=True),
|
||||||
|
sa.Column('top_p', sa.Float(), nullable=True),
|
||||||
|
sa.Column('frequency_penalty', sa.Float(), nullable=True),
|
||||||
|
sa.Column('presence_penalty', sa.Float(), nullable=True),
|
||||||
|
sa.Column('timeout', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('is_default', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('sort_order', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('updated_by', sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||||
|
op.create_index(op.f('ix_llm_configs_provider'), 'llm_configs', ['provider'], unique=False)
|
||||||
|
|
||||||
|
# Add new columns to users table
|
||||||
|
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('is_superuser', sa.Boolean(), nullable=True, default=False))
|
||||||
|
op.add_column('users', sa.Column('is_admin', sa.Boolean(), nullable=True, default=False))
|
||||||
|
op.add_column('users', sa.Column('last_login_at', sa.DateTime(), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('login_count', sa.Integer(), nullable=True, default=0))
|
||||||
|
|
||||||
|
# Create foreign key constraint for department_id
|
||||||
|
op.create_foreign_key('fk_users_department_id', 'users', 'departments', ['department_id'], ['id'])
|
||||||
|
op.create_index(op.f('ix_users_department_id'), 'users', ['department_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Drop system management tables."""
|
||||||
|
|
||||||
|
# Drop foreign key and index for users.department_id
|
||||||
|
op.drop_index(op.f('ix_users_department_id'), table_name='users')
|
||||||
|
op.drop_constraint('fk_users_department_id', 'users', type_='foreignkey')
|
||||||
|
|
||||||
|
# Drop new columns from users table
|
||||||
|
op.drop_column('users', 'login_count')
|
||||||
|
op.drop_column('users', 'last_login_at')
|
||||||
|
op.drop_column('users', 'is_admin')
|
||||||
|
op.drop_column('users', 'is_superuser')
|
||||||
|
op.drop_column('users', 'department_id')
|
||||||
|
|
||||||
|
# Drop llm_configs table
|
||||||
|
op.drop_index(op.f('ix_llm_configs_provider'), table_name='llm_configs')
|
||||||
|
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||||
|
op.drop_table('llm_configs')
|
||||||
|
|
||||||
|
# Drop user_permissions table
|
||||||
|
op.drop_index(op.f('ix_user_permissions_user_id'), table_name='user_permissions')
|
||||||
|
op.drop_index(op.f('ix_user_permissions_permission_id'), table_name='user_permissions')
|
||||||
|
op.drop_table('user_permissions')
|
||||||
|
|
||||||
|
# Drop user_roles table
|
||||||
|
op.drop_index(op.f('ix_user_roles_user_id'), table_name='user_roles')
|
||||||
|
op.drop_index(op.f('ix_user_roles_role_id'), table_name='user_roles')
|
||||||
|
op.drop_table('user_roles')
|
||||||
|
|
||||||
|
# Drop role_permissions table
|
||||||
|
op.drop_index(op.f('ix_role_permissions_role_id'), table_name='role_permissions')
|
||||||
|
op.drop_index(op.f('ix_role_permissions_permission_id'), table_name='role_permissions')
|
||||||
|
op.drop_table('role_permissions')
|
||||||
|
|
||||||
|
# Drop roles table
|
||||||
|
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||||
|
op.drop_table('roles')
|
||||||
|
|
||||||
|
# Drop permissions table
|
||||||
|
op.drop_index(op.f('ix_permissions_name'), table_name='permissions')
|
||||||
|
op.drop_index(op.f('ix_permissions_category'), table_name='permissions')
|
||||||
|
op.drop_table('permissions')
|
||||||
|
|
||||||
|
# Drop departments table
|
||||||
|
op.drop_index(op.f('ix_departments_parent_id'), table_name='departments')
|
||||||
|
op.drop_index(op.f('ix_departments_name'), table_name='departments')
|
||||||
|
op.drop_table('departments')
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Add user_department association table migration."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
from th_agenter.core.config import get_settings
|
||||||
|
|
||||||
|
async def create_user_department_table():
|
||||||
|
"""Create user_departments association table."""
|
||||||
|
settings = get_settings()
|
||||||
|
database_url = settings.database.url
|
||||||
|
|
||||||
|
print(f"Database URL: {database_url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析PostgreSQL连接URL
|
||||||
|
# postgresql://user:password@host:port/database
|
||||||
|
url_parts = database_url.replace('postgresql://', '').split('/')
|
||||||
|
db_name = url_parts[1] if len(url_parts) > 1 else 'postgres'
|
||||||
|
user_host = url_parts[0].split('@')
|
||||||
|
user_pass = user_host[0].split(':')
|
||||||
|
host_port = user_host[1].split(':')
|
||||||
|
|
||||||
|
user = user_pass[0]
|
||||||
|
password = user_pass[1] if len(user_pass) > 1 else ''
|
||||||
|
host = host_port[0]
|
||||||
|
port = int(host_port[1]) if len(host_port) > 1 else 5432
|
||||||
|
|
||||||
|
# 连接PostgreSQL数据库
|
||||||
|
conn = await asyncpg.connect(
|
||||||
|
user=user,
|
||||||
|
password=password,
|
||||||
|
database=db_name,
|
||||||
|
host=host,
|
||||||
|
port=port
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建user_departments表
|
||||||
|
create_table_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS user_departments (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
department_id INTEGER NOT NULL,
|
||||||
|
is_primary BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (department_id) REFERENCES departments (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
await conn.execute(create_table_sql)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
create_indexes_sql = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_user_departments_user_id ON user_departments (user_id);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_user_departments_department_id ON user_departments (department_id);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_user_departments_unique ON user_departments (user_id, department_id);"
|
||||||
|
]
|
||||||
|
|
||||||
|
for index_sql in create_indexes_sql:
|
||||||
|
await conn.execute(index_sql)
|
||||||
|
|
||||||
|
print("User departments table created successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating user departments table: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if 'conn' in locals():
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(create_user_department_table())
|
||||||
|
|
@ -0,0 +1,451 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Migration script to move hardcoded resources to database."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the backend directory to Python path
|
||||||
|
backend_dir = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.insert(0, str(backend_dir))
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from th_agenter.core.config import settings
|
||||||
|
from th_agenter.db.database import Base, get_db_session
|
||||||
|
from th_agenter.models import * # Import all models to ensure they're registered
|
||||||
|
from th_agenter.utils.logger import get_logger
|
||||||
|
from th_agenter.models.resource import Resource
|
||||||
|
from th_agenter.models.permission import Role
|
||||||
|
from th_agenter.models.resource import RoleResource
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def migrate_hardcoded_resources():
|
||||||
|
"""Migrate hardcoded resources from init_resource_data.py to database."""
|
||||||
|
db = None
|
||||||
|
try:
|
||||||
|
# Get database session
|
||||||
|
db = get_db_session()
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
logger.error("Failed to create database session")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create all tables if they don't exist
|
||||||
|
from th_agenter.db.database import engine as global_engine
|
||||||
|
if global_engine:
|
||||||
|
Base.metadata.create_all(bind=global_engine)
|
||||||
|
|
||||||
|
logger.info("Starting hardcoded resources migration...")
|
||||||
|
|
||||||
|
# Check if resources already exist
|
||||||
|
existing_count = db.query(Resource).count()
|
||||||
|
if existing_count > 0:
|
||||||
|
logger.info(f"Found {existing_count} existing resources. Checking role assignments.")
|
||||||
|
# 即使资源已存在,也要检查并分配角色资源关联
|
||||||
|
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||||
|
if admin_role:
|
||||||
|
# 获取所有资源
|
||||||
|
all_resources = db.query(Resource).all()
|
||||||
|
assigned_count = 0
|
||||||
|
|
||||||
|
for resource in all_resources:
|
||||||
|
# 检查关联是否已存在
|
||||||
|
existing = db.query(RoleResource).filter(
|
||||||
|
RoleResource.role_id == admin_role.id,
|
||||||
|
RoleResource.resource_id == resource.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
role_resource = RoleResource(
|
||||||
|
role_id=admin_role.id,
|
||||||
|
resource_id=resource.id
|
||||||
|
)
|
||||||
|
db.add(role_resource)
|
||||||
|
assigned_count += 1
|
||||||
|
|
||||||
|
if assigned_count > 0:
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"已为系统管理员角色分配 {assigned_count} 个新资源")
|
||||||
|
else:
|
||||||
|
logger.info("系统管理员角色已拥有所有资源")
|
||||||
|
else:
|
||||||
|
logger.warning("未找到系统管理员角色")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Define hardcoded resource data
|
||||||
|
main_menu_data = [
|
||||||
|
{
|
||||||
|
"name": "智能问答",
|
||||||
|
"code": "CHAT",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/chat",
|
||||||
|
"component": "views/Chat.vue",
|
||||||
|
"icon": "ChatDotRound",
|
||||||
|
"description": "智能问答功能",
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "智能问数",
|
||||||
|
"code": "SMART_QUERY",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/smart-query",
|
||||||
|
"component": "views/SmartQuery.vue",
|
||||||
|
"icon": "DataAnalysis",
|
||||||
|
"description": "智能问数功能",
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "知识库",
|
||||||
|
"code": "KNOWLEDGE",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/knowledge",
|
||||||
|
"component": "views/KnowledgeBase.vue",
|
||||||
|
"icon": "Collection",
|
||||||
|
"description": "知识库管理",
|
||||||
|
"sort_order": 3,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "工作流编排",
|
||||||
|
"code": "WORKFLOW",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/workflow",
|
||||||
|
"component": "views/Workflow.vue",
|
||||||
|
"icon": "Connection",
|
||||||
|
"description": "工作流编排功能",
|
||||||
|
"sort_order": 4,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "智能体管理",
|
||||||
|
"code": "AGENT",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/agent",
|
||||||
|
"component": "views/Agent.vue",
|
||||||
|
"icon": "User",
|
||||||
|
"description": "智能体管理功能",
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "系统管理",
|
||||||
|
"code": "SYSTEM",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system",
|
||||||
|
"component": "views/SystemManagement.vue",
|
||||||
|
"icon": "Setting",
|
||||||
|
"description": "系统管理功能",
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create main menu resources
|
||||||
|
created_resources = {}
|
||||||
|
for menu_data in main_menu_data:
|
||||||
|
resource = Resource(**menu_data)
|
||||||
|
db.add(resource)
|
||||||
|
db.flush()
|
||||||
|
created_resources[menu_data["code"]] = resource
|
||||||
|
logger.info(f"Created main menu resource: {menu_data['name']}")
|
||||||
|
|
||||||
|
# System management submenu data
|
||||||
|
system_submenu_data = [
|
||||||
|
{
|
||||||
|
"name": "用户管理",
|
||||||
|
"code": "SYSTEM_USERS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/users",
|
||||||
|
"component": "components/system/UserManagement.vue",
|
||||||
|
"icon": "User",
|
||||||
|
"description": "用户管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "部门管理",
|
||||||
|
"code": "SYSTEM_DEPARTMENTS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/departments",
|
||||||
|
"component": "components/system/DepartmentManagement.vue",
|
||||||
|
"icon": "OfficeBuilding",
|
||||||
|
"description": "部门管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "角色管理",
|
||||||
|
"code": "SYSTEM_ROLES",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/roles",
|
||||||
|
"component": "components/system/RoleManagement.vue",
|
||||||
|
"icon": "Avatar",
|
||||||
|
"description": "角色管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 3,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "权限管理",
|
||||||
|
"code": "SYSTEM_PERMISSIONS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/permissions",
|
||||||
|
"component": "components/system/PermissionManagement.vue",
|
||||||
|
"icon": "Lock",
|
||||||
|
"description": "权限管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 4,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "资源管理",
|
||||||
|
"code": "SYSTEM_RESOURCES",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/resources",
|
||||||
|
"component": "components/system/ResourceManagement.vue",
|
||||||
|
"icon": "Grid",
|
||||||
|
"description": "资源管理功能",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "大模型管理",
|
||||||
|
"code": "SYSTEM_LLM_CONFIGS",
|
||||||
|
"type": "menu",
|
||||||
|
"path": "/system/llm-configs",
|
||||||
|
"component": "components/system/LLMConfigManagement.vue",
|
||||||
|
"icon": "Cpu",
|
||||||
|
"description": "大模型配置管理",
|
||||||
|
"parent_id": created_resources["SYSTEM"].id,
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create system management submenu
|
||||||
|
for submenu_data in system_submenu_data:
|
||||||
|
submenu = Resource(**submenu_data)
|
||||||
|
db.add(submenu)
|
||||||
|
db.flush()
|
||||||
|
created_resources[submenu_data["code"]] = submenu
|
||||||
|
logger.info(f"Created system submenu resource: {submenu_data['name']}")
|
||||||
|
|
||||||
|
# Button resources data
|
||||||
|
button_resources_data = [
|
||||||
|
# User management buttons
|
||||||
|
{
|
||||||
|
"name": "新增用户",
|
||||||
|
"code": "USER_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增用户按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑用户",
|
||||||
|
"code": "USER_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑用户按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_USERS"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Role management buttons
|
||||||
|
{
|
||||||
|
"name": "新增角色",
|
||||||
|
"code": "ROLE_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增角色按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑角色",
|
||||||
|
"code": "ROLE_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑角色按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_ROLES"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Permission management buttons
|
||||||
|
{
|
||||||
|
"name": "新增权限",
|
||||||
|
"code": "PERMISSION_CREATE_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "新增权限按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "编辑权限",
|
||||||
|
"code": "PERMISSION_EDIT_BTN",
|
||||||
|
"type": "button",
|
||||||
|
"description": "编辑权限按钮",
|
||||||
|
"parent_id": created_resources["SYSTEM_PERMISSIONS"].id,
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create button resources
|
||||||
|
for button_data in button_resources_data:
|
||||||
|
button = Resource(**button_data)
|
||||||
|
db.add(button)
|
||||||
|
db.flush()
|
||||||
|
created_resources[button_data["code"]] = button
|
||||||
|
logger.info(f"Created button resource: {button_data['name']}")
|
||||||
|
|
||||||
|
# API resources data
|
||||||
|
api_resources_data = [
|
||||||
|
# User management APIs
|
||||||
|
{
|
||||||
|
"name": "用户列表API",
|
||||||
|
"code": "USER_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/users",
|
||||||
|
"description": "获取用户列表API",
|
||||||
|
"sort_order": 1,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建用户API",
|
||||||
|
"code": "USER_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/users",
|
||||||
|
"description": "创建用户API",
|
||||||
|
"sort_order": 2,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Role management APIs
|
||||||
|
{
|
||||||
|
"name": "角色列表API",
|
||||||
|
"code": "ROLE_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/roles",
|
||||||
|
"description": "获取角色列表API",
|
||||||
|
"sort_order": 5,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建角色API",
|
||||||
|
"code": "ROLE_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/roles",
|
||||||
|
"description": "创建角色API",
|
||||||
|
"sort_order": 6,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
# Resource management APIs
|
||||||
|
{
|
||||||
|
"name": "资源列表API",
|
||||||
|
"code": "RESOURCE_LIST_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/resources",
|
||||||
|
"description": "获取资源列表API",
|
||||||
|
"sort_order": 10,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "创建资源API",
|
||||||
|
"code": "RESOURCE_CREATE_API",
|
||||||
|
"type": "api",
|
||||||
|
"path": "/api/admin/resources",
|
||||||
|
"description": "创建资源API",
|
||||||
|
"sort_order": 11,
|
||||||
|
"requires_auth": True,
|
||||||
|
"requires_admin": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create API resources
|
||||||
|
for api_data in api_resources_data:
|
||||||
|
api_resource = Resource(**api_data)
|
||||||
|
db.add(api_resource)
|
||||||
|
db.flush()
|
||||||
|
created_resources[api_data["code"]] = api_resource
|
||||||
|
logger.info(f"Created API resource: {api_data['name']}")
|
||||||
|
|
||||||
|
# 分配资源给系统管理员角色
|
||||||
|
admin_role = db.query(Role).filter(Role.name == "系统管理员").first()
|
||||||
|
if admin_role:
|
||||||
|
all_resources = list(created_resources.values())
|
||||||
|
for resource in all_resources:
|
||||||
|
# 检查关联是否已存在
|
||||||
|
existing = db.query(RoleResource).filter(
|
||||||
|
RoleResource.role_id == admin_role.id,
|
||||||
|
RoleResource.resource_id == resource.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
role_resource = RoleResource(
|
||||||
|
role_id=admin_role.id,
|
||||||
|
resource_id=resource.id
|
||||||
|
)
|
||||||
|
db.add(role_resource)
|
||||||
|
|
||||||
|
logger.info(f"已为系统管理员角色分配 {len(all_resources)} 个资源")
|
||||||
|
else:
|
||||||
|
logger.warning("未找到系统管理员角色")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
total_resources = db.query(Resource).count()
|
||||||
|
logger.info(f"Migration completed successfully. Total resources: {total_resources}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Migration failed: {str(e)}")
|
||||||
|
if db:
|
||||||
|
db.rollback()
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the migration."""
|
||||||
|
print("=== 硬编码资源数据迁移 ===")
|
||||||
|
success = migrate_hardcoded_resources()
|
||||||
|
if success:
|
||||||
|
print("\n🎉 资源数据迁移完成!")
|
||||||
|
else:
|
||||||
|
print("\n❌ 资源数据迁移失败!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""删除权限相关表的迁移脚本
|
||||||
|
|
||||||
|
Revision ID: remove_permission_tables
|
||||||
|
Revises: add_system_management
|
||||||
|
Create Date: 2024-01-25 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'remove_permission_tables'
|
||||||
|
down_revision = 'add_system_management'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""删除权限相关表."""
|
||||||
|
|
||||||
|
# 获取数据库连接
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# 删除外键约束和表(按依赖关系顺序)
|
||||||
|
tables_to_drop = [
|
||||||
|
'user_permissions', # 用户权限关联表
|
||||||
|
'role_permissions', # 角色权限关联表
|
||||||
|
'permission_resources', # 权限资源关联表
|
||||||
|
'permissions', # 权限表
|
||||||
|
'role_resources', # 角色资源关联表
|
||||||
|
'resources', # 资源表
|
||||||
|
'user_departments', # 用户部门关联表
|
||||||
|
'departments' # 部门表
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_name in tables_to_drop:
|
||||||
|
try:
|
||||||
|
# 检查表是否存在
|
||||||
|
result = connection.execute(text(f"""
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = '{table_name}'
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
table_exists = result.scalar()
|
||||||
|
|
||||||
|
if table_exists:
|
||||||
|
print(f"删除表: {table_name}")
|
||||||
|
op.drop_table(table_name)
|
||||||
|
else:
|
||||||
|
print(f"表 {table_name} 不存在,跳过")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"删除表 {table_name} 时出错: {e}")
|
||||||
|
# 继续删除其他表
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 删除用户表中的部门相关字段
|
||||||
|
try:
|
||||||
|
# 检查字段是否存在
|
||||||
|
result = connection.execute(text("""
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'users' AND column_name = 'department_id';
|
||||||
|
"""))
|
||||||
|
column_exists = result.fetchone()
|
||||||
|
|
||||||
|
if column_exists:
|
||||||
|
print("删除用户表中的 department_id 字段")
|
||||||
|
op.drop_column('users', 'department_id')
|
||||||
|
else:
|
||||||
|
print("用户表中的 department_id 字段不存在,跳过")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"删除 department_id 字段时出错: {e}")
|
||||||
|
|
||||||
|
# 简化 user_roles 表结构(如果需要的话)
|
||||||
|
try:
|
||||||
|
# 检查 user_roles 表是否有多余的字段
|
||||||
|
result = connection.execute(text("""
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = 'user_roles' AND column_name IN ('id', 'created_at', 'updated_at', 'created_by', 'updated_by');
|
||||||
|
"""))
|
||||||
|
extra_columns = [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
|
if extra_columns:
|
||||||
|
print("简化 user_roles 表结构")
|
||||||
|
# 创建新的简化表
|
||||||
|
op.execute(text("""
|
||||||
|
CREATE TABLE user_roles_new (
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
role_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, role_id),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (role_id) REFERENCES roles(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# 迁移数据
|
||||||
|
op.execute(text("""
|
||||||
|
INSERT INTO user_roles_new (user_id, role_id)
|
||||||
|
SELECT DISTINCT user_id, role_id FROM user_roles;
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# 删除旧表,重命名新表
|
||||||
|
op.drop_table('user_roles')
|
||||||
|
op.execute(text("ALTER TABLE user_roles_new RENAME TO user_roles;"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"简化 user_roles 表时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""回滚操作 - 重新创建权限相关表."""
|
||||||
|
|
||||||
|
# 注意:这是一个破坏性操作,回滚会丢失数据
|
||||||
|
# 在生产环境中应该谨慎使用
|
||||||
|
|
||||||
|
print("警告:回滚操作会重新创建权限相关表,但不会恢复数据")
|
||||||
|
|
||||||
|
# 重新创建基本的权限表结构(简化版)
|
||||||
|
op.create_table('permissions',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(100), nullable=False),
|
||||||
|
sa.Column('code', sa.String(100), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, default=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table('role_permissions',
|
||||||
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('permission_id', sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['permission_id'], ['permissions.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('role_id', 'permission_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加用户表的 department_id 字段
|
||||||
|
op.add_column('users', sa.Column('department_id', sa.Integer(), nullable=True))
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add backend directory to Python path for direct execution
|
||||||
|
if __name__ == "__main__":
|
||||||
|
backend_dir = Path(__file__).parent.parent
|
||||||
|
if str(backend_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(backend_dir))
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from th_agenter.core.app import create_app
|
||||||
|
|
||||||
|
# Create FastAPI application using factory function
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
# 在 main.py 中添加表元数据路由
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='TH-Agenter Backend Server')
|
||||||
|
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
||||||
|
parser.add_argument('--port', type=int, default=8000, help='Port to bind to')
|
||||||
|
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, reload=args.reload)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Database models for TH-Agenter."""
|
||||||
|
|
||||||
|
from .user import User
|
||||||
|
from .conversation import Conversation
|
||||||
|
from .message import Message
|
||||||
|
from .knowledge_base import KnowledgeBase, Document
|
||||||
|
from .agent_config import AgentConfig
|
||||||
|
from .excel_file import ExcelFile
|
||||||
|
from .permission import Role, UserRole
|
||||||
|
from .llm_config import LLMConfig
|
||||||
|
from .workflow import Workflow, WorkflowExecution, NodeExecution
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"User",
|
||||||
|
"Conversation",
|
||||||
|
"Message",
|
||||||
|
"KnowledgeBase",
|
||||||
|
"Document",
|
||||||
|
"AgentConfig",
|
||||||
|
"ExcelFile",
|
||||||
|
"Role",
|
||||||
|
"UserRole",
|
||||||
|
"LLMConfig",
|
||||||
|
"Workflow",
|
||||||
|
"WorkflowExecution",
|
||||||
|
"NodeExecution"
|
||||||
|
]
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Agent configuration model."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Agent configuration model."""
|
||||||
|
|
||||||
|
__tablename__ = "agent_configs"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String(100), nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Agent configuration
|
||||||
|
enabled_tools = Column(JSON, nullable=False, default=list)
|
||||||
|
max_iterations = Column(Integer, default=10)
|
||||||
|
temperature = Column(String(10), default="0.1")
|
||||||
|
system_message = Column(Text, nullable=True)
|
||||||
|
verbose = Column(Boolean, default=True)
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model_name = Column(String(100), default="gpt-3.5-turbo")
|
||||||
|
max_tokens = Column(Integer, default=2048)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
is_default = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<AgentConfig(id={self.id}, name='{self.name}', is_active={self.is_active})>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"enabled_tools": self.enabled_tools or [],
|
||||||
|
"max_iterations": self.max_iterations,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"system_message": self.system_message,
|
||||||
|
"verbose": self.verbose,
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"is_default": self.is_default,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"updated_at": self.updated_at
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
"""Database base model."""
|
||||||
|
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
"""Conversation model."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Boolean
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
"""Conversation model."""
|
||||||
|
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
title = Column(String(200), nullable=False)
|
||||||
|
user_id = Column(Integer, nullable=False) # Removed ForeignKey("users.id")
|
||||||
|
knowledge_base_id = Column(Integer, nullable=True) # Removed ForeignKey("knowledge_bases.id")
|
||||||
|
system_prompt = Column(Text, nullable=True)
|
||||||
|
model_name = Column(String(100), nullable=False, default="gpt-3.5-turbo")
|
||||||
|
temperature = Column(String(10), nullable=False, default="0.7")
|
||||||
|
max_tokens = Column(Integer, nullable=False, default=2048)
|
||||||
|
is_archived = Column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Conversation(id={self.id}, title='{self.title}', user_id={self.user_id})>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_count(self):
|
||||||
|
"""Get the number of messages in this conversation."""
|
||||||
|
return len(self.messages)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_message_at(self):
|
||||||
|
"""Get the timestamp of the last message."""
|
||||||
|
if self.messages:
|
||||||
|
return self.messages[-1].created_at
|
||||||
|
return self.created_at
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""数据库配置模型"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# 在现有的DatabaseConfig类中添加关系
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
class DatabaseConfig(BaseModel):
|
||||||
|
"""数据库配置表"""
|
||||||
|
__tablename__ = "database_configs"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String(100), nullable=False) # 配置名称
|
||||||
|
db_type = Column(String(20), nullable=False, unique=True) # 数据库类型:postgresql, mysql等
|
||||||
|
host = Column(String(255), nullable=False)
|
||||||
|
port = Column(Integer, nullable=False)
|
||||||
|
database = Column(String(100), nullable=False)
|
||||||
|
username = Column(String(100), nullable=False)
|
||||||
|
password = Column(Text, nullable=False) # 加密存储
|
||||||
|
is_active = Column(Boolean, default=True)
|
||||||
|
is_default = Column(Boolean, default=False)
|
||||||
|
connection_params = Column(JSON, nullable=True) # 额外连接参数
|
||||||
|
|
||||||
|
def to_dict(self, include_password=False, decrypt_service=None):
|
||||||
|
result = {
|
||||||
|
"id": self.id,
|
||||||
|
"created_by": self.created_by,
|
||||||
|
"name": self.name,
|
||||||
|
"db_type": self.db_type,
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"database": self.database,
|
||||||
|
"username": self.username,
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"is_default": self.is_default,
|
||||||
|
"connection_params": self.connection_params,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果需要包含密码且提供了解密服务
|
||||||
|
if include_password and decrypt_service:
|
||||||
|
print('begin decrypt password')
|
||||||
|
result["password"] = decrypt_service._decrypt_password(self.password)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 添加关系
|
||||||
|
# table_metadata = relationship("TableMetadata", back_populates="database_config")
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Excel file models for smart query."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, Text, Boolean, JSON, DateTime
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelFile(BaseModel):
|
||||||
|
"""Excel file model for storing file metadata."""
|
||||||
|
|
||||||
|
__tablename__ = "excel_files"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Basic file information
|
||||||
|
# user_id = Column(Integer, nullable=False) # 用户ID
|
||||||
|
original_filename = Column(String(255), nullable=False) # 原始文件名
|
||||||
|
file_path = Column(String(500), nullable=False) # 文件存储路径
|
||||||
|
file_size = Column(Integer, nullable=False) # 文件大小(字节)
|
||||||
|
file_type = Column(String(50), nullable=False) # 文件类型 (.xlsx, .xls, .csv)
|
||||||
|
|
||||||
|
# Excel specific information
|
||||||
|
sheet_names = Column(JSON, nullable=False) # 所有sheet名称列表
|
||||||
|
default_sheet = Column(String(100), nullable=True) # 默认sheet名称
|
||||||
|
|
||||||
|
# Data preview information
|
||||||
|
columns_info = Column(JSON, nullable=False) # 列信息:{sheet_name: [column_names]}
|
||||||
|
preview_data = Column(JSON, nullable=False) # 前5行数据:{sheet_name: [[row1], [row2], ...]}
|
||||||
|
data_types = Column(JSON, nullable=True) # 数据类型信息:{sheet_name: {column: dtype}}
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
total_rows = Column(JSON, nullable=True) # 每个sheet的总行数:{sheet_name: row_count}
|
||||||
|
total_columns = Column(JSON, nullable=True) # 每个sheet的总列数:{sheet_name: column_count}
|
||||||
|
|
||||||
|
# Processing status
|
||||||
|
is_processed = Column(Boolean, default=True, nullable=False) # 是否已处理
|
||||||
|
processing_error = Column(Text, nullable=True) # 处理错误信息
|
||||||
|
|
||||||
|
# Upload information
|
||||||
|
# upload_time = Column(DateTime, default=func.now(), nullable=False) # 上传时间
|
||||||
|
last_accessed = Column(DateTime, nullable=True) # 最后访问时间
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ExcelFile(id={self.id}, filename='{self.original_filename}', user_id={self.user_id})>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_size_mb(self):
|
||||||
|
"""Get file size in MB."""
|
||||||
|
return round(self.file_size / (1024 * 1024), 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sheet_count(self):
|
||||||
|
"""Get number of sheets."""
|
||||||
|
return len(self.sheet_names) if self.sheet_names else 0
|
||||||
|
|
||||||
|
def get_sheet_info(self, sheet_name: str = None):
|
||||||
|
"""Get information for a specific sheet or default sheet."""
|
||||||
|
if not sheet_name:
|
||||||
|
sheet_name = self.default_sheet or (self.sheet_names[0] if self.sheet_names else None)
|
||||||
|
|
||||||
|
if not sheet_name or sheet_name not in self.sheet_names:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sheet_name': sheet_name,
|
||||||
|
'columns': self.columns_info.get(sheet_name, []) if self.columns_info else [],
|
||||||
|
'preview_data': self.preview_data.get(sheet_name, []) if self.preview_data else [],
|
||||||
|
'data_types': self.data_types.get(sheet_name, {}) if self.data_types else {},
|
||||||
|
'total_rows': self.total_rows.get(sheet_name, 0) if self.total_rows else 0,
|
||||||
|
'total_columns': self.total_columns.get(sheet_name, 0) if self.total_columns else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_sheets_summary(self):
|
||||||
|
"""Get summary information for all sheets."""
|
||||||
|
if not self.sheet_names:
|
||||||
|
return []
|
||||||
|
|
||||||
|
summary = []
|
||||||
|
for sheet_name in self.sheet_names:
|
||||||
|
sheet_info = self.get_sheet_info(sheet_name)
|
||||||
|
if sheet_info:
|
||||||
|
summary.append({
|
||||||
|
'sheet_name': sheet_name,
|
||||||
|
'columns_count': len(sheet_info['columns']),
|
||||||
|
'rows_count': sheet_info['total_rows'],
|
||||||
|
'columns': sheet_info['columns'][:10] # 只显示前10列
|
||||||
|
})
|
||||||
|
return summary
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""Knowledge base models."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Boolean, JSON, Float
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBase(BaseModel):
|
||||||
|
"""Knowledge base model."""
|
||||||
|
|
||||||
|
__tablename__ = "knowledge_bases"
|
||||||
|
|
||||||
|
name = Column(String(100), unique=False, index=True, nullable=False)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
embedding_model = Column(String(100), nullable=False, default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
chunk_size = Column(Integer, nullable=False, default=1000)
|
||||||
|
chunk_overlap = Column(Integer, nullable=False, default=200)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Vector database settings
|
||||||
|
vector_db_type = Column(String(50), nullable=False, default="chroma")
|
||||||
|
collection_name = Column(String(100), nullable=True) # For vector DB collection
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<KnowledgeBase(id={self.id}, name='{self.name}')>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def document_count(self):
|
||||||
|
"""Get the number of documents in this knowledge base."""
|
||||||
|
return len(self.documents)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_document_count(self):
|
||||||
|
"""Get the number of active documents in this knowledge base."""
|
||||||
|
return len([doc for doc in self.documents if doc.is_processed])
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
"""Document model."""
|
||||||
|
|
||||||
|
__tablename__ = "documents"
|
||||||
|
|
||||||
|
knowledge_base_id = Column(Integer, nullable=False) # Removed ForeignKey("knowledge_bases.id")
|
||||||
|
filename = Column(String(255), nullable=False)
|
||||||
|
original_filename = Column(String(255), nullable=False)
|
||||||
|
file_path = Column(String(500), nullable=False)
|
||||||
|
file_size = Column(Integer, nullable=False) # in bytes
|
||||||
|
file_type = Column(String(50), nullable=False) # .pdf, .txt, .docx, etc.
|
||||||
|
mime_type = Column(String(100), nullable=True)
|
||||||
|
|
||||||
|
# Processing status
|
||||||
|
is_processed = Column(Boolean, default=False, nullable=False)
|
||||||
|
processing_error = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Content and metadata
|
||||||
|
content = Column(Text, nullable=True) # Extracted text content
|
||||||
|
doc_metadata = Column(JSON, nullable=True) # Additional metadata
|
||||||
|
|
||||||
|
# Chunking information
|
||||||
|
chunk_count = Column(Integer, default=0, nullable=False)
|
||||||
|
|
||||||
|
# Embedding information
|
||||||
|
embedding_model = Column(String(100), nullable=True)
|
||||||
|
vector_ids = Column(JSON, nullable=True) # Store vector database IDs for chunks
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Document(id={self.id}, filename='{self.filename}', kb_id={self.knowledge_base_id})>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_size_mb(self):
|
||||||
|
"""Get file size in MB."""
|
||||||
|
return round(self.file_size / (1024 * 1024), 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_text_file(self):
|
||||||
|
"""Check if document is a text file."""
|
||||||
|
return self.file_type.lower() in ['.txt', '.md', '.csv']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pdf_file(self):
|
||||||
|
"""Check if document is a PDF file."""
|
||||||
|
return self.file_type.lower() == '.pdf'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_office_file(self):
|
||||||
|
"""Check if document is an Office file."""
|
||||||
|
return self.file_type.lower() in ['.docx', '.xlsx', '.pptx']
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""LLM Configuration model for managing multiple AI models."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Text, Boolean, Integer, Float, JSON
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfig(BaseModel):
|
||||||
|
"""LLM Configuration model for managing AI model settings."""
|
||||||
|
|
||||||
|
__tablename__ = "llm_configs"
|
||||||
|
|
||||||
|
name = Column(String(100), nullable=False, index=True) # 配置名称
|
||||||
|
provider = Column(String(50), nullable=False, index=True) # 服务商:openai, deepseek, doubao, zhipu, moonshot, baidu
|
||||||
|
model_name = Column(String(100), nullable=False) # 模型名称
|
||||||
|
api_key = Column(String(500), nullable=False) # API密钥(加密存储)
|
||||||
|
base_url = Column(String(200), nullable=True) # API基础URL
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
max_tokens = Column(Integer, default=2048, nullable=False)
|
||||||
|
temperature = Column(Float, default=0.7, nullable=False)
|
||||||
|
top_p = Column(Float, default=1.0, nullable=False)
|
||||||
|
frequency_penalty = Column(Float, default=0.0, nullable=False)
|
||||||
|
presence_penalty = Column(Float, default=0.0, nullable=False)
|
||||||
|
|
||||||
|
# 配置信息
|
||||||
|
description = Column(Text, nullable=True) # 配置描述
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False) # 是否启用
|
||||||
|
is_default = Column(Boolean, default=False, nullable=False) # 是否为默认配置
|
||||||
|
is_embedding = Column(Boolean, default=False, nullable=False) # 是否为嵌入模型
|
||||||
|
|
||||||
|
# 扩展配置(JSON格式)
|
||||||
|
extra_config = Column(JSON, nullable=True) # 额外配置参数
|
||||||
|
|
||||||
|
# 使用统计
|
||||||
|
usage_count = Column(Integer, default=0, nullable=False) # 使用次数
|
||||||
|
last_used_at = Column(String(50), nullable=True) # 最后使用时间
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<LLMConfig(id={self.id}, name='{self.name}', provider='{self.provider}', model='{self.model_name}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_sensitive=False):
|
||||||
|
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'provider': self.provider,
|
||||||
|
'model_name': self.model_name,
|
||||||
|
'base_url': self.base_url,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'temperature': self.temperature,
|
||||||
|
'top_p': self.top_p,
|
||||||
|
'frequency_penalty': self.frequency_penalty,
|
||||||
|
'presence_penalty': self.presence_penalty,
|
||||||
|
'description': self.description,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'is_default': self.is_default,
|
||||||
|
'is_embedding': self.is_embedding,
|
||||||
|
'extra_config': self.extra_config,
|
||||||
|
'usage_count': self.usage_count,
|
||||||
|
'last_used_at': self.last_used_at
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_sensitive:
|
||||||
|
data['api_key'] = self.api_key
|
||||||
|
else:
|
||||||
|
# 只显示API密钥的前几位和后几位
|
||||||
|
if self.api_key:
|
||||||
|
key_len = len(self.api_key)
|
||||||
|
if key_len > 8:
|
||||||
|
data['api_key_masked'] = f"{self.api_key[:4]}...{self.api_key[-4:]}"
|
||||||
|
else:
|
||||||
|
data['api_key_masked'] = "***"
|
||||||
|
else:
|
||||||
|
data['api_key_masked'] = None
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_client_config(self) -> Dict[str, Any]:
|
||||||
|
"""获取用于创建客户端的配置."""
|
||||||
|
config = {
|
||||||
|
'api_key': self.api_key,
|
||||||
|
'base_url': self.base_url,
|
||||||
|
'model': self.model_name,
|
||||||
|
'max_tokens': self.max_tokens,
|
||||||
|
'temperature': self.temperature,
|
||||||
|
'top_p': self.top_p,
|
||||||
|
'frequency_penalty': self.frequency_penalty,
|
||||||
|
'presence_penalty': self.presence_penalty
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加额外配置
|
||||||
|
if self.extra_config:
|
||||||
|
config.update(self.extra_config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def validate_config(self) -> Dict[str, Any]:
|
||||||
|
"""验证配置是否有效."""
|
||||||
|
if not self.name or not self.name.strip():
|
||||||
|
return {"valid": False, "error": "配置名称不能为空"}
|
||||||
|
|
||||||
|
if not self.provider or self.provider not in ['openai', 'deepseek', 'doubao', 'zhipu', 'moonshot', 'baidu']:
|
||||||
|
return {"valid": False, "error": "不支持的服务商"}
|
||||||
|
|
||||||
|
if not self.model_name or not self.model_name.strip():
|
||||||
|
return {"valid": False, "error": "模型名称不能为空"}
|
||||||
|
|
||||||
|
if not self.api_key or not self.api_key.strip():
|
||||||
|
return {"valid": False, "error": "API密钥不能为空"}
|
||||||
|
|
||||||
|
if self.max_tokens <= 0 or self.max_tokens > 32000:
|
||||||
|
return {"valid": False, "error": "最大令牌数必须在1-32000之间"}
|
||||||
|
|
||||||
|
if self.temperature < 0 or self.temperature > 2:
|
||||||
|
return {"valid": False, "error": "温度参数必须在0-2之间"}
|
||||||
|
|
||||||
|
return {"valid": True, "error": None}
|
||||||
|
|
||||||
|
def increment_usage(self):
|
||||||
|
"""增加使用次数."""
|
||||||
|
from datetime import datetime
|
||||||
|
self.usage_count += 1
|
||||||
|
self.last_used_at = datetime.now().isoformat()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_config(cls, provider: str, is_embedding: bool = False):
|
||||||
|
"""获取服务商的默认配置模板."""
|
||||||
|
templates = {
|
||||||
|
'openai': {
|
||||||
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
'model_name': 'gpt-3.5-turbo' if not is_embedding else 'text-embedding-ada-002',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'deepseek': {
|
||||||
|
'base_url': 'https://api.deepseek.com/v1',
|
||||||
|
'model_name': 'deepseek-chat' if not is_embedding else 'deepseek-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'doubao': {
|
||||||
|
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||||
|
'model_name': 'doubao-lite-4k' if not is_embedding else 'doubao-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'zhipu': {
|
||||||
|
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||||
|
'model_name': 'glm-4' if not is_embedding else 'embedding-3',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
},
|
||||||
|
'moonshot': {
|
||||||
|
'base_url': 'https://api.moonshot.cn/v1',
|
||||||
|
'model_name': 'moonshot-v1-8k' if not is_embedding else 'moonshot-embedding',
|
||||||
|
'max_tokens': 2048,
|
||||||
|
'temperature': 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates.get(provider, {})
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Message model."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Integer, ForeignKey, Text, Enum, JSON
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, enum.Enum):
|
||||||
|
"""Message role enumeration."""
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(str, enum.Enum):
|
||||||
|
"""Message type enumeration."""
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
FILE = "file"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""Message model."""
|
||||||
|
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
conversation_id = Column(Integer, nullable=False) # Removed ForeignKey("conversations.id")
|
||||||
|
role = Column(Enum(MessageRole), nullable=False)
|
||||||
|
content = Column(Text, nullable=False)
|
||||||
|
message_type = Column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||||
|
message_metadata = Column(JSON, nullable=True) # Store additional data like file info, tokens used, etc.
|
||||||
|
|
||||||
|
# For knowledge base context
|
||||||
|
context_documents = Column(JSON, nullable=True) # Store retrieved document references
|
||||||
|
|
||||||
|
# Token usage tracking
|
||||||
|
prompt_tokens = Column(Integer, nullable=True)
|
||||||
|
completion_tokens = Column(Integer, nullable=True)
|
||||||
|
total_tokens = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Relationships removed to eliminate foreign key constraints
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content
|
||||||
|
return f"<Message(id={self.id}, role='{self.role}', content='{content_preview}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_metadata=True):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
data = super().to_dict()
|
||||||
|
if not include_metadata:
|
||||||
|
data.pop('message_metadata', None)
|
||||||
|
data.pop('context_documents', None)
|
||||||
|
data.pop('prompt_tokens', None)
|
||||||
|
data.pop('completion_tokens', None)
|
||||||
|
data.pop('total_tokens', None)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_from_user(self):
|
||||||
|
"""Check if message is from user."""
|
||||||
|
return self.role == MessageRole.USER
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_from_assistant(self):
|
||||||
|
"""Check if message is from assistant."""
|
||||||
|
return self.role == MessageRole.ASSISTANT
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
"""Role models for simplified RBAC system."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Text, Boolean, ForeignKey, Integer
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from ..db.base import BaseModel, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Role(BaseModel):
|
||||||
|
"""Role model for simplified RBAC system."""
|
||||||
|
|
||||||
|
__tablename__ = "roles"
|
||||||
|
|
||||||
|
name = Column(String(100), nullable=False, unique=True, index=True) # 角色名称
|
||||||
|
code = Column(String(100), nullable=False, unique=True, index=True) # 角色编码
|
||||||
|
description = Column(Text, nullable=True) # 角色描述
|
||||||
|
is_system = Column(Boolean, default=False, nullable=False) # 是否系统角色
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# 关系 - 只保留用户关系
|
||||||
|
users = relationship("User", secondary="user_roles", back_populates="roles")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Role(id={self.id}, code='{self.code}', name='{self.name}')>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'code': self.code,
|
||||||
|
'description': self.description,
|
||||||
|
'is_system': self.is_system,
|
||||||
|
'is_active': self.is_active
|
||||||
|
})
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(Base):
|
||||||
|
"""User role association model."""
|
||||||
|
|
||||||
|
__tablename__ = "user_roles"
|
||||||
|
|
||||||
|
user_id = Column(Integer, ForeignKey('users.id'), primary_key=True)
|
||||||
|
role_id = Column(Integer, ForeignKey('roles.id'), primary_key=True)
|
||||||
|
|
||||||
|
# 关系 - 用于直接操作关联表的场景
|
||||||
|
user = relationship("User", viewonly=True)
|
||||||
|
role = relationship("Role", viewonly=True)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<UserRole(user_id={self.user_id}, role_id={self.role_id})>"
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""表元数据模型"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, JSON, ForeignKey
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TableMetadata(BaseModel):
|
||||||
|
"""表元数据表"""
|
||||||
|
__tablename__ = "table_metadata"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
# database_config_id = Column(Integer, ForeignKey('database_configs.id'), nullable=False)
|
||||||
|
table_name = Column(String(100), nullable=False, index=True)
|
||||||
|
table_schema = Column(String(50), default='public')
|
||||||
|
table_type = Column(String(20), default='BASE TABLE')
|
||||||
|
table_comment = Column(Text, nullable=True) # 表描述
|
||||||
|
database_config_id = Column(Integer, nullable=True) #数据库配置ID
|
||||||
|
# 表结构信息
|
||||||
|
columns_info = Column(JSON, nullable=False) # 列信息:名称、类型、注释等
|
||||||
|
primary_keys = Column(JSON, nullable=True) # 主键列表
|
||||||
|
foreign_keys = Column(JSON, nullable=True) # 外键信息
|
||||||
|
indexes = Column(JSON, nullable=True) # 索引信息
|
||||||
|
|
||||||
|
# 示例数据
|
||||||
|
sample_data = Column(JSON, nullable=True) # 前5条示例数据
|
||||||
|
row_count = Column(Integer, default=0) # 总行数
|
||||||
|
|
||||||
|
# 问答相关
|
||||||
|
is_enabled_for_qa = Column(Boolean, default=True) # 是否启用问答
|
||||||
|
qa_description = Column(Text, nullable=True) # 问答描述
|
||||||
|
business_context = Column(Text, nullable=True) # 业务上下文
|
||||||
|
|
||||||
|
last_synced_at = Column(DateTime(timezone=True), nullable=True) # 最后同步时间
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
# database_config = relationship("DatabaseConfig", back_populates="table_metadata")
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"created_by": self.created_by, # 改为created_by
|
||||||
|
"database_config_id": self.database_config_id,
|
||||||
|
"table_name": self.table_name,
|
||||||
|
"table_schema": self.table_schema,
|
||||||
|
"table_type": self.table_type,
|
||||||
|
"table_comment": self.table_comment,
|
||||||
|
"columns_info": self.columns_info,
|
||||||
|
"primary_keys": self.primary_keys,
|
||||||
|
# "foreign_keys": self.foreign_keys,
|
||||||
|
"indexes": self.indexes,
|
||||||
|
"sample_data": self.sample_data,
|
||||||
|
"row_count": self.row_count,
|
||||||
|
"is_enabled_for_qa": self.is_enabled_for_qa,
|
||||||
|
"qa_description": self.qa_description,
|
||||||
|
"business_context": self.business_context,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
"last_synced_at": self.last_synced_at.isoformat() if self.last_synced_at else None
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""User model."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Boolean, Text
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
"""User model."""
|
||||||
|
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
username = Column(String(50), unique=True, index=True, nullable=False)
|
||||||
|
email = Column(String(100), unique=True, index=True, nullable=False)
|
||||||
|
hashed_password = Column(String(255), nullable=False)
|
||||||
|
full_name = Column(String(100), nullable=True)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
avatar_url = Column(String(255), nullable=True)
|
||||||
|
bio = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 关系 - 只保留角色关系
|
||||||
|
roles = relationship("Role", secondary="user_roles", back_populates="users")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_sensitive=False, include_roles=False):
|
||||||
|
"""Convert to dictionary, optionally excluding sensitive data."""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'username': self.username,
|
||||||
|
'email': self.email,
|
||||||
|
'full_name': self.full_name,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'avatar_url': self.avatar_url,
|
||||||
|
'bio': self.bio,
|
||||||
|
'is_superuser': self.is_superuser()
|
||||||
|
})
|
||||||
|
|
||||||
|
if not include_sensitive:
|
||||||
|
data.pop('hashed_password', None)
|
||||||
|
|
||||||
|
if include_roles:
|
||||||
|
data['roles'] = [role.to_dict() for role in self.roles if role.is_active]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def has_role(self, role_code: str) -> bool:
|
||||||
|
"""检查用户是否拥有指定角色."""
|
||||||
|
try:
|
||||||
|
return any(role.code == role_code and role.is_active for role in self.roles)
|
||||||
|
except Exception:
|
||||||
|
# 如果对象已分离,使用数据库查询
|
||||||
|
from sqlalchemy.orm import object_session
|
||||||
|
from .permission import Role, UserRole
|
||||||
|
|
||||||
|
session = object_session(self)
|
||||||
|
if session is None:
|
||||||
|
# 如果没有会话,创建新的会话
|
||||||
|
from ..db.database import SessionLocal
|
||||||
|
session = SessionLocal()
|
||||||
|
try:
|
||||||
|
user_role = session.query(UserRole).join(Role).filter(
|
||||||
|
UserRole.user_id == self.id,
|
||||||
|
Role.code == role_code,
|
||||||
|
Role.is_active == True
|
||||||
|
).first()
|
||||||
|
return user_role is not None
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
else:
|
||||||
|
user_role = session.query(UserRole).join(Role).filter(
|
||||||
|
UserRole.user_id == self.id,
|
||||||
|
Role.code == role_code,
|
||||||
|
Role.is_active == True
|
||||||
|
).first()
|
||||||
|
return user_role is not None
|
||||||
|
|
||||||
|
def is_superuser(self) -> bool:
|
||||||
|
"""检查用户是否为超级管理员."""
|
||||||
|
return self.has_role('SUPER_ADMIN')
|
||||||
|
|
||||||
|
def is_admin_user(self) -> bool:
|
||||||
|
"""检查用户是否为管理员(兼容性方法)."""
|
||||||
|
return self.is_superuser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_admin(self) -> bool:
|
||||||
|
"""检查用户是否为管理员(属性方式)."""
|
||||||
|
return self.is_superuser()
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""Workflow models."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, String, Text, Boolean, Integer, JSON, ForeignKey, Enum
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from ..db.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowStatus(enum.Enum):
|
||||||
|
"""工作流状态枚举"""
|
||||||
|
DRAFT = "DRAFT" # 草稿
|
||||||
|
PUBLISHED = "PUBLISHED" # 已发布
|
||||||
|
ARCHIVED = "ARCHIVED" # 已归档
|
||||||
|
|
||||||
|
|
||||||
|
class NodeType(enum.Enum):
|
||||||
|
"""节点类型枚举"""
|
||||||
|
START = "start" # 开始节点
|
||||||
|
END = "end" # 结束节点
|
||||||
|
LLM = "llm" # 大模型节点
|
||||||
|
CONDITION = "condition" # 条件分支节点
|
||||||
|
LOOP = "loop" # 循环节点
|
||||||
|
CODE = "code" # 代码执行节点
|
||||||
|
HTTP = "http" # HTTP请求节点
|
||||||
|
TOOL = "tool" # 工具节点
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionStatus(enum.Enum):
|
||||||
|
"""执行状态枚举"""
|
||||||
|
PENDING = "pending" # 等待执行
|
||||||
|
RUNNING = "running" # 执行中
|
||||||
|
COMPLETED = "completed" # 执行完成
|
||||||
|
FAILED = "failed" # 执行失败
|
||||||
|
CANCELLED = "cancelled" # 已取消
|
||||||
|
|
||||||
|
|
||||||
|
class Workflow(BaseModel):
|
||||||
|
"""工作流模型"""
|
||||||
|
|
||||||
|
__tablename__ = "workflows"
|
||||||
|
|
||||||
|
name = Column(String(100), nullable=False, comment="工作流名称")
|
||||||
|
description = Column(Text, nullable=True, comment="工作流描述")
|
||||||
|
status = Column(Enum(WorkflowStatus), default=WorkflowStatus.DRAFT, nullable=False, comment="工作流状态")
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||||
|
|
||||||
|
# 工作流定义(JSON格式存储节点和连接信息)
|
||||||
|
definition = Column(JSON, nullable=False, comment="工作流定义")
|
||||||
|
|
||||||
|
# 版本信息
|
||||||
|
version = Column(String(20), default="1.0.0", nullable=False, comment="版本号")
|
||||||
|
|
||||||
|
# 关联用户
|
||||||
|
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False, comment="所有者ID")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
executions = relationship("WorkflowExecution", back_populates="workflow", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Workflow(id={self.id}, name='{self.name}', status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_definition=True):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'name': self.name,
|
||||||
|
'description': self.description,
|
||||||
|
'status': self.status.value,
|
||||||
|
'is_active': self.is_active,
|
||||||
|
'version': self.version,
|
||||||
|
'owner_id': self.owner_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_definition:
|
||||||
|
data['definition'] = self.definition
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecution(BaseModel):
|
||||||
|
"""工作流执行记录"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_executions"
|
||||||
|
|
||||||
|
workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False, comment="工作流ID")
|
||||||
|
status = Column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||||
|
|
||||||
|
# 执行输入和输出
|
||||||
|
input_data = Column(JSON, nullable=True, comment="输入数据")
|
||||||
|
output_data = Column(JSON, nullable=True, comment="输出数据")
|
||||||
|
|
||||||
|
# 执行信息
|
||||||
|
started_at = Column(String(50), nullable=True, comment="开始时间")
|
||||||
|
completed_at = Column(String(50), nullable=True, comment="完成时间")
|
||||||
|
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||||
|
|
||||||
|
# 执行者
|
||||||
|
executor_id = Column(Integer, ForeignKey("users.id"), nullable=False, comment="执行者ID")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
workflow = relationship("Workflow", back_populates="executions")
|
||||||
|
node_executions = relationship("NodeExecution", back_populates="workflow_execution", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<WorkflowExecution(id={self.id}, workflow_id={self.workflow_id}, status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self, include_nodes=False):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'workflow_id': self.workflow_id,
|
||||||
|
'status': self.status.value,
|
||||||
|
'input_data': self.input_data,
|
||||||
|
'output_data': self.output_data,
|
||||||
|
'started_at': self.started_at,
|
||||||
|
'completed_at': self.completed_at,
|
||||||
|
'error_message': self.error_message,
|
||||||
|
'executor_id': self.executor_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if include_nodes:
|
||||||
|
data['node_executions'] = [node.to_dict() for node in self.node_executions]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecution(BaseModel):
|
||||||
|
"""节点执行记录"""
|
||||||
|
|
||||||
|
__tablename__ = "node_executions"
|
||||||
|
|
||||||
|
workflow_execution_id = Column(Integer, ForeignKey("workflow_executions.id"), nullable=False, comment="工作流执行ID")
|
||||||
|
node_id = Column(String(50), nullable=False, comment="节点ID")
|
||||||
|
node_type = Column(Enum(NodeType), nullable=False, comment="节点类型")
|
||||||
|
node_name = Column(String(100), nullable=False, comment="节点名称")
|
||||||
|
|
||||||
|
# 执行状态和结果
|
||||||
|
status = Column(Enum(ExecutionStatus), default=ExecutionStatus.PENDING, nullable=False, comment="执行状态")
|
||||||
|
input_data = Column(JSON, nullable=True, comment="输入数据")
|
||||||
|
output_data = Column(JSON, nullable=True, comment="输出数据")
|
||||||
|
|
||||||
|
# 执行时间
|
||||||
|
started_at = Column(String(50), nullable=True, comment="开始时间")
|
||||||
|
completed_at = Column(String(50), nullable=True, comment="完成时间")
|
||||||
|
duration_ms = Column(Integer, nullable=True, comment="执行时长(毫秒)")
|
||||||
|
|
||||||
|
# 错误信息
|
||||||
|
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
workflow_execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<NodeExecution(id={self.id}, node_id='{self.node_id}', status='{self.status.value}')>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""转换为字典"""
|
||||||
|
data = super().to_dict()
|
||||||
|
data.update({
|
||||||
|
'workflow_execution_id': self.workflow_execution_id,
|
||||||
|
'node_id': self.node_id,
|
||||||
|
'node_type': self.node_type.value,
|
||||||
|
'node_name': self.node_name,
|
||||||
|
'status': self.status.value,
|
||||||
|
'input_data': self.input_data,
|
||||||
|
'output_data': self.output_data,
|
||||||
|
'started_at': self.started_at,
|
||||||
|
'completed_at': self.completed_at,
|
||||||
|
'duration_ms': self.duration_ms,
|
||||||
|
'error_message': self.error_message
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Schemas package initialization."""
|
||||||
|
|
||||||
|
from .user import UserCreate, UserUpdate, UserResponse
|
||||||
|
from .permission import (
|
||||||
|
RoleCreate, RoleUpdate, RoleResponse,
|
||||||
|
UserRoleAssign
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# User schemas
|
||||||
|
"UserCreate", "UserUpdate", "UserResponse",
|
||||||
|
|
||||||
|
# Permission schemas
|
||||||
|
"RoleCreate", "RoleUpdate", "RoleResponse",
|
||||||
|
"UserRoleAssign",
|
||||||
|
]
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue