hxf/backend/th_agenter/api/endpoints/llm_configs.py

440 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM configuration management API endpoints."""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from sqlalchemy import or_, select, delete, update
from loguru import logger
from ...db.database import get_session
from ...models.user import User
from ...models.llm_config import LLMConfig
from ...core.simple_permissions import require_super_admin, require_authenticated_user
from ...schemas.llm_config import (
LLMConfigCreate, LLMConfigUpdate, LLMConfigResponse,
LLMConfigTest
)
from th_agenter.services.document_processor import get_document_processor
from utils.util_exceptions import HxfResponse
router = APIRouter(prefix="/llm-configs", tags=["llm-configs"])
@router.get("/", response_model=List[LLMConfigResponse], summary="获取大模型配置列表")
async def get_llm_configs(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
search: Optional[str] = Query(None),
provider: Optional[str] = Query(None),
is_active: Optional[bool] = Query(None),
is_embedding: Optional[bool] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取大模型配置列表."""
session.desc = f"START: 获取大模型配置列表, skip={skip}, limit={limit}, search={search}, provider={provider}, is_active={is_active}, is_embedding={is_embedding}"
stmt = select(LLMConfig)
# 搜索
if search:
stmt = stmt.where(
or_(
LLMConfig.name.ilike(f"%{search}%"),
LLMConfig.model_name.ilike(f"%{search}%"),
LLMConfig.description.ilike(f"%{search}%")
)
)
# 服务商筛选
if provider:
stmt = stmt.where(LLMConfig.provider == provider)
# 状态筛选
if is_active is not None:
stmt = stmt.where(LLMConfig.is_active == is_active)
# 模型类型筛选
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
# 排序
stmt = stmt.order_by(LLMConfig.name)
# 分页
stmt = stmt.offset(skip).limit(limit)
configs = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(configs)} 个大模型配置"
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
@router.get("/providers", summary="获取支持的大模型服务商列表")
async def get_llm_providers(
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取支持的大模型服务商列表."""
session.desc = "START: 获取支持的大模型服务商列表"
stmt = select(LLMConfig.provider).distinct()
providers = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(providers)} 个大模型服务商"
return HxfResponse([provider for provider in providers if provider])
@router.get("/active", response_model=List[LLMConfigResponse], summary="获取所有激活的大模型配置")
async def get_active_llm_configs(
is_embedding: Optional[bool] = Query(None),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取所有激活的大模型配置."""
session.desc = f"START: 获取所有激活的大模型配置, is_embedding={is_embedding}"
stmt = select(LLMConfig).where(LLMConfig.is_active == True)
if is_embedding is not None:
stmt = stmt.where(LLMConfig.is_embedding == is_embedding)
stmt = stmt.order_by(LLMConfig.created_at)
configs = (await session.execute(stmt)).scalars().all()
session.desc = f"SUCCESS: 获取 {len(configs)} 个激活的大模型配置"
return HxfResponse([config.to_dict(include_sensitive=True) for config in configs])
@router.get("/default", response_model=LLMConfigResponse, summary="获取默认大模型配置")
async def get_default_llm_config(
is_embedding: bool = Query(False, description="是否获取嵌入模型默认配置"),
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取默认大模型配置."""
session.desc = f"START: 获取默认大模型配置, is_embedding={is_embedding}"
stmt = select(LLMConfig).where(
LLMConfig.is_default == True,
LLMConfig.is_embedding == is_embedding,
LLMConfig.is_active == True
)
config = (await session.execute(stmt)).scalar_one_or_none()
if not config:
model_type = "嵌入模型" if is_embedding else "对话模型"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到默认{model_type}配置"
)
session.desc = f"SUCCESS: 获取默认大模型配置, is_embedding={is_embedding}"
return HxfResponse(config.to_dict(include_sensitive=True))
@router.get("/{config_id}", response_model=LLMConfigResponse, summary="获取大模型配置详情")
async def get_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_authenticated_user)
):
"""获取大模型配置详情."""
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
return HxfResponse(config.to_dict(include_sensitive=True))
@router.post("/", response_model=LLMConfigResponse, status_code=status.HTTP_201_CREATED, summary="创建大模型配置")
async def create_llm_config(
config_data: LLMConfigCreate,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""创建大模型配置."""
# 检查配置名称是否已存在
session.desc = f"START: 创建大模型配置, name={config_data.name}"
stmt = select(LLMConfig).where(LLMConfig.name == config_data.name)
existing_config = session.execute(stmt).scalar_one_or_none()
if existing_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="配置名称已存在"
)
# 创建临时配置对象进行验证
temp_config = LLMConfig(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
api_key=config_data.api_key,
base_url=config_data.base_url,
max_tokens=config_data.max_tokens,
temperature=config_data.temperature,
top_p=config_data.top_p,
frequency_penalty=config_data.frequency_penalty,
presence_penalty=config_data.presence_penalty,
description=config_data.description,
is_active=config_data.is_active,
is_default=config_data.is_default,
is_embedding=config_data.is_embedding,
extra_config=config_data.extra_config or {}
)
# 验证配置
validation_result = temp_config.validate_config()
if not validation_result['valid']:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=validation_result['error']
)
# 如果设为默认,取消同类型的其他默认配置
if config_data.is_default:
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == config_data.is_embedding
).values({"is_default": False})
session.execute(stmt)
# 创建配置
config = LLMConfig(
name=config_data.name,
provider=config_data.provider,
model_name=config_data.model_name,
api_key=config_data.api_key,
base_url=config_data.base_url,
max_tokens=config_data.max_tokens,
temperature=config_data.temperature,
top_p=config_data.top_p,
frequency_penalty=config_data.frequency_penalty,
presence_penalty=config_data.presence_penalty,
description=config_data.description,
is_active=config_data.is_active,
is_default=config_data.is_default,
is_embedding=config_data.is_embedding,
extra_config=config_data.extra_config or {}
)
# Audit fields are set automatically by SQLAlchemy event listener
session.add(config)
session.commit()
session.refresh(config)
session.desc = f"SUCCESS: 创建大模型配置, name={config.name} by user {current_user.username}"
return HxfResponse(config.to_dict())
@router.put("/{config_id}", response_model=LLMConfigResponse, summary="更新大模型配置")
async def update_llm_config(
config_id: int,
config_data: LLMConfigUpdate,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""更新大模型配置."""
session.desc = f"START: 更新大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 检查配置名称是否已存在(排除自己)
if config_data.name and config_data.name != config.name:
stmt = select(LLMConfig).where(
LLMConfig.name == config_data.name,
LLMConfig.id != config_id
)
existing_config = session.execute(stmt).scalar_one_or_none()
if existing_config:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="配置名称已存在"
)
# 如果设为默认,取消同类型的其他默认配置
if config_data.is_default is True:
# 获取当前配置的embedding类型如果更新中包含is_embedding则使用新值
is_embedding = config_data.is_embedding if config_data.is_embedding is not None else config.is_embedding
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
session.execute(stmt)
# 更新字段
update_data = config_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
session.commit()
session.refresh(config)
session.desc = f"SUCCESS: 更新大模型配置, id={config_id} by user {current_user.username}"
return HxfResponse(config.to_dict())
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除大模型配置")
async def delete_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""删除大模型配置."""
session.desc = f"START: 删除大模型配置, id={config_id}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# TODO: 检查是否有对话或其他功能正在使用该配置
# 这里可以添加相关的检查逻辑
session.delete(config)
session.commit()
session.desc = f"SUCCESS: 删除大模型配置, id={config_id} by user {current_user.username}"
return HxfResponse({"message": "LLM config deleted successfully"})
@router.post("/{config_id}/test", summary="测试连接大模型配置")
async def test_llm_config(
config_id: int,
test_data: LLMConfigTest,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""测试连接大模型配置."""
session.desc = f"TEST: 测试连接大模型配置 {config_id} by user {current_user.username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 验证配置
validation_result = config.validate_config()
if not validation_result["valid"]:
return {
"success": False,
"message": f"配置验证失败: {validation_result['error']}",
"details": validation_result
}
# 尝试创建客户端并发送测试请求
try:
# 这里应该根据不同的服务商创建相应的客户端
# 由于具体的客户端实现可能因服务商而异,这里提供一个通用的框架
test_message = test_data.message or "Hello, this is a test message."
# TODO: 实现具体的测试逻辑
# 例如:
# client = config.get_client()
# response = client.chat.completions.create(
# model=config.model_name,
# messages=[{"role": "user", "content": test_message}],
# max_tokens=100
# )
# 模拟测试成功
session.desc = f"SUCCESS: 模拟测试连接大模型配置 {config.name} by user {current_user.username}"
return HxfResponse({
"success": True,
"message": "配置测试成功",
"test_message": test_message,
"response": "这是一个模拟的测试响应。实际实现中,这里会是大模型的真实响应。",
"latency_ms": 150, # 模拟延迟
"config_info": config.get_client_config()
})
except Exception as test_error:
session.desc = f"ERROR: 测试连接大模型配置 {config.name} 失败, error: {str(test_error)}"
return HxfResponse({
"success": False,
"message": f"配置测试失败: {str(test_error)}",
"test_message": test_message,
"config_info": config.get_client_config()
})
@router.post("/{config_id}/toggle-status", summary="切换大模型配置状态")
async def toggle_llm_config_status(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""切换大模型配置状态."""
session.desc = f"START: 切换大模型配置状态, id={config_id} by user {current_user.username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 切换状态
config.is_active = not config.is_active
# Audit fields are set automatically by SQLAlchemy event listener
session.commit()
session.refresh(config)
status_text = "激活" if config.is_active else "禁用"
session.desc = f"SUCCESS: 切换大模型配置状态: {config.name} {status_text} by user {current_user.username}"
return HxfResponse({
"message": f"配置已{status_text}",
"is_active": config.is_active
})
@router.post("/{config_id}/set-default", summary="设置默认大模型配置")
async def set_default_llm_config(
config_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(require_super_admin)
):
"""设置默认大模型配置."""
session.desc = f"START: 设置大模型配置 {config_id} 为默认 by user {current_user.username}"
stmt = select(LLMConfig).where(LLMConfig.id == config_id)
config = session.execute(stmt).scalar_one_or_none()
if not config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="大模型配置不存在"
)
# 检查配置是否激活
if not config.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="只能将激活的配置设为默认"
)
# 取消同类型的其他默认配置
stmt = update(LLMConfig).where(
LLMConfig.is_embedding == config.is_embedding,
LLMConfig.id != config_id
).values({"is_default": False})
session.execute(stmt)
# 设置当前配置为默认
config.is_default = True
config.set_audit_fields(current_user.id, is_update=True)
session.commit()
session.refresh(config)
model_type = "嵌入模型" if config.is_embedding else "对话模型"
# 更新文档处理器默认embedding
get_document_processor()._init_embeddings()
session.desc = f"SUCCESS: 设置大模型配置 {config.name} ({model_type}) 为默认 by user {current_user.username}"
return HxfResponse({
"message": f"已将 {config.name} 设为默认{model_type}配置",
"is_default": config.is_default
})