AIlib2/segutils/core/data/dataloader/mscoco.py

137 lines
5.3 KiB
Python
Raw Normal View History

2025-04-26 10:35:59 +08:00
"""MSCOCO Semantic Segmentation pretraining for VOC."""
import os
import pickle
import torch
import numpy as np
from tqdm import trange
from PIL import Image
from .segbase import SegmentationDataset
class COCOSegmentation(SegmentationDataset):
"""COCO Semantic Segmentation Dataset for VOC Pre-training.
Parameters
----------
root : string
Path to ADE20K folder. Default is './datasets/coco'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
>>> ])
>>> # Create Dataset
>>> trainset = COCOSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
1, 64, 20, 63, 7, 72]
NUM_CLASS = 21
def __init__(self, root='../datasets/coco', split='train', mode=None, transform=None, **kwargs):
super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs)
# lazy import pycocotools
from pycocotools.coco import COCO
from pycocotools import mask
if split == 'train':
print('train set')
ann_file = os.path.join(root, 'annotations/instances_train2017.json')
ids_file = os.path.join(root, 'annotations/train_ids.mx')
self.root = os.path.join(root, 'train2017')
else:
print('val set')
ann_file = os.path.join(root, 'annotations/instances_val2017.json')
ids_file = os.path.join(root, 'annotations/val_ids.mx')
self.root = os.path.join(root, 'val2017')
self.coco = COCO(ann_file)
self.coco_mask = mask
if os.path.exists(ids_file):
with open(ids_file, 'rb') as f:
self.ids = pickle.load(f)
else:
ids = list(self.coco.imgs.keys())
self.ids = self._preprocess(ids, ids_file)
self.transform = transform
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
img_metadata = coco.loadImgs(img_id)[0]
path = img_metadata['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
mask = Image.fromarray(self._gen_seg_mask(
cocotarget, img_metadata['height'], img_metadata['width']))
# synchrosized transform
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
elif self.mode == 'val':
img, mask = self._val_sync_transform(img, mask)
else:
assert self.mode == 'testval'
img, mask = self._img_transform(img), self._mask_transform(mask)
# general resize, normalize and toTensor
if self.transform is not None:
img = self.transform(img)
return img, mask, os.path.basename(self.ids[index])
def _mask_transform(self, mask):
return torch.LongTensor(np.array(mask).astype('int32'))
def _gen_seg_mask(self, target, h, w):
mask = np.zeros((h, w), dtype=np.uint8)
coco_mask = self.coco_mask
for instance in target:
rle = coco_mask.frPyObjects(instance['Segmentation'], h, w)
m = coco_mask.decode(rle)
cat = instance['category_id']
if cat in self.CAT_LIST:
c = self.CAT_LIST.index(cat)
else:
continue
if len(m.shape) < 3:
mask[:, :] += (mask == 0) * (m * c)
else:
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
return mask
def _preprocess(self, ids, ids_file):
print("Preprocessing mask, this will take a while." + \
"But don't worry, it only run once for each split.")
tbar = trange(len(ids))
new_ids = []
for i in tbar:
img_id = ids[i]
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
img_metadata = self.coco.loadImgs(img_id)[0]
mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width'])
# more than 1k pixels
if (mask > 0).sum() > 1000:
new_ids.append(img_id)
tbar.set_description('Doing: {}/{}, got {} qualified images'. \
format(i, len(ids), len(new_ids)))
print('Found number of qualified images: ', len(new_ids))
with open(ids_file, 'wb') as f:
pickle.dump(new_ids, f)
return new_ids
@property
def classes(self):
"""Category names."""
return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
'tv')