2018-11-21 15:45:22 +08:00
|
|
|
"""Evaluate directional marking point detector."""
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import cv2 as cv
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import config
|
|
|
|
|
import util
|
|
|
|
|
from data import match_slots, Slot
|
|
|
|
|
from model import DirectionalPointDetector
|
|
|
|
|
from inference import detect_marking_points, inference_slots
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_ground_truths(label):
|
2019-04-07 14:11:38 +08:00
|
|
|
"""Read label to get ground truth slot."""
|
2018-11-21 15:45:22 +08:00
|
|
|
slots = np.array(label['slots'])
|
|
|
|
|
if slots.size == 0:
|
|
|
|
|
return []
|
|
|
|
|
if len(slots.shape) < 2:
|
|
|
|
|
slots = np.expand_dims(slots, axis=0)
|
|
|
|
|
marks = np.array(label['marks'])
|
|
|
|
|
if len(marks.shape) < 2:
|
|
|
|
|
marks = np.expand_dims(marks, axis=0)
|
|
|
|
|
ground_truths = []
|
|
|
|
|
for slot in slots:
|
|
|
|
|
mark_a = marks[slot[0] - 1]
|
|
|
|
|
mark_b = marks[slot[1] - 1]
|
|
|
|
|
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
|
|
|
|
|
coords = (coords - 0.5) / 600
|
|
|
|
|
ground_truths.append(Slot(*coords))
|
|
|
|
|
return ground_truths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def psevaluate_detector(args):
|
|
|
|
|
"""Evaluate directional point detector."""
|
|
|
|
|
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
|
|
|
|
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
|
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
|
|
dp_detector = DirectionalPointDetector(
|
|
|
|
|
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
|
|
|
|
if args.detector_weights:
|
|
|
|
|
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
|
|
|
|
dp_detector.eval()
|
|
|
|
|
|
|
|
|
|
logger = util.Logger(enable_visdom=args.enable_visdom)
|
|
|
|
|
|
|
|
|
|
ground_truths_list = []
|
|
|
|
|
predictions_list = []
|
|
|
|
|
for idx, label_file in enumerate(os.listdir(args.label_directory)):
|
|
|
|
|
name = os.path.splitext(label_file)[0]
|
|
|
|
|
print(idx, name)
|
|
|
|
|
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
|
|
|
|
|
pred_points = detect_marking_points(
|
|
|
|
|
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
|
|
|
|
|
if pred_points:
|
|
|
|
|
marking_points = list(list(zip(*pred_points))[1])
|
|
|
|
|
slots = inference_slots(marking_points)
|
|
|
|
|
pred_slots = []
|
|
|
|
|
for slot in slots:
|
|
|
|
|
point_a = marking_points[slot[0]]
|
|
|
|
|
point_b = marking_points[slot[1]]
|
|
|
|
|
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
|
|
|
|
|
pred_slots.append(
|
|
|
|
|
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
|
|
|
|
|
predictions_list.append(pred_slots)
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(args.label_directory, label_file), 'r') as file:
|
|
|
|
|
ground_truths_list.append(get_ground_truths(json.load(file)))
|
|
|
|
|
|
|
|
|
|
precisions, recalls = util.calc_precision_recall(
|
|
|
|
|
ground_truths_list, predictions_list, match_slots)
|
|
|
|
|
average_precision = util.calc_average_precision(precisions, recalls)
|
|
|
|
|
if args.enable_visdom:
|
|
|
|
|
logger.plot_curve(precisions, recalls)
|
|
|
|
|
logger.log(average_precision=average_precision)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())
|