Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

156 lines
6.8KB

  1. # Auto-anchor utils
  2. import numpy as np
  3. import torch
  4. import yaml
  5. from scipy.cluster.vq import kmeans
  6. from tqdm import tqdm
  7. from utils.general import colorstr
  8. def check_anchor_order(m):
  9. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  10. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  11. da = a[-1] - a[0] # delta a
  12. ds = m.stride[-1] - m.stride[0] # delta s
  13. if da.sign() != ds.sign(): # same order
  14. print('Reversing anchor order')
  15. m.anchors[:] = m.anchors.flip(0)
  16. m.anchor_grid[:] = m.anchor_grid.flip(0)
  17. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  18. # Check anchor fit to data, recompute if necessary
  19. prefix = colorstr('autoanchor: ')
  20. print(f'\n{prefix}Analyzing anchors... ', end='')
  21. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  22. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  23. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  24. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  25. def metric(k): # compute metric
  26. r = wh[:, None] / k[None]
  27. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  28. best = x.max(1)[0] # best_x
  29. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  30. bpr = (best > 1. / thr).float().mean() # best possible recall
  31. return bpr, aat
  32. bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
  33. print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
  34. if bpr < 0.98: # threshold to recompute
  35. print('. Attempting to improve anchors, please wait...')
  36. na = m.anchor_grid.numel() // 2 # number of anchors
  37. new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  38. new_bpr = metric(new_anchors.reshape(-1, 2))[0]
  39. if new_bpr > bpr: # replace anchors
  40. new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
  41. m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
  42. m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  43. check_anchor_order(m)
  44. print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  45. else:
  46. print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
  47. print('') # newline
  48. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  49. """ Creates kmeans-evolved anchors from training dataset
  50. Arguments:
  51. path: path to dataset *.yaml, or a loaded dataset
  52. n: number of anchors
  53. img_size: image size used for training
  54. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  55. gen: generations to evolve anchors using genetic algorithm
  56. verbose: print all results
  57. Return:
  58. k: kmeans evolved anchors
  59. Usage:
  60. from utils.autoanchor import *; _ = kmean_anchors()
  61. """
  62. thr = 1. / thr
  63. prefix = colorstr('autoanchor: ')
  64. def metric(k, wh): # compute metrics
  65. r = wh[:, None] / k[None]
  66. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  67. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  68. return x, x.max(1)[0] # x, best_x
  69. def anchor_fitness(k): # mutation fitness
  70. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  71. return (best * (best > thr).float()).mean() # fitness
  72. def print_results(k):
  73. k = k[np.argsort(k.prod(1))] # sort small to large
  74. x, best = metric(k, wh0)
  75. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  76. print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
  77. print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
  78. f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
  79. for i, x in enumerate(k):
  80. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  81. return k
  82. if isinstance(path, str): # *.yaml file
  83. with open(path) as f:
  84. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict
  85. from utils.datasets import LoadImagesAndLabels
  86. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  87. else:
  88. dataset = path # dataset
  89. # Get label wh
  90. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  91. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  92. # Filter
  93. i = (wh0 < 3.0).any(1).sum()
  94. if i:
  95. print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  96. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  97. # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  98. # Kmeans calculation
  99. print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
  100. s = wh.std(0) # sigmas for whitening
  101. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  102. k *= s
  103. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  104. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  105. k = print_results(k)
  106. # Plot
  107. # k, d = [None] * 20, [None] * 20
  108. # for i in tqdm(range(1, 21)):
  109. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  110. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  111. # ax = ax.ravel()
  112. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  113. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  114. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  115. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  116. # fig.savefig('wh.png', dpi=200)
  117. # Evolve
  118. npr = np.random
  119. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  120. pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
  121. for _ in pbar:
  122. v = np.ones(sh)
  123. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  124. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  125. kg = (k.copy() * v).clip(min=2.0)
  126. fg = anchor_fitness(kg)
  127. if fg > f:
  128. f, k = fg, kg.copy()
  129. pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  130. if verbose:
  131. print_results(k)
  132. return print_results(k)