|
|
@@ -40,7 +40,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): |
|
|
|
bpr = (best > 1 / thr).float().mean() # best possible recall |
|
|
|
return bpr, aat |
|
|
|
|
|
|
|
anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors |
|
|
|
stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides |
|
|
|
anchors = m.anchors.clone() * stride # current anchors |
|
|
|
bpr, aat = metric(anchors.cpu().view(-1, 2)) |
|
|
|
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). ' |
|
|
|
if bpr > 0.98: # threshold to recompute |
|
|
@@ -55,8 +56,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): |
|
|
|
new_bpr = metric(anchors)[0] |
|
|
|
if new_bpr > bpr: # replace anchors |
|
|
|
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) |
|
|
|
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss |
|
|
|
check_anchor_order(m) |
|
|
|
m.anchors[:] = anchors.clone().view_as(m.anchors) |
|
|
|
check_anchor_order(m) # must be in pixel-space (not grid-space) |
|
|
|
m.anchors /= stride |
|
|
|
s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)' |
|
|
|
else: |
|
|
|
s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)' |