You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
7.1KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. AutoAnchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils.general import LOGGER, colorstr, emojis
  11. PREFIX = colorstr('AutoAnchor: ')
  12. def check_anchor_order(m):
  13. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  14. a = m.anchors.prod(-1).view(-1) # anchor area
  15. da = a[-1] - a[0] # delta a
  16. ds = m.stride[-1] - m.stride[0] # delta s
  17. if da.sign() != ds.sign(): # same order
  18. LOGGER.info(f'{PREFIX}Reversing anchor order')
  19. m.anchors[:] = m.anchors.flip(0)
  20. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  21. # Check anchor fit to data, recompute if necessary
  22. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  23. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  24. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  25. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  26. def metric(k): # compute metric
  27. r = wh[:, None] / k[None]
  28. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  29. best = x.max(1)[0] # best_x
  30. aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
  31. bpr = (best > 1 / thr).float().mean() # best possible recall
  32. return bpr, aat
  33. anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
  34. bpr, aat = metric(anchors.cpu().view(-1, 2))
  35. s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
  36. if bpr > 0.98: # threshold to recompute
  37. LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅'))
  38. else:
  39. LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
  40. na = m.anchors.numel() // 2 # number of anchors
  41. try:
  42. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  43. except Exception as e:
  44. LOGGER.info(f'{PREFIX}ERROR: {e}')
  45. new_bpr = metric(anchors)[0]
  46. if new_bpr > bpr: # replace anchors
  47. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  48. m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  49. check_anchor_order(m)
  50. LOGGER.info(f'{PREFIX}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  51. else:
  52. LOGGER.info(f'{PREFIX}Original anchors better than new anchors. Proceeding with original anchors.')
  53. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  54. """ Creates kmeans-evolved anchors from training dataset
  55. Arguments:
  56. dataset: path to data.yaml, or a loaded dataset
  57. n: number of anchors
  58. img_size: image size used for training
  59. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  60. gen: generations to evolve anchors using genetic algorithm
  61. verbose: print all results
  62. Return:
  63. k: kmeans evolved anchors
  64. Usage:
  65. from utils.autoanchor import *; _ = kmean_anchors()
  66. """
  67. from scipy.cluster.vq import kmeans
  68. npr = np.random
  69. thr = 1 / thr
  70. def metric(k, wh): # compute metrics
  71. r = wh[:, None] / k[None]
  72. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  73. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  74. return x, x.max(1)[0] # x, best_x
  75. def anchor_fitness(k): # mutation fitness
  76. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  77. return (best * (best > thr).float()).mean() # fitness
  78. def print_results(k, verbose=True):
  79. k = k[np.argsort(k.prod(1))] # sort small to large
  80. x, best = metric(k, wh0)
  81. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  82. s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
  83. f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
  84. f'past_thr={x[x > thr].mean():.3f}-mean: '
  85. for i, x in enumerate(k):
  86. s += '%i,%i, ' % (round(x[0]), round(x[1]))
  87. if verbose:
  88. LOGGER.info(s[:-2])
  89. return k
  90. if isinstance(dataset, str): # *.yaml file
  91. with open(dataset, errors='ignore') as f:
  92. data_dict = yaml.safe_load(f) # model dict
  93. from utils.datasets import LoadImagesAndLabels
  94. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  95. # Get label wh
  96. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  97. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  98. # Filter
  99. i = (wh0 < 3.0).any(1).sum()
  100. if i:
  101. LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  102. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  103. # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  104. # Kmeans calculation
  105. LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
  106. s = wh.std(0) # sigmas for whitening
  107. k = kmeans(wh / s, n, iter=30)[0] * s # points
  108. if len(k) != n: # kmeans may return fewer points than requested if wh is insufficient or too similar
  109. LOGGER.warning(f'{PREFIX}WARNING: scipy.cluster.vq.kmeans returned only {len(k)} of {n} requested points')
  110. k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
  111. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  112. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  113. k = print_results(k, verbose=False)
  114. # Plot
  115. # k, d = [None] * 20, [None] * 20
  116. # for i in tqdm(range(1, 21)):
  117. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  118. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  119. # ax = ax.ravel()
  120. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  121. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  122. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  123. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  124. # fig.savefig('wh.png', dpi=200)
  125. # Evolve
  126. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  127. pbar = tqdm(range(gen), desc=f'{PREFIX}Evolving anchors with Genetic Algorithm:') # progress bar
  128. for _ in pbar:
  129. v = np.ones(sh)
  130. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  131. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  132. kg = (k.copy() * v).clip(min=2.0)
  133. fg = anchor_fitness(kg)
  134. if fg > f:
  135. f, k = fg, kg.copy()
  136. pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  137. if verbose:
  138. print_results(k, verbose)
  139. return print_results(k)