# -*- coding: utf-8 -*- import json import os import os.path import cv2 as cv from torch.utils.data import Dataset from torchvision import transforms from data import MarkingPoint class ParkingSlotDataset(Dataset): """Parking slot dataset.""" def __init__(self, root): super(ParkingSlotDataset, self).__init__() self.root = root self.sample_names = [] self.image_transform = transforms.ToTensor() for file in os.listdir(root): if file.endswith(".json"): self.sample_names.append(os.path.splitext(file)[0]) def __getitem__(self, index): name = self.sample_names[index] image = cv.imread(os.path.join(self.root, name+'.jpg')) image = self.image_transform(image) marking_points = [] with open(os.path.join(self.root, name + '.json'), 'r') as file: for label in json.load(file): marking_points.append(MarkingPoint(*label)) return image, marking_points def __len__(self): return len(self.sample_names)