80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import HumanMessage, BaseMessage
|
|
from typing import List, Optional, Any, Union
|
|
from langchain_core.outputs import ChatResult
|
|
from th_agenter.llm.base_llm import BaseLLM
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
|
class OnlineLLM(BaseLLM):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
def _validate_config(self):
|
|
if not self.config.api_key:
|
|
raise ValueError("OnlineLLM 必须配置 api_key")
|
|
|
|
def load_model(self):
|
|
# from langchain.chat_models import init_chat_model
|
|
# self.model = init_chat_model(
|
|
# self.config.model_name,
|
|
# self.config.api_key)
|
|
from langchain_openai import ChatOpenAI
|
|
self.model = ChatOpenAI(
|
|
api_key=self.config.api_key,
|
|
model_name=self.config.model_name,
|
|
temperature=self.config.temperature,
|
|
max_tokens=self.config.max_tokens,
|
|
base_url=self.config.base_url,
|
|
)
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "openai" # 标识模型类型
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
|
) -> ChatResult:
|
|
"""委托给底层 LangChain 模型的 _generate 方法"""
|
|
if not self.model:
|
|
self.load_model()
|
|
# 复用底层模型的实现
|
|
return self.model._generate(
|
|
messages=messages,
|
|
stop=stop,
|
|
run_manager=run_manager,** kwargs
|
|
)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,** kwargs: Any,
|
|
) -> ChatResult:
|
|
if not self.model:
|
|
self.load_model()
|
|
return await self.model._agenerate(
|
|
messages=messages,
|
|
stop=stop,
|
|
run_manager=run_manager,** kwargs
|
|
)
|
|
|
|
# ---------------------- 保留自定义的便捷方法 ----------------------
|
|
def generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
|
"""自定义便捷方法:直接传入字符串 prompt 或消息列表"""
|
|
if isinstance(prompt, str):
|
|
messages = [HumanMessage(content=prompt)]
|
|
else:
|
|
messages = prompt
|
|
result = self._generate(messages, **kwargs)
|
|
return result.generations[0].text
|
|
|
|
async def async_generate(self, prompt: Union[str, List[BaseMessage]], **kwargs) -> str:
|
|
"""自定义便捷异步方法:直接传入字符串 prompt 或消息列表"""
|
|
if isinstance(prompt, str):
|
|
messages = [HumanMessage(content=prompt)]
|
|
else:
|
|
messages = prompt
|
|
result = await self._agenerate(messages, **kwargs)
|
|
return result.generations[0].text |