138 lines
5.5 KiB
Python
138 lines
5.5 KiB
Python
"""Prepare Cityscapes dataset"""
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
from .segbase import SegmentationDataset
|
|
|
|
|
|
class CitySegmentation(SegmentationDataset):
|
|
"""Cityscapes Semantic Segmentation Dataset.
|
|
|
|
Parameters
|
|
----------
|
|
root : string
|
|
Path to Cityscapes folder. Default is './datasets/citys'
|
|
split: string
|
|
'train', 'val' or 'test'
|
|
transform : callable, optional
|
|
A function that transforms the image
|
|
Examples
|
|
--------
|
|
>>> from torchvision import transforms
|
|
>>> import torch.utils.data as data
|
|
>>> # Transforms for Normalization
|
|
>>> input_transform = transforms.Compose([
|
|
>>> transforms.ToTensor(),
|
|
>>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
|
|
>>> ])
|
|
>>> # Create Dataset
|
|
>>> trainset = CitySegmentation(split='train', transform=input_transform)
|
|
>>> # Create Training Loader
|
|
>>> train_data = data.DataLoader(
|
|
>>> trainset, 4, shuffle=True,
|
|
>>> num_workers=4)
|
|
"""
|
|
BASE_DIR = 'cityscapes'
|
|
NUM_CLASS = 19
|
|
|
|
def __init__(self, root='../datasets/citys', split='train', mode=None, transform=None, **kwargs):
|
|
super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs)
|
|
# self.root = os.path.join(root, self.BASE_DIR)
|
|
assert os.path.exists(self.root), "Please setup the dataset using ../datasets/cityscapes.py"
|
|
self.images, self.mask_paths = _get_city_pairs(self.root, self.split)
|
|
assert (len(self.images) == len(self.mask_paths))
|
|
if len(self.images) == 0:
|
|
raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
|
|
self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
|
|
23, 24, 25, 26, 27, 28, 31, 32, 33]
|
|
self._key = np.array([-1, -1, -1, -1, -1, -1,
|
|
-1, -1, 0, 1, -1, -1,
|
|
2, 3, 4, -1, -1, -1,
|
|
5, -1, 6, 7, 8, 9,
|
|
10, 11, 12, 13, 14, 15,
|
|
-1, -1, 16, 17, 18])
|
|
self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
|
|
|
|
def _class_to_index(self, mask):
|
|
# assert the value
|
|
values = np.unique(mask)
|
|
for value in values:
|
|
assert (value in self._mapping)
|
|
index = np.digitize(mask.ravel(), self._mapping, right=True)
|
|
return self._key[index].reshape(mask.shape)
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(self.images[index]).convert('RGB')
|
|
if self.mode == 'test':
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, os.path.basename(self.images[index])
|
|
mask = Image.open(self.mask_paths[index])
|
|
# synchrosized transform
|
|
if self.mode == 'train':
|
|
img, mask = self._sync_transform(img, mask)
|
|
elif self.mode == 'val':
|
|
img, mask = self._val_sync_transform(img, mask)
|
|
else:
|
|
assert self.mode == 'testval'
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
# general resize, normalize and toTensor
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, mask, os.path.basename(self.images[index])
|
|
|
|
def _mask_transform(self, mask):
|
|
target = self._class_to_index(np.array(mask).astype('int32'))
|
|
return torch.LongTensor(np.array(target).astype('int32'))
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
@property
|
|
def pred_offset(self):
|
|
return 0
|
|
|
|
|
|
def _get_city_pairs(folder, split='train'):
|
|
def get_path_pairs(img_folder, mask_folder):
|
|
img_paths = []
|
|
mask_paths = []
|
|
for root, _, files in os.walk(img_folder):
|
|
for filename in files:
|
|
if filename.endswith('.png'):
|
|
imgpath = os.path.join(root, filename)
|
|
foldername = os.path.basename(os.path.dirname(imgpath))
|
|
maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
|
|
maskpath = os.path.join(mask_folder, foldername, maskname)
|
|
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
|
|
img_paths.append(imgpath)
|
|
mask_paths.append(maskpath)
|
|
else:
|
|
print('cannot find the mask or image:', imgpath, maskpath)
|
|
print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
|
|
return img_paths, mask_paths
|
|
|
|
if split in ('train', 'val'):
|
|
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
|
|
mask_folder = os.path.join(folder, 'gtFine/' + split)
|
|
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
|
|
return img_paths, mask_paths
|
|
else:
|
|
assert split == 'trainval'
|
|
print('trainval set')
|
|
train_img_folder = os.path.join(folder, 'leftImg8bit/train')
|
|
train_mask_folder = os.path.join(folder, 'gtFine/train')
|
|
val_img_folder = os.path.join(folder, 'leftImg8bit/val')
|
|
val_mask_folder = os.path.join(folder, 'gtFine/val')
|
|
train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
|
|
val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
|
|
img_paths = train_img_paths + val_img_paths
|
|
mask_paths = train_mask_paths + val_mask_paths
|
|
return img_paths, mask_paths
|
|
|
|
|
|
if __name__ == '__main__':
|
|
dataset = CitySegmentation()
|