车位角点检测代码
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

101 行
3.7KB

  1. """Evaluate directional marking point detector."""
  2. import torch
  3. import config
  4. import util
  5. from thop import profile
  6. from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
  7. from data import ParkingSlotDataset
  8. from model import DirectionalPointDetector
  9. from train import generate_objective
  10. def is_gt_and_pred_matched(ground_truths, predictions, thresh):
  11. """Check if there is any false positive or false negative."""
  12. predictions = [pred for pred in predictions if pred[0] >= thresh]
  13. prediction_matched = [False] * len(predictions)
  14. for ground_truth in ground_truths:
  15. idx = util.match_gt_with_preds(ground_truth, predictions,
  16. match_marking_points)
  17. if idx < 0:
  18. return False
  19. prediction_matched[idx] = True
  20. if not all(prediction_matched):
  21. return False
  22. return True
  23. def collect_error(ground_truths, predictions, thresh):
  24. """Collect errors for those correctly detected points."""
  25. dists = []
  26. angles = []
  27. predictions = [pred for pred in predictions if pred[0] >= thresh]
  28. for ground_truth in ground_truths:
  29. idx = util.match_gt_with_preds(ground_truth, predictions,
  30. match_marking_points)
  31. if idx >= 0:
  32. detected_point = predictions[idx][1]
  33. dists.append(calc_point_squre_dist(detected_point, ground_truth))
  34. angles.append(calc_point_direction_angle(
  35. detected_point, ground_truth))
  36. else:
  37. continue
  38. return dists, angles
  39. def evaluate_detector(args):
  40. """Evaluate directional point detector."""
  41. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  42. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  43. torch.set_grad_enabled(False)
  44. dp_detector = DirectionalPointDetector(
  45. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  46. if args.detector_weights:
  47. dp_detector.load_state_dict(torch.load(args.detector_weights))
  48. dp_detector.eval()
  49. psdataset = ParkingSlotDataset(args.dataset_directory)
  50. logger = util.Logger(enable_visdom=args.enable_visdom)
  51. total_loss = 0
  52. position_errors = []
  53. direction_errors = []
  54. ground_truths_list = []
  55. predictions_list = []
  56. for iter_idx, (image, marking_points) in enumerate(psdataset):
  57. ground_truths_list.append(marking_points)
  58. image = torch.unsqueeze(image, 0).to(device)
  59. prediction = dp_detector(image)
  60. objective, gradient = generate_objective([marking_points], device)
  61. loss = (prediction - objective) ** 2
  62. total_loss += torch.sum(loss*gradient).item()
  63. pred_points = get_predicted_points(prediction[0], 0.01)
  64. predictions_list.append(pred_points)
  65. dists, angles = collect_error(marking_points, pred_points,
  66. config.CONFID_THRESH_FOR_POINT)
  67. position_errors += dists
  68. direction_errors += angles
  69. logger.log(iter=iter_idx, total_loss=total_loss)
  70. precisions, recalls = util.calc_precision_recall(
  71. ground_truths_list, predictions_list, match_marking_points)
  72. average_precision = util.calc_average_precision(precisions, recalls)
  73. if args.enable_visdom:
  74. logger.plot_curve(precisions, recalls)
  75. sample = torch.randn(1, 3, config.INPUT_IMAGE_SIZE,
  76. config.INPUT_IMAGE_SIZE)
  77. flops, params = profile(dp_detector, inputs=(sample.to(device), ))
  78. logger.log(average_loss=total_loss / len(psdataset),
  79. average_precision=average_precision,
  80. flops=flops,
  81. params=params)
  82. if __name__ == '__main__':
  83. evaluate_detector(config.get_parser_for_evaluation().parse_args())