* initial commit * Update general.py Indent update * Update general.py refactor duplicate code * 200 dpi5.0
@@ -30,9 +30,9 @@ def test(data, | |||
verbose=False, | |||
model=None, | |||
dataloader=None, | |||
save_dir='', | |||
merge=False, | |||
save_txt=False): | |||
save_dir=Path(''), # for saving images | |||
save_txt=False, # for auto-labelling | |||
plots=True): | |||
# Initialize/load model and set device | |||
training = model is not None | |||
if training: # called by train.py | |||
@@ -41,7 +41,7 @@ def test(data, | |||
else: # called directly | |||
set_logging() | |||
device = select_device(opt.device, batch_size=batch_size) | |||
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels | |||
save_txt = opt.save_txt # save *.txt labels | |||
if save_txt: | |||
out = Path('inference/output') | |||
if os.path.exists(out): | |||
@@ -49,7 +49,7 @@ def test(data, | |||
os.makedirs(out) # make new output folder | |||
# Remove previous | |||
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): | |||
for f in glob.glob(str(save_dir / 'test_batch*.jpg')): | |||
os.remove(f) | |||
# Load model | |||
@@ -110,7 +110,7 @@ def test(data, | |||
# Run NMS | |||
t = time_synchronized() | |||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge) | |||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) | |||
t1 += time_synchronized() - t | |||
# Statistics per image | |||
@@ -186,16 +186,16 @@ def test(data, | |||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | |||
# Plot images | |||
if batch_i < 1: | |||
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename | |||
if plots and batch_i < 1: | |||
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename | |||
plot_images(img, targets, paths, str(f), names) # ground truth | |||
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) | |||
f = save_dir / ('test_batch%g_pred.jpg' % batch_i) | |||
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions | |||
# Compute statistics | |||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | |||
if len(stats) and stats[0].any(): | |||
p, r, ap, f1, ap_class = ap_per_class(*stats) | |||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png') | |||
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95] | |||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() | |||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class | |||
@@ -261,7 +261,6 @@ if __name__ == '__main__': | |||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | |||
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') | |||
parser.add_argument('--augment', action='store_true', help='augmented inference') | |||
parser.add_argument('--merge', action='store_true', help='use Merge NMS') | |||
parser.add_argument('--verbose', action='store_true', help='report mAP by class') | |||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') | |||
opt = parser.parse_args() |
@@ -1,5 +1,4 @@ | |||
import argparse | |||
import glob | |||
import logging | |||
import math | |||
import os | |||
@@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None): | |||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) | |||
final_epoch = epoch + 1 == epochs | |||
if not opt.notest or final_epoch: # Calculate mAP | |||
if final_epoch: # replot predictions | |||
[os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)] | |||
results, maps, times = test.test(opt.data, | |||
batch_size=total_batch_size, | |||
imgsz=imgsz_test, | |||
model=ema.ema, | |||
single_cls=opt.single_cls, | |||
dataloader=testloader, | |||
save_dir=log_dir) | |||
save_dir=log_dir, | |||
plots=epoch == 0 or final_epoch) # plot first and last | |||
# Write | |||
with open(results_file, 'a') as f: |
@@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape): | |||
boxes[:, 3].clamp_(0, img_shape[0]) # y2 | |||
def ap_per_class(tp, conf, pred_cls, target_cls): | |||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'): | |||
""" Compute the average precision, given the recall and precision curves. | |||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | |||
# Arguments | |||
tp: True positives (nparray, nx1 or nx10). | |||
tp: True positives (nparray, nx1 or nx10). | |||
conf: Objectness value from 0-1 (nparray). | |||
pred_cls: Predicted object classes (nparray). | |||
target_cls: True object classes (nparray). | |||
pred_cls: Predicted object classes (nparray). | |||
target_cls: True object classes (nparray). | |||
plot: Plot precision-recall curve at mAP@0.5 | |||
fname: Plot filename | |||
# Returns | |||
The average precision as computed in py-faster-rcnn. | |||
""" | |||
@@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls): | |||
unique_classes = np.unique(target_cls) | |||
# Create Precision-Recall curve and compute AP for each class | |||
px, py = np.linspace(0, 1, 1000), [] # for plotting | |||
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 | |||
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95) | |||
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s) | |||
@@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls): | |||
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score | |||
# AP from recall-precision curve | |||
py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5 | |||
for j in range(tp.shape[1]): | |||
ap[ci, j] = compute_ap(recall[:, j], precision[:, j]) | |||
# Plot | |||
# fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | |||
# ax.plot(recall, precision) | |||
# ax.set_xlabel('Recall') | |||
# ax.set_ylabel('Precision') | |||
# ax.set_xlim(0, 1.01) | |||
# ax.set_ylim(0, 1.01) | |||
# fig.tight_layout() | |||
# fig.savefig('PR_curve.png', dpi=300) | |||
# Compute F1 score (harmonic mean of precision and recall) | |||
f1 = 2 * p * r / (p + r + 1e-16) | |||
if plot: | |||
py = np.stack(py, axis=1) | |||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | |||
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision) | |||
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes') | |||
ax.set_xlabel('Recall') | |||
ax.set_ylabel('Precision') | |||
ax.set_xlim(0, 1) | |||
ax.set_ylim(0, 1) | |||
plt.legend() | |||
fig.tight_layout() | |||
fig.savefig(fname, dpi=200) | |||
return p, r, ap, f1, unique_classes.astype('int32') | |||
@@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods() | |||
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): | |||
tl = 3 # line thickness | |||
tf = max(tl - 1, 1) # font thickness | |||
if os.path.isfile(fname): # do not overwrite | |||
return None | |||
if isinstance(images, torch.Tensor): | |||
images = images.cpu().float().numpy() |