hyf-backend/th_agenter/api/endpoints/llm_configs.py

473 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM configuration management API endpoints."""
from turtle import textinput
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage
from sqlalchemy.orm import Session
from sqlalchemy import or_, select, delete, update
from loguru import logger
from th_agenter.llm.embed.embed_llm import BGEEmbedLLM, EmbedLLM
from th_agenter.llm.online.online_llm import OnlineLLM
from ...db.database import get_session
from ...models.user import User
from ...models.llm_config import LLMConfig
from th_agenter.llm.base_llm import LLMConfig_DataClass
from ...core.simple_permissions import require_super_admin, require_authenticated_user
from ...schemas.llm_config import (
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
LLMConfigTest
)
from th_agenter.services.document_processor import get_document_processor
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表")
async def get_llm_configs(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None),
provider: Optional[str] = Query(None),
is_active: Optional[bool] = Query(None),
is_embedding: Optional[bool] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取大模型配置列表."""
session.title = "获取大模型配置列表"
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
stmt = select(LLMConfig)
# 搜索
if search:
stmt = stmt.where(
or_(
LLMConfig.name.ilike(f"%{search}%"),
LLMConfig.model_name.ilike(f"%{search}%"),
LLMConfig.description.ilike(f"%{search}%")
)
)
# 服务商筛选
if provider:
stmt = stmt.where(LLMConfig.provider == provider)
# 状态筛选
if is_active is not None:
stmt = stmt.where(LLMConfig.is_active == is_active)
# 模型类型筛选
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
# 排序
stmt = stmt.order_by(LLMConfig.name)
# 分页
stmt = stmt.offset(skip).limit(limit)
configs = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置 ..."
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
@router.get("/providers", summary="获取支持的大模型服务商列表")
async def get_llm_providers(
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取支持的大模型服务商列表."""
session.desc = "START: 获取支持的大模型服务商列表"
stmt = select(LLMConfig.provider).distinct()
providers = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商"
return HxfResponse([provider for provider in providers if provider])
@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置")
async def get_active_llm_configs(
is_embedding: Optional[bool] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取所有激活的大模型配置."""
session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}"
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
stmt = stmt.order_by(LLMConfig.created_at)
configs = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置"
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置")
async def get_default_llm_config(
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取默认大模型配置."""
session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}"
stmt = select(LLMConfig).where(
LLMConfig.is_default == True,
LLMConfig.is_embedding == is_embedding,
LLMConfig.is_active == True
)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
model_type = "嵌入模型" if is_embedding else "对话模型"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到默认{model_type}配置"
)
session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}"
return HxfResponse(config.to_dict(include_sensitive=True))
@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情")
async def get_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取大模型配置详情."""
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
return HxfResponse(config.to_dict(include_sensitive=True))
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置")
async def create_llm_config(
config_data: LLMConfigCreate,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""创建大模型配置."""
# 检查配置名称是否已存在
# 先保存当前用户名避免在refresh后访问可能导致MissingGreenlet错误
username = current_user.username
session.desc = f"START: 创建大模型配置, name={config_data.name}"
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
existing_config = (await session.execute(stmt)).scalar_one_or_none()
if existing_config:
session.desc = f"ERROR: 配置名称已存在, name={config_data.name}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="配置名称已存在"
)
# 创建配置对象
config = LLMConfig_DataClass(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
api_key=config_data.api_key,
base_url=config_data.base_url,
max_tokens=config_data.max_tokens,
temperature=config_data.temperature,
top_p=config_data.top_p,
frequency_penalty=config_data.frequency_penalty,
presence_penalty=config_data.presence_penalty,
description=config_data.description,
is_active=config_data.is_active,
is_default=config_data.is_default,
is_embedding=config_data.is_embedding,
extra_config=config_data.extra_config or {}
)
# 验证配置
validation_result = config.validate_config()
if not validation_result['valid']:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=validation_result['error']
)
# 如果设为默认,取消同类型的其他默认配置
if config_data.is_default:
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == config_data.is_embedding
).values({"is_default": False})
await session.execute(stmt)
session.desc = f"验证大模型配置, config_data"
# 创建配置
config = LLMConfig_DataClass(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
api_key=config_data.api_key,
base_url=config_data.base_url,
max_tokens=config_data.max_tokens,
temperature=config_data.temperature,
top_p=config_data.top_p,
frequency_penalty=config_data.frequency_penalty,
presence_penalty=config_data.presence_penalty,
description=config_data.description,
is_active=config_data.is_active,
is_default=config_data.is_default,
is_embedding=config_data.is_embedding,
extra_config=config_data.extra_config or {}
)
# Audit fields are set automatically by SQLAlchemy event listener
session.add(config)
await session.commit()
await session.refresh(config)
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {username}"
return HxfResponse(config.to_dict())
@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置")
async def update_llm_config(
config_id: int,
config_data: LLMConfigUpdate,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""更新大模型配置."""
username = current_user.username
session.desc = f"START: 更新大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 检查配置名称是否已存在(排除自己)
if config_data.name and config_data.name != config.name:
stmt = select(LLMConfig).where(
LLMConfig.name == config_data.name,
LLMConfig.id != config_id
)
existing_config = (await session.execute(stmt)).scalar_one_or_none()
if existing_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="配置名称已存在"
)
# 如果设为默认,取消同类型的其他默认配置
if config_data.is_default is True:
# 获取当前配置的embedding类型如果更新中包含is_embedding则使用新值
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
await session.execute(stmt)
# 更新字段
update_data = config_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await session.commit()
await session.refresh(config)
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {username}"
return HxfResponse(config.to_dict())
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
async def delete_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""删除大模型配置."""
username = current_user.username
session.desc = f"START: 删除大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
session.desc = f"待删除大模型记录 {config.to_dict()}"
# TODO: 检查是否有对话或其他功能正在使用该配置
# 这里可以添加相关的检查逻辑
# 删除配置
await session.delete(config)
await session.commit()
session.desc = f"SUCCESS: 删除大模型配置成功, id={config_id} by user {username}"
return HxfResponse({"message": "LLM config deleted successfully"})
@router.post("/{config_id}/test", summary="测试连接大模型配置")
async def test_llm_config(
config_id: int,
test_data: LLMConfigTest,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""测试连接大模型配置."""
username = current_user.username
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
logger.info(f"TEST: 测试连接大模型配置 {config_id} by user {username}")
config_name = config.name
# 验证配置
validation_result = config.validate_config()
logger.info(f"TEST: 验证大模型配置 {config_name} validation_result = {validation_result}")
if not validation_result["valid"]:
return {
"success": False,
"message": f"配置验证失败: {validation_result['error']}",
"details": validation_result
}
session.desc = f"准备测试LLM功能 > 测试连接大模型配置 {config.to_dict()}"
# 尝试创建客户端并发送测试请求
try:
# # 这里应该根据不同的服务商创建相应的客户端
# # 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
test_message = test_data.message or "Hello, this is a test message."
session.desc = f"准备测试LLM功能 > test_message = {test_message}"
if config.is_embedding:
config.provider = "ollama"
streaming_llm = BGEEmbedLLM(config)
else:
streaming_llm = OnlineLLM(config)
session.desc = f"创建{'EmbeddingLLM' if config.is_embedding else 'OnlineLLM'}完毕 > 测试连接大模型配置 {config.to_dict()}"
streaming_llm.load_model() # 加载模型
session.desc = f"加载模型完毕,模型名称:{config.model_name}base_url: {config.base_url},准备测试对话..."
if config.is_embedding:
# 测试嵌入模型使用嵌入API而非聊天API
test_text = test_message or "Hello, this is a test message for embedding"
response = streaming_llm.embed_query(test_text)
else:
# 测试聊天模型
from langchain.messages import SystemMessage, HumanMessage
messages = [
SystemMessage(content="你是一个简洁的助手回答控制在50字以内"),
HumanMessage(content=test_message)
]
response = streaming_llm.model.invoke(messages)
session.desc = f"测试连接大模型配置 {config_name} 成功 >>> 响应: {type(response)}"
return HxfResponse({
"success": True,
"message": "LLM测试成功",
"request": test_message,
"response": response.content if hasattr(response, 'content') else response, # 使用转换后的字典
"latency_ms": 150, # 模拟延迟
"config_info": config.to_dict()
})
except Exception as test_error:
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
return HxfResponse({
"success": False,
"message": f"LLM测试失败: {str(test_error)}",
"test_message": test_message,
"config_info": config.to_dict()
})
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
async def toggle_llm_config_status(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""切换大模型配置状态."""
username = current_user.username
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 切换状态
config.is_active = not config.is_active
# Audit fields are set automatically by SQLAlchemy event listener
await session.commit()
await session.refresh(config)
status_text = "激活" if config.is_active else "禁用"
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {username}"
return HxfResponse({
"message": f"配置已{status_text}",
"is_active": config.is_active
})
@router.post("/{config_id}/set-default", summary="设置默认大模型配置")
async def set_default_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""设置默认大模型配置."""
username = current_user.username
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 检查配置是否激活
if not config.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="只能将激活的配置设为默认"
)
# 取消同类型的其他默认配置
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == config.is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
await session.execute(stmt)
# 设置当前配置为默认
config.is_default = True
config.set_audit_fields(current_user.id, is_update=True)
await session.commit()
await session.refresh(config)
model_type = "嵌入模型" if config.is_embedding else "对话模型"
# 更新文档处理器默认embedding
await get_document_processor(session)._init_embeddings()
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {username}"
return HxfResponse({
"message": f"已将 {config.name} 设为默认{model_type}配置",
"is_default": config.is_default
})