浏览代码

Bug fix mAP0.5-0.95 (#6787)

* Improve mAP0.5-0.95

Two changes provided
1. Added limit on the maximum number of detections for each image likewise pycocotools
2. Rework process_batch function

Changes #2 solved issue #4251
I also independently encountered the problem described in issue #4251 that the values for the same thresholds do not match when changing the limits in the torch.linspace function.
These changes solve this problem.

Currently during validation yolov5x.pt model the following results were obtained:
from yolov5 validation
               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 157/157 [01:07<00:00,  2.33it/s]
                 all       5000      36335      0.743      0.626      0.682      0.506
from pycocotools
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.505
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.685

These results are very close, although not completely pass the competition issue #2258.
I think it's problem with false positive bboxes matched ignored criteria, but this is not actual for custom datasets and does not require an additional solution.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove line to retain pycocotools results

* Update val.py

* Update val.py

* Remove to device op

* Higher precision int conversion

* Update val.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Anton Lebedev GitHub 2 年前
父节点
当前提交
43569d53da
找不到此签名对应的密钥 GPG 密钥 ID: 4AEE18F83AFDEB23
共有 2 个文件被更改,包括 14 次插入13 次删除
  1. +2
    -2
      utils/metrics.py
  2. +12
    -11
      val.py

+ 2
- 2
utils/metrics.py 查看文件

@@ -90,7 +90,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
return tp, fp, p, r, f1, ap, unique_classes.astype('int32')
return tp, fp, p, r, f1, ap, unique_classes.astype(int)


def compute_ap(recall, precision):
@@ -156,7 +156,7 @@ class ConfusionMatrix:
matches = np.zeros((0, 3))

n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(np.int16)
m0, m1, _ = matches.transpose().astype(int)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:

+ 12
- 11
val.py 查看文件

@@ -79,16 +79,17 @@ def process_batch(detections, labels, iouv):
"""
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(iouv)):
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return correct


@@ -265,7 +266,7 @@ def run(
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
else:
nt = torch.zeros(1)


正在加载...
取消
保存