您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

1299 行
53KB

  1. import glob
  2. import logging
  3. import os
  4. import platform
  5. import random
  6. import re
  7. import shutil
  8. import subprocess
  9. import time
  10. from contextlib import contextmanager
  11. from copy import copy
  12. from pathlib import Path
  13. import cv2
  14. import math
  15. import matplotlib
  16. import matplotlib.pyplot as plt
  17. import numpy as np
  18. import torch
  19. import torch.nn as nn
  20. import yaml
  21. from scipy.cluster.vq import kmeans
  22. from scipy.signal import butter, filtfilt
  23. from tqdm import tqdm
  24. from utils.google_utils import gsutil_getsize
  25. from utils.torch_utils import is_parallel, init_torch_seeds
  26. # Set printoptions
  27. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  28. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  29. matplotlib.rc('font', **{'size': 11})
  30. # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
  31. cv2.setNumThreads(0)
  32. @contextmanager
  33. def torch_distributed_zero_first(local_rank: int):
  34. """
  35. Decorator to make all processes in distributed training wait for each local_master to do something.
  36. """
  37. if local_rank not in [-1, 0]:
  38. torch.distributed.barrier()
  39. yield
  40. if local_rank == 0:
  41. torch.distributed.barrier()
  42. def set_logging(rank=-1):
  43. logging.basicConfig(
  44. format="%(message)s",
  45. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  46. def init_seeds(seed=0):
  47. random.seed(seed)
  48. np.random.seed(seed)
  49. init_torch_seeds(seed)
  50. def get_latest_run(search_dir='./runs'):
  51. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  52. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  53. return max(last_list, key=os.path.getctime) if last_list else ''
  54. def check_git_status():
  55. # Suggest 'git pull' if repo is out of date
  56. if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
  57. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  58. if 'Your branch is behind' in s:
  59. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  60. def check_img_size(img_size, s=32):
  61. # Verify img_size is a multiple of stride s
  62. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  63. if new_size != img_size:
  64. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  65. return new_size
  66. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  67. # Check anchor fit to data, recompute if necessary
  68. print('\nAnalyzing anchors... ', end='')
  69. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  70. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  71. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  72. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  73. def metric(k): # compute metric
  74. r = wh[:, None] / k[None]
  75. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  76. best = x.max(1)[0] # best_x
  77. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  78. bpr = (best > 1. / thr).float().mean() # best possible recall
  79. return bpr, aat
  80. bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
  81. print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
  82. if bpr < 0.98: # threshold to recompute
  83. print('. Attempting to generate improved anchors, please wait...' % bpr)
  84. na = m.anchor_grid.numel() // 2 # number of anchors
  85. new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  86. new_bpr = metric(new_anchors.reshape(-1, 2))[0]
  87. if new_bpr > bpr: # replace anchors
  88. new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
  89. m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
  90. m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  91. check_anchor_order(m)
  92. print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  93. else:
  94. print('Original anchors better than new anchors. Proceeding with original anchors.')
  95. print('') # newline
  96. def check_anchor_order(m):
  97. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  98. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  99. da = a[-1] - a[0] # delta a
  100. ds = m.stride[-1] - m.stride[0] # delta s
  101. if da.sign() != ds.sign(): # same order
  102. print('Reversing anchor order')
  103. m.anchors[:] = m.anchors.flip(0)
  104. m.anchor_grid[:] = m.anchor_grid.flip(0)
  105. def check_file(file):
  106. # Search for file if not found
  107. if os.path.isfile(file) or file == '':
  108. return file
  109. else:
  110. files = glob.glob('./**/' + file, recursive=True) # find file
  111. assert len(files), 'File Not Found: %s' % file # assert file was found
  112. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  113. return files[0] # return file
  114. def check_dataset(dict):
  115. # Download dataset if not found
  116. val, s = dict.get('val'), dict.get('download')
  117. if val and len(val):
  118. val = [os.path.abspath(x) for x in (val if isinstance(val, list) else [val])] # val path
  119. if not all(os.path.exists(x) for x in val):
  120. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [*val])
  121. if s and len(s): # download script
  122. print('Downloading %s ...' % s)
  123. if s.startswith('http') and s.endswith('.zip'): # URL
  124. f = Path(s).name # filename
  125. torch.hub.download_url_to_file(s, f)
  126. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  127. else: # bash script
  128. r = os.system(s)
  129. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  130. else:
  131. raise Exception('Dataset not found.')
  132. def make_divisible(x, divisor):
  133. # Returns x evenly divisible by divisor
  134. return math.ceil(x / divisor) * divisor
  135. def labels_to_class_weights(labels, nc=80):
  136. # Get class weights (inverse frequency) from training labels
  137. if labels[0] is None: # no labels loaded
  138. return torch.Tensor()
  139. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  140. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  141. weights = np.bincount(classes, minlength=nc) # occurrences per class
  142. # Prepend gridpoint count (for uCE training)
  143. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  144. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  145. weights[weights == 0] = 1 # replace empty bins with 1
  146. weights = 1 / weights # number of targets per class
  147. weights /= weights.sum() # normalize
  148. return torch.from_numpy(weights)
  149. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  150. # Produces image weights based on class mAPs
  151. n = len(labels)
  152. class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
  153. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  154. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  155. return image_weights
  156. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  157. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  158. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  159. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  160. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  161. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  162. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  163. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  164. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  165. return x
  166. def xyxy2xywh(x):
  167. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  168. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  169. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  170. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  171. y[:, 2] = x[:, 2] - x[:, 0] # width
  172. y[:, 3] = x[:, 3] - x[:, 1] # height
  173. return y
  174. def xywh2xyxy(x):
  175. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  176. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  177. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  178. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  179. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  180. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  181. return y
  182. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  183. # Rescale coords (xyxy) from img1_shape to img0_shape
  184. if ratio_pad is None: # calculate from img0_shape
  185. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  186. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  187. else:
  188. gain = ratio_pad[0][0]
  189. pad = ratio_pad[1]
  190. coords[:, [0, 2]] -= pad[0] # x padding
  191. coords[:, [1, 3]] -= pad[1] # y padding
  192. coords[:, :4] /= gain
  193. clip_coords(coords, img0_shape)
  194. return coords
  195. def clip_coords(boxes, img_shape):
  196. # Clip bounding xyxy bounding boxes to image shape (height, width)
  197. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  198. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  199. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  200. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  201. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
  202. """ Compute the average precision, given the recall and precision curves.
  203. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  204. # Arguments
  205. tp: True positives (nparray, nx1 or nx10).
  206. conf: Objectness value from 0-1 (nparray).
  207. pred_cls: Predicted object classes (nparray).
  208. target_cls: True object classes (nparray).
  209. plot: Plot precision-recall curve at mAP@0.5
  210. fname: Plot filename
  211. # Returns
  212. The average precision as computed in py-faster-rcnn.
  213. """
  214. # Sort by objectness
  215. i = np.argsort(-conf)
  216. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  217. # Find unique classes
  218. unique_classes = np.unique(target_cls)
  219. # Create Precision-Recall curve and compute AP for each class
  220. px, py = np.linspace(0, 1, 1000), [] # for plotting
  221. pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
  222. s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
  223. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  224. for ci, c in enumerate(unique_classes):
  225. i = pred_cls == c
  226. n_gt = (target_cls == c).sum() # Number of ground truth objects
  227. n_p = i.sum() # Number of predicted objects
  228. if n_p == 0 or n_gt == 0:
  229. continue
  230. else:
  231. # Accumulate FPs and TPs
  232. fpc = (1 - tp[i]).cumsum(0)
  233. tpc = tp[i].cumsum(0)
  234. # Recall
  235. recall = tpc / (n_gt + 1e-16) # recall curve
  236. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
  237. # Precision
  238. precision = tpc / (tpc + fpc) # precision curve
  239. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
  240. # AP from recall-precision curve
  241. for j in range(tp.shape[1]):
  242. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  243. if j == 0:
  244. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  245. # Compute F1 score (harmonic mean of precision and recall)
  246. f1 = 2 * p * r / (p + r + 1e-16)
  247. if plot:
  248. py = np.stack(py, axis=1)
  249. fig, ax = plt.subplots(1, 1, figsize=(5, 5))
  250. ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
  251. ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  252. ax.set_xlabel('Recall')
  253. ax.set_ylabel('Precision')
  254. ax.set_xlim(0, 1)
  255. ax.set_ylim(0, 1)
  256. plt.legend()
  257. fig.tight_layout()
  258. fig.savefig(fname, dpi=200)
  259. return p, r, ap, f1, unique_classes.astype('int32')
  260. def compute_ap(recall, precision):
  261. """ Compute the average precision, given the recall and precision curves.
  262. Source: https://github.com/rbgirshick/py-faster-rcnn.
  263. # Arguments
  264. recall: The recall curve (list).
  265. precision: The precision curve (list).
  266. # Returns
  267. The average precision as computed in py-faster-rcnn.
  268. """
  269. # Append sentinel values to beginning and end
  270. mrec = recall # np.concatenate(([0.], recall, [recall[-1] + 1E-3]))
  271. mpre = precision # np.concatenate(([0.], precision, [0.]))
  272. # Compute the precision envelope
  273. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  274. # Integrate area under curve
  275. method = 'interp' # methods: 'continuous', 'interp'
  276. if method == 'interp':
  277. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  278. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  279. else: # 'continuous'
  280. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  281. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  282. return ap, mpre, mrec
  283. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  284. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  285. box2 = box2.T
  286. # Get the coordinates of bounding boxes
  287. if x1y1x2y2: # x1, y1, x2, y2 = box1
  288. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  289. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  290. else: # transform from xywh to xyxy
  291. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  292. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  293. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  294. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  295. # Intersection area
  296. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  297. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  298. # Union Area
  299. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  300. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  301. union = w1 * h1 + w2 * h2 - inter + eps
  302. iou = inter / union
  303. if GIoU or DIoU or CIoU:
  304. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  305. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  306. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  307. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  308. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  309. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  310. if DIoU:
  311. return iou - rho2 / c2 # DIoU
  312. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  313. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  314. with torch.no_grad():
  315. alpha = v / ((1 + eps) - iou + v)
  316. return iou - (rho2 / c2 + v * alpha) # CIoU
  317. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  318. c_area = cw * ch + eps # convex area
  319. return iou - (c_area - union) / c_area # GIoU
  320. else:
  321. return iou # IoU
  322. def box_iou(box1, box2):
  323. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  324. """
  325. Return intersection-over-union (Jaccard index) of boxes.
  326. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  327. Arguments:
  328. box1 (Tensor[N, 4])
  329. box2 (Tensor[M, 4])
  330. Returns:
  331. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  332. IoU values for every element in boxes1 and boxes2
  333. """
  334. def box_area(box):
  335. # box = 4xn
  336. return (box[2] - box[0]) * (box[3] - box[1])
  337. area1 = box_area(box1.T)
  338. area2 = box_area(box2.T)
  339. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  340. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  341. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  342. def wh_iou(wh1, wh2):
  343. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  344. wh1 = wh1[:, None] # [N,1,2]
  345. wh2 = wh2[None] # [1,M,2]
  346. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  347. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  348. class FocalLoss(nn.Module):
  349. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  350. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  351. super(FocalLoss, self).__init__()
  352. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  353. self.gamma = gamma
  354. self.alpha = alpha
  355. self.reduction = loss_fcn.reduction
  356. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  357. def forward(self, pred, true):
  358. loss = self.loss_fcn(pred, true)
  359. # p_t = torch.exp(-loss)
  360. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  361. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  362. pred_prob = torch.sigmoid(pred) # prob from logits
  363. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  364. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  365. modulating_factor = (1.0 - p_t) ** self.gamma
  366. loss *= alpha_factor * modulating_factor
  367. if self.reduction == 'mean':
  368. return loss.mean()
  369. elif self.reduction == 'sum':
  370. return loss.sum()
  371. else: # 'none'
  372. return loss
  373. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  374. # return positive, negative label smoothing BCE targets
  375. return 1.0 - 0.5 * eps, 0.5 * eps
  376. class BCEBlurWithLogitsLoss(nn.Module):
  377. # BCEwithLogitLoss() with reduced missing label effects.
  378. def __init__(self, alpha=0.05):
  379. super(BCEBlurWithLogitsLoss, self).__init__()
  380. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  381. self.alpha = alpha
  382. def forward(self, pred, true):
  383. loss = self.loss_fcn(pred, true)
  384. pred = torch.sigmoid(pred) # prob from logits
  385. dx = pred - true # reduce only missing label effects
  386. # dx = (pred - true).abs() # reduce missing label and false label effects
  387. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  388. loss *= alpha_factor
  389. return loss.mean()
  390. def compute_loss(p, targets, model): # predictions, targets, model
  391. device = targets.device
  392. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  393. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  394. h = model.hyp # hyperparameters
  395. # Define criteria
  396. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
  397. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
  398. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  399. cp, cn = smooth_BCE(eps=0.0)
  400. # Focal loss
  401. g = h['fl_gamma'] # focal loss gamma
  402. if g > 0:
  403. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  404. # Losses
  405. nt = 0 # number of targets
  406. np = len(p) # number of outputs
  407. balance = [4.0, 1.0, 0.4] if np == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
  408. for i, pi in enumerate(p): # layer index, layer predictions
  409. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  410. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  411. n = b.shape[0] # number of targets
  412. if n:
  413. nt += n # cumulative targets
  414. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  415. # Regression
  416. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  417. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  418. pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
  419. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  420. lbox += (1.0 - iou).mean() # iou loss
  421. # Objectness
  422. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
  423. # Classification
  424. if model.nc > 1: # cls loss (only if multiple classes)
  425. t = torch.full_like(ps[:, 5:], cn, device=device) # targets
  426. t[range(n), tcls[i]] = cp
  427. lcls += BCEcls(ps[:, 5:], t) # BCE
  428. # Append targets to text file
  429. # with open('targets.txt', 'a') as file:
  430. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  431. lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
  432. s = 3 / np # output count scaling
  433. lbox *= h['box'] * s
  434. lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
  435. lcls *= h['cls'] * s
  436. bs = tobj.shape[0] # batch size
  437. loss = lbox + lobj + lcls
  438. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  439. def build_targets(p, targets, model):
  440. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  441. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  442. na, nt = det.na, targets.shape[0] # number of anchors, targets
  443. tcls, tbox, indices, anch = [], [], [], []
  444. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  445. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  446. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  447. g = 0.5 # bias
  448. off = torch.tensor([[0, 0],
  449. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  450. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  451. ], device=targets.device).float() * g # offsets
  452. for i in range(det.nl):
  453. anchors = det.anchors[i]
  454. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  455. # Match targets to anchors
  456. t = targets * gain
  457. if nt:
  458. # Matches
  459. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  460. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  461. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  462. t = t[j] # filter
  463. # Offsets
  464. gxy = t[:, 2:4] # grid xy
  465. gxi = gain[[2, 3]] - gxy # inverse
  466. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  467. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  468. j = torch.stack((torch.ones_like(j), j, k, l, m))
  469. t = t.repeat((5, 1, 1))[j]
  470. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  471. else:
  472. t = targets[0]
  473. offsets = 0
  474. # Define
  475. b, c = t[:, :2].long().T # image, class
  476. gxy = t[:, 2:4] # grid xy
  477. gwh = t[:, 4:6] # grid wh
  478. gij = (gxy - offsets).long()
  479. gi, gj = gij.T # grid xy indices
  480. # Append
  481. a = t[:, 6].long() # anchor indices
  482. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  483. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  484. anch.append(anchors[a]) # anchors
  485. tcls.append(c) # class
  486. return tcls, tbox, indices, anch
  487. def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
  488. """Performs Non-Maximum Suppression (NMS) on inference results
  489. Returns:
  490. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  491. """
  492. nc = prediction[0].shape[1] - 5 # number of classes
  493. xc = prediction[..., 4] > conf_thres # candidates
  494. # Settings
  495. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  496. max_det = 300 # maximum number of detections per image
  497. time_limit = 10.0 # seconds to quit after
  498. redundant = True # require redundant detections
  499. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  500. t = time.time()
  501. output = [None] * prediction.shape[0]
  502. for xi, x in enumerate(prediction): # image index, image inference
  503. # Apply constraints
  504. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  505. x = x[xc[xi]] # confidence
  506. # If none remain process next image
  507. if not x.shape[0]:
  508. continue
  509. # Compute conf
  510. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  511. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  512. box = xywh2xyxy(x[:, :4])
  513. # Detections matrix nx6 (xyxy, conf, cls)
  514. if multi_label:
  515. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  516. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  517. else: # best class only
  518. conf, j = x[:, 5:].max(1, keepdim=True)
  519. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  520. # Filter by class
  521. if classes:
  522. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  523. # Apply finite constraint
  524. # if not torch.isfinite(x).all():
  525. # x = x[torch.isfinite(x).all(1)]
  526. # If none remain process next image
  527. n = x.shape[0] # number of boxes
  528. if not n:
  529. continue
  530. # Sort by confidence
  531. # x = x[x[:, 4].argsort(descending=True)]
  532. # Batched NMS
  533. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  534. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  535. i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
  536. if i.shape[0] > max_det: # limit detections
  537. i = i[:max_det]
  538. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  539. try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  540. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  541. weights = iou * scores[None] # box weights
  542. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  543. if redundant:
  544. i = i[iou.sum(1) > 1] # require redundancy
  545. except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
  546. print(x, i, x.shape, i.shape)
  547. pass
  548. output[xi] = x[i]
  549. if (time.time() - t) > time_limit:
  550. break # time limit exceeded
  551. return output
  552. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  553. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  554. x = torch.load(f, map_location=torch.device('cpu'))
  555. x['optimizer'] = None
  556. x['training_results'] = None
  557. x['epoch'] = -1
  558. x['model'].half() # to FP16
  559. for p in x['model'].parameters():
  560. p.requires_grad = False
  561. torch.save(x, s or f)
  562. mb = os.path.getsize(s or f) / 1E6 # filesize
  563. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  564. def coco_class_count(path='../coco/labels/train2014/'):
  565. # Histogram of occurrences per class
  566. nc = 80 # number classes
  567. x = np.zeros(nc, dtype='int32')
  568. files = sorted(glob.glob('%s/*.*' % path))
  569. for i, file in enumerate(files):
  570. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  571. x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
  572. print(i, len(files))
  573. def coco_only_people(path='../coco/labels/train2017/'): # from utils.general import *; coco_only_people()
  574. # Find images with only people
  575. files = sorted(glob.glob('%s/*.*' % path))
  576. for i, file in enumerate(files):
  577. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  578. if all(labels[:, 0] == 0):
  579. print(labels.shape[0], file)
  580. def crop_images_random(path='../images/', scale=0.50): # from utils.general import *; crop_images_random()
  581. # crops images into random squares up to scale fraction
  582. # WARNING: overwrites images!
  583. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  584. img = cv2.imread(file) # BGR
  585. if img is not None:
  586. h, w = img.shape[:2]
  587. # create random mask
  588. a = 30 # minimum size (pixels)
  589. mask_h = random.randint(a, int(max(a, h * scale))) # mask height
  590. mask_w = mask_h # mask width
  591. # box
  592. xmin = max(0, random.randint(0, w) - mask_w // 2)
  593. ymin = max(0, random.randint(0, h) - mask_h // 2)
  594. xmax = min(w, xmin + mask_w)
  595. ymax = min(h, ymin + mask_h)
  596. # apply random color mask
  597. cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
  598. def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
  599. # Makes single-class coco datasets. from utils.general import *; coco_single_class_labels()
  600. if os.path.exists('new/'):
  601. shutil.rmtree('new/') # delete output folder
  602. os.makedirs('new/') # make new output folder
  603. os.makedirs('new/labels/')
  604. os.makedirs('new/images/')
  605. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  606. with open(file, 'r') as f:
  607. labels = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  608. i = labels[:, 0] == label_class
  609. if any(i):
  610. img_file = file.replace('labels', 'images').replace('txt', 'jpg')
  611. labels[:, 0] = 0 # reset class to 0
  612. with open('new/images.txt', 'a') as f: # add image to dataset list
  613. f.write(img_file + '\n')
  614. with open('new/labels/' + Path(file).name, 'a') as f: # write label
  615. for l in labels[i]:
  616. f.write('%g %.6f %.6f %.6f %.6f\n' % tuple(l))
  617. shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
  618. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  619. """ Creates kmeans-evolved anchors from training dataset
  620. Arguments:
  621. path: path to dataset *.yaml, or a loaded dataset
  622. n: number of anchors
  623. img_size: image size used for training
  624. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  625. gen: generations to evolve anchors using genetic algorithm
  626. Return:
  627. k: kmeans evolved anchors
  628. Usage:
  629. from utils.general import *; _ = kmean_anchors()
  630. """
  631. thr = 1. / thr
  632. def metric(k, wh): # compute metrics
  633. r = wh[:, None] / k[None]
  634. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  635. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  636. return x, x.max(1)[0] # x, best_x
  637. def fitness(k): # mutation fitness
  638. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  639. return (best * (best > thr).float()).mean() # fitness
  640. def print_results(k):
  641. k = k[np.argsort(k.prod(1))] # sort small to large
  642. x, best = metric(k, wh0)
  643. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  644. print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
  645. print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
  646. (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
  647. for i, x in enumerate(k):
  648. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  649. return k
  650. if isinstance(path, str): # *.yaml file
  651. with open(path) as f:
  652. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  653. from utils.datasets import LoadImagesAndLabels
  654. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  655. else:
  656. dataset = path # dataset
  657. # Get label wh
  658. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  659. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  660. # Filter
  661. i = (wh0 < 3.0).any(1).sum()
  662. if i:
  663. print('WARNING: Extremely small objects found. '
  664. '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
  665. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  666. # Kmeans calculation
  667. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  668. s = wh.std(0) # sigmas for whitening
  669. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  670. k *= s
  671. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  672. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  673. k = print_results(k)
  674. # Plot
  675. # k, d = [None] * 20, [None] * 20
  676. # for i in tqdm(range(1, 21)):
  677. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  678. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  679. # ax = ax.ravel()
  680. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  681. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  682. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  683. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  684. # fig.tight_layout()
  685. # fig.savefig('wh.png', dpi=200)
  686. # Evolve
  687. npr = np.random
  688. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  689. pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
  690. for _ in pbar:
  691. v = np.ones(sh)
  692. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  693. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  694. kg = (k.copy() * v).clip(min=2.0)
  695. fg = fitness(kg)
  696. if fg > f:
  697. f, k = fg, kg.copy()
  698. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  699. if verbose:
  700. print_results(k)
  701. return print_results(k)
  702. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  703. # Print mutation results to evolve.txt (for use with train.py --evolve)
  704. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  705. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  706. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  707. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  708. if bucket:
  709. url = 'gs://%s/evolve.txt' % bucket
  710. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  711. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  712. with open('evolve.txt', 'a') as f: # append result
  713. f.write(c + b + '\n')
  714. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  715. x = x[np.argsort(-fitness(x))] # sort
  716. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  717. # Save yaml
  718. for i, k in enumerate(hyp.keys()):
  719. hyp[k] = float(x[0, i + 7])
  720. with open(yaml_file, 'w') as f:
  721. results = tuple(x[0, :7])
  722. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  723. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  724. yaml.dump(hyp, f, sort_keys=False)
  725. if bucket:
  726. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  727. def apply_classifier(x, model, img, im0):
  728. # applies a second stage classifier to yolo outputs
  729. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  730. for i, d in enumerate(x): # per image
  731. if d is not None and len(d):
  732. d = d.clone()
  733. # Reshape and pad cutouts
  734. b = xyxy2xywh(d[:, :4]) # boxes
  735. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  736. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  737. d[:, :4] = xywh2xyxy(b).long()
  738. # Rescale boxes from img_size to im0 size
  739. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  740. # Classes
  741. pred_cls1 = d[:, 5].long()
  742. ims = []
  743. for j, a in enumerate(d): # per item
  744. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  745. im = cv2.resize(cutout, (224, 224)) # BGR
  746. # cv2.imwrite('test%i.jpg' % j, cutout)
  747. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  748. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  749. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  750. ims.append(im)
  751. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  752. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  753. return x
  754. def fitness(x):
  755. # Returns fitness (for use with results.txt or evolve.txt)
  756. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  757. return (x[:, :4] * w).sum(1)
  758. def output_to_target(output, width, height):
  759. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  760. if isinstance(output, torch.Tensor):
  761. output = output.cpu().numpy()
  762. targets = []
  763. for i, o in enumerate(output):
  764. if o is not None:
  765. for pred in o:
  766. box = pred[:4]
  767. w = (box[2] - box[0]) / width
  768. h = (box[3] - box[1]) / height
  769. x = box[0] / width + w / 2
  770. y = box[1] / height + h / 2
  771. conf = pred[4]
  772. cls = int(pred[5])
  773. targets.append([i, cls, x, y, w, h, conf])
  774. return np.array(targets)
  775. def increment_dir(dir, comment=''):
  776. # Increments a directory runs/exp1 --> runs/exp2_comment
  777. n = 0 # number
  778. dir = str(Path(dir)) # os-agnostic
  779. dirs = sorted(glob.glob(dir + '*')) # directories
  780. if dirs:
  781. matches = [re.search(r"exp(\d+)", d) for d in dirs]
  782. idxs = [int(m.groups()[0]) for m in matches if m]
  783. if idxs:
  784. n = max(idxs) + 1 # increment
  785. return dir + str(n) + ('_' + comment if comment else '')
  786. # Plotting functions ---------------------------------------------------------------------------------------------------
  787. def hist2d(x, y, n=100):
  788. # 2d histogram used in labels.png and evolve.png
  789. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  790. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  791. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  792. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  793. return np.log(hist[xidx, yidx])
  794. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  795. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  796. def butter_lowpass(cutoff, fs, order):
  797. nyq = 0.5 * fs
  798. normal_cutoff = cutoff / nyq
  799. b, a = butter(order, normal_cutoff, btype='low', analog=False)
  800. return b, a
  801. b, a = butter_lowpass(cutoff, fs, order=order)
  802. return filtfilt(b, a, data) # forward-backward filter
  803. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  804. # Plots one bounding box on image img
  805. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  806. color = color or [random.randint(0, 255) for _ in range(3)]
  807. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  808. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  809. if label:
  810. tf = max(tl - 1, 1) # font thickness
  811. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  812. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  813. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  814. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  815. def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
  816. # Compares the two methods for width-height anchor multiplication
  817. # https://github.com/ultralytics/yolov3/issues/168
  818. x = np.arange(-4.0, 4.0, .1)
  819. ya = np.exp(x)
  820. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  821. fig = plt.figure(figsize=(6, 3), dpi=150)
  822. plt.plot(x, ya, '.-', label='YOLOv3')
  823. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  824. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  825. plt.xlim(left=-4, right=4)
  826. plt.ylim(bottom=0, top=6)
  827. plt.xlabel('input')
  828. plt.ylabel('output')
  829. plt.grid()
  830. plt.legend()
  831. fig.tight_layout()
  832. fig.savefig('comparison.png', dpi=200)
  833. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  834. tl = 3 # line thickness
  835. tf = max(tl - 1, 1) # font thickness
  836. if isinstance(images, torch.Tensor):
  837. images = images.cpu().float().numpy()
  838. if isinstance(targets, torch.Tensor):
  839. targets = targets.cpu().numpy()
  840. # un-normalise
  841. if np.max(images[0]) <= 1:
  842. images *= 255
  843. bs, _, h, w = images.shape # batch size, _, height, width
  844. bs = min(bs, max_subplots) # limit plot images
  845. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  846. # Check if we should resize
  847. scale_factor = max_size / max(h, w)
  848. if scale_factor < 1:
  849. h = math.ceil(scale_factor * h)
  850. w = math.ceil(scale_factor * w)
  851. # Empty array for output
  852. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
  853. # Fix class - colour map
  854. prop_cycle = plt.rcParams['axes.prop_cycle']
  855. # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  856. hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  857. color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
  858. for i, img in enumerate(images):
  859. if i == max_subplots: # if last batch has fewer images than we expect
  860. break
  861. block_x = int(w * (i // ns))
  862. block_y = int(h * (i % ns))
  863. img = img.transpose(1, 2, 0)
  864. if scale_factor < 1:
  865. img = cv2.resize(img, (w, h))
  866. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  867. if len(targets) > 0:
  868. image_targets = targets[targets[:, 0] == i]
  869. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  870. classes = image_targets[:, 1].astype('int')
  871. gt = image_targets.shape[1] == 6 # ground truth if no conf column
  872. conf = None if gt else image_targets[:, 6] # check for confidence presence (gt vs pred)
  873. boxes[[0, 2]] *= w
  874. boxes[[0, 2]] += block_x
  875. boxes[[1, 3]] *= h
  876. boxes[[1, 3]] += block_y
  877. for j, box in enumerate(boxes.T):
  878. cls = int(classes[j])
  879. color = color_lut[cls % len(color_lut)]
  880. cls = names[cls] if names else cls
  881. if gt or conf[j] > 0.3: # 0.3 conf thresh
  882. label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j])
  883. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  884. # Draw image filename labels
  885. if paths is not None:
  886. label = os.path.basename(paths[i])[:40] # trim to 40 char
  887. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  888. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  889. lineType=cv2.LINE_AA)
  890. # Image border
  891. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  892. if fname is not None:
  893. mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
  894. cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
  895. return mosaic
  896. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  897. # Plot LR simulating training for full epochs
  898. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  899. y = []
  900. for _ in range(epochs):
  901. scheduler.step()
  902. y.append(optimizer.param_groups[0]['lr'])
  903. plt.plot(y, '.-', label='LR')
  904. plt.xlabel('epoch')
  905. plt.ylabel('LR')
  906. plt.grid()
  907. plt.xlim(0, epochs)
  908. plt.ylim(0)
  909. plt.tight_layout()
  910. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  911. def plot_test_txt(): # from utils.general import *; plot_test()
  912. # Plot test.txt histograms
  913. x = np.loadtxt('test.txt', dtype=np.float32)
  914. box = xyxy2xywh(x[:, :4])
  915. cx, cy = box[:, 0], box[:, 1]
  916. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  917. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  918. ax.set_aspect('equal')
  919. plt.savefig('hist2d.png', dpi=300)
  920. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  921. ax[0].hist(cx, bins=600)
  922. ax[1].hist(cy, bins=600)
  923. plt.savefig('hist1d.png', dpi=200)
  924. def plot_targets_txt(): # from utils.general import *; plot_targets_txt()
  925. # Plot targets.txt histograms
  926. x = np.loadtxt('targets.txt', dtype=np.float32).T
  927. s = ['x targets', 'y targets', 'width targets', 'height targets']
  928. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  929. ax = ax.ravel()
  930. for i in range(4):
  931. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  932. ax[i].legend()
  933. ax[i].set_title(s[i])
  934. plt.savefig('targets.jpg', dpi=200)
  935. def plot_study_txt(f='study.txt', x=None): # from utils.general import *; plot_study_txt()
  936. # Plot study.txt generated by test.py
  937. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  938. ax = ax.ravel()
  939. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  940. for f in ['study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  941. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  942. x = np.arange(y.shape[1]) if x is None else np.array(x)
  943. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  944. for i in range(7):
  945. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  946. ax[i].set_title(s[i])
  947. j = y[3].argmax() + 1
  948. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  949. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  950. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  951. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  952. ax2.grid()
  953. ax2.set_xlim(0, 30)
  954. ax2.set_ylim(28, 50)
  955. ax2.set_yticks(np.arange(30, 55, 5))
  956. ax2.set_xlabel('GPU Speed (ms/img)')
  957. ax2.set_ylabel('COCO AP val')
  958. ax2.legend(loc='lower right')
  959. plt.savefig('study_mAP_latency.png', dpi=300)
  960. plt.savefig(f.replace('.txt', '.png'), dpi=300)
  961. def plot_labels(labels, save_dir=''):
  962. # plot dataset labels
  963. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  964. nc = int(c.max() + 1) # number of classes
  965. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  966. ax = ax.ravel()
  967. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  968. ax[0].set_xlabel('classes')
  969. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  970. ax[1].set_xlabel('x')
  971. ax[1].set_ylabel('y')
  972. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  973. ax[2].set_xlabel('width')
  974. ax[2].set_ylabel('height')
  975. plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
  976. plt.close()
  977. # seaborn correlogram
  978. try:
  979. import seaborn as sns
  980. import pandas as pd
  981. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  982. sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
  983. plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
  984. diag_kws=dict(bins=50))
  985. plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
  986. plt.close()
  987. except Exception as e:
  988. pass
  989. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general import *; plot_evolution()
  990. # Plot hyperparameter evolution results in evolve.txt
  991. with open(yaml_file) as f:
  992. hyp = yaml.load(f, Loader=yaml.FullLoader)
  993. x = np.loadtxt('evolve.txt', ndmin=2)
  994. f = fitness(x)
  995. # weights = (f - f.min()) ** 2 # for weighted results
  996. plt.figure(figsize=(10, 12), tight_layout=True)
  997. matplotlib.rc('font', **{'size': 8})
  998. for i, (k, v) in enumerate(hyp.items()):
  999. y = x[:, i + 7]
  1000. # mu = (y * weights).sum() / weights.sum() # best weighted result
  1001. mu = y[f.argmax()] # best single result
  1002. plt.subplot(6, 5, i + 1)
  1003. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  1004. plt.plot(mu, f.max(), 'k+', markersize=15)
  1005. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  1006. if i % 5 != 0:
  1007. plt.yticks([])
  1008. print('%15s: %.3g' % (k, mu))
  1009. plt.savefig('evolve.png', dpi=200)
  1010. print('\nPlot saved as evolve.png')
  1011. def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
  1012. # Plot training 'results*.txt', overlaying train and val losses
  1013. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  1014. t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  1015. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  1016. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  1017. n = results.shape[1] # number of rows
  1018. x = range(start, min(stop, n) if stop else n)
  1019. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  1020. ax = ax.ravel()
  1021. for i in range(5):
  1022. for j in [i, i + 5]:
  1023. y = results[j, x]
  1024. ax[i].plot(x, y, marker='.', label=s[j])
  1025. # y_smooth = butter_lowpass_filtfilt(y)
  1026. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  1027. ax[i].set_title(t[i])
  1028. ax[i].legend()
  1029. ax[i].set_ylabel(f) if i == 0 else None # add filename
  1030. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  1031. def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
  1032. # from utils.general import *; plot_results(save_dir='runs/exp0')
  1033. # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
  1034. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  1035. ax = ax.ravel()
  1036. s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
  1037. 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  1038. if bucket:
  1039. # os.system('rm -rf storage.googleapis.com')
  1040. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  1041. files = ['results%g.txt' % x for x in id]
  1042. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  1043. os.system(c)
  1044. else:
  1045. files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
  1046. assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
  1047. for fi, f in enumerate(files):
  1048. try:
  1049. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  1050. n = results.shape[1] # number of rows
  1051. x = range(start, min(stop, n) if stop else n)
  1052. for i in range(10):
  1053. y = results[i, x]
  1054. if i in [0, 1, 2, 5, 6, 7]:
  1055. y[y == 0] = np.nan # don't show zero loss values
  1056. # y /= y[0] # normalize
  1057. label = labels[fi] if len(labels) else Path(f).stem
  1058. ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
  1059. ax[i].set_title(s[i])
  1060. # if i in [5, 6, 7]: # share train and val loss y axes
  1061. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  1062. except Exception as e:
  1063. print('Warning: Plotting error for %s; %s' % (f, e))
  1064. fig.tight_layout()
  1065. ax[1].legend()
  1066. fig.savefig(Path(save_dir) / 'results.png', dpi=200)