You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1209 lines
49KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import subprocess
  7. import time
  8. from copy import copy
  9. from pathlib import Path
  10. from sys import platform
  11. import cv2
  12. import matplotlib
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. import torch
  16. import torch.nn as nn
  17. import torchvision
  18. import yaml
  19. from scipy.signal import butter, filtfilt
  20. from tqdm import tqdm
  21. from . import torch_utils #  torch_utils, google_utils
  22. # Set printoptions
  23. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  24. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  25. matplotlib.rc('font', **{'size': 11})
  26. # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
  27. cv2.setNumThreads(0)
  28. def init_seeds(seed=0):
  29. random.seed(seed)
  30. np.random.seed(seed)
  31. torch_utils.init_seeds(seed=seed)
  32. def check_git_status():
  33. # Suggest 'git pull' if repo is out of date
  34. if platform in ['linux', 'darwin']:
  35. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  36. if 'Your branch is behind' in s:
  37. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  38. def check_img_size(img_size, s=32):
  39. # Verify img_size is a multiple of stride s
  40. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  41. if new_size != img_size:
  42. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  43. return new_size
  44. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  45. # Check anchor fit to data, recompute if necessary
  46. print('\nAnalyzing anchors... ', end='')
  47. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  48. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  49. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  50. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  51. def metric(k): # compute metric
  52. r = wh[:, None] / k[None]
  53. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  54. best = x.max(1)[0] # best_x
  55. return (best > 1. / thr).float().mean() #  best possible recall
  56. bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
  57. print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
  58. if bpr < 0.99: # threshold to recompute
  59. print('. Attempting to generate improved anchors, please wait...' % bpr)
  60. na = m.anchor_grid.numel() // 2 # number of anchors
  61. new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  62. new_bpr = metric(new_anchors.reshape(-1, 2))
  63. if new_bpr > bpr: # replace anchors
  64. new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
  65. m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
  66. m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  67. check_anchor_order(m)
  68. print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  69. else:
  70. print('Original anchors better than new anchors. Proceeding with original anchors.')
  71. print('') # newline
  72. def check_anchor_order(m):
  73. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  74. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  75. da = a[-1] - a[0] # delta a
  76. ds = m.stride[-1] - m.stride[0] # delta s
  77. if da.sign() != ds.sign(): # same order
  78. m.anchors[:] = m.anchors.flip(0)
  79. m.anchor_grid[:] = m.anchor_grid.flip(0)
  80. def check_file(file):
  81. # Searches for file if not found locally
  82. if os.path.isfile(file):
  83. return file
  84. else:
  85. files = glob.glob('./**/' + file, recursive=True) # find file
  86. assert len(files), 'File Not Found: %s' % file # assert file was found
  87. return files[0] # return first file if multiple found
  88. def make_divisible(x, divisor):
  89. # Returns x evenly divisble by divisor
  90. return math.ceil(x / divisor) * divisor
  91. def labels_to_class_weights(labels, nc=80):
  92. # Get class weights (inverse frequency) from training labels
  93. if labels[0] is None: # no labels loaded
  94. return torch.Tensor()
  95. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  96. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  97. weights = np.bincount(classes, minlength=nc) # occurences per class
  98. # Prepend gridpoint count (for uCE trianing)
  99. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  100. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  101. weights[weights == 0] = 1 # replace empty bins with 1
  102. weights = 1 / weights # number of targets per class
  103. weights /= weights.sum() # normalize
  104. return torch.from_numpy(weights)
  105. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  106. # Produces image weights based on class mAPs
  107. n = len(labels)
  108. class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
  109. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  110. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  111. return image_weights
  112. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  113. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  114. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  115. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  116. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  117. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  118. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  119. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  120. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  121. return x
  122. def xyxy2xywh(x):
  123. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  124. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  125. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  126. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  127. y[:, 2] = x[:, 2] - x[:, 0] # width
  128. y[:, 3] = x[:, 3] - x[:, 1] # height
  129. return y
  130. def xywh2xyxy(x):
  131. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  132. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  133. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  134. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  135. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  136. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  137. return y
  138. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  139. # Rescale coords (xyxy) from img1_shape to img0_shape
  140. if ratio_pad is None: # calculate from img0_shape
  141. gain = max(img1_shape) / max(img0_shape) # gain = old / new
  142. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  143. else:
  144. gain = ratio_pad[0][0]
  145. pad = ratio_pad[1]
  146. coords[:, [0, 2]] -= pad[0] # x padding
  147. coords[:, [1, 3]] -= pad[1] # y padding
  148. coords[:, :4] /= gain
  149. clip_coords(coords, img0_shape)
  150. return coords
  151. def clip_coords(boxes, img_shape):
  152. # Clip bounding xyxy bounding boxes to image shape (height, width)
  153. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  154. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  155. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  156. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  157. def ap_per_class(tp, conf, pred_cls, target_cls):
  158. """ Compute the average precision, given the recall and precision curves.
  159. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  160. # Arguments
  161. tp: True positives (nparray, nx1 or nx10).
  162. conf: Objectness value from 0-1 (nparray).
  163. pred_cls: Predicted object classes (nparray).
  164. target_cls: True object classes (nparray).
  165. # Returns
  166. The average precision as computed in py-faster-rcnn.
  167. """
  168. # Sort by objectness
  169. i = np.argsort(-conf)
  170. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  171. # Find unique classes
  172. unique_classes = np.unique(target_cls)
  173. # Create Precision-Recall curve and compute AP for each class
  174. pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
  175. s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
  176. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  177. for ci, c in enumerate(unique_classes):
  178. i = pred_cls == c
  179. n_gt = (target_cls == c).sum() # Number of ground truth objects
  180. n_p = i.sum() # Number of predicted objects
  181. if n_p == 0 or n_gt == 0:
  182. continue
  183. else:
  184. # Accumulate FPs and TPs
  185. fpc = (1 - tp[i]).cumsum(0)
  186. tpc = tp[i].cumsum(0)
  187. # Recall
  188. recall = tpc / (n_gt + 1e-16) # recall curve
  189. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
  190. # Precision
  191. precision = tpc / (tpc + fpc) # precision curve
  192. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
  193. # AP from recall-precision curve
  194. for j in range(tp.shape[1]):
  195. ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
  196. # Plot
  197. # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
  198. # ax.plot(recall, precision)
  199. # ax.set_xlabel('Recall')
  200. # ax.set_ylabel('Precision')
  201. # ax.set_xlim(0, 1.01)
  202. # ax.set_ylim(0, 1.01)
  203. # fig.tight_layout()
  204. # fig.savefig('PR_curve.png', dpi=300)
  205. # Compute F1 score (harmonic mean of precision and recall)
  206. f1 = 2 * p * r / (p + r + 1e-16)
  207. return p, r, ap, f1, unique_classes.astype('int32')
  208. def compute_ap(recall, precision):
  209. """ Compute the average precision, given the recall and precision curves.
  210. Source: https://github.com/rbgirshick/py-faster-rcnn.
  211. # Arguments
  212. recall: The recall curve (list).
  213. precision: The precision curve (list).
  214. # Returns
  215. The average precision as computed in py-faster-rcnn.
  216. """
  217. # Append sentinel values to beginning and end
  218. mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
  219. mpre = np.concatenate(([0.], precision, [0.]))
  220. # Compute the precision envelope
  221. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  222. # Integrate area under curve
  223. method = 'interp' # methods: 'continuous', 'interp'
  224. if method == 'interp':
  225. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  226. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  227. else: # 'continuous'
  228. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  229. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  230. return ap
  231. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
  232. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  233. box2 = box2.t()
  234. # Get the coordinates of bounding boxes
  235. if x1y1x2y2: # x1, y1, x2, y2 = box1
  236. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  237. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  238. else: # transform from xywh to xyxy
  239. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  240. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  241. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  242. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  243. # Intersection area
  244. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  245. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  246. # Union Area
  247. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
  248. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
  249. union = (w1 * h1 + 1e-16) + w2 * h2 - inter
  250. iou = inter / union # iou
  251. if GIoU or DIoU or CIoU:
  252. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  253. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  254. if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
  255. c_area = cw * ch + 1e-16 # convex area
  256. return iou - (c_area - union) / c_area # GIoU
  257. if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  258. # convex diagonal squared
  259. c2 = cw ** 2 + ch ** 2 + 1e-16
  260. # centerpoint distance squared
  261. rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
  262. if DIoU:
  263. return iou - rho2 / c2 # DIoU
  264. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  265. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  266. with torch.no_grad():
  267. alpha = v / (1 - iou + v)
  268. return iou - (rho2 / c2 + v * alpha) # CIoU
  269. return iou
  270. def box_iou(box1, box2):
  271. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  272. """
  273. Return intersection-over-union (Jaccard index) of boxes.
  274. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  275. Arguments:
  276. box1 (Tensor[N, 4])
  277. box2 (Tensor[M, 4])
  278. Returns:
  279. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  280. IoU values for every element in boxes1 and boxes2
  281. """
  282. def box_area(box):
  283. # box = 4xn
  284. return (box[2] - box[0]) * (box[3] - box[1])
  285. area1 = box_area(box1.t())
  286. area2 = box_area(box2.t())
  287. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  288. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  289. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  290. def wh_iou(wh1, wh2):
  291. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  292. wh1 = wh1[:, None] # [N,1,2]
  293. wh2 = wh2[None] # [1,M,2]
  294. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  295. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  296. class FocalLoss(nn.Module):
  297. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  298. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  299. super(FocalLoss, self).__init__()
  300. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  301. self.gamma = gamma
  302. self.alpha = alpha
  303. self.reduction = loss_fcn.reduction
  304. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  305. def forward(self, pred, true):
  306. loss = self.loss_fcn(pred, true)
  307. # p_t = torch.exp(-loss)
  308. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  309. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  310. pred_prob = torch.sigmoid(pred) # prob from logits
  311. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  312. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  313. modulating_factor = (1.0 - p_t) ** self.gamma
  314. loss *= alpha_factor * modulating_factor
  315. if self.reduction == 'mean':
  316. return loss.mean()
  317. elif self.reduction == 'sum':
  318. return loss.sum()
  319. else: # 'none'
  320. return loss
  321. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  322. # return positive, negative label smoothing BCE targets
  323. return 1.0 - 0.5 * eps, 0.5 * eps
  324. class BCEBlurWithLogitsLoss(nn.Module):
  325. # BCEwithLogitLoss() with reduced missing label effects.
  326. def __init__(self, alpha=0.05):
  327. super(BCEBlurWithLogitsLoss, self).__init__()
  328. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  329. self.alpha = alpha
  330. def forward(self, pred, true):
  331. loss = self.loss_fcn(pred, true)
  332. pred = torch.sigmoid(pred) # prob from logits
  333. dx = pred - true # reduce only missing label effects
  334. # dx = (pred - true).abs() # reduce missing label and false label effects
  335. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  336. loss *= alpha_factor
  337. return loss.mean()
  338. def compute_loss(p, targets, model): # predictions, targets, model
  339. ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
  340. lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
  341. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  342. h = model.hyp # hyperparameters
  343. red = 'mean' # Loss reduction (sum or mean)
  344. # Define criteria
  345. BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
  346. BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
  347. # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  348. cp, cn = smooth_BCE(eps=0.0)
  349. # focal loss
  350. g = h['fl_gamma'] # focal loss gamma
  351. if g > 0:
  352. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  353. # per output
  354. nt = 0 # number of targets
  355. np = len(p) # number of outputs
  356. balance = [1.0, 1.0, 1.0]
  357. for i, pi in enumerate(p): # layer index, layer predictions
  358. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  359. tobj = torch.zeros_like(pi[..., 0]) # target obj
  360. nb = b.shape[0] # number of targets
  361. if nb:
  362. nt += nb # cumulative targets
  363. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  364. # GIoU
  365. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  366. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  367. pbox = torch.cat((pxy, pwh), 1) # predicted box
  368. giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
  369. lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
  370. # Obj
  371. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
  372. # Class
  373. if model.nc > 1: # cls loss (only if multiple classes)
  374. t = torch.full_like(ps[:, 5:], cn) # targets
  375. t[range(nb), tcls[i]] = cp
  376. lcls += BCEcls(ps[:, 5:], t) # BCE
  377. # Append targets to text file
  378. # with open('targets.txt', 'a') as file:
  379. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  380. lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
  381. s = 3 / np # output count scaling
  382. lbox *= h['giou'] * s
  383. lobj *= h['obj'] * s
  384. lcls *= h['cls'] * s
  385. bs = tobj.shape[0] # batch size
  386. if red == 'sum':
  387. g = 3.0 # loss gain
  388. lobj *= g / bs
  389. if nt:
  390. lcls *= g / nt / model.nc
  391. lbox *= g / nt
  392. loss = lbox + lobj + lcls
  393. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  394. def build_targets(p, targets, model):
  395. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  396. det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \
  397. else model.model[-1] # Detect() module
  398. na, nt = det.na, targets.shape[0] # number of anchors, targets
  399. tcls, tbox, indices, anch = [], [], [], []
  400. gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
  401. off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
  402. at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
  403. style = 'rect4'
  404. for i in range(det.nl):
  405. anchors = det.anchors[i]
  406. gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  407. # Match targets to anchors
  408. a, t, offsets = [], targets * gain, 0
  409. if nt:
  410. r = t[None, :, 4:6] / anchors[:, None] # wh ratio
  411. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  412. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
  413. a, t = at[j], t.repeat(na, 1, 1)[j] # filter
  414. # overlaps
  415. g = 0.5 # offset
  416. gxy = t[:, 2:4] # grid xy
  417. z = torch.zeros_like(gxy)
  418. if style == 'rect2':
  419. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  420. a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
  421. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
  422. elif style == 'rect4':
  423. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  424. l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
  425. a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
  426. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
  427. # Define
  428. b, c = t[:, :2].long().T # image, class
  429. gxy = t[:, 2:4] # grid xy
  430. gwh = t[:, 4:6] # grid wh
  431. gij = (gxy - offsets).long()
  432. gi, gj = gij.T # grid xy indices
  433. # Append
  434. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  435. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  436. anch.append(anchors[a]) # anchors
  437. tcls.append(c) # class
  438. return tcls, tbox, indices, anch
  439. def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
  440. """Performs Non-Maximum Suppression (NMS) on inference results
  441. Returns:
  442. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  443. """
  444. if prediction.dtype is torch.float16:
  445. prediction = prediction.float() # to FP32
  446. nc = prediction[0].shape[1] - 5 # number of classes
  447. xc = prediction[..., 4] > conf_thres # candidates
  448. # Settings
  449. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  450. max_det = 300 # maximum number of detections per image
  451. time_limit = 10.0 # seconds to quit after
  452. redundant = True # require redundant detections
  453. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  454. t = time.time()
  455. output = [None] * prediction.shape[0]
  456. for xi, x in enumerate(prediction): # image index, image inference
  457. # Apply constraints
  458. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  459. x = x[xc[xi]] # confidence
  460. # If none remain process next image
  461. if not x.shape[0]:
  462. continue
  463. # Compute conf
  464. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  465. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  466. box = xywh2xyxy(x[:, :4])
  467. # Detections matrix nx6 (xyxy, conf, cls)
  468. if multi_label:
  469. i, j = (x[:, 5:] > conf_thres).nonzero().t()
  470. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  471. else: # best class only
  472. conf, j = x[:, 5:].max(1, keepdim=True)
  473. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  474. # Filter by class
  475. if classes:
  476. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  477. # Apply finite constraint
  478. # if not torch.isfinite(x).all():
  479. # x = x[torch.isfinite(x).all(1)]
  480. # If none remain process next image
  481. n = x.shape[0] # number of boxes
  482. if not n:
  483. continue
  484. # Sort by confidence
  485. # x = x[x[:, 4].argsort(descending=True)]
  486. # Batched NMS
  487. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  488. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  489. i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
  490. if i.shape[0] > max_det: # limit detections
  491. i = i[:max_det]
  492. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  493. try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  494. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  495. weights = iou * scores[None] # box weights
  496. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  497. if redundant:
  498. i = i[iou.sum(1) > 1] # require redundancy
  499. except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
  500. print(x, i, x.shape, i.shape)
  501. pass
  502. output[xi] = x[i]
  503. if (time.time() - t) > time_limit:
  504. break # time limit exceeded
  505. return output
  506. def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_optimizer()
  507. # Strip optimizer from *.pt files for lighter files (reduced by 1/2 size)
  508. x = torch.load(f, map_location=torch.device('cpu'))
  509. x['optimizer'] = None
  510. x['model'].half() # to FP16
  511. torch.save(x, f)
  512. print('Optimizer stripped from %s' % f)
  513. def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
  514. # create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
  515. device = torch.device('cpu')
  516. x = torch.load(s, map_location=device)
  517. x['optimizer'] = None
  518. x['training_results'] = None
  519. x['epoch'] = -1
  520. x['model'].half() # to FP16
  521. for p in x['model'].parameters():
  522. p.requires_grad = True
  523. torch.save(x, s)
  524. print('%s saved as pretrained checkpoint %s' % (f, s))
  525. def coco_class_count(path='../coco/labels/train2014/'):
  526. # Histogram of occurrences per class
  527. nc = 80 # number classes
  528. x = np.zeros(nc, dtype='int32')
  529. files = sorted(glob.glob('%s/*.*' % path))
  530. for i, file in enumerate(files):
  531. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  532. x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
  533. print(i, len(files))
  534. def coco_only_people(path='../coco/labels/train2017/'): # from utils.utils import *; coco_only_people()
  535. # Find images with only people
  536. files = sorted(glob.glob('%s/*.*' % path))
  537. for i, file in enumerate(files):
  538. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  539. if all(labels[:, 0] == 0):
  540. print(labels.shape[0], file)
  541. def crop_images_random(path='../images/', scale=0.50): # from utils.utils import *; crop_images_random()
  542. # crops images into random squares up to scale fraction
  543. # WARNING: overwrites images!
  544. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  545. img = cv2.imread(file) # BGR
  546. if img is not None:
  547. h, w = img.shape[:2]
  548. # create random mask
  549. a = 30 # minimum size (pixels)
  550. mask_h = random.randint(a, int(max(a, h * scale))) # mask height
  551. mask_w = mask_h # mask width
  552. # box
  553. xmin = max(0, random.randint(0, w) - mask_w // 2)
  554. ymin = max(0, random.randint(0, h) - mask_h // 2)
  555. xmax = min(w, xmin + mask_w)
  556. ymax = min(h, ymin + mask_h)
  557. # apply random color mask
  558. cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
  559. def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
  560. # Makes single-class coco datasets. from utils.utils import *; coco_single_class_labels()
  561. if os.path.exists('new/'):
  562. shutil.rmtree('new/') # delete output folder
  563. os.makedirs('new/') # make new output folder
  564. os.makedirs('new/labels/')
  565. os.makedirs('new/images/')
  566. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  567. with open(file, 'r') as f:
  568. labels = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  569. i = labels[:, 0] == label_class
  570. if any(i):
  571. img_file = file.replace('labels', 'images').replace('txt', 'jpg')
  572. labels[:, 0] = 0 # reset class to 0
  573. with open('new/images.txt', 'a') as f: # add image to dataset list
  574. f.write(img_file + '\n')
  575. with open('new/labels/' + Path(file).name, 'a') as f: # write label
  576. for l in labels[i]:
  577. f.write('%g %.6f %.6f %.6f %.6f\n' % tuple(l))
  578. shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
  579. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  580. """ Creates kmeans-evolved anchors from training dataset
  581. Arguments:
  582. path: path to dataset *.yaml, or a loaded dataset
  583. n: number of anchors
  584. img_size: image size used for training
  585. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  586. gen: generations to evolve anchors using genetic algorithm
  587. Return:
  588. k: kmeans evolved anchors
  589. Usage:
  590. from utils.utils import *; _ = kmean_anchors()
  591. """
  592. thr = 1. / thr
  593. def metric(k, wh): # compute metrics
  594. r = wh[:, None] / k[None]
  595. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  596. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  597. return x, x.max(1)[0] # x, best_x
  598. def fitness(k): # mutation fitness
  599. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  600. return (best * (best > thr).float()).mean() # fitness
  601. def print_results(k):
  602. k = k[np.argsort(k.prod(1))] # sort small to large
  603. x, best = metric(k, wh0)
  604. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  605. print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
  606. print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
  607. (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
  608. for i, x in enumerate(k):
  609. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  610. return k
  611. if isinstance(path, str): # *.yaml file
  612. with open(path) as f:
  613. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  614. from utils.datasets import LoadImagesAndLabels
  615. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  616. else:
  617. dataset = path # dataset
  618. # Get label wh
  619. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  620. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  621. # Filter
  622. i = (wh0 < 3.0).any(1).sum()
  623. if i:
  624. print('WARNING: Extremely small objects found. '
  625. '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
  626. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  627. # Kmeans calculation
  628. from scipy.cluster.vq import kmeans
  629. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  630. s = wh.std(0) # sigmas for whitening
  631. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  632. k *= s
  633. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  634. wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered
  635. k = print_results(k)
  636. # Plot
  637. # k, d = [None] * 20, [None] * 20
  638. # for i in tqdm(range(1, 21)):
  639. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  640. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  641. # ax = ax.ravel()
  642. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  643. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  644. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  645. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  646. # fig.tight_layout()
  647. # fig.savefig('wh.png', dpi=200)
  648. # Evolve
  649. npr = np.random
  650. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  651. pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
  652. for _ in pbar:
  653. v = np.ones(sh)
  654. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  655. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  656. kg = (k.copy() * v).clip(min=2.0)
  657. fg = fitness(kg)
  658. if fg > f:
  659. f, k = fg, kg.copy()
  660. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  661. if verbose:
  662. print_results(k)
  663. return print_results(k)
  664. def print_mutation(hyp, results, bucket=''):
  665. # Print mutation results to evolve.txt (for use with train.py --evolve)
  666. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  667. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  668. c = '%10.4g' * len(results) % results # results (P, R, mAP, F1, test_loss)
  669. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  670. if bucket:
  671. os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
  672. with open('evolve.txt', 'a') as f: # append result
  673. f.write(c + b + '\n')
  674. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  675. np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness
  676. if bucket:
  677. os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
  678. def apply_classifier(x, model, img, im0):
  679. # applies a second stage classifier to yolo outputs
  680. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  681. for i, d in enumerate(x): # per image
  682. if d is not None and len(d):
  683. d = d.clone()
  684. # Reshape and pad cutouts
  685. b = xyxy2xywh(d[:, :4]) # boxes
  686. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  687. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  688. d[:, :4] = xywh2xyxy(b).long()
  689. # Rescale boxes from img_size to im0 size
  690. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  691. # Classes
  692. pred_cls1 = d[:, 5].long()
  693. ims = []
  694. for j, a in enumerate(d): # per item
  695. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  696. im = cv2.resize(cutout, (224, 224)) # BGR
  697. # cv2.imwrite('test%i.jpg' % j, cutout)
  698. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  699. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  700. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  701. ims.append(im)
  702. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  703. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  704. return x
  705. def fitness(x):
  706. # Returns fitness (for use with results.txt or evolve.txt)
  707. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  708. return (x[:, :4] * w).sum(1)
  709. def output_to_target(output, width, height):
  710. """
  711. Convert a YOLO model output to target format
  712. [batch_id, class_id, x, y, w, h, conf]
  713. """
  714. if isinstance(output, torch.Tensor):
  715. output = output.cpu().numpy()
  716. targets = []
  717. for i, o in enumerate(output):
  718. if o is not None:
  719. for pred in o:
  720. box = pred[:4]
  721. w = (box[2] - box[0]) / width
  722. h = (box[3] - box[1]) / height
  723. x = box[0] / width + w / 2
  724. y = box[1] / height + h / 2
  725. conf = pred[4]
  726. cls = int(pred[5])
  727. targets.append([i, cls, x, y, w, h, conf])
  728. return np.array(targets)
  729. # Plotting functions ---------------------------------------------------------------------------------------------------
  730. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  731. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  732. def butter_lowpass(cutoff, fs, order):
  733. nyq = 0.5 * fs
  734. normal_cutoff = cutoff / nyq
  735. b, a = butter(order, normal_cutoff, btype='low', analog=False)
  736. return b, a
  737. b, a = butter_lowpass(cutoff, fs, order=order)
  738. return filtfilt(b, a, data) # forward-backward filter
  739. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  740. # Plots one bounding box on image img
  741. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  742. color = color or [random.randint(0, 255) for _ in range(3)]
  743. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  744. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  745. if label:
  746. tf = max(tl - 1, 1) # font thickness
  747. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  748. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  749. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  750. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  751. def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
  752. # Compares the two methods for width-height anchor multiplication
  753. # https://github.com/ultralytics/yolov3/issues/168
  754. x = np.arange(-4.0, 4.0, .1)
  755. ya = np.exp(x)
  756. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  757. fig = plt.figure(figsize=(6, 3), dpi=150)
  758. plt.plot(x, ya, '.-', label='yolo method')
  759. plt.plot(x, yb ** 2, '.-', label='^2 power method')
  760. plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method')
  761. plt.xlim(left=-4, right=4)
  762. plt.ylim(bottom=0, top=6)
  763. plt.xlabel('input')
  764. plt.ylabel('output')
  765. plt.legend()
  766. fig.tight_layout()
  767. fig.savefig('comparison.png', dpi=200)
  768. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  769. tl = 3 # line thickness
  770. tf = max(tl - 1, 1) # font thickness
  771. if os.path.isfile(fname): # do not overwrite
  772. return None
  773. if isinstance(images, torch.Tensor):
  774. images = images.cpu().float().numpy()
  775. if isinstance(targets, torch.Tensor):
  776. targets = targets.cpu().numpy()
  777. # un-normalise
  778. if np.max(images[0]) <= 1:
  779. images *= 255
  780. bs, _, h, w = images.shape # batch size, _, height, width
  781. bs = min(bs, max_subplots) # limit plot images
  782. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  783. # Check if we should resize
  784. scale_factor = max_size / max(h, w)
  785. if scale_factor < 1:
  786. h = math.ceil(scale_factor * h)
  787. w = math.ceil(scale_factor * w)
  788. # Empty array for output
  789. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
  790. # Fix class - colour map
  791. prop_cycle = plt.rcParams['axes.prop_cycle']
  792. # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  793. hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  794. color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
  795. for i, img in enumerate(images):
  796. if i == max_subplots: # if last batch has fewer images than we expect
  797. break
  798. block_x = int(w * (i // ns))
  799. block_y = int(h * (i % ns))
  800. img = img.transpose(1, 2, 0)
  801. if scale_factor < 1:
  802. img = cv2.resize(img, (w, h))
  803. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  804. if len(targets) > 0:
  805. image_targets = targets[targets[:, 0] == i]
  806. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  807. classes = image_targets[:, 1].astype('int')
  808. gt = image_targets.shape[1] == 6 # ground truth if no conf column
  809. conf = None if gt else image_targets[:, 6] # check for confidence presence (gt vs pred)
  810. boxes[[0, 2]] *= w
  811. boxes[[0, 2]] += block_x
  812. boxes[[1, 3]] *= h
  813. boxes[[1, 3]] += block_y
  814. for j, box in enumerate(boxes.T):
  815. cls = int(classes[j])
  816. color = color_lut[cls % len(color_lut)]
  817. cls = names[cls] if names else cls
  818. if gt or conf[j] > 0.3: # 0.3 conf thresh
  819. label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j])
  820. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  821. # Draw image filename labels
  822. if paths is not None:
  823. label = os.path.basename(paths[i])[:40] # trim to 40 char
  824. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  825. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  826. lineType=cv2.LINE_AA)
  827. # Image border
  828. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  829. if fname is not None:
  830. mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
  831. cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
  832. return mosaic
  833. def plot_lr_scheduler(optimizer, scheduler, epochs=300):
  834. # Plot LR simulating training for full epochs
  835. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  836. y = []
  837. for _ in range(epochs):
  838. scheduler.step()
  839. y.append(optimizer.param_groups[0]['lr'])
  840. plt.plot(y, '.-', label='LR')
  841. plt.xlabel('epoch')
  842. plt.ylabel('LR')
  843. plt.grid()
  844. plt.xlim(0, epochs)
  845. plt.ylim(0)
  846. plt.tight_layout()
  847. plt.savefig('LR.png', dpi=200)
  848. def plot_test_txt(): # from utils.utils import *; plot_test()
  849. # Plot test.txt histograms
  850. x = np.loadtxt('test.txt', dtype=np.float32)
  851. box = xyxy2xywh(x[:, :4])
  852. cx, cy = box[:, 0], box[:, 1]
  853. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  854. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  855. ax.set_aspect('equal')
  856. plt.savefig('hist2d.png', dpi=300)
  857. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  858. ax[0].hist(cx, bins=600)
  859. ax[1].hist(cy, bins=600)
  860. plt.savefig('hist1d.png', dpi=200)
  861. def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()
  862. # Plot targets.txt histograms
  863. x = np.loadtxt('targets.txt', dtype=np.float32).T
  864. s = ['x targets', 'y targets', 'width targets', 'height targets']
  865. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  866. ax = ax.ravel()
  867. for i in range(4):
  868. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  869. ax[i].legend()
  870. ax[i].set_title(s[i])
  871. plt.savefig('targets.jpg', dpi=200)
  872. def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_study_txt()
  873. # Plot study.txt generated by test.py
  874. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  875. ax = ax.ravel()
  876. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  877. for f in ['coco_study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  878. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  879. x = np.arange(y.shape[1]) if x is None else np.array(x)
  880. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  881. for i in range(7):
  882. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  883. ax[i].set_title(s[i])
  884. j = y[3].argmax() + 1
  885. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  886. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  887. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
  888. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  889. ax2.grid()
  890. ax2.set_xlim(0, 30)
  891. ax2.set_ylim(28, 50)
  892. ax2.set_yticks(np.arange(30, 55, 5))
  893. ax2.set_xlabel('GPU Speed (ms/img)')
  894. ax2.set_ylabel('COCO AP val')
  895. ax2.legend(loc='lower right')
  896. plt.savefig('study_mAP_latency.png', dpi=300)
  897. plt.savefig(f.replace('.txt', '.png'), dpi=200)
  898. def plot_labels(labels):
  899. # plot dataset labels
  900. c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes
  901. def hist2d(x, y, n=100):
  902. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  903. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  904. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  905. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  906. return np.log(hist[xidx, yidx])
  907. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  908. ax = ax.ravel()
  909. ax[0].hist(c, bins=int(c.max() + 1))
  910. ax[0].set_xlabel('classes')
  911. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  912. ax[1].set_xlabel('x')
  913. ax[1].set_ylabel('y')
  914. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  915. ax[2].set_xlabel('width')
  916. ax[2].set_ylabel('height')
  917. plt.savefig('labels.png', dpi=200)
  918. plt.close()
  919. def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
  920. # Plot hyperparameter evolution results in evolve.txt
  921. x = np.loadtxt('evolve.txt', ndmin=2)
  922. f = fitness(x)
  923. # weights = (f - f.min()) ** 2 # for weighted results
  924. plt.figure(figsize=(12, 10), tight_layout=True)
  925. matplotlib.rc('font', **{'size': 8})
  926. for i, (k, v) in enumerate(hyp.items()):
  927. y = x[:, i + 7]
  928. # mu = (y * weights).sum() / weights.sum() # best weighted result
  929. mu = y[f.argmax()] # best single result
  930. plt.subplot(4, 5, i + 1)
  931. plt.plot(mu, f.max(), 'o', markersize=10)
  932. plt.plot(y, f, '.')
  933. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  934. print('%15s: %.3g' % (k, mu))
  935. plt.savefig('evolve.png', dpi=200)
  936. def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_results_overlay()
  937. # Plot training 'results*.txt', overlaying train and val losses
  938. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  939. t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  940. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  941. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  942. n = results.shape[1] # number of rows
  943. x = range(start, min(stop, n) if stop else n)
  944. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  945. ax = ax.ravel()
  946. for i in range(5):
  947. for j in [i, i + 5]:
  948. y = results[j, x]
  949. ax[i].plot(x, y, marker='.', label=s[j])
  950. # y_smooth = butter_lowpass_filtfilt(y)
  951. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  952. ax[i].set_title(t[i])
  953. ax[i].legend()
  954. ax[i].set_ylabel(f) if i == 0 else None # add filename
  955. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  956. def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results()
  957. # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
  958. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  959. ax = ax.ravel()
  960. s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
  961. 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  962. if bucket:
  963. os.system('rm -rf storage.googleapis.com')
  964. files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  965. else:
  966. files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')
  967. for fi, f in enumerate(files):
  968. try:
  969. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  970. n = results.shape[1] # number of rows
  971. x = range(start, min(stop, n) if stop else n)
  972. for i in range(10):
  973. y = results[i, x]
  974. if i in [0, 1, 2, 5, 6, 7]:
  975. y[y == 0] = np.nan # dont show zero loss values
  976. # y /= y[0] # normalize
  977. label = labels[fi] if len(labels) else Path(f).stem
  978. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  979. ax[i].set_title(s[i])
  980. # if i in [5, 6, 7]: # share train and val loss y axes
  981. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  982. except:
  983. print('Warning: Plotting error for %s, skipping file' % f)
  984. fig.tight_layout()
  985. ax[1].legend()
  986. fig.savefig('results.png', dpi=200)