DMPR-PS/dataset.py

34 lines
1.0 KiB
Python

# -*- coding: utf-8 -*-
import json
import os
import os.path
import cv2 as cv
from torch.utils.data import Dataset
from torchvision import transforms
from data import MarkingPoint
class ParkingSlotDataset(Dataset):
"""Parking slot dataset."""
def __init__(self, root):
super(ParkingSlotDataset, self).__init__()
self.root = root
self.sample_names = []
self.image_transform = transforms.ToTensor()
for file in os.listdir(root):
if file.endswith(".json"):
self.sample_names.append(os.path.splitext(file)[0])
def __getitem__(self, index):
name = self.sample_names[index]
image = cv.imread(os.path.join(self.root, name+'.jpg'))
image = self.image_transform(image)
marking_points = []
with open(os.path.join(self.root, name + '.json'), 'r') as file:
for label in json.load(file):
marking_points.append(MarkingPoint(*label))
return image, marking_points
def __len__(self):
return len(self.sample_names)