DMPR-PS/ps_evaluate.py

81 lines
2.9 KiB
Python

"""Evaluate directional marking point detector."""
import json
import os
import cv2 as cv
import numpy as np
import torch
import config
import util
from data import match_slots, Slot
from model import DirectionalPointDetector
from inference import detect_marking_points, inference_slots
def get_ground_truths(label):
"""Read label to get ground truth slot."""
slots = np.array(label['slots'])
if slots.size == 0:
return []
if len(slots.shape) < 2:
slots = np.expand_dims(slots, axis=0)
marks = np.array(label['marks'])
if len(marks.shape) < 2:
marks = np.expand_dims(marks, axis=0)
ground_truths = []
for slot in slots:
mark_a = marks[slot[0] - 1]
mark_b = marks[slot[1] - 1]
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
coords = (coords - 0.5) / 600
ground_truths.append(Slot(*coords))
return ground_truths
def psevaluate_detector(args):
"""Evaluate directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(False)
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.eval()
logger = util.Logger(enable_visdom=args.enable_visdom)
ground_truths_list = []
predictions_list = []
for idx, label_file in enumerate(os.listdir(args.label_directory)):
name = os.path.splitext(label_file)[0]
print(idx, name)
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
pred_points = detect_marking_points(
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
if pred_points:
marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points)
pred_slots = []
for slot in slots:
point_a = marking_points[slot[0]]
point_b = marking_points[slot[1]]
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
pred_slots.append(
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
predictions_list.append(pred_slots)
with open(os.path.join(args.label_directory, label_file), 'r') as file:
ground_truths_list.append(get_ground_truths(json.load(file)))
precisions, recalls = util.calc_precision_recall(
ground_truths_list, predictions_list, match_slots)
average_precision = util.calc_average_precision(precisions, recalls)
if args.enable_visdom:
logger.plot_curve(precisions, recalls)
logger.log(average_precision=average_precision)
if __name__ == '__main__':
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())