Browse Source

Daemon thread plotting (#1561)

* Daemon thread plotting

* remove process_batch

* plot after print
5.0
Glenn Jocher GitHub 4 years ago
parent
commit
b6ed1104a6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 19 deletions
  1. +12
    -11
      test.py
  2. +5
    -5
      train.py
  3. +8
    -3
      utils/plots.py

+ 12
- 11
test.py View File

@@ -3,6 +3,7 @@ import glob
import json
import os
from pathlib import Path
from threading import Thread

import numpy as np
import torch
@@ -206,10 +207,10 @@ def test(data,

# Plot images
if plots and batch_i < 3:
f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename
plot_images(img, targets, paths, f, names) # labels
f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output), paths, f, names) # predictions
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
@@ -221,13 +222,6 @@ def test(data,
else:
nt = torch.zeros(1)

# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Print results
pf = '%20s' + '%12.3g' * 6 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
@@ -242,6 +236,13 @@ def test(data,
if not training:
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)

# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Save JSON
if save_json and len(jdict):
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights

+ 5
- 5
train.py View File

@@ -1,12 +1,13 @@
import argparse
import logging
import math
import os
import random
import time
from pathlib import Path
from threading import Thread
from warnings import warn

import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
@@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# Resume
start_epoch, best_fitness = 0, 0.0
@@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, save_dir=save_dir)
Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
if tb_writer:
tb_writer.add_histogram('classes', c, 0)
if wandb:
wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})

# Anchors
if not opt.noautoanchor:
@@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Plot
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
plot_images(images=imgs, targets=targets, paths=paths, fname=f)
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard

+ 8
- 3
utils/plots.py View File

@@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
plt.savefig('test_study.png', dpi=300)


def plot_labels(labels, save_dir=''):
def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
@@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''):
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50))
plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
plt.close()
except Exception as e:
pass
@@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''):
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
plt.savefig(save_dir / 'labels.png', dpi=200)
plt.close()

# loggers
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})


def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt

Loading…
Cancel
Save