177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
"""Inference demo of directional point detector."""
|
|
import math
|
|
import cv2 as cv
|
|
import numpy as np
|
|
import torch
|
|
from torchvision.transforms import ToTensor
|
|
import config
|
|
from data import get_predicted_points, pair_marking_points, filter_slots
|
|
from model import DirectionalPointDetector
|
|
from util import Timer
|
|
|
|
|
|
def plot_points(image, pred_points):
|
|
"""Plot marking points on the image and show."""
|
|
if not pred_points:
|
|
return
|
|
height = image.shape[0]
|
|
width = image.shape[1]
|
|
for confidence, marking_point in pred_points:
|
|
p0_x = width * marking_point.x - 0.5
|
|
p0_y = height * marking_point.y - 0.5
|
|
cos_val = math.cos(marking_point.direction)
|
|
sin_val = math.sin(marking_point.direction)
|
|
p1_x = p0_x + 50*cos_val
|
|
p1_y = p0_y + 50*sin_val
|
|
p2_x = p0_x - 50*sin_val
|
|
p2_y = p0_y + 50*cos_val
|
|
p3_x = p0_x + 50*sin_val
|
|
p3_y = p0_y - 50*cos_val
|
|
p0_x = int(round(p0_x))
|
|
p0_y = int(round(p0_y))
|
|
p1_x = int(round(p1_x))
|
|
p1_y = int(round(p1_y))
|
|
p2_x = int(round(p2_x))
|
|
p2_y = int(round(p2_y))
|
|
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
|
|
cv.putText(image, str(confidence), (p0_x, p0_y),
|
|
cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
|
|
if marking_point.shape > 0.5:
|
|
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255))
|
|
else:
|
|
p3_x = int(round(p3_x))
|
|
p3_y = int(round(p3_y))
|
|
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255))
|
|
|
|
|
|
def plot_slots(image, pred_points, slots):
|
|
if not pred_points or not slots:
|
|
return
|
|
marking_points = list(list(zip(*pred_points))[1])
|
|
height = image.shape[0]
|
|
width = image.shape[1]
|
|
for slot in slots:
|
|
point_a = marking_points[slot[0]]
|
|
point_b = marking_points[slot[1]]
|
|
p0_x = width * point_a.x - 0.5
|
|
p0_y = height * point_a.y - 0.5
|
|
p1_x = width * point_b.x - 0.5
|
|
p1_y = height * point_b.y - 0.5
|
|
vec = np.array([p1_x - p0_x, p1_y - p0_y])
|
|
vec = vec / np.linalg.norm(vec)
|
|
p2_x = p0_x + 200*vec[1]
|
|
p2_y = p0_y - 200*vec[0]
|
|
p3_x = p1_x + 200*vec[1]
|
|
p3_y = p1_y - 200*vec[0]
|
|
p0_x = int(round(p0_x))
|
|
p0_y = int(round(p0_y))
|
|
p1_x = int(round(p1_x))
|
|
p1_y = int(round(p1_y))
|
|
p2_x = int(round(p2_x))
|
|
p2_y = int(round(p2_y))
|
|
p3_x = int(round(p3_x))
|
|
p3_y = int(round(p3_y))
|
|
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0))
|
|
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0))
|
|
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0))
|
|
|
|
|
|
def preprocess_image(image):
|
|
"""Preprocess numpy image to torch tensor."""
|
|
if image.shape[0] != 512 or image.shape[1] != 512:
|
|
image = cv.resize(image, (512, 512))
|
|
return torch.unsqueeze(ToTensor()(image), 0)
|
|
|
|
|
|
def detect_marking_points(detector, image, thresh, device):
|
|
"""Given image read from opencv, return detected marking points."""
|
|
prediction = detector(preprocess_image(image).to(device))
|
|
return get_predicted_points(prediction[0], thresh)
|
|
|
|
|
|
def inference_slots(marking_points):
|
|
"""Inference slots based on marking points."""
|
|
num_detected = len(marking_points)
|
|
slots = []
|
|
for i in range(num_detected - 1):
|
|
for j in range(i + 1, num_detected):
|
|
result = pair_marking_points(marking_points[i], marking_points[j])
|
|
if result == 1:
|
|
slots.append((i, j))
|
|
elif result == -1:
|
|
slots.append((j, i))
|
|
slots = filter_slots(marking_points, slots)
|
|
return slots
|
|
|
|
|
|
def detect_video(detector, device, args):
|
|
"""Demo for detecting video."""
|
|
timer = Timer()
|
|
input_video = cv.VideoCapture(args.video)
|
|
frame_width = int(input_video.get(cv.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
|
|
output_video = cv.VideoWriter()
|
|
if args.save:
|
|
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
|
|
input_video.get(cv.CAP_PROP_FPS),
|
|
(frame_width, frame_height))
|
|
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
|
while input_video.read(frame)[0]:
|
|
timer.tic()
|
|
pred_points = detect_marking_points(
|
|
detector, frame, args.thresh, device)
|
|
slots = None
|
|
if pred_points and args.inference_slot:
|
|
marking_points = list(list(zip(*pred_points))[1])
|
|
slots = inference_slots(marking_points)
|
|
timer.toc()
|
|
plot_points(frame, pred_points)
|
|
plot_slots(frame, pred_points, slots)
|
|
cv.imshow('demo', frame)
|
|
cv.waitKey(1)
|
|
if args.save:
|
|
output_video.write(frame)
|
|
print("Average time: ", timer.calc_average_time(), "s.")
|
|
input_video.release()
|
|
output_video.release()
|
|
|
|
|
|
def detect_image(detector, device, args):
|
|
"""Demo for detecting images."""
|
|
timer = Timer()
|
|
while True:
|
|
image_file = input('Enter image file path: ')
|
|
image = cv.imread(image_file)
|
|
timer.tic()
|
|
pred_points = detect_marking_points(
|
|
detector, image, args.thresh, device)
|
|
if pred_points and args.inference_slot:
|
|
marking_points = list(list(zip(*pred_points))[1])
|
|
slots = inference_slots(marking_points)
|
|
timer.toc()
|
|
plot_points(image, pred_points)
|
|
plot_slots(image, pred_points, slots)
|
|
cv.imshow('demo', image)
|
|
cv.waitKey(1)
|
|
if args.save:
|
|
cv.imwrite('save.jpg', image, [int(cv.IMWRITE_JPEG_QUALITY), 100])
|
|
|
|
|
|
def inference_detector(args):
|
|
"""Inference demo of directional point detector."""
|
|
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
|
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
|
|
torch.set_grad_enabled(False)
|
|
dp_detector = DirectionalPointDetector(
|
|
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
|
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
|
dp_detector.eval()
|
|
if args.mode == "image":
|
|
detect_image(dp_detector, device, args)
|
|
elif args.mode == "video":
|
|
detect_video(dp_detector, device, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
inference_detector(config.get_parser_for_inference().parse_args())
|