You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

autoanchor.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. AutoAnchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils.general import LOGGER, colorstr, emojis
  11. PREFIX = colorstr('AutoAnchor: ')
  12. def check_anchor_order(m):
  13. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  14. a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
  15. da = a[-1] - a[0] # delta a
  16. ds = m.stride[-1] - m.stride[0] # delta s
  17. if da and (da.sign() != ds.sign()): # same order
  18. LOGGER.info(f'{PREFIX}Reversing anchor order')
  19. m.anchors[:] = m.anchors.flip(0)
  20. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  21. # Check anchor fit to data, recompute if necessary
  22. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  23. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  24. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  25. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  26. def metric(k): # compute metric
  27. r = wh[:, None] / k[None]
  28. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  29. best = x.max(1)[0] # best_x
  30. aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
  31. bpr = (best > 1 / thr).float().mean() # best possible recall
  32. return bpr, aat
  33. stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
  34. anchors = m.anchors.clone() * stride # current anchors
  35. bpr, aat = metric(anchors.cpu().view(-1, 2))
  36. s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
  37. if bpr > 0.98: # threshold to recompute
  38. LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅'))
  39. else:
  40. LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
  41. na = m.anchors.numel() // 2 # number of anchors
  42. try:
  43. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  44. except Exception as e:
  45. LOGGER.info(f'{PREFIX}ERROR: {e}')
  46. new_bpr = metric(anchors)[0]
  47. if new_bpr > bpr: # replace anchors
  48. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  49. m.anchors[:] = anchors.clone().view_as(m.anchors)
  50. check_anchor_order(m) # must be in pixel-space (not grid-space)
  51. m.anchors /= stride
  52. s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
  53. else:
  54. s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
  55. LOGGER.info(emojis(s))
  56. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  57. """ Creates kmeans-evolved anchors from training dataset
  58. Arguments:
  59. dataset: path to data.yaml, or a loaded dataset
  60. n: number of anchors
  61. img_size: image size used for training
  62. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  63. gen: generations to evolve anchors using genetic algorithm
  64. verbose: print all results
  65. Return:
  66. k: kmeans evolved anchors
  67. Usage:
  68. from utils.autoanchor import *; _ = kmean_anchors()
  69. """
  70. from scipy.cluster.vq import kmeans
  71. npr = np.random
  72. thr = 1 / thr
  73. def metric(k, wh): # compute metrics
  74. r = wh[:, None] / k[None]
  75. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  76. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  77. return x, x.max(1)[0] # x, best_x
  78. def anchor_fitness(k): # mutation fitness
  79. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  80. return (best * (best > thr).float()).mean() # fitness
  81. def print_results(k, verbose=True):
  82. k = k[np.argsort(k.prod(1))] # sort small to large
  83. x, best = metric(k, wh0)
  84. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  85. s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
  86. f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
  87. f'past_thr={x[x > thr].mean():.3f}-mean: '
  88. for i, x in enumerate(k):
  89. s += '%i,%i, ' % (round(x[0]), round(x[1]))
  90. if verbose:
  91. LOGGER.info(s[:-2])
  92. return k
  93. if isinstance(dataset, str): # *.yaml file
  94. with open(dataset, errors='ignore') as f:
  95. data_dict = yaml.safe_load(f) # model dict
  96. from utils.datasets import LoadImagesAndLabels
  97. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  98. # Get label wh
  99. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  100. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  101. # Filter
  102. i = (wh0 < 3.0).any(1).sum()
  103. if i:
  104. LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
  105. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  106. # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  107. # Kmeans init
  108. try:
  109. LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
  110. assert n <= len(wh) # apply overdetermined constraint
  111. s = wh.std(0) # sigmas for whitening
  112. k = kmeans(wh / s, n, iter=30)[0] * s # points
  113. assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
  114. except Exception:
  115. LOGGER.warning(f'{PREFIX}WARNING: switching strategies from kmeans to random init')
  116. k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
  117. wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
  118. k = print_results(k, verbose=False)
  119. # Plot
  120. # k, d = [None] * 20, [None] * 20
  121. # for i in tqdm(range(1, 21)):
  122. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  123. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  124. # ax = ax.ravel()
  125. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  126. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  127. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  128. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  129. # fig.savefig('wh.png', dpi=200)
  130. # Evolve
  131. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  132. pbar = tqdm(range(gen), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  133. for _ in pbar:
  134. v = np.ones(sh)
  135. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  136. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  137. kg = (k.copy() * v).clip(min=2.0)
  138. fg = anchor_fitness(kg)
  139. if fg > f:
  140. f, k = fg, kg.copy()
  141. pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  142. if verbose:
  143. print_results(k, verbose)
  144. return print_results(k)