* New CSV Logger * cleanup * move batch plots into Logger * rename comment * Remove total loss from progress bar * mloss :-1 bug fix * Update plot_results() * Update plot_results() * plot_results bug fixmodifyDataloader
@@ -31,6 +31,7 @@ data/* | |||
!data/*.sh | |||
results*.txt | |||
results*.csv | |||
# Datasets ------------------------------------------------------------------------------------------------------------- | |||
coco/ |
@@ -12,7 +12,6 @@ import sys | |||
import time | |||
from copy import deepcopy | |||
from pathlib import Path | |||
from threading import Thread | |||
import math | |||
import numpy as np | |||
@@ -38,7 +37,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima | |||
check_requirements, print_mutation, set_logging, one_cycle, colorstr | |||
from utils.google_utils import attempt_download | |||
from utils.loss import ComputeLoss | |||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | |||
from utils.plots import plot_labels, plot_evolution | |||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | |||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | |||
from utils.metrics import fitness | |||
@@ -61,7 +60,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
# Directories | |||
w = save_dir / 'weights' # weights dir | |||
w.mkdir(parents=True, exist_ok=True) # make dir | |||
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt' | |||
last, best = w / 'last.pt', w / 'best.pt' | |||
# Hyperparameters | |||
if isinstance(hyp, str): | |||
@@ -88,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
# Loggers | |||
if RANK in [-1, 0]: | |||
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict | |||
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict | |||
if loggers.wandb and resume: | |||
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict | |||
@@ -167,10 +166,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) | |||
ema.updates = ckpt['updates'] | |||
# Results | |||
if ckpt.get('training_results') is not None: | |||
results_file.write_text(ckpt['training_results']) # write results.txt | |||
# Epochs | |||
start_epoch = ckpt['epoch'] + 1 | |||
if resume: | |||
@@ -275,11 +270,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) | |||
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders | |||
mloss = torch.zeros(4, device=device) # mean losses | |||
mloss = torch.zeros(3, device=device) # mean losses | |||
if RANK != -1: | |||
train_loader.sampler.set_epoch(epoch) | |||
pbar = enumerate(train_loader) | |||
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) | |||
LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size')) | |||
if RANK in [-1, 0]: | |||
pbar = tqdm(pbar, total=nb) # progress bar | |||
optimizer.zero_grad() | |||
@@ -327,20 +322,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
ema.update(model) | |||
last_opt_step = ni | |||
# Log | |||
if RANK in [-1, 0]: | |||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses | |||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) | |||
s = ('%10s' * 2 + '%10.4g' * 6) % ( | |||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]) | |||
pbar.set_description(s) | |||
# Plot | |||
if plots: | |||
if ni < 3: | |||
f = save_dir / f'train_batch{ni}.jpg' # filename | |||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() | |||
loggers.on_train_batch_end(ni, model, imgs) | |||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( | |||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) | |||
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) | |||
# end batch ------------------------------------------------------------------------------------------------ | |||
@@ -371,13 +359,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] | |||
if fi > best_fitness: | |||
best_fitness = fi | |||
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi) | |||
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) | |||
# Save model | |||
if (not nosave) or (final_epoch and not evolve): # if save | |||
ckpt = {'epoch': epoch, | |||
'best_fitness': best_fitness, | |||
'training_results': results_file.read_text(), | |||
'model': deepcopy(de_parallel(model)).half(), | |||
'ema': deepcopy(ema.ema).half(), | |||
'updates': ema.updates, | |||
@@ -395,9 +382,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
# end training ----------------------------------------------------------------------------------------------------- | |||
if RANK in [-1, 0]: | |||
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') | |||
if plots: | |||
plot_results(save_dir=save_dir) # save as results.png | |||
if not evolve: | |||
if is_coco: # COCO dataset | |||
for m in [last, best] if best.exists() else [last]: # speed, mAP tests | |||
@@ -411,13 +395,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
save_dir=save_dir, | |||
save_json=True, | |||
plots=False) | |||
# Strip optimizers | |||
for f in last, best: | |||
if f.exists(): | |||
strip_optimizer(f) # strip optimizers | |||
loggers.on_train_end(last, best) | |||
loggers.on_train_end(last, best, plots) | |||
torch.cuda.empty_cache() | |||
return results |
@@ -1,15 +1,17 @@ | |||
# YOLOv5 experiment logging utils | |||
import warnings | |||
from threading import Thread | |||
import torch | |||
from torch.utils.tensorboard import SummaryWriter | |||
from utils.general import colorstr, emojis | |||
from utils.loggers.wandb.wandb_utils import WandbLogger | |||
from utils.plots import plot_images, plot_results | |||
from utils.torch_utils import de_parallel | |||
LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases | |||
LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases | |||
try: | |||
import wandb | |||
@@ -21,10 +23,8 @@ except (ImportError, AssertionError): | |||
class Loggers(): | |||
# YOLOv5 Loggers class | |||
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None, | |||
data_dict=None, logger=None, include=LOGGERS): | |||
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS): | |||
self.save_dir = save_dir | |||
self.results_file = results_file | |||
self.weights = weights | |||
self.opt = opt | |||
self.hyp = hyp | |||
@@ -35,7 +35,7 @@ class Loggers(): | |||
setattr(self, k, None) # init empty logger dictionary | |||
def start(self): | |||
self.txt = True # always log to txt | |||
self.csv = True # always log to csv | |||
# Message | |||
try: | |||
@@ -63,15 +63,19 @@ class Loggers(): | |||
return self | |||
def on_train_batch_end(self, ni, model, imgs): | |||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): | |||
# Callback runs on train batch end | |||
if ni == 0: | |||
with warnings.catch_warnings(): | |||
warnings.simplefilter('ignore') # suppress jit trace warning | |||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) | |||
if self.wandb and ni == 10: | |||
files = sorted(self.save_dir.glob('train*.jpg')) | |||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) | |||
if plots: | |||
if ni == 0: | |||
with warnings.catch_warnings(): | |||
warnings.simplefilter('ignore') # suppress jit trace warning | |||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) | |||
if ni < 3: | |||
f = self.save_dir / f'train_batch{ni}.jpg' # filename | |||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() | |||
if self.wandb and ni == 10: | |||
files = sorted(self.save_dir.glob('train*.jpg')) | |||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) | |||
def on_train_epoch_end(self, epoch): | |||
# Callback runs on train epoch end | |||
@@ -89,21 +93,28 @@ class Loggers(): | |||
files = sorted(self.save_dir.glob('val*.jpg')) | |||
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) | |||
def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi): | |||
# Callback runs on validation end during training | |||
vals = list(mloss[:-1]) + list(results) + lr | |||
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss | |||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | |||
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi): | |||
# Callback runs on val end during training | |||
vals = list(mloss) + list(results) + lr | |||
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss | |||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics | |||
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss | |||
'x/lr0', 'x/lr1', 'x/lr2'] # params | |||
if self.txt: | |||
with open(self.results_file, 'a') as f: | |||
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss | |||
x = {k: v for k, v in zip(keys, vals)} # dict | |||
if self.csv: | |||
file = self.save_dir / 'results.csv' | |||
n = len(x) + 1 # number of cols | |||
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header | |||
with open(file, 'a') as f: | |||
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') | |||
if self.tb: | |||
for x, tag in zip(vals, tags): | |||
self.tb.add_scalar(tag, x, epoch) # TensorBoard | |||
for k, v in x.items(): | |||
self.tb.add_scalar(k, v, epoch) # TensorBoard | |||
if self.wandb: | |||
self.wandb.log({k: v for k, v in zip(tags, vals)}) | |||
self.wandb.log(x) | |||
self.wandb.end_epoch(best_result=best_fitness == fi) | |||
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): | |||
@@ -112,8 +123,10 @@ class Loggers(): | |||
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: | |||
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) | |||
def on_train_end(self, last, best): | |||
def on_train_end(self, last, best, plots): | |||
# Callback runs on training end | |||
if plots: | |||
plot_results(dir=self.save_dir) # save results.png | |||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] | |||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter | |||
if self.wandb: |
@@ -162,8 +162,7 @@ class ComputeLoss: | |||
lcls *= self.hyp['cls'] | |||
bs = tobj.shape[0] # batch size | |||
loss = lbox + lobj + lcls | |||
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach() | |||
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach() | |||
def build_targets(self, p, targets): | |||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) |
@@ -1,7 +1,5 @@ | |||
# Plotting utils | |||
import glob | |||
import os | |||
from copy import copy | |||
from pathlib import Path | |||
@@ -387,63 +385,29 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''): | |||
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) | |||
def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() | |||
# Plot training 'results*.txt', overlaying train and val losses | |||
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends | |||
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles | |||
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): | |||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T | |||
n = results.shape[1] # number of rows | |||
x = range(start, min(stop, n) if stop else n) | |||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) | |||
ax = ax.ravel() | |||
for i in range(5): | |||
for j in [i, i + 5]: | |||
y = results[j, x] | |||
ax[i].plot(x, y, marker='.', label=s[j]) | |||
# y_smooth = butter_lowpass_filtfilt(y) | |||
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) | |||
ax[i].set_title(t[i]) | |||
ax[i].legend() | |||
ax[i].set_ylabel(f) if i == 0 else None # add filename | |||
fig.savefig(f.replace('.txt', '.png'), dpi=200) | |||
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): | |||
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') | |||
def plot_results(file='', dir=''): | |||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') | |||
save_dir = Path(file).parent if file else Path(dir) | |||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) | |||
ax = ax.ravel() | |||
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', | |||
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] | |||
if bucket: | |||
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] | |||
files = ['results%g.txt' % x for x in id] | |||
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) | |||
os.system(c) | |||
else: | |||
files = list(Path(save_dir).glob('results*.txt')) | |||
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) | |||
files = list(save_dir.glob('results*.csv')) | |||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' | |||
for fi, f in enumerate(files): | |||
try: | |||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T | |||
n = results.shape[1] # number of rows | |||
x = range(start, min(stop, n) if stop else n) | |||
for i in range(10): | |||
y = results[i, x] | |||
if i in [0, 1, 2, 5, 6, 7]: | |||
y[y == 0] = np.nan # don't show zero loss values | |||
# y /= y[0] # normalize | |||
label = labels[fi] if len(labels) else f.stem | |||
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) | |||
ax[i].set_title(s[i]) | |||
# if i in [5, 6, 7]: # share train and val loss y axes | |||
data = pd.read_csv(f) | |||
s = [x.strip() for x in data.columns] | |||
x = data.values[:, 0] | |||
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]): | |||
y = data.values[:, j] | |||
# y[y == 0] = np.nan # don't show zero values | |||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) | |||
ax[i].set_title(s[j], fontsize=12) | |||
# if j in [8, 9, 10]: # share train and val loss y axes | |||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) | |||
except Exception as e: | |||
print('Warning: Plotting error for %s; %s' % (f, e)) | |||
print(f'Warning: Plotting error for {f}: {e}') | |||
ax[1].legend() | |||
fig.savefig(Path(save_dir) / 'results.png', dpi=200) | |||
fig.savefig(save_dir / 'results.png', dpi=200) | |||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): |
@@ -171,7 +171,7 @@ def run(data, | |||
# Compute loss | |||
if compute_loss: | |||
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls | |||
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls | |||
# Run NMS | |||
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels |