车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 line
3.7KB

  1. """Defines related function to process defined data structure."""
  2. import math
  3. import numpy as np
  4. import torch
  5. import config
  6. from data.struct import MarkingPoint, detemine_point_shape
  7. def non_maximum_suppression(pred_points):
  8. """Perform non-maxmum suppression on marking points."""
  9. suppressed = [False] * len(pred_points)
  10. for i in range(len(pred_points) - 1):
  11. for j in range(i + 1, len(pred_points)):
  12. i_x = pred_points[i][1].x
  13. i_y = pred_points[i][1].y
  14. j_x = pred_points[j][1].x
  15. j_y = pred_points[j][1].y
  16. # 0.0625 = 1 / 16
  17. if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
  18. idx = i if pred_points[i][0] < pred_points[j][0] else j
  19. suppressed[idx] = True
  20. if any(suppressed):
  21. unsupres_pred_points = []
  22. for i, supres in enumerate(suppressed):
  23. if not supres:
  24. unsupres_pred_points.append(pred_points[i])
  25. return unsupres_pred_points
  26. return pred_points
  27. def get_predicted_points(prediction, thresh):
  28. """Get marking points from one predicted feature map."""
  29. assert isinstance(prediction, torch.Tensor)
  30. predicted_points = []
  31. prediction = prediction.detach().cpu().numpy()
  32. for i in range(prediction.shape[1]):
  33. for j in range(prediction.shape[2]):
  34. if prediction[0, i, j] >= thresh:
  35. xval = (j + prediction[2, i, j]) / prediction.shape[2]
  36. yval = (i + prediction[3, i, j]) / prediction.shape[1]
  37. if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
  38. and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
  39. continue
  40. cos_value = prediction[4, i, j]
  41. sin_value = prediction[5, i, j]
  42. direction = math.atan2(sin_value, cos_value)
  43. marking_point = MarkingPoint(
  44. xval, yval, direction, prediction[1, i, j])
  45. predicted_points.append((prediction[0, i, j], marking_point))
  46. return non_maximum_suppression(predicted_points)
  47. def pass_through_third_point(marking_points, i, j):
  48. """See whether the line between two points pass through a third point."""
  49. x_1 = marking_points[i].x
  50. y_1 = marking_points[i].y
  51. x_2 = marking_points[j].x
  52. y_2 = marking_points[j].y
  53. for point_idx, point in enumerate(marking_points):
  54. if point_idx == i or point_idx == j:
  55. continue
  56. x_0 = point.x
  57. y_0 = point.y
  58. vec1 = np.array([x_0 - x_1, y_0 - y_1])
  59. vec2 = np.array([x_2 - x_0, y_2 - y_0])
  60. vec1 = vec1 / np.linalg.norm(vec1)
  61. vec2 = vec2 / np.linalg.norm(vec2)
  62. if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
  63. return True
  64. return False
  65. def pair_marking_points(point_a, point_b):
  66. """See whether two marking points form a slot."""
  67. vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
  68. vector_ab = vector_ab / np.linalg.norm(vector_ab)
  69. point_shape_a = detemine_point_shape(point_a, vector_ab)
  70. point_shape_b = detemine_point_shape(point_b, -vector_ab)
  71. if point_shape_a.value == 0 or point_shape_b.value == 0:
  72. return 0
  73. if point_shape_a.value == 3 and point_shape_b.value == 3:
  74. return 0
  75. if point_shape_a.value > 3 and point_shape_b.value > 3:
  76. return 0
  77. if point_shape_a.value < 3 and point_shape_b.value < 3:
  78. return 0
  79. if point_shape_a.value != 3:
  80. if point_shape_a.value > 3:
  81. return 1
  82. if point_shape_a.value < 3:
  83. return -1
  84. if point_shape_a.value == 3:
  85. if point_shape_b.value < 3:
  86. return 1
  87. if point_shape_b.value > 3:
  88. return -1