86 lines
3.5 KiB
Python
86 lines
3.5 KiB
Python
"""Embedding factory for different providers."""
|
|
|
|
from typing import Optional
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from requests import Session
|
|
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
|
|
from ..core.config import settings
|
|
from loguru import logger
|
|
|
|
class EmbeddingFactory:
|
|
"""Factory class for creating embedding instances based on provider."""
|
|
|
|
@staticmethod
|
|
async def create_embeddings(
|
|
session: Session = None,
|
|
provider: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
dimensions: Optional[int] = None
|
|
) -> Embeddings:
|
|
"""Create embeddings instance based on provider.
|
|
|
|
Args:
|
|
provider: Embedding provider (openai, zhipu, deepseek, doubao, moonshot, sentence-transformers)
|
|
model: Model name
|
|
dimensions: Embedding dimensions
|
|
|
|
Returns:
|
|
Embeddings instance
|
|
"""
|
|
# 使用新的embedding配置
|
|
embedding_config = await settings.embedding.get_current_config(session)
|
|
provider = provider or settings.embedding.provider
|
|
model = model or embedding_config.get("model")
|
|
dimensions = dimensions or settings.vector_db.embedding_dimension
|
|
|
|
session.desc = f"创建嵌入模型: {provider}, {model}"
|
|
|
|
if provider == "openai":
|
|
return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions)
|
|
elif provider in ["zhipu", "deepseek", "doubao", "moonshot"]:
|
|
return EmbeddingFactory._create_openai_compatible_embeddings(embedding_config, model, dimensions, provider)
|
|
elif provider == "sentence-transformers":
|
|
return EmbeddingFactory._create_huggingface_embeddings(model)
|
|
else:
|
|
raise ValueError(f"Unsupported embedding provider: {provider}")
|
|
|
|
@staticmethod
|
|
def _create_openai_embeddings(embedding_config: dict, model: str, dimensions: int) -> OpenAIEmbeddings:
|
|
"""Create OpenAI embeddings."""
|
|
return OpenAIEmbeddings(
|
|
api_key=embedding_config["api_key"],
|
|
base_url=embedding_config["base_url"],
|
|
model=model if model.startswith("text-embedding") else "text-embedding-ada-002",
|
|
dimensions=dimensions if model.startswith("text-embedding-3") else None
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
def _create_openai_compatible_embeddings(embedding_config: dict, model: str, dimensions: int, provider: str) -> Embeddings:
|
|
"""Create OpenAI-compatible embeddings for ZhipuAI, DeepSeek, Doubao, Moonshot."""
|
|
if provider == "zhipu":
|
|
return ZhipuOpenAIEmbeddings(
|
|
api_key=embedding_config["api_key"],
|
|
base_url=embedding_config["base_url"],
|
|
model=model if model.startswith("embedding") else "embedding-3",
|
|
dimensions=dimensions
|
|
)
|
|
else:
|
|
return OpenAIEmbeddings(
|
|
api_key=embedding_config["api_key"],
|
|
base_url=embedding_config["base_url"],
|
|
model=model,
|
|
dimensions=dimensions
|
|
)
|
|
|
|
@staticmethod
|
|
def _create_huggingface_embeddings(model: str) -> HuggingFaceEmbeddings:
|
|
"""Create HuggingFace embeddings."""
|
|
return HuggingFaceEmbeddings(
|
|
model_name=model,
|
|
model_kwargs={'device': 'cpu'},
|
|
encode_kwargs={'normalize_embeddings': True}
|
|
) |