DMPR-PS/inference.py

86 lines
3.0 KiB
Python
Raw Normal View History

2018-07-20 16:25:15 +08:00
"""Inference demo of directional point detector."""
import math
import cv2 as cv
import numpy as np
import torch
from torchvision.transforms import ToTensor
import config
from detector import DirectionalPointDetector
from utils import get_marking_points, Timer
def plot_marking_points(image, marking_points):
"""Plot marking points on the image and show."""
height = image.shape[0]
width = image.shape[1]
for marking_point in marking_points:
p0_x = width * marking_point[0]
p0_y = height * marking_point[1]
p1_x = p0_x + 50 * math.cos(marking_point[2])
p1_y = p0_y + 50 * math.sin(marking_point[2])
p0_x = int(round(p0_x))
p0_y = int(round(p0_y))
p1_x = int(round(p1_x))
p1_y = int(round(p1_y))
cv.arrowedLine(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
cv.imshow('demo', image)
cv.waitKey(1)
def preprocess_image(image):
"""Preprocess numpy image to torch tensor."""
if image.shape[0] != 512 or image.shape[1] != 512:
image = cv.resize(image, (512, 512))
return torch.unsqueeze(ToTensor()(image), 0)
def detect_video(detector, device, args):
"""Demo for detecting video."""
timer = Timer()
input_video = cv.VideoCapture(args.video)
frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
output_video = cv.VideoWriter()
if args.save:
output_video.open('record.avi', cv.VideoWriter_fourcc(* 'MJPG'),
input_video.get(cv.CAP_PROP_FPS),
(frame_width, frame_height))
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
while input_video.read(frame)[0]:
if args.timing:
timer.tic()
prediction = detector(preprocess_image(frame).to(device))
if args.timing:
timer.toc()
pred_points = get_marking_points(prediction[0], args.thresh)
plot_marking_points(frame, pred_points)
if args.save:
output_video.write(frame)
input_video.release()
output_video.release()
def detect_image(detector, device, args):
"""Demo for detecting images."""
image_file = input('Enter image file path: ')
image = cv.imread(image_file)
prediction = detector(preprocess_image(image).to(device))
pred_points = get_marking_points(prediction[0], args.thresh)
plot_marking_points(image, pred_points)
def inference_detector(args):
"""Inference demo of directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device("cuda:" + str(args.gpu_id) if args.cuda else "cpu")
dp_detector = DirectionalPointDetector(3, args.depth_factor, 5).to(device)
dp_detector.load_state_dict(torch.load(args.detector_weights))
if args.mode == "image":
detect_image(dp_detector, device, args)
elif args.mode == "video":
detect_video(dp_detector, device, args)
if __name__ == '__main__':
inference_detector(config.get_parser_for_inference().parse_args())