hxf/backend/th_agenter/llm/base_llm.py

199 lines
9.2 KiB
Python
Raw Normal View History

2026-01-07 11:30:54 +08:00
from loguru import logger
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
# 核心:导入 LangChain 的基础语言模型抽象类
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.callbacks import CallbackManagerForLLMRun
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List
from datetime import datetime
@dataclass
class LLMConfig_DataClass:
"""
统一的LLM配置基类覆盖在线/本地/嵌入式模型所有配置映射数据库完整字段
通过 provider + is_embedding 区分模型类型
- 在线模型provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
- 本地模型provider in ['llama', 'qwen', 'yi'] + is_embedding=False
- 嵌入式模型provider in ['bge', 'text2vec'] + is_embedding=True
"""
# ====================== 数据库核心公共字段(必选/可选) ======================
# 基础标识字段
name: str # 模型自定义名称(如 "gpt-5"
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5"
provider: str # 提供商openai/llama/bge/zhipu 等)
id: Optional[int] = None # 数据库主键ID
description: Optional[str] = None # 模型描述
is_active: bool = True # 是否启用
is_default: bool = False # 是否默认模型
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
# ====================== 通用生成参数(所有推理模型共用) ======================
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
top_p: float = 0.6 # 采样Top-P
frequency_penalty: float = 0.0 # 频率惩罚
presence_penalty: float = 0.0 # 存在惩罚
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
api_key: Optional[str] = None # API密钥在线模型必填
base_url: Optional[str] = None # API代理地址如 https://api.openai-proxy.org/v1
# timeout: int = 30 # 请求超时时间(秒)
max_retries: int = 3 # 最大重试次数
api_version: Optional[str] = None # API版本如 OpenAI 的 2024-02-15-preview
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
device: str = "cpu" # 运行设备cpu/cuda/mps
n_ctx: int = 2048 # 上下文窗口大小
n_threads: int = 8 # 推理线程数
quantization: str = "q4_0" # 量化级别q4_0/q8_0/f16
load_in_8bit: bool = False # 是否8bit加载
load_in_4bit: bool = False # 是否4bit加载
prompt_template: Optional[str] = None # 自定义Prompt模板
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
normalize_embeddings: bool = True # 是否归一化向量
batch_size: int = 32 # 批量编码大小
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
dimension: Optional[int] = None # 向量维度(如 768
# ====================== 元数据字段(数据库自动维护) ======================
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
usage_count: int = 0 # 使用次数
last_used_at: Optional[datetime] = None # 最后使用时间
created_at: Optional[datetime] = None # 创建时间
updated_at: Optional[datetime] = None # 更新时间
created_by: Optional[int] = None # 创建人ID
updated_by: Optional[int] = None # 更新人ID
api_key_masked: Optional[str] = "" # 掩码后的API密钥数据库存储
# ====================== 核心工具方法 ======================
def __post_init__(self):
"""后置初始化:自动校验和修正配置"""
# 1. 嵌入式模型强制清空推理参数(避免误用)
if self.is_embedding:
self.max_tokens = 0
self.temperature = 0.0
self.top_p = 0.0
# 2. 校验必填参数(按模型类型)
self._validate_required_fields()
def _validate_required_fields(self):
"""按模型类型校验必填参数"""
# 在线模型校验
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
if not self.api_key:
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
# 本地模型校验
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
if not self.model_path:
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于存入/更新数据库)"""
return {
key: value for key, value in self.__dict__.items()
if not key.startswith('_') # 排除私有属性
}
@classmethod
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
"""从数据库字典初始化配置(核心方法)"""
# 1. 时间字段转换:字符串 → datetime
time_fields = ['last_used_at', 'created_at', 'updated_at']
for field_name in time_fields:
val = db_dict.get(field_name)
if val and isinstance(val, str):
try:
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
except (ValueError, TypeError):
db_dict[field_name] = None
# 2. 过滤数据库中无关字段(如 api_key_masked
valid_fields = cls.__dataclass_fields__.keys()
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
# 3. 初始化并返回配置实例
return cls(**filtered_dict)
def get_model_type(self) -> str:
"""快速判断模型类型返回online/local/embedding"""
if self.is_embedding:
return "embedding"
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
return "online"
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
return "local"
return "unknown"
class BaseLLM(BaseChatModel):
"""
继承 LangChain BaseChatModelBaseLanguageModel 的子类
使其能直接用于 create_agent
"""
# 配置参数(通过 __init__ 初始化)
config: Any = None
model: Any = None
def __init__(self, config):
super().__init__() # 必须调用父类构造函数
self.config = config
self.model = None
self._validate_config()
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
# ---------------------- 必须实现的核心抽象方法LangChain 协议) ----------------------
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
核心同步生成方法LangChain 要求必须实现
messages: 消息列表 [HumanMessage(content="你好")]
返回 ChatResult 类型LangChain 标准输出
"""
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
** kwargs: Any,
) -> ChatResult:
"""异步生成方法LangChain 异步协议)"""
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
@property
def _llm_type(self) -> str:
"""返回模型类型标识(如 "openai""llama""bge""""
return self.__class__.__name__
def load_model(self) -> None:
"""加载模型(自定义逻辑)"""
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
def close(self) -> None:
"""释放资源(自定义逻辑)"""
if self.model:
logger.info(f"释放 {self.__class__.__name__} 模型资源")
self.model = None
def __enter__(self):
self.load_model()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()