|
- """Evaluate directional marking point detector."""
- import json
- import os
- import cv2 as cv
- import numpy as np
- import torch
- import config
- import util
- from data import match_slots, Slot
- from model import DirectionalPointDetector
- from inference import detect_marking_points, inference_slots
-
-
- def get_ground_truths(label):
- """Read label to get ground truth slot."""
- slots = np.array(label['slots'])
- if slots.size == 0:
- return []
- if len(slots.shape) < 2:
- slots = np.expand_dims(slots, axis=0)
- marks = np.array(label['marks'])
- if len(marks.shape) < 2:
- marks = np.expand_dims(marks, axis=0)
- ground_truths = []
- for slot in slots:
- mark_a = marks[slot[0] - 1]
- mark_b = marks[slot[1] - 1]
- coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
- coords = (coords - 0.5) / 600
- ground_truths.append(Slot(*coords))
- return ground_truths
-
-
- def psevaluate_detector(args):
- """Evaluate directional point detector."""
- args.cuda = not args.disable_cuda and torch.cuda.is_available()
- device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
- torch.set_grad_enabled(False)
-
- dp_detector = DirectionalPointDetector(
- 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
- if args.detector_weights:
- dp_detector.load_state_dict(torch.load(args.detector_weights))
- dp_detector.eval()
-
- logger = util.Logger(enable_visdom=args.enable_visdom)
-
- ground_truths_list = []
- predictions_list = []
- for idx, label_file in enumerate(os.listdir(args.label_directory)):
- name = os.path.splitext(label_file)[0]
- print(idx, name)
- image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
- pred_points = detect_marking_points(
- dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
- slots = []
- if pred_points:
- marking_points = list(list(zip(*pred_points))[1])
- slots = inference_slots(marking_points)
- pred_slots = []
- for slot in slots:
- point_a = marking_points[slot[0]]
- point_b = marking_points[slot[1]]
- prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
- pred_slots.append(
- (prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
- predictions_list.append(pred_slots)
-
- with open(os.path.join(args.label_directory, label_file), 'r') as file:
- ground_truths_list.append(get_ground_truths(json.load(file)))
-
- precisions, recalls = util.calc_precision_recall(
- ground_truths_list, predictions_list, match_slots)
- average_precision = util.calc_average_precision(precisions, recalls)
- if args.enable_visdom:
- logger.plot_curve(precisions, recalls)
- logger.log(average_precision=average_precision)
-
-
- if __name__ == '__main__':
- psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())
|