车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
4.2KB

  1. """Defines data structure and related function to process these data."""
  2. import math
  3. import numpy as np
  4. import torch
  5. import config
  6. from data.struct import MarkingPoint, calc_point_squre_dist, detemine_point_shape
  7. def non_maximum_suppression(pred_points):
  8. """Perform non-maxmum suppression on marking points."""
  9. suppressed = [False] * len(pred_points)
  10. for i in range(len(pred_points) - 1):
  11. for j in range(i + 1, len(pred_points)):
  12. i_x = pred_points[i][1].x
  13. i_y = pred_points[i][1].y
  14. j_x = pred_points[j][1].x
  15. j_y = pred_points[j][1].y
  16. # 0.0625 = 1 / 16
  17. if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
  18. idx = i if pred_points[i][0] < pred_points[j][0] else j
  19. suppressed[idx] = True
  20. if any(suppressed):
  21. unsupres_pred_points = []
  22. for i, supres in enumerate(suppressed):
  23. if not supres:
  24. unsupres_pred_points.append(pred_points[i])
  25. return unsupres_pred_points
  26. return pred_points
  27. def get_predicted_points(prediction, thresh):
  28. """Get marking points from one predicted feature map."""
  29. assert isinstance(prediction, torch.Tensor)
  30. predicted_points = []
  31. prediction = prediction.detach().cpu().numpy()
  32. for i in range(prediction.shape[1]):
  33. for j in range(prediction.shape[2]):
  34. if prediction[0, i, j] >= thresh:
  35. xval = (j + prediction[2, i, j]) / prediction.shape[2]
  36. yval = (i + prediction[3, i, j]) / prediction.shape[1]
  37. if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
  38. and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
  39. continue
  40. cos_value = prediction[4, i, j]
  41. sin_value = prediction[5, i, j]
  42. direction = math.atan2(sin_value, cos_value)
  43. marking_point = MarkingPoint(
  44. xval, yval, direction, prediction[1, i, j])
  45. predicted_points.append((prediction[0, i, j], marking_point))
  46. return non_maximum_suppression(predicted_points)
  47. def pair_marking_points(point_a, point_b):
  48. distance = calc_point_squre_dist(point_a, point_b)
  49. if not (config.VSLOT_MIN_DISTANCE <= distance <= config.VSLOT_MAX_DISTANCE
  50. or config.HSLOT_MIN_DISTANCE <= distance <= config.HSLOT_MAX_DISTANCE):
  51. return 0
  52. vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
  53. vector_ab = vector_ab / np.linalg.norm(vector_ab)
  54. point_shape_a = detemine_point_shape(point_a, vector_ab)
  55. point_shape_b = detemine_point_shape(point_b, -vector_ab)
  56. if point_shape_a.value == 0 or point_shape_b.value == 0:
  57. return 0
  58. if point_shape_a.value == 3 and point_shape_b.value == 3:
  59. return 0
  60. if point_shape_a.value > 3 and point_shape_b.value > 3:
  61. return 0
  62. if point_shape_a.value < 3 and point_shape_b.value < 3:
  63. return 0
  64. if point_shape_a.value != 3:
  65. if point_shape_a.value > 3:
  66. return 1
  67. if point_shape_a.value < 3:
  68. return -1
  69. if point_shape_a.value == 3:
  70. if point_shape_b.value < 3:
  71. return 1
  72. if point_shape_b.value > 3:
  73. return -1
  74. def filter_slots(marking_points, slots):
  75. suppressed = [False] * len(slots)
  76. for i, slot in enumerate(slots):
  77. x1 = marking_points[slot[0]].x
  78. y1 = marking_points[slot[0]].y
  79. x2 = marking_points[slot[1]].x
  80. y2 = marking_points[slot[1]].y
  81. for point_idx, point in enumerate(marking_points):
  82. if point_idx == slot[0] or point_idx == slot[1]:
  83. continue
  84. x0 = point.x
  85. y0 = point.y
  86. vec1 = np.array([x0 - x1, y0 - y1])
  87. vec2 = np.array([x2 - x0, y2 - y0])
  88. vec1 = vec1 / np.linalg.norm(vec1)
  89. vec2 = vec2 / np.linalg.norm(vec2)
  90. if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
  91. suppressed[i] = True
  92. if any(suppressed):
  93. unsupres_slots = []
  94. for i, supres in enumerate(suppressed):
  95. if not supres:
  96. unsupres_slots.append(slots[i])
  97. return unsupres_slots
  98. return slots