選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

score.py 6.1KB

2年前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. """Evaluation Metrics for Semantic Segmentation"""
  2. import torch
  3. import numpy as np
  4. __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union',
  5. 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score']
  6. class SegmentationMetric(object):
  7. """Computes pixAcc and mIoU metric scores
  8. """
  9. def __init__(self, nclass):
  10. super(SegmentationMetric, self).__init__()
  11. self.nclass = nclass
  12. self.reset()
  13. def update(self, preds, labels):
  14. """Updates the internal evaluation result.
  15. Parameters
  16. ----------
  17. labels : 'NumpyArray' or list of `NumpyArray`
  18. The labels of the data.
  19. preds : 'NumpyArray' or list of `NumpyArray`
  20. Predicted values.
  21. """
  22. def evaluate_worker(self, pred, label):
  23. correct, labeled = batch_pix_accuracy(pred, label)
  24. inter, union = batch_intersection_union(pred, label, self.nclass)
  25. self.total_correct += correct
  26. self.total_label += labeled
  27. if self.total_inter.device != inter.device:
  28. self.total_inter = self.total_inter.to(inter.device)
  29. self.total_union = self.total_union.to(union.device)
  30. self.total_inter += inter
  31. self.total_union += union
  32. if isinstance(preds, torch.Tensor):
  33. evaluate_worker(self, preds, labels)
  34. elif isinstance(preds, (list, tuple)):
  35. for (pred, label) in zip(preds, labels):
  36. evaluate_worker(self, pred, label)
  37. def get(self):
  38. """Gets the current evaluation result.
  39. Returns
  40. -------
  41. metrics : tuple of float
  42. pixAcc and mIoU
  43. """
  44. pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1)
  45. IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
  46. mIoU = IoU.mean().item()
  47. return pixAcc, mIoU
  48. def reset(self):
  49. """Resets the internal evaluation result to initial state."""
  50. self.total_inter = torch.zeros(self.nclass)
  51. self.total_union = torch.zeros(self.nclass)
  52. self.total_correct = 0
  53. self.total_label = 0
  54. # pytorch version
  55. def batch_pix_accuracy(output, target):
  56. """PixAcc"""
  57. # inputs are numpy array, output 4D, target 3D
  58. predict = torch.argmax(output.long(), 1) + 1
  59. target = target.long() + 1
  60. pixel_labeled = torch.sum(target > 0).item()
  61. pixel_correct = torch.sum((predict == target) * (target > 0)).item()
  62. assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
  63. return pixel_correct, pixel_labeled
  64. def batch_intersection_union(output, target, nclass):
  65. """mIoU"""
  66. # inputs are numpy array, output 4D, target 3D
  67. mini = 1
  68. maxi = nclass
  69. nbins = nclass
  70. predict = torch.argmax(output, 1) + 1
  71. target = target.float() + 1
  72. predict = predict.float() * (target > 0).float()
  73. intersection = predict * (predict == target).float()
  74. # areas of intersection and union
  75. # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
  76. area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
  77. area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
  78. area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
  79. area_union = area_pred + area_lab - area_inter
  80. assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
  81. return area_inter.float(), area_union.float()
  82. def pixelAccuracy(imPred, imLab):
  83. """
  84. This function takes the prediction and label of a single image, returns pixel-wise accuracy
  85. To compute over many images do:
  86. for i = range(Nimages):
  87. (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \
  88. pixelAccuracy(imPred[i], imLab[i])
  89. mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled))
  90. """
  91. # Remove classes from unlabeled pixels in gt image.
  92. # We should not penalize detections in unlabeled portions of the image.
  93. pixel_labeled = np.sum(imLab >= 0)
  94. pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
  95. pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
  96. return (pixel_accuracy, pixel_correct, pixel_labeled)
  97. def intersectionAndUnion(imPred, imLab, numClass):
  98. """
  99. This function takes the prediction and label of a single image,
  100. returns intersection and union areas for each class
  101. To compute over many images do:
  102. for i in range(Nimages):
  103. (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i])
  104. IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1)
  105. """
  106. # Remove classes from unlabeled pixels in gt image.
  107. # We should not penalize detections in unlabeled portions of the image.
  108. imPred = imPred * (imLab >= 0)
  109. # Compute area intersection:
  110. intersection = imPred * (imPred == imLab)
  111. (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass))
  112. # Compute area union:
  113. (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
  114. (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
  115. area_union = area_pred + area_lab - area_intersection
  116. return (area_intersection, area_union)
  117. def hist_info(pred, label, num_cls):
  118. assert pred.shape == label.shape
  119. k = (label >= 0) & (label < num_cls)
  120. labeled = np.sum(k)
  121. correct = np.sum((pred[k] == label[k]))
  122. return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls,
  123. num_cls), labeled, correct
  124. def compute_score(hist, correct, labeled):
  125. iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
  126. mean_IU = np.nanmean(iu)
  127. mean_IU_no_back = np.nanmean(iu[1:])
  128. freq = hist.sum(1) / hist.sum()
  129. freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
  130. mean_pixel_acc = correct / labeled
  131. return iu, mean_IU, mean_IU_no_back, mean_pixel_acc