148 lines
6.1 KiB
Python
148 lines
6.1 KiB
Python
|
|
"""Perform data augmentation and preprocessing."""
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
import os
|
||
|
|
import random
|
||
|
|
import cv2 as cv
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
def get_parser():
|
||
|
|
"""Return argument parser for generating dataset."""
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument('--dataset', required=True,
|
||
|
|
choices=['trainval', 'test'],
|
||
|
|
help="Generate trainval or test dataset.")
|
||
|
|
parser.add_argument('--val_prop', type=float, default=0.1,
|
||
|
|
help="The proportion of val sample in trainval.")
|
||
|
|
parser.add_argument('--label_directory', required=True,
|
||
|
|
help="The location of label directory.")
|
||
|
|
parser.add_argument('--image_directory', required=True,
|
||
|
|
help="The location of image directory.")
|
||
|
|
parser.add_argument('--output_directory', required=True,
|
||
|
|
help="The location of output directory.")
|
||
|
|
return parser
|
||
|
|
|
||
|
|
|
||
|
|
def boundary_check(centralied_marks):
|
||
|
|
"""Check situation that marking point appears too near to border."""
|
||
|
|
for mark in centralied_marks:
|
||
|
|
if mark[0] < -260 or mark[0] > 260 or mark[1] < -260 or mark[1] > 260:
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def overlap_check(centralied_marks):
|
||
|
|
"""Check situation that multiple marking points appear in same cell."""
|
||
|
|
for i in range(len(centralied_marks) - 1):
|
||
|
|
i_x = centralied_marks[i, 0]
|
||
|
|
i_y = centralied_marks[i, 1]
|
||
|
|
for j in range(i + 1, len(centralied_marks)):
|
||
|
|
j_x = centralied_marks[j, 0]
|
||
|
|
j_y = centralied_marks[j, 1]
|
||
|
|
if abs(j_x - i_x) < 600 / 16 and abs(j_y - i_y) < 600 / 16:
|
||
|
|
return False
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def generalize_marks(centralied_marks):
|
||
|
|
"""Convert coordinate to [0, 1] and calculate direction label."""
|
||
|
|
generalized_marks = []
|
||
|
|
for mark in centralied_marks:
|
||
|
|
xval = (mark[0] + 300) / 600
|
||
|
|
yval = (mark[1] + 300) / 600
|
||
|
|
direction = math.atan2(mark[3] - mark[1], mark[2] - mark[0])
|
||
|
|
generalized_marks.append([xval, yval, direction, mark[4]])
|
||
|
|
return generalized_marks
|
||
|
|
|
||
|
|
|
||
|
|
def write_image_and_label(name, image, centralied_marks, name_list):
|
||
|
|
"""Write image and label with given name."""
|
||
|
|
name_list.append(os.path.basename(name))
|
||
|
|
print("Processing NO.%d samples: %s..." % (len(name_list), name_list[-1]))
|
||
|
|
image = cv.resize(image, (512, 512))
|
||
|
|
cv.imwrite(name + '.jpg', image, [int(cv.IMWRITE_JPEG_QUALITY), 100])
|
||
|
|
with open(name + '.json', 'w') as file:
|
||
|
|
json.dump(generalize_marks(centralied_marks), file)
|
||
|
|
|
||
|
|
|
||
|
|
def rotate_vector(vector, angle_degree):
|
||
|
|
"""Rotate a vector with given angle in degree."""
|
||
|
|
angle_rad = math.pi * angle_degree / 180
|
||
|
|
xval = vector[0]*math.cos(angle_rad) + vector[1]*math.sin(angle_rad)
|
||
|
|
yval = -vector[0]*math.sin(angle_rad) + vector[1]*math.cos(angle_rad)
|
||
|
|
return xval, yval
|
||
|
|
|
||
|
|
|
||
|
|
def rotate_centralized_marks(centralied_marks, angle_degree):
|
||
|
|
"""Rotate centralized marks with given angle in degree."""
|
||
|
|
rotated_marks = centralied_marks.copy()
|
||
|
|
for i in range(centralied_marks.shape[0]):
|
||
|
|
mark = centralied_marks[i]
|
||
|
|
rotated_marks[i, 0:2] = rotate_vector(mark[0:2], angle_degree)
|
||
|
|
rotated_marks[i, 2:4] = rotate_vector(mark[2:4], angle_degree)
|
||
|
|
return rotated_marks
|
||
|
|
|
||
|
|
|
||
|
|
def rotate_image(image, angle_degree):
|
||
|
|
"""Rotate image with given angle in degree."""
|
||
|
|
rows, cols, _ = image.shape
|
||
|
|
rotation_matrix = cv.getRotationMatrix2D((rows/2, cols/2), angle_degree, 1)
|
||
|
|
return cv.warpAffine(image, rotation_matrix, (rows, cols))
|
||
|
|
|
||
|
|
|
||
|
|
def generate_dataset(args):
|
||
|
|
"""Generate dataset according to arguments."""
|
||
|
|
if args.dataset == 'trainval':
|
||
|
|
val_directory = os.path.join(args.output_directory, 'val')
|
||
|
|
args.output_directory = os.path.join(args.output_directory, 'train')
|
||
|
|
elif args.dataset == 'test':
|
||
|
|
args.output_directory = os.path.join(args.output_directory, 'test')
|
||
|
|
os.makedirs(args.output_directory, exist_ok=True)
|
||
|
|
name_list = []
|
||
|
|
for label_file in os.listdir(args.label_directory):
|
||
|
|
name = os.path.splitext(label_file)[0]
|
||
|
|
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
|
||
|
|
with open(os.path.join(args.label_directory, label_file), 'r') as file:
|
||
|
|
label = json.load(file)
|
||
|
|
centralied_marks = np.array(label['marks'])
|
||
|
|
if len(centralied_marks.shape) < 2:
|
||
|
|
centralied_marks = np.expand_dims(centralied_marks, axis=0)
|
||
|
|
centralied_marks[:, 0: 4] -= 300.5
|
||
|
|
if boundary_check(centralied_marks):
|
||
|
|
output_name = os.path.join(args.output_directory, name)
|
||
|
|
write_image_and_label(output_name, image,
|
||
|
|
centralied_marks, name_list)
|
||
|
|
|
||
|
|
if args.dataset == 'test':
|
||
|
|
continue
|
||
|
|
for angle in range(5, 360, 5):
|
||
|
|
rotated_marks = rotate_centralized_marks(centralied_marks, angle)
|
||
|
|
if boundary_check(rotated_marks) and overlap_check(rotated_marks):
|
||
|
|
rotated_image = rotate_image(image, angle)
|
||
|
|
output_name = os.path.join(
|
||
|
|
args.output_directory, name + '_' + str(angle))
|
||
|
|
write_image_and_label(
|
||
|
|
output_name, rotated_image, rotated_marks, name_list)
|
||
|
|
|
||
|
|
if args.dataset == 'trainval':
|
||
|
|
print("Dividing training set and validation set...")
|
||
|
|
val_idx = random.sample(list(range(len(name_list))),
|
||
|
|
int(round(len(name_list)*args.val_prop)))
|
||
|
|
val_samples = [name_list[idx] for idx in val_idx]
|
||
|
|
os.makedirs(val_directory, exist_ok=True)
|
||
|
|
for val_sample in val_samples:
|
||
|
|
train_directory = args.output_directory
|
||
|
|
image_src = os.path.join(train_directory, val_sample + '.jpg')
|
||
|
|
label_src = os.path.join(train_directory, val_sample + '.json')
|
||
|
|
image_dst = os.path.join(val_directory, val_sample + '.jpg')
|
||
|
|
label_dst = os.path.join(val_directory, val_sample + '.json')
|
||
|
|
os.rename(image_src, image_dst)
|
||
|
|
os.rename(label_src, label_dst)
|
||
|
|
print("Done.")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
generate_dataset(get_parser().parse_args())
|