高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 line
6.5KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from torch.utils.data import Dataset
  6. import torchvision.transforms as transforms
  7. import os.path as osp
  8. import os
  9. from PIL import Image
  10. import numpy as np
  11. import json
  12. import cv2
  13. import time
  14. from transform import *
  15. class Heliushuju(Dataset):
  16. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  17. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  18. super(Heliushuju, self).__init__(*args, **kwargs)
  19. assert mode in ('train', 'val', 'test', 'trainval')
  20. self.mode = mode
  21. print('self.mode', self.mode)
  22. self.ignore_lb = 255
  23. with open('./heliushuju_info.json', 'r') as fr:
  24. labels_info = json.load(fr)
  25. # print('###line30:',labels_info)
  26. # self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  27. self.lb_map = {el['id']: el['color'] for el in labels_info}
  28. # print('###line32:', self.lb_map)
  29. # parse img directory
  30. self.imgs = {}
  31. imgnames = []
  32. impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径
  33. folders = os.listdir(impth) # 图片名列表
  34. names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀
  35. impths = [osp.join(impth, el) for el in folders] # 图片路径
  36. imgnames.extend(names) # 存放图片名前缀的列表
  37. self.imgs.update(dict(zip(names, impths)))
  38. # parse gt directory
  39. self.labels = {}
  40. gtnames = []
  41. gtpth = osp.join(rootpth, mode, 'labels_2')
  42. folders = os.listdir(gtpth)
  43. names = [el.replace(el[-4:], '') for el in folders]
  44. lbpths = [osp.join(gtpth, el) for el in folders]
  45. gtnames.extend(names)
  46. self.labels.update(dict(zip(names, lbpths)))
  47. self.imnames = imgnames
  48. self.len = len(self.imnames)
  49. print('self.len', self.mode, self.len)
  50. assert set(imgnames) == set(gtnames)
  51. assert set(self.imnames) == set(self.imgs.keys())
  52. assert set(self.imnames) == set(self.labels.keys())
  53. # pre-processing
  54. self.to_tensor = transforms.Compose([
  55. transforms.ToTensor(),
  56. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  57. ])
  58. self.trans_train = Compose([
  59. ColorJitter(
  60. brightness = 0.5,
  61. contrast = 0.5,
  62. saturation = 0.5),
  63. HorizontalFlip(),
  64. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  65. RandomScale(randomscale),
  66. # RandomScale((0.125, 1)),
  67. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  69. RandomCrop(cropsize)
  70. ])
  71. self.mean = (0.485, 0.456, 0.406)
  72. self.std = (0.229, 0.224, 0.225)
  73. def __getitem__(self, idx):
  74. fn = self.imnames[idx]
  75. impth = self.imgs[fn]
  76. lbpth = self.labels[fn]
  77. img = Image.open(impth).convert('RGB')
  78. # img = cv2.imread(impth);img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  79. # label = Image.open(lbpth) # 改动
  80. label = cv2.imread(lbpth) # 原始
  81. label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确)
  82. # plt.figure(1);plt.imshow(label);plt.show() # 添加
  83. if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val':
  84. label = Image.fromarray(label)
  85. im_lb = dict(im = img, lb = label)
  86. im_lb = self.trans_train(im_lb)
  87. img, label = im_lb['im'], im_lb['lb']
  88. # img = self.to_tensor(img)
  89. img = np.array(img);
  90. img_bak = img.copy()
  91. img = self.preprocess_image(img)
  92. label = cv2.resize(np.array(label), (640, 360))
  93. label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度
  94. # label = cv2.resize(label,(640,360))
  95. # print('###line108:', self.lb_map)
  96. label = self.convert_labels(label)
  97. # plt.figure(0);plt.imshow(label[0]);
  98. # plt.figure(1);plt.imshow(img_bak);plt.show()
  99. return img, label.astype(np.int64)
  100. def __len__(self):
  101. return self.len
  102. def convert_labels(self, label):
  103. b, h, w, c = label.shape
  104. # print('####line118:',label.shape)
  105. # b, h, w = label.shape # [1,360,640]
  106. label_index = np.zeros((b, h, w))
  107. for k, v in self.lb_map.items():
  108. t_0 = (label[..., 0] == v[0])
  109. t_1 = (label[..., 1] == v[1])
  110. t_2 = (label[..., 2] == v[2])
  111. t_loc = (t_0 & t_1 & t_2)
  112. label_index[t_loc] = k
  113. # label[label == k] = v
  114. # print(label)
  115. # print("6666666666666666")
  116. return label_index
  117. def preprocess_image(self, image):
  118. time0 = time.time()
  119. image = cv2.resize(image, (640, 360))
  120. time1 = time.time()
  121. image = image.astype(np.float32)
  122. image /= 255.0
  123. time2 = time.time()
  124. # image = image * 3.2 - 1.6
  125. image[:, :, 0] -= self.mean[0]
  126. image[:, :, 1] -= self.mean[1]
  127. image[:, :, 2] -= self.mean[2]
  128. time3 = time.time()
  129. image[:, :, 0] /= self.std[0]
  130. image[:, :, 1] /= self.std[1]
  131. image[:, :, 2] /= self.std[2]
  132. time4 = time.time()
  133. image = np.transpose(image, (2, 0, 1))
  134. time5 = time.time()
  135. image = torch.from_numpy(image).float()
  136. # image = image.unsqueeze(0)
  137. # outStr = '###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f ' % (
  138. # self.get_ms(time1, time0), self.get_ms(time2, time1), self.get_ms(time3, time2), self.get_ms(time4, time3),
  139. # self.get_ms(time5, time4))
  140. # print(outStr)
  141. # print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  142. return image
  143. if __name__ == "__main__":
  144. from tqdm import tqdm
  145. # ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始
  146. ds = Heliushuju('./data/', n_classes=3, mode='val') # 改动
  147. uni = []
  148. for im, lb in tqdm(ds):
  149. lb_uni = np.unique(lb).tolist()
  150. uni.extend(lb_uni)
  151. print(uni)
  152. print(set(uni))