69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
from typing import List, Optional
|
||
from th_agenter.llm.base_llm import BaseLLM
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||
|
||
|
||
class LocalLLM(BaseLLM):
|
||
def __init__(self, config):
|
||
super().__init__(config)
|
||
self.local_config = config
|
||
|
||
def _validate_config(self):
|
||
if not self.local_config.model_path:
|
||
raise ValueError("LocalLLM 必须配置 model_path")
|
||
|
||
def load_model(self):
|
||
from langchain_community.llms import LlamaCpp
|
||
self.model = LlamaCpp(
|
||
model_path=self.local_config.model_path,
|
||
temperature=self.local_config.temperature,
|
||
max_tokens=self.local_config.max_tokens,
|
||
n_ctx=self.local_config.n_ctx,
|
||
n_threads=self.local_config.n_threads,
|
||
verbose=False
|
||
)
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
return "llama"
|
||
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||
) -> ChatResult:
|
||
if not self.model:
|
||
self.load_model()
|
||
# 适配 LlamaCpp(非 Chat 模型)的调用方式
|
||
prompt = self._format_messages(messages)
|
||
text = self.model.invoke(prompt, stop=stop, **kwargs)
|
||
# 构造 ChatResult(LangChain 标准格式)
|
||
generation = ChatGeneration(message=AIMessage(content=text))
|
||
return ChatResult(generations=[generation])
|
||
|
||
async def _agenerate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
||
) -> ChatResult:
|
||
if not self.model:
|
||
self.load_model()
|
||
prompt = self._format_messages(messages)
|
||
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
|
||
generation = ChatGeneration(message=AIMessage(content=text))
|
||
return ChatResult(generations=[generation])
|
||
|
||
def _format_messages(self, messages: List[BaseMessage]) -> str:
|
||
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
|
||
prompt_parts = []
|
||
for msg in messages:
|
||
if isinstance(msg, HumanMessage):
|
||
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
|
||
elif isinstance(msg, AIMessage):
|
||
prompt_parts.append(msg.content)
|
||
return "".join(prompt_parts)
|