#!/usr/bin/python # -*- encoding: utf-8 -*- import torch from matplotlib import pyplot as plt from torch.utils.data import Dataset import torchvision.transforms as transforms import os.path as osp import os from PIL import Image import numpy as np import json import cv2 import time from transform import * class Heliushuju(Dataset): def __init__(self, rootpth, cropsize=(640, 480), mode='train', randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs): super(Heliushuju, self).__init__(*args, **kwargs) assert mode in ('train', 'val', 'test', 'trainval') self.mode = mode print('self.mode', self.mode) self.ignore_lb = 255 with open('./heliushuju_info.json', 'r') as fr: labels_info = json.load(fr) # print('###line30:',labels_info) # self.lb_map = {el['id']: el['trainId'] for el in labels_info} self.lb_map = {el['id']: el['color'] for el in labels_info} # print('###line32:', self.lb_map) # parse img directory self.imgs = {} imgnames = [] impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径 folders = os.listdir(impth) # 图片名列表 names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀 impths = [osp.join(impth, el) for el in folders] # 图片路径 imgnames.extend(names) # 存放图片名前缀的列表 self.imgs.update(dict(zip(names, impths))) # parse gt directory self.labels = {} gtnames = [] gtpth = osp.join(rootpth, mode, 'labels_2') folders = os.listdir(gtpth) names = [el.replace(el[-4:], '') for el in folders] lbpths = [osp.join(gtpth, el) for el in folders] gtnames.extend(names) self.labels.update(dict(zip(names, lbpths))) self.imnames = imgnames self.len = len(self.imnames) print('self.len', self.mode, self.len) assert set(imgnames) == set(gtnames) assert set(self.imnames) == set(self.imgs.keys()) assert set(self.imnames) == set(self.labels.keys()) # pre-processing self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) self.trans_train = Compose([ ColorJitter( brightness = 0.5, contrast = 0.5, saturation = 0.5), HorizontalFlip(), # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), RandomScale(randomscale), # RandomScale((0.125, 1)), # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)), # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)), RandomCrop(cropsize) ]) self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def __getitem__(self, idx): fn = self.imnames[idx] impth = self.imgs[fn] lbpth = self.labels[fn] img = Image.open(impth).convert('RGB') # img = cv2.imread(impth);img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # label = Image.open(lbpth) # 改动 label = cv2.imread(lbpth) # 原始 label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确) # plt.figure(1);plt.imshow(label);plt.show() # 添加 if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val': label = Image.fromarray(label) im_lb = dict(im = img, lb = label) im_lb = self.trans_train(im_lb) img, label = im_lb['im'], im_lb['lb'] # img = self.to_tensor(img) img = np.array(img); img_bak = img.copy() img = self.preprocess_image(img) label = cv2.resize(np.array(label), (640, 360)) label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度 # label = cv2.resize(label,(640,360)) # print('###line108:', self.lb_map) label = self.convert_labels(label) # plt.figure(0);plt.imshow(label[0]); # plt.figure(1);plt.imshow(img_bak);plt.show() return img, label.astype(np.int64) def __len__(self): return self.len def convert_labels(self, label): b, h, w, c = label.shape # print('####line118:',label.shape) # b, h, w = label.shape # [1,360,640] label_index = np.zeros((b, h, w)) for k, v in self.lb_map.items(): t_0 = (label[..., 0] == v[0]) t_1 = (label[..., 1] == v[1]) t_2 = (label[..., 2] == v[2]) t_loc = (t_0 & t_1 & t_2) label_index[t_loc] = k # label[label == k] = v # print(label) # print("6666666666666666") return label_index def preprocess_image(self, image): time0 = time.time() image = cv2.resize(image, (640, 360)) time1 = time.time() image = image.astype(np.float32) image /= 255.0 time2 = time.time() # image = image * 3.2 - 1.6 image[:, :, 0] -= self.mean[0] image[:, :, 1] -= self.mean[1] image[:, :, 2] -= self.mean[2] time3 = time.time() image[:, :, 0] /= self.std[0] image[:, :, 1] /= self.std[1] image[:, :, 2] /= self.std[2] time4 = time.time() image = np.transpose(image, (2, 0, 1)) time5 = time.time() image = torch.from_numpy(image).float() # image = image.unsqueeze(0) # outStr = '###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f ' % ( # self.get_ms(time1, time0), self.get_ms(time2, time1), self.get_ms(time3, time2), self.get_ms(time4, time3), # self.get_ms(time5, time4)) # print(outStr) # print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) ) return image if __name__ == "__main__": from tqdm import tqdm # ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始 ds = Heliushuju('./data/', n_classes=3, mode='val') # 改动 uni = [] for im, lb in tqdm(ds): lb_uni = np.unique(lb).tolist() uni.extend(lb_uni) print(uni) print(set(uni))