feat: 添加 Zhipu 嵌入模型支持并修复 Python 3.9 兼容性问题
- 添加 Zhipu (智谱AI) 嵌入模型支持 - 修复 Python 3.9 兼容性问题(anext -> async for) - 更新 README.md 添加项目介绍和前端开发指南 - 添加向量数据库配置文档
This commit is contained in:
parent
85d8f49b7a
commit
01070cd44d
|
|
@ -76,3 +76,7 @@
|
||||||
# 启动项目命令
|
# 启动项目命令
|
||||||
python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
```
|
```
|
||||||
|
5. 查看日志
|
||||||
|
```bash
|
||||||
|
tail -f /tmp/uvicorn.log
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
# 本地向量数据库配置指南
|
||||||
|
|
||||||
|
## 当前配置
|
||||||
|
|
||||||
|
项目使用 **ChromaDB** 作为本地向量数据库,数据存储在 `./data/chroma/` 目录下。
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
### 1. 环境变量配置 (.env)
|
||||||
|
|
||||||
|
在 `.env` 文件中配置以下参数:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# 向量数据库类型(虽然代码中已使用 Chroma,但建议设置为 chroma)
|
||||||
|
VECTOR_DB_TYPE=chroma
|
||||||
|
|
||||||
|
# 向量数据库存储路径(本地文件系统)
|
||||||
|
# 相对路径会基于项目根目录
|
||||||
|
VECTOR_DB_PERSIST_DIRECTORY=./data/chroma
|
||||||
|
|
||||||
|
# 集合名称(默认)
|
||||||
|
VECTOR_DB_COLLECTION_NAME=documents
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 目录结构
|
||||||
|
|
||||||
|
向量数据库按知识库 ID 组织,每个知识库有独立的目录:
|
||||||
|
|
||||||
|
```
|
||||||
|
data/chroma/
|
||||||
|
├── kb_1/ # 知识库 1 的向量数据
|
||||||
|
├── kb_2/ # 知识库 2 的向量数据
|
||||||
|
├── kb_13/ # 知识库 13 的向量数据
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 使用方式
|
||||||
|
|
||||||
|
向量数据库会在以下场景自动创建:
|
||||||
|
|
||||||
|
1. **上传文档时**:如果上传时选择立即处理,会自动创建向量数据库
|
||||||
|
2. **处理文档时**:调用 `POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process` 接口
|
||||||
|
|
||||||
|
### 4. 验证安装
|
||||||
|
|
||||||
|
运行以下命令验证向量数据库是否正常工作:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -c "
|
||||||
|
import chromadb
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 测试创建本地 Chroma 数据库
|
||||||
|
test_path = './data/chroma/test_kb'
|
||||||
|
client = chromadb.PersistentClient(path=test_path)
|
||||||
|
collection = client.get_or_create_collection(name='test')
|
||||||
|
print('✅ ChromaDB 本地数据库创建成功')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 依赖包
|
||||||
|
|
||||||
|
已安装的依赖:
|
||||||
|
- `chromadb>=1.0.20` - ChromaDB 核心库
|
||||||
|
- `langchain-chroma>=0.1.0` - LangChain Chroma 集成
|
||||||
|
|
||||||
|
### 6. 注意事项
|
||||||
|
|
||||||
|
- 向量数据库数据存储在本地文件系统,无需额外服务
|
||||||
|
- 每个知识库的向量数据独立存储
|
||||||
|
- 删除知识库时,对应的向量数据目录也会被清理
|
||||||
|
- 确保 `data/chroma/` 目录有写入权限
|
||||||
|
|
||||||
|
## 故障排查
|
||||||
|
|
||||||
|
### 问题:向量数据库不存在
|
||||||
|
|
||||||
|
**原因**:文档尚未被处理和向量化
|
||||||
|
|
||||||
|
**解决**:
|
||||||
|
1. 先调用处理文档接口:`POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process`
|
||||||
|
2. 处理完成后,向量数据库会自动创建
|
||||||
|
|
||||||
|
### 问题:权限错误
|
||||||
|
|
||||||
|
**解决**:
|
||||||
|
```bash
|
||||||
|
chmod -R 755 data/chroma/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 问题:磁盘空间不足
|
||||||
|
|
||||||
|
**解决**:清理不需要的知识库向量数据,或扩展存储空间
|
||||||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Any, List, Optional
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import select
|
||||||
from th_agenter.models.conversation import Conversation
|
from th_agenter.models.conversation import Conversation
|
||||||
from th_agenter.models.message import Message
|
from th_agenter.models.message import Message
|
||||||
from th_agenter.db.database import get_session
|
from th_agenter.db.database import get_session
|
||||||
|
|
@ -27,30 +28,34 @@ class ConversationContextService:
|
||||||
新创建的对话ID
|
新创建的对话ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
|
async for session in get_session():
|
||||||
conversation = Conversation(
|
try:
|
||||||
user_id=user_id,
|
conversation = Conversation(
|
||||||
title=title,
|
user_id=user_id,
|
||||||
created_at=datetime.utcnow(),
|
title=title,
|
||||||
updated_at=datetime.utcnow()
|
created_at=datetime.utcnow(),
|
||||||
)
|
updated_at=datetime.utcnow()
|
||||||
|
)
|
||||||
session.add(conversation)
|
|
||||||
await session.commit()
|
session.add(conversation)
|
||||||
await session.refresh(conversation)
|
await session.commit()
|
||||||
|
await session.refresh(conversation)
|
||||||
# 初始化对话上下文
|
|
||||||
self.context_cache[conversation.id] = {
|
# 初始化对话上下文
|
||||||
'conversation_id': conversation.id,
|
self.context_cache[conversation.id] = {
|
||||||
'user_id': user_id,
|
'conversation_id': conversation.id,
|
||||||
'file_list': [],
|
'user_id': user_id,
|
||||||
'selected_files': [],
|
'file_list': [],
|
||||||
'query_history': [],
|
'selected_files': [],
|
||||||
'created_at': datetime.utcnow().isoformat()
|
'query_history': [],
|
||||||
}
|
'created_at': datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
return conversation.id
|
|
||||||
|
await session.close()
|
||||||
|
return conversation.id
|
||||||
|
finally:
|
||||||
|
break # 只取第一个 session
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"创建对话失败: {e}")
|
print(f"创建对话失败: {e}")
|
||||||
|
|
@ -74,52 +79,58 @@ class ConversationContextService:
|
||||||
|
|
||||||
# 从数据库加载
|
# 从数据库加载
|
||||||
try:
|
try:
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
|
async for session in get_session():
|
||||||
conversation = session.query(Conversation).filter(
|
try:
|
||||||
Conversation.id == conversation_id
|
conversation = await session.scalar(
|
||||||
).first()
|
select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
)
|
||||||
if not conversation:
|
|
||||||
return None
|
if not conversation:
|
||||||
|
await session.close()
|
||||||
# 加载消息历史
|
return None
|
||||||
messages = session.query(Message).filter(
|
|
||||||
Message.conversation_id == conversation_id
|
# 加载消息历史
|
||||||
).order_by(Message.created_at).all()
|
messages = await session.scalars(
|
||||||
|
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||||||
# 重建上下文
|
)
|
||||||
context = {
|
messages_list = list(messages)
|
||||||
'conversation_id': conversation_id,
|
|
||||||
'user_id': conversation.user_id,
|
# 重建上下文
|
||||||
'file_list': [],
|
context = {
|
||||||
'selected_files': [],
|
'conversation_id': conversation_id,
|
||||||
'query_history': [],
|
'user_id': conversation.user_id,
|
||||||
'created_at': conversation.created_at.isoformat()
|
'file_list': [],
|
||||||
}
|
'selected_files': [],
|
||||||
|
'query_history': [],
|
||||||
# 从消息中提取查询历史
|
'created_at': conversation.created_at.isoformat()
|
||||||
for message in messages:
|
}
|
||||||
if message.role == 'user':
|
|
||||||
context['query_history'].append({
|
# 从消息中提取查询历史
|
||||||
'query': message.content,
|
for message in messages_list:
|
||||||
'timestamp': message.created_at.isoformat()
|
if message.role == 'user':
|
||||||
})
|
context['query_history'].append({
|
||||||
elif message.role == 'assistant' and message.metadata:
|
'query': message.content,
|
||||||
# 从助手消息的元数据中提取文件信息
|
'timestamp': message.created_at.isoformat()
|
||||||
try:
|
})
|
||||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
elif message.role == 'assistant' and message.metadata:
|
||||||
if 'selected_files' in metadata:
|
# 从助手消息的元数据中提取文件信息
|
||||||
context['selected_files'] = metadata['selected_files']
|
try:
|
||||||
if 'file_list' in metadata:
|
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||||
context['file_list'] = metadata['file_list']
|
if 'selected_files' in metadata:
|
||||||
except (json.JSONDecodeError, TypeError):
|
context['selected_files'] = metadata['selected_files']
|
||||||
pass
|
if 'file_list' in metadata:
|
||||||
|
context['file_list'] = metadata['file_list']
|
||||||
# 缓存上下文
|
except (json.JSONDecodeError, TypeError):
|
||||||
self.context_cache[conversation_id] = context
|
pass
|
||||||
|
|
||||||
return context
|
# 缓存上下文
|
||||||
|
self.context_cache[conversation_id] = context
|
||||||
|
|
||||||
|
await session.close()
|
||||||
|
return context
|
||||||
|
finally:
|
||||||
|
break # 只取第一个 session
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取对话上下文失败: {e}")
|
print(f"获取对话上下文失败: {e}")
|
||||||
|
|
@ -194,29 +205,30 @@ class ConversationContextService:
|
||||||
保存是否成功
|
保存是否成功
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
|
async for session in get_session():
|
||||||
message = Message(
|
message = Message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role=role,
|
role=role,
|
||||||
content=content,
|
content=content,
|
||||||
metadata=json.dumps(metadata) if metadata else None,
|
metadata=json.dumps(metadata) if metadata else None,
|
||||||
created_at=datetime.utcnow()
|
created_at=datetime.utcnow()
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(message)
|
session.add(message)
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
# 更新对话的最后更新时间
|
|
||||||
conversation = session.query(Conversation).filter(
|
|
||||||
Conversation.id == conversation_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if conversation:
|
|
||||||
conversation.updated_at = datetime.utcnow()
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return True
|
# 更新对话的最后更新时间
|
||||||
|
conversation = await session.scalar(
|
||||||
|
select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
conversation.updated_at = datetime.utcnow()
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await session.close()
|
||||||
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"保存消息失败: {e}")
|
print(f"保存消息失败: {e}")
|
||||||
|
|
@ -262,31 +274,33 @@ class ConversationContextService:
|
||||||
消息历史列表
|
消息历史列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
|
async for session in get_session():
|
||||||
messages = session.query(Message).filter(
|
messages = await session.scalars(
|
||||||
Message.conversation_id == conversation_id
|
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
|
||||||
).order_by(Message.created_at).all()
|
)
|
||||||
|
messages_list = list(messages)
|
||||||
history = []
|
|
||||||
for message in messages:
|
|
||||||
msg_data = {
|
|
||||||
'id': message.id,
|
|
||||||
'role': message.role,
|
|
||||||
'content': message.content,
|
|
||||||
'timestamp': message.created_at.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.metadata:
|
history = []
|
||||||
try:
|
for message in messages_list:
|
||||||
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
msg_data = {
|
||||||
msg_data['metadata'] = metadata
|
'id': message.id,
|
||||||
except (json.JSONDecodeError, TypeError):
|
'role': message.role,
|
||||||
pass
|
'content': message.content,
|
||||||
|
'timestamp': message.created_at.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.metadata:
|
||||||
|
try:
|
||||||
|
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
|
||||||
|
msg_data['metadata'] = metadata
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
history.append(msg_data)
|
||||||
|
|
||||||
history.append(msg_data)
|
await session.close()
|
||||||
|
return history
|
||||||
return history
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取对话历史失败: {e}")
|
print(f"获取对话历史失败: {e}")
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,8 @@ class DocumentProcessor:
|
||||||
config = None
|
config = None
|
||||||
if session:
|
if session:
|
||||||
config = await llm_config_service.get_default_embedding_config(session)
|
config = await llm_config_service.get_default_embedding_config(session)
|
||||||
if config:
|
if config and session:
|
||||||
if(session != None):
|
session.desc = f"获取默认嵌入模型配置: {config}"
|
||||||
session.desc = f"获取默认嵌入模型配置: {config}"
|
|
||||||
# # 转换配置格式
|
# # 转换配置格式
|
||||||
# config = {
|
# config = {
|
||||||
# "provider": config.provider,
|
# "provider": config.provider,
|
||||||
|
|
@ -96,39 +95,55 @@ class DocumentProcessor:
|
||||||
|
|
||||||
# 如果未找到配置,使用默认配置
|
# 如果未找到配置,使用默认配置
|
||||||
if not config:
|
if not config:
|
||||||
session.desc = f"ERROR: 未找到嵌入模型配置"
|
if session:
|
||||||
|
session.desc = f"ERROR: 未找到嵌入模型配置"
|
||||||
raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
|
raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
|
||||||
session.desc = f"获取嵌入模型配置 > 结果:{config}"
|
if session:
|
||||||
|
session.desc = f"获取嵌入模型配置 > 结果:{config}"
|
||||||
|
|
||||||
# 根据配置创建嵌入模型
|
# 根据配置创建嵌入模型
|
||||||
if config.provider == "openai":
|
if config.provider == "openai":
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
self.embeddings = OpenAIEmbeddings(
|
self.embeddings = OpenAIEmbeddings(
|
||||||
model=config.get("model", "text-embedding-3-small"),
|
model=config.model_name or "text-embedding-3-small",
|
||||||
api_key=config.get("api_key")
|
api_key=config.api_key
|
||||||
)
|
)
|
||||||
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})"
|
if session:
|
||||||
|
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.model_name or 'text-embedding-3-small'})"
|
||||||
|
elif config.provider == "zhipu":
|
||||||
|
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
||||||
|
self.embeddings = ZhipuOpenAIEmbeddings(
|
||||||
|
api_key=config.api_key,
|
||||||
|
base_url=config.base_url or "https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
model=config.model_name or "embedding-3",
|
||||||
|
dimensions=settings.vector_db.embedding_dimension
|
||||||
|
)
|
||||||
|
if session:
|
||||||
|
session.desc = f"创建嵌入模型 - ZhipuOpenAIEmbeddings(model={config.model_name or 'embedding-3'}, base_url={config.base_url})"
|
||||||
elif config.provider == "ollama":
|
elif config.provider == "ollama":
|
||||||
from langchain_ollama import OllamaEmbeddings
|
from langchain_ollama import OllamaEmbeddings
|
||||||
self.embeddings = OllamaEmbeddings(
|
self.embeddings = OllamaEmbeddings(
|
||||||
model=config.model_name,
|
model=config.model_name,
|
||||||
base_url=config.base_url
|
base_url=config.base_url
|
||||||
)
|
)
|
||||||
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
|
if session:
|
||||||
|
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
|
||||||
elif config.provider == "local":
|
elif config.provider == "local":
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
self.embeddings = HuggingFaceEmbeddings(
|
self.embeddings = HuggingFaceEmbeddings(
|
||||||
model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2")
|
model_name=config.model_name or "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})"
|
if session:
|
||||||
|
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.model_name or 'sentence-transformers/all-MiniLM-L6-v2'})"
|
||||||
else:
|
else:
|
||||||
# 默认使用OpenAI
|
# 默认使用OpenAI
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
self.embeddings = OpenAIEmbeddings(
|
self.embeddings = OpenAIEmbeddings(
|
||||||
model=config.get("model", "text-embedding-3-small"),
|
model=config.model_name or "text-embedding-3-small",
|
||||||
api_key=config.get("api_key")
|
api_key=config.api_key
|
||||||
)
|
)
|
||||||
session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
|
if session:
|
||||||
|
session.desc = f"ERROR: 未支持的嵌入提供者: {config.provider},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
|
||||||
|
|
||||||
return self.embeddings
|
return self.embeddings
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -388,17 +403,19 @@ class DocumentProcessor:
|
||||||
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
|
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
|
||||||
|
|
||||||
# 4. 更新文档状态
|
# 4. 更新文档状态
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
try:
|
async for db_session in get_session():
|
||||||
from sqlalchemy import select
|
try:
|
||||||
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
from sqlalchemy import select
|
||||||
|
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||||
if document:
|
|
||||||
document.status = "processed"
|
if document:
|
||||||
document.chunk_count = len(chunks)
|
document.is_processed = True
|
||||||
await session.commit()
|
document.chunk_count = len(chunks)
|
||||||
finally:
|
await db_session.commit()
|
||||||
await session.close()
|
finally:
|
||||||
|
await db_session.close()
|
||||||
|
break # 只取第一个 session
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"document_id": document_id,
|
"document_id": document_id,
|
||||||
|
|
@ -416,16 +433,18 @@ class DocumentProcessor:
|
||||||
|
|
||||||
# 更新文档状态为失败
|
# 更新文档状态为失败
|
||||||
try:
|
try:
|
||||||
session = await anext(get_session())
|
# Python 3.9 兼容:使用 async for 替代 anext
|
||||||
try:
|
async for db_session in get_session():
|
||||||
from sqlalchemy import select
|
try:
|
||||||
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
from sqlalchemy import select
|
||||||
if document:
|
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
|
||||||
document.status = "failed"
|
if document:
|
||||||
document.error_message = str(e)
|
document.is_processed = False
|
||||||
await session.commit()
|
document.processing_error = str(e)
|
||||||
finally:
|
await db_session.commit()
|
||||||
await session.close()
|
finally:
|
||||||
|
await db_session.close()
|
||||||
|
break # 只取第一个 session
|
||||||
except Exception as db_error:
|
except Exception as db_error:
|
||||||
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"
|
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue