69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
|
|
from typing import List, Optional
|
|||
|
|
from th_agenter.llm.base_llm import BaseLLM
|
|||
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||
|
|
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
|||
|
|
from langchain_core.outputs import ChatResult, ChatGeneration
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LocalLLM(BaseLLM):
|
|||
|
|
def __init__(self, config):
|
|||
|
|
super().__init__(config)
|
|||
|
|
self.local_config = config
|
|||
|
|
|
|||
|
|
def _validate_config(self):
|
|||
|
|
if not self.local_config.model_path:
|
|||
|
|
raise ValueError("LocalLLM 必须配置 model_path")
|
|||
|
|
|
|||
|
|
def load_model(self):
|
|||
|
|
from langchain_community.llms import LlamaCpp
|
|||
|
|
self.model = LlamaCpp(
|
|||
|
|
model_path=self.local_config.model_path,
|
|||
|
|
temperature=self.local_config.temperature,
|
|||
|
|
max_tokens=self.local_config.max_tokens,
|
|||
|
|
n_ctx=self.local_config.n_ctx,
|
|||
|
|
n_threads=self.local_config.n_threads,
|
|||
|
|
verbose=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def _llm_type(self) -> str:
|
|||
|
|
return "llama"
|
|||
|
|
|
|||
|
|
def _generate(
|
|||
|
|
self,
|
|||
|
|
messages: List[BaseMessage],
|
|||
|
|
stop: Optional[List[str]] = None,
|
|||
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
|||
|
|
) -> ChatResult:
|
|||
|
|
if not self.model:
|
|||
|
|
self.load_model()
|
|||
|
|
# 适配 LlamaCpp(非 Chat 模型)的调用方式
|
|||
|
|
prompt = self._format_messages(messages)
|
|||
|
|
text = self.model.invoke(prompt, stop=stop, **kwargs)
|
|||
|
|
# 构造 ChatResult(LangChain 标准格式)
|
|||
|
|
generation = ChatGeneration(message=AIMessage(content=text))
|
|||
|
|
return ChatResult(generations=[generation])
|
|||
|
|
|
|||
|
|
async def _agenerate(
|
|||
|
|
self,
|
|||
|
|
messages: List[BaseMessage],
|
|||
|
|
stop: Optional[List[str]] = None,
|
|||
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
|||
|
|
) -> ChatResult:
|
|||
|
|
if not self.model:
|
|||
|
|
self.load_model()
|
|||
|
|
prompt = self._format_messages(messages)
|
|||
|
|
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
|
|||
|
|
generation = ChatGeneration(message=AIMessage(content=text))
|
|||
|
|
return ChatResult(generations=[generation])
|
|||
|
|
|
|||
|
|
def _format_messages(self, messages: List[BaseMessage]) -> str:
|
|||
|
|
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
|
|||
|
|
prompt_parts = []
|
|||
|
|
for msg in messages:
|
|||
|
|
if isinstance(msg, HumanMessage):
|
|||
|
|
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
|
|||
|
|
elif isinstance(msg, AIMessage):
|
|||
|
|
prompt_parts.append(msg.content)
|
|||
|
|
return "".join(prompt_parts)
|