hyf-backend/th_agenter/llm/embed/embed_llm.py

77 lines
3.1 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
from typing import List
from langchain_core.embeddings import Embeddings
from loguru import logger
from th_agenter.llm.base_llm import BaseLLM
class EmbedLLM(BaseLLM, Embeddings):
"""嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel"""
def __init__(self, config):
logger.info(f"初始化 EmbedLLM 模型: {config.model_name}")
super().__init__(config)
logger.info(f"已加载 EmbedLLM 模型: {config.model_name}")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""LangChain 要求的核心方法:批量文档向量化"""
pass
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步批量向量化"""
pass
def embed_query(self, text: str) -> List[float]:
"""单查询文本向量化"""
pass
async def aembed_query(self, text: str) -> List[float]:
"""异步单查询向量化"""
pass
# 具体实现 BGE 嵌入式模型
class BGEEmbedLLM(EmbedLLM):
def __init__(self, config):
super().__init__(config)
def _validate_config(self):
if not self.config.model_name:
raise ValueError("必须配置 model_name")
def load_model(self):
logger.info(f"正在加载 嵌入 模型: {self.config.model_name}")
if hasattr(self.config, 'provider') and self.config.provider == 'ollama':
from langchain_ollama import OllamaEmbeddings
self.model = OllamaEmbeddings(
model=self.config.model_name,
base_url=self.config.base_url if hasattr(self.config, 'base_url') else None
)
else:
try:
from langchain_huggingface import HuggingFaceEmbeddings
self.model = HuggingFaceEmbeddings(
model_name=self.config.model_name,
model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"},
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True}
)
except ImportError as e:
logger.error(f"Failed to load HuggingFaceEmbeddings: {e}")
logger.error("Please install sentence-transformers: pip install sentence-transformers")
raise
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.model:
self.load_model()
return self.model.embed_documents(texts)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.model:
self.load_model()
return await self.model.aembed_documents(texts)
def embed_query(self, text: str) -> List[float]:
if not self.model:
self.load_model()
return self.model.embed_query(text)
async def aembed_query(self, text: str) -> List[float]:
if not self.model:
self.load_model()
return await self.model.aembed_query(text)