Use torchvision.ops.nms (#1460)

This commit is contained in:
Glenn Jocher 2020-11-20 11:23:36 +01:00 committed by GitHub
parent 199c9c7874
commit 394131c2aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -15,6 +15,7 @@ import cv2
import matplotlib import matplotlib
import numpy as np import numpy as np
import torch import torch
import torchvision
import yaml import yaml
from utils.google_utils import gsutil_getsize from utils.google_utils import gsutil_getsize
@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
# Batched NMS # Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torch.ops.torchvision.nms(boxes, scores, iou_thres) i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections if i.shape[0] > max_det: # limit detections
i = i[:max_det] i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)