"""Pascal ADE20K Semantic Segmentation Dataset.""" import os import torch import numpy as np from PIL import Image from .segbase import SegmentationDataset class ADE20KSegmentation(SegmentationDataset): """ADE20K Semantic Segmentation Dataset. Parameters ---------- root : string Path to ADE20K folder. Default is './datasets/ade' split: string 'train', 'val' or 'test' transform : callable, optional A function that transforms the image Examples -------- >>> from torchvision import transforms >>> import torch.utils.data as data >>> # Transforms for Normalization >>> input_transform = transforms.Compose([ >>> transforms.ToTensor(), >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), >>> ]) >>> # Create Dataset >>> trainset = ADE20KSegmentation(split='train', transform=input_transform) >>> # Create Training Loader >>> train_data = data.DataLoader( >>> trainset, 4, shuffle=True, >>> num_workers=4) """ BASE_DIR = 'ADEChallengeData2016' NUM_CLASS = 150 def __init__(self, root='../datasets/ade', split='test', mode=None, transform=None, **kwargs): super(ADE20KSegmentation, self).__init__(root, split, mode, transform, **kwargs) root = os.path.join(root, self.BASE_DIR) assert os.path.exists(root), "Please setup the dataset using ../datasets/ade20k.py" self.images, self.masks = _get_ade20k_pairs(root, split) assert (len(self.images) == len(self.masks)) if len(self.images) == 0: raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") print('Found {} images in the folder {}'.format(len(self.images), root)) def __getitem__(self, index): img = Image.open(self.images[index]).convert('RGB') if self.mode == 'test': img = self._img_transform(img) if self.transform is not None: img = self.transform(img) return img, os.path.basename(self.images[index]) mask = Image.open(self.masks[index]) # synchrosized transform if self.mode == 'train': img, mask = self._sync_transform(img, mask) elif self.mode == 'val': img, mask = self._val_sync_transform(img, mask) else: assert self.mode == 'testval' img, mask = self._img_transform(img), self._mask_transform(mask) # general resize, normalize and to Tensor if self.transform is not None: img = self.transform(img) return img, mask, os.path.basename(self.images[index]) def _mask_transform(self, mask): return torch.LongTensor(np.array(mask).astype('int32') - 1) def __len__(self): return len(self.images) @property def pred_offset(self): return 1 @property def classes(self): """Category names.""" return ("wall", "building, edifice", "sky", "floor, flooring", "tree", "ceiling", "road, route", "bed", "windowpane, window", "grass", "cabinet", "sidewalk, pavement", "person, individual, someone, somebody, mortal, soul", "earth, ground", "door, double door", "table", "mountain, mount", "plant, flora, plant life", "curtain, drape, drapery, mantle, pall", "chair", "car, auto, automobile, machine, motorcar", "water", "painting, picture", "sofa, couch, lounge", "shelf", "house", "sea", "mirror", "rug, carpet, carpeting", "field", "armchair", "seat", "fence, fencing", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "bathtub, bathing tub, bath, tub", "railing, rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace, hearth, open fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs, steps", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table, cocktail table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove, kitchen stove, range, kitchen range, cooking stove", "palm, palm tree", "kitchen island", "computer, computing machine, computing device, data processor, " "electronic computer, information processing system", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, " "motorcoach, omnibus, passenger vehicle", "towel", "light, light source", "truck, motortruck", "tower", "chandelier, pendant, pendent", "awning, sunshade, sunblind", "streetlight, street lamp", "booth, cubicle, stall, kiosk", "television receiver, television, television set, tv, tv set, idiot " "box, boob tube, telly, goggle box", "airplane, aeroplane, plane", "dirt track", "apparel, wearing apparel, dress, clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "swimming pool, swimming bath, natatorium", "stool", "barrel, cask", "basket, handbasket", "waterfall, falls", "tent, collapsible shelter", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name, brand name, brand, marque", "microwave, microwave oven", "pot, flowerpot", "animal, animate being, beast, brute, creature, fauna", "bicycle, bike, wheel, cycle", "lake", "dishwasher, dish washer, dishwashing machine", "screen, silver screen, projection screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light, traffic signal, stoplight", "tray", "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, " "dustbin, trash barrel, trash bin", "fan", "pier, wharf, wharfage, dock", "crt screen", "plate", "monitor, monitoring device", "bulletin board, notice board", "shower", "radiator", "glass, drinking glass", "clock", "flag") def _get_ade20k_pairs(folder, mode='train'): img_paths = [] mask_paths = [] if mode == 'train': img_folder = os.path.join(folder, 'images/training') mask_folder = os.path.join(folder, 'annotations/training') else: img_folder = os.path.join(folder, 'images/validation') mask_folder = os.path.join(folder, 'annotations/validation') for filename in os.listdir(img_folder): basename, _ = os.path.splitext(filename) if filename.endswith(".jpg"): imgpath = os.path.join(img_folder, filename) maskname = basename + '.png' maskpath = os.path.join(mask_folder, maskname) if os.path.isfile(maskpath): img_paths.append(imgpath) mask_paths.append(maskpath) else: print('cannot find the mask:', maskpath) return img_paths, mask_paths if __name__ == '__main__': train_dataset = ADE20KSegmentation()