车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
7.2KB

  1. """Inference demo of directional point detector."""
  2. import math
  3. import cv2 as cv
  4. import numpy as np
  5. import torch
  6. from torchvision.transforms import ToTensor
  7. import config
  8. from data import get_predicted_points, pair_marking_points, calc_point_squre_dist, pass_through_third_point
  9. from model import DirectionalPointDetector
  10. from util import Timer
  11. def plot_points(image, pred_points):
  12. """Plot marking points on the image."""
  13. if not pred_points:
  14. return
  15. height = image.shape[0]
  16. width = image.shape[1]
  17. for confidence, marking_point in pred_points:
  18. p0_x = width * marking_point.x - 0.5
  19. p0_y = height * marking_point.y - 0.5
  20. cos_val = math.cos(marking_point.direction)
  21. sin_val = math.sin(marking_point.direction)
  22. p1_x = p0_x + 50*cos_val
  23. p1_y = p0_y + 50*sin_val
  24. p2_x = p0_x - 50*sin_val
  25. p2_y = p0_y + 50*cos_val
  26. p3_x = p0_x + 50*sin_val
  27. p3_y = p0_y - 50*cos_val
  28. p0_x = int(round(p0_x))
  29. p0_y = int(round(p0_y))
  30. p1_x = int(round(p1_x))
  31. p1_y = int(round(p1_y))
  32. p2_x = int(round(p2_x))
  33. p2_y = int(round(p2_y))
  34. cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), 2)
  35. cv.putText(image, str(confidence), (p0_x, p0_y),
  36. cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
  37. if marking_point.shape > 0.5:
  38. cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), 2)
  39. else:
  40. p3_x = int(round(p3_x))
  41. p3_y = int(round(p3_y))
  42. cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), 2)
  43. def plot_slots(image, pred_points, slots):
  44. """Plot parking slots on the image."""
  45. if not pred_points or not slots:
  46. return
  47. marking_points = list(list(zip(*pred_points))[1])
  48. height = image.shape[0]
  49. width = image.shape[1]
  50. for slot in slots:
  51. point_a = marking_points[slot[0]]
  52. point_b = marking_points[slot[1]]
  53. p0_x = width * point_a.x - 0.5
  54. p0_y = height * point_a.y - 0.5
  55. p1_x = width * point_b.x - 0.5
  56. p1_y = height * point_b.y - 0.5
  57. vec = np.array([p1_x - p0_x, p1_y - p0_y])
  58. vec = vec / np.linalg.norm(vec)
  59. distance = calc_point_squre_dist(point_a, point_b)
  60. if config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST:
  61. separating_length = config.LONG_SEPARATOR_LENGTH
  62. elif config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST:
  63. separating_length = config.SHORT_SEPARATOR_LENGTH
  64. p2_x = p0_x + height * separating_length * vec[1]
  65. p2_y = p0_y - width * separating_length * vec[0]
  66. p3_x = p1_x + height * separating_length * vec[1]
  67. p3_y = p1_y - width * separating_length * vec[0]
  68. p0_x = int(round(p0_x))
  69. p0_y = int(round(p0_y))
  70. p1_x = int(round(p1_x))
  71. p1_y = int(round(p1_y))
  72. p2_x = int(round(p2_x))
  73. p2_y = int(round(p2_y))
  74. p3_x = int(round(p3_x))
  75. p3_y = int(round(p3_y))
  76. cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0), 2)
  77. cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0), 2)
  78. cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0), 2)
  79. def preprocess_image(image):
  80. """Preprocess numpy image to torch tensor."""
  81. if image.shape[0] != 512 or image.shape[1] != 512:
  82. image = cv.resize(image, (512, 512))
  83. return torch.unsqueeze(ToTensor()(image), 0)
  84. def detect_marking_points(detector, image, thresh, device):
  85. """Given image read from opencv, return detected marking points."""
  86. prediction = detector(preprocess_image(image).to(device))
  87. return get_predicted_points(prediction[0], thresh)
  88. def inference_slots(marking_points):
  89. """Inference slots based on marking points."""
  90. num_detected = len(marking_points)
  91. slots = []
  92. for i in range(num_detected - 1):
  93. for j in range(i + 1, num_detected):
  94. point_i = marking_points[i]
  95. point_j = marking_points[j]
  96. # Step 1: length filtration.
  97. distance = calc_point_squre_dist(point_i, point_j)
  98. if not (config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST
  99. or config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST):
  100. continue
  101. # Step 2: pass through filtration.
  102. if pass_through_third_point(marking_points, i, j):
  103. continue
  104. result = pair_marking_points(point_i, point_j)
  105. if result == 1:
  106. slots.append((i, j))
  107. elif result == -1:
  108. slots.append((j, i))
  109. return slots
  110. def detect_video(detector, device, args):
  111. """Demo for detecting video."""
  112. timer = Timer()
  113. input_video = cv.VideoCapture(args.video)
  114. frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
  115. frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
  116. output_video = cv.VideoWriter()
  117. if args.save:
  118. output_video.open('record.avi', cv.VideoWriter_fourcc(*'XVID'),
  119. input_video.get(cv.CAP_PROP_FPS),
  120. (frame_width, frame_height), True)
  121. frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
  122. while input_video.read(frame)[0]:
  123. timer.tic()
  124. pred_points = detect_marking_points(
  125. detector, frame, args.thresh, device)
  126. slots = None
  127. if pred_points and args.inference_slot:
  128. marking_points = list(list(zip(*pred_points))[1])
  129. slots = inference_slots(marking_points)
  130. timer.toc()
  131. plot_points(frame, pred_points)
  132. plot_slots(frame, pred_points, slots)
  133. cv.imshow('demo', frame)
  134. cv.waitKey(1)
  135. if args.save:
  136. output_video.write(frame)
  137. print("Average time: ", timer.calc_average_time(), "s.")
  138. input_video.release()
  139. output_video.release()
  140. def detect_image(detector, device, args):
  141. """Demo for detecting images."""
  142. timer = Timer()
  143. while True:
  144. image_file = input('Enter image file path: ')
  145. image = cv.imread(image_file)
  146. timer.tic()
  147. pred_points = detect_marking_points(
  148. detector, image, args.thresh, device)
  149. slots = None
  150. if pred_points and args.inference_slot:
  151. marking_points = list(list(zip(*pred_points))[1])
  152. slots = inference_slots(marking_points)
  153. timer.toc()
  154. plot_points(image, pred_points)
  155. plot_slots(image, pred_points, slots)
  156. cv.imshow('demo', image)
  157. cv.waitKey(1)
  158. if args.save:
  159. cv.imwrite('save.jpg', image, [int(cv.IMWRITE_JPEG_QUALITY), 100])
  160. def inference_detector(args):
  161. """Inference demo of directional point detector."""
  162. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  163. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  164. torch.set_grad_enabled(False)
  165. dp_detector = DirectionalPointDetector(
  166. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  167. dp_detector.load_state_dict(torch.load(args.detector_weights))
  168. dp_detector.eval()
  169. if args.mode == "image":
  170. detect_image(dp_detector, device, args)
  171. elif args.mode == "video":
  172. detect_video(dp_detector, device, args)
  173. if __name__ == '__main__':
  174. inference_detector(config.get_parser_for_inference().parse_args())