落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
1.7KB

  1. import torch
  2. import random
  3. import numpy as np
  4. from PIL import Image, ImageOps, ImageFilter
  5. class Normalize(object):
  6. """Normalize a tensor image with mean and standard deviation.
  7. Args:
  8. mean (tuple): means for each channel.
  9. std (tuple): standard deviations for each channel.
  10. """
  11. def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
  12. self.mean = mean
  13. self.std = std
  14. def __call__(self, sample):
  15. img = sample['image']
  16. # mask = sample['label']
  17. img = np.array(img).astype(np.float32)
  18. # mask = np.array(mask).astype(np.float32)
  19. img /= 255.0
  20. img -= self.mean
  21. img /= self.std
  22. # return {'image': img,
  23. # 'label': mask}
  24. return {'image': img}
  25. class ToTensor(object):
  26. """Convert ndarrays in sample to Tensors."""
  27. def __call__(self, sample):
  28. # swap color axis because
  29. # numpy image: H x W x C
  30. # torch image: C X H X W
  31. img = sample['image']
  32. # mask = sample['label']
  33. img = np.array(img).astype(np.float32).transpose((2, 0, 1))
  34. # mask = np.array(mask).astype(np.float32)
  35. img = torch.from_numpy(img).float()
  36. # mask = torch.from_numpy(mask).float()
  37. # return img, mask
  38. return img
  39. class FixedResize(object):
  40. def __init__(self, size):
  41. self.size = (size, size) # size: (h, w)
  42. def __call__(self, sample):
  43. img = sample['image']
  44. # mask = sample['label']
  45. # assert img.size == mask.size
  46. img = img.resize(self.size, Image.BILINEAR)
  47. # mask = mask.resize(self.size, Image.NEAREST)
  48. # return {'image': img,
  49. # 'label': mask}
  50. return {'image': img}