Browse Source

New CSV Logger (#4148)

* New CSV Logger

* cleanup

* move batch plots into Logger

* rename comment

* Remove total loss from progress bar

* mloss :-1 bug fix

* Update plot_results()

* Update plot_results()

* plot_results bug fix
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
96e36a7c91
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 109 deletions
  1. +1
    -0
      .gitignore
  2. +11
    -29
      train.py
  3. +38
    -25
      utils/loggers/__init__.py
  4. +1
    -2
      utils/loss.py
  5. +16
    -52
      utils/plots.py
  6. +1
    -1
      val.py

+ 1
- 0
.gitignore View File

!data/*.sh !data/*.sh


results*.txt results*.txt
results*.csv


# Datasets ------------------------------------------------------------------------------------------------------------- # Datasets -------------------------------------------------------------------------------------------------------------
coco/ coco/

+ 11
- 29
train.py View File

import time import time
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread


import math import math
import numpy as np import numpy as np
check_requirements, print_mutation, set_logging, one_cycle, colorstr check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download from utils.google_utils import attempt_download
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness from utils.metrics import fitness
# Directories # Directories
w = save_dir / 'weights' # weights dir w = save_dir / 'weights' # weights dir
w.mkdir(parents=True, exist_ok=True) # make dir w.mkdir(parents=True, exist_ok=True) # make dir
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt'
last, best = w / 'last.pt', w / 'best.pt'


# Hyperparameters # Hyperparameters
if isinstance(hyp, str): if isinstance(hyp, str):


# Loggers # Loggers
if RANK in [-1, 0]: if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume: if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict


ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates'] ema.updates = ckpt['updates']


# Results
if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt

# Epochs # Epochs
start_epoch = ckpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
if resume: if resume:
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders # dataset.mosaic_border = [b - imgsz, -b] # height, width borders


mloss = torch.zeros(4, device=device) # mean losses
mloss = torch.zeros(3, device=device) # mean losses
if RANK != -1: if RANK != -1:
train_loader.sampler.set_epoch(epoch) train_loader.sampler.set_epoch(epoch)
pbar = enumerate(train_loader) pbar = enumerate(train_loader)
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
if RANK in [-1, 0]: if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad() optimizer.zero_grad()
ema.update(model) ema.update(model)
last_opt_step = ni last_opt_step = ni


# Print
# Log
if RANK in [-1, 0]: if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)

# Plot
if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
loggers.on_train_batch_end(ni, model, imgs)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)


# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------


fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness: if fi > best_fitness:
best_fitness = fi best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)


# Save model # Save model
if (not nosave) or (final_epoch and not evolve): # if save if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': deepcopy(de_parallel(model)).half(), 'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(), 'ema': deepcopy(ema.ema).half(),
'updates': ema.updates, 'updates': ema.updates,
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]: if RANK in [-1, 0]:
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png

if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
save_dir=save_dir, save_dir=save_dir,
save_json=True, save_json=True,
plots=False) plots=False)

# Strip optimizers # Strip optimizers
for f in last, best: for f in last, best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers

loggers.on_train_end(last, best)
loggers.on_train_end(last, best, plots)


torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results

+ 38
- 25
utils/loggers/__init__.py View File

# YOLOv5 experiment logging utils # YOLOv5 experiment logging utils


import warnings import warnings
from threading import Thread


import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter


from utils.general import colorstr, emojis from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results
from utils.torch_utils import de_parallel from utils.torch_utils import de_parallel


LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases


try: try:
import wandb import wandb


class Loggers(): class Loggers():
# YOLOv5 Loggers class # YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
data_dict=None, logger=None, include=LOGGERS):
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir self.save_dir = save_dir
self.results_file = results_file
self.weights = weights self.weights = weights
self.opt = opt self.opt = opt
self.hyp = hyp self.hyp = hyp
setattr(self, k, None) # init empty logger dictionary setattr(self, k, None) # init empty logger dictionary


def start(self): def start(self):
self.txt = True # always log to txt
self.csv = True # always log to csv


# Message # Message
try: try:


return self return self


def on_train_batch_end(self, ni, model, imgs):
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end # Callback runs on train batch end
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
if plots:
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})


def on_train_epoch_end(self, epoch): def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end # Callback runs on train epoch end
files = sorted(self.save_dir.glob('val*.jpg')) files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})


def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
# Callback runs on validation end during training
vals = list(mloss[:-1]) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs on val end during training
vals = list(mloss) + list(results) + lr
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params 'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt:
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
x = {k: v for k, v in zip(keys, vals)} # dict

if self.csv:
file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')

if self.tb: if self.tb:
for x, tag in zip(vals, tags):
self.tb.add_scalar(tag, x, epoch) # TensorBoard
for k, v in x.items():
self.tb.add_scalar(k, v, epoch) # TensorBoard

if self.wandb: if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)})
self.wandb.log(x)
self.wandb.end_epoch(best_result=best_fitness == fi) self.wandb.end_epoch(best_result=best_fitness == fi)


def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)


def on_train_end(self, last, best):
def on_train_end(self, last, best, plots):
# Callback runs on training end # Callback runs on training end
if plots:
plot_results(dir=self.save_dir) # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb: if self.wandb:

+ 1
- 2
utils/loss.py View File

lcls *= self.hyp['cls'] lcls *= self.hyp['cls']
bs = tobj.shape[0] # batch size bs = tobj.shape[0] # batch size


loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()


def build_targets(self, p, targets): def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) # Build targets for compute_loss(), input targets(image,class,x,y,w,h)

+ 16
- 52
utils/plots.py View File

# Plotting utils # Plotting utils


import glob
import os
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path


plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)




def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
# Plot training 'results*.txt', overlaying train and val losses
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n)
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
ax = ax.ravel()
for i in range(5):
for j in [i, i + 5]:
y = results[j, x]
ax[i].plot(x, y, marker='.', label=s[j])
# y_smooth = butter_lowpass_filtfilt(y)
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])

ax[i].set_title(t[i])
ax[i].legend()
ax[i].set_ylabel(f) if i == 0 else None # add filename
fig.savefig(f.replace('.txt', '.png'), dpi=200)


def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
def plot_results(file='', dir=''):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel() ax = ax.ravel()
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
if bucket:
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
files = ['results%g.txt' % x for x in id]
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
os.system(c)
else:
files = list(Path(save_dir).glob('results*.txt'))
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files): for fi, f in enumerate(files):
try: try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n)
for i in range(10):
y = results[i, x]
if i in [0, 1, 2, 5, 6, 7]:
y[y == 0] = np.nan # don't show zero loss values
# y /= y[0] # normalize
label = labels[fi] if len(labels) else f.stem
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
ax[i].set_title(s[i])
# if i in [5, 6, 7]: # share train and val loss y axes
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
y = data.values[:, j]
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e: except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e))

print(f'Warning: Plotting error for {f}: {e}')
ax[1].legend() ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
fig.savefig(save_dir / 'results.png', dpi=200)




def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):

+ 1
- 1
val.py View File



# Compute loss # Compute loss
if compute_loss: if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls


# Run NMS # Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels

Loading…
Cancel
Save