高速公路违停检测
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

121 lignes
3.7KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class CityScapes(Dataset):
  13. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  14. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  15. super(CityScapes, self).__init__(*args, **kwargs)
  16. assert mode in ('train', 'val', 'test', 'trainval')
  17. self.mode = mode
  18. print('self.mode', self.mode)
  19. self.ignore_lb = 255
  20. with open('./cityscapes_info.json', 'r') as fr:
  21. labels_info = json.load(fr)
  22. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  23. ## parse img directory
  24. self.imgs = {}
  25. imgnames = []
  26. impth = osp.join(rootpth, 'leftImg8bit', mode)
  27. folders = os.listdir(impth)
  28. for fd in folders:
  29. fdpth = osp.join(impth, fd)
  30. im_names = os.listdir(fdpth)
  31. names = [el.replace('_leftImg8bit.png', '') for el in im_names]
  32. impths = [osp.join(fdpth, el) for el in im_names]
  33. imgnames.extend(names)
  34. self.imgs.update(dict(zip(names, impths)))
  35. ## parse gt directory
  36. self.labels = {}
  37. gtnames = []
  38. gtpth = osp.join(rootpth, 'gtFine', mode)
  39. folders = os.listdir(gtpth)
  40. for fd in folders:
  41. fdpth = osp.join(gtpth, fd)
  42. lbnames = os.listdir(fdpth)
  43. lbnames = [el for el in lbnames if 'labelIds' in el]
  44. names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
  45. lbpths = [osp.join(fdpth, el) for el in lbnames]
  46. gtnames.extend(names)
  47. self.labels.update(dict(zip(names, lbpths)))
  48. self.imnames = imgnames
  49. self.len = len(self.imnames)
  50. print('self.len', self.mode, self.len)
  51. assert set(imgnames) == set(gtnames)
  52. assert set(self.imnames) == set(self.imgs.keys())
  53. assert set(self.imnames) == set(self.labels.keys())
  54. ## pre-processing
  55. self.to_tensor = transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  58. ])
  59. self.trans_train = Compose([
  60. ColorJitter(
  61. brightness = 0.5,
  62. contrast = 0.5,
  63. saturation = 0.5),
  64. HorizontalFlip(),
  65. RandomScale(randomscale),
  66. RandomCrop(cropsize)
  67. ])
  68. def __getitem__(self, idx):
  69. fn = self.imnames[idx]
  70. impth = self.imgs[fn]
  71. lbpth = self.labels[fn]
  72. img_tt = impth.split('/')[-1].split('.')[0]
  73. img = Image.open(impth).convert('RGB')
  74. label = Image.open(lbpth)
  75. if self.mode == 'train' or self.mode == 'trainval':
  76. im_lb = dict(im = img, lb = label)
  77. im_lb = self.trans_train(im_lb)
  78. img, label = im_lb['im'], im_lb['lb']
  79. img = self.to_tensor(img)
  80. label = np.array(label).astype(np.int64)[np.newaxis, :]
  81. label = self.convert_labels(label)
  82. return img, label, img_tt
  83. def __len__(self):
  84. return self.len
  85. def convert_labels(self, label):
  86. for k, v in self.lb_map.items():
  87. label[label == k] = v
  88. return label
  89. if __name__ == "__main__":
  90. from tqdm import tqdm
  91. ds = CityScapes('./data/', n_classes=19, mode='val')
  92. uni = []
  93. for im, lb in tqdm(ds):
  94. lb_uni = np.unique(lb).tolist()
  95. uni.extend(lb_uni)
  96. print(uni)
  97. print(set(uni))