You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
7.0KB

  1. # Auto-anchor utils
  2. import numpy as np
  3. import torch
  4. import yaml
  5. from scipy.cluster.vq import kmeans
  6. from tqdm import tqdm
  7. from utils.general import colorstr
  8. def check_anchor_order(m):
  9. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  10. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  11. da = a[-1] - a[0] # delta a
  12. ds = m.stride[-1] - m.stride[0] # delta s
  13. if da.sign() != ds.sign(): # same order
  14. print('Reversing anchor order')
  15. m.anchors[:] = m.anchors.flip(0)
  16. m.anchor_grid[:] = m.anchor_grid.flip(0)
  17. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  18. # Check anchor fit to data, recompute if necessary
  19. prefix = colorstr('autoanchor: ')
  20. print(f'\n{prefix}Analyzing anchors... ', end='')
  21. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  22. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  23. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  24. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  25. def metric(k): # compute metric
  26. r = wh[:, None] / k[None]
  27. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  28. best = x.max(1)[0] # best_x
  29. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  30. bpr = (best > 1. / thr).float().mean() # best possible recall
  31. return bpr, aat
  32. anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
  33. bpr, aat = metric(anchors)
  34. print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
  35. if bpr < 0.98: # threshold to recompute
  36. print('. Attempting to improve anchors, please wait...')
  37. na = m.anchor_grid.numel() // 2 # number of anchors
  38. try:
  39. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  40. except Exception as e:
  41. print(f'{prefix}ERROR: {e}')
  42. new_bpr = metric(anchors)[0]
  43. if new_bpr > bpr: # replace anchors
  44. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  45. m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
  46. m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  47. check_anchor_order(m)
  48. print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  49. else:
  50. print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
  51. print('') # newline
  52. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  53. """ Creates kmeans-evolved anchors from training dataset
  54. Arguments:
  55. path: path to dataset *.yaml, or a loaded dataset
  56. n: number of anchors
  57. img_size: image size used for training
  58. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  59. gen: generations to evolve anchors using genetic algorithm
  60. verbose: print all results
  61. Return:
  62. k: kmeans evolved anchors
  63. Usage:
  64. from utils.autoanchor import *; _ = kmean_anchors()
  65. """
  66. thr = 1. / thr
  67. prefix = colorstr('autoanchor: ')
  68. def metric(k, wh): # compute metrics
  69. r = wh[:, None] / k[None]
  70. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  71. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  72. return x, x.max(1)[0] # x, best_x
  73. def anchor_fitness(k): # mutation fitness
  74. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  75. return (best * (best > thr).float()).mean() # fitness
  76. def print_results(k):
  77. k = k[np.argsort(k.prod(1))] # sort small to large
  78. x, best = metric(k, wh0)
  79. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  80. print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
  81. print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
  82. f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
  83. for i, x in enumerate(k):
  84. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  85. return k
  86. if isinstance(path, str): # *.yaml file
  87. with open(path) as f:
  88. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict
  89. from utils.datasets import LoadImagesAndLabels
  90. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  91. else:
  92. dataset = path # dataset
  93. # Get label wh
  94. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  95. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  96. # Filter
  97. i = (wh0 < 3.0).any(1).sum()
  98. if i:
  99. print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  100. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  101. # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  102. # Kmeans calculation
  103. print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
  104. s = wh.std(0) # sigmas for whitening
  105. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  106. assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}')
  107. k *= s
  108. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  109. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  110. k = print_results(k)
  111. # Plot
  112. # k, d = [None] * 20, [None] * 20
  113. # for i in tqdm(range(1, 21)):
  114. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  115. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  116. # ax = ax.ravel()
  117. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  118. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  119. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  120. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  121. # fig.savefig('wh.png', dpi=200)
  122. # Evolve
  123. npr = np.random
  124. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  125. pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
  126. for _ in pbar:
  127. v = np.ones(sh)
  128. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  129. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  130. kg = (k.copy() * v).clip(min=2.0)
  131. fg = anchor_fitness(kg)
  132. if fg > f:
  133. f, k = fg, kg.copy()
  134. pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  135. if verbose:
  136. print_results(k)
  137. return print_results(k)