Просмотр исходного кода

Precision-Recall Curve Feature Addition (#1107)

* initial commit

* Update general.py

Indent update

* Update general.py

refactor duplicate code

* 200 dpi
5.0
Glenn Jocher GitHub 4 лет назад
Родитель
Сommit
5fac5ad165
Не найден GPG ключ соответствующий данной подписи Идентификатор GPG ключа: 4AEE18F83AFDEB23
3 измененных файлов: 33 добавлений и 31 удалений
  1. +10
    -11
      test.py
  2. +2
    -4
      train.py
  3. +21
    -16
      utils/general.py

+ 10
- 11
test.py Просмотреть файл

@@ -30,9 +30,9 @@ def test(data,
verbose=False,
model=None,
dataloader=None,
save_dir='',
merge=False,
save_txt=False):
save_dir=Path(''), # for saving images
save_txt=False, # for auto-labelling
plots=True):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
@@ -41,7 +41,7 @@ def test(data,
else: # called directly
set_logging()
device = select_device(opt.device, batch_size=batch_size)
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
save_txt = opt.save_txt # save *.txt labels
if save_txt:
out = Path('inference/output')
if os.path.exists(out):
@@ -49,7 +49,7 @@ def test(data,
os.makedirs(out) # make new output folder

# Remove previous
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
for f in glob.glob(str(save_dir / 'test_batch*.jpg')):
os.remove(f)

# Load model
@@ -110,7 +110,7 @@ def test(data,

# Run NMS
t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
t1 += time_synchronized() - t

# Statistics per image
@@ -186,16 +186,16 @@ def test(data,
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

# Plot images
if batch_i < 1:
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
plot_images(img, targets, paths, str(f), names) # ground truth
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats)
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
@@ -261,7 +261,6 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
opt = parser.parse_args()

+ 2
- 4
train.py Просмотреть файл

@@ -1,5 +1,4 @@
import argparse
import glob
import logging
import math
import os
@@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None):
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
if final_epoch: # replot predictions
[os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)]
results, maps, times = test.test(opt.data,
batch_size=total_batch_size,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=log_dir)
save_dir=log_dir,
plots=epoch == 0 or final_epoch) # plot first and last

# Write
with open(results_file, 'a') as f:

+ 21
- 16
utils/general.py Просмотреть файл

@@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape):
boxes[:, 3].clamp_(0, img_shape[0]) # y2


def ap_per_class(tp, conf, pred_cls, target_cls):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
fname: Plot filename
# Returns
The average precision as computed in py-faster-rcnn.
"""
@@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
unique_classes = np.unique(target_cls)

# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
@@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score

# AP from recall-precision curve
py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5
for j in range(tp.shape[1]):
ap[ci, j] = compute_ap(recall[:, j], precision[:, j])

# Plot
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
# ax.plot(recall, precision)
# ax.set_xlabel('Recall')
# ax.set_ylabel('Precision')
# ax.set_xlim(0, 1.01)
# ax.set_ylim(0, 1.01)
# fig.tight_layout()
# fig.savefig('PR_curve.png', dpi=300)

# Compute F1 score (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)

if plot:
py = np.stack(py, axis=1)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend()
fig.tight_layout()
fig.savefig(fname, dpi=200)

return p, r, ap, f1, unique_classes.astype('int32')


@@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
tl = 3 # line thickness
tf = max(tl - 1, 1) # font thickness
if os.path.isfile(fname): # do not overwrite
return None

if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()

Загрузка…
Отмена
Сохранить