高速公路违停检测
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

124 lines
3.9KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class CityScapes(Dataset):
  13. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  14. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  15. super(CityScapes, self).__init__(*args, **kwargs)
  16. assert mode in ('train', 'val', 'test', 'trainval')
  17. self.mode = mode
  18. print('self.mode', self.mode)
  19. self.ignore_lb = 255
  20. with open('./cityscapes_info.json', 'r') as fr:
  21. labels_info = json.load(fr)
  22. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  23. ## parse img directory
  24. self.imgs = {}
  25. imgnames = []
  26. impth = osp.join(rootpth, 'leftImg8bit', mode)
  27. folders = os.listdir(impth)
  28. for fd in folders:
  29. fdpth = osp.join(impth, fd)
  30. im_names = os.listdir(fdpth)
  31. names = [el.replace('_leftImg8bit.png', '') for el in im_names]
  32. impths = [osp.join(fdpth, el) for el in im_names]
  33. imgnames.extend(names)
  34. self.imgs.update(dict(zip(names, impths)))
  35. ## parse gt directory
  36. self.labels = {}
  37. gtnames = []
  38. gtpth = osp.join(rootpth, 'gtFine', mode)
  39. folders = os.listdir(gtpth)
  40. for fd in folders:
  41. fdpth = osp.join(gtpth, fd)
  42. lbnames = os.listdir(fdpth)
  43. lbnames = [el for el in lbnames if 'labelIds' in el]
  44. names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
  45. lbpths = [osp.join(fdpth, el) for el in lbnames]
  46. gtnames.extend(names)
  47. self.labels.update(dict(zip(names, lbpths)))
  48. self.imnames = imgnames
  49. self.len = len(self.imnames)
  50. print('self.len', self.mode, self.len)
  51. assert set(imgnames) == set(gtnames)
  52. assert set(self.imnames) == set(self.imgs.keys())
  53. assert set(self.imnames) == set(self.labels.keys())
  54. ## pre-processing
  55. self.to_tensor = transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  58. ])
  59. self.trans_train = Compose([
  60. ColorJitter(
  61. brightness = 0.5,
  62. contrast = 0.5,
  63. saturation = 0.5),
  64. HorizontalFlip(),
  65. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  66. RandomScale(randomscale),
  67. # RandomScale((0.125, 1)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  69. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  70. RandomCrop(cropsize)
  71. ])
  72. def __getitem__(self, idx):
  73. fn = self.imnames[idx]
  74. impth = self.imgs[fn]
  75. lbpth = self.labels[fn]
  76. img = Image.open(impth).convert('RGB')
  77. label = Image.open(lbpth)
  78. if self.mode == 'train' or self.mode == 'trainval':
  79. im_lb = dict(im = img, lb = label)
  80. im_lb = self.trans_train(im_lb)
  81. img, label = im_lb['im'], im_lb['lb']
  82. img = self.to_tensor(img)
  83. label = np.array(label).astype(np.int64)[np.newaxis, :]
  84. label = self.convert_labels(label)
  85. return img, label
  86. def __len__(self):
  87. return self.len
  88. def convert_labels(self, label):
  89. for k, v in self.lb_map.items():
  90. label[label == k] = v
  91. return label
  92. if __name__ == "__main__":
  93. from tqdm import tqdm
  94. ds = CityScapes('./data/', n_classes=19, mode='val')
  95. uni = []
  96. for im, lb in tqdm(ds):
  97. lb_uni = np.unique(lb).tolist()
  98. uni.extend(lb_uni)
  99. print(uni)
  100. print(set(uni))