车位角点检测代码
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

61 lines
2.3KB

  1. """Evaluate directional marking point detector."""
  2. import torch
  3. from torch.utils.data import DataLoader
  4. import config
  5. import util
  6. from data import get_predicted_points, match_marking_points
  7. from data import ParkingSlotDataset
  8. from model import DirectionalPointDetector
  9. from train import generate_objective
  10. def evaluate_detector(args):
  11. """Evaluate directional point detector."""
  12. args.cuda = not args.disable_cuda and torch.cuda.is_available()
  13. device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
  14. torch.set_grad_enabled(False)
  15. dp_detector = DirectionalPointDetector(
  16. 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  17. if args.detector_weights:
  18. dp_detector.load_state_dict(torch.load(args.detector_weights))
  19. dp_detector.eval()
  20. torch.multiprocessing.set_sharing_strategy('file_system')
  21. data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
  22. batch_size=args.batch_size, shuffle=True,
  23. num_workers=args.data_loading_workers,
  24. collate_fn=lambda x: list(zip(*x)))
  25. logger = util.Logger(enable_visdom=args.enable_visdom)
  26. total_loss = 0
  27. num_evaluation = 0
  28. ground_truths_list = []
  29. predictions_list = []
  30. for iter_idx, (image, marking_points) in enumerate(data_loader):
  31. image = torch.stack(image)
  32. image = image.to(device)
  33. ground_truths_list += list(marking_points)
  34. prediction = dp_detector(image)
  35. objective, gradient = generate_objective(marking_points, device)
  36. loss = (prediction - objective) ** 2
  37. total_loss += torch.sum(loss*gradient).item()
  38. num_evaluation += loss.size(0)
  39. pred_points = [get_predicted_points(pred, 0.01) for pred in prediction]
  40. predictions_list += pred_points
  41. logger.log(iter=iter_idx, total_loss=total_loss)
  42. precisions, recalls = util.calc_precision_recall(
  43. ground_truths_list, predictions_list, match_marking_points)
  44. average_precision = util.calc_average_precision(precisions, recalls)
  45. if args.enable_visdom:
  46. logger.plot_curve(precisions, recalls)
  47. logger.log(average_loss=total_loss / num_evaluation,
  48. average_precision=average_precision)
  49. if __name__ == '__main__':
  50. evaluate_detector(config.get_parser_for_evaluation().parse_args())