85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
"""Collect the value range of different propertity of ps dataset."""
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from data import MarkingPoint
|
|
from data.struct import calc_point_squre_dist, direction_diff
|
|
from prepare_dataset import generalize_marks
|
|
|
|
|
|
def get_parser():
|
|
"""Return argument parser for collecting thresholds."""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--label_directory', required=True,
|
|
help="The location of label directory.")
|
|
return parser
|
|
|
|
|
|
def collect_thresholds(args):
|
|
"""Collect range of value from ground truth to determine threshold."""
|
|
distances = []
|
|
separator_angles = []
|
|
bridge_angles = []
|
|
|
|
for label_file in os.listdir(args.label_directory):
|
|
print(label_file)
|
|
with open(os.path.join(args.label_directory, label_file), 'r') as file:
|
|
label = json.load(file)
|
|
marks = np.array(label['marks'])
|
|
slots = np.array(label['slots'])
|
|
if slots.size == 0:
|
|
continue
|
|
if len(marks.shape) < 2:
|
|
marks = np.expand_dims(marks, axis=0)
|
|
if len(slots.shape) < 2:
|
|
slots = np.expand_dims(slots, axis=0)
|
|
marks[:, 0:4] -= 300.5
|
|
marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)]
|
|
for slot in slots:
|
|
mark_a = marks[slot[0] - 1]
|
|
mark_b = marks[slot[1] - 1]
|
|
distances.append(calc_point_squre_dist(mark_a, mark_b))
|
|
|
|
vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y])
|
|
vector_ab = vector_ab / np.linalg.norm(vector_ab)
|
|
ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0])
|
|
ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0])
|
|
separator_direction = math.atan2(-vector_ab[0], vector_ab[1])
|
|
|
|
sangle = direction_diff(separator_direction, mark_a.direction)
|
|
if mark_a.shape > 0.5:
|
|
separator_angles.append(sangle)
|
|
else:
|
|
bangle = direction_diff(ab_bridge_direction, mark_a.direction)
|
|
if sangle < bangle:
|
|
separator_angles.append(sangle)
|
|
else:
|
|
bridge_angles.append(bangle)
|
|
|
|
bangle = direction_diff(ba_bridge_direction, mark_b.direction)
|
|
if mark_b.shape > 0.5:
|
|
bridge_angles.append(bangle)
|
|
else:
|
|
sangle = direction_diff(separator_direction, mark_b.direction)
|
|
if sangle < bangle:
|
|
separator_angles.append(sangle)
|
|
else:
|
|
bridge_angles.append(bangle)
|
|
|
|
distances = sorted(distances)
|
|
separator_angles = sorted(separator_angles)
|
|
bridge_angles = sorted(bridge_angles)
|
|
plt.figure()
|
|
plt.hist(distances, len(distances) // 10)
|
|
plt.figure()
|
|
plt.hist(separator_angles, len(separator_angles) // 10)
|
|
plt.figure()
|
|
plt.hist(bridge_angles, len(bridge_angles) // 3)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
collect_thresholds(get_parser().parse_args())
|