|
- """Evaluate directional marking point detector."""
- import torch
- import config
- import util
- from thop import profile
- from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
- from data import ParkingSlotDataset
- from model import DirectionalPointDetector
- from train import generate_objective
-
-
- def is_gt_and_pred_matched(ground_truths, predictions, thresh):
- """Check if there is any false positive or false negative."""
- predictions = [pred for pred in predictions if pred[0] >= thresh]
- prediction_matched = [False] * len(predictions)
- for ground_truth in ground_truths:
- idx = util.match_gt_with_preds(ground_truth, predictions,
- match_marking_points)
- if idx < 0:
- return False
- prediction_matched[idx] = True
- if not all(prediction_matched):
- return False
- return True
-
-
- def collect_error(ground_truths, predictions, thresh):
- """Collect errors for those correctly detected points."""
- dists = []
- angles = []
- predictions = [pred for pred in predictions if pred[0] >= thresh]
- for ground_truth in ground_truths:
- idx = util.match_gt_with_preds(ground_truth, predictions,
- match_marking_points)
- if idx >= 0:
- detected_point = predictions[idx][1]
- dists.append(calc_point_squre_dist(detected_point, ground_truth))
- angles.append(calc_point_direction_angle(
- detected_point, ground_truth))
- else:
- continue
- return dists, angles
-
-
- def evaluate_detector(args):
- """Evaluate directional point detector."""
- args.cuda = not args.disable_cuda and torch.cuda.is_available()
- device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
- torch.set_grad_enabled(False)
-
- dp_detector = DirectionalPointDetector(
- 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
- if args.detector_weights:
- dp_detector.load_state_dict(torch.load(args.detector_weights))
- dp_detector.eval()
-
- psdataset = ParkingSlotDataset(args.dataset_directory)
- logger = util.Logger(enable_visdom=args.enable_visdom)
-
- total_loss = 0
- position_errors = []
- direction_errors = []
- ground_truths_list = []
- predictions_list = []
- for iter_idx, (image, marking_points) in enumerate(psdataset):
- ground_truths_list.append(marking_points)
-
- image = torch.unsqueeze(image, 0).to(device)
- prediction = dp_detector(image)
- objective, gradient = generate_objective([marking_points], device)
- loss = (prediction - objective) ** 2
- total_loss += torch.sum(loss*gradient).item()
-
- pred_points = get_predicted_points(prediction[0], 0.01)
- predictions_list.append(pred_points)
-
- dists, angles = collect_error(marking_points, pred_points,
- config.CONFID_THRESH_FOR_POINT)
- position_errors += dists
- direction_errors += angles
-
- logger.log(iter=iter_idx, total_loss=total_loss)
-
- precisions, recalls = util.calc_precision_recall(
- ground_truths_list, predictions_list, match_marking_points)
- average_precision = util.calc_average_precision(precisions, recalls)
- if args.enable_visdom:
- logger.plot_curve(precisions, recalls)
-
- sample = torch.randn(1, 3, config.INPUT_IMAGE_SIZE,
- config.INPUT_IMAGE_SIZE)
- flops, params = profile(dp_detector, inputs=(sample.to(device), ))
- logger.log(average_loss=total_loss / len(psdataset),
- average_precision=average_precision,
- flops=flops,
- params=params)
-
-
- if __name__ == '__main__':
- evaluate_detector(config.get_parser_for_evaluation().parse_args())
|