落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

294 lines
12KB

  1. import numpy as np
  2. class Evaluator(object):
  3. def __init__(self, num_class):
  4. self.num_class = num_class
  5. # self.confusion_matrix = np.zeros((self.num_class,)*2) # 原始
  6. self.confusion_matrix = np.zeros((self.num_class, self.num_class)) # 改动
  7. def Recall_Precision(self):
  8. TP = np.diag(self.confusion_matrix)
  9. TP_add_FN = self.confusion_matrix.sum(axis=1) # 每一行的和
  10. TP_add_FP = self.confusion_matrix.sum(axis=0) # 每一列的和
  11. Recall = TP / TP_add_FN # 模型正确识别出为正类的样本的数量占总的正类样本数量的比值。一般情况下,Recall越高,说明有更多的正类样本被模型预测正确,模型的效果越好。
  12. Precision = TP / TP_add_FP # 表示在模型识别为正类的样本中,真正为正类的样本所占的比例。
  13. F1 = 2*(Precision*Recall)/(Precision+Recall)
  14. # recall= np.nanmean(Recall)
  15. # precision = np.nanmean(Precision)
  16. # f1= np.nanmean(F1)
  17. return Recall, Precision, F1
  18. def Pixel_Accuracy(self):
  19. Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
  20. return Acc
  21. def Pixel_Accuracy_Class(self):
  22. Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
  23. Acc = np.nanmean(Acc)
  24. return Acc
  25. def Mean_Intersection_over_Union(self):
  26. class_IoU = np.diag(self.confusion_matrix) / (
  27. np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
  28. np.diag(self.confusion_matrix))
  29. MIoU = np.nanmean(class_IoU)
  30. return class_IoU, MIoU
  31. def Frequency_Weighted_Intersection_over_Union(self):
  32. freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
  33. iu = np.diag(self.confusion_matrix) / (
  34. np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
  35. np.diag(self.confusion_matrix))
  36. FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
  37. return FWIoU
  38. def _generate_matrix(self, gt_image, pre_image):
  39. mask = (gt_image >= 0) & (gt_image < self.num_class)
  40. label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
  41. count = np.bincount(label, minlength=self.num_class**2)
  42. confusion_matrix = count.reshape(self.num_class, self.num_class)
  43. return confusion_matrix
  44. def add_batch(self, gt_image, pre_image):
  45. assert gt_image.shape == pre_image.shape
  46. self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
  47. def reset(self):
  48. # self.confusion_matrix = np.zeros((self.num_class,) * 2) # 原始
  49. self.confusion_matrix = np.zeros((self.num_class, self.num_class)) # 改动
  50. # Model validation metrics
  51. from pathlib import Path
  52. import matplotlib.pyplot as plt
  53. import numpy as np
  54. import torch
  55. from . import general
  56. def fitness(x):
  57. # Model fitness as a weighted combination of metrics
  58. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  59. return (x[:, :4] * w).sum(1)
  60. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
  61. """ Compute the average precision, given the recall and precision curves.
  62. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  63. # Arguments
  64. tp: True positives (nparray, nx1 or nx10).
  65. conf: Objectness value from 0-1 (nparray).
  66. pred_cls: Predicted object classes (nparray).
  67. target_cls: True object classes (nparray).
  68. plot: Plot precision-recall curve at mAP@0.5
  69. save_dir: Plot save directory
  70. # Returns
  71. The average precision as computed in py-faster-rcnn.
  72. """
  73. # Sort by objectness
  74. i = np.argsort(-conf)
  75. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  76. # Find unique classes
  77. unique_classes = np.unique(target_cls)
  78. nc = unique_classes.shape[0] # number of classes, number of detections
  79. # Create Precision-Recall curve and compute AP for each class
  80. px, py = np.linspace(0, 1, 1000), [] # for plotting
  81. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  82. for ci, c in enumerate(unique_classes):
  83. i = pred_cls == c
  84. n_l = (target_cls == c).sum() # number of labels
  85. n_p = i.sum() # number of predictions
  86. if n_p == 0 or n_l == 0:
  87. continue
  88. else:
  89. # Accumulate FPs and TPs
  90. fpc = (1 - tp[i]).cumsum(0)
  91. tpc = tp[i].cumsum(0)
  92. # Recall
  93. recall = tpc / (n_l + 1e-16) # recall curve
  94. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  95. # Precision
  96. precision = tpc / (tpc + fpc) # precision curve
  97. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  98. # AP from recall-precision curve
  99. for j in range(tp.shape[1]):
  100. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  101. if plot and j == 0:
  102. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  103. # Compute F1 (harmonic mean of precision and recall)
  104. f1 = 2 * p * r / (p + r + 1e-16)
  105. if plot:
  106. plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
  107. plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
  108. plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
  109. plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
  110. i = f1.mean(0).argmax() # max F1 index
  111. return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
  112. def compute_ap(recall, precision):
  113. """ Compute the average precision, given the recall and precision curves
  114. # Arguments
  115. recall: The recall curve (list)
  116. precision: The precision curve (list)
  117. # Returns
  118. Average precision, precision curve, recall curve
  119. """
  120. # Append sentinel values to beginning and end
  121. mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
  122. mpre = np.concatenate(([1.], precision, [0.]))
  123. # Compute the precision envelope
  124. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  125. # Integrate area under curve
  126. method = 'interp' # methods: 'continuous', 'interp'
  127. if method == 'interp':
  128. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  129. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  130. else: # 'continuous'
  131. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  132. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  133. return ap, mpre, mrec
  134. class ConfusionMatrix:
  135. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  136. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  137. self.matrix = np.zeros((nc + 1, nc + 1))
  138. self.nc = nc # number of classes
  139. self.conf = conf
  140. self.iou_thres = iou_thres
  141. def process_batch(self, detections, labels):
  142. """
  143. Return intersection-over-union (Jaccard index) of boxes.
  144. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  145. Arguments:
  146. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  147. labels (Array[M, 5]), class, x1, y1, x2, y2
  148. Returns:
  149. None, updates confusion matrix accordingly
  150. """
  151. detections = detections[detections[:, 4] > self.conf]
  152. gt_classes = labels[:, 0].int()
  153. detection_classes = detections[:, 5].int()
  154. iou = general.box_iou(labels[:, 1:], detections[:, :4])
  155. x = torch.where(iou > self.iou_thres)
  156. if x[0].shape[0]:
  157. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  158. if x[0].shape[0] > 1:
  159. matches = matches[matches[:, 2].argsort()[::-1]]
  160. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  161. matches = matches[matches[:, 2].argsort()[::-1]]
  162. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  163. else:
  164. matches = np.zeros((0, 3))
  165. n = matches.shape[0] > 0
  166. m0, m1, _ = matches.transpose().astype(np.int16)
  167. for i, gc in enumerate(gt_classes):
  168. j = m0 == i
  169. if n and sum(j) == 1:
  170. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  171. else:
  172. self.matrix[self.nc, gc] += 1 # background FP
  173. if n:
  174. for i, dc in enumerate(detection_classes):
  175. if not any(m1 == i):
  176. self.matrix[dc, self.nc] += 1 # background FN
  177. def matrix(self):
  178. return self.matrix
  179. def plot(self, save_dir='', names=()):
  180. try:
  181. import seaborn as sn
  182. array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
  183. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  184. fig = plt.figure(figsize=(12, 9), tight_layout=True)
  185. sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
  186. labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
  187. sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
  188. xticklabels=names + ['background FP'] if labels else "auto",
  189. yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
  190. fig.axes[0].set_xlabel('True')
  191. fig.axes[0].set_ylabel('Predicted')
  192. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  193. except Exception as e:
  194. pass
  195. def print(self):
  196. for i in range(self.nc + 1):
  197. print(' '.join(map(str, self.matrix[i])))
  198. # Plots ----------------------------------------------------------------------------------------------------------------
  199. def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
  200. # Precision-recall curve
  201. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  202. py = np.stack(py, axis=1)
  203. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  204. for i, y in enumerate(py.T):
  205. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  206. else:
  207. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  208. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  209. ax.set_xlabel('Recall')
  210. ax.set_ylabel('Precision')
  211. ax.set_xlim(0, 1)
  212. ax.set_ylim(0, 1)
  213. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  214. fig.savefig(Path(save_dir), dpi=250)
  215. def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
  216. # Metric-confidence curve
  217. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  218. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  219. for i, y in enumerate(py):
  220. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  221. else:
  222. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  223. y = py.mean(0)
  224. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  225. ax.set_xlabel(xlabel)
  226. ax.set_ylabel(ylabel)
  227. ax.set_xlim(0, 1)
  228. ax.set_ylim(0, 1)
  229. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  230. fig.savefig(Path(save_dir), dpi=250)