车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

precision_recall.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """Universal procedure of calculating precision and recall."""
  2. import bisect
  3. def match_gt_with_preds(ground_truth, predictions, match_labels):
  4. """Match a ground truth with every predictions and return matched index."""
  5. max_confidence = 0.
  6. matched_idx = -1
  7. for i, pred in enumerate(predictions):
  8. if match_labels(ground_truth, pred[1]) and max_confidence < pred[0]:
  9. max_confidence = pred[0]
  10. matched_idx = i
  11. return matched_idx
  12. def get_confidence_list(ground_truths_list, predictions_list, match_labels):
  13. """Generate a list of confidence of true positives and false positives."""
  14. assert len(ground_truths_list) == len(predictions_list)
  15. true_positive_list = []
  16. false_positive_list = []
  17. num_samples = len(ground_truths_list)
  18. for i in range(num_samples):
  19. ground_truths = ground_truths_list[i]
  20. predictions = predictions_list[i]
  21. prediction_matched = [False] * len(predictions)
  22. for ground_truth in ground_truths:
  23. idx = match_gt_with_preds(ground_truth, predictions, match_labels)
  24. if idx >= 0:
  25. prediction_matched[idx] = True
  26. true_positive_list.append(predictions[idx][0])
  27. else:
  28. true_positive_list.append(.0)
  29. for idx, pred_matched in enumerate(prediction_matched):
  30. if not pred_matched:
  31. false_positive_list.append(predictions[idx][0])
  32. return true_positive_list, false_positive_list
  33. def calc_precision_recall(ground_truths_list, predictions_list, match_labels):
  34. """Adjust threshold to get mutiple precision recall sample."""
  35. true_positive_list, false_positive_list = get_confidence_list(
  36. ground_truths_list, predictions_list, match_labels)
  37. true_positive_list = sorted(true_positive_list)
  38. false_positive_list = sorted(false_positive_list)
  39. thresholds = sorted(list(set(true_positive_list)))
  40. recalls = [0.]
  41. precisions = [0.]
  42. for thresh in reversed(thresholds):
  43. if thresh == 0.:
  44. recalls.append(1.)
  45. precisions.append(0.)
  46. break
  47. false_negatives = bisect.bisect_left(true_positive_list, thresh)
  48. true_positives = len(true_positive_list) - false_negatives
  49. true_negatives = bisect.bisect_left(false_positive_list, thresh)
  50. false_positives = len(false_positive_list) - true_negatives
  51. recalls.append(true_positives / (true_positives+false_negatives))
  52. precisions.append(true_positives / (true_positives + false_positives))
  53. return precisions, recalls
  54. def calc_average_precision(precisions, recalls):
  55. """Calculate average precision defined in VOC contest."""
  56. total_precision = 0.
  57. for i in range(11):
  58. index = next(conf[0] for conf in enumerate(recalls) if conf[1] >= i/10)
  59. total_precision += max(precisions[index:])
  60. return total_precision / 11