Drowning_Person_Detection/utils/summaries.py

24 lines
1.2 KiB
Python

import os
import torch
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
from dataloaders.utils import decode_seg_map_sequence
class TensorboardSummary(object):
def __init__(self, directory):
self.directory = directory
def create_summary(self):
writer = SummaryWriter(log_dir=os.path.join(self.directory))
return writer
def visualize_image(self, writer, dataset, image, target, output, global_step):
grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
writer.add_image('Image', grid_image, global_step)
grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
dataset=dataset), 3, normalize=False, range=(0, 255))
writer.add_image('Predicted label', grid_image, global_step)
grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
dataset=dataset), 3, normalize=False, range=(0, 255))
writer.add_image('Groundtruth label', grid_image, global_step)