64 lines
2.6 KiB
Python
64 lines
2.6 KiB
Python
"""Universal procedure of calculating precision and recall."""
|
|
|
|
|
|
def match_gt_with_preds(ground_truth, predictions, match_labels):
|
|
"""Match a ground truth with every predictions and return matched index."""
|
|
max_confidence = 0.
|
|
matched_idx = -1
|
|
for i, pred in enumerate(predictions):
|
|
if match_labels(ground_truth, pred[1]) and max_confidence < pred[0]:
|
|
max_confidence = pred[0]
|
|
matched_idx = i
|
|
return matched_idx
|
|
|
|
|
|
def get_confidence_list(ground_truths_list, predictions_list, match_labels):
|
|
"""Generate a list of confidence of true positives and false positives."""
|
|
assert len(ground_truths_list) == len(predictions_list)
|
|
true_positive_list = []
|
|
false_positive_list = []
|
|
num_samples = len(ground_truths_list)
|
|
for i in range(num_samples):
|
|
ground_truths = ground_truths_list[i]
|
|
predictions = predictions_list[i]
|
|
prediction_matched = [False] * len(predictions)
|
|
for ground_truth in ground_truths:
|
|
idx = match_gt_with_preds(ground_truth, predictions, match_labels)
|
|
if idx >= 0:
|
|
prediction_matched[idx] = True
|
|
true_positive_list.append(predictions[idx][0])
|
|
else:
|
|
true_positive_list.append(.0)
|
|
for idx, pred_matched in enumerate(prediction_matched):
|
|
if not pred_matched:
|
|
false_positive_list.append(predictions[idx][0])
|
|
return true_positive_list, false_positive_list
|
|
|
|
|
|
def calc_precision_recall(ground_truths_list, predictions_list, match_labels):
|
|
"""Adjust threshold to get mutiple precision recall sample."""
|
|
true_positive_list, false_positive_list = get_confidence_list(
|
|
ground_truths_list, predictions_list, match_labels)
|
|
recalls = [0.]
|
|
precisions = [0.]
|
|
thresholds = sorted(list(set(true_positive_list)))
|
|
for thresh in reversed(thresholds):
|
|
if thresh == 0.:
|
|
recalls.append(1.)
|
|
precisions.append(0.)
|
|
true_positives = sum(i >= thresh for i in true_positive_list)
|
|
false_positives = sum(i >= thresh for i in false_positive_list)
|
|
false_negatives = len(true_positive_list) - true_positives
|
|
recalls.append(true_positives / (true_positives+false_negatives))
|
|
precisions.append(true_positives / (true_positives + false_positives))
|
|
return precisions, recalls
|
|
|
|
|
|
def calc_average_precision(precisions, recalls):
|
|
"""Calculate average precision defined in VOC contest."""
|
|
total_precision = 0.
|
|
for i in range(11):
|
|
index = next(conf[0] for conf in enumerate(recalls) if conf[1] >= i/10)
|
|
total_precision += max(precisions[index:])
|
|
return total_precision / 11
|