车位角点检测代码
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

il y a 5 ans
il y a 5 ans
il y a 5 ans
il y a 5 ans
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """Defines related function to process defined data structure."""
  2. import math
  3. import numpy as np
  4. import torch
  5. import config
  6. from data.struct import MarkingPoint, detemine_point_shape
  7. def non_maximum_suppression(pred_points):
  8. """Perform non-maxmum suppression on marking points."""
  9. suppressed = [False] * len(pred_points)
  10. for i in range(len(pred_points) - 1):
  11. for j in range(i + 1, len(pred_points)):
  12. i_x = pred_points[i][1].x
  13. i_y = pred_points[i][1].y
  14. j_x = pred_points[j][1].x
  15. j_y = pred_points[j][1].y
  16. # 0.0625 = 1 / 16
  17. if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
  18. idx = i if pred_points[i][0] < pred_points[j][0] else j
  19. suppressed[idx] = True
  20. if any(suppressed):
  21. unsupres_pred_points = []
  22. for i, supres in enumerate(suppressed):
  23. if not supres:
  24. unsupres_pred_points.append(pred_points[i])
  25. return unsupres_pred_points
  26. return pred_points
  27. def get_predicted_points(prediction, thresh):
  28. """Get marking points from one predicted feature map."""
  29. assert isinstance(prediction, torch.Tensor)
  30. predicted_points = []
  31. prediction = prediction.detach().cpu().numpy()
  32. for i in range(prediction.shape[1]):
  33. for j in range(prediction.shape[2]):
  34. if prediction[0, i, j] >= thresh:
  35. xval = (j + prediction[2, i, j]) / prediction.shape[2]
  36. yval = (i + prediction[3, i, j]) / prediction.shape[1]
  37. if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
  38. and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
  39. continue
  40. cos_value = prediction[4, i, j]
  41. sin_value = prediction[5, i, j]
  42. direction = math.atan2(sin_value, cos_value)
  43. marking_point = MarkingPoint(
  44. xval, yval, direction, prediction[1, i, j])
  45. predicted_points.append((prediction[0, i, j], marking_point))
  46. return non_maximum_suppression(predicted_points)
  47. def pass_through_third_point(marking_points, i, j):
  48. """See whether the line between two points pass through a third point."""
  49. x_1 = marking_points[i].x
  50. y_1 = marking_points[i].y
  51. x_2 = marking_points[j].x
  52. y_2 = marking_points[j].y
  53. for point_idx, point in enumerate(marking_points):
  54. if point_idx == i or point_idx == j:
  55. continue
  56. x_0 = point.x
  57. y_0 = point.y
  58. vec1 = np.array([x_0 - x_1, y_0 - y_1])
  59. vec2 = np.array([x_2 - x_0, y_2 - y_0])
  60. vec1 = vec1 / np.linalg.norm(vec1)
  61. vec2 = vec2 / np.linalg.norm(vec2)
  62. if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
  63. return True
  64. return False
  65. def pair_marking_points(point_a, point_b):
  66. """See whether two marking points form a slot."""
  67. vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
  68. vector_ab = vector_ab / np.linalg.norm(vector_ab)
  69. point_shape_a = detemine_point_shape(point_a, vector_ab)
  70. point_shape_b = detemine_point_shape(point_b, -vector_ab)
  71. if point_shape_a.value == 0 or point_shape_b.value == 0:
  72. return 0
  73. if point_shape_a.value == 3 and point_shape_b.value == 3:
  74. return 0
  75. if point_shape_a.value > 3 and point_shape_b.value > 3:
  76. return 0
  77. if point_shape_a.value < 3 and point_shape_b.value < 3:
  78. return 0
  79. if point_shape_a.value != 3:
  80. if point_shape_a.value > 3:
  81. return 1
  82. if point_shape_a.value < 3:
  83. return -1
  84. if point_shape_a.value == 3:
  85. if point_shape_b.value < 3:
  86. return 1
  87. if point_shape_b.value > 3:
  88. return -1