199 lines
9.4 KiB
Python
199 lines
9.4 KiB
Python
from loguru import logger
|
||
from typing import List, Dict, Optional, Union, AsyncGenerator, Generator, Any
|
||
|
||
# 核心:导入 LangChain 的基础语言模型抽象类
|
||
from langchain_core.language_models import BaseLanguageModel
|
||
from langchain_core.language_models.chat_models import BaseChatModel
|
||
from langchain_core.messages import BaseMessage
|
||
from langchain_core.outputs import ChatResult
|
||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime
|
||
|
||
@dataclass
|
||
class LLMConfig_DataClass:
|
||
"""
|
||
统一的LLM配置基类,覆盖在线/本地/嵌入式模型所有配置,映射数据库完整字段
|
||
通过 provider + is_embedding 区分模型类型:
|
||
- 在线模型:provider in ['openai', 'zhipu', 'baidu'] + is_embedding=False
|
||
- 本地模型:provider in ['llama', 'qwen', 'yi'] + is_embedding=False
|
||
- 嵌入式模型:provider in ['bge', 'text2vec'] + is_embedding=True
|
||
"""
|
||
# ====================== 数据库核心公共字段(必选/可选) ======================
|
||
# 基础标识字段
|
||
name: str # 模型自定义名称(如 "gpt-5")
|
||
model_name: str # 模型官方标识名(如 "gpt-5"、"BAAI/bge-small-zh-v1.5")
|
||
provider: str # 提供商(openai/llama/bge/zhipu 等)
|
||
id: Optional[int] = None # 数据库主键ID
|
||
description: Optional[str] = None # 模型描述
|
||
is_active: bool = True # 是否启用
|
||
is_default: bool = False # 是否默认模型
|
||
is_embedding: bool = False # 是否为嵌入式模型(核心区分标识)
|
||
|
||
# ====================== 通用生成参数(所有推理模型共用) ======================
|
||
temperature: float = 0.7 # 生成温度(默认值对齐数据库示例)
|
||
max_tokens: int = 3000 # 最大生成长度(默认值对齐数据库示例)
|
||
top_p: float = 0.6 # 采样Top-P
|
||
frequency_penalty: float = 0.0 # 频率惩罚
|
||
presence_penalty: float = 0.0 # 存在惩罚
|
||
|
||
# ====================== 在线模型专属参数(非必填,仅在线模型生效) ======================
|
||
api_key: Optional[str] = None # API密钥(在线模型必填)
|
||
base_url: Optional[str] = None # API代理地址(如 https://api.openai-proxy.org/v1)
|
||
# timeout: int = 30 # 请求超时时间(秒)
|
||
max_retries: int = 3 # 最大重试次数
|
||
api_version: Optional[str] = None # API版本(如 OpenAI 的 2024-02-15-preview)
|
||
|
||
# ====================== 本地模型专属参数(非必填,仅本地模型生效) ======================
|
||
model_path: Optional[str] = None # 本地模型文件路径(本地模型必填)
|
||
device: str = "cpu" # 运行设备(cpu/cuda/mps)
|
||
n_ctx: int = 2048 # 上下文窗口大小
|
||
n_threads: int = 8 # 推理线程数
|
||
quantization: str = "q4_0" # 量化级别(q4_0/q8_0/f16)
|
||
load_in_8bit: bool = False # 是否8bit加载
|
||
load_in_4bit: bool = False # 是否4bit加载
|
||
prompt_template: Optional[str] = None # 自定义Prompt模板
|
||
|
||
# ====================== 嵌入式模型专属参数(非必填,仅嵌入式模型生效) ======================
|
||
normalize_embeddings: bool = True # 是否归一化向量
|
||
batch_size: int = 32 # 批量编码大小
|
||
encode_kwargs: Dict[str, Any] = field(default_factory=dict) # 编码扩展参数
|
||
dimension: Optional[int] = None # 向量维度(如 768)
|
||
|
||
# ====================== 元数据字段(数据库自动维护) ======================
|
||
extra_config: Dict[str, Any] = field(default_factory=dict) # 额外扩展配置
|
||
usage_count: int = 0 # 使用次数
|
||
last_used_at: Optional[datetime] = None # 最后使用时间
|
||
created_at: Optional[datetime] = None # 创建时间
|
||
updated_at: Optional[datetime] = None # 更新时间
|
||
created_by: Optional[int] = None # 创建人ID
|
||
updated_by: Optional[int] = None # 更新人ID
|
||
|
||
api_key_masked: Optional[str] = "" # 掩码后的API密钥(数据库存储)
|
||
|
||
# ====================== 核心工具方法 ======================
|
||
def __post_init__(self):
|
||
"""后置初始化:自动校验和修正配置"""
|
||
# 1. 嵌入式模型强制清空推理参数(避免误用)
|
||
if self.is_embedding:
|
||
self.max_tokens = 0
|
||
self.temperature = 0.0
|
||
self.top_p = 0.0
|
||
|
||
# 2. 校验必填参数(按模型类型)
|
||
self._validate_required_fields()
|
||
|
||
def _validate_required_fields(self):
|
||
"""按模型类型校验必填参数"""
|
||
# 在线模型校验
|
||
if not self.is_embedding and self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||
if not self.api_key:
|
||
raise ValueError(f"[{self.name}] 在线模型({self.provider})必须配置 api_key")
|
||
|
||
# 本地模型校验
|
||
if not self.is_embedding and self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||
if not self.model_path:
|
||
raise ValueError(f"[{self.name}] 本地模型({self.provider})必须配置 model_path")
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典(用于存入/更新数据库)"""
|
||
return {
|
||
key: value for key, value in self.__dict__.items()
|
||
if not key.startswith('_') # 排除私有属性
|
||
}
|
||
|
||
@classmethod
|
||
def from_db_dict(cls, db_dict: Dict[str, Any]) -> "LLMConfig_DataClass":
|
||
"""从数据库字典初始化配置(核心方法)"""
|
||
# 1. 时间字段转换:字符串 → datetime
|
||
time_fields = ['last_used_at', 'created_at', 'updated_at']
|
||
for field_name in time_fields:
|
||
val = db_dict.get(field_name)
|
||
if val and isinstance(val, str):
|
||
try:
|
||
db_dict[field_name] = datetime.fromisoformat(val.replace('Z', '+00:00'))
|
||
except (ValueError, TypeError):
|
||
db_dict[field_name] = None
|
||
|
||
# 2. 过滤数据库中无关字段(如 api_key_masked)
|
||
valid_fields = cls.__dataclass_fields__.keys()
|
||
filtered_dict = {k: v for k, v in db_dict.items() if k in valid_fields}
|
||
|
||
# 3. 初始化并返回配置实例
|
||
return cls(**filtered_dict)
|
||
|
||
def get_model_type(self) -> str:
|
||
"""快速判断模型类型(返回:online/local/embedding)"""
|
||
if self.is_embedding:
|
||
return "embedding"
|
||
if self.provider in ['openai', 'zhipu', 'baidu', 'anthropic']:
|
||
return "online"
|
||
if self.provider in ['llama', 'qwen', 'yi', 'glm', 'mistral']:
|
||
return "local"
|
||
return "unknown"
|
||
|
||
|
||
class BaseLLM(BaseChatModel):
|
||
"""
|
||
继承 LangChain 的 BaseChatModel(BaseLanguageModel 的子类)
|
||
使其能直接用于 create_agent
|
||
"""
|
||
# 配置参数(通过 __init__ 初始化)
|
||
config: Any = None
|
||
model: Any = None
|
||
|
||
def __init__(self, config):
|
||
super().__init__() # 必须调用父类构造函数
|
||
self.config = config
|
||
self.model = None
|
||
self._validate_config()
|
||
logger.info(f"初始化 {self.__class__.__name__},模型: {config.model_name}")
|
||
|
||
# ---------------------- 必须实现的核心抽象方法(LangChain 协议) ----------------------
|
||
def _generate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
**kwargs: Any,
|
||
) -> ChatResult:
|
||
"""
|
||
核心同步生成方法(LangChain 要求必须实现)
|
||
messages: 消息列表(如 [HumanMessage(content="你好")])
|
||
返回 ChatResult 类型(LangChain 标准输出)
|
||
"""
|
||
logger.error(f"{self.__class__.__name__} 未实现 同步 _generate 方法")
|
||
|
||
async def _agenerate(
|
||
self,
|
||
messages: List[BaseMessage],
|
||
stop: Optional[List[str]] = None,
|
||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
** kwargs: Any,
|
||
) -> ChatResult:
|
||
"""异步生成方法(LangChain 异步协议)"""
|
||
logger.error(f"{self.__class__.__name__} 未实现 异步 _agenerate 方法")
|
||
|
||
@property
|
||
def _llm_type(self) -> str:
|
||
"""返回模型类型标识(如 "openai"、"llama"、"bge")"""
|
||
return self.__class__.__name__
|
||
|
||
def load_model(self) -> None:
|
||
"""加载模型(自定义逻辑)"""
|
||
logger.error(f"{self.__class__.__name__} 未实现 load_model 方法")
|
||
|
||
def close(self) -> None:
|
||
"""释放资源(自定义逻辑)"""
|
||
if self.model:
|
||
logger.info(f"释放 {self.__class__.__name__} 模型资源")
|
||
self.model = None
|
||
|
||
def __enter__(self):
|
||
self.load_model()
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.close()
|