You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1288 lines
52KB

  1. import glob
  2. import logging
  3. import math
  4. import os
  5. import platform
  6. import random
  7. import shutil
  8. import subprocess
  9. import time
  10. from contextlib import contextmanager
  11. from copy import copy
  12. from pathlib import Path
  13. import cv2
  14. import matplotlib
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import torch
  18. import torch.nn as nn
  19. import yaml
  20. from scipy.cluster.vq import kmeans
  21. from scipy.signal import butter, filtfilt
  22. from tqdm import tqdm
  23. from utils.google_utils import gsutil_getsize
  24. from utils.torch_utils import init_seeds as init_torch_seeds
  25. from utils.torch_utils import is_parallel
  26. # Set printoptions
  27. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  28. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  29. matplotlib.rc('font', **{'size': 11})
  30. # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
  31. cv2.setNumThreads(0)
  32. @contextmanager
  33. def torch_distributed_zero_first(local_rank: int):
  34. """
  35. Decorator to make all processes in distributed training wait for each local_master to do something.
  36. """
  37. if local_rank not in [-1, 0]:
  38. torch.distributed.barrier()
  39. yield
  40. if local_rank == 0:
  41. torch.distributed.barrier()
  42. def set_logging(rank=-1):
  43. logging.basicConfig(
  44. format="%(message)s",
  45. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  46. def init_seeds(seed=0):
  47. random.seed(seed)
  48. np.random.seed(seed)
  49. init_torch_seeds(seed=seed)
  50. def get_latest_run(search_dir='./runs'):
  51. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  52. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  53. return max(last_list, key=os.path.getctime) if last_list else ''
  54. def check_git_status():
  55. # Suggest 'git pull' if repo is out of date
  56. if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
  57. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  58. if 'Your branch is behind' in s:
  59. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  60. def check_img_size(img_size, s=32):
  61. # Verify img_size is a multiple of stride s
  62. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  63. if new_size != img_size:
  64. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  65. return new_size
  66. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  67. # Check anchor fit to data, recompute if necessary
  68. print('\nAnalyzing anchors... ', end='')
  69. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  70. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  71. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  72. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  73. def metric(k): # compute metric
  74. r = wh[:, None] / k[None]
  75. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  76. best = x.max(1)[0] # best_x
  77. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  78. bpr = (best > 1. / thr).float().mean() # best possible recall
  79. return bpr, aat
  80. bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
  81. print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
  82. if bpr < 0.98: # threshold to recompute
  83. print('. Attempting to generate improved anchors, please wait...' % bpr)
  84. na = m.anchor_grid.numel() // 2 # number of anchors
  85. new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  86. new_bpr = metric(new_anchors.reshape(-1, 2))[0]
  87. if new_bpr > bpr: # replace anchors
  88. new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
  89. m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
  90. m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  91. check_anchor_order(m)
  92. print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  93. else:
  94. print('Original anchors better than new anchors. Proceeding with original anchors.')
  95. print('') # newline
  96. def check_anchor_order(m):
  97. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  98. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  99. da = a[-1] - a[0] # delta a
  100. ds = m.stride[-1] - m.stride[0] # delta s
  101. if da.sign() != ds.sign(): # same order
  102. print('Reversing anchor order')
  103. m.anchors[:] = m.anchors.flip(0)
  104. m.anchor_grid[:] = m.anchor_grid.flip(0)
  105. def check_file(file):
  106. # Search for file if not found
  107. if os.path.isfile(file) or file == '':
  108. return file
  109. else:
  110. files = glob.glob('./**/' + file, recursive=True) # find file
  111. assert len(files), 'File Not Found: %s' % file # assert file was found
  112. return files[0] # return first file if multiple found
  113. def check_dataset(dict):
  114. # Download dataset if not found
  115. val, s = dict.get('val'), dict.get('download')
  116. if val and len(val):
  117. val = [os.path.abspath(x) for x in (val if isinstance(val, list) else [val])] # val path
  118. if not all(os.path.exists(x) for x in val):
  119. print('\nWARNING: Dataset not found, nonexistant paths: %s' % [*val])
  120. if s and len(s): # download script
  121. print('Downloading %s ...' % s)
  122. if s.startswith('http') and s.endswith('.zip'): # URL
  123. f = Path(s).name # filename
  124. torch.hub.download_url_to_file(s, f)
  125. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  126. else: # bash script
  127. r = os.system(s)
  128. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  129. else:
  130. raise Exception('Dataset not found.')
  131. def make_divisible(x, divisor):
  132. # Returns x evenly divisble by divisor
  133. return math.ceil(x / divisor) * divisor
  134. def labels_to_class_weights(labels, nc=80):
  135. # Get class weights (inverse frequency) from training labels
  136. if labels[0] is None: # no labels loaded
  137. return torch.Tensor()
  138. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  139. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  140. weights = np.bincount(classes, minlength=nc) # occurences per class
  141. # Prepend gridpoint count (for uCE trianing)
  142. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  143. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  144. weights[weights == 0] = 1 # replace empty bins with 1
  145. weights = 1 / weights # number of targets per class
  146. weights /= weights.sum() # normalize
  147. return torch.from_numpy(weights)
  148. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  149. # Produces image weights based on class mAPs
  150. n = len(labels)
  151. class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
  152. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  153. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  154. return image_weights
  155. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  156. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  157. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  158. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  159. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  160. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  161. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  162. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  163. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  164. return x
  165. def xyxy2xywh(x):
  166. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  167. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  168. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  169. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  170. y[:, 2] = x[:, 2] - x[:, 0] # width
  171. y[:, 3] = x[:, 3] - x[:, 1] # height
  172. return y
  173. def xywh2xyxy(x):
  174. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  175. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  176. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  177. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  178. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  179. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  180. return y
  181. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  182. # Rescale coords (xyxy) from img1_shape to img0_shape
  183. if ratio_pad is None: # calculate from img0_shape
  184. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  185. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  186. else:
  187. gain = ratio_pad[0][0]
  188. pad = ratio_pad[1]
  189. coords[:, [0, 2]] -= pad[0] # x padding
  190. coords[:, [1, 3]] -= pad[1] # y padding
  191. coords[:, :4] /= gain
  192. clip_coords(coords, img0_shape)
  193. return coords
  194. def clip_coords(boxes, img_shape):
  195. # Clip bounding xyxy bounding boxes to image shape (height, width)
  196. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  197. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  198. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  199. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  200. def ap_per_class(tp, conf, pred_cls, target_cls):
  201. """ Compute the average precision, given the recall and precision curves.
  202. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  203. # Arguments
  204. tp: True positives (nparray, nx1 or nx10).
  205. conf: Objectness value from 0-1 (nparray).
  206. pred_cls: Predicted object classes (nparray).
  207. target_cls: True object classes (nparray).
  208. # Returns
  209. The average precision as computed in py-faster-rcnn.
  210. """
  211. # Sort by objectness
  212. i = np.argsort(-conf)
  213. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  214. # Find unique classes
  215. unique_classes = np.unique(target_cls)
  216. # Create Precision-Recall curve and compute AP for each class
  217. pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
  218. s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
  219. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  220. for ci, c in enumerate(unique_classes):
  221. i = pred_cls == c
  222. n_gt = (target_cls == c).sum() # Number of ground truth objects
  223. n_p = i.sum() # Number of predicted objects
  224. if n_p == 0 or n_gt == 0:
  225. continue
  226. else:
  227. # Accumulate FPs and TPs
  228. fpc = (1 - tp[i]).cumsum(0)
  229. tpc = tp[i].cumsum(0)
  230. # Recall
  231. recall = tpc / (n_gt + 1e-16) # recall curve
  232. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
  233. # Precision
  234. precision = tpc / (tpc + fpc) # precision curve
  235. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
  236. # AP from recall-precision curve
  237. for j in range(tp.shape[1]):
  238. ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
  239. # Plot
  240. # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
  241. # ax.plot(recall, precision)
  242. # ax.set_xlabel('Recall')
  243. # ax.set_ylabel('Precision')
  244. # ax.set_xlim(0, 1.01)
  245. # ax.set_ylim(0, 1.01)
  246. # fig.tight_layout()
  247. # fig.savefig('PR_curve.png', dpi=300)
  248. # Compute F1 score (harmonic mean of precision and recall)
  249. f1 = 2 * p * r / (p + r + 1e-16)
  250. return p, r, ap, f1, unique_classes.astype('int32')
  251. def compute_ap(recall, precision):
  252. """ Compute the average precision, given the recall and precision curves.
  253. Source: https://github.com/rbgirshick/py-faster-rcnn.
  254. # Arguments
  255. recall: The recall curve (list).
  256. precision: The precision curve (list).
  257. # Returns
  258. The average precision as computed in py-faster-rcnn.
  259. """
  260. # Append sentinel values to beginning and end
  261. mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
  262. mpre = np.concatenate(([0.], precision, [0.]))
  263. # Compute the precision envelope
  264. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  265. # Integrate area under curve
  266. method = 'interp' # methods: 'continuous', 'interp'
  267. if method == 'interp':
  268. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  269. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  270. else: # 'continuous'
  271. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  272. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  273. return ap
  274. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  275. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  276. box2 = box2.T
  277. # Get the coordinates of bounding boxes
  278. if x1y1x2y2: # x1, y1, x2, y2 = box1
  279. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  280. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  281. else: # transform from xywh to xyxy
  282. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  283. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  284. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  285. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  286. # Intersection area
  287. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  288. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  289. # Union Area
  290. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  291. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  292. union = w1 * h1 + w2 * h2 - inter + eps
  293. iou = inter / union
  294. if GIoU or DIoU or CIoU:
  295. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  296. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  297. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  298. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  299. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  300. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  301. if DIoU:
  302. return iou - rho2 / c2 # DIoU
  303. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  304. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  305. with torch.no_grad():
  306. alpha = v / ((1 + eps) - iou + v)
  307. return iou - (rho2 / c2 + v * alpha) # CIoU
  308. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  309. c_area = cw * ch + eps # convex area
  310. return iou - (c_area - union) / c_area # GIoU
  311. else:
  312. return iou # IoU
  313. def box_iou(box1, box2):
  314. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  315. """
  316. Return intersection-over-union (Jaccard index) of boxes.
  317. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  318. Arguments:
  319. box1 (Tensor[N, 4])
  320. box2 (Tensor[M, 4])
  321. Returns:
  322. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  323. IoU values for every element in boxes1 and boxes2
  324. """
  325. def box_area(box):
  326. # box = 4xn
  327. return (box[2] - box[0]) * (box[3] - box[1])
  328. area1 = box_area(box1.T)
  329. area2 = box_area(box2.T)
  330. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  331. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  332. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  333. def wh_iou(wh1, wh2):
  334. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  335. wh1 = wh1[:, None] # [N,1,2]
  336. wh2 = wh2[None] # [1,M,2]
  337. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  338. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  339. class FocalLoss(nn.Module):
  340. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  341. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  342. super(FocalLoss, self).__init__()
  343. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  344. self.gamma = gamma
  345. self.alpha = alpha
  346. self.reduction = loss_fcn.reduction
  347. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  348. def forward(self, pred, true):
  349. loss = self.loss_fcn(pred, true)
  350. # p_t = torch.exp(-loss)
  351. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  352. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  353. pred_prob = torch.sigmoid(pred) # prob from logits
  354. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  355. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  356. modulating_factor = (1.0 - p_t) ** self.gamma
  357. loss *= alpha_factor * modulating_factor
  358. if self.reduction == 'mean':
  359. return loss.mean()
  360. elif self.reduction == 'sum':
  361. return loss.sum()
  362. else: # 'none'
  363. return loss
  364. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  365. # return positive, negative label smoothing BCE targets
  366. return 1.0 - 0.5 * eps, 0.5 * eps
  367. class BCEBlurWithLogitsLoss(nn.Module):
  368. # BCEwithLogitLoss() with reduced missing label effects.
  369. def __init__(self, alpha=0.05):
  370. super(BCEBlurWithLogitsLoss, self).__init__()
  371. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  372. self.alpha = alpha
  373. def forward(self, pred, true):
  374. loss = self.loss_fcn(pred, true)
  375. pred = torch.sigmoid(pred) # prob from logits
  376. dx = pred - true # reduce only missing label effects
  377. # dx = (pred - true).abs() # reduce missing label and false label effects
  378. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  379. loss *= alpha_factor
  380. return loss.mean()
  381. def compute_loss(p, targets, model): # predictions, targets, model
  382. device = targets.device
  383. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  384. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  385. h = model.hyp # hyperparameters
  386. # Define criteria
  387. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
  388. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
  389. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  390. cp, cn = smooth_BCE(eps=0.0)
  391. # Focal loss
  392. g = h['fl_gamma'] # focal loss gamma
  393. if g > 0:
  394. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  395. # Losses
  396. nt = 0 # number of targets
  397. np = len(p) # number of outputs
  398. balance = [4.0, 1.0, 0.4] if np == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
  399. for i, pi in enumerate(p): # layer index, layer predictions
  400. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  401. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  402. n = b.shape[0] # number of targets
  403. if n:
  404. nt += n # cumulative targets
  405. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  406. # Regression
  407. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  408. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  409. pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
  410. giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target)
  411. lbox += (1.0 - giou).mean() # giou loss
  412. # Objectness
  413. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
  414. # Classification
  415. if model.nc > 1: # cls loss (only if multiple classes)
  416. t = torch.full_like(ps[:, 5:], cn, device=device) # targets
  417. t[range(n), tcls[i]] = cp
  418. lcls += BCEcls(ps[:, 5:], t) # BCE
  419. # Append targets to text file
  420. # with open('targets.txt', 'a') as file:
  421. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  422. lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
  423. s = 3 / np # output count scaling
  424. lbox *= h['giou'] * s
  425. lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
  426. lcls *= h['cls'] * s
  427. bs = tobj.shape[0] # batch size
  428. loss = lbox + lobj + lcls
  429. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  430. def build_targets(p, targets, model):
  431. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  432. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  433. na, nt = det.na, targets.shape[0] # number of anchors, targets
  434. tcls, tbox, indices, anch = [], [], [], []
  435. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  436. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  437. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  438. g = 0.5 # bias
  439. off = torch.tensor([[0, 0],
  440. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  441. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  442. ], device=targets.device).float() * g # offsets
  443. for i in range(det.nl):
  444. anchors = det.anchors[i]
  445. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  446. # Match targets to anchors
  447. t = targets * gain
  448. if nt:
  449. # Matches
  450. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  451. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  452. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  453. t = t[j] # filter
  454. # Offsets
  455. gxy = t[:, 2:4] # grid xy
  456. gxi = gain[[2, 3]] - gxy # inverse
  457. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  458. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  459. j = torch.stack((torch.ones_like(j), j, k, l, m))
  460. t = t.repeat((5, 1, 1))[j]
  461. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  462. else:
  463. t = targets[0]
  464. offsets = 0
  465. # Define
  466. b, c = t[:, :2].long().T # image, class
  467. gxy = t[:, 2:4] # grid xy
  468. gwh = t[:, 4:6] # grid wh
  469. gij = (gxy - offsets).long()
  470. gi, gj = gij.T # grid xy indices
  471. # Append
  472. a = t[:, 6].long() # anchor indices
  473. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  474. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  475. anch.append(anchors[a]) # anchors
  476. tcls.append(c) # class
  477. return tcls, tbox, indices, anch
  478. def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
  479. """Performs Non-Maximum Suppression (NMS) on inference results
  480. Returns:
  481. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  482. """
  483. nc = prediction[0].shape[1] - 5 # number of classes
  484. xc = prediction[..., 4] > conf_thres # candidates
  485. # Settings
  486. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  487. max_det = 300 # maximum number of detections per image
  488. time_limit = 10.0 # seconds to quit after
  489. redundant = True # require redundant detections
  490. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  491. t = time.time()
  492. output = [None] * prediction.shape[0]
  493. for xi, x in enumerate(prediction): # image index, image inference
  494. # Apply constraints
  495. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  496. x = x[xc[xi]] # confidence
  497. # If none remain process next image
  498. if not x.shape[0]:
  499. continue
  500. # Compute conf
  501. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  502. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  503. box = xywh2xyxy(x[:, :4])
  504. # Detections matrix nx6 (xyxy, conf, cls)
  505. if multi_label:
  506. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  507. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  508. else: # best class only
  509. conf, j = x[:, 5:].max(1, keepdim=True)
  510. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  511. # Filter by class
  512. if classes:
  513. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  514. # Apply finite constraint
  515. # if not torch.isfinite(x).all():
  516. # x = x[torch.isfinite(x).all(1)]
  517. # If none remain process next image
  518. n = x.shape[0] # number of boxes
  519. if not n:
  520. continue
  521. # Sort by confidence
  522. # x = x[x[:, 4].argsort(descending=True)]
  523. # Batched NMS
  524. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  525. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  526. i = torch.ops.torchvision.nms(boxes, scores, iou_thres)
  527. if i.shape[0] > max_det: # limit detections
  528. i = i[:max_det]
  529. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  530. try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  531. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  532. weights = iou * scores[None] # box weights
  533. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  534. if redundant:
  535. i = i[iou.sum(1) > 1] # require redundancy
  536. except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
  537. print(x, i, x.shape, i.shape)
  538. pass
  539. output[xi] = x[i]
  540. if (time.time() - t) > time_limit:
  541. break # time limit exceeded
  542. return output
  543. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  544. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  545. x = torch.load(f, map_location=torch.device('cpu'))
  546. x['optimizer'] = None
  547. x['training_results'] = None
  548. x['epoch'] = -1
  549. x['model'].half() # to FP16
  550. for p in x['model'].parameters():
  551. p.requires_grad = False
  552. torch.save(x, s or f)
  553. mb = os.path.getsize(s or f) / 1E6 # filesize
  554. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  555. def coco_class_count(path='../coco/labels/train2014/'):
  556. # Histogram of occurrences per class
  557. nc = 80 # number classes
  558. x = np.zeros(nc, dtype='int32')
  559. files = sorted(glob.glob('%s/*.*' % path))
  560. for i, file in enumerate(files):
  561. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  562. x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
  563. print(i, len(files))
  564. def coco_only_people(path='../coco/labels/train2017/'): # from utils.general import *; coco_only_people()
  565. # Find images with only people
  566. files = sorted(glob.glob('%s/*.*' % path))
  567. for i, file in enumerate(files):
  568. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  569. if all(labels[:, 0] == 0):
  570. print(labels.shape[0], file)
  571. def crop_images_random(path='../images/', scale=0.50): # from utils.general import *; crop_images_random()
  572. # crops images into random squares up to scale fraction
  573. # WARNING: overwrites images!
  574. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  575. img = cv2.imread(file) # BGR
  576. if img is not None:
  577. h, w = img.shape[:2]
  578. # create random mask
  579. a = 30 # minimum size (pixels)
  580. mask_h = random.randint(a, int(max(a, h * scale))) # mask height
  581. mask_w = mask_h # mask width
  582. # box
  583. xmin = max(0, random.randint(0, w) - mask_w // 2)
  584. ymin = max(0, random.randint(0, h) - mask_h // 2)
  585. xmax = min(w, xmin + mask_w)
  586. ymax = min(h, ymin + mask_h)
  587. # apply random color mask
  588. cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
  589. def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
  590. # Makes single-class coco datasets. from utils.general import *; coco_single_class_labels()
  591. if os.path.exists('new/'):
  592. shutil.rmtree('new/') # delete output folder
  593. os.makedirs('new/') # make new output folder
  594. os.makedirs('new/labels/')
  595. os.makedirs('new/images/')
  596. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  597. with open(file, 'r') as f:
  598. labels = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  599. i = labels[:, 0] == label_class
  600. if any(i):
  601. img_file = file.replace('labels', 'images').replace('txt', 'jpg')
  602. labels[:, 0] = 0 # reset class to 0
  603. with open('new/images.txt', 'a') as f: # add image to dataset list
  604. f.write(img_file + '\n')
  605. with open('new/labels/' + Path(file).name, 'a') as f: # write label
  606. for l in labels[i]:
  607. f.write('%g %.6f %.6f %.6f %.6f\n' % tuple(l))
  608. shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
  609. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  610. """ Creates kmeans-evolved anchors from training dataset
  611. Arguments:
  612. path: path to dataset *.yaml, or a loaded dataset
  613. n: number of anchors
  614. img_size: image size used for training
  615. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  616. gen: generations to evolve anchors using genetic algorithm
  617. Return:
  618. k: kmeans evolved anchors
  619. Usage:
  620. from utils.general import *; _ = kmean_anchors()
  621. """
  622. thr = 1. / thr
  623. def metric(k, wh): # compute metrics
  624. r = wh[:, None] / k[None]
  625. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  626. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  627. return x, x.max(1)[0] # x, best_x
  628. def fitness(k): # mutation fitness
  629. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  630. return (best * (best > thr).float()).mean() # fitness
  631. def print_results(k):
  632. k = k[np.argsort(k.prod(1))] # sort small to large
  633. x, best = metric(k, wh0)
  634. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  635. print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
  636. print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
  637. (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
  638. for i, x in enumerate(k):
  639. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  640. return k
  641. if isinstance(path, str): # *.yaml file
  642. with open(path) as f:
  643. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  644. from utils.datasets import LoadImagesAndLabels
  645. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  646. else:
  647. dataset = path # dataset
  648. # Get label wh
  649. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  650. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  651. # Filter
  652. i = (wh0 < 3.0).any(1).sum()
  653. if i:
  654. print('WARNING: Extremely small objects found. '
  655. '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
  656. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  657. # Kmeans calculation
  658. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  659. s = wh.std(0) # sigmas for whitening
  660. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  661. k *= s
  662. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  663. wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered
  664. k = print_results(k)
  665. # Plot
  666. # k, d = [None] * 20, [None] * 20
  667. # for i in tqdm(range(1, 21)):
  668. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  669. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  670. # ax = ax.ravel()
  671. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  672. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  673. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  674. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  675. # fig.tight_layout()
  676. # fig.savefig('wh.png', dpi=200)
  677. # Evolve
  678. npr = np.random
  679. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  680. pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
  681. for _ in pbar:
  682. v = np.ones(sh)
  683. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  684. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  685. kg = (k.copy() * v).clip(min=2.0)
  686. fg = fitness(kg)
  687. if fg > f:
  688. f, k = fg, kg.copy()
  689. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  690. if verbose:
  691. print_results(k)
  692. return print_results(k)
  693. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  694. # Print mutation results to evolve.txt (for use with train.py --evolve)
  695. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  696. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  697. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  698. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  699. if bucket:
  700. url = 'gs://%s/evolve.txt' % bucket
  701. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  702. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  703. with open('evolve.txt', 'a') as f: # append result
  704. f.write(c + b + '\n')
  705. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  706. x = x[np.argsort(-fitness(x))] # sort
  707. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  708. # Save yaml
  709. for i, k in enumerate(hyp.keys()):
  710. hyp[k] = float(x[0, i + 7])
  711. with open(yaml_file, 'w') as f:
  712. results = tuple(x[0, :7])
  713. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  714. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  715. yaml.dump(hyp, f, sort_keys=False)
  716. if bucket:
  717. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  718. def apply_classifier(x, model, img, im0):
  719. # applies a second stage classifier to yolo outputs
  720. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  721. for i, d in enumerate(x): # per image
  722. if d is not None and len(d):
  723. d = d.clone()
  724. # Reshape and pad cutouts
  725. b = xyxy2xywh(d[:, :4]) # boxes
  726. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  727. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  728. d[:, :4] = xywh2xyxy(b).long()
  729. # Rescale boxes from img_size to im0 size
  730. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  731. # Classes
  732. pred_cls1 = d[:, 5].long()
  733. ims = []
  734. for j, a in enumerate(d): # per item
  735. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  736. im = cv2.resize(cutout, (224, 224)) # BGR
  737. # cv2.imwrite('test%i.jpg' % j, cutout)
  738. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  739. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  740. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  741. ims.append(im)
  742. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  743. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  744. return x
  745. def fitness(x):
  746. # Returns fitness (for use with results.txt or evolve.txt)
  747. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  748. return (x[:, :4] * w).sum(1)
  749. def output_to_target(output, width, height):
  750. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  751. if isinstance(output, torch.Tensor):
  752. output = output.cpu().numpy()
  753. targets = []
  754. for i, o in enumerate(output):
  755. if o is not None:
  756. for pred in o:
  757. box = pred[:4]
  758. w = (box[2] - box[0]) / width
  759. h = (box[3] - box[1]) / height
  760. x = box[0] / width + w / 2
  761. y = box[1] / height + h / 2
  762. conf = pred[4]
  763. cls = int(pred[5])
  764. targets.append([i, cls, x, y, w, h, conf])
  765. return np.array(targets)
  766. def increment_dir(dir, comment=''):
  767. # Increments a directory runs/exp1 --> runs/exp2_comment
  768. n = 0 # number
  769. dir = str(Path(dir)) # os-agnostic
  770. d = sorted(glob.glob(dir + '*')) # directories
  771. if len(d):
  772. n = max([int(x[len(dir):x.find('_') if '_' in x else None]) for x in d]) + 1 # increment
  773. return dir + str(n) + ('_' + comment if comment else '')
  774. # Plotting functions ---------------------------------------------------------------------------------------------------
  775. def hist2d(x, y, n=100):
  776. # 2d histogram used in labels.png and evolve.png
  777. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  778. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  779. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  780. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  781. return np.log(hist[xidx, yidx])
  782. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  783. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  784. def butter_lowpass(cutoff, fs, order):
  785. nyq = 0.5 * fs
  786. normal_cutoff = cutoff / nyq
  787. b, a = butter(order, normal_cutoff, btype='low', analog=False)
  788. return b, a
  789. b, a = butter_lowpass(cutoff, fs, order=order)
  790. return filtfilt(b, a, data) # forward-backward filter
  791. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  792. # Plots one bounding box on image img
  793. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  794. color = color or [random.randint(0, 255) for _ in range(3)]
  795. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  796. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  797. if label:
  798. tf = max(tl - 1, 1) # font thickness
  799. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  800. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  801. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  802. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  803. def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
  804. # Compares the two methods for width-height anchor multiplication
  805. # https://github.com/ultralytics/yolov3/issues/168
  806. x = np.arange(-4.0, 4.0, .1)
  807. ya = np.exp(x)
  808. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  809. fig = plt.figure(figsize=(6, 3), dpi=150)
  810. plt.plot(x, ya, '.-', label='YOLOv3')
  811. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  812. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  813. plt.xlim(left=-4, right=4)
  814. plt.ylim(bottom=0, top=6)
  815. plt.xlabel('input')
  816. plt.ylabel('output')
  817. plt.grid()
  818. plt.legend()
  819. fig.tight_layout()
  820. fig.savefig('comparison.png', dpi=200)
  821. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  822. tl = 3 # line thickness
  823. tf = max(tl - 1, 1) # font thickness
  824. if os.path.isfile(fname): # do not overwrite
  825. return None
  826. if isinstance(images, torch.Tensor):
  827. images = images.cpu().float().numpy()
  828. if isinstance(targets, torch.Tensor):
  829. targets = targets.cpu().numpy()
  830. # un-normalise
  831. if np.max(images[0]) <= 1:
  832. images *= 255
  833. bs, _, h, w = images.shape # batch size, _, height, width
  834. bs = min(bs, max_subplots) # limit plot images
  835. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  836. # Check if we should resize
  837. scale_factor = max_size / max(h, w)
  838. if scale_factor < 1:
  839. h = math.ceil(scale_factor * h)
  840. w = math.ceil(scale_factor * w)
  841. # Empty array for output
  842. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
  843. # Fix class - colour map
  844. prop_cycle = plt.rcParams['axes.prop_cycle']
  845. # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  846. hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  847. color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
  848. for i, img in enumerate(images):
  849. if i == max_subplots: # if last batch has fewer images than we expect
  850. break
  851. block_x = int(w * (i // ns))
  852. block_y = int(h * (i % ns))
  853. img = img.transpose(1, 2, 0)
  854. if scale_factor < 1:
  855. img = cv2.resize(img, (w, h))
  856. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  857. if len(targets) > 0:
  858. image_targets = targets[targets[:, 0] == i]
  859. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  860. classes = image_targets[:, 1].astype('int')
  861. gt = image_targets.shape[1] == 6 # ground truth if no conf column
  862. conf = None if gt else image_targets[:, 6] # check for confidence presence (gt vs pred)
  863. boxes[[0, 2]] *= w
  864. boxes[[0, 2]] += block_x
  865. boxes[[1, 3]] *= h
  866. boxes[[1, 3]] += block_y
  867. for j, box in enumerate(boxes.T):
  868. cls = int(classes[j])
  869. color = color_lut[cls % len(color_lut)]
  870. cls = names[cls] if names else cls
  871. if gt or conf[j] > 0.3: # 0.3 conf thresh
  872. label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j])
  873. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  874. # Draw image filename labels
  875. if paths is not None:
  876. label = os.path.basename(paths[i])[:40] # trim to 40 char
  877. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  878. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  879. lineType=cv2.LINE_AA)
  880. # Image border
  881. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  882. if fname is not None:
  883. mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
  884. cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
  885. return mosaic
  886. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  887. # Plot LR simulating training for full epochs
  888. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  889. y = []
  890. for _ in range(epochs):
  891. scheduler.step()
  892. y.append(optimizer.param_groups[0]['lr'])
  893. plt.plot(y, '.-', label='LR')
  894. plt.xlabel('epoch')
  895. plt.ylabel('LR')
  896. plt.grid()
  897. plt.xlim(0, epochs)
  898. plt.ylim(0)
  899. plt.tight_layout()
  900. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  901. def plot_test_txt(): # from utils.general import *; plot_test()
  902. # Plot test.txt histograms
  903. x = np.loadtxt('test.txt', dtype=np.float32)
  904. box = xyxy2xywh(x[:, :4])
  905. cx, cy = box[:, 0], box[:, 1]
  906. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  907. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  908. ax.set_aspect('equal')
  909. plt.savefig('hist2d.png', dpi=300)
  910. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  911. ax[0].hist(cx, bins=600)
  912. ax[1].hist(cy, bins=600)
  913. plt.savefig('hist1d.png', dpi=200)
  914. def plot_targets_txt(): # from utils.general import *; plot_targets_txt()
  915. # Plot targets.txt histograms
  916. x = np.loadtxt('targets.txt', dtype=np.float32).T
  917. s = ['x targets', 'y targets', 'width targets', 'height targets']
  918. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  919. ax = ax.ravel()
  920. for i in range(4):
  921. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  922. ax[i].legend()
  923. ax[i].set_title(s[i])
  924. plt.savefig('targets.jpg', dpi=200)
  925. def plot_study_txt(f='study.txt', x=None): # from utils.general import *; plot_study_txt()
  926. # Plot study.txt generated by test.py
  927. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  928. ax = ax.ravel()
  929. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  930. for f in ['study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  931. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  932. x = np.arange(y.shape[1]) if x is None else np.array(x)
  933. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  934. for i in range(7):
  935. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  936. ax[i].set_title(s[i])
  937. j = y[3].argmax() + 1
  938. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  939. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  940. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  941. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  942. ax2.grid()
  943. ax2.set_xlim(0, 30)
  944. ax2.set_ylim(28, 50)
  945. ax2.set_yticks(np.arange(30, 55, 5))
  946. ax2.set_xlabel('GPU Speed (ms/img)')
  947. ax2.set_ylabel('COCO AP val')
  948. ax2.legend(loc='lower right')
  949. plt.savefig('study_mAP_latency.png', dpi=300)
  950. plt.savefig(f.replace('.txt', '.png'), dpi=300)
  951. def plot_labels(labels, save_dir=''):
  952. # plot dataset labels
  953. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  954. nc = int(c.max() + 1) # number of classes
  955. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  956. ax = ax.ravel()
  957. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  958. ax[0].set_xlabel('classes')
  959. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  960. ax[1].set_xlabel('x')
  961. ax[1].set_ylabel('y')
  962. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  963. ax[2].set_xlabel('width')
  964. ax[2].set_ylabel('height')
  965. plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
  966. plt.close()
  967. # seaborn correlogram
  968. try:
  969. import seaborn as sns
  970. import pandas as pd
  971. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  972. sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
  973. plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
  974. diag_kws=dict(bins=50))
  975. plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
  976. plt.close()
  977. except Exception as e:
  978. pass
  979. def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general import *; plot_evolution()
  980. # Plot hyperparameter evolution results in evolve.txt
  981. with open(yaml_file) as f:
  982. hyp = yaml.load(f, Loader=yaml.FullLoader)
  983. x = np.loadtxt('evolve.txt', ndmin=2)
  984. f = fitness(x)
  985. # weights = (f - f.min()) ** 2 # for weighted results
  986. plt.figure(figsize=(10, 10), tight_layout=True)
  987. matplotlib.rc('font', **{'size': 8})
  988. for i, (k, v) in enumerate(hyp.items()):
  989. y = x[:, i + 7]
  990. # mu = (y * weights).sum() / weights.sum() # best weighted result
  991. mu = y[f.argmax()] # best single result
  992. plt.subplot(5, 5, i + 1)
  993. plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  994. plt.plot(mu, f.max(), 'k+', markersize=15)
  995. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  996. if i % 5 != 0:
  997. plt.yticks([])
  998. print('%15s: %.3g' % (k, mu))
  999. plt.savefig('evolve.png', dpi=200)
  1000. print('\nPlot saved as evolve.png')
  1001. def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
  1002. # Plot training 'results*.txt', overlaying train and val losses
  1003. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  1004. t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  1005. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  1006. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  1007. n = results.shape[1] # number of rows
  1008. x = range(start, min(stop, n) if stop else n)
  1009. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  1010. ax = ax.ravel()
  1011. for i in range(5):
  1012. for j in [i, i + 5]:
  1013. y = results[j, x]
  1014. ax[i].plot(x, y, marker='.', label=s[j])
  1015. # y_smooth = butter_lowpass_filtfilt(y)
  1016. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  1017. ax[i].set_title(t[i])
  1018. ax[i].legend()
  1019. ax[i].set_ylabel(f) if i == 0 else None # add filename
  1020. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  1021. def plot_results(start=0, stop=0, bucket='', id=(), labels=(),
  1022. save_dir=''): # from utils.general import *; plot_results()
  1023. # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
  1024. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  1025. ax = ax.ravel()
  1026. s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
  1027. 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  1028. if bucket:
  1029. # os.system('rm -rf storage.googleapis.com')
  1030. # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  1031. files = ['results%g.txt' % x for x in id]
  1032. c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
  1033. os.system(c)
  1034. else:
  1035. files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
  1036. for fi, f in enumerate(files):
  1037. try:
  1038. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  1039. n = results.shape[1] # number of rows
  1040. x = range(start, min(stop, n) if stop else n)
  1041. for i in range(10):
  1042. y = results[i, x]
  1043. if i in [0, 1, 2, 5, 6, 7]:
  1044. y[y == 0] = np.nan # dont show zero loss values
  1045. # y /= y[0] # normalize
  1046. label = labels[fi] if len(labels) else Path(f).stem
  1047. ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6)
  1048. ax[i].set_title(s[i])
  1049. # if i in [5, 6, 7]: # share train and val loss y axes
  1050. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  1051. except Exception as e:
  1052. print('Warning: Plotting error for %s; %s' % (f, e))
  1053. fig.tight_layout()
  1054. ax[1].legend()
  1055. fig.savefig(Path(save_dir) / 'results.png', dpi=200)