|
|
@@ -15,6 +15,7 @@ import cv2 |
|
|
|
import matplotlib |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torchvision |
|
|
|
import yaml |
|
|
|
|
|
|
|
from utils.google_utils import gsutil_getsize |
|
|
@@ -323,7 +324,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, |
|
|
|
# Batched NMS |
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes |
|
|
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores |
|
|
|
i = torch.ops.torchvision.nms(boxes, scores, iou_thres) |
|
|
|
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS |
|
|
|
if i.shape[0] > max_det: # limit detections |
|
|
|
i = i[:max_det] |
|
|
|
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) |