ソースを参照

NMS updates

5.0
Glenn Jocher 4年前
コミット
66d73e4cf3
1個のファイルの変更6行の追加6行の削除
  1. +6
    -6
      utils/utils.py

+ 6
- 6
utils/utils.py ファイルの表示

@@ -494,7 +494,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
continue

# Compute conf
x[..., 5:] *= x[..., 4:5] # conf = obj_conf * cls_conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
@@ -502,10 +502,10 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero().t()
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1)
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres]
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

# Filter by class
if classes:
@@ -524,8 +524,8 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
# x = x[x[:, 4].argsort(descending=True)]

# Batched NMS
c = x[:, 5] * 0 if agnostic else x[:, 5] # classes
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
if i.shape[0] > max_det: # limit detections
i = i[:max_det]

読み込み中…
キャンセル
保存