77 lines
3.1 KiB
Python
77 lines
3.1 KiB
Python
|
|
from typing import List
|
||
|
|
from langchain_core.embeddings import Embeddings
|
||
|
|
from loguru import logger
|
||
|
|
from th_agenter.llm.base_llm import BaseLLM
|
||
|
|
|
||
|
|
class EmbedLLM(BaseLLM, Embeddings):
|
||
|
|
"""嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel"""
|
||
|
|
def __init__(self, config):
|
||
|
|
logger.info(f"初始化 EmbedLLM 模型: {config.model_name}")
|
||
|
|
super().__init__(config)
|
||
|
|
logger.info(f"已加载 EmbedLLM 模型: {config.model_name}")
|
||
|
|
|
||
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
|
"""LangChain 要求的核心方法:批量文档向量化"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
|
"""异步批量向量化"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def embed_query(self, text: str) -> List[float]:
|
||
|
|
"""单查询文本向量化"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def aembed_query(self, text: str) -> List[float]:
|
||
|
|
"""异步单查询向量化"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
# 具体实现 BGE 嵌入式模型
|
||
|
|
class BGEEmbedLLM(EmbedLLM):
|
||
|
|
def __init__(self, config):
|
||
|
|
super().__init__(config)
|
||
|
|
|
||
|
|
def _validate_config(self):
|
||
|
|
if not self.config.model_name:
|
||
|
|
raise ValueError("必须配置 model_name")
|
||
|
|
|
||
|
|
def load_model(self):
|
||
|
|
logger.info(f"正在加载 嵌入 模型: {self.config.model_name}")
|
||
|
|
if hasattr(self.config, 'provider') and self.config.provider == 'ollama':
|
||
|
|
from langchain_ollama import OllamaEmbeddings
|
||
|
|
self.model = OllamaEmbeddings(
|
||
|
|
model=self.config.model_name,
|
||
|
|
base_url=self.config.base_url if hasattr(self.config, 'base_url') else None
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||
|
|
self.model = HuggingFaceEmbeddings(
|
||
|
|
model_name=self.config.model_name,
|
||
|
|
model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"},
|
||
|
|
encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True}
|
||
|
|
)
|
||
|
|
except ImportError as e:
|
||
|
|
logger.error(f"Failed to load HuggingFaceEmbeddings: {e}")
|
||
|
|
logger.error("Please install sentence-transformers: pip install sentence-transformers")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
|
if not self.model:
|
||
|
|
self.load_model()
|
||
|
|
return self.model.embed_documents(texts)
|
||
|
|
|
||
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
|
if not self.model:
|
||
|
|
self.load_model()
|
||
|
|
return await self.model.aembed_documents(texts)
|
||
|
|
|
||
|
|
def embed_query(self, text: str) -> List[float]:
|
||
|
|
if not self.model:
|
||
|
|
self.load_model()
|
||
|
|
return self.model.embed_query(text)
|
||
|
|
|
||
|
|
async def aembed_query(self, text: str) -> List[float]:
|
||
|
|
if not self.model:
|
||
|
|
self.load_model()
|
||
|
|
return await self.model.aembed_query(text)
|