#!/usr/bin/python # -*- encoding: utf-8 -*- import torch from torch.utils.data import Dataset import torchvision.transforms as transforms import os.path as osp import os from PIL import Image import numpy as np import json from transform import * class Heliushuju(Dataset): # def __init__(self, rootpth, cropsize=(640, 480), mode='train', # 原始 def __init__(self, rootpth, cropsize=(640, 480), mode='test', # 改动 randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs): super(Heliushuju, self).__init__(*args, **kwargs) assert mode in ('train', 'val', 'test', 'trainval') self.mode = mode print('self.mode', self.mode) self.ignore_lb = 255 with open('./heliushuju_info.json', 'r') as fr: labels_info = json.load(fr) self.lb_map = {el['id']: el['trainId'] for el in labels_info} ## parse img directory self.imgs = {} imgnames = [] impth = osp.join(rootpth, mode, 'images') folders = os.listdir(impth) names = [el.replace(el[-4:], '') for el in folders] impths = [osp.join(impth, el) for el in folders] imgnames.extend(names) self.imgs.update(dict(zip(names, impths))) ## parse gt directory self.labels = {} gtnames = [] gtpth = osp.join(rootpth, mode, 'labels_2') folders = os.listdir(gtpth) names = [el.replace(el[-4:], '') for el in folders] lbpths = [osp.join(gtpth, el) for el in folders] gtnames.extend(names) self.labels.update(dict(zip(names, lbpths))) self.imnames = imgnames self.len = len(self.imnames) print('self.len', self.mode, self.len) assert set(imgnames) == set(gtnames) assert set(self.imnames) == set(self.imgs.keys()) assert set(self.imnames) == set(self.labels.keys()) ## pre-processing self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) self.trans_train = Compose([ ColorJitter( brightness = 0.5, contrast = 0.5, saturation = 0.5), HorizontalFlip(), # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), RandomScale(randomscale), # RandomScale((0.125, 1)), # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)), # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)), RandomCrop(cropsize)############################################################### ]) def __getitem__(self, idx): fn = self.imnames[idx] impth = self.imgs[fn] lbpth = self.labels[fn] img = Image.open(impth).convert('RGB') label = Image.open(lbpth) # if self.mode == 'train' or self.mode == 'trainval': # 原始 if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'test': # 改动 im_lb = dict(im = img, lb = label) im_lb = self.trans_train(im_lb) img, label = im_lb['im'], im_lb['lb'] img = self.to_tensor(img) label = np.array(label).astype(np.int64)[np.newaxis, :] label = self.convert_labels(label) return img, label def __len__(self): return self.len def convert_labels(self, label): for k, v in self.lb_map.items(): label[label == k] = v return label if __name__ == "__main__": from tqdm import tqdm ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始 # ds = Heliushuju('./data/', n_classes=2, mode='test') # 改动 uni = [] for im, lb in tqdm(ds): lb_uni = np.unique(lb).tolist() uni.extend(lb_uni) print(uni) print(set(uni))