车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.8KB

  1. """Inference demo of directional point detector."""
  2. import math
  3. import cv2 as cv
  4. import numpy as np
  5. import torch
  6. from torchvision.transforms import ToTensor
  7. import config
  8. from data import get_predicted_points
  9. from model import DirectionalPointDetector
  10. from util import Timer
  11. def plot_marking_points(image, marking_points):
  12. """Plot marking points on the image and show."""
  13. height = image.shape[0]
  14. width = image.shape[1]
  15. for marking_point in marking_points:
  16. p0_x = width * marking_point.x - 0.5
  17. p0_y = height * marking_point.y - 0.5
  18. cos_val = math.cos(marking_point.direction)
  19. sin_val = math.sin(marking_point.direction)
  20. p1_x = p0_x + 50*cos_val
  21. p1_y = p0_y + 50*sin_val
  22. p2_x = p0_x - 50*sin_val
  23. p2_y = p0_y + 50*cos_val
  24. p3_x = p0_x + 50*sin_val
  25. p3_y = p0_y - 50*cos_val
  26. p0_x = int(round(p0_x))
  27. p0_y = int(round(p0_y))
  28. p1_x = int(round(p1_x))
  29. p1_y = int(round(p1_y))
  30. p2_x = int(round(p2_x))
  31. p2_y = int(round(p2_y))
  32. cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
  33. if marking_point.shape > 0.5:
  34. cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255))
  35. else:
  36. p3_x = int(round(p3_x))
  37. p3_y = int(round(p3_y))
  38. cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255))
  39. def preprocess_image(image):
  40. """Preprocess numpy image to torch tensor."""
  41. if image.shape[0] != 512 or image.shape[1] != 512:
  42. image = cv.resize(image, (512, 512))
  43. return torch.unsqueeze(ToTensor()(image), 0)
  44. def detect_video(detector, device, args):
  45. """Demo for detecting video."""
  46. timer = Timer()
  47. input_video = cv.VideoCapture(args.video)
  48. frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
  49. frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
  50. output_video = cv.VideoWriter()
  51. if args.save:
  52. output_video.open('record.avi', cv.VideoWriter_fourcc(* 'MJPG'),
  53. input_video.get(cv.CAP_PROP_FPS),
  54. (frame_width, frame_height))
  55. frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
  56. while input_video.read(frame)[0]:
  57. if args.timing:
  58. timer.tic()
  59. prediction = detector(preprocess_image(frame).to(device))
  60. if args.timing:
  61. timer.toc()
  62. pred_points = get_predicted_points(prediction[0], args.thresh)
  63. if pred_points:
  64. plot_marking_points(frame, list(list(zip(*pred_points))[1]))
  65. cv.imshow('demo', frame)
  66. cv.waitKey(1)
  67. if args.save:
  68. output_video.write(frame)
  69. input_video.release()
  70. output_video.release()
  71. def detect_image(detector, device, args):
  72. """Demo for detecting images."""
  73. image_file = input('Enter image file path: ')
  74. image = cv.imread(image_file)
  75. prediction = detector(preprocess_image(image).to(device))
  76. pred_points = get_predicted_points(prediction[0], args.thresh)
  77. if pred_points:
  78. plot_marking_points(image, list(list(zip(*pred_points))[1]))
  79. cv.imshow('demo', image)
  80. cv.waitKey(1)
  81. def inference_detector(args):
  82. """Inference demo of directional point detector."""
  83. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  84. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  85. torch.set_grad_enabled(False)
  86. dp_detector = DirectionalPointDetector(
  87. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  88. dp_detector.load_state_dict(torch.load(args.detector_weights))
  89. dp_detector.eval()
  90. if args.mode == "image":
  91. detect_image(dp_detector, device, args)
  92. elif args.mode == "video":
  93. detect_video(dp_detector, device, args)
  94. if __name__ == '__main__':
  95. inference_detector(config.get_parser_for_inference().parse_args())