交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.9KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. from torch.utils.data import Dataset
  5. import torchvision.transforms as transforms
  6. import os.path as osp
  7. import os
  8. from PIL import Image
  9. import numpy as np
  10. import json
  11. from transform import *
  12. class CityScapes(Dataset):
  13. def __init__(self, rootpth, cropsize=(640, 480), mode='train',
  14. randomscale=(0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.25, 1.5), *args, **kwargs):
  15. super(CityScapes, self).__init__(*args, **kwargs)
  16. assert mode in ('train', 'val', 'test', 'trainval')
  17. self.mode = mode
  18. print('self.mode', self.mode)
  19. self.ignore_lb = 255
  20. with open('./cityscapes_info.json', 'r') as fr:
  21. labels_info = json.load(fr)
  22. self.lb_map = {el['id']: el['trainId'] for el in labels_info}
  23. ## parse img directory
  24. self.imgs = {}
  25. imgnames = []
  26. impth = osp.join(rootpth, 'leftImg8bit', mode)
  27. folders = os.listdir(impth)
  28. for fd in folders:
  29. fdpth = osp.join(impth, fd)
  30. im_names = os.listdir(fdpth)
  31. names = [el.replace('_leftImg8bit.png', '') for el in im_names]
  32. impths = [osp.join(fdpth, el) for el in im_names]
  33. imgnames.extend(names)
  34. self.imgs.update(dict(zip(names, impths)))
  35. ## parse gt directory
  36. self.labels = {}
  37. gtnames = []
  38. gtpth = osp.join(rootpth, 'gtFine', mode)
  39. folders = os.listdir(gtpth)
  40. for fd in folders:
  41. fdpth = osp.join(gtpth, fd)
  42. lbnames = os.listdir(fdpth)
  43. lbnames = [el for el in lbnames if 'labelIds' in el]
  44. names = [el.replace('_gtFine_labelIds.png', '') for el in lbnames]
  45. lbpths = [osp.join(fdpth, el) for el in lbnames]
  46. gtnames.extend(names)
  47. self.labels.update(dict(zip(names, lbpths)))
  48. self.imnames = imgnames
  49. self.len = len(self.imnames)
  50. print('self.len', self.mode, self.len)
  51. assert set(imgnames) == set(gtnames)
  52. assert set(self.imnames) == set(self.imgs.keys())
  53. assert set(self.imnames) == set(self.labels.keys())
  54. ## pre-processing
  55. self.to_tensor = transforms.Compose([
  56. transforms.ToTensor(),
  57. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  58. ])
  59. self.trans_train = Compose([
  60. ColorJitter(
  61. brightness = 0.5,
  62. contrast = 0.5,
  63. saturation = 0.5),
  64. HorizontalFlip(),
  65. # RandomScale((0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
  66. RandomScale(randomscale),
  67. # RandomScale((0.125, 1)),
  68. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0)),
  69. # RandomScale((0.125, 0.25, 0.375, 0.5, 0.675, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)),
  70. RandomCrop(cropsize)
  71. ])
  72. def __getitem__(self, idx):
  73. fn = self.imnames[idx]
  74. impth = self.imgs[fn]
  75. lbpth = self.labels[fn]
  76. img = Image.open(impth).convert('RGB')
  77. label = Image.open(lbpth)
  78. if self.mode == 'train' or self.mode == 'trainval':
  79. im_lb = dict(im = img, lb = label)
  80. im_lb = self.trans_train(im_lb)
  81. img, label = im_lb['im'], im_lb['lb']
  82. img = self.to_tensor(img)
  83. label = np.array(label).astype(np.int64)[np.newaxis, :]
  84. label = self.convert_labels(label)
  85. return img, label
  86. def __len__(self):
  87. return self.len
  88. def convert_labels(self, label):
  89. for k, v in self.lb_map.items():
  90. label[label == k] = v
  91. return label
  92. if __name__ == "__main__":
  93. from tqdm import tqdm
  94. ds = CityScapes('./data/', n_classes=19, mode='val')
  95. uni = []
  96. for im, lb in tqdm(ds):
  97. lb_uni = np.unique(lb).tolist()
  98. uni.extend(lb_uni)
  99. print(uni)
  100. print(set(uni))