"""Collect the value range of different propertity of ps dataset.""" import argparse import json import math import os import matplotlib.pyplot as plt import numpy as np from data import MarkingPoint from data.struct import calc_point_squre_dist, direction_diff from prepare_dataset import generalize_marks def get_parser(): """Return argument parser for collecting thresholds.""" parser = argparse.ArgumentParser() parser.add_argument('--label_directory', required=True, help="The location of label directory.") return parser def collect_thresholds(args): """Collect range of value from ground truth to determine threshold.""" distances = [] separator_angles = [] bridge_angles = [] for label_file in os.listdir(args.label_directory): print(label_file) with open(os.path.join(args.label_directory, label_file), 'r') as file: label = json.load(file) marks = np.array(label['marks']) slots = np.array(label['slots']) if slots.size == 0: continue if len(marks.shape) < 2: marks = np.expand_dims(marks, axis=0) if len(slots.shape) < 2: slots = np.expand_dims(slots, axis=0) marks[:, 0:4] -= 300.5 marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)] for slot in slots: mark_a = marks[slot[0] - 1] mark_b = marks[slot[1] - 1] distances.append(calc_point_squre_dist(mark_a, mark_b)) vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y]) vector_ab = vector_ab / np.linalg.norm(vector_ab) ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0]) ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0]) separator_direction = math.atan2(-vector_ab[0], vector_ab[1]) sangle = direction_diff(separator_direction, mark_a.direction) if mark_a.shape > 0.5: separator_angles.append(sangle) else: bangle = direction_diff(ab_bridge_direction, mark_a.direction) if sangle < bangle: separator_angles.append(sangle) else: bridge_angles.append(bangle) bangle = direction_diff(ba_bridge_direction, mark_b.direction) if mark_b.shape > 0.5: bridge_angles.append(bangle) else: sangle = direction_diff(separator_direction, mark_b.direction) if sangle < bangle: separator_angles.append(sangle) else: bridge_angles.append(bangle) distances = sorted(distances) separator_angles = sorted(separator_angles) bridge_angles = sorted(bridge_angles) plt.figure() plt.hist(distances, len(distances) // 10) plt.figure() plt.hist(separator_angles, len(separator_angles) // 10) plt.figure() plt.hist(bridge_angles, len(bridge_angles) // 3) if __name__ == '__main__': collect_thresholds(get_parser().parse_args())