Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

163 lines
6.9KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Auto-anchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils.general import colorstr
  11. def check_anchor_order(m):
  12. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  13. a = m.anchors.prod(-1).view(-1) # anchor area
  14. da = a[-1] - a[0] # delta a
  15. ds = m.stride[-1] - m.stride[0] # delta s
  16. if da.sign() != ds.sign(): # same order
  17. print('Reversing anchor order')
  18. m.anchors[:] = m.anchors.flip(0)
  19. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  20. # Check anchor fit to data, recompute if necessary
  21. prefix = colorstr('autoanchor: ')
  22. print(f'\n{prefix}Analyzing anchors... ', end='')
  23. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  24. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  25. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  26. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  27. def metric(k): # compute metric
  28. r = wh[:, None] / k[None]
  29. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  30. best = x.max(1)[0] # best_x
  31. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  32. bpr = (best > 1. / thr).float().mean() # best possible recall
  33. return bpr, aat
  34. anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
  35. bpr, aat = metric(anchors.cpu().view(-1, 2))
  36. print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
  37. if bpr < 0.98: # threshold to recompute
  38. print('. Attempting to improve anchors, please wait...')
  39. na = m.anchors.numel() // 2 # number of anchors
  40. try:
  41. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  42. except Exception as e:
  43. print(f'{prefix}ERROR: {e}')
  44. new_bpr = metric(anchors)[0]
  45. if new_bpr > bpr: # replace anchors
  46. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  47. m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  48. check_anchor_order(m)
  49. print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  50. else:
  51. print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
  52. print('') # newline
  53. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  54. """ Creates kmeans-evolved anchors from training dataset
  55. Arguments:
  56. dataset: path to data.yaml, or a loaded dataset
  57. n: number of anchors
  58. img_size: image size used for training
  59. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  60. gen: generations to evolve anchors using genetic algorithm
  61. verbose: print all results
  62. Return:
  63. k: kmeans evolved anchors
  64. Usage:
  65. from utils.autoanchor import *; _ = kmean_anchors()
  66. """
  67. from scipy.cluster.vq import kmeans
  68. thr = 1. / thr
  69. prefix = colorstr('autoanchor: ')
  70. def metric(k, wh): # compute metrics
  71. r = wh[:, None] / k[None]
  72. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  73. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  74. return x, x.max(1)[0] # x, best_x
  75. def anchor_fitness(k): # mutation fitness
  76. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  77. return (best * (best > thr).float()).mean() # fitness
  78. def print_results(k):
  79. k = k[np.argsort(k.prod(1))] # sort small to large
  80. x, best = metric(k, wh0)
  81. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  82. print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
  83. print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
  84. f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
  85. for i, x in enumerate(k):
  86. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  87. return k
  88. if isinstance(dataset, str): # *.yaml file
  89. with open(dataset, errors='ignore') as f:
  90. data_dict = yaml.safe_load(f) # model dict
  91. from utils.datasets import LoadImagesAndLabels
  92. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  93. # Get label wh
  94. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  95. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  96. # Filter
  97. i = (wh0 < 3.0).any(1).sum()
  98. if i:
  99. print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  100. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  101. # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  102. # Kmeans calculation
  103. print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
  104. s = wh.std(0) # sigmas for whitening
  105. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  106. assert len(k) == n, f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}'
  107. k *= s
  108. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  109. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  110. k = print_results(k)
  111. # Plot
  112. # k, d = [None] * 20, [None] * 20
  113. # for i in tqdm(range(1, 21)):
  114. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  115. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  116. # ax = ax.ravel()
  117. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  118. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  119. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  120. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  121. # fig.savefig('wh.png', dpi=200)
  122. # Evolve
  123. npr = np.random
  124. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  125. pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
  126. for _ in pbar:
  127. v = np.ones(sh)
  128. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  129. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  130. kg = (k.copy() * v).clip(min=2.0)
  131. fg = anchor_fitness(kg)
  132. if fg > f:
  133. f, k = fg, kg.copy()
  134. pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  135. if verbose:
  136. print_results(k)
  137. return print_results(k)