Update tensorboard logging

This commit is contained in:
Glenn Jocher 2020-10-26 01:08:33 +01:00
parent 7f1640695b
commit 453acdec67
3 changed files with 9 additions and 8 deletions

View File

@ -191,9 +191,9 @@ def test(data,
# Plot images # Plot images
if plots and batch_i < 1: if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
plot_images(img, targets, paths, str(f), names) # ground truth plot_images(img, targets, paths, str(f), names) # ground truth
f = save_dir / ('test_batch%g_pred.jpg' % batch_i) f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
# Compute statistics # Compute statistics

View File

@ -291,11 +291,11 @@ def train(hyp, opt, device, tb_writer=None):
# Plot # Plot
if ni < 3: if ni < 3:
f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename f = str(log_dir / f'train_batch{ni}.jpg') # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None: # if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard # tb_writer.add_graph(model, imgs) # add model to tensorboard
# end batch ------------------------------------------------------------------------------------------------ # end batch ------------------------------------------------------------------------------------------------

View File

@ -19,6 +19,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import yaml import yaml
from PIL import Image
from scipy.cluster.vq import kmeans from scipy.cluster.vq import kmeans
from scipy.signal import butter, filtfilt from scipy.signal import butter, filtfilt
from tqdm import tqdm from tqdm import tqdm
@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
if fname is not None: if fname is not None:
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA) mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
Image.fromarray(mosaic).save(fname) # PIL save
return mosaic return mosaic