Use torchvision.ops.nms (#1460)
This commit is contained in:
parent
199c9c7874
commit
394131c2aa
|
|
@ -15,6 +15,7 @@ import cv2
|
|||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import yaml
|
||||
|
||||
from utils.google_utils import gsutil_getsize
|
||||
|
|
@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
|||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
if i.shape[0] > max_det: # limit detections
|
||||
i = i[:max_det]
|
||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
|
|
|
|||
Loading…
Reference in New Issue