* Daemon thread plotting * remove process_batch * plot after print5.0
@@ -3,6 +3,7 @@ import glob | |||
import json | |||
import os | |||
from pathlib import Path | |||
from threading import Thread | |||
import numpy as np | |||
import torch | |||
@@ -206,10 +207,10 @@ def test(data, | |||
# Plot images | |||
if plots and batch_i < 3: | |||
f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename | |||
plot_images(img, targets, paths, f, names) # labels | |||
f = save_dir / f'test_batch{batch_i}_pred.jpg' | |||
plot_images(img, output_to_target(output), paths, f, names) # predictions | |||
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels | |||
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() | |||
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions | |||
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start() | |||
# Compute statistics | |||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | |||
@@ -221,13 +222,6 @@ def test(data, | |||
else: | |||
nt = torch.zeros(1) | |||
# Plots | |||
if plots: | |||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) | |||
if wandb and wandb.run: | |||
wandb.log({"Images": wandb_images}) | |||
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]}) | |||
# Print results | |||
pf = '%20s' + '%12.3g' * 6 # print format | |||
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) | |||
@@ -242,6 +236,13 @@ def test(data, | |||
if not training: | |||
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) | |||
# Plots | |||
if plots: | |||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) | |||
if wandb and wandb.run: | |||
wandb.log({"Images": wandb_images}) | |||
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]}) | |||
# Save JSON | |||
if save_json and len(jdict): | |||
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights |
@@ -1,12 +1,13 @@ | |||
import argparse | |||
import logging | |||
import math | |||
import os | |||
import random | |||
import time | |||
from pathlib import Path | |||
from threading import Thread | |||
from warnings import warn | |||
import math | |||
import numpy as np | |||
import torch.distributed as dist | |||
import torch.nn as nn | |||
@@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, | |||
name=save_dir.stem, | |||
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) | |||
loggers = {'wandb': wandb} # loggers dict | |||
# Resume | |||
start_epoch, best_fitness = 0, 0.0 | |||
@@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency | |||
# model._initialize_biases(cf.to(device)) | |||
if plots: | |||
plot_labels(labels, save_dir=save_dir) | |||
Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start() | |||
if tb_writer: | |||
tb_writer.add_histogram('classes', c, 0) | |||
if wandb: | |||
wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]}) | |||
# Anchors | |||
if not opt.noautoanchor: | |||
@@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
# Plot | |||
if plots and ni < 3: | |||
f = save_dir / f'train_batch{ni}.jpg' # filename | |||
plot_images(images=imgs, targets=targets, paths=paths, fname=f) | |||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() | |||
# if tb_writer: | |||
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) | |||
# tb_writer.add_graph(model, imgs) # add model to tensorboard |
@@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx | |||
plt.savefig('test_study.png', dpi=300) | |||
def plot_labels(labels, save_dir=''): | |||
def plot_labels(labels, save_dir=Path(''), loggers=None): | |||
# plot dataset labels | |||
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes | |||
nc = int(c.max() + 1) # number of classes | |||
@@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''): | |||
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', | |||
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), | |||
diag_kws=dict(bins=50)) | |||
plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200) | |||
plt.savefig(save_dir / 'labels_correlogram.png', dpi=200) | |||
plt.close() | |||
except Exception as e: | |||
pass | |||
@@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''): | |||
for a in [0, 1, 2, 3]: | |||
for s in ['top', 'right', 'left', 'bottom']: | |||
ax[a].spines[s].set_visible(False) | |||
plt.savefig(Path(save_dir) / 'labels.png', dpi=200) | |||
plt.savefig(save_dir / 'labels.png', dpi=200) | |||
plt.close() | |||
# loggers | |||
for k, v in loggers.items() or {}: | |||
if k == 'wandb' and v: | |||
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]}) | |||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() | |||
# Plot hyperparameter evolution results in evolve.txt |