hxf/backend/th_agenter/llm/base_llm.py

199 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from loguru import logger
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
# 核心:导入 LangChain 的基础语言模型抽象类
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from datetime import datetime
@dataclass
class LLMConfig_DataClass:
"""
统一的LLM配置基类覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段
通过 provider + is_embedding 区分模型类型:
- 在线模型provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
- 本地模型provider in ['llama', 'qwen', 'yi'] + is_embedding=False
- 嵌入式模型provider in ['bge', 'text2vec'] + is_embedding=True
"""
# ====================== 数据库核心公共字段(必选/可选) ======================
# 基础标识字段
name: str # 模型自定义名称(如 "gpt-5"
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5"
provider: str # 提供商openai/llama/bge/zhipu 等)
id: Optional[int] = None # 数据库主键ID
description: Optional[str] = None # 模型描述
is_active: bool = True # 是否启用
is_default: bool = False # 是否默认模型
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
# ====================== 通用生成参数(所有推理模型共用) ======================
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
top_p: float = 0.6 # 采样Top-P
frequency_penalty: float = 0.0 # 频率惩罚
presence_penalty: float = 0.0 # 存在惩罚
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
api_key: Optional[str] = None # API密钥在线模型必填
base_url: Optional[str] = None # API代理地址如 https://api.openai-proxy.org/v1
# timeout: int = 30 # 请求超时时间(秒)
max_retries: int = 3 # 最大重试次数
api_version: Optional[str] = None # API版本如 OpenAI 的 2024-02-15-preview
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
device: str = "cpu" # 运行设备cpu/cuda/mps
n_ctx: int = 2048 # 上下文窗口大小
n_threads: int = 8 # 推理线程数
quantization: str = "q4_0" # 量化级别q4_0/q8_0/f16
load_in_8bit: bool = False # 是否8bit加载
load_in_4bit: bool = False # 是否4bit加载
prompt_template: Optional[str] = None # 自定义Prompt模板
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
normalize_embeddings: bool = True # 是否归一化向量
batch_size: int = 32 # 批量编码大小
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
dimension: Optional[int] = None # 向量维度(如 768
# ====================== 元数据字段(数据库自动维护) ======================
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
usage_count: int = 0 # 使用次数
last_used_at: Optional[datetime] = None # 最后使用时间
created_at: Optional[datetime] = None # 创建时间
updated_at: Optional[datetime] = None # 更新时间
created_by: Optional[int] = None # 创建人ID
updated_by: Optional[int] = None # 更新人ID
api_key_masked: Optional[str] = "" # 掩码后的API密钥数据库存储
# ====================== 核心工具方法 ======================
def __post_init__(self):
"""后置初始化:自动校验和修正配置"""
# 1. 嵌入式模型强制清空推理参数(避免误用)
if self.is_embedding:
self.max_tokens = 0
self.temperature = 0.0
self.top_p = 0.0
# 2. 校验必填参数(按模型类型)
self._validate_required_fields()
def _validate_required_fields(self):
"""按模型类型校验必填参数"""
# 在线模型校验
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
if not self.api_key:
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
# 本地模型校验
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
if not self.model_path:
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存入/更新数据库)"""
return {
key: value for key, value in self.__dict__.items()
if not key.startswith('_') # 排除私有属性
}
@classmethod
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
"""从数据库字典初始化配置(核心方法)"""
# 1. 时间字段转换:字符串 → datetime
time_fields = ['last_used_at', 'created_at', 'updated_at']
for field_name in time_fields:
val = db_dict.get(field_name)
if val and isinstance(val, str):
try:
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
except (ValueError, TypeError):
db_dict[field_name] = None
# 2. 过滤数据库中无关字段(如 api_key_masked
valid_fields = cls.__dataclass_fields__.keys()
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
# 3. 初始化并返回配置实例
return cls(**filtered_dict)
def get_model_type(self) -> str:
"""快速判断模型类型返回online/local/embedding"""
if self.is_embedding:
return "embedding"
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
return "online"
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
return "local"
return "unknown"
class BaseLLM(BaseChatModel):
"""
继承 LangChain 的 BaseChatModelBaseLanguageModel 的子类)
使其能直接用于 create_agent
"""
# 配置参数(通过 __init__ 初始化)
config: Any = None
model: Any = None
def __init__(self, config):
super().__init__() # 必须调用父类构造函数
self.config = config
self.model = None
self._validate_config()
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
# ---------------------- 必须实现的核心抽象方法LangChain 协议) ----------------------
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
核心同步生成方法LangChain 要求必须实现)
messages: 消息列表(如 [HumanMessage(content="你好")]
返回 ChatResult 类型LangChain 标准输出)
"""
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
** kwargs: Any,
) -> ChatResult:
"""异步生成方法LangChain 异步协议)"""
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
@property
def _llm_type(self) -> str:
"""返回模型类型标识(如 "openai""llama""bge""""
return self.__class__.__name__
def load_model(self) -> None:
"""加载模型(自定义逻辑)"""
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
def close(self) -> None:
"""释放资源(自定义逻辑)"""
if self.model:
logger.info(f"释放 {self.__class__.__name__} 模型资源")
self.model = None
def __enter__(self):
self.load_model()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()