137 lines
5.3 KiB
Python
137 lines
5.3 KiB
Python
"""MSCOCO Semantic Segmentation pretraining for VOC."""
|
|
import os
|
|
import pickle
|
|
import torch
|
|
import numpy as np
|
|
|
|
from tqdm import trange
|
|
from PIL import Image
|
|
from .segbase import SegmentationDataset
|
|
|
|
|
|
class COCOSegmentation(SegmentationDataset):
|
|
"""COCO Semantic Segmentation Dataset for VOC Pre-training.
|
|
|
|
Parameters
|
|
----------
|
|
root : string
|
|
Path to ADE20K folder. Default is './datasets/coco'
|
|
split: string
|
|
'train', 'val' or 'test'
|
|
transform : callable, optional
|
|
A function that transforms the image
|
|
Examples
|
|
--------
|
|
>>> from torchvision import transforms
|
|
>>> import torch.utils.data as data
|
|
>>> # Transforms for Normalization
|
|
>>> input_transform = transforms.Compose([
|
|
>>> transforms.ToTensor(),
|
|
>>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
|
|
>>> ])
|
|
>>> # Create Dataset
|
|
>>> trainset = COCOSegmentation(split='train', transform=input_transform)
|
|
>>> # Create Training Loader
|
|
>>> train_data = data.DataLoader(
|
|
>>> trainset, 4, shuffle=True,
|
|
>>> num_workers=4)
|
|
"""
|
|
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
|
|
1, 64, 20, 63, 7, 72]
|
|
NUM_CLASS = 21
|
|
|
|
def __init__(self, root='../datasets/coco', split='train', mode=None, transform=None, **kwargs):
|
|
super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs)
|
|
# lazy import pycocotools
|
|
from pycocotools.coco import COCO
|
|
from pycocotools import mask
|
|
if split == 'train':
|
|
print('train set')
|
|
ann_file = os.path.join(root, 'annotations/instances_train2017.json')
|
|
ids_file = os.path.join(root, 'annotations/train_ids.mx')
|
|
self.root = os.path.join(root, 'train2017')
|
|
else:
|
|
print('val set')
|
|
ann_file = os.path.join(root, 'annotations/instances_val2017.json')
|
|
ids_file = os.path.join(root, 'annotations/val_ids.mx')
|
|
self.root = os.path.join(root, 'val2017')
|
|
self.coco = COCO(ann_file)
|
|
self.coco_mask = mask
|
|
if os.path.exists(ids_file):
|
|
with open(ids_file, 'rb') as f:
|
|
self.ids = pickle.load(f)
|
|
else:
|
|
ids = list(self.coco.imgs.keys())
|
|
self.ids = self._preprocess(ids, ids_file)
|
|
self.transform = transform
|
|
|
|
def __getitem__(self, index):
|
|
coco = self.coco
|
|
img_id = self.ids[index]
|
|
img_metadata = coco.loadImgs(img_id)[0]
|
|
path = img_metadata['file_name']
|
|
img = Image.open(os.path.join(self.root, path)).convert('RGB')
|
|
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
|
|
mask = Image.fromarray(self._gen_seg_mask(
|
|
cocotarget, img_metadata['height'], img_metadata['width']))
|
|
# synchrosized transform
|
|
if self.mode == 'train':
|
|
img, mask = self._sync_transform(img, mask)
|
|
elif self.mode == 'val':
|
|
img, mask = self._val_sync_transform(img, mask)
|
|
else:
|
|
assert self.mode == 'testval'
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
# general resize, normalize and toTensor
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, mask, os.path.basename(self.ids[index])
|
|
|
|
def _mask_transform(self, mask):
|
|
return torch.LongTensor(np.array(mask).astype('int32'))
|
|
|
|
def _gen_seg_mask(self, target, h, w):
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
coco_mask = self.coco_mask
|
|
for instance in target:
|
|
rle = coco_mask.frPyObjects(instance['Segmentation'], h, w)
|
|
m = coco_mask.decode(rle)
|
|
cat = instance['category_id']
|
|
if cat in self.CAT_LIST:
|
|
c = self.CAT_LIST.index(cat)
|
|
else:
|
|
continue
|
|
if len(m.shape) < 3:
|
|
mask[:, :] += (mask == 0) * (m * c)
|
|
else:
|
|
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
|
|
return mask
|
|
|
|
def _preprocess(self, ids, ids_file):
|
|
print("Preprocessing mask, this will take a while." + \
|
|
"But don't worry, it only run once for each split.")
|
|
tbar = trange(len(ids))
|
|
new_ids = []
|
|
for i in tbar:
|
|
img_id = ids[i]
|
|
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
|
|
img_metadata = self.coco.loadImgs(img_id)[0]
|
|
mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width'])
|
|
# more than 1k pixels
|
|
if (mask > 0).sum() > 1000:
|
|
new_ids.append(img_id)
|
|
tbar.set_description('Doing: {}/{}, got {} qualified images'. \
|
|
format(i, len(ids), len(new_ids)))
|
|
print('Found number of qualified images: ', len(new_ids))
|
|
with open(ids_file, 'wb') as f:
|
|
pickle.dump(new_ids, f)
|
|
return new_ids
|
|
|
|
@property
|
|
def classes(self):
|
|
"""Category names."""
|
|
return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
|
|
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
|
'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
|
|
'tv')
|