车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 6.3KB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """Inference demo of directional point detector."""
  2. import math
  3. import cv2 as cv
  4. import numpy as np
  5. import torch
  6. from torchvision.transforms import ToTensor
  7. import config
  8. from data import get_predicted_points, pair_marking_points, filter_slots
  9. from model import DirectionalPointDetector
  10. from util import Timer
  11. def plot_points(image, pred_points):
  12. """Plot marking points on the image and show."""
  13. if not pred_points:
  14. return
  15. height = image.shape[0]
  16. width = image.shape[1]
  17. for confidence, marking_point in pred_points:
  18. p0_x = width * marking_point.x - 0.5
  19. p0_y = height * marking_point.y - 0.5
  20. cos_val = math.cos(marking_point.direction)
  21. sin_val = math.sin(marking_point.direction)
  22. p1_x = p0_x + 50*cos_val
  23. p1_y = p0_y + 50*sin_val
  24. p2_x = p0_x - 50*sin_val
  25. p2_y = p0_y + 50*cos_val
  26. p3_x = p0_x + 50*sin_val
  27. p3_y = p0_y - 50*cos_val
  28. p0_x = int(round(p0_x))
  29. p0_y = int(round(p0_y))
  30. p1_x = int(round(p1_x))
  31. p1_y = int(round(p1_y))
  32. p2_x = int(round(p2_x))
  33. p2_y = int(round(p2_y))
  34. cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
  35. cv.putText(image, str(confidence), (p0_x, p0_y),
  36. cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
  37. if marking_point.shape > 0.5:
  38. cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255))
  39. else:
  40. p3_x = int(round(p3_x))
  41. p3_y = int(round(p3_y))
  42. cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255))
  43. def plot_slots(image, pred_points, slots):
  44. if not pred_points or not slots:
  45. return
  46. marking_points = list(list(zip(*pred_points))[1])
  47. height = image.shape[0]
  48. width = image.shape[1]
  49. for slot in slots:
  50. point_a = marking_points[slot[0]]
  51. point_b = marking_points[slot[1]]
  52. p0_x = width * point_a.x - 0.5
  53. p0_y = height * point_a.y - 0.5
  54. p1_x = width * point_b.x - 0.5
  55. p1_y = height * point_b.y - 0.5
  56. vec = np.array([p1_x - p0_x, p1_y - p0_y])
  57. vec = vec / np.linalg.norm(vec)
  58. p2_x = p0_x + 200*vec[1]
  59. p2_y = p0_y - 200*vec[0]
  60. p3_x = p1_x + 200*vec[1]
  61. p3_y = p1_y - 200*vec[0]
  62. p0_x = int(round(p0_x))
  63. p0_y = int(round(p0_y))
  64. p1_x = int(round(p1_x))
  65. p1_y = int(round(p1_y))
  66. p2_x = int(round(p2_x))
  67. p2_y = int(round(p2_y))
  68. p3_x = int(round(p3_x))
  69. p3_y = int(round(p3_y))
  70. cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0))
  71. cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0))
  72. cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0))
  73. def preprocess_image(image):
  74. """Preprocess numpy image to torch tensor."""
  75. if image.shape[0] != 512 or image.shape[1] != 512:
  76. image = cv.resize(image, (512, 512))
  77. return torch.unsqueeze(ToTensor()(image), 0)
  78. def detect_marking_points(detector, image, thresh, device):
  79. """Given image read from opencv, return detected marking points."""
  80. prediction = detector(preprocess_image(image).to(device))
  81. return get_predicted_points(prediction[0], thresh)
  82. def inference_slots(marking_points):
  83. """Inference slots based on marking points."""
  84. num_detected = len(marking_points)
  85. slots = []
  86. for i in range(num_detected - 1):
  87. for j in range(i + 1, num_detected):
  88. result = pair_marking_points(marking_points[i], marking_points[j])
  89. if result == 1:
  90. slots.append((i, j))
  91. elif result == -1:
  92. slots.append((j, i))
  93. slots = filter_slots(marking_points, slots)
  94. return slots
  95. def detect_video(detector, device, args):
  96. """Demo for detecting video."""
  97. timer = Timer()
  98. input_video = cv.VideoCapture(args.video)
  99. frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
  100. frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
  101. output_video = cv.VideoWriter()
  102. if args.save:
  103. output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
  104. input_video.get(cv.CAP_PROP_FPS),
  105. (frame_width, frame_height))
  106. frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
  107. while input_video.read(frame)[0]:
  108. timer.tic()
  109. pred_points = detect_marking_points(
  110. detector, frame, args.thresh, device)
  111. slots = None
  112. if pred_points and args.inference_slot:
  113. marking_points = list(list(zip(*pred_points))[1])
  114. slots = inference_slots(marking_points)
  115. timer.toc()
  116. plot_points(frame, pred_points)
  117. plot_slots(frame, pred_points, slots)
  118. cv.imshow('demo', frame)
  119. cv.waitKey(1)
  120. if args.save:
  121. output_video.write(frame)
  122. print("Average time: ", timer.calc_average_time(), "s.")
  123. input_video.release()
  124. output_video.release()
  125. def detect_image(detector, device, args):
  126. """Demo for detecting images."""
  127. timer = Timer()
  128. while True:
  129. image_file = input('Enter image file path: ')
  130. image = cv.imread(image_file)
  131. timer.tic()
  132. pred_points = detect_marking_points(
  133. detector, image, args.thresh, device)
  134. if pred_points and args.inference_slot:
  135. marking_points = list(list(zip(*pred_points))[1])
  136. slots = inference_slots(marking_points)
  137. timer.toc()
  138. plot_points(image, pred_points)
  139. plot_slots(image, pred_points, slots)
  140. cv.imshow('demo', image)
  141. cv.waitKey(1)
  142. if args.save:
  143. cv.imwrite('save.jpg', image, [int(cv.IMWRITE_JPEG_QUALITY), 100])
  144. def inference_detector(args):
  145. """Inference demo of directional point detector."""
  146. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  147. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  148. torch.set_grad_enabled(False)
  149. dp_detector = DirectionalPointDetector(
  150. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  151. dp_detector.load_state_dict(torch.load(args.detector_weights))
  152. dp_detector.eval()
  153. if args.mode == "image":
  154. detect_image(dp_detector, device, args)
  155. elif args.mode == "video":
  156. detect_video(dp_detector, device, args)
  157. if __name__ == '__main__':
  158. inference_detector(config.get_parser_for_inference().parse_args())