from typing import List from langchain_core.embeddings import Embeddings from loguru import logger from th_agenter.llm.base_llm import BaseLLM class EmbedLLM(BaseLLM, Embeddings): """嵌入式模型继承 LangChain 的 Embeddings 抽象类,而非 BaseLanguageModel""" def __init__(self, config): logger.info(f"初始化 EmbedLLM 模型: {config.model_name}") super().__init__(config) logger.info(f"已加载 EmbedLLM 模型: {config.model_name}") def embed_documents(self, texts: List[str]) -> List[List[float]]: """LangChain 要求的核心方法:批量文档向量化""" pass async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """异步批量向量化""" pass def embed_query(self, text: str) -> List[float]: """单查询文本向量化""" pass async def aembed_query(self, text: str) -> List[float]: """异步单查询向量化""" pass # 具体实现 BGE 嵌入式模型 class BGEEmbedLLM(EmbedLLM): def __init__(self, config): super().__init__(config) def _validate_config(self): if not self.config.model_name: raise ValueError("必须配置 model_name") def load_model(self): logger.info(f"正在加载 嵌入 模型: {self.config.model_name}") if hasattr(self.config, 'provider') and self.config.provider == 'ollama': from langchain_ollama import OllamaEmbeddings self.model = OllamaEmbeddings( model=self.config.model_name, base_url=self.config.base_url if hasattr(self.config, 'base_url') else None ) else: try: from langchain_huggingface import HuggingFaceEmbeddings self.model = HuggingFaceEmbeddings( model_name=self.config.model_name, model_kwargs={"device": self.config.device if hasattr(self.config, 'device') else "cpu"}, encode_kwargs={"normalize_embeddings": self.config.normalize_embeddings if hasattr(self.config, 'normalize_embeddings') else True} ) except ImportError as e: logger.error(f"Failed to load HuggingFaceEmbeddings: {e}") logger.error("Please install sentence-transformers: pip install sentence-transformers") raise def embed_documents(self, texts: List[str]) -> List[List[float]]: if not self.model: self.load_model() return self.model.embed_documents(texts) async def aembed_documents(self, texts: List[str]) -> List[List[float]]: if not self.model: self.load_model() return await self.model.aembed_documents(texts) def embed_query(self, text: str) -> List[float]: if not self.model: self.load_model() return self.model.embed_query(text) async def aembed_query(self, text: str) -> List[float]: if not self.model: self.load_model() return await self.model.aembed_query(text)