高速公路违停检测
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

184 lignes
6.5KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from torch.utils.data import Dataset
  6. import torchvision.transforms as transforms
  7. import os.path as osp
  8. import os
  9. from PIL import Image
  10. import numpy as np
  11. import json
  12. import cv2
  13. import time
  14. from transform import *
  15. class Heliushuju(Dataset):
  16. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  17. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  18. super(Heliushuju, self).__init__(*args, **kwargs)
  19. assert mode in ('train', 'val', 'test', 'trainval')
  20. self.mode = mode
  21. print('self.mode', self.mode)
  22. self.ignore_lb = 255
  23. with open('./heliushuju_info.json', 'r') as fr:
  24. labels_info = json.load(fr)
  25. # print('###line30:',labels_info)
  26. # self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  27. self.lb_map = {el['id']: el['color'] for el in labels_info}
  28. # print('###line32:', self.lb_map)
  29. # parse img directory
  30. self.imgs = {}
  31. imgnames = []
  32. impth = osp.join(rootpth, mode, 'images') # 图片所在目录的路径
  33. folders = os.listdir(impth) # 图片名列表
  34. names = [el.replace(el[-4:], '') for el in folders] # el是整个图片名,names是图片名前缀
  35. impths = [osp.join(impth, el) for el in folders] # 图片路径
  36. imgnames.extend(names) # 存放图片名前缀的列表
  37. self.imgs.update(dict(zip(names, impths)))
  38. # parse gt directory
  39. self.labels = {}
  40. gtnames = []
  41. gtpth = osp.join(rootpth, mode, 'labels_2')
  42. folders = os.listdir(gtpth)
  43. names = [el.replace(el[-4:], '') for el in folders]
  44. lbpths = [osp.join(gtpth, el) for el in folders]
  45. gtnames.extend(names)
  46. self.labels.update(dict(zip(names, lbpths)))
  47. self.imnames = imgnames
  48. self.len = len(self.imnames)
  49. print('self.len', self.mode, self.len)
  50. assert set(imgnames) == set(gtnames)
  51. assert set(self.imnames) == set(self.imgs.keys())
  52. assert set(self.imnames) == set(self.labels.keys())
  53. # pre-processing
  54. self.to_tensor = transforms.Compose([
  55. transforms.ToTensor(),
  56. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  57. ])
  58. self.trans_train = Compose([
  59. ColorJitter(
  60. brightness = 0.5,
  61. contrast = 0.5,
  62. saturation = 0.5),
  63. HorizontalFlip(),
  64. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  65. RandomScale(randomscale),
  66. # RandomScale((0.125, 1)),
  67. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  69. RandomCrop(cropsize)
  70. ])
  71. self.mean = (0.485, 0.456, 0.406)
  72. self.std = (0.229, 0.224, 0.225)
  73. def __getitem__(self, idx):
  74. fn = self.imnames[idx]
  75. impth = self.imgs[fn]
  76. lbpth = self.labels[fn]
  77. img = Image.open(impth).convert('RGB')
  78. # img = cv2.imread(impth);img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  79. # label = Image.open(lbpth) # 改动
  80. label = cv2.imread(lbpth) # 原始
  81. label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB) # 添加(训练交通事故数据,添加了这行代码使标签颜色正确)
  82. # plt.figure(1);plt.imshow(label);plt.show() # 添加
  83. if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'val':
  84. label = Image.fromarray(label)
  85. im_lb = dict(im = img, lb = label)
  86. im_lb = self.trans_train(im_lb)
  87. img, label = im_lb['im'], im_lb['lb']
  88. # img = self.to_tensor(img)
  89. img = np.array(img);
  90. img_bak = img.copy()
  91. img = self.preprocess_image(img)
  92. label = cv2.resize(np.array(label), (640, 360))
  93. label = label.astype(np.int64)[np.newaxis, :] # 给行上增加维度
  94. # label = cv2.resize(label,(640,360))
  95. # print('###line108:', self.lb_map)
  96. label = self.convert_labels(label)
  97. # plt.figure(0);plt.imshow(label[0]);
  98. # plt.figure(1);plt.imshow(img_bak);plt.show()
  99. return img, label.astype(np.int64)
  100. def __len__(self):
  101. return self.len
  102. def convert_labels(self, label):
  103. b, h, w, c = label.shape
  104. # print('####line118:',label.shape)
  105. # b, h, w = label.shape # [1,360,640]
  106. label_index = np.zeros((b, h, w))
  107. for k, v in self.lb_map.items():
  108. t_0 = (label[..., 0] == v[0])
  109. t_1 = (label[..., 1] == v[1])
  110. t_2 = (label[..., 2] == v[2])
  111. t_loc = (t_0 & t_1 & t_2)
  112. label_index[t_loc] = k
  113. # label[label == k] = v
  114. # print(label)
  115. # print("6666666666666666")
  116. return label_index
  117. def preprocess_image(self, image):
  118. time0 = time.time()
  119. image = cv2.resize(image, (640, 360))
  120. time1 = time.time()
  121. image = image.astype(np.float32)
  122. image /= 255.0
  123. time2 = time.time()
  124. # image = image * 3.2 - 1.6
  125. image[:, :, 0] -= self.mean[0]
  126. image[:, :, 1] -= self.mean[1]
  127. image[:, :, 2] -= self.mean[2]
  128. time3 = time.time()
  129. image[:, :, 0] /= self.std[0]
  130. image[:, :, 1] /= self.std[1]
  131. image[:, :, 2] /= self.std[2]
  132. time4 = time.time()
  133. image = np.transpose(image, (2, 0, 1))
  134. time5 = time.time()
  135. image = torch.from_numpy(image).float()
  136. # image = image.unsqueeze(0)
  137. # outStr = '###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f ' % (
  138. # self.get_ms(time1, time0), self.get_ms(time2, time1), self.get_ms(time3, time2), self.get_ms(time4, time3),
  139. # self.get_ms(time5, time4))
  140. # print(outStr)
  141. # print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  142. return image
  143. if __name__ == "__main__":
  144. from tqdm import tqdm
  145. # ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始
  146. ds = Heliushuju('./data/', n_classes=3, mode='val') # 改动
  147. uni = []
  148. for im, lb in tqdm(ds):
  149. lb_uni = np.unique(lb).tolist()
  150. uni.extend(lb_uni)
  151. print(uni)
  152. print(set(uni))