34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
import json
|
|
import os
|
|
import os.path
|
|
import cv2 as cv
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
from data import MarkingPoint
|
|
|
|
|
|
class ParkingSlotDataset(Dataset):
|
|
"""Parking slot dataset."""
|
|
def __init__(self, root):
|
|
super(ParkingSlotDataset, self).__init__()
|
|
self.root = root
|
|
self.sample_names = []
|
|
self.image_transform = transforms.ToTensor()
|
|
for file in os.listdir(root):
|
|
if file.endswith(".json"):
|
|
self.sample_names.append(os.path.splitext(file)[0])
|
|
|
|
def __getitem__(self, index):
|
|
name = self.sample_names[index]
|
|
image = cv.imread(os.path.join(self.root, name+'.jpg'))
|
|
image = self.image_transform(image)
|
|
marking_points = []
|
|
with open(os.path.join(self.root, name + '.json'), 'r') as file:
|
|
for label in json.load(file):
|
|
marking_points.append(MarkingPoint(*label))
|
|
return image, marking_points
|
|
|
|
def __len__(self):
|
|
return len(self.sample_names)
|