"""Defines the parking slot dataset for directional marking point detection.""" import json import os import os.path import cv2 as cv from torch.utils.data import Dataset from torchvision.transforms import ToTensor from data.struct import MarkingPoint class ParkingSlotDataset(Dataset): """Parking slot dataset.""" def __init__(self, root): super(ParkingSlotDataset, self).__init__() self.root = root self.sample_names = [] self.image_transform = ToTensor() for file in os.listdir(root): if file.endswith(".json"): self.sample_names.append(os.path.splitext(file)[0]) def __getitem__(self, index): name = self.sample_names[index] image = cv.imread(os.path.join(self.root, name+'.jpg')) image = self.image_transform(image) marking_points = [] with open(os.path.join(self.root, name + '.json'), 'r') as file: for label in json.load(file): marking_points.append(MarkingPoint(*label)) return image, marking_points def __len__(self): return len(self.sample_names)