|
|
@@ -17,7 +17,6 @@ import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torchvision |
|
|
|
import yaml |
|
|
|
from scipy.cluster.vq import kmeans |
|
|
|
from scipy.signal import butter, filtfilt |
|
|
@@ -651,7 +650,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, |
|
|
|
# Batched NMS |
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes |
|
|
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores |
|
|
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |
|
|
|
i = torch.ops.torchvision.nms(boxes, scores, iou_thres) |
|
|
|
if i.shape[0] > max_det: # limit detections |
|
|
|
i = i[:max_det] |
|
|
|
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) |