|
|
@@ -494,7 +494,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c |
|
|
|
continue |
|
|
|
|
|
|
|
# Compute conf |
|
|
|
x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf |
|
|
|
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf |
|
|
|
|
|
|
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2) |
|
|
|
box = xywh2xyxy(x[:, :4]) |
|
|
@@ -502,10 +502,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c |
|
|
|
# Detections matrix nx6 (xyxy, conf, cls) |
|
|
|
if multi_label: |
|
|
|
i, j = (x[:, 5:] > conf_thres).nonzero().t() |
|
|
|
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) |
|
|
|
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) |
|
|
|
else: # best class only |
|
|
|
conf, j = x[:, 5:].max(1) |
|
|
|
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres] |
|
|
|
conf, j = x[:, 5:].max(1, keepdim=True) |
|
|
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] |
|
|
|
|
|
|
|
# Filter by class |
|
|
|
if classes: |
|
|
@@ -524,8 +524,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c |
|
|
|
# x = x[x[:, 4].argsort(descending=True)] |
|
|
|
|
|
|
|
# Batched NMS |
|
|
|
c = x[:, 5] * 0 if agnostic else x[:, 5] # classes |
|
|
|
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores |
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes |
|
|
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores |
|
|
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) |
|
|
|
if i.shape[0] > max_det: # limit detections |
|
|
|
i = i[:max_det] |