You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1159 lines
47KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import subprocess
  7. import time
  8. from copy import copy
  9. from pathlib import Path
  10. from sys import platform
  11. import cv2
  12. import matplotlib
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. import torch
  16. import torch.nn as nn
  17. import torchvision
  18. from scipy.signal import butter, filtfilt
  19. from tqdm import tqdm
  20. from . import torch_utils, google_utils #  torch_utils, google_utils
  21. # Set printoptions
  22. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  23. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  24. matplotlib.rc('font', **{'size': 11})
  25. # Prevent OpenCV from multithreading (to use PyTorch DataLoader)
  26. cv2.setNumThreads(0)
  27. def init_seeds(seed=0):
  28. random.seed(seed)
  29. np.random.seed(seed)
  30. torch_utils.init_seeds(seed=seed)
  31. def check_git_status():
  32. # Suggest 'git pull' if repo is out of date
  33. if platform in ['linux', 'darwin']:
  34. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  35. if 'Your branch is behind' in s:
  36. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  37. def check_img_size(img_size, s=32):
  38. # Verify img_size is a multiple of stride s
  39. if img_size % s != 0:
  40. print('WARNING: --img-size %g must be multiple of max stride %g' % (img_size, s))
  41. return make_divisible(img_size, s) # nearest gs-multiple
  42. def check_best_possible_recall(dataset, anchors, thr):
  43. # Check best possible recall of dataset with current anchors
  44. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])).float() # wh
  45. ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
  46. m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
  47. bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
  48. mr = (m < thr).float().mean() # match ratio
  49. print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
  50. print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
  51. assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
  52. 'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
  53. def make_divisible(x, divisor):
  54. # Returns x evenly divisble by divisor
  55. return math.ceil(x / divisor) * divisor
  56. def labels_to_class_weights(labels, nc=80):
  57. # Get class weights (inverse frequency) from training labels
  58. if labels[0] is None: # no labels loaded
  59. return torch.Tensor()
  60. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  61. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  62. weights = np.bincount(classes, minlength=nc) # occurences per class
  63. # Prepend gridpoint count (for uCE trianing)
  64. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  65. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  66. weights[weights == 0] = 1 # replace empty bins with 1
  67. weights = 1 / weights # number of targets per class
  68. weights /= weights.sum() # normalize
  69. return torch.from_numpy(weights)
  70. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  71. # Produces image weights based on class mAPs
  72. n = len(labels)
  73. class_counts = np.array([np.bincount(labels[i][:, 0].astype(np.int), minlength=nc) for i in range(n)])
  74. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  75. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  76. return image_weights
  77. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  78. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  79. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  80. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  81. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  82. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  83. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  84. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  85. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  86. return x
  87. def xyxy2xywh(x):
  88. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  89. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  90. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  91. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  92. y[:, 2] = x[:, 2] - x[:, 0] # width
  93. y[:, 3] = x[:, 3] - x[:, 1] # height
  94. return y
  95. def xywh2xyxy(x):
  96. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  97. y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
  98. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  99. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  100. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  101. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  102. return y
  103. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  104. # Rescale coords (xyxy) from img1_shape to img0_shape
  105. if ratio_pad is None: # calculate from img0_shape
  106. gain = max(img1_shape) / max(img0_shape) # gain = old / new
  107. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  108. else:
  109. gain = ratio_pad[0][0]
  110. pad = ratio_pad[1]
  111. coords[:, [0, 2]] -= pad[0] # x padding
  112. coords[:, [1, 3]] -= pad[1] # y padding
  113. coords[:, :4] /= gain
  114. clip_coords(coords, img0_shape)
  115. return coords
  116. def clip_coords(boxes, img_shape):
  117. # Clip bounding xyxy bounding boxes to image shape (height, width)
  118. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  119. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  120. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  121. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  122. def ap_per_class(tp, conf, pred_cls, target_cls):
  123. """ Compute the average precision, given the recall and precision curves.
  124. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  125. # Arguments
  126. tp: True positives (nparray, nx1 or nx10).
  127. conf: Objectness value from 0-1 (nparray).
  128. pred_cls: Predicted object classes (nparray).
  129. target_cls: True object classes (nparray).
  130. # Returns
  131. The average precision as computed in py-faster-rcnn.
  132. """
  133. # Sort by objectness
  134. i = np.argsort(-conf)
  135. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  136. # Find unique classes
  137. unique_classes = np.unique(target_cls)
  138. # Create Precision-Recall curve and compute AP for each class
  139. pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
  140. s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
  141. ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
  142. for ci, c in enumerate(unique_classes):
  143. i = pred_cls == c
  144. n_gt = (target_cls == c).sum() # Number of ground truth objects
  145. n_p = i.sum() # Number of predicted objects
  146. if n_p == 0 or n_gt == 0:
  147. continue
  148. else:
  149. # Accumulate FPs and TPs
  150. fpc = (1 - tp[i]).cumsum(0)
  151. tpc = tp[i].cumsum(0)
  152. # Recall
  153. recall = tpc / (n_gt + 1e-16) # recall curve
  154. r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
  155. # Precision
  156. precision = tpc / (tpc + fpc) # precision curve
  157. p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
  158. # AP from recall-precision curve
  159. for j in range(tp.shape[1]):
  160. ap[ci, j] = compute_ap(recall[:, j], precision[:, j])
  161. # Plot
  162. # fig, ax = plt.subplots(1, 1, figsize=(5, 5))
  163. # ax.plot(recall, precision)
  164. # ax.set_xlabel('Recall')
  165. # ax.set_ylabel('Precision')
  166. # ax.set_xlim(0, 1.01)
  167. # ax.set_ylim(0, 1.01)
  168. # fig.tight_layout()
  169. # fig.savefig('PR_curve.png', dpi=300)
  170. # Compute F1 score (harmonic mean of precision and recall)
  171. f1 = 2 * p * r / (p + r + 1e-16)
  172. return p, r, ap, f1, unique_classes.astype('int32')
  173. def compute_ap(recall, precision):
  174. """ Compute the average precision, given the recall and precision curves.
  175. Source: https://github.com/rbgirshick/py-faster-rcnn.
  176. # Arguments
  177. recall: The recall curve (list).
  178. precision: The precision curve (list).
  179. # Returns
  180. The average precision as computed in py-faster-rcnn.
  181. """
  182. # Append sentinel values to beginning and end
  183. mrec = np.concatenate(([0.], recall, [min(recall[-1] + 1E-3, 1.)]))
  184. mpre = np.concatenate(([0.], precision, [0.]))
  185. # Compute the precision envelope
  186. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  187. # Integrate area under curve
  188. method = 'interp' # methods: 'continuous', 'interp'
  189. if method == 'interp':
  190. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  191. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  192. else: # 'continuous'
  193. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  194. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  195. return ap
  196. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
  197. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  198. box2 = box2.t()
  199. # Get the coordinates of bounding boxes
  200. if x1y1x2y2: # x1, y1, x2, y2 = box1
  201. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  202. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  203. else: # transform from xywh to xyxy
  204. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  205. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  206. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  207. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  208. # Intersection area
  209. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  210. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  211. # Union Area
  212. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
  213. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
  214. union = (w1 * h1 + 1e-16) + w2 * h2 - inter
  215. iou = inter / union # iou
  216. if GIoU or DIoU or CIoU:
  217. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  218. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  219. if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
  220. c_area = cw * ch + 1e-16 # convex area
  221. return iou - (c_area - union) / c_area # GIoU
  222. if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  223. # convex diagonal squared
  224. c2 = cw ** 2 + ch ** 2 + 1e-16
  225. # centerpoint distance squared
  226. rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
  227. if DIoU:
  228. return iou - rho2 / c2 # DIoU
  229. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  230. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  231. with torch.no_grad():
  232. alpha = v / (1 - iou + v)
  233. return iou - (rho2 / c2 + v * alpha) # CIoU
  234. return iou
  235. def box_iou(box1, box2):
  236. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  237. """
  238. Return intersection-over-union (Jaccard index) of boxes.
  239. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  240. Arguments:
  241. box1 (Tensor[N, 4])
  242. box2 (Tensor[M, 4])
  243. Returns:
  244. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  245. IoU values for every element in boxes1 and boxes2
  246. """
  247. def box_area(box):
  248. # box = 4xn
  249. return (box[2] - box[0]) * (box[3] - box[1])
  250. area1 = box_area(box1.t())
  251. area2 = box_area(box2.t())
  252. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  253. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  254. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  255. def wh_iou(wh1, wh2):
  256. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  257. wh1 = wh1[:, None] # [N,1,2]
  258. wh2 = wh2[None] # [1,M,2]
  259. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  260. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  261. class FocalLoss(nn.Module):
  262. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  263. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  264. super(FocalLoss, self).__init__()
  265. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  266. self.gamma = gamma
  267. self.alpha = alpha
  268. self.reduction = loss_fcn.reduction
  269. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  270. def forward(self, pred, true):
  271. loss = self.loss_fcn(pred, true)
  272. # p_t = torch.exp(-loss)
  273. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  274. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  275. pred_prob = torch.sigmoid(pred) # prob from logits
  276. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  277. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  278. modulating_factor = (1.0 - p_t) ** self.gamma
  279. loss *= alpha_factor * modulating_factor
  280. if self.reduction == 'mean':
  281. return loss.mean()
  282. elif self.reduction == 'sum':
  283. return loss.sum()
  284. else: # 'none'
  285. return loss
  286. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  287. # return positive, negative label smoothing BCE targets
  288. return 1.0 - 0.5 * eps, 0.5 * eps
  289. class BCEBlurWithLogitsLoss(nn.Module):
  290. # BCEwithLogitLoss() with reduced missing label effects.
  291. def __init__(self, alpha=0.05):
  292. super(BCEBlurWithLogitsLoss, self).__init__()
  293. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  294. self.alpha = alpha
  295. def forward(self, pred, true):
  296. loss = self.loss_fcn(pred, true)
  297. pred = torch.sigmoid(pred) # prob from logits
  298. dx = pred - true # reduce only missing label effects
  299. # dx = (pred - true).abs() # reduce missing label and false label effects
  300. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  301. loss *= alpha_factor
  302. return loss.mean()
  303. def compute_loss(p, targets, model): # predictions, targets, model
  304. ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
  305. lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
  306. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  307. h = model.hyp # hyperparameters
  308. red = 'mean' # Loss reduction (sum or mean)
  309. # Define criteria
  310. BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
  311. BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
  312. # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  313. cp, cn = smooth_BCE(eps=0.0)
  314. # focal loss
  315. g = h['fl_gamma'] # focal loss gamma
  316. if g > 0:
  317. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  318. # per output
  319. nt = 0 # targets
  320. for i, pi in enumerate(p): # layer index, layer predictions
  321. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  322. tobj = torch.zeros_like(pi[..., 0]) # target obj
  323. nb = b.shape[0] # number of targets
  324. if nb:
  325. nt += nb # cumulative targets
  326. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  327. # GIoU
  328. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  329. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  330. pbox = torch.cat((pxy, pwh), 1) # predicted box
  331. giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
  332. lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
  333. # Obj
  334. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
  335. # Class
  336. if model.nc > 1: # cls loss (only if multiple classes)
  337. t = torch.full_like(ps[:, 5:], cn) # targets
  338. t[range(nb), tcls[i]] = cp
  339. lcls += BCEcls(ps[:, 5:], t) # BCE
  340. # Append targets to text file
  341. # with open('targets.txt', 'a') as file:
  342. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  343. lobj += BCEobj(pi[..., 4], tobj) # obj loss
  344. lbox *= h['giou']
  345. lobj *= h['obj']
  346. lcls *= h['cls']
  347. bs = tobj.shape[0] # batch size
  348. if red == 'sum':
  349. g = 3.0 # loss gain
  350. lobj *= g / bs
  351. if nt:
  352. lcls *= g / nt / model.nc
  353. lbox *= g / nt
  354. loss = lbox + lobj + lcls
  355. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  356. def build_targets(p, targets, model):
  357. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  358. det = model.module.model[-1] if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) \
  359. else model.model[-1] # Detect() module
  360. na, nt = det.na, targets.shape[0] # number of anchors, targets
  361. tcls, tbox, indices, anch = [], [], [], []
  362. gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
  363. off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
  364. at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
  365. style = 'rect4'
  366. for i in range(det.nl):
  367. anchors = det.anchors[i]
  368. gain[2:] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  369. # Match targets to anchors
  370. a, t, offsets = [], targets * gain, 0
  371. if nt:
  372. r = t[None, :, 4:6] / anchors[:, None] # wh ratio
  373. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  374. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n) = wh_iou(anchors(3,2), gwh(n,2))
  375. a, t = at[j], t.repeat(na, 1, 1)[j] # filter
  376. # overlaps
  377. gxy = t[:, 2:4] # grid xy
  378. z = torch.zeros_like(gxy)
  379. if style == 'rect2':
  380. g = 0.2 # offset
  381. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  382. a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
  383. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
  384. elif style == 'rect4':
  385. g = 0.5 # offset
  386. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  387. l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
  388. a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
  389. offsets = torch.cat((z, z[j] + off[0], z[k] + off[1], z[l] + off[2], z[m] + off[3]), 0) * g
  390. # Define
  391. b, c = t[:, :2].long().T # image, class
  392. gxy = t[:, 2:4] # grid xy
  393. gwh = t[:, 4:6] # grid wh
  394. gij = (gxy - offsets).long()
  395. gi, gj = gij.T # grid xy indices
  396. # Append
  397. indices.append((b, a, gj, gi)) # image, anchor, grid indices
  398. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  399. anch.append(anchors[a]) # anchors
  400. tcls.append(c) # class
  401. return tcls, tbox, indices, anch
  402. def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
  403. """
  404. Performs Non-Maximum Suppression on inference results
  405. Returns detections with shape:
  406. nx6 (x1, y1, x2, y2, conf, cls)
  407. """
  408. if prediction.dtype is torch.float16:
  409. prediction = prediction.float() # to FP32
  410. nc = prediction[0].shape[1] - 5 # number of classes
  411. xc = prediction[..., 4] > conf_thres # candidates
  412. # Settings
  413. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  414. max_det = 300 # maximum number of detections per image
  415. time_limit = 10.0 # seconds to quit after
  416. redundant = True # require redundant detections
  417. fast |= conf_thres > 0.001 # fast mode
  418. if fast:
  419. merge = False
  420. multi_label = False
  421. else:
  422. merge = True # merge for best mAP (adds 0.5ms/img)
  423. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  424. t = time.time()
  425. output = [None] * prediction.shape[0]
  426. for xi, x in enumerate(prediction): # image index, image inference
  427. # Apply constraints
  428. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  429. x = x[xc[xi]] # confidence
  430. # If none remain process next image
  431. if not x.shape[0]:
  432. continue
  433. # Compute conf
  434. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  435. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  436. box = xywh2xyxy(x[:, :4])
  437. # Detections matrix nx6 (xyxy, conf, cls)
  438. if multi_label:
  439. i, j = (x[:, 5:] > conf_thres).nonzero().t()
  440. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  441. else: # best class only
  442. conf, j = x[:, 5:].max(1, keepdim=True)
  443. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  444. # Filter by class
  445. if classes:
  446. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  447. # Apply finite constraint
  448. # if not torch.isfinite(x).all():
  449. # x = x[torch.isfinite(x).all(1)]
  450. # If none remain process next image
  451. n = x.shape[0] # number of boxes
  452. if not n:
  453. continue
  454. # Sort by confidence
  455. # x = x[x[:, 4].argsort(descending=True)]
  456. # Batched NMS
  457. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  458. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  459. i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
  460. if i.shape[0] > max_det: # limit detections
  461. i = i[:max_det]
  462. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  463. try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  464. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  465. weights = iou * scores[None] # box weights
  466. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  467. if redundant:
  468. i = i[iou.sum(1) > 1] # require redundancy
  469. except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
  470. print(x, i, x.shape, i.shape)
  471. pass
  472. output[xi] = x[i]
  473. if (time.time() - t) > time_limit:
  474. break # time limit exceeded
  475. return output
  476. def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_optimizer()
  477. # Strip optimizer from *.pt files for lighter files (reduced by 1/2 size)
  478. x = torch.load(f, map_location=torch.device('cpu'))
  479. x['optimizer'] = None
  480. torch.save(x, f)
  481. print('Optimizer stripped from %s' % f)
  482. def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
  483. # create backbone 's' from 'f'
  484. device = torch.device('cpu')
  485. x = torch.load(f, map_location=device)
  486. torch.save(x, s) # update model if SourceChangeWarning
  487. x = torch.load(s, map_location=device)
  488. x['optimizer'] = None
  489. x['training_results'] = None
  490. x['epoch'] = -1
  491. for p in x['model'].parameters():
  492. p.requires_grad = True
  493. torch.save(x, s)
  494. print('%s modified for backbone use and saved as %s' % (f, s))
  495. def coco_class_count(path='../coco/labels/train2014/'):
  496. # Histogram of occurrences per class
  497. nc = 80 # number classes
  498. x = np.zeros(nc, dtype='int32')
  499. files = sorted(glob.glob('%s/*.*' % path))
  500. for i, file in enumerate(files):
  501. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  502. x += np.bincount(labels[:, 0].astype('int32'), minlength=nc)
  503. print(i, len(files))
  504. def coco_only_people(path='../coco/labels/train2017/'): # from utils.utils import *; coco_only_people()
  505. # Find images with only people
  506. files = sorted(glob.glob('%s/*.*' % path))
  507. for i, file in enumerate(files):
  508. labels = np.loadtxt(file, dtype=np.float32).reshape(-1, 5)
  509. if all(labels[:, 0] == 0):
  510. print(labels.shape[0], file)
  511. def crop_images_random(path='../images/', scale=0.50): # from utils.utils import *; crop_images_random()
  512. # crops images into random squares up to scale fraction
  513. # WARNING: overwrites images!
  514. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  515. img = cv2.imread(file) # BGR
  516. if img is not None:
  517. h, w = img.shape[:2]
  518. # create random mask
  519. a = 30 # minimum size (pixels)
  520. mask_h = random.randint(a, int(max(a, h * scale))) # mask height
  521. mask_w = mask_h # mask width
  522. # box
  523. xmin = max(0, random.randint(0, w) - mask_w // 2)
  524. ymin = max(0, random.randint(0, h) - mask_h // 2)
  525. xmax = min(w, xmin + mask_w)
  526. ymax = min(h, ymin + mask_h)
  527. # apply random color mask
  528. cv2.imwrite(file, img[ymin:ymax, xmin:xmax])
  529. def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
  530. # Makes single-class coco datasets. from utils.utils import *; coco_single_class_labels()
  531. if os.path.exists('new/'):
  532. shutil.rmtree('new/') # delete output folder
  533. os.makedirs('new/') # make new output folder
  534. os.makedirs('new/labels/')
  535. os.makedirs('new/images/')
  536. for file in tqdm(sorted(glob.glob('%s/*.*' % path))):
  537. with open(file, 'r') as f:
  538. labels = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  539. i = labels[:, 0] == label_class
  540. if any(i):
  541. img_file = file.replace('labels', 'images').replace('txt', 'jpg')
  542. labels[:, 0] = 0 # reset class to 0
  543. with open('new/images.txt', 'a') as f: # add image to dataset list
  544. f.write(img_file + '\n')
  545. with open('new/labels/' + Path(file).name, 'a') as f: # write label
  546. for l in labels[i]:
  547. f.write('%g %.6f %.6f %.6f %.6f\n' % tuple(l))
  548. shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
  549. def kmean_anchors(path='./data/coco128.txt', n=9, img_size=(640, 640), thr=0.20, gen=1000):
  550. # Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
  551. # n: number of anchors
  552. # img_size: (min, max) image size used for multi-scale training (can be same values)
  553. # thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
  554. # gen: generations to evolve anchors using genetic algorithm
  555. from utils.datasets import LoadImagesAndLabels
  556. def print_results(k):
  557. k = k[np.argsort(k.prod(1))] # sort small to large
  558. iou = wh_iou(wh, torch.Tensor(k))
  559. max_iou = iou.max(1)[0]
  560. bpr, aat = (max_iou > thr).float().mean(), (iou > thr).float().mean() * n # best possible recall, anch > thr
  561. # thr = 5.0
  562. # r = wh[:, None] / k[None]
  563. # ar = torch.max(r, 1. / r).max(2)[0]
  564. # max_ar = ar.min(1)[0]
  565. # bpr, aat = (max_ar < thr).float().mean(), (ar < thr).float().mean() * n # best possible recall, anch > thr
  566. print('%.2f iou_thr: %.3f best possible recall, %.2f anchors > thr' % (thr, bpr, aat))
  567. print('n=%g, img_size=%s, IoU_all=%.3f/%.3f-mean/best, IoU>thr=%.3f-mean: ' %
  568. (n, img_size, iou.mean(), max_iou.mean(), iou[iou > thr].mean()), end='')
  569. for i, x in enumerate(k):
  570. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  571. return k
  572. def fitness(k): # mutation fitness
  573. iou = wh_iou(wh, torch.Tensor(k)) # iou
  574. max_iou = iou.max(1)[0]
  575. return (max_iou * (max_iou > thr).float()).mean() # product
  576. # def fitness_ratio(k): # mutation fitness
  577. # # wh(5316,2), k(9,2)
  578. # r = wh[:, None] / k[None]
  579. # x = torch.max(r, 1. / r).max(2)[0]
  580. # m = x.min(1)[0]
  581. # return 1. / (m * (m < 5).float()).mean() # product
  582. # Get label wh
  583. wh = []
  584. dataset = LoadImagesAndLabels(path, augment=True, rect=True)
  585. nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions
  586. for s, l in zip(dataset.shapes, dataset.labels):
  587. # wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
  588. wh.append(l[:, 3:5] * s) # image normalized to pixels
  589. wh = np.concatenate(wh, 0).repeat(nr, axis=0) # augment 3x
  590. # wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1)) # normalized to pixels (multi-scale)
  591. wh = wh[(wh > 2.0).all(1)] # remove below threshold boxes (< 2 pixels wh)
  592. # Kmeans calculation
  593. from scipy.cluster.vq import kmeans
  594. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  595. s = wh.std(0) # sigmas for whitening
  596. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  597. k *= s
  598. wh = torch.Tensor(wh)
  599. k = print_results(k)
  600. # # Plot
  601. # k, d = [None] * 20, [None] * 20
  602. # for i in tqdm(range(1, 21)):
  603. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  604. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  605. # ax = ax.ravel()
  606. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  607. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  608. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  609. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  610. # fig.tight_layout()
  611. # fig.savefig('wh.png', dpi=200)
  612. # Evolve
  613. npr = np.random
  614. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  615. for _ in tqdm(range(gen), desc='Evolving anchors'):
  616. v = np.ones(sh)
  617. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  618. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  619. kg = (k.copy() * v).clip(min=2.0)
  620. fg = fitness(kg)
  621. if fg > f:
  622. f, k = fg, kg.copy()
  623. print_results(k)
  624. k = print_results(k)
  625. return k
  626. def print_mutation(hyp, results, bucket=''):
  627. # Print mutation results to evolve.txt (for use with train.py --evolve)
  628. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  629. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  630. c = '%10.4g' * len(results) % results # results (P, R, mAP, F1, test_loss)
  631. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  632. if bucket:
  633. os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
  634. with open('evolve.txt', 'a') as f: # append result
  635. f.write(c + b + '\n')
  636. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  637. np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness
  638. if bucket:
  639. os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
  640. def apply_classifier(x, model, img, im0):
  641. # applies a second stage classifier to yolo outputs
  642. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  643. for i, d in enumerate(x): # per image
  644. if d is not None and len(d):
  645. d = d.clone()
  646. # Reshape and pad cutouts
  647. b = xyxy2xywh(d[:, :4]) # boxes
  648. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  649. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  650. d[:, :4] = xywh2xyxy(b).long()
  651. # Rescale boxes from img_size to im0 size
  652. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  653. # Classes
  654. pred_cls1 = d[:, 5].long()
  655. ims = []
  656. for j, a in enumerate(d): # per item
  657. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  658. im = cv2.resize(cutout, (224, 224)) # BGR
  659. # cv2.imwrite('test%i.jpg' % j, cutout)
  660. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  661. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  662. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  663. ims.append(im)
  664. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  665. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  666. return x
  667. def fitness(x):
  668. # Returns fitness (for use with results.txt or evolve.txt)
  669. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  670. return (x[:, :4] * w).sum(1)
  671. def output_to_target(output, width, height):
  672. """
  673. Convert a YOLO model output to target format
  674. [batch_id, class_id, x, y, w, h, conf]
  675. """
  676. if isinstance(output, torch.Tensor):
  677. output = output.cpu().numpy()
  678. targets = []
  679. for i, o in enumerate(output):
  680. if o is not None:
  681. for pred in o:
  682. box = pred[:4]
  683. w = (box[2] - box[0]) / width
  684. h = (box[3] - box[1]) / height
  685. x = box[0] / width + w / 2
  686. y = box[1] / height + h / 2
  687. conf = pred[4]
  688. cls = int(pred[5])
  689. targets.append([i, cls, x, y, w, h, conf])
  690. return np.array(targets)
  691. # Plotting functions ---------------------------------------------------------------------------------------------------
  692. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  693. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  694. def butter_lowpass(cutoff, fs, order):
  695. nyq = 0.5 * fs
  696. normal_cutoff = cutoff / nyq
  697. b, a = butter(order, normal_cutoff, btype='low', analog=False)
  698. return b, a
  699. b, a = butter_lowpass(cutoff, fs, order=order)
  700. return filtfilt(b, a, data) # forward-backward filter
  701. def plot_one_box(x, img, color=None, label=None, line_thickness=None):
  702. # Plots one bounding box on image img
  703. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  704. color = color or [random.randint(0, 255) for _ in range(3)]
  705. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  706. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  707. if label:
  708. tf = max(tl - 1, 1) # font thickness
  709. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  710. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  711. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  712. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  713. def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
  714. # Compares the two methods for width-height anchor multiplication
  715. # https://github.com/ultralytics/yolov3/issues/168
  716. x = np.arange(-4.0, 4.0, .1)
  717. ya = np.exp(x)
  718. yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
  719. fig = plt.figure(figsize=(6, 3), dpi=150)
  720. plt.plot(x, ya, '.-', label='yolo method')
  721. plt.plot(x, yb ** 2, '.-', label='^2 power method')
  722. plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method')
  723. plt.xlim(left=-4, right=4)
  724. plt.ylim(bottom=0, top=6)
  725. plt.xlabel('input')
  726. plt.ylabel('output')
  727. plt.legend()
  728. fig.tight_layout()
  729. fig.savefig('comparison.png', dpi=200)
  730. def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
  731. tl = 3 # line thickness
  732. tf = max(tl - 1, 1) # font thickness
  733. if os.path.isfile(fname): # do not overwrite
  734. return None
  735. if isinstance(images, torch.Tensor):
  736. images = images.cpu().float().numpy()
  737. if isinstance(targets, torch.Tensor):
  738. targets = targets.cpu().numpy()
  739. # un-normalise
  740. if np.max(images[0]) <= 1:
  741. images *= 255
  742. bs, _, h, w = images.shape # batch size, _, height, width
  743. bs = min(bs, max_subplots) # limit plot images
  744. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  745. # Check if we should resize
  746. scale_factor = max_size / max(h, w)
  747. if scale_factor < 1:
  748. h = math.ceil(scale_factor * h)
  749. w = math.ceil(scale_factor * w)
  750. # Empty array for output
  751. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
  752. # Fix class - colour map
  753. prop_cycle = plt.rcParams['axes.prop_cycle']
  754. # https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
  755. hex2rgb = lambda h: tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  756. color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
  757. for i, img in enumerate(images):
  758. if i == max_subplots: # if last batch has fewer images than we expect
  759. break
  760. block_x = int(w * (i // ns))
  761. block_y = int(h * (i % ns))
  762. img = img.transpose(1, 2, 0)
  763. if scale_factor < 1:
  764. img = cv2.resize(img, (w, h))
  765. mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
  766. if len(targets) > 0:
  767. image_targets = targets[targets[:, 0] == i]
  768. boxes = xywh2xyxy(image_targets[:, 2:6]).T
  769. classes = image_targets[:, 1].astype('int')
  770. gt = image_targets.shape[1] == 6 # ground truth if no conf column
  771. conf = None if gt else image_targets[:, 6] # check for confidence presence (gt vs pred)
  772. boxes[[0, 2]] *= w
  773. boxes[[0, 2]] += block_x
  774. boxes[[1, 3]] *= h
  775. boxes[[1, 3]] += block_y
  776. for j, box in enumerate(boxes.T):
  777. cls = int(classes[j])
  778. color = color_lut[cls % len(color_lut)]
  779. cls = names[cls] if names else cls
  780. if gt or conf[j] > 0.3: # 0.3 conf thresh
  781. label = '%s' % cls if gt else '%s %.1f' % (cls, conf[j])
  782. plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
  783. # Draw image filename labels
  784. if paths is not None:
  785. label = os.path.basename(paths[i])[:40] # trim to 40 char
  786. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  787. cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
  788. lineType=cv2.LINE_AA)
  789. # Image border
  790. cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
  791. if fname is not None:
  792. mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
  793. cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
  794. return mosaic
  795. def plot_lr_scheduler(optimizer, scheduler, epochs=300):
  796. # Plot LR simulating training for full epochs
  797. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  798. y = []
  799. for _ in range(epochs):
  800. scheduler.step()
  801. y.append(optimizer.param_groups[0]['lr'])
  802. plt.plot(y, '.-', label='LR')
  803. plt.xlabel('epoch')
  804. plt.ylabel('LR')
  805. plt.grid()
  806. plt.xlim(0, epochs)
  807. plt.ylim(0)
  808. plt.tight_layout()
  809. plt.savefig('LR.png', dpi=200)
  810. def plot_test_txt(): # from utils.utils import *; plot_test()
  811. # Plot test.txt histograms
  812. x = np.loadtxt('test.txt', dtype=np.float32)
  813. box = xyxy2xywh(x[:, :4])
  814. cx, cy = box[:, 0], box[:, 1]
  815. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  816. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  817. ax.set_aspect('equal')
  818. plt.savefig('hist2d.png', dpi=300)
  819. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  820. ax[0].hist(cx, bins=600)
  821. ax[1].hist(cy, bins=600)
  822. plt.savefig('hist1d.png', dpi=200)
  823. def plot_targets_txt(): # from utils.utils import *; plot_targets_txt()
  824. # Plot targets.txt histograms
  825. x = np.loadtxt('targets.txt', dtype=np.float32).T
  826. s = ['x targets', 'y targets', 'width targets', 'height targets']
  827. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  828. ax = ax.ravel()
  829. for i in range(4):
  830. ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
  831. ax[i].legend()
  832. ax[i].set_title(s[i])
  833. plt.savefig('targets.jpg', dpi=200)
  834. def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_study_txt()
  835. # Plot study.txt generated by test.py
  836. fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
  837. ax = ax.ravel()
  838. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  839. for f in ['coco_study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
  840. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  841. x = np.arange(y.shape[1]) if x is None else np.array(x)
  842. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
  843. for i in range(7):
  844. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  845. ax[i].set_title(s[i])
  846. j = y[3].argmax() + 1
  847. ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
  848. label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  849. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
  850. 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
  851. ax2.set_xlim(0, 30)
  852. ax2.set_ylim(25, 50)
  853. ax2.set_xlabel('GPU Latency (ms)')
  854. ax2.set_ylabel('COCO AP val')
  855. ax2.legend(loc='lower right')
  856. ax2.grid()
  857. plt.savefig('study_mAP_latency.png', dpi=300)
  858. plt.savefig(f.replace('.txt', '.png'), dpi=200)
  859. def plot_labels(labels):
  860. # plot dataset labels
  861. c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes
  862. def hist2d(x, y, n=100):
  863. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  864. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  865. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  866. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  867. return np.log(hist[xidx, yidx])
  868. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  869. ax = ax.ravel()
  870. ax[0].hist(c, bins=int(c.max() + 1))
  871. ax[0].set_xlabel('classes')
  872. ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
  873. ax[1].set_xlabel('x')
  874. ax[1].set_ylabel('y')
  875. ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
  876. ax[2].set_xlabel('width')
  877. ax[2].set_ylabel('height')
  878. plt.savefig('labels.png', dpi=200)
  879. def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)
  880. # Plot hyperparameter evolution results in evolve.txt
  881. x = np.loadtxt('evolve.txt', ndmin=2)
  882. f = fitness(x)
  883. # weights = (f - f.min()) ** 2 # for weighted results
  884. plt.figure(figsize=(12, 10), tight_layout=True)
  885. matplotlib.rc('font', **{'size': 8})
  886. for i, (k, v) in enumerate(hyp.items()):
  887. y = x[:, i + 7]
  888. # mu = (y * weights).sum() / weights.sum() # best weighted result
  889. mu = y[f.argmax()] # best single result
  890. plt.subplot(4, 5, i + 1)
  891. plt.plot(mu, f.max(), 'o', markersize=10)
  892. plt.plot(y, f, '.')
  893. plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
  894. print('%15s: %.3g' % (k, mu))
  895. plt.savefig('evolve.png', dpi=200)
  896. def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_results_overlay()
  897. # Plot training 'results*.txt', overlaying train and val losses
  898. s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
  899. t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
  900. for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
  901. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  902. n = results.shape[1] # number of rows
  903. x = range(start, min(stop, n) if stop else n)
  904. fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
  905. ax = ax.ravel()
  906. for i in range(5):
  907. for j in [i, i + 5]:
  908. y = results[j, x]
  909. ax[i].plot(x, y, marker='.', label=s[j])
  910. # y_smooth = butter_lowpass_filtfilt(y)
  911. # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])
  912. ax[i].set_title(t[i])
  913. ax[i].legend()
  914. ax[i].set_ylabel(f) if i == 0 else None # add filename
  915. fig.savefig(f.replace('.txt', '.png'), dpi=200)
  916. def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results()
  917. # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
  918. fig, ax = plt.subplots(2, 5, figsize=(12, 6))
  919. ax = ax.ravel()
  920. s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
  921. 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
  922. if bucket:
  923. os.system('rm -rf storage.googleapis.com')
  924. files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
  925. else:
  926. files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')
  927. for fi, f in enumerate(files):
  928. try:
  929. results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
  930. n = results.shape[1] # number of rows
  931. x = range(start, min(stop, n) if stop else n)
  932. for i in range(10):
  933. y = results[i, x]
  934. if i in [0, 1, 2, 5, 6, 7]:
  935. y[y == 0] = np.nan # dont show zero loss values
  936. # y /= y[0] # normalize
  937. label = labels[fi] if len(labels) else Path(f).stem
  938. ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
  939. ax[i].set_title(s[i])
  940. # if i in [5, 6, 7]: # share train and val loss y axes
  941. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  942. except:
  943. print('Warning: Plotting error for %s, skipping file' % f)
  944. fig.tight_layout()
  945. ax[1].legend()
  946. fig.savefig('results.png', dpi=200)