You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
7.0KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Auto-anchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils.general import LOGGER, colorstr, emojis
  11. PREFIX = colorstr('AutoAnchor: ')
  12. def check_anchor_order(m):
  13. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  14. a = m.anchors.prod(-1).view(-1) # anchor area
  15. da = a[-1] - a[0] # delta a
  16. ds = m.stride[-1] - m.stride[0] # delta s
  17. if da.sign() != ds.sign(): # same order
  18. LOGGER.info(f'{PREFIX}Reversing anchor order')
  19. m.anchors[:] = m.anchors.flip(0)
  20. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  21. # Check anchor fit to data, recompute if necessary
  22. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  23. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  24. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  25. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  26. def metric(k): # compute metric
  27. r = wh[:, None] / k[None]
  28. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  29. best = x.max(1)[0] # best_x
  30. aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
  31. bpr = (best > 1 / thr).float().mean() # best possible recall
  32. return bpr, aat
  33. anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
  34. bpr, aat = metric(anchors.cpu().view(-1, 2))
  35. s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
  36. if bpr > 0.98: # threshold to recompute
  37. LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅'))
  38. else:
  39. LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
  40. na = m.anchors.numel() // 2 # number of anchors
  41. try:
  42. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  43. except Exception as e:
  44. LOGGER.info(f'{PREFIX}ERROR: {e}')
  45. new_bpr = metric(anchors)[0]
  46. if new_bpr > bpr: # replace anchors
  47. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  48. m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  49. check_anchor_order(m)
  50. LOGGER.info(f'{PREFIX}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  51. else:
  52. LOGGER.info(f'{PREFIX}Original anchors better than new anchors. Proceeding with original anchors.')
  53. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  54. """ Creates kmeans-evolved anchors from training dataset
  55. Arguments:
  56. dataset: path to data.yaml, or a loaded dataset
  57. n: number of anchors
  58. img_size: image size used for training
  59. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  60. gen: generations to evolve anchors using genetic algorithm
  61. verbose: print all results
  62. Return:
  63. k: kmeans evolved anchors
  64. Usage:
  65. from utils.autoanchor import *; _ = kmean_anchors()
  66. """
  67. from scipy.cluster.vq import kmeans
  68. thr = 1 / thr
  69. def metric(k, wh): # compute metrics
  70. r = wh[:, None] / k[None]
  71. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  72. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  73. return x, x.max(1)[0] # x, best_x
  74. def anchor_fitness(k): # mutation fitness
  75. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  76. return (best * (best > thr).float()).mean() # fitness
  77. def print_results(k, verbose=True):
  78. k = k[np.argsort(k.prod(1))] # sort small to large
  79. x, best = metric(k, wh0)
  80. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  81. s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
  82. f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
  83. f'past_thr={x[x > thr].mean():.3f}-mean: '
  84. for i, x in enumerate(k):
  85. s += '%i,%i, ' % (round(x[0]), round(x[1]))
  86. if verbose:
  87. LOGGER.info(s[:-2])
  88. return k
  89. if isinstance(dataset, str): # *.yaml file
  90. with open(dataset, errors='ignore') as f:
  91. data_dict = yaml.safe_load(f) # model dict
  92. from utils.datasets import LoadImagesAndLabels
  93. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  94. # Get label wh
  95. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  96. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  97. # Filter
  98. i = (wh0 < 3.0).any(1).sum()
  99. if i:
  100. LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  101. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  102. # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  103. # Kmeans calculation
  104. LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
  105. s = wh.std(0) # sigmas for whitening
  106. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  107. assert len(k) == n, f'{PREFIX}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}'
  108. k *= s
  109. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  110. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  111. k = print_results(k, verbose=False)
  112. # Plot
  113. # k, d = [None] * 20, [None] * 20
  114. # for i in tqdm(range(1, 21)):
  115. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  116. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  117. # ax = ax.ravel()
  118. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  119. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  120. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  121. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  122. # fig.savefig('wh.png', dpi=200)
  123. # Evolve
  124. npr = np.random
  125. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  126. pbar = tqdm(range(gen), desc=f'{PREFIX}Evolving anchors with Genetic Algorithm:') # progress bar
  127. for _ in pbar:
  128. v = np.ones(sh)
  129. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  130. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  131. kg = (k.copy() * v).clip(min=2.0)
  132. fg = anchor_fitness(kg)
  133. if fg > f:
  134. f, k = fg, kg.copy()
  135. pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  136. if verbose:
  137. print_results(k, verbose)
  138. return print_results(k)