DMPR-PS/data/dataset.py

34 lines
1.1 KiB
Python

"""Defines the parking slot dataset for directional marking point detection."""
import json
import os
import os.path
import cv2 as cv
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from data.struct import MarkingPoint
class ParkingSlotDataset(Dataset):
"""Parking slot dataset."""
def __init__(self, root):
super(ParkingSlotDataset, self).__init__()
self.root = root
self.sample_names = []
self.image_transform = ToTensor()
for file in os.listdir(root):
if file.endswith(".json"):
self.sample_names.append(os.path.splitext(file)[0])
def __getitem__(self, index):
name = self.sample_names[index]
image = cv.imread(os.path.join(self.root, name+'.jpg'))
image = self.image_transform(image)
marking_points = []
with open(os.path.join(self.root, name + '.json'), 'r') as file:
for label in json.load(file):
marking_points.append(MarkingPoint(*label))
return image, marking_points
def __len__(self):
return len(self.sample_names)