|
|
@@ -58,7 +58,8 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): |
|
|
|
print('\nAnalyzing anchors... ', end='') |
|
|
|
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() |
|
|
|
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) |
|
|
|
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh |
|
|
|
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale |
|
|
|
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh |
|
|
|
|
|
|
|
def metric(k): # compute metric |
|
|
|
r = wh[:, None] / k[None] |
|
|
@@ -77,12 +78,23 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640): |
|
|
|
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) |
|
|
|
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference |
|
|
|
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss |
|
|
|
check_anchor_order(m) |
|
|
|
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') |
|
|
|
else: |
|
|
|
print('Original anchors better than new anchors. Proceeding with original anchors.') |
|
|
|
print('') # newline |
|
|
|
|
|
|
|
|
|
|
|
def check_anchor_order(m): |
|
|
|
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary |
|
|
|
a = m.anchor_grid.prod(-1).view(-1) # anchor area |
|
|
|
da = a[-1] - a[0] # delta a |
|
|
|
ds = m.stride[-1] - m.stride[0] # delta s |
|
|
|
if da.sign() != ds.sign(): # same order |
|
|
|
m.anchors[:] = m.anchors.flip(0) |
|
|
|
m.anchor_grid[:] = m.anchor_grid.flip(0) |
|
|
|
|
|
|
|
|
|
|
|
def check_file(file): |
|
|
|
# Searches for file if not found locally |
|
|
|
if os.path.isfile(file): |