高速公路违停检测
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

124 行
3.9KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class CityScapes(Dataset):
  13. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  14. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  15. super(CityScapes, self).__init__(*args, **kwargs)
  16. assert mode in ('train', 'val', 'test', 'trainval')
  17. self.mode = mode
  18. print('self.mode', self.mode)
  19. self.ignore_lb = 255
  20. with open('./cityscapes_info.json', 'r') as fr:
  21. labels_info = json.load(fr)
  22. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  23. ## parse img directory
  24. self.imgs = {}
  25. imgnames = []
  26. impth = osp.join(rootpth, 'leftImg8bit', mode)
  27. folders = os.listdir(impth)
  28. for fd in folders:
  29. fdpth = osp.join(impth, fd)
  30. im_names = os.listdir(fdpth)
  31. names = [el.replace('_leftImg8bit.png', '') for el in im_names]
  32. impths = [osp.join(fdpth, el) for el in im_names]
  33. imgnames.extend(names)
  34. self.imgs.update(dict(zip(names, impths)))
  35. ## parse gt directory
  36. self.labels = {}
  37. gtnames = []
  38. gtpth = osp.join(rootpth, 'gtFine', mode)
  39. folders = os.listdir(gtpth)
  40. for fd in folders:
  41. fdpth = osp.join(gtpth, fd)
  42. lbnames = os.listdir(fdpth)
  43. lbnames = [el for el in lbnames if 'labelIds' in el]
  44. names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
  45. lbpths = [osp.join(fdpth, el) for el in lbnames]
  46. gtnames.extend(names)
  47. self.labels.update(dict(zip(names, lbpths)))
  48. self.imnames = imgnames
  49. self.len = len(self.imnames)
  50. print('self.len', self.mode, self.len)
  51. assert set(imgnames) == set(gtnames)
  52. assert set(self.imnames) == set(self.imgs.keys())
  53. assert set(self.imnames) == set(self.labels.keys())
  54. ## pre-processing
  55. self.to_tensor = transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  58. ])
  59. self.trans_train = Compose([
  60. ColorJitter(
  61. brightness = 0.5,
  62. contrast = 0.5,
  63. saturation = 0.5),
  64. HorizontalFlip(),
  65. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  66. RandomScale(randomscale),
  67. # RandomScale((0.125, 1)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  69. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  70. RandomCrop(cropsize)
  71. ])
  72. def __getitem__(self, idx):
  73. fn = self.imnames[idx]
  74. impth = self.imgs[fn]
  75. lbpth = self.labels[fn]
  76. img = Image.open(impth).convert('RGB')
  77. label = Image.open(lbpth)
  78. if self.mode == 'train' or self.mode == 'trainval':
  79. im_lb = dict(im = img, lb = label)
  80. im_lb = self.trans_train(im_lb)
  81. img, label = im_lb['im'], im_lb['lb']
  82. img = self.to_tensor(img)
  83. label = np.array(label).astype(np.int64)[np.newaxis, :]
  84. label = self.convert_labels(label)
  85. return img, label
  86. def __len__(self):
  87. return self.len
  88. def convert_labels(self, label):
  89. for k, v in self.lb_map.items():
  90. label[label == k] = v
  91. return label
  92. if __name__ == "__main__":
  93. from tqdm import tqdm
  94. ds = CityScapes('./data/', n_classes=19, mode='val')
  95. uni = []
  96. for im, lb in tqdm(ds):
  97. lb_uni = np.unique(lb).tolist()
  98. uni.extend(lb_uni)
  99. print(uni)
  100. print(set(uni))