|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- """Defines data structure."""
- import math
- from collections import namedtuple
- from enum import Enum
- import config
-
-
- MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
- Slot = namedtuple('Slot', ['x1', 'y1', 'x2', 'y2'])
-
-
- class PointShape(Enum):
- """The point shape types used to pair two marking points into slot."""
- none = 0
- l_down = 1
- t_down = 2
- t_middle = 3
- t_up = 4
- l_up = 5
-
-
- def direction_diff(direction_a, direction_b):
- """Calculate the angle between two direction."""
- diff = abs(direction_a - direction_b)
- return diff if diff < math.pi else 2*math.pi - diff
-
-
- def detemine_point_shape(point, vector):
- """Determine which category the point is in."""
- vec_direct = math.atan2(vector[1], vector[0])
- vec_direct_up = math.atan2(-vector[0], vector[1])
- vec_direct_down = math.atan2(vector[0], -vector[1])
- if point.shape < 0.5:
- if direction_diff(vec_direct, point.direction) < config.BRIDGE_ANGLE_DIFF:
- return PointShape.t_middle
- if direction_diff(vec_direct_up, point.direction) < config.SEPARATOR_ANGLE_DIFF:
- return PointShape.t_up
- if direction_diff(vec_direct_down, point.direction) < config.SEPARATOR_ANGLE_DIFF:
- return PointShape.t_down
- else:
- if direction_diff(vec_direct, point.direction) < config.BRIDGE_ANGLE_DIFF:
- return PointShape.l_down
- if direction_diff(vec_direct_up, point.direction) < config.SEPARATOR_ANGLE_DIFF:
- return PointShape.l_up
- return PointShape.none
-
-
- def calc_point_squre_dist(point_a, point_b):
- """Calculate distance between two marking points."""
- distx = point_a.x - point_b.x
- disty = point_a.y - point_b.y
- return distx ** 2 + disty ** 2
-
-
- def calc_point_direction_angle(point_a, point_b):
- """Calculate angle between direction in rad."""
- return direction_diff(point_a.direction, point_b.direction)
-
-
- def match_marking_points(point_a, point_b):
- """Determine whether a detected point match ground truth."""
- dist_square = calc_point_squre_dist(point_a, point_b)
- angle = calc_point_direction_angle(point_a, point_b)
- if point_a.shape > 0.5 and point_b.shape < 0.5:
- return False
- if point_a.shape < 0.5 and point_b.shape > 0.5:
- return False
- return (dist_square < config.SQUARED_DISTANCE_THRESH
- and angle < config.DIRECTION_ANGLE_THRESH)
-
-
- def match_slots(slot_a, slot_b):
- """Determine whether a detected slot match ground truth."""
- dist_x1 = slot_b.x1 - slot_a.x1
- dist_y1 = slot_b.y1 - slot_a.y1
- squared_dist1 = dist_x1**2 + dist_y1**2
- dist_x2 = slot_b.x2 - slot_a.x2
- dist_y2 = slot_b.y2 - slot_a.y2
- squared_dist2 = dist_x2 ** 2 + dist_y2 ** 2
- return (squared_dist1 < config.SQUARED_DISTANCE_THRESH
- and squared_dist2 < config.SQUARED_DISTANCE_THRESH)
|