高速公路违停检测
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

122 lines
3.9KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class Heliushuju(Dataset):
  13. # def __init__(self, rootpth, cropsize=(640, 480), mode='train', # 原始
  14. def __init__(self, rootpth, cropsize=(640, 480), mode='test', # 改动
  15. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  16. super(Heliushuju, self).__init__(*args, **kwargs)
  17. assert mode in ('train', 'val', 'test', 'trainval')
  18. self.mode = mode
  19. print('self.mode', self.mode)
  20. self.ignore_lb = 255
  21. with open('./heliushuju_info.json', 'r') as fr:
  22. labels_info = json.load(fr)
  23. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  24. ## parse img directory
  25. self.imgs = {}
  26. imgnames = []
  27. impth = osp.join(rootpth, mode, 'images')
  28. folders = os.listdir(impth)
  29. names = [el.replace(el[-4:], '') for el in folders]
  30. impths = [osp.join(impth, el) for el in folders]
  31. imgnames.extend(names)
  32. self.imgs.update(dict(zip(names, impths)))
  33. ## parse gt directory
  34. self.labels = {}
  35. gtnames = []
  36. gtpth = osp.join(rootpth, mode, 'labels_2')
  37. folders = os.listdir(gtpth)
  38. names = [el.replace(el[-4:], '') for el in folders]
  39. lbpths = [osp.join(gtpth, el) for el in folders]
  40. gtnames.extend(names)
  41. self.labels.update(dict(zip(names, lbpths)))
  42. self.imnames = imgnames
  43. self.len = len(self.imnames)
  44. print('self.len', self.mode, self.len)
  45. assert set(imgnames) == set(gtnames)
  46. assert set(self.imnames) == set(self.imgs.keys())
  47. assert set(self.imnames) == set(self.labels.keys())
  48. ## pre-processing
  49. self.to_tensor = transforms.Compose([
  50. transforms.ToTensor(),
  51. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  52. ])
  53. self.trans_train = Compose([
  54. ColorJitter(
  55. brightness = 0.5,
  56. contrast = 0.5,
  57. saturation = 0.5),
  58. HorizontalFlip(),
  59. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  60. RandomScale(randomscale),
  61. # RandomScale((0.125, 1)),
  62. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  63. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  64. RandomCrop(cropsize)###############################################################
  65. ])
  66. def __getitem__(self, idx):
  67. fn = self.imnames[idx]
  68. impth = self.imgs[fn]
  69. lbpth = self.labels[fn]
  70. img = Image.open(impth).convert('RGB')
  71. label = Image.open(lbpth)
  72. # if self.mode == 'train' or self.mode == 'trainval': # 原始
  73. if self.mode == 'train' or self.mode == 'trainval' or self.mode == 'test': # 改动
  74. im_lb = dict(im = img, lb = label)
  75. im_lb = self.trans_train(im_lb)
  76. img, label = im_lb['im'], im_lb['lb']
  77. img = self.to_tensor(img)
  78. label = np.array(label).astype(np.int64)[np.newaxis, :]
  79. label = self.convert_labels(label)
  80. return img, label
  81. def __len__(self):
  82. return self.len
  83. def convert_labels(self, label):
  84. for k, v in self.lb_map.items():
  85. label[label == k] = v
  86. return label
  87. if __name__ == "__main__":
  88. from tqdm import tqdm
  89. ds = Heliushuju('./data/', n_classes=2, mode='val') # 原始
  90. # ds = Heliushuju('./data/', n_classes=2, mode='test') # 改动
  91. uni = []
  92. for im, lb in tqdm(ds):
  93. lb_uni = np.unique(lb).tolist()
  94. uni.extend(lb_uni)
  95. print(uni)
  96. print(set(uni))