hyf-backend/th_agenter/services/embedding_factory.py

86 lines
3.5 KiB
Python

"""Embedding factory for different providers."""
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from requests import Session
from .zhipu_embeddings import ZhipuOpenAIEmbeddings
from ..core.config import settings
from loguru import logger
class EmbeddingFactory:
"""Factory class for creating embedding instances based on provider."""
@staticmethod
async def create_embeddings(
session: Session = None,
provider: Optional[str] = None,
model: Optional[str] = None,
dimensions: Optional[int] = None
) -> Embeddings:
"""Create embeddings instance based on provider.
Args:
provider: Embedding provider (openai, zhipu, deepseek, doubao, moonshot, sentence-transformers)
model: Model name
dimensions: Embedding dimensions
Returns:
Embeddings instance
"""
# 使用新的embedding配置
embedding_config = await settings.embedding.get_current_config(session)
provider = provider or settings.embedding.provider
model = model or embedding_config.get("model")
dimensions = dimensions or settings.vector_db.embedding_dimension
session.desc = f"创建嵌入模型: {provider}, {model}"
if provider == "openai":
return EmbeddingFactory._create_openai_embeddings(embedding_config, model, dimensions)
elif provider in ["zhipu", "deepseek", "doubao", "moonshot"]:
return EmbeddingFactory._create_openai_compatible_embeddings(embedding_config, model, dimensions, provider)
elif provider == "sentence-transformers":
return EmbeddingFactory._create_huggingface_embeddings(model)
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
@staticmethod
def _create_openai_embeddings(embedding_config: dict, model: str, dimensions: int) -> OpenAIEmbeddings:
"""Create OpenAI embeddings."""
return OpenAIEmbeddings(
api_key=embedding_config["api_key"],
base_url=embedding_config["base_url"],
model=model if model.startswith("text-embedding") else "text-embedding-ada-002",
dimensions=dimensions if model.startswith("text-embedding-3") else None
)
@staticmethod
def _create_openai_compatible_embeddings(embedding_config: dict, model: str, dimensions: int, provider: str) -> Embeddings:
"""Create OpenAI-compatible embeddings for ZhipuAI, DeepSeek, Doubao, Moonshot."""
if provider == "zhipu":
return ZhipuOpenAIEmbeddings(
api_key=embedding_config["api_key"],
base_url=embedding_config["base_url"],
model=model if model.startswith("embedding") else "embedding-3",
dimensions=dimensions
)
else:
return OpenAIEmbeddings(
api_key=embedding_config["api_key"],
base_url=embedding_config["base_url"],
model=model,
dimensions=dimensions
)
@staticmethod
def _create_huggingface_embeddings(model: str) -> HuggingFaceEmbeddings:
"""Create HuggingFace embeddings."""
return HuggingFaceEmbeddings(
model_name=model,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)