#!/usr/bin/python # -*- encoding: utf-8 -*- import torch from matplotlib import pyplot as plt from torch.utils.data import Dataset import torchvision.transforms as transforms import os.path as osp import os from PIL import Image import numpy as np import json import cv2 import time from transform import * class Heliushuju(Dataset): def __init__(self, rootpth, cropsize=(640, 480), mode='train',labelJson='./heliushuju_info.json', randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs): super(Heliushuju, self).__init__(*args, **kwargs) assert mode in ('train', 'val', 'test', 'trainval') self.mode = mode self.modeSize=cropsize self.ignore_lb = 255 #with open('./heliushuju_info.json', 'r') as fr: with open(labelJson,'r') as fr: print('labelJson:',labelJson) labels_info = json.load(fr) self.lb_map = {el['id']: el['color'] for el in labels_info} self.imgs = {} imgnames = [] impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径 folders = os.listdir(impth) # 图片名列表 names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀 impths = [osp.join(impth, el) for el in folders] # 图片路径 imgnames.extend(names) # 存放图片名前缀的列表 self.imgs.update(dict(zip(names, impths))) if self.mode !='test': self.labels = {} gtnames = [] gtpth = osp.join(rootpth, mode, 'labels') folders = os.listdir(gtpth) names = [el.replace(el[-4:], '') for el in folders] lbpths = [osp.join(gtpth, el) for el in folders] gtnames.extend(names) self.labels.update(dict(zip(names, lbpths))) self.imnames = imgnames self.len = len(self.imnames) print('self.len', self.mode, self.len) if self.mode !='test': assert set(imgnames) == set(gtnames) assert set(self.imnames) == set(self.imgs.keys()) assert set(self.imnames) == set(self.labels.keys()) # pre-processing self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) self.trans_train = Compose([ ColorJitter( brightness = 0.5, contrast = 0.5, saturation = 0.5), HorizontalFlip(), RandomScale(randomscale), RandomCrop(cropsize) ]) self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def __getitem__(self, idx): fn = self.imnames[idx] impth = self.imgs[fn] img = Image.open(impth).convert('RGB') if self.mode !='test': lbpth = self.labels[fn] label = cv2.imread(lbpth) # 原始 label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确) if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val': label = Image.fromarray(label) im_lb = dict(im = img, lb = label) im_lb = self.trans_train(im_lb) img, label = im_lb['im'], im_lb['lb'] img = np.array(img); img = self.preprocess_image(img) if self.mode !='test': label = cv2.resize(np.array(label), self.modeSize) label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度 label = self.convert_labels(label) return img, label.astype(np.int64) else: return img,fn def __len__(self): return self.len def convert_labels(self, label): b, h, w, c = label.shape label_index = np.zeros((b, h, w)) for k, v in self.lb_map.items(): t_0 = (label[..., 0] == v[0]) t_1 = (label[..., 1] == v[1]) t_2 = (label[..., 2] == v[2]) t_loc = (t_0 & t_1 & t_2) label_index[t_loc] = k # label[label == k] = v # print(label) # print("6666666666666666") return label_index def preprocess_image(self, image): time0 = time.time() image = cv2.resize(image, self.modeSize) time1 = time.time() image = image.astype(np.float32) image /= 255.0 time2 = time.time() # image = image * 3.2 - 1.6 image[:, :, 0] -= self.mean[0] image[:, :, 1] -= self.mean[1] image[:, :, 2] -= self.mean[2] time3 = time.time() image[:, :, 0] /= self.std[0] image[:, :, 1] /= self.std[1] image[:, :, 2] /= self.std[2] time4 = time.time() image = np.transpose(image, (2, 0, 1)) time5 = time.time() image = torch.from_numpy(image).float() return image class Heliushuju_test(Dataset): def __init__(self, rootpth, cropsize=(640, 480), mode='test',labelJson='./heliushuju_info.json', randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs): super(Heliushuju_test, self).__init__(*args, **kwargs) assert mode in ('train', 'val', 'test', 'trainval') self.mode = mode self.modeSize=cropsize #with open('./heliushuju_info.json', 'r') as fr: with open(labelJson,'r') as fr: labels_info = json.load(fr) self.lb_map = {el['id']: el['color'] for el in labels_info} self.imgs = {} imgnames = [] impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径 folders = os.listdir(impth) # 图片名列表 names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀 impths = [osp.join(impth, el) for el in folders] # 图片路径 imgnames.extend(names) # 存放图片名前缀的列表 self.imgs.update(dict(zip(names, impths))) self.imnames = imgnames self.len = len(self.imnames) print('self.len', self.mode, self.len) assert set(imgnames) == set(gtnames) assert set(self.imnames) == set(self.imgs.keys()) assert set(self.imnames) == set(self.labels.keys()) # pre-processing self.to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) self.trans_train = Compose([ ColorJitter( brightness = 0.5, contrast = 0.5, saturation = 0.5), HorizontalFlip(), RandomScale(randomscale), RandomCrop(cropsize) ]) self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def __getitem__(self, idx): fn = self.imnames[idx] impth = self.imgs[fn] lbpth = self.labels[fn] img = Image.open(impth).convert('RGB') label = cv2.imread(lbpth) # 原始 label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确) if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val': label = Image.fromarray(label) im_lb = dict(im = img, lb = label) im_lb = self.trans_train(im_lb) img, label = im_lb['im'], im_lb['lb'] img = np.array(img); img_bak = img.copy() img = self.preprocess_image(img) label = cv2.resize(np.array(label), self.modeSize) label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度 label = self.convert_labels(label) return img, label.astype(np.int64) def __len__(self): return self.len def convert_labels(self, label): b, h, w, c = label.shape label_index = np.zeros((b, h, w)) for k, v in self.lb_map.items(): t_0 = (label[..., 0] == v[0]) t_1 = (label[..., 1] == v[1]) t_2 = (label[..., 2] == v[2]) t_loc = (t_0 & t_1 & t_2) label_index[t_loc] = k # label[label == k] = v # print(label) # print("6666666666666666") return label_index def preprocess_image(self, image): time0 = time.time() image = cv2.resize(image, self.modeSize) time1 = time.time() image = image.astype(np.float32) image /= 255.0 time2 = time.time() # image = image * 3.2 - 1.6 image[:, :, 0] -= self.mean[0] image[:, :, 1] -= self.mean[1] image[:, :, 2] -= self.mean[2] time3 = time.time() image[:, :, 0] /= self.std[0] image[:, :, 1] /= self.std[1] image[:, :, 2] /= self.std[2] time4 = time.time() image = np.transpose(image, (2, 0, 1)) time5 = time.time() image = torch.from_numpy(image).float() return image if __name__ == "__main__": from tqdm import tqdm # ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始 ds = Heliushuju('./data/', n_classes=3, mode='val') # 改动 uni = [] for im, lb in tqdm(ds): lb_uni = np.unique(lb).tolist() uni.extend(lb_uni) print(uni) print(set(uni))