hxf/backend/th_agenter/llm/llm_model_base.py

70 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, dotenv
from loguru import logger
from utils.Constant import Constant
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
# 加载环境变量
dotenv.load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_BASE_URL"] = os.getenv("OPENAI_BASE_URL")
class LLM_Model_Base(object):
'''
语言模型基类
所有语言模型类的基类,定义了语言模型的基本属性和方法。
- 语言模型名称, 缺省为"gpt-4o-mini"
- 温度缺省为0.7
- 语言模型实例, 由子类实现
- 语言模型模式, 由子类实现
- 语言模型名称, 用于描述语言模型, 在人机界面中显示
author: DrGraph
date: 2025-11-20
'''
def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0.7):
self.model_name = model_name # 0.15 0.6
self.temperature = temperature
self.llmModel = None
self.mode = Constant.LLM_MODE_NONE
self.name = '未知模型'
def buildPromptTemplateValue(self, prompt: str, methodType: str, valueType: str):
logger.info(f"{self.name} >>> 1.1 用户输入: {type(prompt)}")
prompt_template = PromptTemplate.from_template(
template="请回答以下问题: {question}",
)
prompt_template_value = None
if methodType == "format":
# 方式1 - 使用format方法取得字符串
prompt_str = prompt_template.format(question=prompt) # prompt 为 字符串
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 format 方法取得字符串prompt_str, 然后再处理 - {type(prompt_str)} - {prompt_str}")
if valueType == "str":
# 1.1 直接用字符串进行调用LLM的invoke
prompt_template_value = prompt_str
logger.info(f"{self.name} >>> 1.2.1 直接使用字符串")
elif valueType == "messages":
# 1.2 由字符串创建HumanMessage对象列表
prompt_template_value = [HumanMessage(content=prompt)]
logger.info(f"{self.name} >>> 1.2.2 创建HumanMessage对象列表")
elif methodType == "invoke":
# 方式2 - 使用invoke方法取得PromptValue
prompt_value = prompt_template.invoke(input={"question" : prompt}) # prompt 为 langchain_core.prompt_values.StringPromptValue
logger.info(f"{self.name} >>> 1.2 通过PromptTemplate实例 invoke 方法取得PromptValue, 然后再处理 - {type(prompt_value)} - {prompt_value}")
if valueType == "str":
# 2.1 再倒回字符串方式
prompt_template_value = prompt_value.to_string()
logger.info(f"{self.name} >>> 1.2.1 由 PromptValue 转换为字符串")
elif valueType == "promptValue":
# 2.2 直接使用 prompt_value 作为 prompt_template_value
prompt_template_value = prompt_value
logger.info(f"{self.name} >>> 1.2.2 直接使用 PromptValue 作为 prompt_template_value")
elif valueType == "messages":
# 2.3 使用 prompt_value.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表
prompt_template_value = prompt_value.to_messages()
logger.info(f"{self.name} >>> 1.2.3 使用 PromptValue.to_messages() 方法,将 PromptValue 转换为 HumanMessage 对象列表")
logger.info(f"{self.name} >>> 1.3 用户输入 最终包装为(PromptValue/str/list of BaseMessages): {type(prompt_template_value)}\n{prompt_template_value}")
return prompt_template_value