车位角点检测代码
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

82 lines
2.9KB

  1. """Evaluate directional marking point detector."""
  2. import json
  3. import os
  4. import cv2 as cv
  5. import numpy as np
  6. import torch
  7. import config
  8. import util
  9. from data import match_slots, Slot
  10. from model import DirectionalPointDetector
  11. from inference import detect_marking_points, inference_slots
  12. def get_ground_truths(label):
  13. """Read label to get ground truth slot."""
  14. slots = np.array(label['slots'])
  15. if slots.size == 0:
  16. return []
  17. if len(slots.shape) < 2:
  18. slots = np.expand_dims(slots, axis=0)
  19. marks = np.array(label['marks'])
  20. if len(marks.shape) < 2:
  21. marks = np.expand_dims(marks, axis=0)
  22. ground_truths = []
  23. for slot in slots:
  24. mark_a = marks[slot[0] - 1]
  25. mark_b = marks[slot[1] - 1]
  26. coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
  27. coords = (coords - 0.5) / 600
  28. ground_truths.append(Slot(*coords))
  29. return ground_truths
  30. def psevaluate_detector(args):
  31. """Evaluate directional point detector."""
  32. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  33. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  34. torch.set_grad_enabled(False)
  35. dp_detector = DirectionalPointDetector(
  36. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  37. if args.detector_weights:
  38. dp_detector.load_state_dict(torch.load(args.detector_weights))
  39. dp_detector.eval()
  40. logger = util.Logger(enable_visdom=args.enable_visdom)
  41. ground_truths_list = []
  42. predictions_list = []
  43. for idx, label_file in enumerate(os.listdir(args.label_directory)):
  44. name = os.path.splitext(label_file)[0]
  45. print(idx, name)
  46. image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
  47. pred_points = detect_marking_points(
  48. dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
  49. slots = []
  50. if pred_points:
  51. marking_points = list(list(zip(*pred_points))[1])
  52. slots = inference_slots(marking_points)
  53. pred_slots = []
  54. for slot in slots:
  55. point_a = marking_points[slot[0]]
  56. point_b = marking_points[slot[1]]
  57. prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
  58. pred_slots.append(
  59. (prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
  60. predictions_list.append(pred_slots)
  61. with open(os.path.join(args.label_directory, label_file), 'r') as file:
  62. ground_truths_list.append(get_ground_truths(json.load(file)))
  63. precisions, recalls = util.calc_precision_recall(
  64. ground_truths_list, predictions_list, match_slots)
  65. average_precision = util.calc_average_precision(precisions, recalls)
  66. if args.enable_visdom:
  67. logger.plot_curve(precisions, recalls)
  68. logger.log(average_precision=average_precision)
  69. if __name__ == '__main__':
  70. psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())