STDC-th/heliushuju_process.py

296 lines
9.3 KiB
Python

#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os.path as osp
import os
from PIL import Image
import numpy as np
import json
import cv2
import time
from transform import *
class Heliushuju(Dataset):
def __init__(self, rootpth, cropsize=(640, 480), mode='train',labelJson='./heliushuju_info.json',
randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
super(Heliushuju, self).__init__(*args, **kwargs)
assert mode in ('train', 'val', 'test', 'trainval')
self.mode = mode
self.modeSize=cropsize
self.ignore_lb = 255
#with open('./heliushuju_info.json', 'r') as fr:
with open(labelJson,'r') as fr:
print('labelJson:',labelJson)
labels_info = json.load(fr)
self.lb_map = {el['id']: el['color'] for el in labels_info}
self.imgs = {}
imgnames = []
impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径
folders = os.listdir(impth) # 图片名列表
names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀
impths = [osp.join(impth, el) for el in folders] # 图片路径
imgnames.extend(names) # 存放图片名前缀的列表
self.imgs.update(dict(zip(names, impths)))
if self.mode !='test':
self.labels = {}
gtnames = []
gtpth = osp.join(rootpth, mode, 'labels')
folders = os.listdir(gtpth)
names = [el.replace(el[-4:], '') for el in folders]
lbpths = [osp.join(gtpth, el) for el in folders]
gtnames.extend(names)
self.labels.update(dict(zip(names, lbpths)))
self.imnames = imgnames
self.len = len(self.imnames)
print('self.len', self.mode, self.len)
if self.mode !='test':
assert set(imgnames) == set(gtnames)
assert set(self.imnames) == set(self.imgs.keys())
assert set(self.imnames) == set(self.labels.keys())
# pre-processing
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.trans_train = Compose([
ColorJitter(
brightness = 0.5,
contrast = 0.5,
saturation = 0.5),
HorizontalFlip(),
RandomScale(randomscale),
RandomCrop(cropsize)
])
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def __getitem__(self, idx):
fn = self.imnames[idx]
impth = self.imgs[fn]
img = Image.open(impth).convert('RGB')
if self.mode !='test':
lbpth = self.labels[fn]
label = cv2.imread(lbpth) # 原始
label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确)
if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val':
label = Image.fromarray(label)
im_lb = dict(im = img, lb = label)
im_lb = self.trans_train(im_lb)
img, label = im_lb['im'], im_lb['lb']
img = np.array(img);
img = self.preprocess_image(img)
if self.mode !='test':
label = cv2.resize(np.array(label), self.modeSize)
label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度
label = self.convert_labels(label)
return img, label.astype(np.int64)
else:
return img,fn
def __len__(self):
return self.len
def convert_labels(self, label):
b, h, w, c = label.shape
label_index = np.zeros((b, h, w))
for k, v in self.lb_map.items():
t_0 = (label[..., 0] == v[0])
t_1 = (label[..., 1] == v[1])
t_2 = (label[..., 2] == v[2])
t_loc = (t_0 & t_1 & t_2)
label_index[t_loc] = k
# label[label == k] = v
# print(label)
# print("6666666666666666")
return label_index
def preprocess_image(self, image):
time0 = time.time()
image = cv2.resize(image, self.modeSize)
time1 = time.time()
image = image.astype(np.float32)
image /= 255.0
time2 = time.time()
# image = image * 3.2 - 1.6
image[:, :, 0] -= self.mean[0]
image[:, :, 1] -= self.mean[1]
image[:, :, 2] -= self.mean[2]
time3 = time.time()
image[:, :, 0] /= self.std[0]
image[:, :, 1] /= self.std[1]
image[:, :, 2] /= self.std[2]
time4 = time.time()
image = np.transpose(image, (2, 0, 1))
time5 = time.time()
image = torch.from_numpy(image).float()
return image
class Heliushuju_test(Dataset):
def __init__(self, rootpth, cropsize=(640, 480), mode='test',labelJson='./heliushuju_info.json',
randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
super(Heliushuju_test, self).__init__(*args, **kwargs)
assert mode in ('train', 'val', 'test', 'trainval')
self.mode = mode
self.modeSize=cropsize
#with open('./heliushuju_info.json', 'r') as fr:
with open(labelJson,'r') as fr:
labels_info = json.load(fr)
self.lb_map = {el['id']: el['color'] for el in labels_info}
self.imgs = {}
imgnames = []
impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径
folders = os.listdir(impth) # 图片名列表
names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀
impths = [osp.join(impth, el) for el in folders] # 图片路径
imgnames.extend(names) # 存放图片名前缀的列表
self.imgs.update(dict(zip(names, impths)))
self.imnames = imgnames
self.len = len(self.imnames)
print('self.len', self.mode, self.len)
assert set(imgnames) == set(gtnames)
assert set(self.imnames) == set(self.imgs.keys())
assert set(self.imnames) == set(self.labels.keys())
# pre-processing
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.trans_train = Compose([
ColorJitter(
brightness = 0.5,
contrast = 0.5,
saturation = 0.5),
HorizontalFlip(),
RandomScale(randomscale),
RandomCrop(cropsize)
])
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def __getitem__(self, idx):
fn = self.imnames[idx]
impth = self.imgs[fn]
lbpth = self.labels[fn]
img = Image.open(impth).convert('RGB')
label = cv2.imread(lbpth) # 原始
label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确)
if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val':
label = Image.fromarray(label)
im_lb = dict(im = img, lb = label)
im_lb = self.trans_train(im_lb)
img, label = im_lb['im'], im_lb['lb']
img = np.array(img);
img_bak = img.copy()
img = self.preprocess_image(img)
label = cv2.resize(np.array(label), self.modeSize)
label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度
label = self.convert_labels(label)
return img, label.astype(np.int64)
def __len__(self):
return self.len
def convert_labels(self, label):
b, h, w, c = label.shape
label_index = np.zeros((b, h, w))
for k, v in self.lb_map.items():
t_0 = (label[..., 0] == v[0])
t_1 = (label[..., 1] == v[1])
t_2 = (label[..., 2] == v[2])
t_loc = (t_0 & t_1 & t_2)
label_index[t_loc] = k
# label[label == k] = v
# print(label)
# print("6666666666666666")
return label_index
def preprocess_image(self, image):
time0 = time.time()
image = cv2.resize(image, self.modeSize)
time1 = time.time()
image = image.astype(np.float32)
image /= 255.0
time2 = time.time()
# image = image * 3.2 - 1.6
image[:, :, 0] -= self.mean[0]
image[:, :, 1] -= self.mean[1]
image[:, :, 2] -= self.mean[2]
time3 = time.time()
image[:, :, 0] /= self.std[0]
image[:, :, 1] /= self.std[1]
image[:, :, 2] /= self.std[2]
time4 = time.time()
image = np.transpose(image, (2, 0, 1))
time5 = time.time()
image = torch.from_numpy(image).float()
return image
if __name__ == "__main__":
from tqdm import tqdm
# ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始
ds = Heliushuju('./data/', n_classes=3, mode='val') # 改动
uni = []
for im, lb in tqdm(ds):
lb_uni = np.unique(lb).tolist()
uni.extend(lb_uni)
print(uni)
print(set(uni))