You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1230 lines
50KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import subprocess
  7. import time
  8. from contextlib import contextmanager
  9. from copy import copy
  10. from pathlib import Path
  11. from sys import platform
  12. import cv2
  13. import matplotlib
  14. import matplotlib.pyplot as plt
  15. import numpy as np
  16. import torch
  17. import torch.nn as nn
  18. import torchvision
  19. import yaml
  20. from scipy.signal import butter, filtfilt
  21. from tqdm import tqdm
  22. from . import torch_utils #  torch_utils, google_utils
  23. # Set printoptions
  24. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  25. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  26. matplotlib.rc('font', **{'size': 11})
  27. # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
  28. cv2.setNumThreads(0)
  29. @contextmanager
  30. def torch_distributed_zero_first(local_rank: int):
  31. """
  32. Decorator to make all processes in distributed training wait for each local_master to do something.
  33. """
  34. if local_rank not in [-1, 0]:
  35. torch.distributed.barrier()
  36. yield
  37. if local_rank == 0:
  38. torch.distributed.barrier()
  39. def init_seeds(seed=0):
  40. random.seed(seed)
  41. np.random.seed(seed)
  42. torch_utils.init_seeds(seed=seed)
  43. def get_latest_run(search_dir='./runs'):
  44. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  45. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  46. return max(last_list, key=os.path.getctime)
  47. def check_git_status():
  48. # Suggest 'git pull' if repo is out of date
  49. if platform in ['linux', 'darwin'] and not os.path.isfile('/.dockerenv'):
  50. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  51. if 'Your branch is behind' in s:
  52. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  53. def check_img_size(img_size, s=32):
  54. # Verify img_size is a multiple of stride s
  55. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  56. if new_size != img_size:
  57. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  58. return new_size
  59. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  60. # Check anchor fit to data, recompute if necessary
  61. print('\nAnalyzing anchors... ', end='')
  62. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  63. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  64. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  65. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  66. def metric(k): # compute metric
  67. r = wh[:, None] / k[None]
  68. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  69. best = x.max(1)[0] # best_x
  70. return (best > 1. / thr).float().mean() #  best possible recall
  71. bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
  72. print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
  73. if bpr < 0.99: # threshold to recompute
  74. print('. Attempting to generate improved anchors, please wait...' % bpr)
  75. na = m.anchor_grid.numel() // 2 # number of anchors
  76. new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  77. new_bpr = metric(new_anchors.reshape(-1, 2))
  78. if new_bpr > bpr: # replace anchors
  79. new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
  80. m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
  81. m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  82. check_anchor_order(m)
  83. print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  84. else:
  85. print('Original anchors better than new anchors. Proceeding with original anchors.')
  86. print('') # newline
  87. def check_anchor_order(m):
  88. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  89. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  90. da = a[-1] - a[0] # delta a
  91. ds = m.stride[-1] - m.stride[0] # delta s
  92. if da.sign() != ds.sign(): # same order
  93. print('Reversing anchor order')
  94. m.anchors[:] = m.anchors.flip(0)
  95. m.anchor_grid[:] = m.anchor_grid.flip(0)
  96. def check_file(file):
  97. # Searches for file if not found locally
  98. if os.path.isfile(file):
  99. return file
  100. else:
  101. files = glob.glob('./**/' + file, recursive=True) # find file
  102. assert len(files), 'File Not Found: %s' % file # assert file was found
  103. return files[0] # return first file if multiple found
  104. def make_divisible(x, divisor):
  105. # Returns x evenly divisble by divisor
  106. return math.ceil(x / divisor) * divisor
  107. def labels_to_class_weights(labels, nc=80):
  108. # Get class weights (inverse frequency) from training labels
  109. if labels[0] is None: # no labels loaded
  110. return torch.Tensor()
  111. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  112. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  113. weights = np.bincount(classes, minlength=nc) # occurences per class
  114. # Prepend gridpoint count (for uCE trianing)
  115. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  116. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  117. weights[weights == 0] = 1 # replace empty bins with 1
  118. weights = 1 / weights # number of targets per class
  119. weights /= weights.sum() # normalize
  120. return torch.from_numpy(weights)
  121. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  122. # Produces image weights based on class mAPs
  123. n = len(labels)
  124. class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
  125. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  126. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  127. return image_weights
  128. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  129. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  130. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  131. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  132. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  133. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  134. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  135. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  136. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  137. return x
  138. def xyxy2xywh(x):
  139. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  140. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  141. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  142. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  143. y[:, 2] = x[:, 2] - x[:, 0] # width
  144. y[:, 3] = x[:, 3] - x[:, 1] # height
  145. return y
  146. def xywh2xyxy(x):
  147. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  148. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  149. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  150. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  151. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  152. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  153. return y
  154. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  155. # Rescale coords (xyxy) from img1_shape to img0_shape
  156. if ratio_pad is None: # calculate from img0_shape
  157. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  158. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  159. else:
  160. gain = ratio_pad[0][0]
  161. pad = ratio_pad[1]
  162. coords[:, [0, 2]] -= pad[0] # x padding
  163. coords[:, [1, 3]] -= pad[1] # y padding
  164. coords[:, :4] /= gain
  165. clip_coords(coords, img0_shape)
  166. return coords
  167. def clip_coords(boxes, img_shape):
  168. # Clip bounding xyxy bounding boxes to image shape (height, width)
  169. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  170. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  171. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  172. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  173. def ap_per_class(tp, conf, pred_cls, target_cls):
  174. """ Compute the average precision, given the recall and precision curves.
  175. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  176. # Arguments
  177. tp: True positives (nparray, nx1 or nx10).
  178. conf: Objectness value from 0-1 (nparray).
  179. pred_cls: Predicted object classes (nparray).
  180. target_cls: True object classes (nparray).
  181. # Returns
  182. The average precision as computed in py-faster-rcnn.
  183. """
  184. # Sort by objectness
  185. i = np.argsort(-conf)
  186. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  187. # Find unique classes
  188. unique_classes = np.unique(target_cls)
  189. # Create Precision-Recall curve and compute AP for each class
  190. pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
  191. s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
  192. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  193. for ci, c in enumerate(unique_classes):
  194. i = pred_cls == c
  195. n_gt = (target_cls == c).sum() # Number of ground truth objects
  196. n_p = i.sum() # Number of predicted objects
  197. if n_p == 0 or n_gt == 0:
  198. continue
  199. else:
  200. # Accumulate FPs and TPs
  201. fpc = (1 - tp[i]).cumsum(0)
  202. tpc = tp[i].cumsum(0)
  203. # Recall
  204. recall = tpc / (n_gt + 1e-16) # recall curve
  205. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
  206. # Precision
  207. precision = tpc / (tpc + fpc) # precision curve
  208. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
  209. # AP from recall-precision curve
  210. for j in range(tp.shape[1]):
  211. ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
  212. # Plot
  213. # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
  214. # ax.plot(recall, precision)
  215. # ax.set_xlabel('Recall')
  216. # ax.set_ylabel('Precision')
  217. # ax.set_xlim(0, 1.01)
  218. # ax.set_ylim(0, 1.01)
  219. # fig.tight_layout()
  220. # fig.savefig('PR_curve.png', dpi=300)
  221. # Compute F1 score (harmonic mean of precision and recall)
  222. f1 = 2 * p * r / (p + r + 1e-16)
  223. return p, r, ap, f1, unique_classes.astype('int32')
  224. def compute_ap(recall, precision):
  225. """ Compute the average precision, given the recall and precision curves.
  226. Source: https://github.com/rbgirshick/py-faster-rcnn.
  227. # Arguments
  228. recall: The recall curve (list).
  229. precision: The precision curve (list).
  230. # Returns
  231. The average precision as computed in py-faster-rcnn.
  232. """
  233. # Append sentinel values to beginning and end
  234. mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
  235. mpre = np.concatenate(([0.], precision, [0.]))
  236. # Compute the precision envelope
  237. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  238. # Integrate area under curve
  239. method = 'interp' # methods: 'continuous', 'interp'
  240. if method == 'interp':
  241. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  242. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  243. else: # 'continuous'
  244. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  245. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  246. return ap
  247. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
  248. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  249. box2 = box2.t()
  250. # Get the coordinates of bounding boxes
  251. if x1y1x2y2: # x1, y1, x2, y2 = box1
  252. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  253. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  254. else: # transform from xywh to xyxy
  255. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  256. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  257. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  258. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  259. # Intersection area
  260. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  261. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  262. # Union Area
  263. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
  264. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
  265. union = (w1 * h1 + 1e-16) + w2 * h2 - inter
  266. iou = inter / union # iou
  267. if GIoU or DIoU or CIoU:
  268. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  269. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  270. if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
  271. c_area = cw * ch + 1e-16 # convex area
  272. return iou - (c_area - union) / c_area # GIoU
  273. if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  274. # convex diagonal squared
  275. c2 = cw ** 2 + ch ** 2 + 1e-16
  276. # centerpoint distance squared
  277. rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
  278. if DIoU:
  279. return iou - rho2 / c2 # DIoU
  280. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  281. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  282. with torch.no_grad():
  283. alpha = v / (1 - iou + v)
  284. return iou - (rho2 / c2 + v * alpha) # CIoU
  285. return iou
  286. def box_iou(box1, box2):
  287. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  288. """
  289. Return intersection-over-union (Jaccard index) of boxes.
  290. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  291. Arguments:
  292. box1 (Tensor[N, 4])
  293. box2 (Tensor[M, 4])
  294. Returns:
  295. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  296. IoU values for every element in boxes1 and boxes2
  297. """
  298. def box_area(box):
  299. # box = 4xn
  300. return (box[2] - box[0]) * (box[3] - box[1])
  301. area1 = box_area(box1.t())
  302. area2 = box_area(box2.t())
  303. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  304. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  305. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  306. def wh_iou(wh1, wh2):
  307. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  308. wh1 = wh1[:, None] # [N,1,2]
  309. wh2 = wh2[None] # [1,M,2]
  310. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  311. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  312. class FocalLoss(nn.Module):
  313. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  314. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  315. super(FocalLoss, self).__init__()
  316. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  317. self.gamma = gamma
  318. self.alpha = alpha
  319. self.reduction = loss_fcn.reduction
  320. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  321. def forward(self, pred, true):
  322. loss = self.loss_fcn(pred, true)
  323. # p_t = torch.exp(-loss)
  324. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  325. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  326. pred_prob = torch.sigmoid(pred) # prob from logits
  327. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  328. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  329. modulating_factor = (1.0 - p_t) ** self.gamma
  330. loss *= alpha_factor * modulating_factor
  331. if self.reduction == 'mean':
  332. return loss.mean()
  333. elif self.reduction == 'sum':
  334. return loss.sum()
  335. else: # 'none'
  336. return loss
  337. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  338. # return positive, negative label smoothing BCE targets
  339. return 1.0 - 0.5 * eps, 0.5 * eps
  340. class BCEBlurWithLogitsLoss(nn.Module):
  341. # BCEwithLogitLoss() with reduced missing label effects.
  342. def __init__(self, alpha=0.05):
  343. super(BCEBlurWithLogitsLoss, self).__init__()
  344. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  345. self.alpha = alpha
  346. def forward(self, pred, true):
  347. loss = self.loss_fcn(pred, true)
  348. pred = torch.sigmoid(pred) # prob from logits
  349. dx = pred - true # reduce only missing label effects
  350. # dx = (pred - true).abs() # reduce missing label and false label effects
  351. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  352. loss *= alpha_factor
  353. return loss.mean()
  354. def compute_loss(p, targets, model): # predictions, targets, model
  355. device = targets.device
  356. ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
  357. lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device)
  358. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  359. h = model.hyp # hyperparameters
  360. red = 'mean' # Loss reduction (sum or mean)
  361. # Define criteria
  362. BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device)
  363. BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device)
  364. # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  365. cp, cn = smooth_BCE(eps=0.0)
  366. # focal loss
  367. g = h['fl_gamma'] # focal loss gamma
  368. if g > 0:
  369. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  370. # per output
  371. nt = 0 # number of targets
  372. np = len(p) # number of outputs
  373. balance = [4.0, 1.0, 0.4] if np == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
  374. for i, pi in enumerate(p): # layer index, layer predictions
  375. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  376. tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj
  377. nb = b.shape[0] # number of targets
  378. if nb:
  379. nt += nb # cumulative targets
  380. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  381. # GIoU
  382. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  383. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  384. pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
  385. giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
  386. lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
  387. # Obj
  388. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
  389. # Class
  390. if model.nc > 1: # cls loss (only if multiple classes)
  391. t = torch.full_like(ps[:, 5:], cn).to(device) # targets
  392. t[range(nb), tcls[i]] = cp
  393. lcls += BCEcls(ps[:, 5:], t) # BCE
  394. # Append targets to text file
  395. # with open('targets.txt', 'a') as file:
  396. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  397. lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
  398. s = 3 / np # output count scaling
  399. lbox *= h['giou'] * s
  400. lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
  401. lcls *= h['cls'] * s
  402. bs = tobj.shape[0] # batch size
  403. if red == 'sum':
  404. g = 3.0 # loss gain
  405. lobj *= g / bs
  406. if nt:
  407. lcls *= g / nt / model.nc
  408. lbox *= g / nt
  409. loss = lbox + lobj + lcls
  410. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  411. def build_targets(p, targets, model):
  412. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  413. det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \
  414. else model.model[-1] # Detect() module
  415. na, nt = det.na, targets.shape[0] # number of anchors, targets
  416. tcls, tbox, indices, anch = [], [], [], []
  417. gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
  418. off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
  419. at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
  420. g = 0.5 # offset
  421. style = 'rect4'
  422. for i in range(det.nl):
  423. anchors = det.anchors[i]
  424. gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  425. # Match targets to anchors
  426. a, t, offsets = [], targets * gain, 0
  427. if nt:
  428. r = t[None, :, 4:6] / anchors[:, None] # wh ratio
  429. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  430. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
  431. a, t = at[j], t.repeat(na, 1, 1)[j] # filter
  432. # overlaps
  433. gxy = t[:, 2:4] # grid xy
  434. z = torch.zeros_like(gxy)
  435. if style == 'rect2':
  436. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  437. a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
  438. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
  439. elif style == 'rect4':
  440. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  441. l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
  442. a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
  443. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
  444. # Define
  445. b, c = t[:, :2].long().T # image, class
  446. gxy = t[:, 2:4] # grid xy
  447. gwh = t[:, 4:6] # grid wh
  448. gij = (gxy - offsets).long()
  449. gi, gj = gij.T # grid xy indices
  450. # Append
  451. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  452. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  453. anch.append(anchors[a]) # anchors
  454. tcls.append(c) # class
  455. return tcls, tbox, indices, anch
  456. def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
  457. """Performs Non-Maximum Suppression (NMS) on inference results
  458. Returns:
  459. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  460. """
  461. if prediction.dtype is torch.float16:
  462. prediction = prediction.float() # to FP32
  463. nc = prediction[0].shape[1] - 5 # number of classes
  464. xc = prediction[..., 4] > conf_thres # candidates
  465. # Settings
  466. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  467. max_det = 300 # maximum number of detections per image
  468. time_limit = 10.0 # seconds to quit after
  469. redundant = True # require redundant detections
  470. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  471. t = time.time()
  472. output = [None] * prediction.shape[0]
  473. for xi, x in enumerate(prediction): # image index, image inference
  474. # Apply constraints
  475. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  476. x = x[xc[xi]] # confidence
  477. # If none remain process next image
  478. if not x.shape[0]:
  479. continue
  480. # Compute conf
  481. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  482. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  483. box = xywh2xyxy(x[:, :4])
  484. # Detections matrix nx6 (xyxy, conf, cls)
  485. if multi_label:
  486. i, j = (x[:, 5:] > conf_thres).nonzero().t()
  487. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  488. else: # best class only
  489. conf, j = x[:, 5:].max(1, keepdim=True)
  490. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  491. # Filter by class
  492. if classes:
  493. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  494. # Apply finite constraint
  495. # if not torch.isfinite(x).all():
  496. # x = x[torch.isfinite(x).all(1)]
  497. # If none remain process next image
  498. n = x.shape[0] # number of boxes
  499. if not n:
  500. continue
  501. # Sort by confidence
  502. # x = x[x[:, 4].argsort(descending=True)]
  503. # Batched NMS
  504. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  505. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  506. i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
  507. if i.shape[0] > max_det: # limit detections
  508. i = i[:max_det]
  509. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  510. try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  511. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  512. weights = iou * scores[None] # box weights
  513. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  514. if redundant:
  515. i = i[iou.sum(1) > 1] # require redundancy
  516. except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
  517. print(x, i, x.shape, i.shape)
  518. pass
  519. output[xi] = x[i]
  520. if (time.time() - t) > time_limit:
  521. break # time limit exceeded
  522. return output
  523. def strip_optimizer(f='weights/best.pt', s=''): # from utils.utils import *; strip_optimizer()
  524. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  525. x = torch.load(f, map_location=torch.device('cpu'))
  526. x['optimizer'] = None
  527. x['training_results'] = None
  528. x['epoch'] = -1
  529. x['model'].half() # to FP16
  530. for p in x['model'].parameters():
  531. p.requires_grad = False
  532. torch.save(x, s or f)
  533. mb = os.path.getsize(s or f) / 1E6 # filesize
  534. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  535. def coco_class_count(path='../coco/labels/train2014/'):
  536. # Histogram of occurrences per class
  537. nc = 80 # number classes
  538. x = np.zeros(nc, dtype='int32')
  539. files = sorted(glob.glob('%s/*.*' % path))
  540. for i, file in enumerate(files):
  541. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  542. x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
  543. print(i, len(files))
  544. def coco_only_people(path='../coco/labels/train2017/'): # from utils.utils import *; coco_only_people()
  545. # Find images with only people
  546. files = sorted(glob.glob('%s/*.*' % path))
  547. for i, file in enumerate(files):
  548. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  549. if all(labels[:, 0] == 0):
  550. print(labels.shape[0], file)
  551. def crop_images_random(path='../images/', scale=0.50): # from utils.utils import *; crop_images_random()
  552. # crops images into random squares up to scale fraction
  553. # WARNING: overwrites images!
  554. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  555. img = cv2.imread(file) # BGR
  556. if img is not None:
  557. h, w = img.shape[:2]
  558. # create random mask
  559. a = 30 # minimum size (pixels)
  560. mask_h = random.randint(a, int(max(a, h * scale))) # mask height
  561. mask_w = mask_h # mask width
  562. # box
  563. xmin = max(0, random.randint(0, w) - mask_w // 2)
  564. ymin = max(0, random.randint(0, h) - mask_h // 2)
  565. xmax = min(w, xmin + mask_w)
  566. ymax = min(h, ymin + mask_h)
  567. # apply random color mask
  568. cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
  569. def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
  570. # Makes single-class coco datasets. from utils.utils import *; coco_single_class_labels()
  571. if os.path.exists('new/'):
  572. shutil.rmtree('new/') # delete output folder
  573. os.makedirs('new/') # make new output folder
  574. os.makedirs('new/labels/')
  575. os.makedirs('new/images/')
  576. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  577. with open(file, 'r') as f:
  578. labels = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  579. i = labels[:, 0] == label_class
  580. if any(i):
  581. img_file = file.replace('labels', 'images').replace('txt', 'jpg')
  582. labels[:, 0] = 0 # reset class to 0
  583. with open('new/images.txt', 'a') as f: # add image to dataset list
  584. f.write(img_file + '\n')
  585. with open('new/labels/' + Path(file).name, 'a') as f: # write label
  586. for l in labels[i]:
  587. f.write('%g %.6f %.6f %.6f %.6f\n' % tuple(l))
  588. shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
  589. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  590. """ Creates kmeans-evolved anchors from training dataset
  591. Arguments:
  592. path: path to dataset *.yaml, or a loaded dataset
  593. n: number of anchors
  594. img_size: image size used for training
  595. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  596. gen: generations to evolve anchors using genetic algorithm
  597. Return:
  598. k: kmeans evolved anchors
  599. Usage:
  600. from utils.utils import *; _ = kmean_anchors()
  601. """
  602. thr = 1. / thr
  603. def metric(k, wh): # compute metrics
  604. r = wh[:, None] / k[None]
  605. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  606. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  607. return x, x.max(1)[0] # x, best_x
  608. def fitness(k): # mutation fitness
  609. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  610. return (best * (best > thr).float()).mean() # fitness
  611. def print_results(k):
  612. k = k[np.argsort(k.prod(1))] # sort small to large
  613. x, best = metric(k, wh0)
  614. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  615. print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
  616. print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
  617. (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
  618. for i, x in enumerate(k):
  619. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  620. return k
  621. if isinstance(path, str): # *.yaml file
  622. with open(path) as f:
  623. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  624. from utils.datasets import LoadImagesAndLabels
  625. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  626. else:
  627. dataset = path # dataset
  628. # Get label wh
  629. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  630. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  631. # Filter
  632. i = (wh0 < 3.0).any(1).sum()
  633. if i:
  634. print('WARNING: Extremely small objects found. '
  635. '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
  636. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  637. # Kmeans calculation
  638. from scipy.cluster.vq import kmeans
  639. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  640. s = wh.std(0) # sigmas for whitening
  641. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  642. k *= s
  643. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  644. wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered
  645. k = print_results(k)
  646. # Plot
  647. # k, d = [None] * 20, [None] * 20
  648. # for i in tqdm(range(1, 21)):
  649. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  650. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  651. # ax = ax.ravel()
  652. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  653. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  654. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  655. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  656. # fig.tight_layout()
  657. # fig.savefig('wh.png', dpi=200)
  658. # Evolve
  659. npr = np.random
  660. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  661. pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
  662. for _ in pbar:
  663. v = np.ones(sh)
  664. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  665. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  666. kg = (k.copy() * v).clip(min=2.0)
  667. fg = fitness(kg)
  668. if fg > f:
  669. f, k = fg, kg.copy()
  670. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  671. if verbose:
  672. print_results(k)
  673. return print_results(k)
  674. def print_mutation(hyp, results, bucket=''):
  675. # Print mutation results to evolve.txt (for use with train.py --evolve)
  676. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  677. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  678. c = '%10.4g' * len(results) % results # results (P, R, mAP, F1, test_loss)
  679. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  680. if bucket:
  681. os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
  682. with open('evolve.txt', 'a') as f: # append result
  683. f.write(c + b + '\n')
  684. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  685. np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness
  686. if bucket:
  687. os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
  688. def apply_classifier(x, model, img, im0):
  689. # applies a second stage classifier to yolo outputs
  690. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  691. for i, d in enumerate(x): # per image
  692. if d is not None and len(d):
  693. d = d.clone()
  694. # Reshape and pad cutouts
  695. b = xyxy2xywh(d[:, :4]) # boxes
  696. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  697. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  698. d[:, :4] = xywh2xyxy(b).long()
  699. # Rescale boxes from img_size to im0 size
  700. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  701. # Classes
  702. pred_cls1 = d[:, 5].long()
  703. ims = []
  704. for j, a in enumerate(d): # per item
  705. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  706. im = cv2.resize(cutout, (224, 224)) # BGR
  707. # cv2.imwrite('test%i.jpg' % j, cutout)
  708. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  709. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  710. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  711. ims.append(im)
  712. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  713. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  714. return x
  715. def fitness(x):
  716. # Returns fitness (for use with results.txt or evolve.txt)
  717. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  718. return (x[:, :4] * w).sum(1)
  719. def output_to_target(output, width, height):
  720. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
  721. if isinstance(output, torch.Tensor):
  722. output = output.cpu().numpy()
  723. targets = []
  724. for i, o in enumerate(output):
  725. if o is not None:
  726. for pred in o:
  727. box = pred[:4]
  728. w = (box[2] - box[0]) / width
  729. h = (box[3] - box[1]) / height
  730. x = box[0] / width + w / 2
  731. y = box[1] / height + h / 2
  732. conf = pred[4]
  733. cls = int(pred[5])
  734. targets.append([i, cls, x, y, w, h, conf])
  735. return np.array(targets)
  736. def increment_dir(dir, comment=''):
  737. # Increments a directory runs/exp1 --> runs/exp2_comment
  738. n = 0 # number
  739. dir = str(Path(dir)) # os-agnostic
  740. d = sorted(glob.glob(dir + '*')) # directories
  741. if len(d):
  742. n = max([int(x[len(dir):x.find('_') if '_' in x else None]) for x in d]) + 1 # increment
  743. return dir + str(n) + ('_' + comment if comment else '')
  744. # Plotting functions ---------------------------------------------------------------------------------------------------
  745. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  746. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  747. def butter_lowpass(cutoff, fs, order):
  748. nyq = 0.5 * fs
  749. normal_cutoff = cutoff / nyq
  750. b, a = butter(order, normal_cutoff, btype='low', analog=False)
  751. return b, a
  752. b, a = butter_lowpass(cutoff, fs, order=order)
  753. return filtfilt(b, a, data) # forward-backward filter
  754. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  755. # Plots one bounding box on image img
  756. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  757. color = color or [random.randint(0, 255) for _ in range(3)]
  758. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  759. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  760. if label:
  761. tf = max(tl - 1, 1) # font thickness
  762. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  763. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  764. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  765. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  766. def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
  767. # Compares the two methods for width-height anchor multiplication
  768. # https://github.com/ultralytics/yolov3/issues/168
  769. x = np.arange(-4.0, 4.0, .1)
  770. ya = np.exp(x)
  771. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  772. fig = plt.figure(figsize=(6, 3), dpi=150)
  773. plt.plot(x, ya, '.-', label='YOLOv3')
  774. plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
  775. plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
  776. plt.xlim(left=-4, right=4)
  777. plt.ylim(bottom=0, top=6)
  778. plt.xlabel('input')
  779. plt.ylabel('output')
  780. plt.grid()
  781. plt.legend()
  782. fig.tight_layout()
  783. fig.savefig('comparison.png', dpi=200)
  784. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  785. tl = 3 # line thickness
  786. tf = max(tl - 1, 1) # font thickness
  787. if os.path.isfile(fname): # do not overwrite
  788. return None
  789. if isinstance(images, torch.Tensor):
  790. images = images.cpu().float().numpy()
  791. if isinstance(targets, torch.Tensor):
  792. targets = targets.cpu().numpy()
  793. # un-normalise
  794. if np.max(images[0]) <= 1:
  795. images *= 255
  796. bs, _, h, w = images.shape # batch size, _, height, width
  797. bs = min(bs, max_subplots) # limit plot images
  798. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  799. # Check if we should resize
  800. scale_factor = max_size / max(h, w)
  801. if scale_factor < 1:
  802. h = math.ceil(scale_factor * h)
  803. w = math.ceil(scale_factor * w)
  804. # Empty array for output
  805. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
  806. # Fix class - colour map
  807. prop_cycle = plt.rcParams['axes.prop_cycle']
  808. # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  809. hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  810. color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
  811. for i, img in enumerate(images):
  812. if i == max_subplots: # if last batch has fewer images than we expect
  813. break
  814. block_x = int(w * (i // ns))
  815. block_y = int(h * (i % ns))
  816. img = img.transpose(1, 2, 0)
  817. if scale_factor < 1:
  818. img = cv2.resize(img, (w, h))
  819. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  820. if len(targets) > 0:
  821. image_targets = targets[targets[:, 0] == i]
  822. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  823. classes = image_targets[:, 1].astype('int')
  824. gt = image_targets.shape[1] == 6 # ground truth if no conf column
  825. conf = None if gt else image_targets[:, 6] # check for confidence presence (gt vs pred)
  826. boxes[[0, 2]] *= w
  827. boxes[[0, 2]] += block_x
  828. boxes[[1, 3]] *= h
  829. boxes[[1, 3]] += block_y
  830. for j, box in enumerate(boxes.T):
  831. cls = int(classes[j])
  832. color = color_lut[cls % len(color_lut)]
  833. cls = names[cls] if names else cls
  834. if gt or conf[j] > 0.3: # 0.3 conf thresh
  835. label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j])
  836. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  837. # Draw image filename labels
  838. if paths is not None:
  839. label = os.path.basename(paths[i])[:40] # trim to 40 char
  840. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  841. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  842. lineType=cv2.LINE_AA)
  843. # Image border
  844. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  845. if fname is not None:
  846. mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
  847. cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
  848. return mosaic
  849. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  850. # Plot LR simulating training for full epochs
  851. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  852. y = []
  853. for _ in range(epochs):
  854. scheduler.step()
  855. y.append(optimizer.param_groups[0]['lr'])
  856. plt.plot(y, '.-', label='LR')
  857. plt.xlabel('epoch')
  858. plt.ylabel('LR')
  859. plt.grid()
  860. plt.xlim(0, epochs)
  861. plt.ylim(0)
  862. plt.tight_layout()
  863. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  864. def plot_test_txt(): # from utils.utils import *; plot_test()
  865. # Plot test.txt histograms
  866. x = np.loadtxt('test.txt', dtype=np.float32)
  867. box = xyxy2xywh(x[:, :4])
  868. cx, cy = box[:, 0], box[:, 1]
  869. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  870. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  871. ax.set_aspect('equal')
  872. plt.savefig('hist2d.png', dpi=300)
  873. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  874. ax[0].hist(cx, bins=600)
  875. ax[1].hist(cy, bins=600)
  876. plt.savefig('hist1d.png', dpi=200)
  877. def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()
  878. # Plot targets.txt histograms
  879. x = np.loadtxt('targets.txt', dtype=np.float32).T
  880. s = ['x targets', 'y targets', 'width targets', 'height targets']
  881. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  882. ax = ax.ravel()
  883. for i in range(4):
  884. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  885. ax[i].legend()
  886. ax[i].set_title(s[i])
  887. plt.savefig('targets.jpg', dpi=200)
  888. def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_study_txt()
  889. # Plot study.txt generated by test.py
  890. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  891. ax = ax.ravel()
  892. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  893. for f in ['coco_study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  894. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  895. x = np.arange(y.shape[1]) if x is None else np.array(x)
  896. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  897. for i in range(7):
  898. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  899. ax[i].set_title(s[i])
  900. j = y[3].argmax() + 1
  901. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  902. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  903. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.8, 39.6, 43.0, 47.5, 49.4, 50.7],
  904. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  905. ax2.grid()
  906. ax2.set_xlim(0, 30)
  907. ax2.set_ylim(28, 50)
  908. ax2.set_yticks(np.arange(30, 55, 5))
  909. ax2.set_xlabel('GPU Speed (ms/img)')
  910. ax2.set_ylabel('COCO AP val')
  911. ax2.legend(loc='lower right')
  912. plt.savefig('study_mAP_latency.png', dpi=300)
  913. plt.savefig(f.replace('.txt', '.png'), dpi=200)
  914. def plot_labels(labels, save_dir=''):
  915. # plot dataset labels
  916. def hist2d(x, y, n=100):
  917. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  918. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  919. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  920. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  921. return np.log(hist[xidx, yidx])
  922. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  923. nc = int(c.max() + 1) # number of classes
  924. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  925. ax = ax.ravel()
  926. ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  927. ax[0].set_xlabel('classes')
  928. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  929. ax[1].set_xlabel('x')
  930. ax[1].set_ylabel('y')
  931. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  932. ax[2].set_xlabel('width')
  933. ax[2].set_ylabel('height')
  934. plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
  935. plt.close()
  936. def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
  937. # Plot hyperparameter evolution results in evolve.txt
  938. x = np.loadtxt('evolve.txt', ndmin=2)
  939. f = fitness(x)
  940. # weights = (f - f.min()) ** 2 # for weighted results
  941. plt.figure(figsize=(12, 10), tight_layout=True)
  942. matplotlib.rc('font', **{'size': 8})
  943. for i, (k, v) in enumerate(hyp.items()):
  944. y = x[:, i + 7]
  945. # mu = (y * weights).sum() / weights.sum() # best weighted result
  946. mu = y[f.argmax()] # best single result
  947. plt.subplot(4, 5, i + 1)
  948. plt.plot(mu, f.max(), 'o', markersize=10)
  949. plt.plot(y, f, '.')
  950. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  951. print('%15s: %.3g' % (k, mu))
  952. plt.savefig('evolve.png', dpi=200)
  953. def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_results_overlay()
  954. # Plot training 'results*.txt', overlaying train and val losses
  955. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  956. t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  957. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  958. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  959. n = results.shape[1] # number of rows
  960. x = range(start, min(stop, n) if stop else n)
  961. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  962. ax = ax.ravel()
  963. for i in range(5):
  964. for j in [i, i + 5]:
  965. y = results[j, x]
  966. ax[i].plot(x, y, marker='.', label=s[j])
  967. # y_smooth = butter_lowpass_filtfilt(y)
  968. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  969. ax[i].set_title(t[i])
  970. ax[i].legend()
  971. ax[i].set_ylabel(f) if i == 0 else None # add filename
  972. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  973. def plot_results(start=0, stop=0, bucket='', id=(), labels=(),
  974. save_dir=''): # from utils.utils import *; plot_results()
  975. # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
  976. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  977. ax = ax.ravel()
  978. s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
  979. 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  980. if bucket:
  981. os.system('rm -rf storage.googleapis.com')
  982. files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  983. else:
  984. files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt')
  985. for fi, f in enumerate(files):
  986. try:
  987. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  988. n = results.shape[1] # number of rows
  989. x = range(start, min(stop, n) if stop else n)
  990. for i in range(10):
  991. y = results[i, x]
  992. if i in [0, 1, 2, 5, 6, 7]:
  993. y[y == 0] = np.nan # dont show zero loss values
  994. # y /= y[0] # normalize
  995. label = labels[fi] if len(labels) else Path(f).stem
  996. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  997. ax[i].set_title(s[i])
  998. # if i in [5, 6, 7]: # share train and val loss y axes
  999. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  1000. except:
  1001. print('Warning: Plotting error for %s, skipping file' % f)
  1002. fig.tight_layout()
  1003. ax[1].legend()
  1004. fig.savefig(Path(save_dir) / 'results.png', dpi=200)