feat: 添加 Zhipu 嵌入模型支持并修复 Python 3.9 兼容性问题

- 添加 Zhipu (智谱AI) 嵌入模型支持
- 修复 Python 3.9 兼容性问题(anext -> async for)
- 更新 README.md 添加项目介绍和前端开发指南
- 添加向量数据库配置文档
This commit is contained in:
eason 2026-01-22 16:18:29 +08:00
parent 85d8f49b7a
commit 01070cd44d
4 changed files with 280 additions and 150 deletions

View File

@ -76,3 +76,7 @@
# 启动项目命令 # 启动项目命令
python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload python3 -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload
``` ```
5. 查看日志
```bash
tail -f /tmp/uvicorn.log
```

View File

@ -0,0 +1,93 @@
# 本地向量数据库配置指南
## 当前配置
项目使用 **ChromaDB** 作为本地向量数据库,数据存储在 `./data/chroma/` 目录下。
## 配置说明
### 1. 环境变量配置 (.env)
`.env` 文件中配置以下参数:
```env
# 向量数据库类型(虽然代码中已使用 Chroma但建议设置为 chroma
VECTOR_DB_TYPE=chroma
# 向量数据库存储路径(本地文件系统)
# 相对路径会基于项目根目录
VECTOR_DB_PERSIST_DIRECTORY=./data/chroma
# 集合名称(默认)
VECTOR_DB_COLLECTION_NAME=documents
```
### 2. 目录结构
向量数据库按知识库 ID 组织,每个知识库有独立的目录:
```
data/chroma/
├── kb_1/ # 知识库 1 的向量数据
├── kb_2/ # 知识库 2 的向量数据
├── kb_13/ # 知识库 13 的向量数据
└── ...
```
### 3. 使用方式
向量数据库会在以下场景自动创建:
1. **上传文档时**:如果上传时选择立即处理,会自动创建向量数据库
2. **处理文档时**:调用 `POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process` 接口
### 4. 验证安装
运行以下命令验证向量数据库是否正常工作:
```bash
python3 -c "
import chromadb
from pathlib import Path
# 测试创建本地 Chroma 数据库
test_path = './data/chroma/test_kb'
client = chromadb.PersistentClient(path=test_path)
collection = client.get_or_create_collection(name='test')
print('✅ ChromaDB 本地数据库创建成功')
"
```
### 5. 依赖包
已安装的依赖:
- `chromadb>=1.0.20` - ChromaDB 核心库
- `langchain-chroma>=0.1.0` - LangChain Chroma 集成
### 6. 注意事项
- 向量数据库数据存储在本地文件系统,无需额外服务
- 每个知识库的向量数据独立存储
- 删除知识库时,对应的向量数据目录也会被清理
- 确保 `data/chroma/` 目录有写入权限
## 故障排查
### 问题:向量数据库不存在
**原因**:文档尚未被处理和向量化
**解决**
1. 先调用处理文档接口:`POST /api/knowledge-bases/{kb_id}/documents/{doc_id}/process`
2. 处理完成后,向量数据库会自动创建
### 问题:权限错误
**解决**
```bash
chmod -R 755 data/chroma/
```
### 问题:磁盘空间不足
**解决**:清理不需要的知识库向量数据,或扩展存储空间

View File

@ -2,6 +2,7 @@ from typing import Dict, Any, List, Optional
import json import json
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import select
from th_agenter.models.conversation import Conversation from th_agenter.models.conversation import Conversation
from th_agenter.models.message import Message from th_agenter.models.message import Message
from th_agenter.db.database import get_session from th_agenter.db.database import get_session
@ -27,30 +28,34 @@ class ConversationContextService:
新创建的对话ID 新创建的对话ID
""" """
try: try:
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
conversation = Conversation( try:
user_id=user_id, conversation = Conversation(
title=title, user_id=user_id,
created_at=datetime.utcnow(), title=title,
updated_at=datetime.utcnow() created_at=datetime.utcnow(),
) updated_at=datetime.utcnow()
)
session.add(conversation)
await session.commit() session.add(conversation)
await session.refresh(conversation) await session.commit()
await session.refresh(conversation)
# 初始化对话上下文
self.context_cache[conversation.id] = { # 初始化对话上下文
'conversation_id': conversation.id, self.context_cache[conversation.id] = {
'user_id': user_id, 'conversation_id': conversation.id,
'file_list': [], 'user_id': user_id,
'selected_files': [], 'file_list': [],
'query_history': [], 'selected_files': [],
'created_at': datetime.utcnow().isoformat() 'query_history': [],
} 'created_at': datetime.utcnow().isoformat()
}
return conversation.id
await session.close()
return conversation.id
finally:
break # 只取第一个 session
except Exception as e: except Exception as e:
print(f"创建对话失败: {e}") print(f"创建对话失败: {e}")
@ -74,52 +79,58 @@ class ConversationContextService:
# 从数据库加载 # 从数据库加载
try: try:
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
conversation = session.query(Conversation).filter( try:
Conversation.id == conversation_id conversation = await session.scalar(
).first() select(Conversation).where(Conversation.id == conversation_id)
)
if not conversation:
return None if not conversation:
await session.close()
# 加载消息历史 return None
messages = session.query(Message).filter(
Message.conversation_id == conversation_id # 加载消息历史
).order_by(Message.created_at).all() messages = await session.scalars(
select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
# 重建上下文 )
context = { messages_list = list(messages)
'conversation_id': conversation_id,
'user_id': conversation.user_id, # 重建上下文
'file_list': [], context = {
'selected_files': [], 'conversation_id': conversation_id,
'query_history': [], 'user_id': conversation.user_id,
'created_at': conversation.created_at.isoformat() 'file_list': [],
} 'selected_files': [],
'query_history': [],
# 从消息中提取查询历史 'created_at': conversation.created_at.isoformat()
for message in messages: }
if message.role == 'user':
context['query_history'].append({ # 从消息中提取查询历史
'query': message.content, for message in messages_list:
'timestamp': message.created_at.isoformat() if message.role == 'user':
}) context['query_history'].append({
elif message.role == 'assistant' and message.metadata: 'query': message.content,
# 从助手消息的元数据中提取文件信息 'timestamp': message.created_at.isoformat()
try: })
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata elif message.role == 'assistant' and message.metadata:
if 'selected_files' in metadata: # 从助手消息的元数据中提取文件信息
context['selected_files'] = metadata['selected_files'] try:
if 'file_list' in metadata: metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
context['file_list'] = metadata['file_list'] if 'selected_files' in metadata:
except (json.JSONDecodeError, TypeError): context['selected_files'] = metadata['selected_files']
pass if 'file_list' in metadata:
context['file_list'] = metadata['file_list']
# 缓存上下文 except (json.JSONDecodeError, TypeError):
self.context_cache[conversation_id] = context pass
return context # 缓存上下文
self.context_cache[conversation_id] = context
await session.close()
return context
finally:
break # 只取第一个 session
except Exception as e: except Exception as e:
print(f"获取对话上下文失败: {e}") print(f"获取对话上下文失败: {e}")
@ -194,29 +205,30 @@ class ConversationContextService:
保存是否成功 保存是否成功
""" """
try: try:
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
message = Message( message = Message(
conversation_id=conversation_id, conversation_id=conversation_id,
role=role, role=role,
content=content, content=content,
metadata=json.dumps(metadata) if metadata else None, metadata=json.dumps(metadata) if metadata else None,
created_at=datetime.utcnow() created_at=datetime.utcnow()
) )
session.add(message) session.add(message)
await session.commit()
# 更新对话的最后更新时间
conversation = session.query(Conversation).filter(
Conversation.id == conversation_id
).first()
if conversation:
conversation.updated_at = datetime.utcnow()
await session.commit() await session.commit()
return True # 更新对话的最后更新时间
conversation = await session.scalar(
select(Conversation).where(Conversation.id == conversation_id)
)
if conversation:
conversation.updated_at = datetime.utcnow()
await session.commit()
await session.close()
return True
except Exception as e: except Exception as e:
print(f"保存消息失败: {e}") print(f"保存消息失败: {e}")
@ -262,31 +274,33 @@ class ConversationContextService:
消息历史列表 消息历史列表
""" """
try: try:
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
async for session in get_session():
messages = session.query(Message).filter( messages = await session.scalars(
Message.conversation_id == conversation_id select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at)
).order_by(Message.created_at).all() )
messages_list = list(messages)
history = []
for message in messages:
msg_data = {
'id': message.id,
'role': message.role,
'content': message.content,
'timestamp': message.created_at.isoformat()
}
if message.metadata: history = []
try: for message in messages_list:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata msg_data = {
msg_data['metadata'] = metadata 'id': message.id,
except (json.JSONDecodeError, TypeError): 'role': message.role,
pass 'content': message.content,
'timestamp': message.created_at.isoformat()
}
if message.metadata:
try:
metadata = json.loads(message.metadata) if isinstance(message.metadata, str) else message.metadata
msg_data['metadata'] = metadata
except (json.JSONDecodeError, TypeError):
pass
history.append(msg_data)
history.append(msg_data) await session.close()
return history
return history
except Exception as e: except Exception as e:
print(f"获取对话历史失败: {e}") print(f"获取对话历史失败: {e}")

View File

@ -84,9 +84,8 @@ class DocumentProcessor:
config = None config = None
if session: if session:
config = await llm_config_service.get_default_embedding_config(session) config = await llm_config_service.get_default_embedding_config(session)
if config: if config and session:
if(session != None): session.desc = f"获取默认嵌入模型配置: {config}"
session.desc = f"获取默认嵌入模型配置: {config}"
# # 转换配置格式 # # 转换配置格式
# config = { # config = {
# "provider": config.provider, # "provider": config.provider,
@ -96,39 +95,55 @@ class DocumentProcessor:
# 如果未找到配置,使用默认配置 # 如果未找到配置,使用默认配置
if not config: if not config:
session.desc = f"ERROR: 未找到嵌入模型配置" if session:
session.desc = f"ERROR: 未找到嵌入模型配置"
raise HTTPException(status_code=400, detail="未找到嵌入模型配置") raise HTTPException(status_code=400, detail="未找到嵌入模型配置")
session.desc = f"获取嵌入模型配置 > 结果:{config}" if session:
session.desc = f"获取嵌入模型配置 > 结果:{config}"
# 根据配置创建嵌入模型 # 根据配置创建嵌入模型
if config.provider == "openai": if config.provider == "openai":
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
self.embeddings = OpenAIEmbeddings( self.embeddings = OpenAIEmbeddings(
model=config.get("model", "text-embedding-3-small"), model=config.model_name or "text-embedding-3-small",
api_key=config.get("api_key") api_key=config.api_key
) )
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.get('model', 'text-embedding-3-small')})" if session:
session.desc = f"创建嵌入模型 - OpenAIEmbeddings(model={config.model_name or 'text-embedding-3-small'})"
elif config.provider == "zhipu":
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
self.embeddings = ZhipuOpenAIEmbeddings(
api_key=config.api_key,
base_url=config.base_url or "https://open.bigmodel.cn/api/paas/v4",
model=config.model_name or "embedding-3",
dimensions=settings.vector_db.embedding_dimension
)
if session:
session.desc = f"创建嵌入模型 - ZhipuOpenAIEmbeddings(model={config.model_name or 'embedding-3'}, base_url={config.base_url})"
elif config.provider == "ollama": elif config.provider == "ollama":
from langchain_ollama import OllamaEmbeddings from langchain_ollama import OllamaEmbeddings
self.embeddings = OllamaEmbeddings( self.embeddings = OllamaEmbeddings(
model=config.model_name, model=config.model_name,
base_url=config.base_url base_url=config.base_url
) )
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})" if session:
session.desc = f"创建嵌入模型 - OllamaEmbeddings({self.embeddings.base_url} - {self.embeddings.model})"
elif config.provider == "local": elif config.provider == "local":
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
self.embeddings = HuggingFaceEmbeddings( self.embeddings = HuggingFaceEmbeddings(
model_name=config.get("model", "sentence-transformers/all-MiniLM-L6-v2") model_name=config.model_name or "sentence-transformers/all-MiniLM-L6-v2"
) )
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.get('model', 'sentence-transformers/all-MiniLM-L6-v2')})" if session:
session.desc = f"创建嵌入模型 - HuggingFaceEmbeddings(model={config.model_name or 'sentence-transformers/all-MiniLM-L6-v2'})"
else: else:
# 默认使用OpenAI # 默认使用OpenAI
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
self.embeddings = OpenAIEmbeddings( self.embeddings = OpenAIEmbeddings(
model=config.get("model", "text-embedding-3-small"), model=config.model_name or "text-embedding-3-small",
api_key=config.get("api_key") api_key=config.api_key
) )
session.desc = f"ERROR: 未支持的嵌入提供者: {config['provider']},已使用默认的 OpenAIEmbeddings - 可能不正确或无效" if session:
session.desc = f"ERROR: 未支持的嵌入提供者: {config.provider},已使用默认的 OpenAIEmbeddings - 可能不正确或无效"
return self.embeddings return self.embeddings
except Exception as e: except Exception as e:
@ -388,17 +403,19 @@ class DocumentProcessor:
self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id) self.add_documents_to_vector_store(session, knowledge_base_id, chunks, document_id)
# 4. 更新文档状态 # 4. 更新文档状态
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
try: async for db_session in get_session():
from sqlalchemy import select try:
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) from sqlalchemy import select
document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
if document:
document.status = "processed" if document:
document.chunk_count = len(chunks) document.is_processed = True
await session.commit() document.chunk_count = len(chunks)
finally: await db_session.commit()
await session.close() finally:
await db_session.close()
break # 只取第一个 session
result = { result = {
"document_id": document_id, "document_id": document_id,
@ -416,16 +433,18 @@ class DocumentProcessor:
# 更新文档状态为失败 # 更新文档状态为失败
try: try:
session = await anext(get_session()) # Python 3.9 兼容:使用 async for 替代 anext
try: async for db_session in get_session():
from sqlalchemy import select try:
document = await session.scalar(select(DocumentModel).where(DocumentModel.id == document_id)) from sqlalchemy import select
if document: document = await db_session.scalar(select(DocumentModel).where(DocumentModel.id == document_id))
document.status = "failed" if document:
document.error_message = str(e) document.is_processed = False
await session.commit() document.processing_error = str(e)
finally: await db_session.commit()
await session.close() finally:
await db_session.close()
break # 只取第一个 session
except Exception as db_error: except Exception as db_error:
session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}" session.desc = f"ERROR: 更新文档状态失败: {str(db_error)}"