|
- """Universal procedure of calculating precision and recall."""
- import bisect
-
-
- def match_gt_with_preds(ground_truth, predictions, match_labels):
- """Match a ground truth with every predictions and return matched index."""
- max_confidence = 0.
- matched_idx = -1
- for i, pred in enumerate(predictions):
- if match_labels(ground_truth, pred[1]) and max_confidence < pred[0]:
- max_confidence = pred[0]
- matched_idx = i
- return matched_idx
-
-
- def get_confidence_list(ground_truths_list, predictions_list, match_labels):
- """Generate a list of confidence of true positives and false positives."""
- assert len(ground_truths_list) == len(predictions_list)
- true_positive_list = []
- false_positive_list = []
- num_samples = len(ground_truths_list)
- for i in range(num_samples):
- ground_truths = ground_truths_list[i]
- predictions = predictions_list[i]
- prediction_matched = [False] * len(predictions)
- for ground_truth in ground_truths:
- idx = match_gt_with_preds(ground_truth, predictions, match_labels)
- if idx >= 0:
- prediction_matched[idx] = True
- true_positive_list.append(predictions[idx][0])
- else:
- true_positive_list.append(.0)
- for idx, pred_matched in enumerate(prediction_matched):
- if not pred_matched:
- false_positive_list.append(predictions[idx][0])
- return true_positive_list, false_positive_list
-
-
- def calc_precision_recall(ground_truths_list, predictions_list, match_labels):
- """Adjust threshold to get mutiple precision recall sample."""
- true_positive_list, false_positive_list = get_confidence_list(
- ground_truths_list, predictions_list, match_labels)
- true_positive_list = sorted(true_positive_list)
- false_positive_list = sorted(false_positive_list)
- thresholds = sorted(list(set(true_positive_list)))
- recalls = [0.]
- precisions = [0.]
- for thresh in reversed(thresholds):
- if thresh == 0.:
- recalls.append(1.)
- precisions.append(0.)
- break
- false_negatives = bisect.bisect_left(true_positive_list, thresh)
- true_positives = len(true_positive_list) - false_negatives
- true_negatives = bisect.bisect_left(false_positive_list, thresh)
- false_positives = len(false_positive_list) - true_negatives
- recalls.append(true_positives / (true_positives+false_negatives))
- precisions.append(true_positives / (true_positives + false_positives))
- return precisions, recalls
-
-
- def calc_average_precision(precisions, recalls):
- """Calculate average precision defined in VOC contest."""
- total_precision = 0.
- for i in range(11):
- index = next(conf[0] for conf in enumerate(recalls) if conf[1] >= i/10)
- total_precision += max(precisions[index:])
- return total_precision / 11
|