You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
14KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Model validation metrics
  4. """
  5. import math
  6. import warnings
  7. from pathlib import Path
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import torch
  11. def fitness(x):
  12. # Model fitness as a weighted combination of metrics
  13. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  14. return (x[:, :4] * w).sum(1)
  15. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
  16. """ Compute the average precision, given the recall and precision curves.
  17. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  18. # Arguments
  19. tp: True positives (nparray, nx1 or nx10).
  20. conf: Objectness value from 0-1 (nparray).
  21. pred_cls: Predicted object classes (nparray).
  22. target_cls: True object classes (nparray).
  23. plot: Plot precision-recall curve at mAP@0.5
  24. save_dir: Plot save directory
  25. # Returns
  26. The average precision as computed in py-faster-rcnn.
  27. """
  28. # Sort by objectness
  29. i = np.argsort(-conf)
  30. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  31. # Find unique classes
  32. unique_classes, nt = np.unique(target_cls, return_counts=True)
  33. nc = unique_classes.shape[0] # number of classes, number of detections
  34. # Create Precision-Recall curve and compute AP for each class
  35. px, py = np.linspace(0, 1, 1000), [] # for plotting
  36. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  37. for ci, c in enumerate(unique_classes):
  38. i = pred_cls == c
  39. n_l = nt[ci] # number of labels
  40. n_p = i.sum() # number of predictions
  41. if n_p == 0 or n_l == 0:
  42. continue
  43. else:
  44. # Accumulate FPs and TPs
  45. fpc = (1 - tp[i]).cumsum(0)
  46. tpc = tp[i].cumsum(0)
  47. # Recall
  48. recall = tpc / (n_l + eps) # recall curve
  49. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  50. # Precision
  51. precision = tpc / (tpc + fpc) # precision curve
  52. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  53. # AP from recall-precision curve
  54. for j in range(tp.shape[1]):
  55. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  56. if plot and j == 0:
  57. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  58. # Compute F1 (harmonic mean of precision and recall)
  59. f1 = 2 * p * r / (p + r + eps)
  60. names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
  61. names = {i: v for i, v in enumerate(names)} # to dict
  62. if plot:
  63. plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
  64. plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
  65. plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
  66. plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
  67. i = f1.mean(0).argmax() # max F1 index
  68. p, r, f1 = p[:, i], r[:, i], f1[:, i]
  69. tp = (r * nt).round() # true positives
  70. fp = (tp / (p + eps) - tp).round() # false positives
  71. return tp, fp, p, r, f1, ap, unique_classes.astype('int32')
  72. def compute_ap(recall, precision):
  73. """ Compute the average precision, given the recall and precision curves
  74. # Arguments
  75. recall: The recall curve (list)
  76. precision: The precision curve (list)
  77. # Returns
  78. Average precision, precision curve, recall curve
  79. """
  80. # Append sentinel values to beginning and end
  81. mrec = np.concatenate(([0.0], recall, [1.0]))
  82. mpre = np.concatenate(([1.0], precision, [0.0]))
  83. # Compute the precision envelope
  84. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  85. # Integrate area under curve
  86. method = 'interp' # methods: 'continuous', 'interp'
  87. if method == 'interp':
  88. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  89. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  90. else: # 'continuous'
  91. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  92. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  93. return ap, mpre, mrec
  94. class ConfusionMatrix:
  95. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  96. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  97. self.matrix = np.zeros((nc + 1, nc + 1))
  98. self.nc = nc # number of classes
  99. self.conf = conf
  100. self.iou_thres = iou_thres
  101. def process_batch(self, detections, labels):
  102. """
  103. Return intersection-over-union (Jaccard index) of boxes.
  104. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  105. Arguments:
  106. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  107. labels (Array[M, 5]), class, x1, y1, x2, y2
  108. Returns:
  109. None, updates confusion matrix accordingly
  110. """
  111. detections = detections[detections[:, 4] > self.conf]
  112. gt_classes = labels[:, 0].int()
  113. detection_classes = detections[:, 5].int()
  114. iou = box_iou(labels[:, 1:], detections[:, :4])
  115. x = torch.where(iou > self.iou_thres)
  116. if x[0].shape[0]:
  117. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  118. if x[0].shape[0] > 1:
  119. matches = matches[matches[:, 2].argsort()[::-1]]
  120. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  121. matches = matches[matches[:, 2].argsort()[::-1]]
  122. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  123. else:
  124. matches = np.zeros((0, 3))
  125. n = matches.shape[0] > 0
  126. m0, m1, _ = matches.transpose().astype(np.int16)
  127. for i, gc in enumerate(gt_classes):
  128. j = m0 == i
  129. if n and sum(j) == 1:
  130. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  131. else:
  132. self.matrix[self.nc, gc] += 1 # background FP
  133. if n:
  134. for i, dc in enumerate(detection_classes):
  135. if not any(m1 == i):
  136. self.matrix[dc, self.nc] += 1 # background FN
  137. def matrix(self):
  138. return self.matrix
  139. def tp_fp(self):
  140. tp = self.matrix.diagonal() # true positives
  141. fp = self.matrix.sum(1) - tp # false positives
  142. # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
  143. return tp[:-1], fp[:-1] # remove background class
  144. def plot(self, normalize=True, save_dir='', names=()):
  145. try:
  146. import seaborn as sn
  147. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
  148. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  149. fig = plt.figure(figsize=(12, 9), tight_layout=True)
  150. nc, nn = self.nc, len(names) # number of classes, names
  151. sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
  152. labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
  153. with warnings.catch_warnings():
  154. warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  155. sn.heatmap(array,
  156. annot=nc < 30,
  157. annot_kws={
  158. "size": 8},
  159. cmap='Blues',
  160. fmt='.2f',
  161. square=True,
  162. vmin=0.0,
  163. xticklabels=names + ['background FP'] if labels else "auto",
  164. yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
  165. fig.axes[0].set_xlabel('True')
  166. fig.axes[0].set_ylabel('Predicted')
  167. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  168. plt.close()
  169. except Exception as e:
  170. print(f'WARNING: ConfusionMatrix plot failure: {e}')
  171. def print(self):
  172. for i in range(self.nc + 1):
  173. print(' '.join(map(str, self.matrix[i])))
  174. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  175. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  176. # Get the coordinates of bounding boxes
  177. if xywh: # transform from xywh to xyxy
  178. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
  179. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  180. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  181. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  182. else: # x1, y1, x2, y2 = box1
  183. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
  184. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
  185. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  186. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  187. # Intersection area
  188. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  189. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  190. # Union Area
  191. union = w1 * h1 + w2 * h2 - inter + eps
  192. # IoU
  193. iou = inter / union
  194. if CIoU or DIoU or GIoU:
  195. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  196. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  197. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  198. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  199. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
  200. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  201. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  202. with torch.no_grad():
  203. alpha = v / (v - iou + (1 + eps))
  204. return iou - (rho2 / c2 + v * alpha) # CIoU
  205. return iou - rho2 / c2 # DIoU
  206. c_area = cw * ch + eps # convex area
  207. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  208. return iou # IoU
  209. def box_area(box):
  210. # box = xyxy(4,n)
  211. return (box[2] - box[0]) * (box[3] - box[1])
  212. def box_iou(box1, box2):
  213. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  214. """
  215. Return intersection-over-union (Jaccard index) of boxes.
  216. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  217. Arguments:
  218. box1 (Tensor[N, 4])
  219. box2 (Tensor[M, 4])
  220. Returns:
  221. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  222. IoU values for every element in boxes1 and boxes2
  223. """
  224. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  225. (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
  226. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  227. # IoU = inter / (area1 + area2 - inter)
  228. return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter)
  229. def bbox_ioa(box1, box2, eps=1E-7):
  230. """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
  231. box1: np.array of shape(4)
  232. box2: np.array of shape(nx4)
  233. returns: np.array of shape(n)
  234. """
  235. # Get the coordinates of bounding boxes
  236. b1_x1, b1_y1, b1_x2, b1_y2 = box1
  237. b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
  238. # Intersection area
  239. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  240. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  241. # box2 area
  242. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
  243. # Intersection over box2 area
  244. return inter_area / box2_area
  245. def wh_iou(wh1, wh2):
  246. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  247. wh1 = wh1[:, None] # [N,1,2]
  248. wh2 = wh2[None] # [1,M,2]
  249. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  250. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  251. # Plots ----------------------------------------------------------------------------------------------------------------
  252. def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
  253. # Precision-recall curve
  254. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  255. py = np.stack(py, axis=1)
  256. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  257. for i, y in enumerate(py.T):
  258. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  259. else:
  260. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  261. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  262. ax.set_xlabel('Recall')
  263. ax.set_ylabel('Precision')
  264. ax.set_xlim(0, 1)
  265. ax.set_ylim(0, 1)
  266. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  267. fig.savefig(Path(save_dir), dpi=250)
  268. plt.close()
  269. def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
  270. # Metric-confidence curve
  271. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  272. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  273. for i, y in enumerate(py):
  274. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  275. else:
  276. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  277. y = py.mean(0)
  278. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  279. ax.set_xlabel(xlabel)
  280. ax.set_ylabel(ylabel)
  281. ax.set_xlim(0, 1)
  282. ax.set_ylim(0, 1)
  283. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  284. fig.savefig(Path(save_dir), dpi=250)
  285. plt.close()