车位角点检测代码
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

34 行
1.1KB

  1. """Defines the parking slot dataset for directional marking point detection."""
  2. import json
  3. import os
  4. import os.path
  5. import cv2 as cv
  6. from torch.utils.data import Dataset
  7. from torchvision.transforms import ToTensor
  8. from data.struct import MarkingPoint
  9. class ParkingSlotDataset(Dataset):
  10. """Parking slot dataset."""
  11. def __init__(self, root):
  12. super(ParkingSlotDataset, self).__init__()
  13. self.root = root
  14. self.sample_names = []
  15. self.image_transform = ToTensor()
  16. for file in os.listdir(root):
  17. if file.endswith(".json"):
  18. self.sample_names.append(os.path.splitext(file)[0])
  19. def __getitem__(self, index):
  20. name = self.sample_names[index]
  21. image = cv.imread(os.path.join(self.root, name+'.jpg'))
  22. image = self.image_transform(image)
  23. marking_points = []
  24. with open(os.path.join(self.root, name + '.json'), 'r') as file:
  25. for label in json.load(file):
  26. marking_points.append(MarkingPoint(*label))
  27. return image, marking_points
  28. def __len__(self):
  29. return len(self.sample_names)