Browse Source

Explicitly compute TP, FP in val.py (#5727)

modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
36d12a500e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 7 deletions
  1. +15
    -6
      utils/metrics.py
  2. +1
    -1
      val.py

+ 15
- 6
utils/metrics.py View File

return (x[:, :4] * w).sum(1) return (x[:, :4] * w).sum(1)




def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
""" Compute the average precision, given the recall and precision curves. """ Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments # Arguments
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]


# Find unique classes # Find unique classes
unique_classes = np.unique(target_cls)
unique_classes, nt = np.unique(target_cls, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections nc = unique_classes.shape[0] # number of classes, number of detections


# Create Precision-Recall curve and compute AP for each class # Create Precision-Recall curve and compute AP for each class
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
for ci, c in enumerate(unique_classes): for ci, c in enumerate(unique_classes):
i = pred_cls == c i = pred_cls == c
n_l = (target_cls == c).sum() # number of labels
n_l = nt[ci] # number of labels
n_p = i.sum() # number of predictions n_p = i.sum() # number of predictions


if n_p == 0 or n_l == 0: if n_p == 0 or n_l == 0:
tpc = tp[i].cumsum(0) tpc = tp[i].cumsum(0)


# Recall # Recall
recall = tpc / (n_l + 1e-16) # recall curve
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases


# Precision # Precision
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5


# Compute F1 (harmonic mean of precision and recall) # Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)
f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = {i: v for i, v in enumerate(names)} # to dict names = {i: v for i, v in enumerate(names)} # to dict
if plot: if plot:
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')


i = f1.mean(0).argmax() # max F1 index i = f1.mean(0).argmax() # max F1 index
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
return tp, fp, p, r, f1, ap, unique_classes.astype('int32')




def compute_ap(recall, precision): def compute_ap(recall, precision):
def matrix(self): def matrix(self):
return self.matrix return self.matrix


def tp_fp(self):
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

def plot(self, normalize=True, save_dir='', names=()): def plot(self, normalize=True, save_dir='', names=()):
try: try:
import seaborn as sn import seaborn as sn

+ 1
- 1
val.py View File

# Compute metrics # Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class

Loading…
Cancel
Save