高速公路违停检测
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

cityscapes.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class CityScapes(Dataset):
  13. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  14. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  15. super(CityScapes, self).__init__(*args, **kwargs)
  16. assert mode in ('train', 'val', 'test', 'trainval')
  17. self.mode = mode
  18. print('self.mode', self.mode)
  19. self.ignore_lb = 255
  20. with open('./cityscapes_info.json', 'r') as fr:
  21. labels_info = json.load(fr)
  22. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  23. ## parse img directory
  24. self.imgs = {}
  25. imgnames = []
  26. impth = osp.join(rootpth, 'leftImg8bit', mode)
  27. folders = os.listdir(impth)
  28. for fd in folders:
  29. fdpth = osp.join(impth, fd)
  30. im_names = os.listdir(fdpth)
  31. names = [el.replace('_leftImg8bit.png', '') for el in im_names]
  32. impths = [osp.join(fdpth, el) for el in im_names]
  33. imgnames.extend(names)
  34. self.imgs.update(dict(zip(names, impths)))
  35. ## parse gt directory
  36. self.labels = {}
  37. gtnames = []
  38. gtpth = osp.join(rootpth, 'gtFine', mode)
  39. folders = os.listdir(gtpth)
  40. for fd in folders:
  41. fdpth = osp.join(gtpth, fd)
  42. lbnames = os.listdir(fdpth)
  43. lbnames = [el for el in lbnames if 'labelIds' in el]
  44. names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
  45. lbpths = [osp.join(fdpth, el) for el in lbnames]
  46. gtnames.extend(names)
  47. self.labels.update(dict(zip(names, lbpths)))
  48. self.imnames = imgnames
  49. self.len = len(self.imnames)
  50. print('self.len', self.mode, self.len)
  51. assert set(imgnames) == set(gtnames)
  52. assert set(self.imnames) == set(self.imgs.keys())
  53. assert set(self.imnames) == set(self.labels.keys())
  54. ## pre-processing
  55. self.to_tensor = transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  58. ])
  59. self.trans_train = Compose([
  60. ColorJitter(
  61. brightness = 0.5,
  62. contrast = 0.5,
  63. saturation = 0.5),
  64. HorizontalFlip(),
  65. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  66. RandomScale(randomscale),
  67. # RandomScale((0.125, 1)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  69. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  70. RandomCrop(cropsize)
  71. ])
  72. def __getitem__(self, idx):
  73. fn = self.imnames[idx]
  74. impth = self.imgs[fn]
  75. lbpth = self.labels[fn]
  76. img = Image.open(impth).convert('RGB')
  77. label = Image.open(lbpth)
  78. if self.mode == 'train' or self.mode == 'trainval':
  79. im_lb = dict(im = img, lb = label)
  80. im_lb = self.trans_train(im_lb)
  81. img, label = im_lb['im'], im_lb['lb']
  82. img = self.to_tensor(img)
  83. label = np.array(label).astype(np.int64)[np.newaxis, :]
  84. label = self.convert_labels(label)
  85. return img, label
  86. def __len__(self):
  87. return self.len
  88. def convert_labels(self, label):
  89. for k, v in self.lb_map.items():
  90. label[label == k] = v
  91. return label
  92. if __name__ == "__main__":
  93. from tqdm import tqdm
  94. ds = CityScapes('./data/', n_classes=19, mode='val')
  95. uni = []
  96. for im, lb in tqdm(ds):
  97. lb_uni = np.unique(lb).tolist()
  98. uni.extend(lb_uni)
  99. print(uni)
  100. print(set(uni))