高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
3.9KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class Heliushuju(Dataset):
  13. # def __init__(self, rootpth, cropsize=(640, 480), mode='train', # 原始
  14. def __init__(self, rootpth, cropsize=(640, 480), mode='test', # 改动
  15. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  16. super(Heliushuju, self).__init__(*args, **kwargs)
  17. assert mode in ('train', 'val', 'test', 'trainval')
  18. self.mode = mode
  19. print('self.mode', self.mode)
  20. self.ignore_lb = 255
  21. with open('./heliushuju_info.json', 'r') as fr:
  22. labels_info = json.load(fr)
  23. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  24. ## parse img directory
  25. self.imgs = {}
  26. imgnames = []
  27. impth = osp.join(rootpth, mode, 'images')
  28. folders = os.listdir(impth)
  29. names = [el.replace(el[-4:], '') for el in folders]
  30. impths = [osp.join(impth, el) for el in folders]
  31. imgnames.extend(names)
  32. self.imgs.update(dict(zip(names, impths)))
  33. ## parse gt directory
  34. self.labels = {}
  35. gtnames = []
  36. gtpth = osp.join(rootpth, mode, 'labels_2')
  37. folders = os.listdir(gtpth)
  38. names = [el.replace(el[-4:], '') for el in folders]
  39. lbpths = [osp.join(gtpth, el) for el in folders]
  40. gtnames.extend(names)
  41. self.labels.update(dict(zip(names, lbpths)))
  42. self.imnames = imgnames
  43. self.len = len(self.imnames)
  44. print('self.len', self.mode, self.len)
  45. assert set(imgnames) == set(gtnames)
  46. assert set(self.imnames) == set(self.imgs.keys())
  47. assert set(self.imnames) == set(self.labels.keys())
  48. ## pre-processing
  49. self.to_tensor = transforms.Compose([
  50. transforms.ToTensor(),
  51. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  52. ])
  53. self.trans_train = Compose([
  54. ColorJitter(
  55. brightness = 0.5,
  56. contrast = 0.5,
  57. saturation = 0.5),
  58. HorizontalFlip(),
  59. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  60. RandomScale(randomscale),
  61. # RandomScale((0.125, 1)),
  62. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  63. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  64. RandomCrop(cropsize)###############################################################
  65. ])
  66. def __getitem__(self, idx):
  67. fn = self.imnames[idx]
  68. impth = self.imgs[fn]
  69. lbpth = self.labels[fn]
  70. img = Image.open(impth).convert('RGB')
  71. label = Image.open(lbpth)
  72. # if self.mode == 'train' or self.mode == 'trainval': # 原始
  73. if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'test': # 改动
  74. im_lb = dict(im = img, lb = label)
  75. im_lb = self.trans_train(im_lb)
  76. img, label = im_lb['im'], im_lb['lb']
  77. img = self.to_tensor(img)
  78. label = np.array(label).astype(np.int64)[np.newaxis, :]
  79. label = self.convert_labels(label)
  80. return img, label
  81. def __len__(self):
  82. return self.len
  83. def convert_labels(self, label):
  84. for k, v in self.lb_map.items():
  85. label[label == k] = v
  86. return label
  87. if __name__ == "__main__":
  88. from tqdm import tqdm
  89. ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始
  90. # ds = Heliushuju('./data/', n_classes=2, mode='test') # 改动
  91. uni = []
  92. for im, lb in tqdm(ds):
  93. lb_uni = np.unique(lb).tolist()
  94. uni.extend(lb_uni)
  95. print(uni)
  96. print(set(uni))