车位角点检测代码
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

101 Zeilen
3.7KB

  1. """Evaluate directional marking point detector."""
  2. import torch
  3. import config
  4. import util
  5. from thop import profile
  6. from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
  7. from data import ParkingSlotDataset
  8. from model import DirectionalPointDetector
  9. from train import generate_objective
  10. def is_gt_and_pred_matched(ground_truths, predictions, thresh):
  11. """Check if there is any false positive or false negative."""
  12. predictions = [pred for pred in predictions if pred[0] >= thresh]
  13. prediction_matched = [False] * len(predictions)
  14. for ground_truth in ground_truths:
  15. idx = util.match_gt_with_preds(ground_truth, predictions,
  16. match_marking_points)
  17. if idx < 0:
  18. return False
  19. prediction_matched[idx] = True
  20. if not all(prediction_matched):
  21. return False
  22. return True
  23. def collect_error(ground_truths, predictions, thresh):
  24. """Collect errors for those correctly detected points."""
  25. dists = []
  26. angles = []
  27. predictions = [pred for pred in predictions if pred[0] >= thresh]
  28. for ground_truth in ground_truths:
  29. idx = util.match_gt_with_preds(ground_truth, predictions,
  30. match_marking_points)
  31. if idx >= 0:
  32. detected_point = predictions[idx][1]
  33. dists.append(calc_point_squre_dist(detected_point, ground_truth))
  34. angles.append(calc_point_direction_angle(
  35. detected_point, ground_truth))
  36. else:
  37. continue
  38. return dists, angles
  39. def evaluate_detector(args):
  40. """Evaluate directional point detector."""
  41. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  42. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  43. torch.set_grad_enabled(False)
  44. dp_detector = DirectionalPointDetector(
  45. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  46. if args.detector_weights:
  47. dp_detector.load_state_dict(torch.load(args.detector_weights))
  48. dp_detector.eval()
  49. psdataset = ParkingSlotDataset(args.dataset_directory)
  50. logger = util.Logger(enable_visdom=args.enable_visdom)
  51. total_loss = 0
  52. position_errors = []
  53. direction_errors = []
  54. ground_truths_list = []
  55. predictions_list = []
  56. for iter_idx, (image, marking_points) in enumerate(psdataset):
  57. ground_truths_list.append(marking_points)
  58. image = torch.unsqueeze(image, 0).to(device)
  59. prediction = dp_detector(image)
  60. objective, gradient = generate_objective([marking_points], device)
  61. loss = (prediction - objective) ** 2
  62. total_loss += torch.sum(loss*gradient).item()
  63. pred_points = get_predicted_points(prediction[0], 0.01)
  64. predictions_list.append(pred_points)
  65. dists, angles = collect_error(marking_points, pred_points,
  66. config.CONFID_THRESH_FOR_POINT)
  67. position_errors += dists
  68. direction_errors += angles
  69. logger.log(iter=iter_idx, total_loss=total_loss)
  70. precisions, recalls = util.calc_precision_recall(
  71. ground_truths_list, predictions_list, match_marking_points)
  72. average_precision = util.calc_average_precision(precisions, recalls)
  73. if args.enable_visdom:
  74. logger.plot_curve(precisions, recalls)
  75. sample = torch.randn(1, 3, config.INPUT_IMAGE_SIZE,
  76. config.INPUT_IMAGE_SIZE)
  77. flops, params = profile(dp_detector, inputs=(sample.to(device), ))
  78. logger.log(average_loss=total_loss / len(psdataset),
  79. average_precision=average_precision,
  80. flops=flops,
  81. params=params)
  82. if __name__ == '__main__':
  83. evaluate_detector(config.get_parser_for_evaluation().parse_args())