Use torchvision.ops.nms (#1460)
This commit is contained in:
parent
199c9c7874
commit
394131c2aa
|
|
@ -15,6 +15,7 @@ import cv2
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from utils.google_utils import gsutil_getsize
|
from utils.google_utils import gsutil_getsize
|
||||||
|
|
@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
|
||||||
# Batched NMS
|
# Batched NMS
|
||||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||||
i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||||
if i.shape[0] > max_det: # limit detections
|
if i.shape[0] > max_det: # limit detections
|
||||||
i = i[:max_det]
|
i = i[:max_det]
|
||||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue