hyf-backend/th_agenter/llm/local/local_llm.py

69 lines
2.6 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
from typing import List, Optional
from th_agenter.llm.base_llm import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.outputs import ChatResult, ChatGeneration
class LocalLLM(BaseLLM):
def __init__(self, config):
super().__init__(config)
self.local_config = config
def _validate_config(self):
if not self.local_config.model_path:
raise ValueError("LocalLLM 必须配置 model_path")
def load_model(self):
from langchain_community.llms import LlamaCpp
self.model = LlamaCpp(
model_path=self.local_config.model_path,
temperature=self.local_config.temperature,
max_tokens=self.local_config.max_tokens,
n_ctx=self.local_config.n_ctx,
n_threads=self.local_config.n_threads,
verbose=False
)
@property
def _llm_type(self) -> str:
return "llama"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
# 适配 LlamaCpp非 Chat 模型)的调用方式
prompt = self._format_messages(messages)
text = self.model.invoke(prompt, stop=stop, **kwargs)
# 构造 ChatResultLangChain 标准格式)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
) -> ChatResult:
if not self.model:
self.load_model()
prompt = self._format_messages(messages)
text = await self.model.ainvoke(prompt, stop=stop, **kwargs)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])
def _format_messages(self, messages: List[BaseMessage]) -> str:
"""将 LangChain 消息列表格式化为本地模型的 Prompt"""
prompt_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
prompt_parts.append(f"<s>[INST] {msg.content} [/INST]")
elif isinstance(msg, AIMessage):
prompt_parts.append(msg.content)
return "".join(prompt_parts)