DMPR-PS/data/dataset.py

34 lines
1.1 KiB
Python
Raw Normal View History

2018-10-02 17:16:16 +08:00
"""Defines the parking slot dataset for directional marking point detection."""
2018-10-02 15:54:42 +08:00
import json
import os
import os.path
import cv2 as cv
from torch.utils.data import Dataset
2018-10-04 09:30:25 +08:00
from torchvision.transforms import ToTensor
from data.struct import MarkingPoint
2018-10-02 15:54:42 +08:00
class ParkingSlotDataset(Dataset):
"""Parking slot dataset."""
def __init__(self, root):
super(ParkingSlotDataset, self).__init__()
self.root = root
self.sample_names = []
2018-10-04 09:30:25 +08:00
self.image_transform = ToTensor()
2018-10-02 15:54:42 +08:00
for file in os.listdir(root):
if file.endswith(".json"):
self.sample_names.append(os.path.splitext(file)[0])
def __getitem__(self, index):
name = self.sample_names[index]
image = cv.imread(os.path.join(self.root, name+'.jpg'))
image = self.image_transform(image)
marking_points = []
with open(os.path.join(self.root, name + '.json'), 'r') as file:
for label in json.load(file):
marking_points.append(MarkingPoint(*label))
return image, marking_points
def __len__(self):
return len(self.sample_names)