88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
"""SBU Shadow Segmentation Dataset."""
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
from .segbase import SegmentationDataset
|
|
|
|
|
|
class SBUSegmentation(SegmentationDataset):
|
|
"""SBU Shadow Segmentation Dataset
|
|
"""
|
|
NUM_CLASS = 2
|
|
|
|
def __init__(self, root='../datasets/sbu', split='train', mode=None, transform=None, **kwargs):
|
|
super(SBUSegmentation, self).__init__(root, split, mode, transform, **kwargs)
|
|
assert os.path.exists(self.root)
|
|
self.images, self.masks = _get_sbu_pairs(self.root, self.split)
|
|
assert (len(self.images) == len(self.masks))
|
|
if len(self.images) == 0:
|
|
raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(self.images[index]).convert('RGB')
|
|
if self.mode == 'test':
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, os.path.basename(self.images[index])
|
|
mask = Image.open(self.masks[index])
|
|
# synchrosized transform
|
|
if self.mode == 'train':
|
|
img, mask = self._sync_transform(img, mask)
|
|
elif self.mode == 'val':
|
|
img, mask = self._val_sync_transform(img, mask)
|
|
else:
|
|
assert self.mode == 'testval'
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
# general resize, normalize and toTensor
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, mask, os.path.basename(self.images[index])
|
|
|
|
def _mask_transform(self, mask):
|
|
target = np.array(mask).astype('int32')
|
|
target[target > 0] = 1
|
|
return torch.from_numpy(target).long()
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
@property
|
|
def pred_offset(self):
|
|
return 0
|
|
|
|
|
|
def _get_sbu_pairs(folder, split='train'):
|
|
def get_path_pairs(img_folder, mask_folder):
|
|
img_paths = []
|
|
mask_paths = []
|
|
for root, _, files in os.walk(img_folder):
|
|
print(root)
|
|
for filename in files:
|
|
if filename.endswith('.jpg'):
|
|
imgpath = os.path.join(root, filename)
|
|
maskname = filename.replace('.jpg', '.png')
|
|
maskpath = os.path.join(mask_folder, maskname)
|
|
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
|
|
img_paths.append(imgpath)
|
|
mask_paths.append(maskpath)
|
|
else:
|
|
print('cannot find the mask or image:', imgpath, maskpath)
|
|
print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
|
|
return img_paths, mask_paths
|
|
|
|
if split == 'train':
|
|
img_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowImages')
|
|
mask_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowMasks')
|
|
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
|
else:
|
|
assert split in ('val', 'test')
|
|
img_folder = os.path.join(folder, 'SBU-Test/ShadowImages')
|
|
mask_folder = os.path.join(folder, 'SBU-Test/ShadowMasks')
|
|
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
|
return img_paths, mask_paths
|
|
|
|
|
|
if __name__ == '__main__':
|
|
dataset = SBUSegmentation(base_size=280, crop_size=256) |