落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

24 lines
1.2KB

  1. import os
  2. import torch
  3. from torchvision.utils import make_grid
  4. from tensorboardX import SummaryWriter
  5. from dataloaders.utils import decode_seg_map_sequence
  6. class TensorboardSummary(object):
  7. def __init__(self, directory):
  8. self.directory = directory
  9. def create_summary(self):
  10. writer = SummaryWriter(log_dir=os.path.join(self.directory))
  11. return writer
  12. def visualize_image(self, writer, dataset, image, target, output, global_step):
  13. grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
  14. writer.add_image('Image', grid_image, global_step)
  15. grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
  16. dataset=dataset), 3, normalize=False, range=(0, 255))
  17. writer.add_image('Predicted label', grid_image, global_step)
  18. grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
  19. dataset=dataset), 3, normalize=False, range=(0, 255))
  20. writer.add_image('Groundtruth label', grid_image, global_step)