86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
"""Inference demo of directional point detector."""
|
|
import math
|
|
import cv2 as cv
|
|
import numpy as np
|
|
import torch
|
|
from torchvision.transforms import ToTensor
|
|
import config
|
|
from detector import DirectionalPointDetector
|
|
from utils import get_marking_points, Timer
|
|
|
|
|
|
def plot_marking_points(image, marking_points):
|
|
"""Plot marking points on the image and show."""
|
|
height = image.shape[0]
|
|
width = image.shape[1]
|
|
for marking_point in marking_points:
|
|
p0_x = width * marking_point[0]
|
|
p0_y = height * marking_point[1]
|
|
p1_x = p0_x + 50 * math.cos(marking_point[2])
|
|
p1_y = p0_y + 50 * math.sin(marking_point[2])
|
|
p0_x = int(round(p0_x))
|
|
p0_y = int(round(p0_y))
|
|
p1_x = int(round(p1_x))
|
|
p1_y = int(round(p1_y))
|
|
cv.arrowedLine(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
|
|
cv.imshow('demo', image)
|
|
cv.waitKey(1)
|
|
|
|
|
|
def preprocess_image(image):
|
|
"""Preprocess numpy image to torch tensor."""
|
|
if image.shape[0] != 512 or image.shape[1] != 512:
|
|
image = cv.resize(image, (512, 512))
|
|
return torch.unsqueeze(ToTensor()(image), 0)
|
|
|
|
|
|
def detect_video(detector, device, args):
|
|
"""Demo for detecting video."""
|
|
timer = Timer()
|
|
input_video = cv.VideoCapture(args.video)
|
|
frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
|
|
output_video = cv.VideoWriter()
|
|
if args.save:
|
|
output_video.open('record.avi', cv.VideoWriter_fourcc(* 'MJPG'),
|
|
input_video.get(cv.CAP_PROP_FPS),
|
|
(frame_width, frame_height))
|
|
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
|
while input_video.read(frame)[0]:
|
|
if args.timing:
|
|
timer.tic()
|
|
prediction = detector(preprocess_image(frame).to(device))
|
|
if args.timing:
|
|
timer.toc()
|
|
pred_points = get_marking_points(prediction[0], args.thresh)
|
|
plot_marking_points(frame, pred_points)
|
|
if args.save:
|
|
output_video.write(frame)
|
|
input_video.release()
|
|
output_video.release()
|
|
|
|
|
|
def detect_image(detector, device, args):
|
|
"""Demo for detecting images."""
|
|
image_file = input('Enter image file path: ')
|
|
image = cv.imread(image_file)
|
|
prediction = detector(preprocess_image(image).to(device))
|
|
pred_points = get_marking_points(prediction[0], args.thresh)
|
|
plot_marking_points(image, pred_points)
|
|
|
|
|
|
def inference_detector(args):
|
|
"""Inference demo of directional point detector."""
|
|
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
|
device = torch.device("cuda:" + str(args.gpu_id) if args.cuda else "cpu")
|
|
dp_detector = DirectionalPointDetector(3, args.depth_factor, 5).to(device)
|
|
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
|
if args.mode == "image":
|
|
detect_image(dp_detector, device, args)
|
|
elif args.mode == "video":
|
|
detect_video(dp_detector, device, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
inference_detector(config.get_parser_for_inference().parse_args())
|