Update tensorboard logging
This commit is contained in:
parent
7f1640695b
commit
453acdec67
4
test.py
4
test.py
|
|
@ -191,9 +191,9 @@ def test(data,
|
||||||
|
|
||||||
# Plot images
|
# Plot images
|
||||||
if plots and batch_i < 1:
|
if plots and batch_i < 1:
|
||||||
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
|
f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
|
||||||
plot_images(img, targets, paths, str(f), names) # ground truth
|
plot_images(img, targets, paths, str(f), names) # ground truth
|
||||||
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
|
f = save_dir / f'test_batch{batch_i}_pred.jpg'
|
||||||
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
|
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
|
||||||
|
|
||||||
# Compute statistics
|
# Compute statistics
|
||||||
|
|
|
||||||
8
train.py
8
train.py
|
|
@ -291,11 +291,11 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
|
|
||||||
# Plot
|
# Plot
|
||||||
if ni < 3:
|
if ni < 3:
|
||||||
f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
|
f = str(log_dir / f'train_batch{ni}.jpg') # filename
|
||||||
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
||||||
if tb_writer and result is not None:
|
# if tb_writer and result is not None:
|
||||||
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
||||||
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
# tb_writer.add_graph(model, imgs) # add model to tensorboard
|
||||||
|
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import yaml
|
import yaml
|
||||||
|
from PIL import Image
|
||||||
from scipy.cluster.vq import kmeans
|
from scipy.cluster.vq import kmeans
|
||||||
from scipy.signal import butter, filtfilt
|
from scipy.signal import butter, filtfilt
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
||||||
|
|
||||||
if fname is not None:
|
if fname is not None:
|
||||||
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
|
mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
|
||||||
cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
|
# cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
|
||||||
|
Image.fromarray(mosaic).save(fname) # PIL save
|
||||||
return mosaic
|
return mosaic
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue